Skip to content

Commit 1d09d3b

Browse files
committed
tests for merged output
1 parent 022436f commit 1d09d3b

File tree

5 files changed

+106
-10
lines changed

5 files changed

+106
-10
lines changed

phylib/io/alf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def make_template_and_spikes_objects(self):
270270
templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32)
271271
# for each template, find the nearest channels to keep (one the same probe...)
272272
for t in np.arange(n_clusters):
273-
# here we need to fill with nans if it doesn't exists, but then can no longet be int (sorry) # or have it all 0
274273
channels = self.model.clusters_channels
275274

276275
current_probe = self.model.channel_probes[channels[t]]

phylib/io/model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -413,16 +413,14 @@ def _load_data(self):
413413
self.n_channels_loc = 0
414414

415415
# Clusters waveforms
416-
if np.all(self.spike_clusters == self.spike_templates):
416+
if not np.all(self.spike_clusters == self.spike_templates) and self.sparse_templates.cols is None:
417+
self.merge_map, _ = self.get_merge_map()
418+
self.sparse_clusters = self.cluster_waveforms()
419+
self.n_clusters = self.spike_clusters.max() + 1
420+
else:
417421
self.merge_map = {}
418-
self.nan_clusters = []
419422
self.sparse_clusters = self.sparse_templates
420423
self.n_clusters = self.spike_templates.max() + 1
421-
else:
422-
if self.sparse_templates.cols is None:
423-
self.merge_map, self.nan_clusters = self.get_merge_map()
424-
self.sparse_clusters = self.cluster_waveforms()
425-
self.n_clusters = self.spike_clusters.max() + 1
426424

427425
# Spike waveforms (optional, otherwise fetched from raw data as needed).
428426
self.spike_waveforms = self._load_spike_waveforms()

phylib/io/tests/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True):
9898
_remove(tempdir / 'whitening_mat_inv.npy')
9999
_remove(tempdir / 'sim_binary.dat')
100100

101+
if param == 'merged':
102+
# remove this file to make templates dense
103+
_remove(tempdir / 'template_ind.npy')
104+
clus = np.load(tempdir / 'spike_clusters.npy')
105+
max_clus = np.max(clus)
106+
# merge cluster 0 and 1
107+
clus[np.bitwise_or(clus == 0, clus == 1)] = max_clus + 1
108+
# split cluster 9 into two clusters
109+
idx = np.where(clus == 9)[0]
110+
clus[idx[0:3]] = max_clus + 2
111+
clus[idx[3:]] = max_clus + 3
112+
np.save(tempdir / 'spike_clusters.npy', clus)
113+
101114
# Spike attributes.
102115
if has_spike_attributes:
103116
write_array(tempdir / 'spike_fail.npy', np.full(10, np.nan)) # wrong number of spikes
@@ -120,7 +133,7 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True):
120133
return template_path
121134

122135

123-
@fixture(scope='function', params=('dense', 'sparse', 'misc'))
136+
@fixture(scope='function', params=('dense', 'sparse', 'misc', 'merged'))
124137
def template_path_full(tempdir, request):
125138
return _make_dataset(tempdir, request.param)
126139

