From bccd7a1f9ad5d92fdad762cec2f09def71cc7ba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 16 Mar 2023 15:57:00 +0100 Subject: [PATCH 01/17] Implement merging of AP and LFP channels --- spikeinterface/preprocessing/__init__.py | 2 + spikeinterface/preprocessing/merge_ap_lfp.py | 84 +++++++++++++++++++ .../preprocessing/tests/test_merge_ap_lfp.py | 43 ++++++++++ 3 files changed, 129 insertions(+) create mode 100644 spikeinterface/preprocessing/merge_ap_lfp.py create mode 100644 spikeinterface/preprocessing/tests/test_merge_ap_lfp.py diff --git a/spikeinterface/preprocessing/__init__.py b/spikeinterface/preprocessing/__init__.py index d891d9e9d3..f5272d7299 100644 --- a/spikeinterface/preprocessing/__init__.py +++ b/spikeinterface/preprocessing/__init__.py @@ -4,6 +4,8 @@ from .detect_bad_channels import detect_bad_channels from .correct_lsb import correct_lsb +from .merge_ap_lfp import generate_RC_filter, MergeApLfpRecording + #for snippets from .align_snippets import AlignSnippets \ No newline at end of file diff --git a/spikeinterface/preprocessing/merge_ap_lfp.py b/spikeinterface/preprocessing/merge_ap_lfp.py new file mode 100644 index 0000000000..62417c3310 --- /dev/null +++ b/spikeinterface/preprocessing/merge_ap_lfp.py @@ -0,0 +1,84 @@ +from typing import List, Union +import numpy as np + +from ..core import BaseRecording, BaseRecordingSegment + + +class MergeApLfpRecording(BaseRecording): + """ + Add cool description here. + + Parameters + ---------- + ap_recording: BaseRecording + The recording of the AP channels. + lfp_recording: BaseRecording + The recording of the LFP channels. + + Returns + -------- + merged_ap_lfp_recording: MergeApLfpRecording + The result of the merge of both channels (with the whole frequency spectrum). + """ + + def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording) -> None: + BaseRecording.__init__(self, ap_recording.sampling_frequency, ap_recording.channel_ids, ap_recording.dtype) + ap_recording.copy_metadata(self) + + for segment_index in range(ap_recording.get_num_segments()): + recording_segment = MergeApLfpRecordingSegment(ap_recording._recording_segments[segment_index], lfp_recording._recording_segments[segment_index]) + self.add_recording_segment(recording_segment) + + self._kwargs = { + 'ap_recording': ap_recording.to_dict(), + 'lfp_recording': lfp_recording.to_dict() + } + + +class MergeApLfpRecordingSegment(BaseRecordingSegment): + + def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_segment: BaseRecordingSegment) -> None: + self.ap_recording = ap_recording_segment + self.lfp_recording = lfp_recording_segment + + + def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, + channel_indices: Union[List, None] = None) -> np.ndarray: + # TODO + return self.ap_recording.get_traces(start_frame, end_frame, channel_indices) + + +def generate_RC_filter(frequencies: np.ndarray, cut: Union[float, List[float]], btype: str = "bandpass") -> np.ndarray: + """ + Generates the transfer function of a single pole RC filter. + + Parameters + ---------- + frequencies: np.ndarray + The frequencies (in Hz) for which to generate the transfer function. + cut: float | list[float] + The cutoff frequency/frequencies (in Hz). + Should be a float for lowpass/highpass and a list of 2 floats for bandpass. + btype: str + The type of filter. In "lowpass", "highpass", "bandpass". + + Returns + ------- + transfer_function: np.ndarray + The transfer function of the filter for each frequencies. + """ + + highpass = np.ones(len(frequencies), dtype=np.complex128) + lowpass = np.ones(len(frequencies), dtype=np.complex128) + + if btype == "lowpass": + lowpass = 1 / (1 + 1j * frequencies / cut) + elif btype == "highpass": + highpass = (frequencies / cut) / (1 + 1j * frequencies / cut) + elif btype == "bandpass": + highpass = generate_RC_filter(frequencies, cut[0], btype="highpass") + lowpass = generate_RC_filter(frequencies, cut[1], btype="lowpass") + else: + raise AttributeError(f"btype '{btype}' is invalid for generate_RC_filter.") + + return lowpass * highpass diff --git a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py new file mode 100644 index 0000000000..a56938dce9 --- /dev/null +++ b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -0,0 +1,43 @@ +import numpy as np + +from spikeinterface.core import NumpyRecording +from spikeinterface.preprocessing import generate_RC_filter, MergeApLfpRecording + + +def test_generate_RC_filter(): + frequencies = np.arange(0, 15001, 1, dtype=np.float64) + transfer_func = np.abs(generate_RC_filter(frequencies, [300, 10000])) + + assert abs(transfer_func[300] - 10**(-3/20)) <= 1e-2 + assert abs(transfer_func[10000] - 10**(-3/20)) <= 1e-2 + assert abs(transfer_func[10] / transfer_func[1] - 10.0) <= 1e-2 + + +def test_MergeApLfpRecording(): + sf = 30000 + + # Generate a 1-second 2-channels white noise recording. + original_trace = np.array([np.random.normal(loc=0.0, scale=1.0, size=sf), np.random.normal(loc=0.0, scale=1.0, size=sf)]).T + original_fourier = np.fft.rfft(original_trace, axis=0) + freq = np.fft.rfftfreq(original_trace.shape[0], d=1/sf) + + ap_filter = generate_RC_filter(freq, [300, 10000]) + lfp_filter = generate_RC_filter(freq, [0.5, 500]) + + fourier_ap = original_fourier * ap_filter[:, None] + fourier_lfp = original_fourier * lfp_filter[:, None] + + trace_ap = np.fft.irfft(fourier_ap, axis=0) + trace_lfp = np.fft.irfft(fourier_lfp, axis=0)[::12] + + ap_recording = NumpyRecording(trace_ap, sf) + lfp_recording = NumpyRecording(trace_lfp, sf/12) + + merged_recording = MergeApLfpRecording(ap_recording, lfp_recording) + + # TODO: Test the get_traces. + + +if __name__ == '__main__': + test_generate_RC_filter() + test_MergeApLfpRecording() From 0e2af43496de31d87bf7c58f705634c1419d3463 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 16 Mar 2023 17:40:38 +0100 Subject: [PATCH 02/17] WIP on merging AP LFP channels --- spikeinterface/core/numpyextractors.py | 4 ++ spikeinterface/core/recording_tools.py | 2 +- spikeinterface/preprocessing/__init__.py | 2 +- spikeinterface/preprocessing/merge_ap_lfp.py | 71 ++++++++++++++++--- .../preprocessing/tests/test_merge_ap_lfp.py | 12 ++-- 5 files changed, 73 insertions(+), 18 deletions(-) diff --git a/spikeinterface/core/numpyextractors.py b/spikeinterface/core/numpyextractors.py index 159da4d399..40232a3c38 100644 --- a/spikeinterface/core/numpyextractors.py +++ b/spikeinterface/core/numpyextractors.py @@ -69,6 +69,10 @@ def __init__(self, traces, sampling_frequency, t_start): def get_num_samples(self): return self._traces.shape[0] + @property + def dtype(self): + return self._traces.dtype + def get_traces(self, start_frame, end_frame, channel_indices): traces = self._traces[start_frame:end_frame, :] if channel_indices is not None: diff --git a/spikeinterface/core/recording_tools.py b/spikeinterface/core/recording_tools.py index 53fe2c7148..d80d3a8dae 100644 --- a/spikeinterface/core/recording_tools.py +++ b/spikeinterface/core/recording_tools.py @@ -128,7 +128,7 @@ def get_chunk_with_margin(rec_segment, start_frame, end_frame, channel_indices = slice(None) if not add_zeros: - assert not window_on_margin, 'window_mon_margin can be used only for add_zeros=True' + assert not window_on_margin, 'window_on_margin can be used only for add_zeros=True' if start_frame is None: left_margin = 0 start_frame = 0 diff --git a/spikeinterface/preprocessing/__init__.py b/spikeinterface/preprocessing/__init__.py index f5272d7299..75c5a0017c 100644 --- a/spikeinterface/preprocessing/__init__.py +++ b/spikeinterface/preprocessing/__init__.py @@ -4,7 +4,7 @@ from .detect_bad_channels import detect_bad_channels from .correct_lsb import correct_lsb -from .merge_ap_lfp import generate_RC_filter, MergeApLfpRecording +from .merge_ap_lfp import generate_RC_filter, MergeApLfpRecording, MergeNeuropixels1Recording #for snippets diff --git a/spikeinterface/preprocessing/merge_ap_lfp.py b/spikeinterface/preprocessing/merge_ap_lfp.py index 62417c3310..2d7424e84b 100644 --- a/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/spikeinterface/preprocessing/merge_ap_lfp.py @@ -1,7 +1,7 @@ -from typing import List, Union +from typing import Callable, List, Union import numpy as np -from ..core import BaseRecording, BaseRecordingSegment +from ..core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin class MergeApLfpRecording(BaseRecording): @@ -14,6 +14,14 @@ class MergeApLfpRecording(BaseRecording): The recording of the AP channels. lfp_recording: BaseRecording The recording of the LFP channels. + ap_filter: Callable + Transfer function of the filter used in the ap_recording. + Takes the frequencies as parameter, and outputs the transfer function. + lfp_filter: Callable + Transfer function of the filter used in the lfp_recording. + margin: int + The margin (in samples) to use when extracting the trace. + Takes the frequencies as parameter, and outputs the transfer function. Returns -------- @@ -21,31 +29,74 @@ class MergeApLfpRecording(BaseRecording): The result of the merge of both channels (with the whole frequency spectrum). """ - def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording) -> None: + def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, ap_filter: Callable[[np.ndarray], np.ndarray], + lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int = 1000) -> None: BaseRecording.__init__(self, ap_recording.sampling_frequency, ap_recording.channel_ids, ap_recording.dtype) ap_recording.copy_metadata(self) for segment_index in range(ap_recording.get_num_segments()): - recording_segment = MergeApLfpRecordingSegment(ap_recording._recording_segments[segment_index], lfp_recording._recording_segments[segment_index]) - self.add_recording_segment(recording_segment) + ap_recording_segment = ap_recording._recording_segments[segment_index] + lfp_recording_segment = lfp_recording._recording_segments[segment_index] + self.add_recording_segment(MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin)) - self._kwargs = { + self._kwargs = { # TODO: Is callable serializable? 'ap_recording': ap_recording.to_dict(), - 'lfp_recording': lfp_recording.to_dict() + 'lfp_recording': lfp_recording.to_dict(), + 'margin': margin } class MergeApLfpRecordingSegment(BaseRecordingSegment): - def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_segment: BaseRecordingSegment) -> None: + def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_segment: BaseRecordingSegment, + ap_filter: Callable[[np.ndarray], np.ndarray], lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int) -> None: self.ap_recording = ap_recording_segment self.lfp_recording = lfp_recording_segment + self.ap_filter = ap_filter + self.lfp_filter = lfp_filter + self.margin = margin def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None) -> np.ndarray: - # TODO - return self.ap_recording.get_traces(start_frame, end_frame, channel_indices) + ap_traces, left_margin_ap, right_margin_ap = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin) + lfp_traces, left_margin_lfp, right_margin_lfp = get_chunk_with_margin(self.lfp_recording, start_frame, end_frame, channel_indices, self.margin) + + ap_fourier = np.fft.rfft(ap_traces, axis=0) + lfp_fourier = np.fft.rfft(lfp_traces, axis=0) + ap_freq = np.fft.rfftfreq(ap_traces.shape[0], d=1/self.ap_recording.sampling_frequency) + lfp_freq = np.fft.rfftfreq(lfp_traces.shape[0], d=1/self.lfp_recording.sampling_frequency) + + ap_filter = self.ap_filter(ap_freq) + lfp_filter = self.lfp_filter(lfp_freq) + ap_filter = np.where(ap_filter == 0, 1.0, ap_filter) + lfp_filter = np.where(lfp_filter == 0, 1.0, lfp_filter) + + reconstructed_ap_fourier = ap_fourier / ap_filter[:, None] + reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None] + + # TODO: LFP anti-aliasing + # TODO: reconstruct using both files + + reconstructed_traces = np.fft.irfft(reconstructed_ap_fourier, axis=0) # TODO: is a placeholder + + if right_margin_ap == 0: + right_margin_ap = -reconstructed_traces.shape[0] + + reconstructed_traces = reconstructed_traces[left_margin_ap : -right_margin_ap] + + return reconstructed_traces.astype(self.ap_recording.dtype) + + +class MergeNeuropixels1Recording(MergeApLfpRecording): + """ + + """ + + def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, margin: int = 1000) -> None: + ap_filter = lambda f : generate_RC_filter(f, [300, 10000]) + lfp_filter = lambda f : generate_RC_filter(f, [0.5, 500]) + MergeApLfpRecording.__init__(self, ap_recording, lfp_recording, ap_filter, lfp_filter, margin) def generate_RC_filter(frequencies: np.ndarray, cut: Union[float, List[float]], btype: str = "bandpass") -> np.ndarray: diff --git a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index a56938dce9..97b9c99529 100644 --- a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -1,7 +1,7 @@ import numpy as np from spikeinterface.core import NumpyRecording -from spikeinterface.preprocessing import generate_RC_filter, MergeApLfpRecording +from spikeinterface.preprocessing import generate_RC_filter, MergeNeuropixels1Recording def test_generate_RC_filter(): @@ -17,9 +17,9 @@ def test_MergeApLfpRecording(): sf = 30000 # Generate a 1-second 2-channels white noise recording. - original_trace = np.array([np.random.normal(loc=0.0, scale=1.0, size=sf), np.random.normal(loc=0.0, scale=1.0, size=sf)]).T - original_fourier = np.fft.rfft(original_trace, axis=0) - freq = np.fft.rfftfreq(original_trace.shape[0], d=1/sf) + original_traces = np.array([np.random.normal(loc=0.0, scale=1.0, size=sf), np.random.normal(loc=0.0, scale=1.0, size=sf)]).T + original_fourier = np.fft.rfft(original_traces, axis=0) + freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) ap_filter = generate_RC_filter(freq, [300, 10000]) lfp_filter = generate_RC_filter(freq, [0.5, 500]) @@ -33,9 +33,9 @@ def test_MergeApLfpRecording(): ap_recording = NumpyRecording(trace_ap, sf) lfp_recording = NumpyRecording(trace_lfp, sf/12) - merged_recording = MergeApLfpRecording(ap_recording, lfp_recording) + merged_recording = MergeNeuropixels1Recording(ap_recording, lfp_recording) - # TODO: Test the get_traces. + assert original_traces.shape == merged_recording.get_traces().shape if __name__ == '__main__': From f9fe209b4b09cc2167af60b4e0c4e76d347a4574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 17 Mar 2023 15:08:47 +0100 Subject: [PATCH 03/17] WIP - Merging AP/LFP channels --- spikeinterface/preprocessing/merge_ap_lfp.py | 65 +++++++++++++++---- .../preprocessing/tests/test_merge_ap_lfp.py | 40 +++++++++++- 2 files changed, 91 insertions(+), 14 deletions(-) diff --git a/spikeinterface/preprocessing/merge_ap_lfp.py b/spikeinterface/preprocessing/merge_ap_lfp.py index 2d7424e84b..02ec07f892 100644 --- a/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/spikeinterface/preprocessing/merge_ap_lfp.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Union +from typing import Callable, ClassVar, List, Union import numpy as np from ..core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin @@ -30,7 +30,7 @@ class MergeApLfpRecording(BaseRecording): """ def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, ap_filter: Callable[[np.ndarray], np.ndarray], - lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int = 1000) -> None: + lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int = 60_000) -> None: BaseRecording.__init__(self, ap_recording.sampling_frequency, ap_recording.channel_ids, ap_recording.dtype) ap_recording.copy_metadata(self) @@ -57,10 +57,28 @@ def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_seg self.margin = margin + def get_num_samples(self) -> int: + return self.ap_recording.get_num_samples() + + def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None) -> np.ndarray: - ap_traces, left_margin_ap, right_margin_ap = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin) - lfp_traces, left_margin_lfp, right_margin_lfp = get_chunk_with_margin(self.lfp_recording, start_frame, end_frame, channel_indices, self.margin) + AP_TO_LFP = int(round(self.ap_recording.sampling_frequency / self.lfp_recording.sampling_frequency)) + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + + assert end_frame % AP_TO_LFP == 0 # Fix this. + + ap_traces, left_margin, right_margin = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin + AP_TO_LFP) + + left_leftover = (AP_TO_LFP - (start_frame - left_margin) % AP_TO_LFP) % AP_TO_LFP + left_margin -= left_leftover + + ap_traces = ap_traces[left_leftover:] + + lfp_traces = self.lfp_recording.get_traces((start_frame - left_margin) // AP_TO_LFP, (end_frame + right_margin) // AP_TO_LFP, channel_indices) ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0) @@ -75,15 +93,40 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, reconstructed_ap_fourier = ap_fourier / ap_filter[:, None] reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None] - # TODO: LFP anti-aliasing - # TODO: reconstruct using both files + # Compute aliasing of high frequencies on LFP channels + # TODO: There may be a faster way than computing the Fourier transform + lfp_nyquist = self.lfp_recording.sampling_frequency / 2 + fourier_aliased = reconstructed_ap_fourier.copy() + fourier_aliased[ap_freq <= lfp_nyquist] = 0.0 + fourier_aliased *= self.lfp_filter(ap_freq)[:, None] + traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[::AP_TO_LFP] + fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] + fourier_aliased = fourier_aliased[:np.searchsorted(ap_freq, lfp_nyquist, side="right")] + lfp_aa_fourier = reconstructed_lfp_fourier - fourier_aliased + + # Reconstruct using both AP and LFP channels + # TODO: Have some flexibility on the ratio + lfp_filt = self.lfp_filter(ap_freq) + ratio = np.abs(lfp_filt[1:]) / (np.abs(lfp_filt[1:]) + np.abs(ap_filter[1:])) + ratio = 1 / (1 + np.exp(-6 * np.tan(np.pi * (ratio - 0.5)))) + ratio = ratio[:, None] + + fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) + idx = np.searchsorted(ap_freq, lfp_nyquist, side="right") + fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:] + fourier_reconstructed[:idx] = AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * (1 - ratio[:idx]) + + # To get back to the 0.5 - 10,000 Hz original filter + # filter_reconstructed = generate_RC_filter(ap_freq, [0.5, 10000])[:, None] + # fourier_reconstructed *= filter_reconstructed + + reconstructed_traces = np.fft.irfft(fourier_reconstructed, axis=0) - reconstructed_traces = np.fft.irfft(reconstructed_ap_fourier, axis=0) # TODO: is a placeholder - if right_margin_ap == 0: - right_margin_ap = -reconstructed_traces.shape[0] + if right_margin == 0: + right_margin = -reconstructed_traces.shape[0] - reconstructed_traces = reconstructed_traces[left_margin_ap : -right_margin_ap] + reconstructed_traces = reconstructed_traces[left_margin : -right_margin] return reconstructed_traces.astype(self.ap_recording.dtype) @@ -93,7 +136,7 @@ class MergeNeuropixels1Recording(MergeApLfpRecording): """ - def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, margin: int = 1000) -> None: + def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, margin: int = 60_000) -> None: ap_filter = lambda f : generate_RC_filter(f, [300, 10000]) lfp_filter = lambda f : generate_RC_filter(f, [0.5, 500]) MergeApLfpRecording.__init__(self, ap_recording, lfp_recording, ap_filter, lfp_filter, margin) diff --git a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index 97b9c99529..11e343e0a4 100644 --- a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -15,12 +15,19 @@ def test_generate_RC_filter(): def test_MergeApLfpRecording(): sf = 30000 + T = 10 - # Generate a 1-second 2-channels white noise recording. - original_traces = np.array([np.random.normal(loc=0.0, scale=1.0, size=sf), np.random.normal(loc=0.0, scale=1.0, size=sf)]).T + # Generate a 10-seconds 2-channels white noise recording. + rng = np.random.RandomState(seed=420) + original_traces = np.array([rng.normal(loc=0.0, scale=1.0, size=T*sf), np.random.normal(loc=0.0, scale=1.0, size=T*sf)]).T original_fourier = np.fft.rfft(original_traces, axis=0) freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) + # Remove 0Hz (can't be reconstructed) and Nyquist frequency (behave weirdly). + original_fourier[0] = 0.0 + original_fourier[-1] = 0.0 + original_traces = np.fft.irfft(original_fourier, axis=0) + ap_filter = generate_RC_filter(freq, [300, 10000]) lfp_filter = generate_RC_filter(freq, [0.5, 500]) @@ -34,8 +41,35 @@ def test_MergeApLfpRecording(): lfp_recording = NumpyRecording(trace_lfp, sf/12) merged_recording = MergeNeuropixels1Recording(ap_recording, lfp_recording) + merged_traces = merged_recording.get_traces() + + assert original_traces.shape == merged_traces.shape + assert np.allclose(original_traces, merged_traces, rtol=1e-3, atol=1e-4) + + traces = merged_recording.get_traces(start_frame=100, end_frame=30000) + + # print(original_traces[1000:1010, 0]) + # print(traces[900:910, 0]) + # print(traces[900:910, :]) + + # import plotly.graph_objects as go + # fig = go.Figure() + + # fig.add_trace(go.Scattergl( + # x=np.fft.rfftfreq(len(original_traces[100:30000, 0]), d=1/sf), + # y=np.abs(np.fft.rfft(original_traces[100:30000, 0])), + # mode="lines", + # name="Original" + # )) + # fig.add_trace(go.Scattergl( + # x=np.fft.rfftfreq(traces.shape[0], d=1/sf), + # y=np.abs(np.fft.rfft(traces[:, 0])), + # mode="lines", + # name="Merged" + # )) - assert original_traces.shape == merged_recording.get_traces().shape + # fig.update_xaxes(type="log") + # fig.show() if __name__ == '__main__': From 4d93f4da96f24a0d28d8f3bb8f31b8fb2b7e5118 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 27 Mar 2023 11:13:10 +0200 Subject: [PATCH 04/17] WIP on merging AP/LFP channels --- spikeinterface/preprocessing/merge_ap_lfp.py | 2 + .../preprocessing/tests/test_merge_ap_lfp.py | 70 +++++++++++++------ 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/spikeinterface/preprocessing/merge_ap_lfp.py b/spikeinterface/preprocessing/merge_ap_lfp.py index 02ec07f892..6d4cd798bf 100644 --- a/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/spikeinterface/preprocessing/merge_ap_lfp.py @@ -50,6 +50,8 @@ class MergeApLfpRecordingSegment(BaseRecordingSegment): def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_segment: BaseRecordingSegment, ap_filter: Callable[[np.ndarray], np.ndarray], lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int) -> None: + BaseRecordingSegment.__init__(self, ap_recording_segment.sampling_frequency, ap_recording_segment.t_start) + self.ap_recording = ap_recording_segment self.lfp_recording = lfp_recording_segment self.ap_filter = ap_filter diff --git a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index 11e343e0a4..5e5169c67d 100644 --- a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -1,9 +1,20 @@ import numpy as np +import pytest -from spikeinterface.core import NumpyRecording +from spikeinterface.core import NumpyRecording, load_extractor, set_global_tmp_folder +from spikeinterface.core.testing import check_recordings_equal from spikeinterface.preprocessing import generate_RC_filter, MergeNeuropixels1Recording +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "preprocessing" / "merge_ap_lfp" +else: + cache_folder = Path("cache_folder") / "preprocessing" / "merge_ap_lfp" + +set_global_tmp_folder(cache_folder) +cache_folder.mkdir(parents=True, exist_ok=True) + + def test_generate_RC_filter(): frequencies = np.arange(0, 15001, 1, dtype=np.float64) transfer_func = np.abs(generate_RC_filter(frequencies, [300, 10000])) @@ -46,30 +57,43 @@ def test_MergeApLfpRecording(): assert original_traces.shape == merged_traces.shape assert np.allclose(original_traces, merged_traces, rtol=1e-3, atol=1e-4) - traces = merged_recording.get_traces(start_frame=100, end_frame=30000) - - # print(original_traces[1000:1010, 0]) - # print(traces[900:910, 0]) - # print(traces[900:910, :]) - - # import plotly.graph_objects as go - # fig = go.Figure() - - # fig.add_trace(go.Scattergl( - # x=np.fft.rfftfreq(len(original_traces[100:30000, 0]), d=1/sf), - # y=np.abs(np.fft.rfft(original_traces[100:30000, 0])), - # mode="lines", - # name="Original" - # )) - # fig.add_trace(go.Scattergl( - # x=np.fft.rfftfreq(traces.shape[0], d=1/sf), - # y=np.abs(np.fft.rfft(traces[:, 0])), - # mode="lines", - # name="Merged" - # )) + # Check dumpability + saved_loaded = load_extractor(merged_recording.to_dict()) + check_recordings_equal(merged_recording, saved_loaded, return_scaled=False) + + # Check chunks + chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration='1s') + chunked_traces = chunked_recording.get_traces() + + assert np.all(np.abs(merged_traces - chunked_traces)[500:-500] < 0.05) + + import plotly.graph_objects as go + fig = go.Figure() + + fig.add_trace(go.Scatter( + x=np.arange(sf*T), + y=merged_traces[:, 0], + mode="lines", + name="Non-chunked" + )) + fig.add_trace(go.Scatter( + x=np.arange(sf*T), + y=chunked_traces[:, 0], + mode="lines", + name="Chunked" + )) + fig.add_trace(go.Scatter( + x=np.arange(sf*T), + y=merged_traces[:, 0] - chunked_traces[:, 0], + mode="lines", + name="Difference" + )) + + for i in range(1, T): + fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)") # fig.update_xaxes(type="log") - # fig.show() + fig.show() if __name__ == '__main__': From e62cc32d0eee1b7f5035d883fc7a2cbba34a64d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 27 Mar 2023 17:07:40 +0200 Subject: [PATCH 05/17] WIP for merging AP/LFP channels --- spikeinterface/preprocessing/merge_ap_lfp.py | 23 +++++---- .../preprocessing/tests/test_merge_ap_lfp.py | 50 +++++++++---------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/spikeinterface/preprocessing/merge_ap_lfp.py b/spikeinterface/preprocessing/merge_ap_lfp.py index 6d4cd798bf..da93a56eb3 100644 --- a/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/spikeinterface/preprocessing/merge_ap_lfp.py @@ -58,29 +58,33 @@ def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_seg self.lfp_filter = lfp_filter self.margin = margin + self.AP_TO_LFP = int(round(ap_recording_segment.sampling_frequency / lfp_recording_segment.sampling_frequency)) + def get_num_samples(self) -> int: - return self.ap_recording.get_num_samples() + # Trunk the recording to have a number of samples that is a multiple of 'AP_TO_LFP'. + return self.ap_recording.get_num_samples() - (self.ap_recording.get_num_samples() % self.AP_TO_LFP) def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None) -> np.ndarray: - AP_TO_LFP = int(round(self.ap_recording.sampling_frequency / self.lfp_recording.sampling_frequency)) if start_frame is None: start_frame = 0 if end_frame is None: end_frame = self.get_num_samples() - assert end_frame % AP_TO_LFP == 0 # Fix this. - - ap_traces, left_margin, right_margin = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin + AP_TO_LFP) + ap_traces, left_margin, right_margin = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin + self.AP_TO_LFP) - left_leftover = (AP_TO_LFP - (start_frame - left_margin) % AP_TO_LFP) % AP_TO_LFP + left_leftover = (self.AP_TO_LFP - (start_frame - left_margin) % self.AP_TO_LFP) % self.AP_TO_LFP left_margin -= left_leftover + right_leftover = (end_frame + right_margin) % self.AP_TO_LFP + right_margin -= right_leftover + if right_leftover > 0: + ap_traces = ap_traces[:right_leftover] ap_traces = ap_traces[left_leftover:] - lfp_traces = self.lfp_recording.get_traces((start_frame - left_margin) // AP_TO_LFP, (end_frame + right_margin) // AP_TO_LFP, channel_indices) + lfp_traces = self.lfp_recording.get_traces((start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices) ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0) @@ -96,12 +100,11 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None] # Compute aliasing of high frequencies on LFP channels - # TODO: There may be a faster way than computing the Fourier transform lfp_nyquist = self.lfp_recording.sampling_frequency / 2 fourier_aliased = reconstructed_ap_fourier.copy() fourier_aliased[ap_freq <= lfp_nyquist] = 0.0 fourier_aliased *= self.lfp_filter(ap_freq)[:, None] - traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[::AP_TO_LFP] + traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[::self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] fourier_aliased = fourier_aliased[:np.searchsorted(ap_freq, lfp_nyquist, side="right")] lfp_aa_fourier = reconstructed_lfp_fourier - fourier_aliased @@ -116,7 +119,7 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) idx = np.searchsorted(ap_freq, lfp_nyquist, side="right") fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:] - fourier_reconstructed[:idx] = AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * (1 - ratio[:idx]) + fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * (1 - ratio[:idx]) # To get back to the 0.5 - 10,000 Hz original filter # filter_reconstructed = generate_RC_filter(ap_freq, [0.5, 10000])[:, None] diff --git a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index 5e5169c67d..ea27880004 100644 --- a/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -67,33 +67,33 @@ def test_MergeApLfpRecording(): assert np.all(np.abs(merged_traces - chunked_traces)[500:-500] < 0.05) - import plotly.graph_objects as go - fig = go.Figure() - - fig.add_trace(go.Scatter( - x=np.arange(sf*T), - y=merged_traces[:, 0], - mode="lines", - name="Non-chunked" - )) - fig.add_trace(go.Scatter( - x=np.arange(sf*T), - y=chunked_traces[:, 0], - mode="lines", - name="Chunked" - )) - fig.add_trace(go.Scatter( - x=np.arange(sf*T), - y=merged_traces[:, 0] - chunked_traces[:, 0], - mode="lines", - name="Difference" - )) - - for i in range(1, T): - fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)") + # import plotly.graph_objects as go + # fig = go.Figure() + + # fig.add_trace(go.Scatter( + # x=np.arange(sf*T), + # y=merged_traces[:, 0], + # mode="lines", + # name="Non-chunked" + # )) + # fig.add_trace(go.Scatter( + # x=np.arange(sf*T), + # y=chunked_traces[:, 0], + # mode="lines", + # name="Chunked" + # )) + # fig.add_trace(go.Scatter( + # x=np.arange(sf*T), + # y=merged_traces[:, 0] - chunked_traces[:, 0], + # mode="lines", + # name="Difference" + # )) + + # for i in range(1, T): + # fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)") # fig.update_xaxes(type="log") - fig.show() + # fig.show() if __name__ == '__main__': From 23fd0b1d82892f7ace1901a2a7a3326fffb5bbbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 7 Apr 2023 17:08:21 +0200 Subject: [PATCH 06/17] Remove unnecessary random generator --- src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index ea27880004..f452b1f8e8 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -29,8 +29,7 @@ def test_MergeApLfpRecording(): T = 10 # Generate a 10-seconds 2-channels white noise recording. - rng = np.random.RandomState(seed=420) - original_traces = np.array([rng.normal(loc=0.0, scale=1.0, size=T*sf), np.random.normal(loc=0.0, scale=1.0, size=T*sf)]).T + original_traces = np.array([np.random.normal(loc=0.0, scale=1.0, size=T*sf), np.random.normal(loc=0.0, scale=1.0, size=T*sf)]).T original_fourier = np.fft.rfft(original_traces, axis=0) freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) From 447af6336bc4c129b7914c67586739753248ce08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 12 Apr 2023 11:01:18 +0200 Subject: [PATCH 07/17] Little fix --- src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index f452b1f8e8..ed103f6b6a 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -33,7 +33,7 @@ def test_MergeApLfpRecording(): original_fourier = np.fft.rfft(original_traces, axis=0) freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) - # Remove 0Hz (can't be reconstructed) and Nyquist frequency (behave weirdly). + # Remove 0Hz (can't be reconstructed) and Nyquist frequency (behaves weirdly). original_fourier[0] = 0.0 original_fourier[-1] = 0.0 original_traces = np.fft.irfft(original_fourier, axis=0) @@ -64,7 +64,7 @@ def test_MergeApLfpRecording(): chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration='1s') chunked_traces = chunked_recording.get_traces() - assert np.all(np.abs(merged_traces - chunked_traces)[500:-500] < 0.05) + assert np.all(np.abs(merged_traces - chunked_traces)[1000:-1000] < 0.04) # import plotly.graph_objects as go # fig = go.Figure() From 2ab2e44a1a08708630ae6705e0b978af19612964 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 May 2023 22:34:35 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/merge_ap_lfp.py | 94 +++++++++++-------- .../preprocessing/tests/test_merge_ap_lfp.py | 26 ++--- 2 files changed, 70 insertions(+), 50 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index da93a56eb3..f6ec9f24fc 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -29,53 +29,70 @@ class MergeApLfpRecording(BaseRecording): The result of the merge of both channels (with the whole frequency spectrum). """ - def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, ap_filter: Callable[[np.ndarray], np.ndarray], - lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int = 60_000) -> None: + def __init__( + self, + ap_recording: BaseRecording, + lfp_recording: BaseRecording, + ap_filter: Callable[[np.ndarray], np.ndarray], + lfp_filter: Callable[[np.ndarray], np.ndarray], + margin: int = 60_000, + ) -> None: BaseRecording.__init__(self, ap_recording.sampling_frequency, ap_recording.channel_ids, ap_recording.dtype) ap_recording.copy_metadata(self) for segment_index in range(ap_recording.get_num_segments()): - ap_recording_segment = ap_recording._recording_segments[segment_index] + ap_recording_segment = ap_recording._recording_segments[segment_index] lfp_recording_segment = lfp_recording._recording_segments[segment_index] - self.add_recording_segment(MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin)) + self.add_recording_segment( + MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin) + ) self._kwargs = { # TODO: Is callable serializable? - 'ap_recording': ap_recording.to_dict(), - 'lfp_recording': lfp_recording.to_dict(), - 'margin': margin + "ap_recording": ap_recording.to_dict(), + "lfp_recording": lfp_recording.to_dict(), + "margin": margin, } class MergeApLfpRecordingSegment(BaseRecordingSegment): - - def __init__(self, ap_recording_segment: BaseRecordingSegment, lfp_recording_segment: BaseRecordingSegment, - ap_filter: Callable[[np.ndarray], np.ndarray], lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int) -> None: + def __init__( + self, + ap_recording_segment: BaseRecordingSegment, + lfp_recording_segment: BaseRecordingSegment, + ap_filter: Callable[[np.ndarray], np.ndarray], + lfp_filter: Callable[[np.ndarray], np.ndarray], + margin: int, + ) -> None: BaseRecordingSegment.__init__(self, ap_recording_segment.sampling_frequency, ap_recording_segment.t_start) - self.ap_recording = ap_recording_segment + self.ap_recording = ap_recording_segment self.lfp_recording = lfp_recording_segment - self.ap_filter = ap_filter + self.ap_filter = ap_filter self.lfp_filter = lfp_filter self.margin = margin self.AP_TO_LFP = int(round(ap_recording_segment.sampling_frequency / lfp_recording_segment.sampling_frequency)) - def get_num_samples(self) -> int: # Trunk the recording to have a number of samples that is a multiple of 'AP_TO_LFP'. return self.ap_recording.get_num_samples() - (self.ap_recording.get_num_samples() % self.AP_TO_LFP) - - def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None) -> np.ndarray: + def get_traces( + self, + start_frame: Union[int, None] = None, + end_frame: Union[int, None] = None, + channel_indices: Union[List, None] = None, + ) -> np.ndarray: if start_frame is None: start_frame = 0 if end_frame is None: end_frame = self.get_num_samples() - ap_traces, left_margin, right_margin = get_chunk_with_margin(self.ap_recording, start_frame, end_frame, channel_indices, self.margin + self.AP_TO_LFP) - - left_leftover = (self.AP_TO_LFP - (start_frame - left_margin) % self.AP_TO_LFP) % self.AP_TO_LFP + ap_traces, left_margin, right_margin = get_chunk_with_margin( + self.ap_recording, start_frame, end_frame, channel_indices, self.margin + self.AP_TO_LFP + ) + + left_leftover = (self.AP_TO_LFP - (start_frame - left_margin) % self.AP_TO_LFP) % self.AP_TO_LFP left_margin -= left_leftover right_leftover = (end_frame + right_margin) % self.AP_TO_LFP right_margin -= right_leftover @@ -84,19 +101,21 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, ap_traces = ap_traces[:right_leftover] ap_traces = ap_traces[left_leftover:] - lfp_traces = self.lfp_recording.get_traces((start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices) + lfp_traces = self.lfp_recording.get_traces( + (start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices + ) - ap_fourier = np.fft.rfft(ap_traces, axis=0) + ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0) - ap_freq = np.fft.rfftfreq(ap_traces.shape[0], d=1/self.ap_recording.sampling_frequency) - lfp_freq = np.fft.rfftfreq(lfp_traces.shape[0], d=1/self.lfp_recording.sampling_frequency) + ap_freq = np.fft.rfftfreq(ap_traces.shape[0], d=1 / self.ap_recording.sampling_frequency) + lfp_freq = np.fft.rfftfreq(lfp_traces.shape[0], d=1 / self.lfp_recording.sampling_frequency) - ap_filter = self.ap_filter(ap_freq) + ap_filter = self.ap_filter(ap_freq) lfp_filter = self.lfp_filter(lfp_freq) - ap_filter = np.where(ap_filter == 0, 1.0, ap_filter) + ap_filter = np.where(ap_filter == 0, 1.0, ap_filter) lfp_filter = np.where(lfp_filter == 0, 1.0, lfp_filter) - reconstructed_ap_fourier = ap_fourier / ap_filter[:, None] + reconstructed_ap_fourier = ap_fourier / ap_filter[:, None] reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None] # Compute aliasing of high frequencies on LFP channels @@ -104,9 +123,9 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, fourier_aliased = reconstructed_ap_fourier.copy() fourier_aliased[ap_freq <= lfp_nyquist] = 0.0 fourier_aliased *= self.lfp_filter(ap_freq)[:, None] - traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[::self.AP_TO_LFP] + traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[:: self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] - fourier_aliased = fourier_aliased[:np.searchsorted(ap_freq, lfp_nyquist, side="right")] + fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist, side="right")] lfp_aa_fourier = reconstructed_lfp_fourier - fourier_aliased # Reconstruct using both AP and LFP channels @@ -119,7 +138,9 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) idx = np.searchsorted(ap_freq, lfp_nyquist, side="right") fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:] - fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * (1 - ratio[:idx]) + fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * ( + 1 - ratio[:idx] + ) # To get back to the 0.5 - 10,000 Hz original filter # filter_reconstructed = generate_RC_filter(ap_freq, [0.5, 10000])[:, None] @@ -127,23 +148,20 @@ def get_traces(self, start_frame: Union[int, None] = None, end_frame: Union[int, reconstructed_traces = np.fft.irfft(fourier_reconstructed, axis=0) - if right_margin == 0: right_margin = -reconstructed_traces.shape[0] - reconstructed_traces = reconstructed_traces[left_margin : -right_margin] + reconstructed_traces = reconstructed_traces[left_margin:-right_margin] return reconstructed_traces.astype(self.ap_recording.dtype) class MergeNeuropixels1Recording(MergeApLfpRecording): - """ - - """ + """ """ def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, margin: int = 60_000) -> None: - ap_filter = lambda f : generate_RC_filter(f, [300, 10000]) - lfp_filter = lambda f : generate_RC_filter(f, [0.5, 500]) + ap_filter = lambda f: generate_RC_filter(f, [300, 10000]) + lfp_filter = lambda f: generate_RC_filter(f, [0.5, 500]) MergeApLfpRecording.__init__(self, ap_recording, lfp_recording, ap_filter, lfp_filter, margin) @@ -168,7 +186,7 @@ def generate_RC_filter(frequencies: np.ndarray, cut: Union[float, List[float]], """ highpass = np.ones(len(frequencies), dtype=np.complex128) - lowpass = np.ones(len(frequencies), dtype=np.complex128) + lowpass = np.ones(len(frequencies), dtype=np.complex128) if btype == "lowpass": lowpass = 1 / (1 + 1j * frequencies / cut) @@ -176,7 +194,7 @@ def generate_RC_filter(frequencies: np.ndarray, cut: Union[float, List[float]], highpass = (frequencies / cut) / (1 + 1j * frequencies / cut) elif btype == "bandpass": highpass = generate_RC_filter(frequencies, cut[0], btype="highpass") - lowpass = generate_RC_filter(frequencies, cut[1], btype="lowpass") + lowpass = generate_RC_filter(frequencies, cut[1], btype="lowpass") else: raise AttributeError(f"btype '{btype}' is invalid for generate_RC_filter.") diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index ed103f6b6a..49363e6672 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -18,9 +18,9 @@ def test_generate_RC_filter(): frequencies = np.arange(0, 15001, 1, dtype=np.float64) transfer_func = np.abs(generate_RC_filter(frequencies, [300, 10000])) - - assert abs(transfer_func[300] - 10**(-3/20)) <= 1e-2 - assert abs(transfer_func[10000] - 10**(-3/20)) <= 1e-2 + + assert abs(transfer_func[300] - 10 ** (-3 / 20)) <= 1e-2 + assert abs(transfer_func[10000] - 10 ** (-3 / 20)) <= 1e-2 assert abs(transfer_func[10] / transfer_func[1] - 10.0) <= 1e-2 @@ -29,26 +29,28 @@ def test_MergeApLfpRecording(): T = 10 # Generate a 10-seconds 2-channels white noise recording. - original_traces = np.array([np.random.normal(loc=0.0, scale=1.0, size=T*sf), np.random.normal(loc=0.0, scale=1.0, size=T*sf)]).T + original_traces = np.array( + [np.random.normal(loc=0.0, scale=1.0, size=T * sf), np.random.normal(loc=0.0, scale=1.0, size=T * sf)] + ).T original_fourier = np.fft.rfft(original_traces, axis=0) - freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) + freq = np.fft.rfftfreq(original_traces.shape[0], d=1 / sf) # Remove 0Hz (can't be reconstructed) and Nyquist frequency (behaves weirdly). original_fourier[0] = 0.0 original_fourier[-1] = 0.0 original_traces = np.fft.irfft(original_fourier, axis=0) - ap_filter = generate_RC_filter(freq, [300, 10000]) + ap_filter = generate_RC_filter(freq, [300, 10000]) lfp_filter = generate_RC_filter(freq, [0.5, 500]) - fourier_ap = original_fourier * ap_filter[:, None] + fourier_ap = original_fourier * ap_filter[:, None] fourier_lfp = original_fourier * lfp_filter[:, None] - trace_ap = np.fft.irfft(fourier_ap, axis=0) + trace_ap = np.fft.irfft(fourier_ap, axis=0) trace_lfp = np.fft.irfft(fourier_lfp, axis=0)[::12] - ap_recording = NumpyRecording(trace_ap, sf) - lfp_recording = NumpyRecording(trace_lfp, sf/12) + ap_recording = NumpyRecording(trace_ap, sf) + lfp_recording = NumpyRecording(trace_lfp, sf / 12) merged_recording = MergeNeuropixels1Recording(ap_recording, lfp_recording) merged_traces = merged_recording.get_traces() @@ -61,7 +63,7 @@ def test_MergeApLfpRecording(): check_recordings_equal(merged_recording, saved_loaded, return_scaled=False) # Check chunks - chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration='1s') + chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration="1s") chunked_traces = chunked_recording.get_traces() assert np.all(np.abs(merged_traces - chunked_traces)[1000:-1000] < 0.04) @@ -95,6 +97,6 @@ def test_MergeApLfpRecording(): # fig.show() -if __name__ == '__main__': +if __name__ == "__main__": test_generate_RC_filter() test_MergeApLfpRecording() From 1a2e3453767820c18150485bf72836b8515d211f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 24 Apr 2024 13:20:16 +0200 Subject: [PATCH 09/17] Added delay check between AP and LFP --- .../preprocessing/merge_ap_lfp.py | 54 +++++++++++++++++-- .../preprocessing/tests/test_merge_ap_lfp.py | 12 ++--- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index f6ec9f24fc..07e33389e7 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -1,5 +1,7 @@ +import math from typing import Callable, ClassVar, List, Union import numpy as np +from scipy.optimize import minimize from ..core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin @@ -19,9 +21,9 @@ class MergeApLfpRecording(BaseRecording): Takes the frequencies as parameter, and outputs the transfer function. lfp_filter: Callable Transfer function of the filter used in the lfp_recording. + Takes the frequencies as parameter, and outputs the transfer function. margin: int The margin (in samples) to use when extracting the trace. - Takes the frequencies as parameter, and outputs the transfer function. Returns -------- @@ -35,7 +37,7 @@ def __init__( lfp_recording: BaseRecording, ap_filter: Callable[[np.ndarray], np.ndarray], lfp_filter: Callable[[np.ndarray], np.ndarray], - margin: int = 60_000, + margin: int = 15_000, ) -> None: BaseRecording.__init__(self, ap_recording.sampling_frequency, ap_recording.channel_ids, ap_recording.dtype) ap_recording.copy_metadata(self) @@ -47,7 +49,7 @@ def __init__( MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin) ) - self._kwargs = { # TODO: Is callable serializable? + self._kwargs = { # TODO: Is callable serializable? (missing ap_filter & lfp_filter at the moment) "ap_recording": ap_recording.to_dict(), "lfp_recording": lfp_recording.to_dict(), "margin": margin, @@ -118,6 +120,20 @@ def get_traces( reconstructed_ap_fourier = ap_fourier / ap_filter[:, None] reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None] + # Compute time shift between AP and LFP (this varies in time!!!) + freq_slice = slice(np.searchsorted(ap_freq, 100), np.searchsorted(ap_freq, 600)) + ap_fft = reconstructed_ap_fourier[freq_slice, :] + lfp_fft = reconstructed_lfp_fourier[freq_slice, :] + + t_axis = np.arange(-2000, 2000, 40) * 1e-6 + errors = [_time_shift_error(t, ap_fft, lfp_fft, ap_freq[freq_slice]) for t in t_axis] + shift_estimate = t_axis[np.argmin(errors)] + + minimization = minimize(_time_shift_error, method="Powell", x0=[shift_estimate], args=(ap_fft, lfp_fft, ap_freq[freq_slice]), bounds=[(shift_estimate-1e-4, shift_estimate+1e-4)], tol=1e-10) + shift_estimate = minimization.x[0] + + reshifted_lfp_fourier = reconstructed_lfp_fourier / np.exp(-2j*math.pi * lfp_freq[:, None] * shift_estimate) + # Compute aliasing of high frequencies on LFP channels lfp_nyquist = self.lfp_recording.sampling_frequency / 2 fourier_aliased = reconstructed_ap_fourier.copy() @@ -126,13 +142,13 @@ def get_traces( traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[:: self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist, side="right")] - lfp_aa_fourier = reconstructed_lfp_fourier - fourier_aliased + lfp_aa_fourier = reshifted_lfp_fourier - fourier_aliased # Reconstruct using both AP and LFP channels # TODO: Have some flexibility on the ratio lfp_filt = self.lfp_filter(ap_freq) ratio = np.abs(lfp_filt[1:]) / (np.abs(lfp_filt[1:]) + np.abs(ap_filter[1:])) - ratio = 1 / (1 + np.exp(-6 * np.tan(np.pi * (ratio - 0.5)))) + ratio = 1 / (1 + np.exp(-6 * np.tan(math.pi * (ratio - 0.5)))) ratio = ratio[:, None] fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) @@ -199,3 +215,31 @@ def generate_RC_filter(frequencies: np.ndarray, cut: Union[float, List[float]], raise AttributeError(f"btype '{btype}' is invalid for generate_RC_filter.") return lowpass * highpass + + +def _time_shift_error(delay: float, ap_fft: np.ndarray, lfp_fft: np.ndarray, freq: np.ndarray) -> float: + """ + Computes the error for a given delay between ap and lfp traces. + + Parameters + ---------- + delay: float + The delay (in s) between AP and LFP. + Positive values indicate that lfp comes after ap. + ap_fft: np.ndarray (n_freq, n_channels) + The AP trace in the Fourier domain after unfiltering. + lfp_fft: np.ndarray (n_freq, n_channels) + The LFP trace in the Fourier domain after unfiltering. + freq: np.ndarray (n_freq) + The frequencies (in Hz). + + Returns + ------- + error: float + The error computed for the given delay. + """ + + expected_phase = -2 * math.pi * freq[:, None] * delay + errors = np.angle(lfp_fft / ap_fft / np.exp(1j * expected_phase)) + + return np.sum(np.abs(errors)) diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index 49363e6672..cd88149c1a 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -29,11 +29,9 @@ def test_MergeApLfpRecording(): T = 10 # Generate a 10-seconds 2-channels white noise recording. - original_traces = np.array( - [np.random.normal(loc=0.0, scale=1.0, size=T * sf), np.random.normal(loc=0.0, scale=1.0, size=T * sf)] - ).T + original_traces = np.random.normal(loc=0.0, scale=1.0, size=(T*sf, 2)) original_fourier = np.fft.rfft(original_traces, axis=0) - freq = np.fft.rfftfreq(original_traces.shape[0], d=1 / sf) + freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) # Remove 0Hz (can't be reconstructed) and Nyquist frequency (behaves weirdly). original_fourier[0] = 0.0 @@ -50,13 +48,13 @@ def test_MergeApLfpRecording(): trace_lfp = np.fft.irfft(fourier_lfp, axis=0)[::12] ap_recording = NumpyRecording(trace_ap, sf) - lfp_recording = NumpyRecording(trace_lfp, sf / 12) + lfp_recording = NumpyRecording(trace_lfp, sf/12) merged_recording = MergeNeuropixels1Recording(ap_recording, lfp_recording) merged_traces = merged_recording.get_traces() assert original_traces.shape == merged_traces.shape - assert np.allclose(original_traces, merged_traces, rtol=1e-3, atol=1e-4) + assert np.allclose(original_traces, merged_traces, rtol=1e-2, atol=1e-2) # Check dumpability saved_loaded = load_extractor(merged_recording.to_dict()) @@ -66,7 +64,7 @@ def test_MergeApLfpRecording(): chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration="1s") chunked_traces = chunked_recording.get_traces() - assert np.all(np.abs(merged_traces - chunked_traces)[1000:-1000] < 0.04) + assert np.allclose(merged_traces[1000:-1000], chunked_traces[1000:-1000], rtol=1, atol=0.04) # import plotly.graph_objects as go # fig = go.Figure() From 1030057cb8d4089d4183043933ff675b76a3fbb6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:20:49 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/merge_ap_lfp.py | 11 +++++++++-- .../preprocessing/tests/test_merge_ap_lfp.py | 6 +++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index 07e33389e7..691db5d0ef 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -129,10 +129,17 @@ def get_traces( errors = [_time_shift_error(t, ap_fft, lfp_fft, ap_freq[freq_slice]) for t in t_axis] shift_estimate = t_axis[np.argmin(errors)] - minimization = minimize(_time_shift_error, method="Powell", x0=[shift_estimate], args=(ap_fft, lfp_fft, ap_freq[freq_slice]), bounds=[(shift_estimate-1e-4, shift_estimate+1e-4)], tol=1e-10) + minimization = minimize( + _time_shift_error, + method="Powell", + x0=[shift_estimate], + args=(ap_fft, lfp_fft, ap_freq[freq_slice]), + bounds=[(shift_estimate - 1e-4, shift_estimate + 1e-4)], + tol=1e-10, + ) shift_estimate = minimization.x[0] - reshifted_lfp_fourier = reconstructed_lfp_fourier / np.exp(-2j*math.pi * lfp_freq[:, None] * shift_estimate) + reshifted_lfp_fourier = reconstructed_lfp_fourier / np.exp(-2j * math.pi * lfp_freq[:, None] * shift_estimate) # Compute aliasing of high frequencies on LFP channels lfp_nyquist = self.lfp_recording.sampling_frequency / 2 diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index cd88149c1a..c9f02762d9 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -29,9 +29,9 @@ def test_MergeApLfpRecording(): T = 10 # Generate a 10-seconds 2-channels white noise recording. - original_traces = np.random.normal(loc=0.0, scale=1.0, size=(T*sf, 2)) + original_traces = np.random.normal(loc=0.0, scale=1.0, size=(T * sf, 2)) original_fourier = np.fft.rfft(original_traces, axis=0) - freq = np.fft.rfftfreq(original_traces.shape[0], d=1/sf) + freq = np.fft.rfftfreq(original_traces.shape[0], d=1 / sf) # Remove 0Hz (can't be reconstructed) and Nyquist frequency (behaves weirdly). original_fourier[0] = 0.0 @@ -48,7 +48,7 @@ def test_MergeApLfpRecording(): trace_lfp = np.fft.irfft(fourier_lfp, axis=0)[::12] ap_recording = NumpyRecording(trace_ap, sf) - lfp_recording = NumpyRecording(trace_lfp, sf/12) + lfp_recording = NumpyRecording(trace_lfp, sf / 12) merged_recording = MergeNeuropixels1Recording(ap_recording, lfp_recording) merged_traces = merged_recording.get_traces() From 5dddf3355934e16906b02ea119e95e3b578b8666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 24 Apr 2024 13:31:33 +0200 Subject: [PATCH 11/17] Moved scipy import --- src/spikeinterface/preprocessing/merge_ap_lfp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index 691db5d0ef..57080e38b7 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -1,7 +1,6 @@ import math from typing import Callable, ClassVar, List, Union import numpy as np -from scipy.optimize import minimize from ..core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin @@ -85,6 +84,8 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: + from scipy.optimize import minimize + if start_frame is None: start_frame = 0 if end_frame is None: From 42a5365b9a2f17c98059c3da45ae4432e2674f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 25 Apr 2024 12:04:20 +0200 Subject: [PATCH 12/17] Necessary tweaks for Neuropixels --- src/spikeinterface/core/sortinganalyzer.py | 3 +- .../preprocessing/merge_ap_lfp.py | 35 +++++++++++++------ .../preprocessing/tests/test_merge_ap_lfp.py | 3 +- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 85ea9b8438..3ca479d2c6 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -4,6 +4,7 @@ from pathlib import Path import os import json +import math import pickle import weakref import shutil @@ -236,7 +237,7 @@ def create( return_scaled=True, ): # some checks - assert sorting.sampling_frequency == recording.sampling_frequency + assert math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5) # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes check_probe_do_not_overlap(all_probes) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index 57080e38b7..ca6a2ddf1a 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -36,16 +36,25 @@ def __init__( lfp_recording: BaseRecording, ap_filter: Callable[[np.ndarray], np.ndarray], lfp_filter: Callable[[np.ndarray], np.ndarray], - margin: int = 15_000, + margin: int = 6_000, ) -> None: BaseRecording.__init__(self, ap_recording.sampling_frequency, ap_recording.channel_ids, ap_recording.dtype) ap_recording.copy_metadata(self) + if ap_recording.has_scaled(): + ap_gain = ap_recording.get_property("gain_to_uV") + else: + ap_gain = np.ones(ap_recording.get_num_channels(), dtype=np.float32) + if lfp_recording.has_scaled(): + lfp_gain = lfp_recording.get_property("gain_to_uV") + else: + lfp_gain = np.ones(lfp_recording.get_num_channels(), dtype=np.float32) + for segment_index in range(ap_recording.get_num_segments()): ap_recording_segment = ap_recording._recording_segments[segment_index] lfp_recording_segment = lfp_recording._recording_segments[segment_index] self.add_recording_segment( - MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin) + MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin, lfp_gain/ap_gain, ap_recording.get_dtype()) ) self._kwargs = { # TODO: Is callable serializable? (missing ap_filter & lfp_filter at the moment) @@ -63,6 +72,8 @@ def __init__( ap_filter: Callable[[np.ndarray], np.ndarray], lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int, + lfp_to_ap_gain: np.ndarray, + dtype ) -> None: BaseRecordingSegment.__init__(self, ap_recording_segment.sampling_frequency, ap_recording_segment.t_start) @@ -71,6 +82,8 @@ def __init__( self.ap_filter = ap_filter self.lfp_filter = lfp_filter self.margin = margin + self.lfp_to_ap_gain = lfp_to_ap_gain + self.dtype = dtype self.AP_TO_LFP = int(round(ap_recording_segment.sampling_frequency / lfp_recording_segment.sampling_frequency)) @@ -85,6 +98,7 @@ def get_traces( channel_indices: Union[List, None] = None, ) -> np.ndarray: from scipy.optimize import minimize + import time if start_frame is None: start_frame = 0 @@ -94,6 +108,7 @@ def get_traces( ap_traces, left_margin, right_margin = get_chunk_with_margin( self.ap_recording, start_frame, end_frame, channel_indices, self.margin + self.AP_TO_LFP ) + t15 = time.perf_counter() left_leftover = (self.AP_TO_LFP - (start_frame - left_margin) % self.AP_TO_LFP) % self.AP_TO_LFP left_margin -= left_leftover @@ -101,12 +116,12 @@ def get_traces( right_margin -= right_leftover if right_leftover > 0: - ap_traces = ap_traces[:right_leftover] + ap_traces = ap_traces[:-right_leftover] ap_traces = ap_traces[left_leftover:] lfp_traces = self.lfp_recording.get_traces( (start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices - ) + ) * self.lfp_to_ap_gain ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0) @@ -126,7 +141,7 @@ def get_traces( ap_fft = reconstructed_ap_fourier[freq_slice, :] lfp_fft = reconstructed_lfp_fourier[freq_slice, :] - t_axis = np.arange(-2000, 2000, 40) * 1e-6 + t_axis = np.arange(-2000, 2000, 60) * 1e-6 errors = [_time_shift_error(t, ap_fft, lfp_fft, ap_freq[freq_slice]) for t in t_axis] shift_estimate = t_axis[np.argmin(errors)] @@ -136,7 +151,7 @@ def get_traces( x0=[shift_estimate], args=(ap_fft, lfp_fft, ap_freq[freq_slice]), bounds=[(shift_estimate - 1e-4, shift_estimate + 1e-4)], - tol=1e-10, + tol=1e-6, ) shift_estimate = minimization.x[0] @@ -149,7 +164,7 @@ def get_traces( fourier_aliased *= self.lfp_filter(ap_freq)[:, None] traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[:: self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] - fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist, side="right")] + fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist+1e-6, side="right")] lfp_aa_fourier = reshifted_lfp_fourier - fourier_aliased # Reconstruct using both AP and LFP channels @@ -160,7 +175,7 @@ def get_traces( ratio = ratio[:, None] fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) - idx = np.searchsorted(ap_freq, lfp_nyquist, side="right") + idx = np.searchsorted(ap_freq, lfp_nyquist+1e-6, side="right") fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:] fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * ( 1 - ratio[:idx] @@ -177,13 +192,13 @@ def get_traces( reconstructed_traces = reconstructed_traces[left_margin:-right_margin] - return reconstructed_traces.astype(self.ap_recording.dtype) + return reconstructed_traces.astype(self.dtype) class MergeNeuropixels1Recording(MergeApLfpRecording): """ """ - def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, margin: int = 60_000) -> None: + def __init__(self, ap_recording: BaseRecording, lfp_recording: BaseRecording, margin: int = 6_000) -> None: ap_filter = lambda f: generate_RC_filter(f, [300, 10000]) lfp_filter = lambda f: generate_RC_filter(f, [0.5, 500]) MergeApLfpRecording.__init__(self, ap_recording, lfp_recording, ap_filter, lfp_filter, margin) diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index c9f02762d9..ae24a0bede 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -64,7 +64,7 @@ def test_MergeApLfpRecording(): chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration="1s") chunked_traces = chunked_recording.get_traces() - assert np.allclose(merged_traces[1000:-1000], chunked_traces[1000:-1000], rtol=1, atol=0.04) + assert np.allclose(merged_traces[5000:-5000], chunked_traces[5000:-5000], rtol=1, atol=2e-2) # import plotly.graph_objects as go # fig = go.Figure() @@ -91,7 +91,6 @@ def test_MergeApLfpRecording(): # for i in range(1, T): # fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)") - # fig.update_xaxes(type="log") # fig.show() From eef45a2639f915bdc5606cacb8e7f9f0e4cf8144 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Apr 2024 10:05:14 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/merge_ap_lfp.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index ca6a2ddf1a..a73689907b 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -54,7 +54,15 @@ def __init__( ap_recording_segment = ap_recording._recording_segments[segment_index] lfp_recording_segment = lfp_recording._recording_segments[segment_index] self.add_recording_segment( - MergeApLfpRecordingSegment(ap_recording_segment, lfp_recording_segment, ap_filter, lfp_filter, margin, lfp_gain/ap_gain, ap_recording.get_dtype()) + MergeApLfpRecordingSegment( + ap_recording_segment, + lfp_recording_segment, + ap_filter, + lfp_filter, + margin, + lfp_gain / ap_gain, + ap_recording.get_dtype(), + ) ) self._kwargs = { # TODO: Is callable serializable? (missing ap_filter & lfp_filter at the moment) @@ -73,7 +81,7 @@ def __init__( lfp_filter: Callable[[np.ndarray], np.ndarray], margin: int, lfp_to_ap_gain: np.ndarray, - dtype + dtype, ) -> None: BaseRecordingSegment.__init__(self, ap_recording_segment.sampling_frequency, ap_recording_segment.t_start) @@ -119,9 +127,14 @@ def get_traces( ap_traces = ap_traces[:-right_leftover] ap_traces = ap_traces[left_leftover:] - lfp_traces = self.lfp_recording.get_traces( - (start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices - ) * self.lfp_to_ap_gain + lfp_traces = ( + self.lfp_recording.get_traces( + (start_frame - left_margin) // self.AP_TO_LFP, + (end_frame + right_margin) // self.AP_TO_LFP, + channel_indices, + ) + * self.lfp_to_ap_gain + ) ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0) @@ -164,7 +177,7 @@ def get_traces( fourier_aliased *= self.lfp_filter(ap_freq)[:, None] traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[:: self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] - fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist+1e-6, side="right")] + fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist + 1e-6, side="right")] lfp_aa_fourier = reshifted_lfp_fourier - fourier_aliased # Reconstruct using both AP and LFP channels @@ -175,7 +188,7 @@ def get_traces( ratio = ratio[:, None] fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) - idx = np.searchsorted(ap_freq, lfp_nyquist+1e-6, side="right") + idx = np.searchsorted(ap_freq, lfp_nyquist + 1e-6, side="right") fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:] fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * ( 1 - ratio[:idx] From b2d4848205d422486218926f8f139aa821d2de9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 25 Apr 2024 13:18:24 +0200 Subject: [PATCH 14/17] Better memory managment --- .../preprocessing/merge_ap_lfp.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index a73689907b..7743845492 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -143,42 +143,38 @@ def get_traces( ap_filter = self.ap_filter(ap_freq) lfp_filter = self.lfp_filter(lfp_freq) - ap_filter = np.where(ap_filter == 0, 1.0, ap_filter) - lfp_filter = np.where(lfp_filter == 0, 1.0, lfp_filter) + ap_filter[0] = lfp_filter[0] = 1.0 # Don't reconstruct 0 Hz. - reconstructed_ap_fourier = ap_fourier / ap_filter[:, None] - reconstructed_lfp_fourier = lfp_fourier / lfp_filter[:, None] + ap_fourier /= ap_filter[:, None] + lfp_fourier /= lfp_filter[:, None] # Compute time shift between AP and LFP (this varies in time!!!) freq_slice = slice(np.searchsorted(ap_freq, 100), np.searchsorted(ap_freq, 600)) - ap_fft = reconstructed_ap_fourier[freq_slice, :] - lfp_fft = reconstructed_lfp_fourier[freq_slice, :] t_axis = np.arange(-2000, 2000, 60) * 1e-6 - errors = [_time_shift_error(t, ap_fft, lfp_fft, ap_freq[freq_slice]) for t in t_axis] + errors = [_time_shift_error(t, ap_fourier[freq_slice, :], lfp_fourier[freq_slice, :], ap_freq[freq_slice]) for t in t_axis] shift_estimate = t_axis[np.argmin(errors)] minimization = minimize( _time_shift_error, method="Powell", x0=[shift_estimate], - args=(ap_fft, lfp_fft, ap_freq[freq_slice]), + args=(ap_fourier[freq_slice, :], lfp_fourier[freq_slice, :], ap_freq[freq_slice]), bounds=[(shift_estimate - 1e-4, shift_estimate + 1e-4)], tol=1e-6, ) shift_estimate = minimization.x[0] - - reshifted_lfp_fourier = reconstructed_lfp_fourier / np.exp(-2j * math.pi * lfp_freq[:, None] * shift_estimate) + lfp_fourier /= np.exp(-2j * math.pi * lfp_freq[:, None] * shift_estimate) # Compute aliasing of high frequencies on LFP channels lfp_nyquist = self.lfp_recording.sampling_frequency / 2 - fourier_aliased = reconstructed_ap_fourier.copy() - fourier_aliased[ap_freq <= lfp_nyquist] = 0.0 + nyquist_index = np.searchsorted(ap_freq, lfp_nyquist + 1e-6, side="right") + fourier_aliased = ap_fourier.copy() + fourier_aliased[:nyquist_index] = 0.0 fourier_aliased *= self.lfp_filter(ap_freq)[:, None] traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[:: self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] - fourier_aliased = fourier_aliased[: np.searchsorted(ap_freq, lfp_nyquist + 1e-6, side="right")] - lfp_aa_fourier = reshifted_lfp_fourier - fourier_aliased + lfp_fourier -= fourier_aliased[:nyquist_index] # Reconstruct using both AP and LFP channels # TODO: Have some flexibility on the ratio @@ -187,11 +183,10 @@ def get_traces( ratio = 1 / (1 + np.exp(-6 * np.tan(math.pi * (ratio - 0.5)))) ratio = ratio[:, None] - fourier_reconstructed = np.empty(reconstructed_ap_fourier.shape, dtype=np.complex128) - idx = np.searchsorted(ap_freq, lfp_nyquist + 1e-6, side="right") - fourier_reconstructed[idx:] = reconstructed_ap_fourier[idx:] - fourier_reconstructed[:idx] = self.AP_TO_LFP * lfp_aa_fourier * ratio[:idx] + reconstructed_ap_fourier[:idx] * ( - 1 - ratio[:idx] + fourier_reconstructed = np.empty(ap_fourier.shape, dtype=np.complex128) + fourier_reconstructed[nyquist_index:] = ap_fourier[nyquist_index:] + fourier_reconstructed[:nyquist_index] = self.AP_TO_LFP * lfp_fourier * ratio[:nyquist_index] + ap_fourier[:nyquist_index] * ( + 1 - ratio[:nyquist_index] ) # To get back to the 0.5 - 10,000 Hz original filter From b67d19c4a6a55c4a5a2a757f62d718661990050c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:21:06 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/merge_ap_lfp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index 7743845492..8669d809de 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -152,7 +152,10 @@ def get_traces( freq_slice = slice(np.searchsorted(ap_freq, 100), np.searchsorted(ap_freq, 600)) t_axis = np.arange(-2000, 2000, 60) * 1e-6 - errors = [_time_shift_error(t, ap_fourier[freq_slice, :], lfp_fourier[freq_slice, :], ap_freq[freq_slice]) for t in t_axis] + errors = [ + _time_shift_error(t, ap_fourier[freq_slice, :], lfp_fourier[freq_slice, :], ap_freq[freq_slice]) + for t in t_axis + ] shift_estimate = t_axis[np.argmin(errors)] minimization = minimize( @@ -185,9 +188,9 @@ def get_traces( fourier_reconstructed = np.empty(ap_fourier.shape, dtype=np.complex128) fourier_reconstructed[nyquist_index:] = ap_fourier[nyquist_index:] - fourier_reconstructed[:nyquist_index] = self.AP_TO_LFP * lfp_fourier * ratio[:nyquist_index] + ap_fourier[:nyquist_index] * ( - 1 - ratio[:nyquist_index] - ) + fourier_reconstructed[:nyquist_index] = self.AP_TO_LFP * lfp_fourier * ratio[:nyquist_index] + ap_fourier[ + :nyquist_index + ] * (1 - ratio[:nyquist_index]) # To get back to the 0.5 - 10,000 Hz original filter # filter_reconstructed = generate_RC_filter(ap_freq, [0.5, 10000])[:, None] From 52b83b145e12aaf7c8ec52110263fcda76f2be13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 26 Apr 2024 16:19:54 +0200 Subject: [PATCH 16/17] Fixed problems with merging AP and LFP --- .../preprocessing/merge_ap_lfp.py | 17 ++++---- .../preprocessing/tests/test_merge_ap_lfp.py | 39 +++---------------- 2 files changed, 13 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index 8669d809de..4806dfad2d 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -127,14 +127,11 @@ def get_traces( ap_traces = ap_traces[:-right_leftover] ap_traces = ap_traces[left_leftover:] - lfp_traces = ( - self.lfp_recording.get_traces( + lfp_traces = self.lfp_recording.get_traces( (start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices, - ) - * self.lfp_to_ap_gain - ) + ) * self.lfp_to_ap_gain[channel_indices] ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0) @@ -148,7 +145,7 @@ def get_traces( ap_fourier /= ap_filter[:, None] lfp_fourier /= lfp_filter[:, None] - # Compute time shift between AP and LFP (this varies in time!!!) + # Compute time shift between AP and LFP (TODO: Compute once and store?) freq_slice = slice(np.searchsorted(ap_freq, 100), np.searchsorted(ap_freq, 600)) t_axis = np.arange(-2000, 2000, 60) * 1e-6 @@ -171,13 +168,13 @@ def get_traces( # Compute aliasing of high frequencies on LFP channels lfp_nyquist = self.lfp_recording.sampling_frequency / 2 - nyquist_index = np.searchsorted(ap_freq, lfp_nyquist + 1e-6, side="right") - fourier_aliased = ap_fourier.copy() + nyquist_index = len(lfp_freq) + fourier_aliased = ap_fourier * np.exp(-2j * math.pi * ap_freq[:, None] * shift_estimate) fourier_aliased[:nyquist_index] = 0.0 fourier_aliased *= self.lfp_filter(ap_freq)[:, None] traces_aliased = np.fft.irfft(fourier_aliased, axis=0)[:: self.AP_TO_LFP] fourier_aliased = np.fft.rfft(traces_aliased, axis=0) / lfp_filter[:, None] - lfp_fourier -= fourier_aliased[:nyquist_index] + lfp_fourier -= fourier_aliased / np.exp(-2j * math.pi * lfp_freq[:, None] * shift_estimate) # Reconstruct using both AP and LFP channels # TODO: Have some flexibility on the ratio @@ -241,7 +238,7 @@ def generate_RC_filter(frequencies: np.ndarray, cut: Union[float, List[float]], if btype == "lowpass": lowpass = 1 / (1 + 1j * frequencies / cut) elif btype == "highpass": - highpass = (frequencies / cut) / (1 + 1j * frequencies / cut) + highpass = (1j * frequencies / cut) / (1 + 1j * frequencies / cut) elif btype == "bandpass": highpass = generate_RC_filter(frequencies, cut[0], btype="highpass") lowpass = generate_RC_filter(frequencies, cut[1], btype="lowpass") diff --git a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py index ae24a0bede..a681927e05 100644 --- a/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/tests/test_merge_ap_lfp.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from spikeinterface.core import NumpyRecording, load_extractor, set_global_tmp_folder +from spikeinterface.core import NumpyRecording, load_extractor, normal_pdf, set_global_tmp_folder from spikeinterface.core.testing import check_recordings_equal from spikeinterface.preprocessing import generate_RC_filter, MergeNeuropixels1Recording @@ -26,10 +26,10 @@ def test_generate_RC_filter(): def test_MergeApLfpRecording(): sf = 30000 - T = 10 + T = 5 - # Generate a 10-seconds 2-channels white noise recording. - original_traces = np.random.normal(loc=0.0, scale=1.0, size=(T * sf, 2)) + # Generate a 5-seconds 10-channels white noise recording. + original_traces = np.random.normal(loc=0.0, scale=1.0, size=(T * sf, 10)) original_fourier = np.fft.rfft(original_traces, axis=0) freq = np.fft.rfftfreq(original_traces.shape[0], d=1 / sf) @@ -45,7 +45,7 @@ def test_MergeApLfpRecording(): fourier_lfp = original_fourier * lfp_filter[:, None] trace_ap = np.fft.irfft(fourier_ap, axis=0) - trace_lfp = np.fft.irfft(fourier_lfp, axis=0)[::12] + trace_lfp = np.fft.irfft(fourier_lfp, axis=0)[1::12] # Shifted LFP trace ap_recording = NumpyRecording(trace_ap, sf) lfp_recording = NumpyRecording(trace_lfp, sf / 12) @@ -64,34 +64,7 @@ def test_MergeApLfpRecording(): chunked_recording = merged_recording.save(folder=cache_folder / "chunked", n_jobs=2, chunk_duration="1s") chunked_traces = chunked_recording.get_traces() - assert np.allclose(merged_traces[5000:-5000], chunked_traces[5000:-5000], rtol=1, atol=2e-2) - - # import plotly.graph_objects as go - # fig = go.Figure() - - # fig.add_trace(go.Scatter( - # x=np.arange(sf*T), - # y=merged_traces[:, 0], - # mode="lines", - # name="Non-chunked" - # )) - # fig.add_trace(go.Scatter( - # x=np.arange(sf*T), - # y=chunked_traces[:, 0], - # mode="lines", - # name="Chunked" - # )) - # fig.add_trace(go.Scatter( - # x=np.arange(sf*T), - # y=merged_traces[:, 0] - chunked_traces[:, 0], - # mode="lines", - # name="Difference" - # )) - - # for i in range(1, T): - # fig.add_vline(x=i*sf, line_dash="dash", line_color="rgba(0, 0, 0, 0.3)") - - # fig.show() + assert np.allclose(merged_traces[5000:-5000], chunked_traces[5000:-5000], rtol=1, atol=0.3) if __name__ == "__main__": From a926d459142071bd5692a175c4190f2b6a71c9a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Apr 2024 14:20:26 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/merge_ap_lfp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/merge_ap_lfp.py b/src/spikeinterface/preprocessing/merge_ap_lfp.py index 4806dfad2d..2b764c03f6 100644 --- a/src/spikeinterface/preprocessing/merge_ap_lfp.py +++ b/src/spikeinterface/preprocessing/merge_ap_lfp.py @@ -127,11 +127,14 @@ def get_traces( ap_traces = ap_traces[:-right_leftover] ap_traces = ap_traces[left_leftover:] - lfp_traces = self.lfp_recording.get_traces( + lfp_traces = ( + self.lfp_recording.get_traces( (start_frame - left_margin) // self.AP_TO_LFP, (end_frame + right_margin) // self.AP_TO_LFP, channel_indices, - ) * self.lfp_to_ap_gain[channel_indices] + ) + * self.lfp_to_ap_gain[channel_indices] + ) ap_fourier = np.fft.rfft(ap_traces, axis=0) lfp_fourier = np.fft.rfft(lfp_traces, axis=0)