diff --git a/process/io/plot_scans.py b/process/io/plot_scans.py index 931a80f19..af9c0c707 100644 --- a/process/io/plot_scans.py +++ b/process/io/plot_scans.py @@ -591,18 +591,7 @@ def main(args=None): # ----------- for index, output_name in enumerate(output_names): if stack_plots: - # check stack plots will work - if len(output_names) <= 1: - raise ValueError( - "For stack plots to be used need more than 1 output variable" - ) - fig, axs = plt.subplots( - len(output_names), - 1, - figsize=(8.0, (3.5 + (1 * len(output_names)))), - sharex=True, - ) - fig.subplots_adjust(hspace=0.0) + pass else: fig, ax = plt.subplots() if output_names2 != []: @@ -684,11 +673,25 @@ def main(args=None): plt.tight_layout() else: if stack_plots: - axs[output_names.index(output_name)].plot( + # check stack plots will work + if len(output_names) <= 1: + raise ValueError( + "For stack plots need more than 1 output variable" + ) + # Create subplots only once for the first output + if index == 0: + fig, axs = plt.subplots( + len(output_names), + 1, + figsize=(8.0, (3.5 + (1 * len(output_names)))), + sharex=True, + ) + fig.subplots_adjust(hspace=0.0) + + axs[index].plot( scan_var_array[input_file], output_arrays[input_file][output_name], "--o", - color="blue" if output_names2 != [] else None, label=labl, ) if y_axis_range != []: