Skip to content

Commit fe42a49

Browse files
authored
Merge pull request #29 from cortex-lab/ibl_tests
Alf export IBL
2 parents 3e4dfb9 + 6d98a80 commit fe42a49

File tree

8 files changed

+115
-75
lines changed

8 files changed

+115
-75
lines changed

phylib/io/alf.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
# File utils
2828
#------------------------------------------------------------------------------
2929

30-
NCH_WAVEFORMS = 32 # number of channels to be saved in templates.waveforms and channels.waveforms
30+
NSAMPLE_WAVEFORMS = 500 # number of waveforrms sampled out of the raw data
3131

3232
_FILE_RENAMES = [ # file_in, file_out, squeeze (bool to squeeze vector from matlab in npy)
3333
('params.py', 'params.py', None),
34-
('cluster_metrics.csv', 'clusters.metrics.csv', None),
34+
('cluster_KSLabel.tsv', 'cluster_KSLabel.tsv', None),
3535
('spike_clusters.npy', 'spikes.clusters.npy', True),
3636
('spike_templates.npy', 'spikes.templates.npy', True),
3737
('channel_positions.npy', 'channels.localCoordinates.npy', False),
@@ -42,6 +42,9 @@
4242
('_phy_spikes_subset.channels.npy', '_phy_spikes_subset.channels.npy', False),
4343
('_phy_spikes_subset.spikes.npy', '_phy_spikes_subset.spikes.npy', False),
4444
('_phy_spikes_subset.waveforms.npy', '_phy_spikes_subset.waveforms.npy', False),
45+
('drift.depth_scale.npy', 'drift.depth_scale.npy', False),
46+
('drift.time_scale.npy', 'drift.time_scale.npy', False),
47+
('drift.um.npy', 'drift.um.npy', False),
4548
# ('cluster_group.tsv', 'ks2/clusters.phyAnnotation.tsv', False), # todo check indexing, add2QC
4649
]
4750

@@ -116,21 +119,23 @@ def convert(self, out_path, force=False, label='', ampfactor=1):
116119
if not self.out_path.exists():
117120
self.out_path.mkdir()
118121

119-
with tqdm(desc="Converting to ALF", total=95) as bar:
120-
self.copy_files(force=force)
121-
bar.update(10)
122-
self.make_spike_times_amplitudes()
122+
with tqdm(desc="Converting to ALF", total=125) as bar:
123123
bar.update(10)
124124
self.make_cluster_objects()
125125
bar.update(10)
126126
self.make_channel_objects()
127127
bar.update(5)
128+
self.make_template_and_spikes_objects()
129+
bar.update(30)
130+
self.model.save_spikes_subset_waveforms(
131+
NSAMPLE_WAVEFORMS, sample2unit=self.ampfactor)
132+
bar.update(50)
128133
self.make_depths()
129134
bar.update(20)
130-
self.make_template_object()
131-
bar.update(30)
132135
self.rm_files()
133136
bar.update(10)
137+
self.copy_files(force=force)
138+
bar.update(10)
134139
self.rename_with_label()
135140

136141
# Return the TemplateModel of the converted ALF dataset if the params.py file exists.
@@ -165,16 +170,8 @@ def _save_npy(self, filename, arr):
165170
"""Save an array into a .npy file."""
166171
np.save(self.out_path / filename, arr)
167172

168-
def make_spike_times_amplitudes(self):
169-
"""We cannot just rename/copy spike_times.npy because it is in unit of
170-
*samples*, and not in seconds."""
171-
self._save_npy('spikes.times.npy', self.model.spike_times)
172-
self._save_npy('spikes.samples.npy', self.model.spike_samples)
173-
self._save_npy('spikes.amps.npy', self.model.get_amplitudes_true() * self.ampfactor)
174-
175173
def make_cluster_objects(self):
176174
"""Create clusters.channels, clusters.waveformsDuration and clusters.amps"""
177-
178175
peak_channel_path = self.dir_path / 'clusters.channels.npy'
179176
if not peak_channel_path.exists():
180177
self._save_npy(peak_channel_path.name, self.model.templates_channels)
@@ -184,8 +181,8 @@ def make_cluster_objects(self):
184181
self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
185182