phylib/io/tests/test_alf.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,72 @@ def read_after_write():
210210
c.convert(out_path, label='probe00')
211211
check_conversion_output()
212212
read_after_write()
213+
214+
215+
def test_merger(dataset):
216+
217+
path = Path(dataset.tmp_dir)
218+
out_path = path / 'alf'
219+
220+
model = TemplateModel(
221+
dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc)
222+
223+
c = EphysAlfCreator(model)
224+
c.convert(out_path)
225+
226+
model.close()
227+
228+
# path.joinpath('_phy_spikes_subset.channels.npy').unlink()
229+
# path.joinpath('_phy_spikes_subset.waveforms.npy').unlink()
230+
# path.joinpath('_phy_spikes_subset.spikes.npy').unlink()
231+
232+
out_path_merge = path / 'alf_merge'
233+
spike_clusters = dataset._load('spike_clusters.npy')
234+
clu, n_clu = np.unique(spike_clusters, return_counts=True)
235+
236+
# merge the first two clusters
237+
merge_clu = clu[0:2]
238+
spike_clusters[np.bitwise_or(spike_clusters == clu[0], spike_clusters == clu[1])] = np.max(clu) + 1
239+
# split the cluster with the most spikes
240+
split_clu = clu[-1]
241+
idx = np.where(spike_clusters == split_clu)[0]
242+
spike_clusters[idx[0:int(n_clu[-1] / 2)]] = np.max(clu) + 2
243+
spike_clusters[idx[int(n_clu[-1] / 2):]] = np.max(clu) + 3
244+
245+
np.save(path / 'spike_clusters.npy', spike_clusters)
246+
247+
model = TemplateModel(
248+
dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc)
249+
print(model.merge_map)
250+
c = EphysAlfCreator(model)
251+
c.convert(out_path_merge)
252+
253+
# Test that the split are the same for the expected datasets
254+
clu_old = np.load(next(out_path.glob('clusters.peakToTrough.npy')))
255+
clu_new = np.load(next(out_path_merge.glob('clusters.peakToTrough.npy')))
256+
assert clu_old[split_clu] == clu_new[np.max(clu) + 2]
257+
assert clu_old[split_clu] == clu_new[np.max(clu) + 3]
258+
assert clu_new[split_clu] == 0
259+
assert clu_new[merge_clu[0]] == 0
260+
assert clu_new[merge_clu[1]] == 0
261+
262+
clu_old = np.load(next(out_path.glob('clusters.channels.npy')))
263+
clu_new = np.load(next(out_path_merge.glob('clusters.channels.npy')))
264+
assert clu_old[split_clu] == clu_new[np.max(clu) + 2]
265+
assert clu_old[split_clu] == clu_new[np.max(clu) + 3]
266+
assert clu_new[split_clu] == 0
267+
assert clu_new[merge_clu[0]] == 0
268+
assert clu_new[merge_clu[1]] == 0
269+
270+
clu_old = np.load(next(out_path.glob('clusters.depths.npy')))
271+
clu_new = np.load(next(out_path_merge.glob('clusters.depths.npy')))
272+
assert clu_old[split_clu] == clu_new[np.max(clu) + 2]
273+
assert clu_old[split_clu] == clu_new[np.max(clu) + 3]
274+
assert clu_new[split_clu] == 0
275+
assert clu_new[merge_clu[0]] == 0
276+
assert clu_new[merge_clu[1]] == 0
277+
278+
clu_old = np.load(next(out_path.glob('clusters.waveformsChannels.npy')))
279+
clu_new = np.load(next(out_path_merge.glob('clusters.waveformsChannels.npy')))
280+
assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 2])
281+
assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 3])

phylib/io/tests/test_model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
# from phylib.utils import Bunch
1616
from phylib.utils.testing import captured_output
17-
from ..model import from_sparse, load_model
17+
# from ..model import from_sparse, load_model
18+
from phylib.io.model import from_sparse, load_model
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -113,6 +114,22 @@ def test_model_depth(template_model):
113114
assert depths.shape == (template_model.n_spikes,)
114115

115116

117+
def test_model_merge(template_model_full):
118+
m = template_model_full
119+
120+
# This is the case where we can do the merging
121+
if not np.all(m.spike_templates == m.spike_clusters) and m.sparse_clusters.cols is None:
122+
assert len(m.merge_map) > 0
123+
assert not np.array_equal(m.sparse_clusters.data, m.sparse_templates.data)
124+
assert m.sparse_clusters.data.shape[0] == m.n_clusters
125+
assert m.sparse_templates.data.shape[0] == m.n_templates
126+
127+
else:
128+
assert len(m.merge_map) == 0
129+
assert np.array_equal(m.sparse_clusters.data, m.sparse_templates.data)
130+
assert np.array_equal(m.n_templates, m.n_clusters)
131+
132+
116133
def test_model_save(template_model_full):
117134
m = template_model_full
118135
m.save_metadata('test', {1: 1})

0 commit comments

Comments
 (0)