Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 83 additions & 58 deletions gatetools/bin/gt_phsp_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,55 @@

logger = logging.getLogger(__name__)

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('filenames', nargs=-1)
@click.option('-n', default=float(1e5), help='Use -1 to read all data')
@click.option('--keys', '-k', help='Plot the given keys (as a str list such that "X Y Z")', default='')
@click.option('--skip', multiple=True, help='(string) Dont plot if this str is contained in a branch name')
@click.option('--quantile', '-q', default=float(0), help='Restrict histogram to quantile')
@click.option('--nb_bins', '-b', default=int(100), help='Number of bins')
@click.option('--tree', '-t', default='PhaseSpace', help='Name of the tree in the root file')
@click.option('--shuffle', '-s', default=False, is_flag=True, help='shuffle samples when loading')
@click.option('--output', '-o', type=str, help='Do not plot, only output a pdf with the given name')
@click.option('--plot2d',
type=(str, str),
help='Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ', multiple=True)
@click.argument("filenames", nargs=-1)
@click.option("-n", default=float(1e5), help="Use -1 to read all data")
@click.option(
"--keys",
"-k",
help='Plot the given keys (as a str list such that "X Y Z")',
default="",
)
@click.option(
"--skip",
multiple=True,
help="(string) Dont plot if this str is contained in a branch name",
)
@click.option(
"--quantile", "-q", default=float(0), help="Restrict histogram to quantile"
)
@click.option("--nb_bins", "-b", default=int(100), help="Number of bins")
@click.option(
"--tree", "-t", default="PhaseSpace", help="Name of the tree in the root file"
)
@click.option(
"--shuffle", "-s", default=False, is_flag=True, help="shuffle samples when loading"
)
@click.option(
"--output",
"-o",
type=str,
help="Do not plot, only output a pdf with the given name",
)
@click.option(
"--plot2d",
type=(str, str),
help="Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ",
multiple=True,
)
@click.option(
"--ui",
"-ui",
is_flag=True,
help="Launch the interactive Streamlit dashboard UI",
)
@gt.add_options(gt.common_options)
def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, skip, output, **kwargs):
def gt_phsp_plot(
filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, skip, output, ui, **kwargs
):
"""
\b
Plot histograms
Expand All @@ -44,6 +75,22 @@ def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, s
WARNING: if several filenames, they must have the same keys
"""

if ui:
import subprocess
import sys
import os
try:
import streamlit
except ImportError:
print("Error: streamlit is not installed in the current environment.")
print("Please install it with: pip install streamlit plotly pandas")
return

ui_script = os.path.join(os.path.dirname(__file__), "..", "phsp", "phsp_plot_ui.py")
cmd = [sys.executable, "-m", "streamlit", "run", ui_script, "--"] + list(filenames)
subprocess.run(cmd)
return

# logger
gt.logging_conf(**kwargs)

Expand All @@ -62,7 +109,7 @@ def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, s
data, read_keys, m = phsp.load(filename, tree, n, shuffle)
if n == -1:
n = m
print(f'Reading {n}/{m}')
print(f"Reading {n}/{m}")

# get keys
ckeys = phsp.str_keys_to_array_keys(keys)
Expand All @@ -78,7 +125,7 @@ def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, s
add = True
for sk in skip_branches:
if sk in k:
print('Skip branch ', k)
print("Skip branch ", k)
add = False
if add:
fk.append(k)
Expand All @@ -102,67 +149,45 @@ def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, s
nfig = 0
for k in first_keys:
if k not in read_keys:
print(f'Skip key {k}: not in the first list of keys')
print(f"Skip key {k}: not in the first list of keys")
continue

# get data
index = read_keys.index(k)
x = data[:, index]

if len(x) < 1:
print(f'Skip key {k}: empty')
continue

# check validity
if type(x[0]) == str:
print(f'Skip key {k} : str')
continue
try:
a = int(x[0])
except:
print(f'Skip key {k}: not numeric? x[0] = {x[0]}')
continue
# sometimes, if x is a str (from a root file), x[0] will be 'NULL'
# (probably not the best method ; to be changed)
if x[0] == 'NULL':
print(f'Skip key {k} : not numeric? x[0] = NUL')
# clean data
x = phsp.clean_column(x, k)
if x is None:
continue

# get mean to check if nan
xmean = np.mean(x)
xmax = np.max(x)
xmin = np.min(x)
print(f'Key {k} min/mean/max: {xmin} {xmean} {xmax}')
print(f"Key {k} min/mean/max: {xmin} {xmean} {xmax}")
if np.isnan(xmean):
print(f'Skip key {k} : nan ?')
print(f"Skip key {k} : nan ?")
continue