186183
# group by average over cluster number
187-
camps = np.zeros(np.max(self.cluster_ids) - np.min(self.cluster_ids) + 1,) * np.nan
188-
camps[self.cluster_ids - np.min(self.cluster_ids)] = self.model.templates_amplitudes
184+
camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
185+
camps[self.cluster_ids] = self.model.templates_amplitudes
189186
amps_path = self.dir_path / 'clusters.amps.npy'
190187
self._save_npy(amps_path.name, camps * self.ampfactor)
191188

@@ -198,7 +195,7 @@ def make_cluster_objects(self):
198195
def make_channel_objects(self):
199196
"""If there is no rawInd file, create it"""
200197
rawInd_path = self.dir_path / 'channels.rawInd.npy'
201-
rawInd = np.zeros_like(self.model.channel_probes).astype(np.int)
198+
rawInd = np.zeros_like(self.model.channel_probes).astype(int)
202199
channel_offset = 0
203200
for probe in np.unique(self.model.channel_probes):
204201
ind = self.model.channel_probes == probe
@@ -225,20 +222,27 @@ def make_depths(self):
225222
spikes_depths = clusters_depths[spike_clusters]
226223
else:
227224
spikes_depths = self.model.get_depths()
228-
# if PC features are provided, compute the depth as the weighted sum of coordinates
229-
230225
self._save_npy('spikes.depths.npy', spikes_depths)
231226
self._save_npy('clusters.depths.npy', clusters_depths)
232227

233-
def make_template_object(self):
228+
def make_template_and_spikes_objects(self):
234229
"""Creates the template waveforms sparse object
235230
Without manual curation, it also corresponds to clusters waveforms objects.
236231
"""
232+
# "We cannot just rename/copy spike_times.npy because it is in unit of samples,
233+
# and not seconds
234+
self._save_npy('spikes.times.npy', self.model.spike_times)
235+
self._save_npy('spikes.samples.npy', self.model.spike_samples)
236+
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor)
237+
self._save_npy('spikes.amps.npy', spike_amps)
238+
self._save_npy('templates.amps.npy', template_amps)
239+
237240
if self.model.sparse_templates.cols:
238241
raise(NotImplementedError("Sparse template export to ALF not implemented yet"))
239242
else:
240-
n_templates, n_wavsamps, nchall = self.model.sparse_templates.data.shape
241-
ncw = min(NCH_WAVEFORMS, nchall) # for some datasets, 32 may be too much
243+
n_templates, n_wavsamps, nchall = templates_v.shape
244+
# for some datasets, 32 may be too much
245+
ncw = min(self.model.n_closest_channels, nchall)
242246
assert(n_templates == self.model.n_templates)
243247
templates = np.zeros((n_templates, n_wavsamps, ncw), dtype=np.float32)
244248
templates_inds = np.zeros((n_templates, ncw), dtype=np.int32)
@@ -250,10 +254,10 @@ def make_template_object(self):
250254
self.model.channel_positions[self.model.templates_channels[t]]), axis=1)
251255
channel_distance[self.model.channel_probes != current_probe] += np.inf
252256
templates_inds[t, :] = np.argsort(channel_distance)[:ncw]
253-
templates[t, ...] = self.model.sparse_templates.data[t, :][:, templates_inds[t, :]]
254-
np.save(self.out_path.joinpath('templates.waveforms'), templates * self.ampfactor)
257+
templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]]
258+
np.save(self.out_path.joinpath('templates.waveforms'), templates)
255259
np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds)
256-
np.save(self.out_path.joinpath('clusters.waveforms'), templates * self.ampfactor)
260+
np.save(self.out_path.joinpath('clusters.waveforms'), templates)
257261
np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds)
258262

259263
def rename_with_label(self):

