Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,10 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
show: bool
Produce a visual output of each iteration (default=False).
dirname: str
Path to the directory where visual outputs are saved when show is 'True'.
Path to the directory where visual outputs are saved when show is 'True'.
If 'None', does not save the outputs. (default=None)
extension: str
Extension of the images filenames. Must be compatible with plt.savefig()
Extension of the images filenames. Must be compatible with plt.savefig()
function. (default=".png")
n_iter_max : int
Maximum number of iterations (default=100).
Expand All @@ -313,11 +313,13 @@ def nan_basic_interp(array):
array[nans] = np.interp(ix(nans), ix(~nans), array[~nans])
return array

data_clean = data.copy()

freq_rn = [fline - win_sz, fline + win_sz]
freq_sp = [fline - spot_sz, fline + spot_sz]
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)

freq_rn_ix = np.logical_and(freq >= freq_rn[0],
freq_rn_ix = np.logical_and(freq >= freq_rn[0],
freq <= freq_rn[1])
freq_used = freq[freq_rn_ix]
freq_sp_ix = np.logical_and(freq_used >= freq_sp[0],
Expand All @@ -338,8 +340,8 @@ def nan_basic_interp(array):
aggr_resid = []
iterations = 0
while iterations < n_iter_max:
data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1)
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
data_clean, _ = dss_line(data_clean, fline, sfreq, nfft=nfft, nremove=1)
freq, psd = welch(data_clean, fs=sfreq, nfft=nfft, axis=0)
if psd.ndim == 3:
mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix]
elif psd.ndim == 2:
Expand All @@ -366,7 +368,7 @@ def nan_basic_interp(array):
ax.flat[0].set_xlabel("Frequency (Hz)")
ax.flat[0].set_ylabel("Power")

ax.flat[1].plot(freq_used, mean_psd_tf, c="gray",
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray",
label="Interpolated mean PSD")
ax.flat[1].plot(freq_used, mean_psd, c="blue", label="Mean PSD")
ax.flat[1].plot(freq_used, clean_fit_line, c="red", label="Fitted polynomial")
Expand Down Expand Up @@ -396,6 +398,9 @@ def nan_basic_interp(array):
plt.show()

if mean_score <= 0:
# Return original data if score is negative on first iteration
if iterations == 0:
return data, 0
break

iterations += 1
Expand All @@ -404,4 +409,4 @@ def nan_basic_interp(array):
raise RuntimeError("Could not converge. Consider increasing the "
"maximum number of iterations")

return data, iterations
return data_clean, iterations
24 changes: 24 additions & 0 deletions tests/test_dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,30 @@ def _plot(before, after):
plt.close("all")


def test_dss_line_iter_no_noise():
"""
Test that dss_line_iter returns original data unchanged when DSS
cannot improve the signal.
"""
sr = 200
fline = 50
n_samples = 9000
n_chans = 10
rng = np.random.RandomState(42)

# create data without line noise at target frequency
x = rng.randn(n_samples, n_chans)
x_original = x.copy()

x_out, n_iters = dss.dss_line_iter(x, fline, sr, n_iter_max=10)

assert n_iters == 0, f"Expected 0 iterations (no improvement), got {n_iters}"
assert np.allclose(x_out, x_original), (
"When DSS cannot improve signal, should return original data unchanged"
)
assert np.allclose(x, x_original), "Input data should never be mutated"


def profile_dss_line(nkeep):
"""Test line noise removal."""
import cProfile
Expand Down