a = phsp.fig_get_sub_fig(ax, i)
q1 = quantile
q2 = 1.0 - quantile
if filename == filenames[0]:
q[k] = (np.quantile(x, q1), np.quantile(x, q2))
if k not in q:
q[k] = (np.quantile(x, q1), np.quantile(x, q2))

label = ' {} $\\mu$={:.2f} $\\sigma$={:.2f}'.format(k, np.mean(x), np.std(x))
a.hist(x, nb_bins,
# density=True,
histtype='stepfilled',
range=q[k],
# facecolor='g',
alpha=0.5,
label=label)
# a.set_ylabel('Probability')
a.set_ylabel('Counts')
a.legend()
# plot
q = phsp.plot_column_histogram(
ax,
i,
x,
k,
nb_bins,
quantile,
filename == filenames[0],
q,
)
i = i + 1
nfig += 1

# 2D
for k in keys_2D:
a = phsp.fig_get_sub_fig(ax, i)
phsp.fig_histo2D(a, data, read_keys, k, nb_bins, 'g')
phsp.fig_histo2D(a, data, read_keys, k, nb_bins, "g")
i = i + 1

if nb_fig == 0:
Expand All @@ -179,7 +204,7 @@ def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, s
n = int(n)
m = int(m)
# plt.subplots_adjust(top=0.7)
plt.suptitle(f'Values: {n}/{m}')
plt.suptitle(f"Values: {n}/{m}")
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
if output:
plt.savefig(output)
Expand All @@ -189,5 +214,5 @@ def gt_phsp_plot(filenames, keys, n, quantile, tree, nb_bins, plot2d, shuffle, s


# --------------------------------------------------------------------------
if __name__ == '__main__':
if __name__ == "__main__":
gt_phsp_plot()
91 changes: 91 additions & 0 deletions gatetools/phsp/phsp_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,29 @@ def fig_histo2D(ax, data, keys, k, nbins, color='g'):
x = data[:, i1]
i2 = keys.index(k[1])
y = data[:, i2]

# Convert to float arrays to support object arrays and detect NaNs cleanly
try:
x = x.astype(float)
except (ValueError, TypeError):
pass
try:
y = y.astype(float)
except (ValueError, TypeError):
pass

# Filter out NaNs from both variables alignment-wise
try:
mask = ~np.isnan(x) & ~np.isnan(y)
x = x[mask]
y = y[mask]
except TypeError:
pass

if len(x) < 1 or len(y) < 1:
print(f"Skip 2D plot of {k} because data is empty (all NaN)")
return

if color == 'g':
cmap = plt.cm.Greens
if color == 'r':
Expand All @@ -383,6 +406,74 @@ def fig_histo2D(ax, data, keys, k, nbins, color='g'):
ax.set_ylabel(k[1])


# -----------------------------------------------------------------------------
def clean_column(x, key_name):
"""
Clean a column from PhaseSpace data:
- Try to convert to float (handles object arrays containing numeric values and nan)
- Filter out NaN values
- Check validity (non-strings, numeric first element, not "NULL")
Returns cleaned numpy array, or None if invalid/skipped.
"""
try:
x = x.astype(float)
except (ValueError, TypeError):
pass

try:
x = x[~np.isnan(x)]
except TypeError:
pass

if len(x) < 1:
print(f"Skip key {key_name}: empty (or all NaN)")
return None

if type(x[0]) == str:
print(f"Skip key {key_name} : str")
return None

try:
a = int(x[0])
except:
print(f"Skip key {key_name}: not numeric? x[0] = {x[0]}")
return None

if x[0] == "NULL":
print(f"Skip key {key_name} : not numeric? x[0] = NUL")
return None

return x


# -----------------------------------------------------------------------------
def plot_column_histogram(ax, i, x, k, nb_bins, quantile, is_first_file, q):
"""
Retrieve the sub-figure and plot a histogram for the cleaned column x.
Updates and returns the quantiles dictionary q.
"""
a = fig_get_sub_fig(ax, i)
q1 = quantile
q2 = 1.0 - quantile
if is_first_file or k not in q:
q[k] = (np.quantile(x, q1), np.quantile(x, q2))

label = " {} $\\mu$={:.2f} $\\sigma$={:.2f}".format(
k, np.mean(x), np.std(x)
)
a.hist(
x,
nb_bins,
histtype="stepfilled",
range=q[k],
alpha=0.5,
label=label,
)
a.set_ylabel("Counts")
a.legend()
return q


#####################################################################################
import unittest
import hashlib
Expand Down
Loading
Loading