phylib/io/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _index_of(arr, lookup):
115115
# values
116116
lookup = np.asarray(lookup, dtype=np.int32)
117117
m = (lookup.max() if len(lookup) else 0) + 1
118-
tmp = np.zeros(m + 1, dtype=np.int)
118+
tmp = np.zeros(m + 1, dtype=int)
119119
# Ensure that -1 values are kept.
120120
tmp[-1] = -1
121121
if len(lookup):
@@ -327,7 +327,7 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None):
327327
def _spikes_in_clusters(spike_clusters, clusters):
328328
"""Return the ids of all spikes belonging to the specified clusters."""
329329
if len(spike_clusters) == 0 or len(clusters) == 0:
330-
return np.array([], dtype=np.int)
330+
return np.array([], dtype=int)
331331
return np.nonzero(np.in1d(spike_clusters, clusters))[0]
332332

333333

phylib/io/model.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,44 +1035,70 @@ def get_template_features(self, spike_ids):
10351035
def get_depths(self):
10361036
"""Compute spike depths based on spike pc features and probe depths."""
10371037
# compute the depth as the weighted sum of coordinates
1038-
batch_sz = 50000 # number of spikes per batch
1038+
# if PC features are provided, compute the depth as the weighted sum of coordinates
1039+
nbatch = 50000
10391040
c = 0
1040-
spike_depths = np.zeros_like(self.spike_times)
1041-
nspi = spike_depths.shape[0]
1041+
spikes_depths = np.zeros_like(self.spike_times) * np.nan
1042+
nspi = spikes_depths.shape[0]
10421043
if self.sparse_features is None or self.sparse_features.data.shape[0] != self.n_spikes:
10431044
return None
10441045
while True:
1045-
ispi = np.arange(c, min(c + batch_sz, nspi))
1046+
ispi = np.arange(c, min(c + nbatch, nspi))
10461047
# take only first component
1047-
features = np.square(self.sparse_features.data[ispi, :, 0])
1048-
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.int64)
1048+
features = self.sparse_features.data[ispi, :, 0]
1049+
features = np.maximum(features, 0) ** 2 # takes only positive values into account
1050+
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32)
10491051
ypos = self.channel_positions[ichannels, 1]
1050-
1051-
spike_depths[ispi] = np.sum(np.transpose(ypos * features) /
1052-
np.sum(features, axis=1), axis=0)
1053-
c += batch_sz
1052+
with np.errstate(divide='ignore'):
1053+
spikes_depths[ispi] = (np.sum(np.transpose(ypos * features) /
1054+
np.sum(features, axis=1), axis=0))
1055+
c += nbatch
10541056
if c >= nspi:
10551057
break
1058+
return spikes_depths
10561059

