From 0ce8156221ce89f580565aa5d3753fe59551ea7b Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sat, 24 Jan 2026 13:56:57 +0100 Subject: [PATCH 1/9] [ENH] add zapline plus --- meegkit/dss_zapline.py | 965 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 965 insertions(+) create mode 100644 meegkit/dss_zapline.py diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py new file mode 100644 index 0000000..8545c62 --- /dev/null +++ b/meegkit/dss_zapline.py @@ -0,0 +1,965 @@ +"""Zapline-plus for automatic removal of frequency-specific noise artifacts. + +This module implements Zapline-plus, an extension of the Zapline algorithm +that enables fully automatic removal of line noise and other frequency-specific +artifacts from M/EEG data. + +Based on: +Klug, M., & Kloosterman, N. A. (2022). Zapline-plus: A Zapline extension for +automatic and adaptive removal of frequency-specific noise artifacts in M/EEG. +Human Brain Mapping, 43(9), 2743-2758. + +Original Zapline by: +de Cheveigné, A. (2020). ZapLine: A simple and effective method to remove +power line artifacts. NeuroImage, 207, 116356. + + +Differences from Matlab implementation: + +Finding noise frequencies: + - one iteration returning all frequencies + +Adaptive chunking: + - merged chunks at edges if too short + + + + +""" + +import logging +from typing import dict, list, optional, tuple, union + +import matplotlib.pyplot as plt +import numpy as np +from scipy import signal + +from .dss import dss_line + + +def zapline_plus( + data: np.ndarray, + sfreq: float, + fline: optional[union[float, list[float]]] = None, + nkeep: int = 0, + adaptiveNremove: bool = True, + fixedNremove: int = 1, + minfreq: float = 17.0, + maxfreq: float = 99.0, + chunkLength: float = 0.0, + minChunkLength: float = 30.0, + noiseCompDetectSigma: float = 3.0, + adaptiveSigma: bool = True, + minsigma: float = 2.5, + maxsigma: float = 4.0, + detectionWinsize: float = 6.0, + coarseFreqDetectPowerDiff: float = 4.0, + coarseFreqDetectLowerPowerDiff: float = 1.76, + searchIndividualNoise: bool = True, + freqDetectMultFine: float = 2.0, + detailedFreqBoundsUpper: tuple[float, float] = (0.05, 0.05), + detailedFreqBoundsLower: tuple[float, float] = (0.4, 0.1), + maxProportionAboveUpper: float = 0.005, + maxProportionBelowLower: float = 0.005, + plotResults: bool = False, + figsize: tuple[int, int] = (14, 10), + vanilla_mode: bool = False, +) -> tuple[np.ndarray, dict]: + """Remove line noise and other frequency-specific artifacts using Zapline-plus. + + Parameters + ---------- + data : array, shape=(n_chans, n_times) + Input data. + sfreq : float + Sampling frequency in Hz. + fline : float | list of float | None + Noise frequency or frequencies to remove. If None, frequencies are + detected automatically. Defaults to None. + nkeep : int + Number of principal components to keep in DSS. If 0, no dimensionality + reduction is applied. Defaults to 0. + adaptiveNremove : bool + If True, automatically detect the number of components to remove. + If False, use fixedNremove for all chunks. Defaults to True. + fixedNremove : int + Fixed number of components to remove per chunk. Used when + adaptiveNremove=False, or as minimum when adaptiveNremove=True. + Defaults to 1. + minfreq : float + Minimum frequency (Hz) to consider when detecting noise automatically. + Defaults to 17.0. + maxfreq : float + Maximum frequency (Hz) to consider when detecting noise automatically. + Defaults to 99.0. + chunkLength : float + Length of chunks (seconds) for cleaning. If 0, adaptive chunking based + on noise covariance stability is used. Set to -1 via vanilla_mode to + process the entire recording as a single chunk. Defaults to 0.0. + minChunkLength : float + Minimum chunk length (seconds) when using adaptive chunking. + Defaults to 30.0. + noiseCompDetectSigma : float + Initial SD threshold for iterative outlier detection of noise components. + Defaults to 3.0. + adaptiveSigma : bool + If True, automatically adapt noiseCompDetectSigma and fixedNremove + based on cleaning results. Defaults to True. + minsigma : float + Minimum SD threshold when adapting noiseCompDetectSigma. + Defaults to 2.5. + maxsigma : float + Maximum SD threshold when adapting noiseCompDetectSigma. + Defaults to 4.0. + detectionWinsize : float + Window size (Hz) for noise frequency detection. Defaults to 6.0. + coarseFreqDetectPowerDiff : float + Threshold (10*log10) above center power to detect a peak as noise. + Defaults to 4.0. + coarseFreqDetectLowerPowerDiff : float + Threshold (10*log10) above center power to detect end of noise peak. + Defaults to 1.76. + searchIndividualNoise : bool + If True, search for individual noise peaks in each chunk. + Defaults to True. + freqDetectMultFine : float + Multiplier for fine noise frequency detection threshold. Defaults to 2.0. + detailedFreqBoundsUpper : tuple of float + Frequency boundaries (Hz) for fine threshold of too weak cleaning. + Defaults to (0.05, 0.05). + detailedFreqBoundsLower : tuple of float + Frequency boundaries (Hz) for fine threshold of too strong cleaning. + Defaults to (0.4, 0.1). + maxProportionAboveUpper : float + Maximum proportion of samples above upper threshold before adapting. + Defaults to 0.005. + maxProportionBelowLower : float + Maximum proportion of samples below lower threshold before adapting. + Defaults to 0.005. + plotResults : bool + If True, generate diagnostic plots for each cleaned frequency. + Defaults to False. + figsize : tuple of int + Figure size for diagnostic plots. Defaults to (14, 10). + vanilla_mode : bool + If True, disable all Zapline-plus features and use vanilla Zapline behavior: + - Process entire dataset as single chunk + - Use fixed component removal (no adaptive detection) + - No individual chunk frequency detection + - No adaptive parameter tuning + Requires fline to be specified (not None). Defaults to False. + + Returns + ------- + clean_data : array, shape=(n_chans, n_times) + Cleaned data. + config : dict + Configuration dictionary containing all parameters and analytics. + + Notes + ----- + The algorithm proceeds as follows: + 1. Detect noise frequencies (if not provided) + 2. Segment data into chunks with stable noise topography + 3. Apply Zapline to each chunk + 4. Automatically detect and remove noise components + 5. Adapt parameters if cleaning is too weak or too strong + + Examples + -------- + Remove 50 Hz line noise automatically: + >>> clean_data, config = zapline_plus(data, sfreq=500, fline=50) + + Remove line noise with automatic frequency detection: + >>> clean_data, config = zapline_plus(data, sfreq=500) + + """ + n_chans, n_times = data.shape + + # Handle vanilla mode + if vanilla_mode: + logging.warning( + "vanilla_mode=True: Using vanilla Zapline behavior. " + "All adaptive features disabled." + ) + if fline is None: + raise ValueError("vanilla_mode requires fline to be specified (not None)") + + for param_name in [ + "adaptiveNremove", + "adaptiveSigma", + "searchIndividualNoise", + ]: + if locals()[param_name]: + logging.warning(f"vanilla_mode=True: Overriding {param_name} to False.") + + # Override all adaptive features + adaptiveNremove = False + adaptiveSigma = False + searchIndividualNoise = False + chunkLength = -1 # Zapline vanilla deals with single chunk + + # check for globally flat channels + diff_data = np.diff(data, axis=1) + global_flat = np.where(np.all(diff_data == 0, axis=1))[0] + + if len(global_flat) > 0: + logging.warning( + f"Detected {len(global_flat)} globally flat channels: {global_flat}. " + f"Removing for processing, will add back after." + ) + flat_data = data[global_flat, :] + active_channels = np.setdiff1d(np.arange(n_chans), global_flat) + data = data[active_channels, :] + else: + active_channels = np.arange(n_chans) + flat_data = None + + # Initialize configuration + config = { + "sfreq": sfreq, + "fline": fline, + "nkeep": nkeep, + "adaptiveNremove": adaptiveNremove, + "fixedNremove": fixedNremove, + "minfreq": minfreq, + "maxfreq": maxfreq, + "chunkLength": chunkLength, + "minChunkLength": minChunkLength, + "noiseCompDetectSigma": noiseCompDetectSigma, + "adaptiveSigma": adaptiveSigma, + "minsigma": minsigma, + "maxsigma": maxsigma, + "detectionWinsize": detectionWinsize, + "coarseFreqDetectPowerDiff": coarseFreqDetectPowerDiff, + "coarseFreqDetectLowerPowerDiff": coarseFreqDetectLowerPowerDiff, + "searchIndividualNoise": searchIndividualNoise, + "freqDetectMultFine": freqDetectMultFine, + "detailedFreqBoundsUpper": detailedFreqBoundsUpper, + "detailedFreqBoundsLower": detailedFreqBoundsLower, + "maxProportionAboveUpper": maxProportionAboveUpper, + "maxProportionBelowLower": maxProportionBelowLower, + "analytics": {}, + } + + # Detect noise frequencies if not provided + if fline is None: + fline = _detect_noise_frequencies( + data, + sfreq, + minfreq, + maxfreq, + detectionWinsize, + coarseFreqDetectPowerDiff, + coarseFreqDetectLowerPowerDiff, + ) + elif not isinstance(fline, list): + fline = [fline] + + if len(fline) == 0: + logging.warning("No noise frequencies detected. Returning original data.") + return data.copy(), config + + config["detected_fline"] = fline + + # retain input data + clean_data = data.copy() + + # Process each noise frequency + for freq_idx, target_freq in enumerate(fline): + print(f"Processing noise frequency: {target_freq:.2f} Hz") + + # Adaptive chunking or fixed chunks + if chunkLength == -1: + # single chunk + chunks = [(0, n_times)] + elif chunkLength == 0: + chunks = _adaptive_chunking(clean_data, sfreq, target_freq, minChunkLength) + else: + chunk_samples = int(chunkLength * sfreq) + chunks = [ + (i, min(i + chunk_samples, n_times)) + for i in range(0, n_times, chunk_samples) + ] + + # Initialize tracking variables + current_sigma = noiseCompDetectSigma + current_fixed = fixedNremove + too_strong_once = False + iteration = 0 + max_iterations = 20 + + while iteration < max_iterations: + iteration += 1 + + # Clean each chunk + chunk_results = [] + for chunk_start, chunk_end in chunks: + chunk_data = clean_data[:, chunk_start:chunk_end] + + # Detect chunk-specific noise frequency + if searchIndividualNoise: + chunk_freq, has_noise = _detect_chunk_noise_frequency( + chunk_data, + sfreq, + target_freq, + detectionWinsize, + freqDetectMultFine, + ) + else: + chunk_freq = target_freq + has_noise = True + + # Apply Zapline to chunk + if has_noise: + if adaptiveNremove: + n_remove = _detect_noise_components( + chunk_data, sfreq, chunk_freq, current_sigma, nkeep + ) + n_remove = max(n_remove, current_fixed) + else: + n_remove = current_fixed + + # Cap at 1/5 of components + n_remove = min(n_remove, n_chans // 5) + else: + n_remove = current_fixed + + # clean chunk + cleaned_chunk = _apply_zapline_to_chunk( + chunk_data, sfreq, chunk_freq, n_remove, nkeep + ) + + chunk_results.append( + { + "start": chunk_start, + "end": chunk_end, + "freq": chunk_freq, + "n_remove": n_remove, + "has_noise": has_noise, + "data": cleaned_chunk, + } + ) + + # Reconstruct cleaned data + temp_clean = clean_data.copy() + for result in chunk_results: + temp_clean[:, result["start"] : result["end"]] = result["data"] + + # Check if cleaning is optimal + cleaning_status = _check_cleaning_quality( + data, + temp_clean, + sfreq, + target_freq, + detectionWinsize, + freqDetectMultFine, + detailedFreqBoundsUpper, + detailedFreqBoundsLower, + maxProportionAboveUpper, + maxProportionBelowLower, + ) + + # Store analytics + config["analytics"][f"freq_{freq_idx}"] = { + "target_freq": target_freq, + "iteration": iteration, + "sigma": current_sigma, + "fixed_nremove": current_fixed, + "n_chunks": len(chunks), + "chunk_results": chunk_results, + "cleaning_status": cleaning_status, + } + + # Check if we need to adapt + if cleaning_status == "good": + clean_data = temp_clean + break + elif cleaning_status == "too_weak" and not too_strong_once: + current_sigma = max(current_sigma - 0.25, minsigma) + current_fixed += 1 + print( + f" Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + elif cleaning_status == "too_strong": + too_strong_once = True + current_sigma = min(current_sigma + 0.25, maxsigma) + current_fixed = max(current_fixed - 1, fixedNremove) + print( + f" Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + # Too strong takes precedence, or we can't improve further + clean_data = temp_clean + break + + # Generate diagnostic plot + if plotResults: + _plot_cleaning_results( + data, + clean_data, + sfreq, + target_freq, + config["analytics"][f"freq_{freq_idx}"], + figsize, + ) + + # add flat channels back to data, if present + if flat_data is not None: + full_clean = np.zeros((n_chans, n_times)) + full_clean[active_channels, :] = clean_data + full_clean[global_flat, :] = flat_data + clean_data = full_clean + + return clean_data, config + + +def _detect_noise_frequencies( + data, sfreq, minfreq, maxfreq, winsize, power_diff_high, power_diff_low +): + """ + Detect noise frequencies. + + This is an exact implementation of find_next_noisefreq.m with the only difference + that all peaks are returned instead of this being called iteratively. + + How it works + ------------ + 1. Compute PSD and log-transform. + 2. Slide a window across frequencies from minfreq to maxfreq. + 3. For each frequency, compute center power as mean of left and right thirds. + 4. Use a state machine to detect peaks: + - SEARCHING: If current power - center power > power_diff_high, + mark peak start and switch to IN_PEAK. + - IN_PEAK: If current power - center power <= power_diff_low, + mark peak end, find max within peak, record frequency, + and switch to SEARCHING. + 5. Return list of detected noise frequencies. + """ + # Compute PSD + freqs, psd = _compute_psd(data, sfreq) + log_psd = 10 * np.log10(np.mean(psd, axis=0)) + + # State machine variables + in_peak = False + peak_start_idx = None + noise_freqs = [] + + # Search bounds + start_idx = np.searchsorted(freqs, minfreq) + end_idx = np.searchsorted(freqs, maxfreq) + + # Window size in samples + freq_resolution = freqs[1] - freqs[0] + win_samples = int(winsize / freq_resolution) + + idx = start_idx + while idx < end_idx: + # Get window around current frequency + win_start = max(0, idx - win_samples // 2) + win_end = min(len(freqs), idx + win_samples // 2) + win_psd = log_psd[win_start:win_end] + + if len(win_psd) < 3: + idx += 1 + continue + + # Compute center power (mean of left and right thirds) + n_third = len(win_psd) // 3 + if n_third < 1: + idx += 1 + continue + + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center_power = np.mean([np.mean(left_third), np.mean(right_third)]) + + current_power = log_psd[idx] + + # State machine logic + if not in_peak: + # State: SEARCHING - Check for peak start + if current_power - center_power > power_diff_high: + in_peak = True + peak_start_idx = idx + + else: + # State: IN_PEAK - Check for peak end + if current_power - center_power <= power_diff_low: + in_peak = False + peak_end_idx = idx + + # Find the actual maximum within the peak + if peak_start_idx is not None and peak_end_idx > peak_start_idx: + peak_region = log_psd[peak_start_idx:peak_end_idx] + max_offset = np.argmax(peak_region) + max_idx = peak_start_idx + max_offset + noise_freqs.append(freqs[max_idx]) + + # Skip past this peak to avoid re-detection + idx = peak_end_idx + continue + + idx += 1 + + return noise_freqs + + +def _adaptive_chunking( + data, + sfreq, + target_freq, + min_chunk_length, + detection_winsize=6.0, + prominence_quantile=0.95, +): + """Segment data into chunks with stable noise topography.""" + n_chans, n_times = data.shape + + if n_times < sfreq * min_chunk_length: + logging.warning("Data too short for adaptive chunking. Using single chunk.") + return [(0, n_times)] + + n_chans, n_times = data.shape + + # Narrow-band filter around target frequency + bandwidth = detection_winsize / 2.0 + filtered = _narrowband_filter(data, sfreq, target_freq, bandwidth=bandwidth) + + # Compute covariance matrices for 1-second epochs + epoch_length = int(sfreq) + n_epochs = n_times // epoch_length + + distances = np.zeros(n_epochs) + prev_cov = None + + for i in range(n_epochs): + start = i * epoch_length + end = start + epoch_length + epoch = filtered[:, start:end] + cov = np.cov(epoch) + + if prev_cov is not None: + # Frobenius norm of difference + distances[i] = np.linalg.norm(cov - prev_cov, "fro") + # else: distance[i] already 0 from initialization + + prev_cov = cov + + if len(distances) < 2: + return [(0, n_times)] + + # find all peaks to get prominence distribution + peaks_all, properties_all = signal.find_peaks(distances, prominence=0) + + if len(peaks_all) == 0 or "prominences" not in properties_all: + # No peaks found + logging.warning("No peaks found in distance signal. Using single chunk.") + return [(0, n_times)] + + prominences = properties_all["prominences"] + + # filter by prominence quantile + min_prominence = np.quantile(prominences, prominence_quantile) + min_distance_epochs = int(min_chunk_length) # Convert seconds to epochs + + peaks, properties = signal.find_peaks( + distances, prominence=min_prominence, distance=min_distance_epochs + ) + + # cconvert peak locations (in epochs) to sample indices + chunk_starts = [0] + for peak in peaks: + chunk_start_sample = peak * epoch_length + chunk_starts.append(chunk_start_sample) + chunk_starts.append(n_times) + + # create chunk list + chunks = [] + for i in range(len(chunk_starts) - 1): + start = chunk_starts[i] + end = chunk_starts[i + 1] + chunks.append((start, end)) + + # ensure minimum chunk length at edges + min_chunk_samples = int(min_chunk_length * sfreq) + + if len(chunks) > 1: + # check first chunk + if chunks[0][1] - chunks[0][0] < min_chunk_samples: + # merge with next + chunks[1] = (chunks[0][0], chunks[1][1]) + chunks.pop(0) + + if len(chunks) > 1: + # check last chunk + if chunks[-1][1] - chunks[-1][0] < min_chunk_samples: + # merge with previous + chunks[-2] = (chunks[-2][0], chunks[-1][1]) + chunks.pop(-1) + + return chunks + + +def _detect_chunk_noise_frequency(data, sfreq, target_freq, winsize, mult_fine): + """Detect chunk-specific noise frequency around target.""" + freqs, psd = _compute_psd(data, sfreq) + log_psd = 10 * np.log10(np.mean(psd, axis=0)) + + # Search in ±0.05 Hz range + search_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) + if not np.any(search_mask): + return target_freq, False + + search_freqs = freqs[search_mask] + search_psd = log_psd[search_mask] + + # Find peak + peak_idx = np.argmax(search_psd) + peak_freq = search_freqs[peak_idx] + peak_power = search_psd[peak_idx] + + # Compute threshold + win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) + win_psd = log_psd[win_mask] + + n_third = len(win_psd) // 3 + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center = np.mean([np.mean(left_third), np.mean(right_third)]) + + # Compute deviation (lower 5% quantiles) + lower_quant_left = np.percentile(left_third, 5) + lower_quant_right = np.percentile(right_third, 5) + deviation = center - np.mean([lower_quant_left, lower_quant_right]) + + threshold = center + mult_fine * deviation + + has_noise = peak_power > threshold + + return peak_freq, has_noise + + +def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): + """Detect number of noise components to remove using outlier detection.""" + # Apply DSS to get component scores + _, scores = dss_line(data, target_freq, sfreq, nkeep=nkeep) + + if scores is None or len(scores) == 0: + return 1 + + # Sort scores in descending order + sorted_scores = np.sort(scores)[::-1] + + # Iterative outlier detection + n_remove = 0 + remaining = sorted_scores.copy() + + while len(remaining) > 1: + mean_val = np.mean(remaining) + std_val = np.std(remaining) + threshold = mean_val + sigma * std_val + + if remaining[0] > threshold: + n_remove += 1 + remaining = remaining[1:] + else: + break + + return max(n_remove, 1) + + +def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): + """Apply Zapline to a single chunk, handling flat channels.""" + n_chans, n_samples = chunk_data.shape + + # Detect flat channels (zero variance) + diff_chunk = np.diff(chunk_data, axis=1) + flat_channels = np.where(np.all(diff_chunk == 0, axis=1))[0] + + if len(flat_channels) > 0: + logging.warning( + f"Detected {len(flat_channels)} flat channels in chunk: {flat_channels}. " + f"Removing temporarily for processing." + ) + + # store flat channel data + flat_channel_data = chunk_data[flat_channels, :] + + # remove flat channels from processing + active_channels = np.setdiff1d(np.arange(n_chans), flat_channels) + chunk_data_active = chunk_data[active_channels, :] + + # process only active channels + cleaned_active, _ = dss_line( + chunk_data_active, + fline=chunk_freq, + sfreq=sfreq, + nremove=n_remove, + nkeep=nkeep, + ) + + # Reconstruct full data with flat channels + cleaned_chunk = np.zeros_like(chunk_data) + cleaned_chunk[active_channels, :] = cleaned_active + cleaned_chunk[flat_channels, :] = ( + flat_channel_data # Add flat channels back unchanged + ) + + else: + # no flat channels, process normally + cleaned_chunk, _ = dss_line( + chunk_data, + fline=chunk_freq, + sfreq=sfreq, + nremove=n_remove, + nkeep=nkeep, + ) + + return cleaned_chunk + + +def _check_cleaning_quality( + original_data, + cleaned_data, + sfreq, + target_freq, + winsize, + mult_fine, + bounds_upper, + bounds_lower, + max_prop_above, + max_prop_below, +): + """Check if cleaning is too weak, too strong, or good.""" + # Compute PSDs + freqs, psd_clean = _compute_psd(cleaned_data, sfreq) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + + # Compute fine thresholds + win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) + win_psd = log_psd_clean[win_mask] + + n_third = len(win_psd) // 3 + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center = np.mean([np.mean(left_third), np.mean(right_third)]) + + # Deviation from lower quantiles + lower_quant_left = np.percentile(left_third, 5) + lower_quant_right = np.percentile(right_third, 5) + deviation = center - np.mean([lower_quant_left, lower_quant_right]) + + # Upper threshold (too weak cleaning) + upper_mask = (freqs >= target_freq - bounds_upper[0]) & ( + freqs <= target_freq + bounds_upper[1] + ) + upper_threshold = center + mult_fine * deviation + upper_psd = log_psd_clean[upper_mask] + prop_above = np.mean(upper_psd > upper_threshold) + + # Lower threshold (too strong cleaning) + lower_mask = (freqs >= target_freq - bounds_lower[0]) & ( + freqs <= target_freq + bounds_lower[1] + ) + lower_threshold = center - mult_fine * deviation + lower_psd = log_psd_clean[lower_mask] + prop_below = np.mean(lower_psd < lower_threshold) + + if prop_below > max_prop_below: + return "too_strong" + elif prop_above > max_prop_above: + return "too_weak" + else: + return "good" + + +def _compute_psd(data, sfreq, nperseg=None): + """Compute power spectral density using Welch's method.""" + if nperseg is None: + nperseg = int(sfreq * 4) # 4-second windows + + freqs, psd = signal.welch( + data, + fs=sfreq, + window="hann", + nperseg=nperseg, + axis=-1, + ) + + return freqs, psd + + +def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): + """Apply narrow-band filter around center frequency.""" + nyq = sfreq / 2 + low = (center_freq - bandwidth) / nyq + high = (center_freq + bandwidth) / nyq + + # Ensure valid frequency range + low = max(low, 0.001) + high = min(high, 0.999) + + sos = signal.butter(4, [low, high], btype="band", output="sos") + filtered = signal.sosfiltfilt(sos, data, axis=-1) + + return filtered + + +def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, figsize): + """Generate diagnostic plots for cleaning results.""" + fig = plt.figure(figsize=figsize) + gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3) + + # Compute PSDs + freqs, psd_orig = _compute_psd(original, sfreq) + _, psd_clean = _compute_psd(cleaned, sfreq) + + log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=0)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + + # 1. Zoomed spectrum around noise frequency + ax1 = fig.add_subplot(gs[0, 0]) + zoom_mask = (freqs >= target_freq - 1.1) & (freqs <= target_freq + 1.1) + ax1.plot(freqs[zoom_mask], log_psd_orig[zoom_mask], "k-", label="Original") + ax1.set_xlabel("Frequency (Hz)") + ax1.set_ylabel("Power (dB)") + ax1.set_title(f"Detected frequency: {target_freq:.2f} Hz") + ax1.legend() + ax1.grid(True, alpha=0.3) + + # 2. Number of removed components per chunk + ax2 = fig.add_subplot(gs[0, 1]) + chunk_results = analytics["chunk_results"] + n_removes = [cr["n_remove"] for cr in chunk_results] + ax2.bar(range(len(n_removes)), n_removes) + ax2.set_xlabel("Chunk") + ax2.set_ylabel("# Components removed") + ax2.set_title(f"Removed components (mean={np.mean(n_removes):.1f})") + ax2.grid(True, alpha=0.3) + + # 3. Individual noise frequencies per chunk + ax3 = fig.add_subplot(gs[0, 2]) + chunk_freqs = [cr["freq"] for cr in chunk_results] + time_min = np.array([cr["start"] for cr in chunk_results]) / sfreq / 60 + ax3.plot(time_min, chunk_freqs, "o-") + ax3.set_xlabel("Time (minutes)") + ax3.set_ylabel("Frequency (Hz)") + ax3.set_title("Individual noise frequencies") + ax3.grid(True, alpha=0.3) + + # 4. Component scores (would need actual scores from DSS) + ax4 = fig.add_subplot(gs[0, 3]) + ax4.text( + 0.5, + 0.5, + "Component scores\n(requires DSS output)", + ha="center", + va="center", + transform=ax4.transAxes, + ) + ax4.set_title("Mean artifact scores") + + # 5. Cleaned spectrum (zoomed) + ax5 = fig.add_subplot(gs[1, 0]) + ax5.plot(freqs[zoom_mask], log_psd_clean[zoom_mask], "g-", label="Cleaned") + ax5.set_xlabel("Frequency (Hz)") + ax5.set_ylabel("Power (dB)") + ax5.set_title("Cleaned spectrum") + ax5.legend() + ax5.grid(True, alpha=0.3) + + # 6. Full spectrum + ax6 = fig.add_subplot(gs[1, 1]) + ax6.plot(freqs, log_psd_orig, "k-", alpha=0.5, label="Original") + ax6.plot(freqs, log_psd_clean, "g-", label="Cleaned") + ax6.axvline(target_freq, color="r", linestyle="--", alpha=0.5) + ax6.set_xlabel("Frequency (Hz)") + ax6.set_ylabel("Power (dB)") + ax6.set_title("Full power spectrum") + ax6.legend() + ax6.grid(True, alpha=0.3) + ax6.set_xlim([0, 100]) + + # 7. Removed power (ratio) + ax7 = fig.add_subplot(gs[1, 2]) + noise_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) + ratio_orig = np.mean(psd_orig[:, noise_mask]) / np.mean(psd_orig) + ratio_clean = np.mean(psd_clean[:, noise_mask]) / np.mean(psd_clean) + + ax7.text( + 0.5, + 0.6, + f"Original ratio: {ratio_orig:.2f}", + ha="center", + transform=ax7.transAxes, + ) + ax7.text( + 0.5, + 0.4, + f"Cleaned ratio: {ratio_clean:.2f}", + ha="center", + transform=ax7.transAxes, + ) + ax7.set_title("Noise/surroundings ratio") + ax7.axis("off") + + # 8. Below noise frequencies + ax8 = fig.add_subplot(gs[1, 3]) + below_mask = (freqs >= target_freq - 11) & (freqs <= target_freq - 1) + ax8.plot( + freqs[below_mask], log_psd_orig[below_mask], "k-", alpha=0.5, label="Original" + ) + ax8.plot(freqs[below_mask], log_psd_clean[below_mask], "g-", label="Cleaned") + ax8.set_xlabel("Frequency (Hz)") + ax8.set_ylabel("Power (dB)") + ax8.set_title("Power below noise frequency") + ax8.legend() + ax8.grid(True, alpha=0.3) + + plt.suptitle( + f"Zapline-plus cleaning results: {target_freq:.2f} Hz " + f"(iteration {analytics['iteration']})", + fontsize=14, + y=0.98, + ) + + plt.show() + + return fig + + +# Convenience function with simpler interface +def remove_line_noise( + data: np.ndarray, sfreq: float, fline: optional[float] = None, **kwargs +) -> np.ndarray: + """Remove line noise from data using Zapline-plus. + + This is a simplified interface to zapline_plus() that returns only + the cleaned data. + + Parameters + ---------- + data : array, shape=(n_chans, n_times) + Input data. + sfreq : float + Sampling frequency in Hz. + fline : float | None + Line noise frequency. If None, automatically detected. + **kwargs + Additional arguments passed to zapline_plus(). + + Returns + ------- + clean_data : array, shape=(n_chans, n_times) + Cleaned data. + + Examples + -------- + >>> clean = remove_line_noise(data, sfreq=500, fline=50) + + """ + clean_data, _ = zapline_plus(data, sfreq, fline=fline, **kwargs) + return clean_data From 4b0bc2cbec86a37faa3f78395148ef3a519960b4 Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 14:52:52 +0100 Subject: [PATCH 2/9] =?UTF-8?q?change=20(n=5Fchans,=20n=5Ftimes)=20?= =?UTF-8?q?=E2=86=92=20(n=5Ftimes,=20n=5Fchans)=20to=20adhere=20to=20codeb?= =?UTF-8?q?ase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- meegkit/dss_zapline.py | 169 ++++++++++++++++++++++++----------------- 1 file changed, 100 insertions(+), 69 deletions(-) diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py index 8545c62..684a459 100644 --- a/meegkit/dss_zapline.py +++ b/meegkit/dss_zapline.py @@ -17,30 +17,31 @@ Differences from Matlab implementation: Finding noise frequencies: - - one iteration returning all frequencies +- one iteration returning all frequencies Adaptive chunking: - - merged chunks at edges if too short +- merged chunks at edges if too short +Plotting: +- only once per frequency after cleaning """ import logging -from typing import dict, list, optional, tuple, union import matplotlib.pyplot as plt import numpy as np from scipy import signal -from .dss import dss_line +from meegkit.dss import dss_line def zapline_plus( data: np.ndarray, sfreq: float, - fline: optional[union[float, list[float]]] = None, + fline: float | list[float] | None = None, nkeep: int = 0, adaptiveNremove: bool = True, fixedNremove: int = 1, @@ -69,7 +70,7 @@ def zapline_plus( Parameters ---------- - data : array, shape=(n_chans, n_times) + data : array, shape=(n_times, n_chans) Input data. sfreq : float Sampling frequency in Hz. @@ -151,7 +152,7 @@ def zapline_plus( Returns ------- - clean_data : array, shape=(n_chans, n_times) + clean_data : array, shape=(n_times, n_chans) Cleaned data. config : dict Configuration dictionary containing all parameters and analytics. @@ -174,9 +175,9 @@ def zapline_plus( >>> clean_data, config = zapline_plus(data, sfreq=500) """ - n_chans, n_times = data.shape + n_times, n_chans = data.shape - # Handle vanilla mode + # Handle vanilla mode (ZapLine without plus) if vanilla_mode: logging.warning( "vanilla_mode=True: Using vanilla Zapline behavior. " @@ -199,23 +200,27 @@ def zapline_plus( searchIndividualNoise = False chunkLength = -1 # Zapline vanilla deals with single chunk - # check for globally flat channels - diff_data = np.diff(data, axis=1) - global_flat = np.where(np.all(diff_data == 0, axis=1))[0] + # if nothing is adaptive, only one iteration per frequency + if not (adaptiveNremove and adaptiveSigma): + max_iterations = 1 + # check for globally flat channels + # will be omitted during processing and reintroduced later + diff_data = np.diff(data, axis=0) + global_flat = np.where(np.all(diff_data == 0, axis=0))[0] if len(global_flat) > 0: logging.warning( f"Detected {len(global_flat)} globally flat channels: {global_flat}. " f"Removing for processing, will add back after." ) - flat_data = data[global_flat, :] + flat_data = data[:, global_flat] active_channels = np.setdiff1d(np.arange(n_chans), global_flat) - data = data[active_channels, :] + data = data[:, active_channels] else: active_channels = np.arange(n_chans) flat_data = None - # Initialize configuration + # initialize configuration config = { "sfreq": sfreq, "fline": fline, @@ -242,7 +247,7 @@ def zapline_plus( "analytics": {}, } - # Detect noise frequencies if not provided + # detect noise frequencies if not provided if fline is None: fline = _detect_noise_frequencies( data, @@ -257,7 +262,7 @@ def zapline_plus( fline = [fline] if len(fline) == 0: - logging.warning("No noise frequencies detected. Returning original data.") + logging.info("No noise frequencies detected. Returning original data.") return data.copy(), config config["detected_fline"] = fline @@ -269,20 +274,21 @@ def zapline_plus( for freq_idx, target_freq in enumerate(fline): print(f"Processing noise frequency: {target_freq:.2f} Hz") - # Adaptive chunking or fixed chunks if chunkLength == -1: # single chunk chunks = [(0, n_times)] elif chunkLength == 0: + # adaptive chunking chunks = _adaptive_chunking(clean_data, sfreq, target_freq, minChunkLength) else: + # fixed-length chunks chunk_samples = int(chunkLength * sfreq) chunks = [ (i, min(i + chunk_samples, n_times)) for i in range(0, n_times, chunk_samples) ] - # Initialize tracking variables + # initialize tracking variables current_sigma = noiseCompDetectSigma current_fixed = fixedNremove too_strong_once = False @@ -295,7 +301,7 @@ def zapline_plus( # Clean each chunk chunk_results = [] for chunk_start, chunk_end in chunks: - chunk_data = clean_data[:, chunk_start:chunk_end] + chunk_data = clean_data[chunk_start:chunk_end, :] # Detect chunk-specific noise frequency if searchIndividualNoise: @@ -305,6 +311,7 @@ def zapline_plus( target_freq, detectionWinsize, freqDetectMultFine, + detailed_freq_bounds=detailedFreqBoundsUpper, ) else: chunk_freq = target_freq @@ -341,12 +348,12 @@ def zapline_plus( } ) - # Reconstruct cleaned data + # reconstruct cleaned data temp_clean = clean_data.copy() for result in chunk_results: - temp_clean[:, result["start"] : result["end"]] = result["data"] + temp_clean[result["start"] : result["end"], :] = result["data"] - # Check if cleaning is optimal + # check if cleaning is optimal cleaning_status = _check_cleaning_quality( data, temp_clean, @@ -360,7 +367,7 @@ def zapline_plus( maxProportionBelowLower, ) - # Store analytics + # store analytics config["analytics"][f"freq_{freq_idx}"] = { "target_freq": target_freq, "iteration": iteration, @@ -371,25 +378,38 @@ def zapline_plus( "cleaning_status": cleaning_status, } - # Check if we need to adapt + # check if we need to adapt if cleaning_status == "good": clean_data = temp_clean break + elif cleaning_status == "too_weak" and not too_strong_once: - current_sigma = max(current_sigma - 0.25, minsigma) - current_fixed += 1 - print( - f" Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " - f"fixed removal to {current_fixed}" - ) + if current_sigma > minsigma: + current_sigma = max(current_sigma - 0.25, minsigma) + current_fixed += 1 + logging.info( + f"Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + logging.info("At minimum sigma, accepting result") + clean_data = temp_clean + break + elif cleaning_status == "too_strong": too_strong_once = True - current_sigma = min(current_sigma + 0.25, maxsigma) - current_fixed = max(current_fixed - 1, fixedNremove) - print( - f" Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " - f"fixed removal to {current_fixed}" - ) + if current_sigma < maxsigma: + current_sigma = min(current_sigma + 0.25, maxsigma) + current_fixed = max(current_fixed - 1, fixedNremove) + logging.info( + f"Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + logging.info("At maximum sigma, accepting result") + clean_data = temp_clean + break + else: # Too strong takes precedence, or we can't improve further clean_data = temp_clean @@ -408,9 +428,9 @@ def zapline_plus( # add flat channels back to data, if present if flat_data is not None: - full_clean = np.zeros((n_chans, n_times)) - full_clean[active_channels, :] = clean_data - full_clean[global_flat, :] = flat_data + full_clean = np.zeros((n_times, n_chans)) + full_clean[:, active_channels] = clean_data + full_clean[:, global_flat] = flat_data clean_data = full_clean return clean_data, config @@ -440,7 +460,7 @@ def _detect_noise_frequencies( """ # Compute PSD freqs, psd = _compute_psd(data, sfreq) - log_psd = 10 * np.log10(np.mean(psd, axis=0)) + log_psd = 10 * np.log10(np.mean(psd, axis=1)) # State machine variables in_peak = False @@ -516,14 +536,12 @@ def _adaptive_chunking( prominence_quantile=0.95, ): """Segment data into chunks with stable noise topography.""" - n_chans, n_times = data.shape + n_times, n_chans = data.shape if n_times < sfreq * min_chunk_length: logging.warning("Data too short for adaptive chunking. Using single chunk.") return [(0, n_times)] - n_chans, n_times = data.shape - # Narrow-band filter around target frequency bandwidth = detection_winsize / 2.0 filtered = _narrowband_filter(data, sfreq, target_freq, bandwidth=bandwidth) @@ -538,8 +556,8 @@ def _adaptive_chunking( for i in range(n_epochs): start = i * epoch_length end = start + epoch_length - epoch = filtered[:, start:end] - cov = np.cov(epoch) + epoch = filtered[start:end, :] + cov = np.cov(epoch, rowvar=False) if prev_cov is not None: # Frobenius norm of difference @@ -603,25 +621,35 @@ def _adaptive_chunking( return chunks -def _detect_chunk_noise_frequency(data, sfreq, target_freq, winsize, mult_fine): +def _detect_chunk_noise_frequency( + data, + sfreq, + target_freq, + winsize, + mult_fine, + detailed_freq_bounds=(-0.05, 0.05), # ← Add this parameter +): """Detect chunk-specific noise frequency around target.""" freqs, psd = _compute_psd(data, sfreq) - log_psd = 10 * np.log10(np.mean(psd, axis=0)) + log_psd = 10 * np.log10(np.mean(psd, axis=1)) + + # get frequency mask + search_mask = (freqs >= target_freq + detailed_freq_bounds[0]) & ( + freqs <= target_freq + detailed_freq_bounds[1] + ) - # Search in ±0.05 Hz range - search_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) if not np.any(search_mask): return target_freq, False search_freqs = freqs[search_mask] search_psd = log_psd[search_mask] - # Find peak + # find peak peak_idx = np.argmax(search_psd) peak_freq = search_freqs[peak_idx] peak_power = search_psd[peak_idx] - # Compute threshold + # Compute threshold (uses broader window) win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) win_psd = log_psd[win_mask] @@ -673,11 +701,11 @@ def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): """Apply Zapline to a single chunk, handling flat channels.""" - n_chans, n_samples = chunk_data.shape + n_samples, n_chans = chunk_data.shape # Detect flat channels (zero variance) - diff_chunk = np.diff(chunk_data, axis=1) - flat_channels = np.where(np.all(diff_chunk == 0, axis=1))[0] + diff_chunk = np.diff(chunk_data, axis=0) + flat_channels = np.where(np.all(diff_chunk == 0, axis=0))[0] if len(flat_channels) > 0: logging.warning( @@ -686,11 +714,11 @@ def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): ) # store flat channel data - flat_channel_data = chunk_data[flat_channels, :] + flat_channel_data = chunk_data[:, flat_channels] # remove flat channels from processing active_channels = np.setdiff1d(np.arange(n_chans), flat_channels) - chunk_data_active = chunk_data[active_channels, :] + chunk_data_active = chunk_data[:, active_channels] # process only active channels cleaned_active, _ = dss_line( @@ -703,8 +731,8 @@ def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): # Reconstruct full data with flat channels cleaned_chunk = np.zeros_like(chunk_data) - cleaned_chunk[active_channels, :] = cleaned_active - cleaned_chunk[flat_channels, :] = ( + cleaned_chunk[:, active_channels] = cleaned_active + cleaned_chunk[:, flat_channels] = ( flat_channel_data # Add flat channels back unchanged ) @@ -736,7 +764,7 @@ def _check_cleaning_quality( """Check if cleaning is too weak, too strong, or good.""" # Compute PSDs freqs, psd_clean = _compute_psd(cleaned_data, sfreq) - log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) # Compute fine thresholds win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) @@ -786,7 +814,7 @@ def _compute_psd(data, sfreq, nperseg=None): fs=sfreq, window="hann", nperseg=nperseg, - axis=-1, + axis=0, ) return freqs, psd @@ -803,7 +831,7 @@ def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): high = min(high, 0.999) sos = signal.butter(4, [low, high], btype="band", output="sos") - filtered = signal.sosfiltfilt(sos, data, axis=-1) + filtered = signal.sosfiltfilt(sos, data, axis=0) return filtered @@ -817,8 +845,8 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig freqs, psd_orig = _compute_psd(original, sfreq) _, psd_clean = _compute_psd(cleaned, sfreq) - log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=0)) - log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=1)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) # 1. Zoomed spectrum around noise frequency ax1 = fig.add_subplot(gs[0, 0]) @@ -886,8 +914,8 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig # 7. Removed power (ratio) ax7 = fig.add_subplot(gs[1, 2]) noise_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) - ratio_orig = np.mean(psd_orig[:, noise_mask]) / np.mean(psd_orig) - ratio_clean = np.mean(psd_clean[:, noise_mask]) / np.mean(psd_clean) + ratio_orig = np.mean(psd_orig[noise_mask, :]) / np.mean(psd_orig) + ratio_clean = np.mean(psd_clean[noise_mask, :]) / np.mean(psd_clean) ax7.text( 0.5, @@ -933,7 +961,10 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig # Convenience function with simpler interface def remove_line_noise( - data: np.ndarray, sfreq: float, fline: optional[float] = None, **kwargs + data: np.ndarray, + sfreq: float, + fline: float | None = None, + **kwargs, ) -> np.ndarray: """Remove line noise from data using Zapline-plus. @@ -942,7 +973,7 @@ def remove_line_noise( Parameters ---------- - data : array, shape=(n_chans, n_times) + data : array, shape=(n_times, n_chans) Input data. sfreq : float Sampling frequency in Hz. @@ -953,7 +984,7 @@ def remove_line_noise( Returns ------- - clean_data : array, shape=(n_chans, n_times) + clean_data : array, shape=(n_times, n_chans) Cleaned data. Examples From 94f15e632f312e5ad9609316199793c273fd88c5 Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 15:09:18 +0100 Subject: [PATCH 3/9] remove convenience interface; add 'optional' tag for inputs --- meegkit/dss_zapline.py | 106 ++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 70 deletions(-) diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py index 684a459..f806fd7 100644 --- a/meegkit/dss_zapline.py +++ b/meegkit/dss_zapline.py @@ -14,19 +14,14 @@ power line artifacts. NeuroImage, 207, 116356. -Differences from Matlab implementation: +Differences from Matlab implementation +-------------------------------------- Finding noise frequencies: - one iteration returning all frequencies -Adaptive chunking: -- merged chunks at edges if too short - Plotting: - only once per frequency after cleaning - - - """ import logging @@ -38,7 +33,7 @@ from meegkit.dss import dss_line -def zapline_plus( +def dss_line_plus( data: np.ndarray, sfreq: float, fline: float | list[float] | None = None, @@ -71,78 +66,78 @@ def zapline_plus( Parameters ---------- data : array, shape=(n_times, n_chans) - Input data. + Input data. Note that data is expected in time x channels format. sfreq : float Sampling frequency in Hz. fline : float | list of float | None Noise frequency or frequencies to remove. If None, frequencies are detected automatically. Defaults to None. - nkeep : int + nkeep : int | None Number of principal components to keep in DSS. If 0, no dimensionality reduction is applied. Defaults to 0. - adaptiveNremove : bool + adaptiveNremove : bool | None If True, automatically detect the number of components to remove. If False, use fixedNremove for all chunks. Defaults to True. - fixedNremove : int + fixedNremove : int | None Fixed number of components to remove per chunk. Used when adaptiveNremove=False, or as minimum when adaptiveNremove=True. Defaults to 1. - minfreq : float + minfreq : float | None Minimum frequency (Hz) to consider when detecting noise automatically. Defaults to 17.0. - maxfreq : float + maxfreq : float | None Maximum frequency (Hz) to consider when detecting noise automatically. Defaults to 99.0. - chunkLength : float + chunkLength : float | None Length of chunks (seconds) for cleaning. If 0, adaptive chunking based on noise covariance stability is used. Set to -1 via vanilla_mode to process the entire recording as a single chunk. Defaults to 0.0. - minChunkLength : float + minChunkLength : float | None Minimum chunk length (seconds) when using adaptive chunking. Defaults to 30.0. - noiseCompDetectSigma : float + noiseCompDetectSigma : float | None Initial SD threshold for iterative outlier detection of noise components. Defaults to 3.0. - adaptiveSigma : bool + adaptiveSigma : bool | None If True, automatically adapt noiseCompDetectSigma and fixedNremove based on cleaning results. Defaults to True. - minsigma : float + minsigma : float | None Minimum SD threshold when adapting noiseCompDetectSigma. Defaults to 2.5. - maxsigma : float + maxsigma : float | None Maximum SD threshold when adapting noiseCompDetectSigma. Defaults to 4.0. - detectionWinsize : float + detectionWinsize : float | None Window size (Hz) for noise frequency detection. Defaults to 6.0. - coarseFreqDetectPowerDiff : float + coarseFreqDetectPowerDiff : float | None Threshold (10*log10) above center power to detect a peak as noise. Defaults to 4.0. - coarseFreqDetectLowerPowerDiff : float + coarseFreqDetectLowerPowerDiff : float | None Threshold (10*log10) above center power to detect end of noise peak. Defaults to 1.76. - searchIndividualNoise : bool + searchIndividualNoise : bool | None If True, search for individual noise peaks in each chunk. Defaults to True. - freqDetectMultFine : float + freqDetectMultFine : float | None Multiplier for fine noise frequency detection threshold. Defaults to 2.0. - detailedFreqBoundsUpper : tuple of float + detailedFreqBoundsUpper : tuple of float | None Frequency boundaries (Hz) for fine threshold of too weak cleaning. Defaults to (0.05, 0.05). - detailedFreqBoundsLower : tuple of float + detailedFreqBoundsLower : tuple of float | None Frequency boundaries (Hz) for fine threshold of too strong cleaning. Defaults to (0.4, 0.1). - maxProportionAboveUpper : float + maxProportionAboveUpper : float | None Maximum proportion of samples above upper threshold before adapting. Defaults to 0.005. - maxProportionBelowLower : float + maxProportionBelowLower : float | None Maximum proportion of samples below lower threshold before adapting. Defaults to 0.005. - plotResults : bool + plotResults : bool | None If True, generate diagnostic plots for each cleaned frequency. Defaults to False. figsize : tuple of int Figure size for diagnostic plots. Defaults to (14, 10). - vanilla_mode : bool + vanilla_mode : bool | None If True, disable all Zapline-plus features and use vanilla Zapline behavior: - Process entire dataset as single chunk - Use fixed component removal (no adaptive detection) @@ -169,10 +164,10 @@ def zapline_plus( Examples -------- Remove 50 Hz line noise automatically: - >>> clean_data, config = zapline_plus(data, sfreq=500, fline=50) + >>> clean_data, config = dss_line_plus(data, sfreq=500, fline=50) Remove line noise with automatic frequency detection: - >>> clean_data, config = zapline_plus(data, sfreq=500) + >>> clean_data, config = dss_line_plus(data, sfreq=500) """ n_times, n_chans = data.shape @@ -672,6 +667,10 @@ def _detect_chunk_noise_frequency( def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): """Detect number of noise components to remove using outlier detection.""" + # Convert nkeep=0 to None for dss_line (0 means no reduction) + if nkeep == 0: + nkeep = None + # Apply DSS to get component scores _, scores = dss_line(data, target_freq, sfreq, nkeep=nkeep) @@ -703,6 +702,10 @@ def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): """Apply Zapline to a single chunk, handling flat channels.""" n_samples, n_chans = chunk_data.shape + # Convert nkeep=0 to None for dss_line (0 means no reduction) + if nkeep == 0: + nkeep = None + # Detect flat channels (zero variance) diff_chunk = np.diff(chunk_data, axis=0) flat_channels = np.where(np.all(diff_chunk == 0, axis=0))[0] @@ -957,40 +960,3 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig plt.show() return fig - - -# Convenience function with simpler interface -def remove_line_noise( - data: np.ndarray, - sfreq: float, - fline: float | None = None, - **kwargs, -) -> np.ndarray: - """Remove line noise from data using Zapline-plus. - - This is a simplified interface to zapline_plus() that returns only - the cleaned data. - - Parameters - ---------- - data : array, shape=(n_times, n_chans) - Input data. - sfreq : float - Sampling frequency in Hz. - fline : float | None - Line noise frequency. If None, automatically detected. - **kwargs - Additional arguments passed to zapline_plus(). - - Returns - ------- - clean_data : array, shape=(n_times, n_chans) - Cleaned data. - - Examples - -------- - >>> clean = remove_line_noise(data, sfreq=500, fline=50) - - """ - clean_data, _ = zapline_plus(data, sfreq, fline=fline, **kwargs) - return clean_data From 4b598d15bc2e7ac62869f7f87a9d28cb6d9bfe4a Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sat, 24 Jan 2026 13:56:57 +0100 Subject: [PATCH 4/9] [ENH] add zapline plus --- meegkit/dss_zapline.py | 965 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 965 insertions(+) create mode 100644 meegkit/dss_zapline.py diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py new file mode 100644 index 0000000..8545c62 --- /dev/null +++ b/meegkit/dss_zapline.py @@ -0,0 +1,965 @@ +"""Zapline-plus for automatic removal of frequency-specific noise artifacts. + +This module implements Zapline-plus, an extension of the Zapline algorithm +that enables fully automatic removal of line noise and other frequency-specific +artifacts from M/EEG data. + +Based on: +Klug, M., & Kloosterman, N. A. (2022). Zapline-plus: A Zapline extension for +automatic and adaptive removal of frequency-specific noise artifacts in M/EEG. +Human Brain Mapping, 43(9), 2743-2758. + +Original Zapline by: +de Cheveigné, A. (2020). ZapLine: A simple and effective method to remove +power line artifacts. NeuroImage, 207, 116356. + + +Differences from Matlab implementation: + +Finding noise frequencies: + - one iteration returning all frequencies + +Adaptive chunking: + - merged chunks at edges if too short + + + + +""" + +import logging +from typing import dict, list, optional, tuple, union + +import matplotlib.pyplot as plt +import numpy as np +from scipy import signal + +from .dss import dss_line + + +def zapline_plus( + data: np.ndarray, + sfreq: float, + fline: optional[union[float, list[float]]] = None, + nkeep: int = 0, + adaptiveNremove: bool = True, + fixedNremove: int = 1, + minfreq: float = 17.0, + maxfreq: float = 99.0, + chunkLength: float = 0.0, + minChunkLength: float = 30.0, + noiseCompDetectSigma: float = 3.0, + adaptiveSigma: bool = True, + minsigma: float = 2.5, + maxsigma: float = 4.0, + detectionWinsize: float = 6.0, + coarseFreqDetectPowerDiff: float = 4.0, + coarseFreqDetectLowerPowerDiff: float = 1.76, + searchIndividualNoise: bool = True, + freqDetectMultFine: float = 2.0, + detailedFreqBoundsUpper: tuple[float, float] = (0.05, 0.05), + detailedFreqBoundsLower: tuple[float, float] = (0.4, 0.1), + maxProportionAboveUpper: float = 0.005, + maxProportionBelowLower: float = 0.005, + plotResults: bool = False, + figsize: tuple[int, int] = (14, 10), + vanilla_mode: bool = False, +) -> tuple[np.ndarray, dict]: + """Remove line noise and other frequency-specific artifacts using Zapline-plus. + + Parameters + ---------- + data : array, shape=(n_chans, n_times) + Input data. + sfreq : float + Sampling frequency in Hz. + fline : float | list of float | None + Noise frequency or frequencies to remove. If None, frequencies are + detected automatically. Defaults to None. + nkeep : int + Number of principal components to keep in DSS. If 0, no dimensionality + reduction is applied. Defaults to 0. + adaptiveNremove : bool + If True, automatically detect the number of components to remove. + If False, use fixedNremove for all chunks. Defaults to True. + fixedNremove : int + Fixed number of components to remove per chunk. Used when + adaptiveNremove=False, or as minimum when adaptiveNremove=True. + Defaults to 1. + minfreq : float + Minimum frequency (Hz) to consider when detecting noise automatically. + Defaults to 17.0. + maxfreq : float + Maximum frequency (Hz) to consider when detecting noise automatically. + Defaults to 99.0. + chunkLength : float + Length of chunks (seconds) for cleaning. If 0, adaptive chunking based + on noise covariance stability is used. Set to -1 via vanilla_mode to + process the entire recording as a single chunk. Defaults to 0.0. + minChunkLength : float + Minimum chunk length (seconds) when using adaptive chunking. + Defaults to 30.0. + noiseCompDetectSigma : float + Initial SD threshold for iterative outlier detection of noise components. + Defaults to 3.0. + adaptiveSigma : bool + If True, automatically adapt noiseCompDetectSigma and fixedNremove + based on cleaning results. Defaults to True. + minsigma : float + Minimum SD threshold when adapting noiseCompDetectSigma. + Defaults to 2.5. + maxsigma : float + Maximum SD threshold when adapting noiseCompDetectSigma. + Defaults to 4.0. + detectionWinsize : float + Window size (Hz) for noise frequency detection. Defaults to 6.0. + coarseFreqDetectPowerDiff : float + Threshold (10*log10) above center power to detect a peak as noise. + Defaults to 4.0. + coarseFreqDetectLowerPowerDiff : float + Threshold (10*log10) above center power to detect end of noise peak. + Defaults to 1.76. + searchIndividualNoise : bool + If True, search for individual noise peaks in each chunk. + Defaults to True. + freqDetectMultFine : float + Multiplier for fine noise frequency detection threshold. Defaults to 2.0. + detailedFreqBoundsUpper : tuple of float + Frequency boundaries (Hz) for fine threshold of too weak cleaning. + Defaults to (0.05, 0.05). + detailedFreqBoundsLower : tuple of float + Frequency boundaries (Hz) for fine threshold of too strong cleaning. + Defaults to (0.4, 0.1). + maxProportionAboveUpper : float + Maximum proportion of samples above upper threshold before adapting. + Defaults to 0.005. + maxProportionBelowLower : float + Maximum proportion of samples below lower threshold before adapting. + Defaults to 0.005. + plotResults : bool + If True, generate diagnostic plots for each cleaned frequency. + Defaults to False. + figsize : tuple of int + Figure size for diagnostic plots. Defaults to (14, 10). + vanilla_mode : bool + If True, disable all Zapline-plus features and use vanilla Zapline behavior: + - Process entire dataset as single chunk + - Use fixed component removal (no adaptive detection) + - No individual chunk frequency detection + - No adaptive parameter tuning + Requires fline to be specified (not None). Defaults to False. + + Returns + ------- + clean_data : array, shape=(n_chans, n_times) + Cleaned data. + config : dict + Configuration dictionary containing all parameters and analytics. + + Notes + ----- + The algorithm proceeds as follows: + 1. Detect noise frequencies (if not provided) + 2. Segment data into chunks with stable noise topography + 3. Apply Zapline to each chunk + 4. Automatically detect and remove noise components + 5. Adapt parameters if cleaning is too weak or too strong + + Examples + -------- + Remove 50 Hz line noise automatically: + >>> clean_data, config = zapline_plus(data, sfreq=500, fline=50) + + Remove line noise with automatic frequency detection: + >>> clean_data, config = zapline_plus(data, sfreq=500) + + """ + n_chans, n_times = data.shape + + # Handle vanilla mode + if vanilla_mode: + logging.warning( + "vanilla_mode=True: Using vanilla Zapline behavior. " + "All adaptive features disabled." + ) + if fline is None: + raise ValueError("vanilla_mode requires fline to be specified (not None)") + + for param_name in [ + "adaptiveNremove", + "adaptiveSigma", + "searchIndividualNoise", + ]: + if locals()[param_name]: + logging.warning(f"vanilla_mode=True: Overriding {param_name} to False.") + + # Override all adaptive features + adaptiveNremove = False + adaptiveSigma = False + searchIndividualNoise = False + chunkLength = -1 # Zapline vanilla deals with single chunk + + # check for globally flat channels + diff_data = np.diff(data, axis=1) + global_flat = np.where(np.all(diff_data == 0, axis=1))[0] + + if len(global_flat) > 0: + logging.warning( + f"Detected {len(global_flat)} globally flat channels: {global_flat}. " + f"Removing for processing, will add back after." + ) + flat_data = data[global_flat, :] + active_channels = np.setdiff1d(np.arange(n_chans), global_flat) + data = data[active_channels, :] + else: + active_channels = np.arange(n_chans) + flat_data = None + + # Initialize configuration + config = { + "sfreq": sfreq, + "fline": fline, + "nkeep": nkeep, + "adaptiveNremove": adaptiveNremove, + "fixedNremove": fixedNremove, + "minfreq": minfreq, + "maxfreq": maxfreq, + "chunkLength": chunkLength, + "minChunkLength": minChunkLength, + "noiseCompDetectSigma": noiseCompDetectSigma, + "adaptiveSigma": adaptiveSigma, + "minsigma": minsigma, + "maxsigma": maxsigma, + "detectionWinsize": detectionWinsize, + "coarseFreqDetectPowerDiff": coarseFreqDetectPowerDiff, + "coarseFreqDetectLowerPowerDiff": coarseFreqDetectLowerPowerDiff, + "searchIndividualNoise": searchIndividualNoise, + "freqDetectMultFine": freqDetectMultFine, + "detailedFreqBoundsUpper": detailedFreqBoundsUpper, + "detailedFreqBoundsLower": detailedFreqBoundsLower, + "maxProportionAboveUpper": maxProportionAboveUpper, + "maxProportionBelowLower": maxProportionBelowLower, + "analytics": {}, + } + + # Detect noise frequencies if not provided + if fline is None: + fline = _detect_noise_frequencies( + data, + sfreq, + minfreq, + maxfreq, + detectionWinsize, + coarseFreqDetectPowerDiff, + coarseFreqDetectLowerPowerDiff, + ) + elif not isinstance(fline, list): + fline = [fline] + + if len(fline) == 0: + logging.warning("No noise frequencies detected. Returning original data.") + return data.copy(), config + + config["detected_fline"] = fline + + # retain input data + clean_data = data.copy() + + # Process each noise frequency + for freq_idx, target_freq in enumerate(fline): + print(f"Processing noise frequency: {target_freq:.2f} Hz") + + # Adaptive chunking or fixed chunks + if chunkLength == -1: + # single chunk + chunks = [(0, n_times)] + elif chunkLength == 0: + chunks = _adaptive_chunking(clean_data, sfreq, target_freq, minChunkLength) + else: + chunk_samples = int(chunkLength * sfreq) + chunks = [ + (i, min(i + chunk_samples, n_times)) + for i in range(0, n_times, chunk_samples) + ] + + # Initialize tracking variables + current_sigma = noiseCompDetectSigma + current_fixed = fixedNremove + too_strong_once = False + iteration = 0 + max_iterations = 20 + + while iteration < max_iterations: + iteration += 1 + + # Clean each chunk + chunk_results = [] + for chunk_start, chunk_end in chunks: + chunk_data = clean_data[:, chunk_start:chunk_end] + + # Detect chunk-specific noise frequency + if searchIndividualNoise: + chunk_freq, has_noise = _detect_chunk_noise_frequency( + chunk_data, + sfreq, + target_freq, + detectionWinsize, + freqDetectMultFine, + ) + else: + chunk_freq = target_freq + has_noise = True + + # Apply Zapline to chunk + if has_noise: + if adaptiveNremove: + n_remove = _detect_noise_components( + chunk_data, sfreq, chunk_freq, current_sigma, nkeep + ) + n_remove = max(n_remove, current_fixed) + else: + n_remove = current_fixed + + # Cap at 1/5 of components + n_remove = min(n_remove, n_chans // 5) + else: + n_remove = current_fixed + + # clean chunk + cleaned_chunk = _apply_zapline_to_chunk( + chunk_data, sfreq, chunk_freq, n_remove, nkeep + ) + + chunk_results.append( + { + "start": chunk_start, + "end": chunk_end, + "freq": chunk_freq, + "n_remove": n_remove, + "has_noise": has_noise, + "data": cleaned_chunk, + } + ) + + # Reconstruct cleaned data + temp_clean = clean_data.copy() + for result in chunk_results: + temp_clean[:, result["start"] : result["end"]] = result["data"] + + # Check if cleaning is optimal + cleaning_status = _check_cleaning_quality( + data, + temp_clean, + sfreq, + target_freq, + detectionWinsize, + freqDetectMultFine, + detailedFreqBoundsUpper, + detailedFreqBoundsLower, + maxProportionAboveUpper, + maxProportionBelowLower, + ) + + # Store analytics + config["analytics"][f"freq_{freq_idx}"] = { + "target_freq": target_freq, + "iteration": iteration, + "sigma": current_sigma, + "fixed_nremove": current_fixed, + "n_chunks": len(chunks), + "chunk_results": chunk_results, + "cleaning_status": cleaning_status, + } + + # Check if we need to adapt + if cleaning_status == "good": + clean_data = temp_clean + break + elif cleaning_status == "too_weak" and not too_strong_once: + current_sigma = max(current_sigma - 0.25, minsigma) + current_fixed += 1 + print( + f" Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + elif cleaning_status == "too_strong": + too_strong_once = True + current_sigma = min(current_sigma + 0.25, maxsigma) + current_fixed = max(current_fixed - 1, fixedNremove) + print( + f" Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + # Too strong takes precedence, or we can't improve further + clean_data = temp_clean + break + + # Generate diagnostic plot + if plotResults: + _plot_cleaning_results( + data, + clean_data, + sfreq, + target_freq, + config["analytics"][f"freq_{freq_idx}"], + figsize, + ) + + # add flat channels back to data, if present + if flat_data is not None: + full_clean = np.zeros((n_chans, n_times)) + full_clean[active_channels, :] = clean_data + full_clean[global_flat, :] = flat_data + clean_data = full_clean + + return clean_data, config + + +def _detect_noise_frequencies( + data, sfreq, minfreq, maxfreq, winsize, power_diff_high, power_diff_low +): + """ + Detect noise frequencies. + + This is an exact implementation of find_next_noisefreq.m with the only difference + that all peaks are returned instead of this being called iteratively. + + How it works + ------------ + 1. Compute PSD and log-transform. + 2. Slide a window across frequencies from minfreq to maxfreq. + 3. For each frequency, compute center power as mean of left and right thirds. + 4. Use a state machine to detect peaks: + - SEARCHING: If current power - center power > power_diff_high, + mark peak start and switch to IN_PEAK. + - IN_PEAK: If current power - center power <= power_diff_low, + mark peak end, find max within peak, record frequency, + and switch to SEARCHING. + 5. Return list of detected noise frequencies. + """ + # Compute PSD + freqs, psd = _compute_psd(data, sfreq) + log_psd = 10 * np.log10(np.mean(psd, axis=0)) + + # State machine variables + in_peak = False + peak_start_idx = None + noise_freqs = [] + + # Search bounds + start_idx = np.searchsorted(freqs, minfreq) + end_idx = np.searchsorted(freqs, maxfreq) + + # Window size in samples + freq_resolution = freqs[1] - freqs[0] + win_samples = int(winsize / freq_resolution) + + idx = start_idx + while idx < end_idx: + # Get window around current frequency + win_start = max(0, idx - win_samples // 2) + win_end = min(len(freqs), idx + win_samples // 2) + win_psd = log_psd[win_start:win_end] + + if len(win_psd) < 3: + idx += 1 + continue + + # Compute center power (mean of left and right thirds) + n_third = len(win_psd) // 3 + if n_third < 1: + idx += 1 + continue + + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center_power = np.mean([np.mean(left_third), np.mean(right_third)]) + + current_power = log_psd[idx] + + # State machine logic + if not in_peak: + # State: SEARCHING - Check for peak start + if current_power - center_power > power_diff_high: + in_peak = True + peak_start_idx = idx + + else: + # State: IN_PEAK - Check for peak end + if current_power - center_power <= power_diff_low: + in_peak = False + peak_end_idx = idx + + # Find the actual maximum within the peak + if peak_start_idx is not None and peak_end_idx > peak_start_idx: + peak_region = log_psd[peak_start_idx:peak_end_idx] + max_offset = np.argmax(peak_region) + max_idx = peak_start_idx + max_offset + noise_freqs.append(freqs[max_idx]) + + # Skip past this peak to avoid re-detection + idx = peak_end_idx + continue + + idx += 1 + + return noise_freqs + + +def _adaptive_chunking( + data, + sfreq, + target_freq, + min_chunk_length, + detection_winsize=6.0, + prominence_quantile=0.95, +): + """Segment data into chunks with stable noise topography.""" + n_chans, n_times = data.shape + + if n_times < sfreq * min_chunk_length: + logging.warning("Data too short for adaptive chunking. Using single chunk.") + return [(0, n_times)] + + n_chans, n_times = data.shape + + # Narrow-band filter around target frequency + bandwidth = detection_winsize / 2.0 + filtered = _narrowband_filter(data, sfreq, target_freq, bandwidth=bandwidth) + + # Compute covariance matrices for 1-second epochs + epoch_length = int(sfreq) + n_epochs = n_times // epoch_length + + distances = np.zeros(n_epochs) + prev_cov = None + + for i in range(n_epochs): + start = i * epoch_length + end = start + epoch_length + epoch = filtered[:, start:end] + cov = np.cov(epoch) + + if prev_cov is not None: + # Frobenius norm of difference + distances[i] = np.linalg.norm(cov - prev_cov, "fro") + # else: distance[i] already 0 from initialization + + prev_cov = cov + + if len(distances) < 2: + return [(0, n_times)] + + # find all peaks to get prominence distribution + peaks_all, properties_all = signal.find_peaks(distances, prominence=0) + + if len(peaks_all) == 0 or "prominences" not in properties_all: + # No peaks found + logging.warning("No peaks found in distance signal. Using single chunk.") + return [(0, n_times)] + + prominences = properties_all["prominences"] + + # filter by prominence quantile + min_prominence = np.quantile(prominences, prominence_quantile) + min_distance_epochs = int(min_chunk_length) # Convert seconds to epochs + + peaks, properties = signal.find_peaks( + distances, prominence=min_prominence, distance=min_distance_epochs + ) + + # cconvert peak locations (in epochs) to sample indices + chunk_starts = [0] + for peak in peaks: + chunk_start_sample = peak * epoch_length + chunk_starts.append(chunk_start_sample) + chunk_starts.append(n_times) + + # create chunk list + chunks = [] + for i in range(len(chunk_starts) - 1): + start = chunk_starts[i] + end = chunk_starts[i + 1] + chunks.append((start, end)) + + # ensure minimum chunk length at edges + min_chunk_samples = int(min_chunk_length * sfreq) + + if len(chunks) > 1: + # check first chunk + if chunks[0][1] - chunks[0][0] < min_chunk_samples: + # merge with next + chunks[1] = (chunks[0][0], chunks[1][1]) + chunks.pop(0) + + if len(chunks) > 1: + # check last chunk + if chunks[-1][1] - chunks[-1][0] < min_chunk_samples: + # merge with previous + chunks[-2] = (chunks[-2][0], chunks[-1][1]) + chunks.pop(-1) + + return chunks + + +def _detect_chunk_noise_frequency(data, sfreq, target_freq, winsize, mult_fine): + """Detect chunk-specific noise frequency around target.""" + freqs, psd = _compute_psd(data, sfreq) + log_psd = 10 * np.log10(np.mean(psd, axis=0)) + + # Search in ±0.05 Hz range + search_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) + if not np.any(search_mask): + return target_freq, False + + search_freqs = freqs[search_mask] + search_psd = log_psd[search_mask] + + # Find peak + peak_idx = np.argmax(search_psd) + peak_freq = search_freqs[peak_idx] + peak_power = search_psd[peak_idx] + + # Compute threshold + win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) + win_psd = log_psd[win_mask] + + n_third = len(win_psd) // 3 + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center = np.mean([np.mean(left_third), np.mean(right_third)]) + + # Compute deviation (lower 5% quantiles) + lower_quant_left = np.percentile(left_third, 5) + lower_quant_right = np.percentile(right_third, 5) + deviation = center - np.mean([lower_quant_left, lower_quant_right]) + + threshold = center + mult_fine * deviation + + has_noise = peak_power > threshold + + return peak_freq, has_noise + + +def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): + """Detect number of noise components to remove using outlier detection.""" + # Apply DSS to get component scores + _, scores = dss_line(data, target_freq, sfreq, nkeep=nkeep) + + if scores is None or len(scores) == 0: + return 1 + + # Sort scores in descending order + sorted_scores = np.sort(scores)[::-1] + + # Iterative outlier detection + n_remove = 0 + remaining = sorted_scores.copy() + + while len(remaining) > 1: + mean_val = np.mean(remaining) + std_val = np.std(remaining) + threshold = mean_val + sigma * std_val + + if remaining[0] > threshold: + n_remove += 1 + remaining = remaining[1:] + else: + break + + return max(n_remove, 1) + + +def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): + """Apply Zapline to a single chunk, handling flat channels.""" + n_chans, n_samples = chunk_data.shape + + # Detect flat channels (zero variance) + diff_chunk = np.diff(chunk_data, axis=1) + flat_channels = np.where(np.all(diff_chunk == 0, axis=1))[0] + + if len(flat_channels) > 0: + logging.warning( + f"Detected {len(flat_channels)} flat channels in chunk: {flat_channels}. " + f"Removing temporarily for processing." + ) + + # store flat channel data + flat_channel_data = chunk_data[flat_channels, :] + + # remove flat channels from processing + active_channels = np.setdiff1d(np.arange(n_chans), flat_channels) + chunk_data_active = chunk_data[active_channels, :] + + # process only active channels + cleaned_active, _ = dss_line( + chunk_data_active, + fline=chunk_freq, + sfreq=sfreq, + nremove=n_remove, + nkeep=nkeep, + ) + + # Reconstruct full data with flat channels + cleaned_chunk = np.zeros_like(chunk_data) + cleaned_chunk[active_channels, :] = cleaned_active + cleaned_chunk[flat_channels, :] = ( + flat_channel_data # Add flat channels back unchanged + ) + + else: + # no flat channels, process normally + cleaned_chunk, _ = dss_line( + chunk_data, + fline=chunk_freq, + sfreq=sfreq, + nremove=n_remove, + nkeep=nkeep, + ) + + return cleaned_chunk + + +def _check_cleaning_quality( + original_data, + cleaned_data, + sfreq, + target_freq, + winsize, + mult_fine, + bounds_upper, + bounds_lower, + max_prop_above, + max_prop_below, +): + """Check if cleaning is too weak, too strong, or good.""" + # Compute PSDs + freqs, psd_clean = _compute_psd(cleaned_data, sfreq) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + + # Compute fine thresholds + win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) + win_psd = log_psd_clean[win_mask] + + n_third = len(win_psd) // 3 + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center = np.mean([np.mean(left_third), np.mean(right_third)]) + + # Deviation from lower quantiles + lower_quant_left = np.percentile(left_third, 5) + lower_quant_right = np.percentile(right_third, 5) + deviation = center - np.mean([lower_quant_left, lower_quant_right]) + + # Upper threshold (too weak cleaning) + upper_mask = (freqs >= target_freq - bounds_upper[0]) & ( + freqs <= target_freq + bounds_upper[1] + ) + upper_threshold = center + mult_fine * deviation + upper_psd = log_psd_clean[upper_mask] + prop_above = np.mean(upper_psd > upper_threshold) + + # Lower threshold (too strong cleaning) + lower_mask = (freqs >= target_freq - bounds_lower[0]) & ( + freqs <= target_freq + bounds_lower[1] + ) + lower_threshold = center - mult_fine * deviation + lower_psd = log_psd_clean[lower_mask] + prop_below = np.mean(lower_psd < lower_threshold) + + if prop_below > max_prop_below: + return "too_strong" + elif prop_above > max_prop_above: + return "too_weak" + else: + return "good" + + +def _compute_psd(data, sfreq, nperseg=None): + """Compute power spectral density using Welch's method.""" + if nperseg is None: + nperseg = int(sfreq * 4) # 4-second windows + + freqs, psd = signal.welch( + data, + fs=sfreq, + window="hann", + nperseg=nperseg, + axis=-1, + ) + + return freqs, psd + + +def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): + """Apply narrow-band filter around center frequency.""" + nyq = sfreq / 2 + low = (center_freq - bandwidth) / nyq + high = (center_freq + bandwidth) / nyq + + # Ensure valid frequency range + low = max(low, 0.001) + high = min(high, 0.999) + + sos = signal.butter(4, [low, high], btype="band", output="sos") + filtered = signal.sosfiltfilt(sos, data, axis=-1) + + return filtered + + +def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, figsize): + """Generate diagnostic plots for cleaning results.""" + fig = plt.figure(figsize=figsize) + gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3) + + # Compute PSDs + freqs, psd_orig = _compute_psd(original, sfreq) + _, psd_clean = _compute_psd(cleaned, sfreq) + + log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=0)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + + # 1. Zoomed spectrum around noise frequency + ax1 = fig.add_subplot(gs[0, 0]) + zoom_mask = (freqs >= target_freq - 1.1) & (freqs <= target_freq + 1.1) + ax1.plot(freqs[zoom_mask], log_psd_orig[zoom_mask], "k-", label="Original") + ax1.set_xlabel("Frequency (Hz)") + ax1.set_ylabel("Power (dB)") + ax1.set_title(f"Detected frequency: {target_freq:.2f} Hz") + ax1.legend() + ax1.grid(True, alpha=0.3) + + # 2. Number of removed components per chunk + ax2 = fig.add_subplot(gs[0, 1]) + chunk_results = analytics["chunk_results"] + n_removes = [cr["n_remove"] for cr in chunk_results] + ax2.bar(range(len(n_removes)), n_removes) + ax2.set_xlabel("Chunk") + ax2.set_ylabel("# Components removed") + ax2.set_title(f"Removed components (mean={np.mean(n_removes):.1f})") + ax2.grid(True, alpha=0.3) + + # 3. Individual noise frequencies per chunk + ax3 = fig.add_subplot(gs[0, 2]) + chunk_freqs = [cr["freq"] for cr in chunk_results] + time_min = np.array([cr["start"] for cr in chunk_results]) / sfreq / 60 + ax3.plot(time_min, chunk_freqs, "o-") + ax3.set_xlabel("Time (minutes)") + ax3.set_ylabel("Frequency (Hz)") + ax3.set_title("Individual noise frequencies") + ax3.grid(True, alpha=0.3) + + # 4. Component scores (would need actual scores from DSS) + ax4 = fig.add_subplot(gs[0, 3]) + ax4.text( + 0.5, + 0.5, + "Component scores\n(requires DSS output)", + ha="center", + va="center", + transform=ax4.transAxes, + ) + ax4.set_title("Mean artifact scores") + + # 5. Cleaned spectrum (zoomed) + ax5 = fig.add_subplot(gs[1, 0]) + ax5.plot(freqs[zoom_mask], log_psd_clean[zoom_mask], "g-", label="Cleaned") + ax5.set_xlabel("Frequency (Hz)") + ax5.set_ylabel("Power (dB)") + ax5.set_title("Cleaned spectrum") + ax5.legend() + ax5.grid(True, alpha=0.3) + + # 6. Full spectrum + ax6 = fig.add_subplot(gs[1, 1]) + ax6.plot(freqs, log_psd_orig, "k-", alpha=0.5, label="Original") + ax6.plot(freqs, log_psd_clean, "g-", label="Cleaned") + ax6.axvline(target_freq, color="r", linestyle="--", alpha=0.5) + ax6.set_xlabel("Frequency (Hz)") + ax6.set_ylabel("Power (dB)") + ax6.set_title("Full power spectrum") + ax6.legend() + ax6.grid(True, alpha=0.3) + ax6.set_xlim([0, 100]) + + # 7. Removed power (ratio) + ax7 = fig.add_subplot(gs[1, 2]) + noise_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) + ratio_orig = np.mean(psd_orig[:, noise_mask]) / np.mean(psd_orig) + ratio_clean = np.mean(psd_clean[:, noise_mask]) / np.mean(psd_clean) + + ax7.text( + 0.5, + 0.6, + f"Original ratio: {ratio_orig:.2f}", + ha="center", + transform=ax7.transAxes, + ) + ax7.text( + 0.5, + 0.4, + f"Cleaned ratio: {ratio_clean:.2f}", + ha="center", + transform=ax7.transAxes, + ) + ax7.set_title("Noise/surroundings ratio") + ax7.axis("off") + + # 8. Below noise frequencies + ax8 = fig.add_subplot(gs[1, 3]) + below_mask = (freqs >= target_freq - 11) & (freqs <= target_freq - 1) + ax8.plot( + freqs[below_mask], log_psd_orig[below_mask], "k-", alpha=0.5, label="Original" + ) + ax8.plot(freqs[below_mask], log_psd_clean[below_mask], "g-", label="Cleaned") + ax8.set_xlabel("Frequency (Hz)") + ax8.set_ylabel("Power (dB)") + ax8.set_title("Power below noise frequency") + ax8.legend() + ax8.grid(True, alpha=0.3) + + plt.suptitle( + f"Zapline-plus cleaning results: {target_freq:.2f} Hz " + f"(iteration {analytics['iteration']})", + fontsize=14, + y=0.98, + ) + + plt.show() + + return fig + + +# Convenience function with simpler interface +def remove_line_noise( + data: np.ndarray, sfreq: float, fline: optional[float] = None, **kwargs +) -> np.ndarray: + """Remove line noise from data using Zapline-plus. + + This is a simplified interface to zapline_plus() that returns only + the cleaned data. + + Parameters + ---------- + data : array, shape=(n_chans, n_times) + Input data. + sfreq : float + Sampling frequency in Hz. + fline : float | None + Line noise frequency. If None, automatically detected. + **kwargs + Additional arguments passed to zapline_plus(). + + Returns + ------- + clean_data : array, shape=(n_chans, n_times) + Cleaned data. + + Examples + -------- + >>> clean = remove_line_noise(data, sfreq=500, fline=50) + + """ + clean_data, _ = zapline_plus(data, sfreq, fline=fline, **kwargs) + return clean_data From 71af49643a1890760ba9c41edfe0538b13067083 Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 14:52:52 +0100 Subject: [PATCH 5/9] =?UTF-8?q?change=20(n=5Fchans,=20n=5Ftimes)=20?= =?UTF-8?q?=E2=86=92=20(n=5Ftimes,=20n=5Fchans)=20to=20adhere=20to=20codeb?= =?UTF-8?q?ase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- meegkit/dss_zapline.py | 169 ++++++++++++++++++++++++----------------- 1 file changed, 100 insertions(+), 69 deletions(-) diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py index 8545c62..684a459 100644 --- a/meegkit/dss_zapline.py +++ b/meegkit/dss_zapline.py @@ -17,30 +17,31 @@ Differences from Matlab implementation: Finding noise frequencies: - - one iteration returning all frequencies +- one iteration returning all frequencies Adaptive chunking: - - merged chunks at edges if too short +- merged chunks at edges if too short +Plotting: +- only once per frequency after cleaning """ import logging -from typing import dict, list, optional, tuple, union import matplotlib.pyplot as plt import numpy as np from scipy import signal -from .dss import dss_line +from meegkit.dss import dss_line def zapline_plus( data: np.ndarray, sfreq: float, - fline: optional[union[float, list[float]]] = None, + fline: float | list[float] | None = None, nkeep: int = 0, adaptiveNremove: bool = True, fixedNremove: int = 1, @@ -69,7 +70,7 @@ def zapline_plus( Parameters ---------- - data : array, shape=(n_chans, n_times) + data : array, shape=(n_times, n_chans) Input data. sfreq : float Sampling frequency in Hz. @@ -151,7 +152,7 @@ def zapline_plus( Returns ------- - clean_data : array, shape=(n_chans, n_times) + clean_data : array, shape=(n_times, n_chans) Cleaned data. config : dict Configuration dictionary containing all parameters and analytics. @@ -174,9 +175,9 @@ def zapline_plus( >>> clean_data, config = zapline_plus(data, sfreq=500) """ - n_chans, n_times = data.shape + n_times, n_chans = data.shape - # Handle vanilla mode + # Handle vanilla mode (ZapLine without plus) if vanilla_mode: logging.warning( "vanilla_mode=True: Using vanilla Zapline behavior. " @@ -199,23 +200,27 @@ def zapline_plus( searchIndividualNoise = False chunkLength = -1 # Zapline vanilla deals with single chunk - # check for globally flat channels - diff_data = np.diff(data, axis=1) - global_flat = np.where(np.all(diff_data == 0, axis=1))[0] + # if nothing is adaptive, only one iteration per frequency + if not (adaptiveNremove and adaptiveSigma): + max_iterations = 1 + # check for globally flat channels + # will be omitted during processing and reintroduced later + diff_data = np.diff(data, axis=0) + global_flat = np.where(np.all(diff_data == 0, axis=0))[0] if len(global_flat) > 0: logging.warning( f"Detected {len(global_flat)} globally flat channels: {global_flat}. " f"Removing for processing, will add back after." ) - flat_data = data[global_flat, :] + flat_data = data[:, global_flat] active_channels = np.setdiff1d(np.arange(n_chans), global_flat) - data = data[active_channels, :] + data = data[:, active_channels] else: active_channels = np.arange(n_chans) flat_data = None - # Initialize configuration + # initialize configuration config = { "sfreq": sfreq, "fline": fline, @@ -242,7 +247,7 @@ def zapline_plus( "analytics": {}, } - # Detect noise frequencies if not provided + # detect noise frequencies if not provided if fline is None: fline = _detect_noise_frequencies( data, @@ -257,7 +262,7 @@ def zapline_plus( fline = [fline] if len(fline) == 0: - logging.warning("No noise frequencies detected. Returning original data.") + logging.info("No noise frequencies detected. Returning original data.") return data.copy(), config config["detected_fline"] = fline @@ -269,20 +274,21 @@ def zapline_plus( for freq_idx, target_freq in enumerate(fline): print(f"Processing noise frequency: {target_freq:.2f} Hz") - # Adaptive chunking or fixed chunks if chunkLength == -1: # single chunk chunks = [(0, n_times)] elif chunkLength == 0: + # adaptive chunking chunks = _adaptive_chunking(clean_data, sfreq, target_freq, minChunkLength) else: + # fixed-length chunks chunk_samples = int(chunkLength * sfreq) chunks = [ (i, min(i + chunk_samples, n_times)) for i in range(0, n_times, chunk_samples) ] - # Initialize tracking variables + # initialize tracking variables current_sigma = noiseCompDetectSigma current_fixed = fixedNremove too_strong_once = False @@ -295,7 +301,7 @@ def zapline_plus( # Clean each chunk chunk_results = [] for chunk_start, chunk_end in chunks: - chunk_data = clean_data[:, chunk_start:chunk_end] + chunk_data = clean_data[chunk_start:chunk_end, :] # Detect chunk-specific noise frequency if searchIndividualNoise: @@ -305,6 +311,7 @@ def zapline_plus( target_freq, detectionWinsize, freqDetectMultFine, + detailed_freq_bounds=detailedFreqBoundsUpper, ) else: chunk_freq = target_freq @@ -341,12 +348,12 @@ def zapline_plus( } ) - # Reconstruct cleaned data + # reconstruct cleaned data temp_clean = clean_data.copy() for result in chunk_results: - temp_clean[:, result["start"] : result["end"]] = result["data"] + temp_clean[result["start"] : result["end"], :] = result["data"] - # Check if cleaning is optimal + # check if cleaning is optimal cleaning_status = _check_cleaning_quality( data, temp_clean, @@ -360,7 +367,7 @@ def zapline_plus( maxProportionBelowLower, ) - # Store analytics + # store analytics config["analytics"][f"freq_{freq_idx}"] = { "target_freq": target_freq, "iteration": iteration, @@ -371,25 +378,38 @@ def zapline_plus( "cleaning_status": cleaning_status, } - # Check if we need to adapt + # check if we need to adapt if cleaning_status == "good": clean_data = temp_clean break + elif cleaning_status == "too_weak" and not too_strong_once: - current_sigma = max(current_sigma - 0.25, minsigma) - current_fixed += 1 - print( - f" Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " - f"fixed removal to {current_fixed}" - ) + if current_sigma > minsigma: + current_sigma = max(current_sigma - 0.25, minsigma) + current_fixed += 1 + logging.info( + f"Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + logging.info("At minimum sigma, accepting result") + clean_data = temp_clean + break + elif cleaning_status == "too_strong": too_strong_once = True - current_sigma = min(current_sigma + 0.25, maxsigma) - current_fixed = max(current_fixed - 1, fixedNremove) - print( - f" Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " - f"fixed removal to {current_fixed}" - ) + if current_sigma < maxsigma: + current_sigma = min(current_sigma + 0.25, maxsigma) + current_fixed = max(current_fixed - 1, fixedNremove) + logging.info( + f"Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + logging.info("At maximum sigma, accepting result") + clean_data = temp_clean + break + else: # Too strong takes precedence, or we can't improve further clean_data = temp_clean @@ -408,9 +428,9 @@ def zapline_plus( # add flat channels back to data, if present if flat_data is not None: - full_clean = np.zeros((n_chans, n_times)) - full_clean[active_channels, :] = clean_data - full_clean[global_flat, :] = flat_data + full_clean = np.zeros((n_times, n_chans)) + full_clean[:, active_channels] = clean_data + full_clean[:, global_flat] = flat_data clean_data = full_clean return clean_data, config @@ -440,7 +460,7 @@ def _detect_noise_frequencies( """ # Compute PSD freqs, psd = _compute_psd(data, sfreq) - log_psd = 10 * np.log10(np.mean(psd, axis=0)) + log_psd = 10 * np.log10(np.mean(psd, axis=1)) # State machine variables in_peak = False @@ -516,14 +536,12 @@ def _adaptive_chunking( prominence_quantile=0.95, ): """Segment data into chunks with stable noise topography.""" - n_chans, n_times = data.shape + n_times, n_chans = data.shape if n_times < sfreq * min_chunk_length: logging.warning("Data too short for adaptive chunking. Using single chunk.") return [(0, n_times)] - n_chans, n_times = data.shape - # Narrow-band filter around target frequency bandwidth = detection_winsize / 2.0 filtered = _narrowband_filter(data, sfreq, target_freq, bandwidth=bandwidth) @@ -538,8 +556,8 @@ def _adaptive_chunking( for i in range(n_epochs): start = i * epoch_length end = start + epoch_length - epoch = filtered[:, start:end] - cov = np.cov(epoch) + epoch = filtered[start:end, :] + cov = np.cov(epoch, rowvar=False) if prev_cov is not None: # Frobenius norm of difference @@ -603,25 +621,35 @@ def _adaptive_chunking( return chunks -def _detect_chunk_noise_frequency(data, sfreq, target_freq, winsize, mult_fine): +def _detect_chunk_noise_frequency( + data, + sfreq, + target_freq, + winsize, + mult_fine, + detailed_freq_bounds=(-0.05, 0.05), # ← Add this parameter +): """Detect chunk-specific noise frequency around target.""" freqs, psd = _compute_psd(data, sfreq) - log_psd = 10 * np.log10(np.mean(psd, axis=0)) + log_psd = 10 * np.log10(np.mean(psd, axis=1)) + + # get frequency mask + search_mask = (freqs >= target_freq + detailed_freq_bounds[0]) & ( + freqs <= target_freq + detailed_freq_bounds[1] + ) - # Search in ±0.05 Hz range - search_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) if not np.any(search_mask): return target_freq, False search_freqs = freqs[search_mask] search_psd = log_psd[search_mask] - # Find peak + # find peak peak_idx = np.argmax(search_psd) peak_freq = search_freqs[peak_idx] peak_power = search_psd[peak_idx] - # Compute threshold + # Compute threshold (uses broader window) win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) win_psd = log_psd[win_mask] @@ -673,11 +701,11 @@ def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): """Apply Zapline to a single chunk, handling flat channels.""" - n_chans, n_samples = chunk_data.shape + n_samples, n_chans = chunk_data.shape # Detect flat channels (zero variance) - diff_chunk = np.diff(chunk_data, axis=1) - flat_channels = np.where(np.all(diff_chunk == 0, axis=1))[0] + diff_chunk = np.diff(chunk_data, axis=0) + flat_channels = np.where(np.all(diff_chunk == 0, axis=0))[0] if len(flat_channels) > 0: logging.warning( @@ -686,11 +714,11 @@ def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): ) # store flat channel data - flat_channel_data = chunk_data[flat_channels, :] + flat_channel_data = chunk_data[:, flat_channels] # remove flat channels from processing active_channels = np.setdiff1d(np.arange(n_chans), flat_channels) - chunk_data_active = chunk_data[active_channels, :] + chunk_data_active = chunk_data[:, active_channels] # process only active channels cleaned_active, _ = dss_line( @@ -703,8 +731,8 @@ def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): # Reconstruct full data with flat channels cleaned_chunk = np.zeros_like(chunk_data) - cleaned_chunk[active_channels, :] = cleaned_active - cleaned_chunk[flat_channels, :] = ( + cleaned_chunk[:, active_channels] = cleaned_active + cleaned_chunk[:, flat_channels] = ( flat_channel_data # Add flat channels back unchanged ) @@ -736,7 +764,7 @@ def _check_cleaning_quality( """Check if cleaning is too weak, too strong, or good.""" # Compute PSDs freqs, psd_clean = _compute_psd(cleaned_data, sfreq) - log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) # Compute fine thresholds win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) @@ -786,7 +814,7 @@ def _compute_psd(data, sfreq, nperseg=None): fs=sfreq, window="hann", nperseg=nperseg, - axis=-1, + axis=0, ) return freqs, psd @@ -803,7 +831,7 @@ def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): high = min(high, 0.999) sos = signal.butter(4, [low, high], btype="band", output="sos") - filtered = signal.sosfiltfilt(sos, data, axis=-1) + filtered = signal.sosfiltfilt(sos, data, axis=0) return filtered @@ -817,8 +845,8 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig freqs, psd_orig = _compute_psd(original, sfreq) _, psd_clean = _compute_psd(cleaned, sfreq) - log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=0)) - log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=0)) + log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=1)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) # 1. Zoomed spectrum around noise frequency ax1 = fig.add_subplot(gs[0, 0]) @@ -886,8 +914,8 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig # 7. Removed power (ratio) ax7 = fig.add_subplot(gs[1, 2]) noise_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) - ratio_orig = np.mean(psd_orig[:, noise_mask]) / np.mean(psd_orig) - ratio_clean = np.mean(psd_clean[:, noise_mask]) / np.mean(psd_clean) + ratio_orig = np.mean(psd_orig[noise_mask, :]) / np.mean(psd_orig) + ratio_clean = np.mean(psd_clean[noise_mask, :]) / np.mean(psd_clean) ax7.text( 0.5, @@ -933,7 +961,10 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig # Convenience function with simpler interface def remove_line_noise( - data: np.ndarray, sfreq: float, fline: optional[float] = None, **kwargs + data: np.ndarray, + sfreq: float, + fline: float | None = None, + **kwargs, ) -> np.ndarray: """Remove line noise from data using Zapline-plus. @@ -942,7 +973,7 @@ def remove_line_noise( Parameters ---------- - data : array, shape=(n_chans, n_times) + data : array, shape=(n_times, n_chans) Input data. sfreq : float Sampling frequency in Hz. @@ -953,7 +984,7 @@ def remove_line_noise( Returns ------- - clean_data : array, shape=(n_chans, n_times) + clean_data : array, shape=(n_times, n_chans) Cleaned data. Examples From 1c67f184fa025bd3dd34d38a3db7c3b8d0808274 Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 15:09:18 +0100 Subject: [PATCH 6/9] remove convenience interface; add 'optional' tag for inputs --- meegkit/dss_zapline.py | 106 ++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 70 deletions(-) diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py index 684a459..f806fd7 100644 --- a/meegkit/dss_zapline.py +++ b/meegkit/dss_zapline.py @@ -14,19 +14,14 @@ power line artifacts. NeuroImage, 207, 116356. -Differences from Matlab implementation: +Differences from Matlab implementation +-------------------------------------- Finding noise frequencies: - one iteration returning all frequencies -Adaptive chunking: -- merged chunks at edges if too short - Plotting: - only once per frequency after cleaning - - - """ import logging @@ -38,7 +33,7 @@ from meegkit.dss import dss_line -def zapline_plus( +def dss_line_plus( data: np.ndarray, sfreq: float, fline: float | list[float] | None = None, @@ -71,78 +66,78 @@ def zapline_plus( Parameters ---------- data : array, shape=(n_times, n_chans) - Input data. + Input data. Note that data is expected in time x channels format. sfreq : float Sampling frequency in Hz. fline : float | list of float | None Noise frequency or frequencies to remove. If None, frequencies are detected automatically. Defaults to None. - nkeep : int + nkeep : int | None Number of principal components to keep in DSS. If 0, no dimensionality reduction is applied. Defaults to 0. - adaptiveNremove : bool + adaptiveNremove : bool | None If True, automatically detect the number of components to remove. If False, use fixedNremove for all chunks. Defaults to True. - fixedNremove : int + fixedNremove : int | None Fixed number of components to remove per chunk. Used when adaptiveNremove=False, or as minimum when adaptiveNremove=True. Defaults to 1. - minfreq : float + minfreq : float | None Minimum frequency (Hz) to consider when detecting noise automatically. Defaults to 17.0. - maxfreq : float + maxfreq : float | None Maximum frequency (Hz) to consider when detecting noise automatically. Defaults to 99.0. - chunkLength : float + chunkLength : float | None Length of chunks (seconds) for cleaning. If 0, adaptive chunking based on noise covariance stability is used. Set to -1 via vanilla_mode to process the entire recording as a single chunk. Defaults to 0.0. - minChunkLength : float + minChunkLength : float | None Minimum chunk length (seconds) when using adaptive chunking. Defaults to 30.0. - noiseCompDetectSigma : float + noiseCompDetectSigma : float | None Initial SD threshold for iterative outlier detection of noise components. Defaults to 3.0. - adaptiveSigma : bool + adaptiveSigma : bool | None If True, automatically adapt noiseCompDetectSigma and fixedNremove based on cleaning results. Defaults to True. - minsigma : float + minsigma : float | None Minimum SD threshold when adapting noiseCompDetectSigma. Defaults to 2.5. - maxsigma : float + maxsigma : float | None Maximum SD threshold when adapting noiseCompDetectSigma. Defaults to 4.0. - detectionWinsize : float + detectionWinsize : float | None Window size (Hz) for noise frequency detection. Defaults to 6.0. - coarseFreqDetectPowerDiff : float + coarseFreqDetectPowerDiff : float | None Threshold (10*log10) above center power to detect a peak as noise. Defaults to 4.0. - coarseFreqDetectLowerPowerDiff : float + coarseFreqDetectLowerPowerDiff : float | None Threshold (10*log10) above center power to detect end of noise peak. Defaults to 1.76. - searchIndividualNoise : bool + searchIndividualNoise : bool | None If True, search for individual noise peaks in each chunk. Defaults to True. - freqDetectMultFine : float + freqDetectMultFine : float | None Multiplier for fine noise frequency detection threshold. Defaults to 2.0. - detailedFreqBoundsUpper : tuple of float + detailedFreqBoundsUpper : tuple of float | None Frequency boundaries (Hz) for fine threshold of too weak cleaning. Defaults to (0.05, 0.05). - detailedFreqBoundsLower : tuple of float + detailedFreqBoundsLower : tuple of float | None Frequency boundaries (Hz) for fine threshold of too strong cleaning. Defaults to (0.4, 0.1). - maxProportionAboveUpper : float + maxProportionAboveUpper : float | None Maximum proportion of samples above upper threshold before adapting. Defaults to 0.005. - maxProportionBelowLower : float + maxProportionBelowLower : float | None Maximum proportion of samples below lower threshold before adapting. Defaults to 0.005. - plotResults : bool + plotResults : bool | None If True, generate diagnostic plots for each cleaned frequency. Defaults to False. figsize : tuple of int Figure size for diagnostic plots. Defaults to (14, 10). - vanilla_mode : bool + vanilla_mode : bool | None If True, disable all Zapline-plus features and use vanilla Zapline behavior: - Process entire dataset as single chunk - Use fixed component removal (no adaptive detection) @@ -169,10 +164,10 @@ def zapline_plus( Examples -------- Remove 50 Hz line noise automatically: - >>> clean_data, config = zapline_plus(data, sfreq=500, fline=50) + >>> clean_data, config = dss_line_plus(data, sfreq=500, fline=50) Remove line noise with automatic frequency detection: - >>> clean_data, config = zapline_plus(data, sfreq=500) + >>> clean_data, config = dss_line_plus(data, sfreq=500) """ n_times, n_chans = data.shape @@ -672,6 +667,10 @@ def _detect_chunk_noise_frequency( def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): """Detect number of noise components to remove using outlier detection.""" + # Convert nkeep=0 to None for dss_line (0 means no reduction) + if nkeep == 0: + nkeep = None + # Apply DSS to get component scores _, scores = dss_line(data, target_freq, sfreq, nkeep=nkeep) @@ -703,6 +702,10 @@ def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): """Apply Zapline to a single chunk, handling flat channels.""" n_samples, n_chans = chunk_data.shape + # Convert nkeep=0 to None for dss_line (0 means no reduction) + if nkeep == 0: + nkeep = None + # Detect flat channels (zero variance) diff_chunk = np.diff(chunk_data, axis=0) flat_channels = np.where(np.all(diff_chunk == 0, axis=0))[0] @@ -957,40 +960,3 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig plt.show() return fig - - -# Convenience function with simpler interface -def remove_line_noise( - data: np.ndarray, - sfreq: float, - fline: float | None = None, - **kwargs, -) -> np.ndarray: - """Remove line noise from data using Zapline-plus. - - This is a simplified interface to zapline_plus() that returns only - the cleaned data. - - Parameters - ---------- - data : array, shape=(n_times, n_chans) - Input data. - sfreq : float - Sampling frequency in Hz. - fline : float | None - Line noise frequency. If None, automatically detected. - **kwargs - Additional arguments passed to zapline_plus(). - - Returns - ------- - clean_data : array, shape=(n_times, n_chans) - Cleaned data. - - Examples - -------- - >>> clean = remove_line_noise(data, sfreq=500, fline=50) - - """ - clean_data, _ = zapline_plus(data, sfreq, fline=fline, **kwargs) - return clean_data From 9c1e03fd33033e3b07af1b0ff3840ef56f0bb23e Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 15:41:14 +0100 Subject: [PATCH 7/9] add plus to dss module --- meegkit/dss.py | 941 +++++++++++++++++++++++++++++++++++++++- meegkit/dss_zapline.py | 962 ----------------------------------------- 2 files changed, 936 insertions(+), 967 deletions(-) delete mode 100644 meegkit/dss_zapline.py diff --git a/meegkit/dss.py b/meegkit/dss.py index 7297b7e..7f93261 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -1,12 +1,14 @@ """Denoising source separation.""" # Authors: Nicolas Barascud # Maciej Szul +import logging from pathlib import Path +import matplotlib.pyplot as plt import numpy as np from numpy.lib.stride_tricks import sliding_window_view from scipy import linalg -from scipy.signal import welch +from scipy.signal import butter, find_peaks, sosfiltfilt, welch from .tspca import tsr from .utils import ( @@ -291,10 +293,10 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5, show: bool Produce a visual output of each iteration (default=False). dirname: str - Path to the directory where visual outputs are saved when show is 'True'. + Path to the directory where visual outputs are saved when show is 'True'. If 'None', does not save the outputs. (default=None) extension: str - Extension of the images filenames. Must be compatible with plt.savefig() + Extension of the images filenames. Must be compatible with plt.savefig() function. (default=".png") n_iter_max : int Maximum number of iterations (default=100). @@ -317,7 +319,7 @@ def nan_basic_interp(array): freq_sp = [fline - spot_sz, fline + spot_sz] freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0) - freq_rn_ix = np.logical_and(freq >= freq_rn[0], + freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1]) freq_used = freq[freq_rn_ix] freq_sp_ix = np.logical_and(freq_used >= freq_sp[0], @@ -366,7 +368,7 @@ def nan_basic_interp(array): ax.flat[0].set_xlabel("Frequency (Hz)") ax.flat[0].set_ylabel("Power") - ax.flat[1].plot(freq_used, mean_psd_tf, c="gray", + ax.flat[1].plot(freq_used, mean_psd_tf, c="gray", label="Interpolated mean PSD") ax.flat[1].plot(freq_used, mean_psd, c="blue", label="Mean PSD") ax.flat[1].plot(freq_used, clean_fit_line, c="red", label="Fitted polynomial") @@ -405,3 +407,932 @@ def nan_basic_interp(array): "maximum number of iterations") return data, iterations + + +def dss_line_plus( + data: np.ndarray, + sfreq: float, + fline: float | list[float] | None = None, + nkeep: int = 0, + adaptiveNremove: bool = True, + fixedNremove: int = 1, + minfreq: float = 17.0, + maxfreq: float = 99.0, + chunkLength: float = 0.0, + minChunkLength: float = 30.0, + noiseCompDetectSigma: float = 3.0, + adaptiveSigma: bool = True, + minsigma: float = 2.5, + maxsigma: float = 4.0, + detectionWinsize: float = 6.0, + coarseFreqDetectPowerDiff: float = 4.0, + coarseFreqDetectLowerPowerDiff: float = 1.76, + searchIndividualNoise: bool = True, + freqDetectMultFine: float = 2.0, + detailedFreqBoundsUpper: tuple[float, float] = (0.05, 0.05), + detailedFreqBoundsLower: tuple[float, float] = (0.4, 0.1), + maxProportionAboveUpper: float = 0.005, + maxProportionBelowLower: float = 0.005, + plotResults: bool = False, + figsize: tuple[int, int] = (14, 10), + vanilla_mode: bool = False, +) -> tuple[np.ndarray, dict]: + """Remove line noise and other frequency-specific artifacts using Zapline-plus. + + Parameters + ---------- + data : array, shape=(n_times, n_chans) + Input data. Note that data is expected in time x channels format. + sfreq : float + Sampling frequency in Hz. + fline : float | list of float | None + Noise frequency or frequencies to remove. If None, frequencies are + detected automatically. Defaults to None. + nkeep : int | None + Number of principal components to keep in DSS. If 0, no dimensionality + reduction is applied. Defaults to 0. + adaptiveNremove : bool | None + If True, automatically detect the number of components to remove. + If False, use fixedNremove for all chunks. Defaults to True. + fixedNremove : int | None + Fixed number of components to remove per chunk. Used when + adaptiveNremove=False, or as minimum when adaptiveNremove=True. + Defaults to 1. + minfreq : float | None + Minimum frequency (Hz) to consider when detecting noise automatically. + Defaults to 17.0. + maxfreq : float | None + Maximum frequency (Hz) to consider when detecting noise automatically. + Defaults to 99.0. + chunkLength : float | None + Length of chunks (seconds) for cleaning. If 0, adaptive chunking based + on noise covariance stability is used. Set to -1 via vanilla_mode to + process the entire recording as a single chunk. Defaults to 0.0. + minChunkLength : float | None + Minimum chunk length (seconds) when using adaptive chunking. + Defaults to 30.0. + noiseCompDetectSigma : float | None + Initial SD threshold for iterative outlier detection of noise components. + Defaults to 3.0. + adaptiveSigma : bool | None + If True, automatically adapt noiseCompDetectSigma and fixedNremove + based on cleaning results. Defaults to True. + minsigma : float | None + Minimum SD threshold when adapting noiseCompDetectSigma. + Defaults to 2.5. + maxsigma : float | None + Maximum SD threshold when adapting noiseCompDetectSigma. + Defaults to 4.0. + detectionWinsize : float | None + Window size (Hz) for noise frequency detection. Defaults to 6.0. + coarseFreqDetectPowerDiff : float | None + Threshold (10*log10) above center power to detect a peak as noise. + Defaults to 4.0. + coarseFreqDetectLowerPowerDiff : float | None + Threshold (10*log10) above center power to detect end of noise peak. + Defaults to 1.76. + searchIndividualNoise : bool | None + If True, search for individual noise peaks in each chunk. + Defaults to True. + freqDetectMultFine : float | None + Multiplier for fine noise frequency detection threshold. Defaults to 2.0. + detailedFreqBoundsUpper : tuple of float | None + Frequency boundaries (Hz) for fine threshold of too weak cleaning. + Defaults to (0.05, 0.05). + detailedFreqBoundsLower : tuple of float | None + Frequency boundaries (Hz) for fine threshold of too strong cleaning. + Defaults to (0.4, 0.1). + maxProportionAboveUpper : float | None + Maximum proportion of samples above upper threshold before adapting. + Defaults to 0.005. + maxProportionBelowLower : float | None + Maximum proportion of samples below lower threshold before adapting. + Defaults to 0.005. + plotResults : bool | None + If True, generate diagnostic plots for each cleaned frequency. + Defaults to False. + figsize : tuple of int + Figure size for diagnostic plots. Defaults to (14, 10). + vanilla_mode : bool | None + If True, disable all Zapline-plus features and use vanilla Zapline behavior: + - Process entire dataset as single chunk + - Use fixed component removal (no adaptive detection) + - No individual chunk frequency detection + - No adaptive parameter tuning + Requires fline to be specified (not None). Defaults to False. + + Returns + ------- + clean_data : array, shape=(n_times, n_chans) + Cleaned data. + config : dict + Configuration dictionary containing all parameters and analytics. + + Notes + ----- + The algorithm proceeds as follows: + 1. Detect noise frequencies (if not provided) + 2. Segment data into chunks with stable noise topography + 3. Apply Zapline to each chunk + 4. Automatically detect and remove noise components + 5. Adapt parameters if cleaning is too weak or too strong + + Examples + -------- + Remove 50 Hz line noise automatically: + >>> clean_data, config = dss_line_plus(data, sfreq=500, fline=50) + + Remove line noise with automatic frequency detection: + >>> clean_data, config = dss_line_plus(data, sfreq=500) + + """ + n_times, n_chans = data.shape + + # Handle vanilla mode (ZapLine without plus) + if vanilla_mode: + logging.warning( + "vanilla_mode=True: Using vanilla Zapline behavior. " + "All adaptive features disabled." + ) + if fline is None: + raise ValueError("vanilla_mode requires fline to be specified (not None)") + + for param_name in [ + "adaptiveNremove", + "adaptiveSigma", + "searchIndividualNoise", + ]: + if locals()[param_name]: + logging.warning(f"vanilla_mode=True: Overriding {param_name} to False.") + + # Override all adaptive features + adaptiveNremove = False + adaptiveSigma = False + searchIndividualNoise = False + chunkLength = -1 # Zapline vanilla deals with single chunk + + # if nothing is adaptive, only one iteration per frequency + if not (adaptiveNremove and adaptiveSigma): + max_iterations = 1 + + # check for globally flat channels + # will be omitted during processing and reintroduced later + diff_data = np.diff(data, axis=0) + global_flat = np.where(np.all(diff_data == 0, axis=0))[0] + if len(global_flat) > 0: + logging.warning( + f"Detected {len(global_flat)} globally flat channels: {global_flat}. " + f"Removing for processing, will add back after." + ) + flat_data = data[:, global_flat] + active_channels = np.setdiff1d(np.arange(n_chans), global_flat) + data = data[:, active_channels] + else: + active_channels = np.arange(n_chans) + flat_data = None + + # initialize configuration + config = { + "sfreq": sfreq, + "fline": fline, + "nkeep": nkeep, + "adaptiveNremove": adaptiveNremove, + "fixedNremove": fixedNremove, + "minfreq": minfreq, + "maxfreq": maxfreq, + "chunkLength": chunkLength, + "minChunkLength": minChunkLength, + "noiseCompDetectSigma": noiseCompDetectSigma, + "adaptiveSigma": adaptiveSigma, + "minsigma": minsigma, + "maxsigma": maxsigma, + "detectionWinsize": detectionWinsize, + "coarseFreqDetectPowerDiff": coarseFreqDetectPowerDiff, + "coarseFreqDetectLowerPowerDiff": coarseFreqDetectLowerPowerDiff, + "searchIndividualNoise": searchIndividualNoise, + "freqDetectMultFine": freqDetectMultFine, + "detailedFreqBoundsUpper": detailedFreqBoundsUpper, + "detailedFreqBoundsLower": detailedFreqBoundsLower, + "maxProportionAboveUpper": maxProportionAboveUpper, + "maxProportionBelowLower": maxProportionBelowLower, + "analytics": {}, + } + + # detect noise frequencies if not provided + if fline is None: + fline = _detect_noise_frequencies( + data, + sfreq, + minfreq, + maxfreq, + detectionWinsize, + coarseFreqDetectPowerDiff, + coarseFreqDetectLowerPowerDiff, + ) + elif not isinstance(fline, list): + fline = [fline] + + if len(fline) == 0: + logging.info("No noise frequencies detected. Returning original data.") + return data.copy(), config + + config["detected_fline"] = fline + + # retain input data + clean_data = data.copy() + + # Process each noise frequency + for freq_idx, target_freq in enumerate(fline): + print(f"Processing noise frequency: {target_freq:.2f} Hz") + + if chunkLength == -1: + # single chunk + chunks = [(0, n_times)] + elif chunkLength == 0: + # adaptive chunking + chunks = _adaptive_chunking(clean_data, sfreq, target_freq, minChunkLength) + else: + # fixed-length chunks + chunk_samples = int(chunkLength * sfreq) + chunks = [ + (i, min(i + chunk_samples, n_times)) + for i in range(0, n_times, chunk_samples) + ] + + # initialize tracking variables + current_sigma = noiseCompDetectSigma + current_fixed = fixedNremove + too_strong_once = False + iteration = 0 + max_iterations = 20 + + while iteration < max_iterations: + iteration += 1 + + # Clean each chunk + chunk_results = [] + for chunk_start, chunk_end in chunks: + chunk_data = clean_data[chunk_start:chunk_end, :] + + # Detect chunk-specific noise frequency + if searchIndividualNoise: + chunk_freq, has_noise = _detect_chunk_noise_frequency( + chunk_data, + sfreq, + target_freq, + detectionWinsize, + freqDetectMultFine, + detailed_freq_bounds=detailedFreqBoundsUpper, + ) + else: + chunk_freq = target_freq + has_noise = True + + # Apply Zapline to chunk + if has_noise: + if adaptiveNremove: + n_remove = _detect_noise_components( + chunk_data, sfreq, chunk_freq, current_sigma, nkeep + ) + n_remove = max(n_remove, current_fixed) + else: + n_remove = current_fixed + + # Cap at 1/5 of components + n_remove = min(n_remove, n_chans // 5) + else: + n_remove = current_fixed + + # clean chunk + cleaned_chunk = _apply_zapline_to_chunk( + chunk_data, sfreq, chunk_freq, n_remove, nkeep + ) + + chunk_results.append( + { + "start": chunk_start, + "end": chunk_end, + "freq": chunk_freq, + "n_remove": n_remove, + "has_noise": has_noise, + "data": cleaned_chunk, + } + ) + + # reconstruct cleaned data + temp_clean = clean_data.copy() + for result in chunk_results: + temp_clean[result["start"] : result["end"], :] = result["data"] + + # check if cleaning is optimal + cleaning_status = _check_cleaning_quality( + data, + temp_clean, + sfreq, + target_freq, + detectionWinsize, + freqDetectMultFine, + detailedFreqBoundsUpper, + detailedFreqBoundsLower, + maxProportionAboveUpper, + maxProportionBelowLower, + ) + + # store analytics + config["analytics"][f"freq_{freq_idx}"] = { + "target_freq": target_freq, + "iteration": iteration, + "sigma": current_sigma, + "fixed_nremove": current_fixed, + "n_chunks": len(chunks), + "chunk_results": chunk_results, + "cleaning_status": cleaning_status, + } + + # check if we need to adapt + if cleaning_status == "good": + clean_data = temp_clean + break + + elif cleaning_status == "too_weak" and not too_strong_once: + if current_sigma > minsigma: + current_sigma = max(current_sigma - 0.25, minsigma) + current_fixed += 1 + logging.info( + f"Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + logging.info("At minimum sigma, accepting result") + clean_data = temp_clean + break + + elif cleaning_status == "too_strong": + too_strong_once = True + if current_sigma < maxsigma: + current_sigma = min(current_sigma + 0.25, maxsigma) + current_fixed = max(current_fixed - 1, fixedNremove) + logging.info( + f"Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " + f"fixed removal to {current_fixed}" + ) + else: + logging.info("At maximum sigma, accepting result") + clean_data = temp_clean + break + + else: + # Too strong takes precedence, or we can't improve further + clean_data = temp_clean + break + + # Generate diagnostic plot + if plotResults: + _plot_cleaning_results( + data, + clean_data, + sfreq, + target_freq, + config["analytics"][f"freq_{freq_idx}"], + figsize, + ) + + # add flat channels back to data, if present + if flat_data is not None: + full_clean = np.zeros((n_times, n_chans)) + full_clean[:, active_channels] = clean_data + full_clean[:, global_flat] = flat_data + clean_data = full_clean + + return clean_data, config + + +def _detect_noise_frequencies( + data, sfreq, minfreq, maxfreq, winsize, power_diff_high, power_diff_low +): + """ + Detect noise frequencies. + + This is an exact implementation of find_next_noisefreq.m with the only difference + that all peaks are returned instead of this being called iteratively. + + How it works + ------------ + 1. Compute PSD and log-transform. + 2. Slide a window across frequencies from minfreq to maxfreq. + 3. For each frequency, compute center power as mean of left and right thirds. + 4. Use a state machine to detect peaks: + - SEARCHING: If current power - center power > power_diff_high, + mark peak start and switch to IN_PEAK. + - IN_PEAK: If current power - center power <= power_diff_low, + mark peak end, find max within peak, record frequency, + and switch to SEARCHING. + 5. Return list of detected noise frequencies. + """ + # Compute PSD + freqs, psd = _compute_psd(data, sfreq) + log_psd = 10 * np.log10(np.mean(psd, axis=1)) + + # State machine variables + in_peak = False + peak_start_idx = None + noise_freqs = [] + + # Search bounds + start_idx = np.searchsorted(freqs, minfreq) + end_idx = np.searchsorted(freqs, maxfreq) + + # Window size in samples + freq_resolution = freqs[1] - freqs[0] + win_samples = int(winsize / freq_resolution) + + idx = start_idx + while idx < end_idx: + # Get window around current frequency + win_start = max(0, idx - win_samples // 2) + win_end = min(len(freqs), idx + win_samples // 2) + win_psd = log_psd[win_start:win_end] + + if len(win_psd) < 3: + idx += 1 + continue + + # Compute center power (mean of left and right thirds) + n_third = len(win_psd) // 3 + if n_third < 1: + idx += 1 + continue + + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center_power = np.mean([np.mean(left_third), np.mean(right_third)]) + + current_power = log_psd[idx] + + # State machine logic + if not in_peak: + # State: SEARCHING - Check for peak start + if current_power - center_power > power_diff_high: + in_peak = True + peak_start_idx = idx + + else: + # State: IN_PEAK - Check for peak end + if current_power - center_power <= power_diff_low: + in_peak = False + peak_end_idx = idx + + # Find the actual maximum within the peak + if peak_start_idx is not None and peak_end_idx > peak_start_idx: + peak_region = log_psd[peak_start_idx:peak_end_idx] + max_offset = np.argmax(peak_region) + max_idx = peak_start_idx + max_offset + noise_freqs.append(freqs[max_idx]) + + # Skip past this peak to avoid re-detection + idx = peak_end_idx + continue + + idx += 1 + + return noise_freqs + + +def _adaptive_chunking( + data, + sfreq, + target_freq, + min_chunk_length, + detection_winsize=6.0, + prominence_quantile=0.95, +): + """Segment data into chunks with stable noise topography.""" + n_times, n_chans = data.shape + + if n_times < sfreq * min_chunk_length: + logging.warning("Data too short for adaptive chunking. Using single chunk.") + return [(0, n_times)] + + # Narrow-band filter around target frequency + bandwidth = detection_winsize / 2.0 + filtered = _narrowband_filter(data, sfreq, target_freq, bandwidth=bandwidth) + + # Compute covariance matrices for 1-second epochs + epoch_length = int(sfreq) + n_epochs = n_times // epoch_length + + distances = np.zeros(n_epochs) + prev_cov = None + + for i in range(n_epochs): + start = i * epoch_length + end = start + epoch_length + epoch = filtered[start:end, :] + cov = np.cov(epoch, rowvar=False) + + if prev_cov is not None: + # Frobenius norm of difference + distances[i] = np.linalg.norm(cov - prev_cov, "fro") + # else: distance[i] already 0 from initialization + + prev_cov = cov + + if len(distances) < 2: + return [(0, n_times)] + + # find all peaks to get prominence distribution + peaks_all, properties_all = find_peaks(distances, prominence=0) + + if len(peaks_all) == 0 or "prominences" not in properties_all: + # No peaks found + logging.warning("No peaks found in distance signal. Using single chunk.") + return [(0, n_times)] + + prominences = properties_all["prominences"] + + # filter by prominence quantile + min_prominence = np.quantile(prominences, prominence_quantile) + min_distance_epochs = int(min_chunk_length) # Convert seconds to epochs + + peaks, properties = find_peaks( + distances, prominence=min_prominence, distance=min_distance_epochs + ) + + # convert peak locations (in epochs) to sample indices + chunk_starts = [0] + for peak in peaks: + chunk_start_sample = peak * epoch_length + chunk_starts.append(chunk_start_sample) + chunk_starts.append(n_times) + + # create chunk list + chunks = [] + for i in range(len(chunk_starts) - 1): + start = chunk_starts[i] + end = chunk_starts[i + 1] + chunks.append((start, end)) + + # ensure minimum chunk length at edges + min_chunk_samples = int(min_chunk_length * sfreq) + + if len(chunks) > 1: + # check first chunk + if chunks[0][1] - chunks[0][0] < min_chunk_samples: + # merge with next + chunks[1] = (chunks[0][0], chunks[1][1]) + chunks.pop(0) + + if len(chunks) > 1: + # check last chunk + if chunks[-1][1] - chunks[-1][0] < min_chunk_samples: + # merge with previous + chunks[-2] = (chunks[-2][0], chunks[-1][1]) + chunks.pop(-1) + + return chunks + + +def _detect_chunk_noise_frequency( + data, + sfreq, + target_freq, + winsize, + mult_fine, + detailed_freq_bounds=(-0.05, 0.05), # ← Add this parameter +): + """Detect chunk-specific noise frequency around target.""" + freqs, psd = _compute_psd(data, sfreq) + log_psd = 10 * np.log10(np.mean(psd, axis=1)) + + # get frequency mask + search_mask = (freqs >= target_freq + detailed_freq_bounds[0]) & ( + freqs <= target_freq + detailed_freq_bounds[1] + ) + + if not np.any(search_mask): + return target_freq, False + + search_freqs = freqs[search_mask] + search_psd = log_psd[search_mask] + + # find peak + peak_idx = np.argmax(search_psd) + peak_freq = search_freqs[peak_idx] + peak_power = search_psd[peak_idx] + + # Compute threshold (uses broader window) + win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) + win_psd = log_psd[win_mask] + + n_third = len(win_psd) // 3 + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center = np.mean([np.mean(left_third), np.mean(right_third)]) + + # Compute deviation (lower 5% quantiles) + lower_quant_left = np.percentile(left_third, 5) + lower_quant_right = np.percentile(right_third, 5) + deviation = center - np.mean([lower_quant_left, lower_quant_right]) + + threshold = center + mult_fine * deviation + + has_noise = peak_power > threshold + + return peak_freq, has_noise + + +def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): + """Detect number of noise components to remove using outlier detection.""" + # Convert nkeep=0 to None for dss_line (0 means no reduction) + if nkeep == 0: + nkeep = None + + # Apply DSS to get component scores + _, scores = dss_line(data, target_freq, sfreq, nkeep=nkeep) + + if scores is None or len(scores) == 0: + return 1 + + # Sort scores in descending order + sorted_scores = np.sort(scores)[::-1] + + # Iterative outlier detection + n_remove = 0 + remaining = sorted_scores.copy() + + while len(remaining) > 1: + mean_val = np.mean(remaining) + std_val = np.std(remaining) + threshold = mean_val + sigma * std_val + + if remaining[0] > threshold: + n_remove += 1 + remaining = remaining[1:] + else: + break + + return max(n_remove, 1) + + +def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): + """Apply Zapline to a single chunk, handling flat channels.""" + n_samples, n_chans = chunk_data.shape + + # Convert nkeep=0 to None for dss_line (0 means no reduction) + if nkeep == 0: + nkeep = None + + # Detect flat channels (zero variance) + diff_chunk = np.diff(chunk_data, axis=0) + flat_channels = np.where(np.all(diff_chunk == 0, axis=0))[0] + + if len(flat_channels) > 0: + logging.warning( + f"Detected {len(flat_channels)} flat channels in chunk: {flat_channels}. " + f"Removing temporarily for processing." + ) + + # store flat channel data + flat_channel_data = chunk_data[:, flat_channels] + + # remove flat channels from processing + active_channels = np.setdiff1d(np.arange(n_chans), flat_channels) + chunk_data_active = chunk_data[:, active_channels] + + # process only active channels + cleaned_active, _ = dss_line( + chunk_data_active, + fline=chunk_freq, + sfreq=sfreq, + nremove=n_remove, + nkeep=nkeep, + ) + + # Reconstruct full data with flat channels + cleaned_chunk = np.zeros_like(chunk_data) + cleaned_chunk[:, active_channels] = cleaned_active + cleaned_chunk[:, flat_channels] = ( + flat_channel_data # Add flat channels back unchanged + ) + + else: + # no flat channels, process normally + cleaned_chunk, _ = dss_line( + chunk_data, + fline=chunk_freq, + sfreq=sfreq, + nremove=n_remove, + nkeep=nkeep, + ) + + return cleaned_chunk + + +def _check_cleaning_quality( + original_data, + cleaned_data, + sfreq, + target_freq, + winsize, + mult_fine, + bounds_upper, + bounds_lower, + max_prop_above, + max_prop_below, +): + """Check if cleaning is too weak, too strong, or good.""" + # Compute PSDs + freqs, psd_clean = _compute_psd(cleaned_data, sfreq) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) + + # Compute fine thresholds + win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) + win_psd = log_psd_clean[win_mask] + + n_third = len(win_psd) // 3 + left_third = win_psd[:n_third] + right_third = win_psd[-n_third:] + center = np.mean([np.mean(left_third), np.mean(right_third)]) + + # Deviation from lower quantiles + lower_quant_left = np.percentile(left_third, 5) + lower_quant_right = np.percentile(right_third, 5) + deviation = center - np.mean([lower_quant_left, lower_quant_right]) + + # Upper threshold (too weak cleaning) + upper_mask = (freqs >= target_freq - bounds_upper[0]) & ( + freqs <= target_freq + bounds_upper[1] + ) + upper_threshold = center + mult_fine * deviation + upper_psd = log_psd_clean[upper_mask] + prop_above = np.mean(upper_psd > upper_threshold) + + # Lower threshold (too strong cleaning) + lower_mask = (freqs >= target_freq - bounds_lower[0]) & ( + freqs <= target_freq + bounds_lower[1] + ) + lower_threshold = center - mult_fine * deviation + lower_psd = log_psd_clean[lower_mask] + prop_below = np.mean(lower_psd < lower_threshold) + + if prop_below > max_prop_below: + return "too_strong" + elif prop_above > max_prop_above: + return "too_weak" + else: + return "good" + + +def _compute_psd(data, sfreq, nperseg=None): + """Compute power spectral density using Welch's method.""" + if nperseg is None: + nperseg = int(sfreq * 4) # 4-second windows + + freqs, psd = welch( + data, + fs=sfreq, + window="hann", + nperseg=nperseg, + axis=0, + ) + + return freqs, psd + + +def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): + """Apply narrow-band filter around center frequency.""" + nyq = sfreq / 2 + low = (center_freq - bandwidth) / nyq + high = (center_freq + bandwidth) / nyq + + # Ensure valid frequency range + low = max(low, 0.001) + high = min(high, 0.999) + + sos = butter(4, [low, high], btype="band", output="sos") + filtered = sosfiltfilt(sos, data, axis=0) + + return filtered + + +def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, figsize): + """Generate diagnostic plots for cleaning results.""" + fig = plt.figure(figsize=figsize) + gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3) + + # Compute PSDs + freqs, psd_orig = _compute_psd(original, sfreq) + _, psd_clean = _compute_psd(cleaned, sfreq) + + log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=1)) + log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) + + # 1. Zoomed spectrum around noise frequency + ax1 = fig.add_subplot(gs[0, 0]) + zoom_mask = (freqs >= target_freq - 1.1) & (freqs <= target_freq + 1.1) + ax1.plot(freqs[zoom_mask], log_psd_orig[zoom_mask], "k-", label="Original") + ax1.set_xlabel("Frequency (Hz)") + ax1.set_ylabel("Power (dB)") + ax1.set_title(f"Detected frequency: {target_freq:.2f} Hz") + ax1.legend() + ax1.grid(True, alpha=0.3) + + # 2. Number of removed components per chunk + ax2 = fig.add_subplot(gs[0, 1]) + chunk_results = analytics["chunk_results"] + n_removes = [cr["n_remove"] for cr in chunk_results] + ax2.bar(range(len(n_removes)), n_removes) + ax2.set_xlabel("Chunk") + ax2.set_ylabel("# Components removed") + ax2.set_title(f"Removed components (mean={np.mean(n_removes):.1f})") + ax2.grid(True, alpha=0.3) + + # 3. Individual noise frequencies per chunk + ax3 = fig.add_subplot(gs[0, 2]) + chunk_freqs = [cr["freq"] for cr in chunk_results] + time_min = np.array([cr["start"] for cr in chunk_results]) / sfreq / 60 + ax3.plot(time_min, chunk_freqs, "o-") + ax3.set_xlabel("Time (minutes)") + ax3.set_ylabel("Frequency (Hz)") + ax3.set_title("Individual noise frequencies") + ax3.grid(True, alpha=0.3) + + # 4. Component scores (would need actual scores from DSS) + ax4 = fig.add_subplot(gs[0, 3]) + ax4.text( + 0.5, + 0.5, + "Component scores\n(requires DSS output)", + ha="center", + va="center", + transform=ax4.transAxes, + ) + ax4.set_title("Mean artifact scores") + + # 5. Cleaned spectrum (zoomed) + ax5 = fig.add_subplot(gs[1, 0]) + ax5.plot(freqs[zoom_mask], log_psd_clean[zoom_mask], "g-", label="Cleaned") + ax5.set_xlabel("Frequency (Hz)") + ax5.set_ylabel("Power (dB)") + ax5.set_title("Cleaned spectrum") + ax5.legend() + ax5.grid(True, alpha=0.3) + + # 6. Full spectrum + ax6 = fig.add_subplot(gs[1, 1]) + ax6.plot(freqs, log_psd_orig, "k-", alpha=0.5, label="Original") + ax6.plot(freqs, log_psd_clean, "g-", label="Cleaned") + ax6.axvline(target_freq, color="r", linestyle="--", alpha=0.5) + ax6.set_xlabel("Frequency (Hz)") + ax6.set_ylabel("Power (dB)") + ax6.set_title("Full power spectrum") + ax6.legend() + ax6.grid(True, alpha=0.3) + ax6.set_xlim([0, 100]) + + # 7. Removed power (ratio) + ax7 = fig.add_subplot(gs[1, 2]) + noise_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) + ratio_orig = np.mean(psd_orig[noise_mask, :]) / np.mean(psd_orig) + ratio_clean = np.mean(psd_clean[noise_mask, :]) / np.mean(psd_clean) + + ax7.text( + 0.5, + 0.6, + f"Original ratio: {ratio_orig:.2f}", + ha="center", + transform=ax7.transAxes, + ) + ax7.text( + 0.5, + 0.4, + f"Cleaned ratio: {ratio_clean:.2f}", + ha="center", + transform=ax7.transAxes, + ) + ax7.set_title("Noise/surroundings ratio") + ax7.axis("off") + + # 8. Below noise frequencies + ax8 = fig.add_subplot(gs[1, 3]) + below_mask = (freqs >= target_freq - 11) & (freqs <= target_freq - 1) + ax8.plot( + freqs[below_mask], log_psd_orig[below_mask], "k-", alpha=0.5, label="Original" + ) + ax8.plot(freqs[below_mask], log_psd_clean[below_mask], "g-", label="Cleaned") + ax8.set_xlabel("Frequency (Hz)") + ax8.set_ylabel("Power (dB)") + ax8.set_title("Power below noise frequency") + ax8.legend() + ax8.grid(True, alpha=0.3) + + plt.suptitle( + f"Zapline-plus cleaning results: {target_freq:.2f} Hz " + f"(iteration {analytics['iteration']})", + fontsize=14, + y=0.98, + ) + + plt.show() + + return fig diff --git a/meegkit/dss_zapline.py b/meegkit/dss_zapline.py deleted file mode 100644 index f806fd7..0000000 --- a/meegkit/dss_zapline.py +++ /dev/null @@ -1,962 +0,0 @@ -"""Zapline-plus for automatic removal of frequency-specific noise artifacts. - -This module implements Zapline-plus, an extension of the Zapline algorithm -that enables fully automatic removal of line noise and other frequency-specific -artifacts from M/EEG data. - -Based on: -Klug, M., & Kloosterman, N. A. (2022). Zapline-plus: A Zapline extension for -automatic and adaptive removal of frequency-specific noise artifacts in M/EEG. -Human Brain Mapping, 43(9), 2743-2758. - -Original Zapline by: -de Cheveigné, A. (2020). ZapLine: A simple and effective method to remove -power line artifacts. NeuroImage, 207, 116356. - - -Differences from Matlab implementation --------------------------------------- - -Finding noise frequencies: -- one iteration returning all frequencies - -Plotting: -- only once per frequency after cleaning -""" - -import logging - -import matplotlib.pyplot as plt -import numpy as np -from scipy import signal - -from meegkit.dss import dss_line - - -def dss_line_plus( - data: np.ndarray, - sfreq: float, - fline: float | list[float] | None = None, - nkeep: int = 0, - adaptiveNremove: bool = True, - fixedNremove: int = 1, - minfreq: float = 17.0, - maxfreq: float = 99.0, - chunkLength: float = 0.0, - minChunkLength: float = 30.0, - noiseCompDetectSigma: float = 3.0, - adaptiveSigma: bool = True, - minsigma: float = 2.5, - maxsigma: float = 4.0, - detectionWinsize: float = 6.0, - coarseFreqDetectPowerDiff: float = 4.0, - coarseFreqDetectLowerPowerDiff: float = 1.76, - searchIndividualNoise: bool = True, - freqDetectMultFine: float = 2.0, - detailedFreqBoundsUpper: tuple[float, float] = (0.05, 0.05), - detailedFreqBoundsLower: tuple[float, float] = (0.4, 0.1), - maxProportionAboveUpper: float = 0.005, - maxProportionBelowLower: float = 0.005, - plotResults: bool = False, - figsize: tuple[int, int] = (14, 10), - vanilla_mode: bool = False, -) -> tuple[np.ndarray, dict]: - """Remove line noise and other frequency-specific artifacts using Zapline-plus. - - Parameters - ---------- - data : array, shape=(n_times, n_chans) - Input data. Note that data is expected in time x channels format. - sfreq : float - Sampling frequency in Hz. - fline : float | list of float | None - Noise frequency or frequencies to remove. If None, frequencies are - detected automatically. Defaults to None. - nkeep : int | None - Number of principal components to keep in DSS. If 0, no dimensionality - reduction is applied. Defaults to 0. - adaptiveNremove : bool | None - If True, automatically detect the number of components to remove. - If False, use fixedNremove for all chunks. Defaults to True. - fixedNremove : int | None - Fixed number of components to remove per chunk. Used when - adaptiveNremove=False, or as minimum when adaptiveNremove=True. - Defaults to 1. - minfreq : float | None - Minimum frequency (Hz) to consider when detecting noise automatically. - Defaults to 17.0. - maxfreq : float | None - Maximum frequency (Hz) to consider when detecting noise automatically. - Defaults to 99.0. - chunkLength : float | None - Length of chunks (seconds) for cleaning. If 0, adaptive chunking based - on noise covariance stability is used. Set to -1 via vanilla_mode to - process the entire recording as a single chunk. Defaults to 0.0. - minChunkLength : float | None - Minimum chunk length (seconds) when using adaptive chunking. - Defaults to 30.0. - noiseCompDetectSigma : float | None - Initial SD threshold for iterative outlier detection of noise components. - Defaults to 3.0. - adaptiveSigma : bool | None - If True, automatically adapt noiseCompDetectSigma and fixedNremove - based on cleaning results. Defaults to True. - minsigma : float | None - Minimum SD threshold when adapting noiseCompDetectSigma. - Defaults to 2.5. - maxsigma : float | None - Maximum SD threshold when adapting noiseCompDetectSigma. - Defaults to 4.0. - detectionWinsize : float | None - Window size (Hz) for noise frequency detection. Defaults to 6.0. - coarseFreqDetectPowerDiff : float | None - Threshold (10*log10) above center power to detect a peak as noise. - Defaults to 4.0. - coarseFreqDetectLowerPowerDiff : float | None - Threshold (10*log10) above center power to detect end of noise peak. - Defaults to 1.76. - searchIndividualNoise : bool | None - If True, search for individual noise peaks in each chunk. - Defaults to True. - freqDetectMultFine : float | None - Multiplier for fine noise frequency detection threshold. Defaults to 2.0. - detailedFreqBoundsUpper : tuple of float | None - Frequency boundaries (Hz) for fine threshold of too weak cleaning. - Defaults to (0.05, 0.05). - detailedFreqBoundsLower : tuple of float | None - Frequency boundaries (Hz) for fine threshold of too strong cleaning. - Defaults to (0.4, 0.1). - maxProportionAboveUpper : float | None - Maximum proportion of samples above upper threshold before adapting. - Defaults to 0.005. - maxProportionBelowLower : float | None - Maximum proportion of samples below lower threshold before adapting. - Defaults to 0.005. - plotResults : bool | None - If True, generate diagnostic plots for each cleaned frequency. - Defaults to False. - figsize : tuple of int - Figure size for diagnostic plots. Defaults to (14, 10). - vanilla_mode : bool | None - If True, disable all Zapline-plus features and use vanilla Zapline behavior: - - Process entire dataset as single chunk - - Use fixed component removal (no adaptive detection) - - No individual chunk frequency detection - - No adaptive parameter tuning - Requires fline to be specified (not None). Defaults to False. - - Returns - ------- - clean_data : array, shape=(n_times, n_chans) - Cleaned data. - config : dict - Configuration dictionary containing all parameters and analytics. - - Notes - ----- - The algorithm proceeds as follows: - 1. Detect noise frequencies (if not provided) - 2. Segment data into chunks with stable noise topography - 3. Apply Zapline to each chunk - 4. Automatically detect and remove noise components - 5. Adapt parameters if cleaning is too weak or too strong - - Examples - -------- - Remove 50 Hz line noise automatically: - >>> clean_data, config = dss_line_plus(data, sfreq=500, fline=50) - - Remove line noise with automatic frequency detection: - >>> clean_data, config = dss_line_plus(data, sfreq=500) - - """ - n_times, n_chans = data.shape - - # Handle vanilla mode (ZapLine without plus) - if vanilla_mode: - logging.warning( - "vanilla_mode=True: Using vanilla Zapline behavior. " - "All adaptive features disabled." - ) - if fline is None: - raise ValueError("vanilla_mode requires fline to be specified (not None)") - - for param_name in [ - "adaptiveNremove", - "adaptiveSigma", - "searchIndividualNoise", - ]: - if locals()[param_name]: - logging.warning(f"vanilla_mode=True: Overriding {param_name} to False.") - - # Override all adaptive features - adaptiveNremove = False - adaptiveSigma = False - searchIndividualNoise = False - chunkLength = -1 # Zapline vanilla deals with single chunk - - # if nothing is adaptive, only one iteration per frequency - if not (adaptiveNremove and adaptiveSigma): - max_iterations = 1 - - # check for globally flat channels - # will be omitted during processing and reintroduced later - diff_data = np.diff(data, axis=0) - global_flat = np.where(np.all(diff_data == 0, axis=0))[0] - if len(global_flat) > 0: - logging.warning( - f"Detected {len(global_flat)} globally flat channels: {global_flat}. " - f"Removing for processing, will add back after." - ) - flat_data = data[:, global_flat] - active_channels = np.setdiff1d(np.arange(n_chans), global_flat) - data = data[:, active_channels] - else: - active_channels = np.arange(n_chans) - flat_data = None - - # initialize configuration - config = { - "sfreq": sfreq, - "fline": fline, - "nkeep": nkeep, - "adaptiveNremove": adaptiveNremove, - "fixedNremove": fixedNremove, - "minfreq": minfreq, - "maxfreq": maxfreq, - "chunkLength": chunkLength, - "minChunkLength": minChunkLength, - "noiseCompDetectSigma": noiseCompDetectSigma, - "adaptiveSigma": adaptiveSigma, - "minsigma": minsigma, - "maxsigma": maxsigma, - "detectionWinsize": detectionWinsize, - "coarseFreqDetectPowerDiff": coarseFreqDetectPowerDiff, - "coarseFreqDetectLowerPowerDiff": coarseFreqDetectLowerPowerDiff, - "searchIndividualNoise": searchIndividualNoise, - "freqDetectMultFine": freqDetectMultFine, - "detailedFreqBoundsUpper": detailedFreqBoundsUpper, - "detailedFreqBoundsLower": detailedFreqBoundsLower, - "maxProportionAboveUpper": maxProportionAboveUpper, - "maxProportionBelowLower": maxProportionBelowLower, - "analytics": {}, - } - - # detect noise frequencies if not provided - if fline is None: - fline = _detect_noise_frequencies( - data, - sfreq, - minfreq, - maxfreq, - detectionWinsize, - coarseFreqDetectPowerDiff, - coarseFreqDetectLowerPowerDiff, - ) - elif not isinstance(fline, list): - fline = [fline] - - if len(fline) == 0: - logging.info("No noise frequencies detected. Returning original data.") - return data.copy(), config - - config["detected_fline"] = fline - - # retain input data - clean_data = data.copy() - - # Process each noise frequency - for freq_idx, target_freq in enumerate(fline): - print(f"Processing noise frequency: {target_freq:.2f} Hz") - - if chunkLength == -1: - # single chunk - chunks = [(0, n_times)] - elif chunkLength == 0: - # adaptive chunking - chunks = _adaptive_chunking(clean_data, sfreq, target_freq, minChunkLength) - else: - # fixed-length chunks - chunk_samples = int(chunkLength * sfreq) - chunks = [ - (i, min(i + chunk_samples, n_times)) - for i in range(0, n_times, chunk_samples) - ] - - # initialize tracking variables - current_sigma = noiseCompDetectSigma - current_fixed = fixedNremove - too_strong_once = False - iteration = 0 - max_iterations = 20 - - while iteration < max_iterations: - iteration += 1 - - # Clean each chunk - chunk_results = [] - for chunk_start, chunk_end in chunks: - chunk_data = clean_data[chunk_start:chunk_end, :] - - # Detect chunk-specific noise frequency - if searchIndividualNoise: - chunk_freq, has_noise = _detect_chunk_noise_frequency( - chunk_data, - sfreq, - target_freq, - detectionWinsize, - freqDetectMultFine, - detailed_freq_bounds=detailedFreqBoundsUpper, - ) - else: - chunk_freq = target_freq - has_noise = True - - # Apply Zapline to chunk - if has_noise: - if adaptiveNremove: - n_remove = _detect_noise_components( - chunk_data, sfreq, chunk_freq, current_sigma, nkeep - ) - n_remove = max(n_remove, current_fixed) - else: - n_remove = current_fixed - - # Cap at 1/5 of components - n_remove = min(n_remove, n_chans // 5) - else: - n_remove = current_fixed - - # clean chunk - cleaned_chunk = _apply_zapline_to_chunk( - chunk_data, sfreq, chunk_freq, n_remove, nkeep - ) - - chunk_results.append( - { - "start": chunk_start, - "end": chunk_end, - "freq": chunk_freq, - "n_remove": n_remove, - "has_noise": has_noise, - "data": cleaned_chunk, - } - ) - - # reconstruct cleaned data - temp_clean = clean_data.copy() - for result in chunk_results: - temp_clean[result["start"] : result["end"], :] = result["data"] - - # check if cleaning is optimal - cleaning_status = _check_cleaning_quality( - data, - temp_clean, - sfreq, - target_freq, - detectionWinsize, - freqDetectMultFine, - detailedFreqBoundsUpper, - detailedFreqBoundsLower, - maxProportionAboveUpper, - maxProportionBelowLower, - ) - - # store analytics - config["analytics"][f"freq_{freq_idx}"] = { - "target_freq": target_freq, - "iteration": iteration, - "sigma": current_sigma, - "fixed_nremove": current_fixed, - "n_chunks": len(chunks), - "chunk_results": chunk_results, - "cleaning_status": cleaning_status, - } - - # check if we need to adapt - if cleaning_status == "good": - clean_data = temp_clean - break - - elif cleaning_status == "too_weak" and not too_strong_once: - if current_sigma > minsigma: - current_sigma = max(current_sigma - 0.25, minsigma) - current_fixed += 1 - logging.info( - f"Cleaning too weak. Adjusting sigma to {current_sigma:.2f}, " - f"fixed removal to {current_fixed}" - ) - else: - logging.info("At minimum sigma, accepting result") - clean_data = temp_clean - break - - elif cleaning_status == "too_strong": - too_strong_once = True - if current_sigma < maxsigma: - current_sigma = min(current_sigma + 0.25, maxsigma) - current_fixed = max(current_fixed - 1, fixedNremove) - logging.info( - f"Cleaning too strong. Adjusting sigma to {current_sigma:.2f}, " - f"fixed removal to {current_fixed}" - ) - else: - logging.info("At maximum sigma, accepting result") - clean_data = temp_clean - break - - else: - # Too strong takes precedence, or we can't improve further - clean_data = temp_clean - break - - # Generate diagnostic plot - if plotResults: - _plot_cleaning_results( - data, - clean_data, - sfreq, - target_freq, - config["analytics"][f"freq_{freq_idx}"], - figsize, - ) - - # add flat channels back to data, if present - if flat_data is not None: - full_clean = np.zeros((n_times, n_chans)) - full_clean[:, active_channels] = clean_data - full_clean[:, global_flat] = flat_data - clean_data = full_clean - - return clean_data, config - - -def _detect_noise_frequencies( - data, sfreq, minfreq, maxfreq, winsize, power_diff_high, power_diff_low -): - """ - Detect noise frequencies. - - This is an exact implementation of find_next_noisefreq.m with the only difference - that all peaks are returned instead of this being called iteratively. - - How it works - ------------ - 1. Compute PSD and log-transform. - 2. Slide a window across frequencies from minfreq to maxfreq. - 3. For each frequency, compute center power as mean of left and right thirds. - 4. Use a state machine to detect peaks: - - SEARCHING: If current power - center power > power_diff_high, - mark peak start and switch to IN_PEAK. - - IN_PEAK: If current power - center power <= power_diff_low, - mark peak end, find max within peak, record frequency, - and switch to SEARCHING. - 5. Return list of detected noise frequencies. - """ - # Compute PSD - freqs, psd = _compute_psd(data, sfreq) - log_psd = 10 * np.log10(np.mean(psd, axis=1)) - - # State machine variables - in_peak = False - peak_start_idx = None - noise_freqs = [] - - # Search bounds - start_idx = np.searchsorted(freqs, minfreq) - end_idx = np.searchsorted(freqs, maxfreq) - - # Window size in samples - freq_resolution = freqs[1] - freqs[0] - win_samples = int(winsize / freq_resolution) - - idx = start_idx - while idx < end_idx: - # Get window around current frequency - win_start = max(0, idx - win_samples // 2) - win_end = min(len(freqs), idx + win_samples // 2) - win_psd = log_psd[win_start:win_end] - - if len(win_psd) < 3: - idx += 1 - continue - - # Compute center power (mean of left and right thirds) - n_third = len(win_psd) // 3 - if n_third < 1: - idx += 1 - continue - - left_third = win_psd[:n_third] - right_third = win_psd[-n_third:] - center_power = np.mean([np.mean(left_third), np.mean(right_third)]) - - current_power = log_psd[idx] - - # State machine logic - if not in_peak: - # State: SEARCHING - Check for peak start - if current_power - center_power > power_diff_high: - in_peak = True - peak_start_idx = idx - - else: - # State: IN_PEAK - Check for peak end - if current_power - center_power <= power_diff_low: - in_peak = False - peak_end_idx = idx - - # Find the actual maximum within the peak - if peak_start_idx is not None and peak_end_idx > peak_start_idx: - peak_region = log_psd[peak_start_idx:peak_end_idx] - max_offset = np.argmax(peak_region) - max_idx = peak_start_idx + max_offset - noise_freqs.append(freqs[max_idx]) - - # Skip past this peak to avoid re-detection - idx = peak_end_idx - continue - - idx += 1 - - return noise_freqs - - -def _adaptive_chunking( - data, - sfreq, - target_freq, - min_chunk_length, - detection_winsize=6.0, - prominence_quantile=0.95, -): - """Segment data into chunks with stable noise topography.""" - n_times, n_chans = data.shape - - if n_times < sfreq * min_chunk_length: - logging.warning("Data too short for adaptive chunking. Using single chunk.") - return [(0, n_times)] - - # Narrow-band filter around target frequency - bandwidth = detection_winsize / 2.0 - filtered = _narrowband_filter(data, sfreq, target_freq, bandwidth=bandwidth) - - # Compute covariance matrices for 1-second epochs - epoch_length = int(sfreq) - n_epochs = n_times // epoch_length - - distances = np.zeros(n_epochs) - prev_cov = None - - for i in range(n_epochs): - start = i * epoch_length - end = start + epoch_length - epoch = filtered[start:end, :] - cov = np.cov(epoch, rowvar=False) - - if prev_cov is not None: - # Frobenius norm of difference - distances[i] = np.linalg.norm(cov - prev_cov, "fro") - # else: distance[i] already 0 from initialization - - prev_cov = cov - - if len(distances) < 2: - return [(0, n_times)] - - # find all peaks to get prominence distribution - peaks_all, properties_all = signal.find_peaks(distances, prominence=0) - - if len(peaks_all) == 0 or "prominences" not in properties_all: - # No peaks found - logging.warning("No peaks found in distance signal. Using single chunk.") - return [(0, n_times)] - - prominences = properties_all["prominences"] - - # filter by prominence quantile - min_prominence = np.quantile(prominences, prominence_quantile) - min_distance_epochs = int(min_chunk_length) # Convert seconds to epochs - - peaks, properties = signal.find_peaks( - distances, prominence=min_prominence, distance=min_distance_epochs - ) - - # cconvert peak locations (in epochs) to sample indices - chunk_starts = [0] - for peak in peaks: - chunk_start_sample = peak * epoch_length - chunk_starts.append(chunk_start_sample) - chunk_starts.append(n_times) - - # create chunk list - chunks = [] - for i in range(len(chunk_starts) - 1): - start = chunk_starts[i] - end = chunk_starts[i + 1] - chunks.append((start, end)) - - # ensure minimum chunk length at edges - min_chunk_samples = int(min_chunk_length * sfreq) - - if len(chunks) > 1: - # check first chunk - if chunks[0][1] - chunks[0][0] < min_chunk_samples: - # merge with next - chunks[1] = (chunks[0][0], chunks[1][1]) - chunks.pop(0) - - if len(chunks) > 1: - # check last chunk - if chunks[-1][1] - chunks[-1][0] < min_chunk_samples: - # merge with previous - chunks[-2] = (chunks[-2][0], chunks[-1][1]) - chunks.pop(-1) - - return chunks - - -def _detect_chunk_noise_frequency( - data, - sfreq, - target_freq, - winsize, - mult_fine, - detailed_freq_bounds=(-0.05, 0.05), # ← Add this parameter -): - """Detect chunk-specific noise frequency around target.""" - freqs, psd = _compute_psd(data, sfreq) - log_psd = 10 * np.log10(np.mean(psd, axis=1)) - - # get frequency mask - search_mask = (freqs >= target_freq + detailed_freq_bounds[0]) & ( - freqs <= target_freq + detailed_freq_bounds[1] - ) - - if not np.any(search_mask): - return target_freq, False - - search_freqs = freqs[search_mask] - search_psd = log_psd[search_mask] - - # find peak - peak_idx = np.argmax(search_psd) - peak_freq = search_freqs[peak_idx] - peak_power = search_psd[peak_idx] - - # Compute threshold (uses broader window) - win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) - win_psd = log_psd[win_mask] - - n_third = len(win_psd) // 3 - left_third = win_psd[:n_third] - right_third = win_psd[-n_third:] - center = np.mean([np.mean(left_third), np.mean(right_third)]) - - # Compute deviation (lower 5% quantiles) - lower_quant_left = np.percentile(left_third, 5) - lower_quant_right = np.percentile(right_third, 5) - deviation = center - np.mean([lower_quant_left, lower_quant_right]) - - threshold = center + mult_fine * deviation - - has_noise = peak_power > threshold - - return peak_freq, has_noise - - -def _detect_noise_components(data, sfreq, target_freq, sigma, nkeep): - """Detect number of noise components to remove using outlier detection.""" - # Convert nkeep=0 to None for dss_line (0 means no reduction) - if nkeep == 0: - nkeep = None - - # Apply DSS to get component scores - _, scores = dss_line(data, target_freq, sfreq, nkeep=nkeep) - - if scores is None or len(scores) == 0: - return 1 - - # Sort scores in descending order - sorted_scores = np.sort(scores)[::-1] - - # Iterative outlier detection - n_remove = 0 - remaining = sorted_scores.copy() - - while len(remaining) > 1: - mean_val = np.mean(remaining) - std_val = np.std(remaining) - threshold = mean_val + sigma * std_val - - if remaining[0] > threshold: - n_remove += 1 - remaining = remaining[1:] - else: - break - - return max(n_remove, 1) - - -def _apply_zapline_to_chunk(chunk_data, sfreq, chunk_freq, n_remove, nkeep): - """Apply Zapline to a single chunk, handling flat channels.""" - n_samples, n_chans = chunk_data.shape - - # Convert nkeep=0 to None for dss_line (0 means no reduction) - if nkeep == 0: - nkeep = None - - # Detect flat channels (zero variance) - diff_chunk = np.diff(chunk_data, axis=0) - flat_channels = np.where(np.all(diff_chunk == 0, axis=0))[0] - - if len(flat_channels) > 0: - logging.warning( - f"Detected {len(flat_channels)} flat channels in chunk: {flat_channels}. " - f"Removing temporarily for processing." - ) - - # store flat channel data - flat_channel_data = chunk_data[:, flat_channels] - - # remove flat channels from processing - active_channels = np.setdiff1d(np.arange(n_chans), flat_channels) - chunk_data_active = chunk_data[:, active_channels] - - # process only active channels - cleaned_active, _ = dss_line( - chunk_data_active, - fline=chunk_freq, - sfreq=sfreq, - nremove=n_remove, - nkeep=nkeep, - ) - - # Reconstruct full data with flat channels - cleaned_chunk = np.zeros_like(chunk_data) - cleaned_chunk[:, active_channels] = cleaned_active - cleaned_chunk[:, flat_channels] = ( - flat_channel_data # Add flat channels back unchanged - ) - - else: - # no flat channels, process normally - cleaned_chunk, _ = dss_line( - chunk_data, - fline=chunk_freq, - sfreq=sfreq, - nremove=n_remove, - nkeep=nkeep, - ) - - return cleaned_chunk - - -def _check_cleaning_quality( - original_data, - cleaned_data, - sfreq, - target_freq, - winsize, - mult_fine, - bounds_upper, - bounds_lower, - max_prop_above, - max_prop_below, -): - """Check if cleaning is too weak, too strong, or good.""" - # Compute PSDs - freqs, psd_clean = _compute_psd(cleaned_data, sfreq) - log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) - - # Compute fine thresholds - win_mask = (freqs >= target_freq - winsize / 2) & (freqs <= target_freq + winsize / 2) - win_psd = log_psd_clean[win_mask] - - n_third = len(win_psd) // 3 - left_third = win_psd[:n_third] - right_third = win_psd[-n_third:] - center = np.mean([np.mean(left_third), np.mean(right_third)]) - - # Deviation from lower quantiles - lower_quant_left = np.percentile(left_third, 5) - lower_quant_right = np.percentile(right_third, 5) - deviation = center - np.mean([lower_quant_left, lower_quant_right]) - - # Upper threshold (too weak cleaning) - upper_mask = (freqs >= target_freq - bounds_upper[0]) & ( - freqs <= target_freq + bounds_upper[1] - ) - upper_threshold = center + mult_fine * deviation - upper_psd = log_psd_clean[upper_mask] - prop_above = np.mean(upper_psd > upper_threshold) - - # Lower threshold (too strong cleaning) - lower_mask = (freqs >= target_freq - bounds_lower[0]) & ( - freqs <= target_freq + bounds_lower[1] - ) - lower_threshold = center - mult_fine * deviation - lower_psd = log_psd_clean[lower_mask] - prop_below = np.mean(lower_psd < lower_threshold) - - if prop_below > max_prop_below: - return "too_strong" - elif prop_above > max_prop_above: - return "too_weak" - else: - return "good" - - -def _compute_psd(data, sfreq, nperseg=None): - """Compute power spectral density using Welch's method.""" - if nperseg is None: - nperseg = int(sfreq * 4) # 4-second windows - - freqs, psd = signal.welch( - data, - fs=sfreq, - window="hann", - nperseg=nperseg, - axis=0, - ) - - return freqs, psd - - -def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): - """Apply narrow-band filter around center frequency.""" - nyq = sfreq / 2 - low = (center_freq - bandwidth) / nyq - high = (center_freq + bandwidth) / nyq - - # Ensure valid frequency range - low = max(low, 0.001) - high = min(high, 0.999) - - sos = signal.butter(4, [low, high], btype="band", output="sos") - filtered = signal.sosfiltfilt(sos, data, axis=0) - - return filtered - - -def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, figsize): - """Generate diagnostic plots for cleaning results.""" - fig = plt.figure(figsize=figsize) - gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3) - - # Compute PSDs - freqs, psd_orig = _compute_psd(original, sfreq) - _, psd_clean = _compute_psd(cleaned, sfreq) - - log_psd_orig = 10 * np.log10(np.mean(psd_orig, axis=1)) - log_psd_clean = 10 * np.log10(np.mean(psd_clean, axis=1)) - - # 1. Zoomed spectrum around noise frequency - ax1 = fig.add_subplot(gs[0, 0]) - zoom_mask = (freqs >= target_freq - 1.1) & (freqs <= target_freq + 1.1) - ax1.plot(freqs[zoom_mask], log_psd_orig[zoom_mask], "k-", label="Original") - ax1.set_xlabel("Frequency (Hz)") - ax1.set_ylabel("Power (dB)") - ax1.set_title(f"Detected frequency: {target_freq:.2f} Hz") - ax1.legend() - ax1.grid(True, alpha=0.3) - - # 2. Number of removed components per chunk - ax2 = fig.add_subplot(gs[0, 1]) - chunk_results = analytics["chunk_results"] - n_removes = [cr["n_remove"] for cr in chunk_results] - ax2.bar(range(len(n_removes)), n_removes) - ax2.set_xlabel("Chunk") - ax2.set_ylabel("# Components removed") - ax2.set_title(f"Removed components (mean={np.mean(n_removes):.1f})") - ax2.grid(True, alpha=0.3) - - # 3. Individual noise frequencies per chunk - ax3 = fig.add_subplot(gs[0, 2]) - chunk_freqs = [cr["freq"] for cr in chunk_results] - time_min = np.array([cr["start"] for cr in chunk_results]) / sfreq / 60 - ax3.plot(time_min, chunk_freqs, "o-") - ax3.set_xlabel("Time (minutes)") - ax3.set_ylabel("Frequency (Hz)") - ax3.set_title("Individual noise frequencies") - ax3.grid(True, alpha=0.3) - - # 4. Component scores (would need actual scores from DSS) - ax4 = fig.add_subplot(gs[0, 3]) - ax4.text( - 0.5, - 0.5, - "Component scores\n(requires DSS output)", - ha="center", - va="center", - transform=ax4.transAxes, - ) - ax4.set_title("Mean artifact scores") - - # 5. Cleaned spectrum (zoomed) - ax5 = fig.add_subplot(gs[1, 0]) - ax5.plot(freqs[zoom_mask], log_psd_clean[zoom_mask], "g-", label="Cleaned") - ax5.set_xlabel("Frequency (Hz)") - ax5.set_ylabel("Power (dB)") - ax5.set_title("Cleaned spectrum") - ax5.legend() - ax5.grid(True, alpha=0.3) - - # 6. Full spectrum - ax6 = fig.add_subplot(gs[1, 1]) - ax6.plot(freqs, log_psd_orig, "k-", alpha=0.5, label="Original") - ax6.plot(freqs, log_psd_clean, "g-", label="Cleaned") - ax6.axvline(target_freq, color="r", linestyle="--", alpha=0.5) - ax6.set_xlabel("Frequency (Hz)") - ax6.set_ylabel("Power (dB)") - ax6.set_title("Full power spectrum") - ax6.legend() - ax6.grid(True, alpha=0.3) - ax6.set_xlim([0, 100]) - - # 7. Removed power (ratio) - ax7 = fig.add_subplot(gs[1, 2]) - noise_mask = (freqs >= target_freq - 0.05) & (freqs <= target_freq + 0.05) - ratio_orig = np.mean(psd_orig[noise_mask, :]) / np.mean(psd_orig) - ratio_clean = np.mean(psd_clean[noise_mask, :]) / np.mean(psd_clean) - - ax7.text( - 0.5, - 0.6, - f"Original ratio: {ratio_orig:.2f}", - ha="center", - transform=ax7.transAxes, - ) - ax7.text( - 0.5, - 0.4, - f"Cleaned ratio: {ratio_clean:.2f}", - ha="center", - transform=ax7.transAxes, - ) - ax7.set_title("Noise/surroundings ratio") - ax7.axis("off") - - # 8. Below noise frequencies - ax8 = fig.add_subplot(gs[1, 3]) - below_mask = (freqs >= target_freq - 11) & (freqs <= target_freq - 1) - ax8.plot( - freqs[below_mask], log_psd_orig[below_mask], "k-", alpha=0.5, label="Original" - ) - ax8.plot(freqs[below_mask], log_psd_clean[below_mask], "g-", label="Cleaned") - ax8.set_xlabel("Frequency (Hz)") - ax8.set_ylabel("Power (dB)") - ax8.set_title("Power below noise frequency") - ax8.legend() - ax8.grid(True, alpha=0.3) - - plt.suptitle( - f"Zapline-plus cleaning results: {target_freq:.2f} Hz " - f"(iteration {analytics['iteration']})", - fontsize=14, - y=0.98, - ) - - plt.show() - - return fig From 70ec70839a9bba33bf9fabb2691627fc6f6d0f98 Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 15:53:00 +0100 Subject: [PATCH 8/9] add directory for plot saving --- meegkit/dss.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/meegkit/dss.py b/meegkit/dss.py index 7f93261..1713d33 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -436,6 +436,7 @@ def dss_line_plus( plotResults: bool = False, figsize: tuple[int, int] = (14, 10), vanilla_mode: bool = False, + dirname: str = None ) -> tuple[np.ndarray, dict]: """Remove line noise and other frequency-specific artifacts using Zapline-plus. @@ -520,6 +521,9 @@ def dss_line_plus( - No individual chunk frequency detection - No adaptive parameter tuning Requires fline to be specified (not None). Defaults to False. + dirname: str + Path to the directory where visual outputs are saved when show is 'True'. + If 'None', does not save the outputs. Defaults to None. Returns ------- @@ -795,6 +799,7 @@ def dss_line_plus( target_freq, config["analytics"][f"freq_{freq_idx}"], figsize, + dirname, ) # add flat channels back to data, if present @@ -1215,7 +1220,15 @@ def _narrowband_filter(data, sfreq, center_freq, bandwidth=3.0): return filtered -def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, figsize): +def _plot_cleaning_results( + original, + cleaned, + sfreq, + target_freq, + analytics, + figsize, + dirname, +): """Generate diagnostic plots for cleaning results.""" fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3) @@ -1335,4 +1348,7 @@ def _plot_cleaning_results(original, cleaned, sfreq, target_freq, analytics, fig plt.show() + if dirname is not None: + plt.savefig(f"{dirname}/dss_line_plus_results.png") + return fig From 9aed25ee08ad933dd9145f89a3ef988455253664 Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Sun, 25 Jan 2026 15:53:30 +0100 Subject: [PATCH 9/9] update test to show and save plots in temp path --- tests/test_dss.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/tests/test_dss.py b/tests/test_dss.py index 3816b82..c1ea0a3 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -1,5 +1,4 @@ """Test DSS functions.""" -import os from tempfile import TemporaryDirectory import matplotlib.pyplot as plt @@ -125,6 +124,7 @@ def _plot(x): out, _ = dss.dss_line(s, fline, sr, nremove=1) plt.close("all") + def test_dss_line_iter(): """Test line noise removal.""" @@ -199,6 +199,78 @@ def profile_dss_line(nkeep): ps.print_stats() print(s.getvalue()) + +@pytest.mark.parametrize("mode", ["vanilla", "adaptive"]) +def test_dss_line_plus(mode): + """Test zapline-plus line noise removal.""" + + sfreq = 250 + freq = 50 + fline = freq / sfreq + + n_samples = 30 * sfreq + n_chans = 16 + + # create synthetic data (time × chan × trial) + data, _ = create_line_data( + n_samples=n_samples, + n_chans=n_chans, + n_trials=1, + fline=fline, + SNR=1.0, + noise_dim=20, + ) + + # zapline-plus expects 2D + data = data[..., 0] + + if mode == "vanilla": + with TemporaryDirectory() as tmpdir: + clean, _ = dss.dss_line_plus( + data, + sfreq, + fline=freq, + vanilla_mode=True, + fixedNremove=1, + plotResults=True, + dirname=tmpdir, + ) + else: + with TemporaryDirectory() as tmpdir: + clean, _ = dss.dss_line_plus( + data, + sfreq, + fline=None, + adaptiveNremove=True, + adaptiveSigma=True, + chunkLength=0, + minChunkLength=10, + plotResults=True, + dirname=tmpdir, + ) + + assert clean.shape == data.shape + + # PSD comparison + freqs, psd_orig = signal.welch(data, fs=sfreq, nperseg=sfreq * 2, axis=0) + _, psd_clean = signal.welch(clean, fs=sfreq, nperseg=sfreq * 2, axis=0) + + psd_orig = psd_orig.mean(axis=1) + psd_clean = psd_clean.mean(axis=1) + + idx = np.argmin(np.abs(freqs - freq)) + + reduction = 1 - psd_clean[idx] / psd_orig[idx] + + # Zapline should remove the vast majority of line noise + assert reduction > 0.85 + + # Broadband signal should be preserved + band = (freqs > 5) & (freqs < 40) + ratio = psd_clean[band].mean() / psd_orig[band].mean() + + assert 0.7 < ratio < 1.3 + if __name__ == "__main__": pytest.main([__file__]) # create_data(SNR=5, show=True) @@ -206,3 +278,4 @@ def profile_dss_line(nkeep): # test_dss_line(2) # test_dss_line_iter() # profile_dss_line(None) + # test_dss_line_plus("adaptive")