diff --git a/meegkit/dss.py b/meegkit/dss.py index 7297b7e..0877670 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -291,10 +291,10 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5, show: bool Produce a visual output of each iteration (default=False). dirname: str - Path to the directory where visual outputs are saved when show is 'True'. + Path to the directory where visual outputs are saved when show is 'True'. If 'None', does not save the outputs. (default=None) extension: str - Extension of the images filenames. Must be compatible with plt.savefig() + Extension of the images filenames. Must be compatible with plt.savefig() function. (default=".png") n_iter_max : int Maximum number of iterations (default=100). @@ -313,11 +313,13 @@ def nan_basic_interp(array): array[nans] = np.interp(ix(nans), ix(~nans), array[~nans]) return array + data_clean = data.copy() + freq_rn = [fline - win_sz, fline + win_sz] freq_sp = [fline - spot_sz, fline + spot_sz] freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) - freq_rn_ix = np.logical_and(freq >= freq_rn[0], + freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1]) freq_used = freq[freq_rn_ix] freq_sp_ix = np.logical_and(freq_used >= freq_sp[0], @@ -338,8 +340,8 @@ def nan_basic_interp(array): aggr_resid = [] iterations = 0 while iterations < n_iter_max: - data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1) - freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) + data_clean, _ = dss_line(data_clean, fline, sfreq, nfft=nfft, nremove=1) + freq, psd = welch(data_clean, fs=sfreq, nfft=nfft, axis=0) if psd.ndim == 3: mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix] elif psd.ndim == 2: @@ -366,7 +368,7 @@ def nan_basic_interp(array): ax.flat[0].set_xlabel("Frequency (Hz)") ax.flat[0].set_ylabel("Power") - ax.flat[1].plot(freq_used, mean_psd_tf, c="gray", + ax.flat[1].plot(freq_used, mean_psd_tf, c="gray", label="Interpolated mean PSD") ax.flat[1].plot(freq_used, mean_psd, c="blue", label="Mean PSD") ax.flat[1].plot(freq_used, clean_fit_line, c="red", label="Fitted polynomial") @@ -396,6 +398,9 @@ def nan_basic_interp(array): plt.show() if mean_score <= 0: + # Return original data if score is negative on first iteration + if iterations == 0: + return data, 0 break iterations += 1 @@ -404,4 +409,4 @@ def nan_basic_interp(array): raise RuntimeError("Could not converge. Consider increasing the " "maximum number of iterations") - return data, iterations + return data_clean, iterations diff --git a/tests/test_dss.py b/tests/test_dss.py index 3816b82..bbe1e8e 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -176,6 +176,30 @@ def _plot(before, after): plt.close("all") +def test_dss_line_iter_no_noise(): + """ + Test that dss_line_iter returns original data unchanged when DSS + cannot improve the signal. + """ + sr = 200 + fline = 50 + n_samples = 9000 + n_chans = 10 + rng = np.random.RandomState(42) + + # create data without line noise at target frequency + x = rng.randn(n_samples, n_chans) + x_original = x.copy() + + x_out, n_iters = dss.dss_line_iter(x, fline, sr, n_iter_max=10) + + assert n_iters == 0, f"Expected 0 iterations (no improvement), got {n_iters}" + assert np.allclose(x_out, x_original), ( + "When DSS cannot improve signal, should return original data unchanged" + ) + assert np.allclose(x, x_original), "Input data should never be mutated" + + def profile_dss_line(nkeep): """Test line noise removal.""" import cProfile