1057-
return spike_depths
1058-
1059-
def get_amplitudes_true(self):
1060+
def get_amplitudes_true(self, sample2unit=1.):
10601061
"""Convert spike amplitude values to input amplitudes units
1061-
via scaling by unwhitened template waveform."""
1062-
# unwhiten template waveforms on their channels of max amplitude
1063-
templates_chs = self.templates_channels
1064-
templates_wfs = self.sparse_templates.data[np.arange(self.n_templates), :, templates_chs]
1065-
templates_wfs_unw = templates_wfs.T * self.wmi[templates_chs, templates_chs]
1066-
templates_amps = np.abs(
1067-
np.max(templates_wfs_unw, axis=0) - np.min(templates_wfs_unw, axis=0))
1062+
via scaling by unwhitened template waveform.
1063+
:param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
1064+
:returns: spike_amplitudes_volts: np.array [nspikes] spike amplitudes in raw data units
1065+
:returns: templates_volts: np.array[ntemplates, nsamples, nchannels]: templates
1066+
in raw data units
1067+
:returns: template_amps_volts: np.array[ntemplates]: average templates amplitudes
1068+
in raw data units
1069+
To scale the template for template matching,
1070+
raw_data_volts = templates_volts * spike_amplitudes_volts / template_amps_volts
1071+
"""
1072+
# spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
1073+
# to rescale the template,
10681074

1069-
# scale the spike amplitude values by the template amplitude values
1070-
amplitudes_v = np.zeros_like(self.amplitudes)
1071-
for t in range(self.n_templates):
1072-
idxs = self.get_template_spikes(t)
1073-
amplitudes_v[idxs] = self.amplitudes[idxs] * templates_amps[t]
1074-
1075-
return amplitudes_v
1075+
# unwhiten template waveforms on their channels of max amplitude
1076+
if self.sparse_templates.cols:
1077+
raise NotImplementedError
1078+
# apply the inverse whitening matrix to the template
1079+
templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc
1080+
for n in np.arange(self.n_templates):
1081+
templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi)
1082+
1083+
# The amplitude on each channel is the positive peak minus the negative
1084+
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1)
1085+
1086+
# The template arbitrary unit amplitude is the amplitude of its largest channel
1087+
# (but see below for true tempAmps)
1088+
templates_amps_au = np.max(templates_ch_amps, axis=1)
1089+
spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes
1090+
1091+
with np.errstate(divide='ignore'):
1092+
# take the average spike amplitude per template
1093+
templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) /
1094+
np.bincount(self.spike_templates))
1095+
# scale back the template according to the spikes units
1096+
templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
1097+
)[:, np.newaxis, np.newaxis]
1098+
1099+
return (spike_amps * sample2unit,
1100+
templates_physical_unit * sample2unit,
1101+
templates_amps_v * sample2unit)
10761102

10771103
#--------------------------------------------------------------------------
10781104
# Internal helper methods for public high-level methods
@@ -1232,15 +1258,18 @@ def save_spike_clusters(self, spike_clusters):
12321258
logger.debug("Save spike clusters to `%s`.", path)
12331259
np.save(path, spike_clusters)
12341260

1235-
def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_channels=None):
1261+
def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_channels=None,
1262+
sample2unit=1.):
12361263
if self.traces is None:
12371264
logger.warning(
12381265
"Spike waveforms could not be extracted as the raw data file is not available.")
12391266
return
12401267

12411268
n_chunks_kept = 20 # TODO: better choice
12421269
nst = max_n_spikes_per_template
1243-
nc = max_n_channels
1270+
nc = max_n_channels or self.n_closest_channels
1271+
nc = max(nc, self.n_closest_channels)
1272+
12441273
assert nst > 0
12451274
assert nc > 0
12461275

@@ -1275,7 +1304,7 @@ def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_cha
12751304
# Extract waveforms from the raw data on a chunk by chunk basis.
12761305
export_waveforms(
12771306
path, self.traces, self.spike_samples[spike_ids], spike_channels,
1278-
n_samples_waveforms=self.n_samples_waveforms)
1307+
n_samples_waveforms=self.n_samples_waveforms, sample2unit=sample2unit)
12791308

12801309
# Reload spike waveforms.
12811310
self.spike_waveforms = self._load_spike_waveforms()

phylib/io/tests/test_alf.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, tempdir):
3636
self.nt = 5
3737
self.ncd = 1000
3838
np.save(p / 'spike_times.npy', .01 * np.cumsum(nr.exponential(size=self.ns)))
39-
np.save(p / 'spike_clusters.npy', nr.randint(low=0, high=self.nt, size=self.ns))
39+
np.save(p / 'spike_clusters.npy', nr.randint(low=1, high=self.nt, size=self.ns))
4040
shutil.copy(p / 'spike_clusters.npy', p / 'spike_templates.npy')
4141
np.save(p / 'amplitudes.npy', nr.uniform(low=0.5, high=1.5, size=self.ns))
4242
np.save(p / 'channel_positions.npy', np.c_[np.arange(self.nc), np.zeros(self.nc)])
@@ -174,16 +174,22 @@ def check_conversion_output():
174174
assert f.exists()
175175

176176
# makes sure the output dimensions match (especially clusters which should be 4)
177-
cl_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('clusters.') and
178-
f.name.endswith('.npy')]
177+
cl_shape = []
178+
for f in new_files:
179+
if f.name.startswith('clusters.') and f.name.endswith('.npy'):
180+
cl_shape.append(np.load(f).shape[0])
181+
elif f.name.startswith('clusters.') and f.name.endswith('.csv'):
182+
with open(f) as fid:
183+
cl_shape.append(len(fid.readlines()) - 1)
179184
sp_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('spikes.')]
180185
ch_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('channels.')]
186+
181187
assert len(set(cl_shape)) == 1
182188
assert len(set(sp_shape)) == 1
183189
assert len(set(ch_shape)) == 1
184190

185191
dur = np.load(next(out_path.glob('clusters.peakToTrough*.npy')))
186-
assert np.all(dur == np.array([18., -1., 9.5, 2.5, -2.]))
192+
assert np.all(dur == np.array([-9.5, 3., 13., -4.5, -2.5]))
187193

188194
def read_after_write():
189195
model = TemplateModel(dir_path=out_path, dat_path=dataset.dat_path,

phylib/io/traces.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,20 +641,21 @@ def iter_waveforms(traces, spike_samples, spike_channels, n_samples_waveforms=No
641641

642642

643643
def export_waveforms(
644-
path, traces, spike_samples, spike_channels, n_samples_waveforms=None, cache=False):
644+
path, traces, spike_samples, spike_channels, n_samples_waveforms=None, cache=False,
645+
sample2unit=1):
645646
"""Export a selection of spike waveforms to a npy file by iterating over the data on a chunk
646647
by chunk basis."""
647648
n_spikes = len(spike_samples)
648649
spike_channels = np.asarray(spike_channels, dtype=np.int32)
649650
n_channels_loc = spike_channels.shape[1]
650651
shape = (n_spikes, n_samples_waveforms, n_channels_loc)
651-
652-
writer = NpyWriter(path, shape, traces.dtype)
652+
dtype = traces.dtype if sample2unit is None else float
653+
writer = NpyWriter(path, shape, dtype)
653654
size_written = 0
654655
for waveforms in iter_waveforms(
655656
traces, spike_samples, spike_channels, n_samples_waveforms=n_samples_waveforms,
656657
cache=cache):
657-
writer.append(waveforms)
658+
writer.append(waveforms * sample2unit)
658659
size_written += waveforms.size
659660
writer.close()
660661
assert prod(shape) == size_written

phylib/stats/ccg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def correlograms(
148148

149149
# At a given shift, the mask precises which spikes have matching spikes
150150
# within the correlogram time window.
151-
mask = np.ones_like(spike_samples, dtype=np.bool)
151+
mask = np.ones_like(spike_samples, dtype=bool)
152152

153153
correlograms = _create_correlograms_array(n_clusters, winsize_bins)
154154

phylib/utils/_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
#------------------------------------------------------------------------------
1616

1717
_ACCEPTED_ARRAY_DTYPES = (
18-
np.float, np.float32, np.float64, np.int, np.int8, np.int16, np.uint8, np.uint16,
19-
np.int32, np.int64, np.uint32, np.uint64, np.bool)
18+
float, np.float32, np.float64, int, np.int8, np.int16, np.uint8, np.uint16,
19+
np.int32, np.int64, np.uint32, np.uint64, bool)
2020

2121

2222
class Bunch(dict):

phylib/utils/tests/test_types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ def _check(arr):
8888
_check(_as_array(3.))
8989
_check(_as_array([3]))
9090

91-
_check(_as_array(3, np.float))
92-
_check(_as_array(3., np.float))
93-
_check(_as_array([3], np.float))
91+
_check(_as_array(3, float))
92+
_check(_as_array(3., float))
93+
_check(_as_array([3], float))
9494
_check(_as_array(np.array([3])))
9595
with raises(ValueError):
96-
_check(_as_array(np.array([3]), dtype=np.object))
97-
_check(_as_array(np.array([3]), np.float))
96+
_check(_as_array(np.array([3]), dtype=object))
97+
_check(_as_array(np.array([3]), float))
9898

9999
assert _as_array(None) is None
100100
assert not _is_array_like(None)

0 commit comments

Comments
 (0)