Skip to content

Commit 022436f

Browse files
committed
alf conversion after merge
1 parent 00a09b0 commit 022436f

File tree

2 files changed

+140
-32
lines changed

2 files changed

+140
-32
lines changed

phylib/io/alf.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,19 @@ def make_cluster_objects(self):
174174
"""Create clusters.channels, clusters.waveformsDuration and clusters.amps"""
175175
peak_channel_path = self.dir_path / 'clusters.channels.npy'
176176
if not peak_channel_path.exists():
177-
self._save_npy(peak_channel_path.name, self.model.templates_channels)
177+
# self._save_npy(peak_channel_path.name, self.model.templates_channels)
178+
self._save_npy(peak_channel_path.name, self.model.clusters_channels)
178179

179180
waveform_duration_path = self.dir_path / 'clusters.peakToTrough.npy'
180181
if not waveform_duration_path.exists():
181-
self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
182+
# self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
183+
self._save_npy(waveform_duration_path.name, self.model.clusters_waveforms_durations)
182184

183185
# group by average over cluster number
184-
camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
185-
camps[self.cluster_ids] = self.model.templates_amplitudes
186-
amps_path = self.dir_path / 'clusters.amps.npy'
186+
# camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
187+
camps = np.zeros(self.model.clusters_channels.shape[0], ) * np.nan
188+
camps[self.cluster_ids] = self.model.clusters_amplitudes
189+
amps_path = self.dir_path / 'clusters.amps.npy' # TODO these amplitudes are not on the same scale as the spike amps problem?
187190
self._save_npy(amps_path.name, camps * self.ampfactor)
188191

189192
# clusters uuids
@@ -233,7 +236,7 @@ def make_template_and_spikes_objects(self):
233236
# and not seconds
234237
self._save_npy('spikes.times.npy', self.model.spike_times)
235238
self._save_npy('spikes.samples.npy', self.model.spike_samples)
236-
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor)
239+
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor, use='templates')
237240
self._save_npy('spikes.amps.npy', spike_amps)
238241
self._save_npy('templates.amps.npy', template_amps)
239242

@@ -257,9 +260,35 @@ def make_template_and_spikes_objects(self):
257260
templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]]
258261
np.save(self.out_path.joinpath('templates.waveforms'), templates)
259262
np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds)
263+
264+
_, clusters_v, cluster_amps = self.model.get_amplitudes_true(self.ampfactor, use='clusters')
265+
n_clusters, n_wavsamps, nchall = clusters_v.shape
266+
# for some datasets, 32 may be too much
267+
ncw = min(self.model.n_closest_channels, nchall)
268+
assert(n_clusters == self.model.n_clusters)
269+
templates = np.zeros((n_clusters, n_wavsamps, ncw), dtype=np.float32)
270+
templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32)
271+
# for each template, find the nearest channels to keep (one the same probe...)
272+
for t in np.arange(n_clusters):
273+
# here we need to fill with nans if it doesn't exists, but then can no longet be int (sorry) # or have it all 0
274+
channels = self.model.clusters_channels
275+
276+
current_probe = self.model.channel_probes[channels[t]]
277+
channel_distance = np.sum(np.abs(
278+
self.model.channel_positions -
279+
self.model.channel_positions[channels[t]]), axis=1)
280+
channel_distance[self.model.channel_probes != current_probe] += np.inf
281+
templates_inds[t, :] = np.argsort(channel_distance)[:ncw]
282+
templates[t, ...] = clusters_v[t, :][:, templates_inds[t, :]]
260283
np.save(self.out_path.joinpath('clusters.waveforms'), templates)
261284
np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds)
262285

286+
# This should really be here
287+
np.save(self.out_path.joinpath('clusters.amps'), cluster_amps)
288+
289+
# np.save(self.out_path.joinpath('clusters.waveforms'), templates)
290+
# np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds)
291+
263292
def rename_with_label(self):
264293
"""add the label as an ALF part name before the extension if any label provided"""
265294
if not self.label:

phylib/io/model.py

Lines changed: 105 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,18 @@ def _load_data(self):
412412
self.n_samples_waveforms = 0
413413
self.n_channels_loc = 0
414414

415+
# Clusters waveforms
416+
if np.all(self.spike_clusters == self.spike_templates):
417+
self.merge_map = {}
418+
self.nan_clusters = []
419+
self.sparse_clusters = self.sparse_templates
420+
self.n_clusters = self.spike_templates.max() + 1
421+
else:
422+
if self.sparse_templates.cols is None:
423+
self.merge_map, self.nan_clusters = self.get_merge_map()
424+
self.sparse_clusters = self.cluster_waveforms()
425+
self.n_clusters = self.spike_clusters.max() + 1
426+
415427
# Spike waveforms (optional, otherwise fetched from raw data as needed).
416428
self.spike_waveforms = self._load_spike_waveforms()
417429

@@ -861,12 +873,12 @@ def _template_n_channels(self, template_id, n_channels):
861873
channel_ids += [-1] * (n_channels - len(channel_ids))
862874
return channel_ids
863875

864-
def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None):
876+
def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
865877
"""Return data for one template."""
866878
if not self.sparse_templates:
867879
return
868880
template_w = self.sparse_templates.data[template_id, ...]
869-
template = self._unwhiten(template_w).astype(np.float32)
881+
template = self._unwhiten(template_w).astype(np.float32) if unwhiten else template_w
870882
assert template.ndim == 2
871883
channel_ids_, amplitude, best_channel = self._find_best_channels(
872884
template, amplitude_threshold=amplitude_threshold)
@@ -881,7 +893,7 @@ def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold
881893
channel_ids=channel_ids,
882894
)
883895

884-
def _get_template_sparse(self, template_id):
896+
def _get_template_sparse(self, template_id, unwhiten=True):
885897
data, cols = self.sparse_templates.data, self.sparse_templates.cols
886898
assert cols is not None
887899
template_w, channel_ids = data[template_id], cols[template_id]
@@ -902,7 +914,7 @@ def _get_template_sparse(self, template_id):
902914
channel_ids = channel_ids.astype(np.uint32)
903915

904916
# Unwhiten.
905-
template = self._unwhiten(template_w, channel_ids=channel_ids)
917+
template = self._unwhiten(template_w, channel_ids=channel_ids) if unwhiten else template_w
906918
template = template.astype(np.float32)
907919
assert template.ndim == 2
908920
assert template.shape[1] == len(channel_ids)
@@ -920,17 +932,31 @@ def _get_template_sparse(self, template_id):
920932
)
921933
return out
922934

935+
def get_merge_map(self):
936+
""""Gets the merge mapping for between spikes.clusters and spikes.templates"""
937+
inverse_mapping_dict = {key: [] for key in range(np.max(self.spike_clusters) + 1)}
938+
for temp in np.unique(self.spike_templates):
939+
idx = np.where(self.spike_templates == temp)[0]
940+
new_idx = self.spike_clusters[idx]
941+
mapping = np.unique(new_idx)
942+
for n in mapping:
943+
inverse_mapping_dict[n].append(temp)
944+
945+
nan_idx = np.array([idx for idx, val in inverse_mapping_dict.items() if len(val) == 0])
946+
947+
return inverse_mapping_dict, nan_idx
948+
923949
#--------------------------------------------------------------------------
924950
# Data access methods
925951
#--------------------------------------------------------------------------
926952

927-
def get_template(self, template_id, channel_ids=None, amplitude_threshold=None):
953+
def get_template(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
928954
"""Get data about a template."""
929955
if self.sparse_templates and self.sparse_templates.cols is not None:
930-
return self._get_template_sparse(template_id)
956+
return self._get_template_sparse(template_id, unwhiten=unwhiten)
931957
else:
932958
return self._get_template_dense(
933-
template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold)
959+
template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold, unwhiten=unwhiten)
934960

935961
def get_waveforms(self, spike_ids, channel_ids=None):
936962
"""Return spike waveforms on specified channels."""
@@ -1047,7 +1073,7 @@ def get_depths(self):
10471073
# take only first component
10481074
features = self.sparse_features.data[ispi, :, 0]
10491075
features = np.maximum(features, 0) ** 2 # takes only positive values into account
1050-
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32)
1076+
ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32) ## TODO this should be templates, otherwise won't work
10511077
# features = np.square(self.sparse_features.data[ispi, :, 0])
10521078
# ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64)
10531079
ypos = self.channel_positions[ichannels, 1]
@@ -1059,7 +1085,7 @@ def get_depths(self):
10591085
break
10601086
return spikes_depths
10611087

1062-
def get_amplitudes_true(self, sample2unit=1.):
1088+
def get_amplitudes_true(self, sample2unit=1., use='templates'):
10631089
"""Convert spike amplitude values to input amplitudes units
10641090
via scaling by unwhitened template waveform.
10651091
:param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
@@ -1074,26 +1100,35 @@ def get_amplitudes_true(self, sample2unit=1.):
10741100
# spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
10751101
# to rescale the template,
10761102

1103+
if use == 'clusters':
1104+
sparse = self.sparse_clusters
1105+
spikes = self.spike_clusters
1106+
n_wav = self.n_clusters
1107+
else:
1108+
sparse = self.sparse_templates
1109+
spikes = self.spike_templates
1110+
n_wav = self.n_templates
1111+
10771112
# unwhiten template waveforms on their channels of max amplitude
1078-
if self.sparse_templates.cols:
1113+
if sparse.cols:
10791114
raise NotImplementedError
10801115
# apply the inverse whitening matrix to the template
1081-
templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc
1082-
for n in np.arange(self.n_templates):
1083-
templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi)
1116+
templates_wfs = np.zeros_like(sparse.data) # nt, ns, nc
1117+
for n in np.arange(n_wav):
1118+
templates_wfs[n, :, :] = np.matmul(sparse.data[n, :, :], self.wmi)
10841119

10851120
# The amplitude on each channel is the positive peak minus the negative
10861121
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1)
10871122

10881123
# The template arbitrary unit amplitude is the amplitude of its largest channel
10891124
# (but see below for true tempAmps)
10901125
templates_amps_au = np.max(templates_ch_amps, axis=1)
1091-
spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes
1126+
spike_amps = templates_amps_au[spikes] * self.amplitudes
10921127

10931128
with np.errstate(divide='ignore', invalid='ignore'):
10941129
# take the average spike amplitude per template
1095-
templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) /
1096-
np.bincount(self.spike_templates))
1130+
templates_amps_v = (np.bincount(spikes, weights=spike_amps) /
1131+
np.bincount(spikes))
10971132
# scale back the template according to the spikes units
10981133
templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
10991134
)[:, np.newaxis, np.newaxis]
@@ -1167,18 +1202,18 @@ def get_template_waveforms(self, template_id):
11671202
template = self.get_template(template_id)
11681203
return template.template if template else None
11691204

1170-
def get_cluster_mean_waveforms(self, cluster_id):
1205+
def get_cluster_mean_waveforms(self, cluster_id, unwhiten=True):
11711206
"""Return the mean template waveforms of a cluster, as a weighted average of the
11721207
template waveforms from which the cluster originates from."""
11731208
count = self.get_template_counts(cluster_id)
11741209
best_template = np.argmax(count)
11751210
template_ids = np.nonzero(count)[0]
11761211
count = count[template_ids]
11771212
# Get local channels of the best template for the given cluster.
1178-
template = self.get_template(best_template)
1213+
template = self.get_template(best_template, unwhiten=unwhiten)
11791214
channel_ids = template.channel_ids
11801215
# Get all templates from which this cluster stems from.
1181-
templates = [self.get_template(template_id) for template_id in template_ids]
1216+
templates = [self.get_template(template_id, unwhiten=unwhiten) for template_id in template_ids]
11821217
# Construct the waveforms array.
11831218
ns = self.n_samples_waveforms
11841219
data = np.zeros((len(template_ids), ns, self.n_channels))
@@ -1205,16 +1240,27 @@ def get_cluster_spike_waveforms(self, cluster_id):
12051240
@property
12061241
def templates_channels(self):
12071242
"""Returns a vector of peak channels for all templates"""
1208-
tmp = self.sparse_templates.data
1243+
return self._channels(self.sparse_templates)
1244+
1245+
@property
1246+
def clusters_channels(self):
1247+
"""Returns a vector of peak channels for all templates"""
1248+
channels = self._channels(self.sparse_clusters)
1249+
return channels
1250+
1251+
def _channels(self, sparse):
1252+
# TODO document and better name
1253+
tmp = sparse.data
12091254
n_templates, n_samples, n_channels = tmp.shape
1210-
if self.sparse_templates.cols is None:
1255+
if sparse.cols is None:
12111256
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
12121257
else:
12131258
# when the templates are sparse, the first channel is the highest amplitude channel
1214-
template_peak_channels = self.sparse_templates.cols[:, 0]
1259+
template_peak_channels = sparse.cols[:, 0]
12151260
assert template_peak_channels.shape == (n_templates,)
12161261
return template_peak_channels
12171262

1263+
12181264
@property
12191265
def templates_probes(self):
12201266
"""Returns a vector of probe index for all templates"""
@@ -1223,16 +1269,32 @@ def templates_probes(self):
12231269
@property
12241270
def templates_amplitudes(self):
12251271
"""Returns the average amplitude per cluster"""
1226-
tid = np.unique(self.spike_templates)
1227-
n = np.bincount(self.spike_templates)[tid]
1228-
a = np.bincount(self.spike_templates, weights=self.amplitudes)[tid]
1272+
return self._amplitudes(self.spike_templates)
1273+
1274+
@property
1275+
def clusters_amplitudes(self):
1276+
"""Returns the average amplitude per cluster"""
1277+
return self._amplitudes(self.spike_clusters)
1278+
1279+
def _amplitudes(self, tmp):
1280+
tid = np.unique(tmp)
1281+
n = np.bincount(tmp)[tid]
1282+
a = np.bincount(tmp, weights=self.amplitudes)[tid]
12291283
n[np.isnan(n)] = 1
12301284
return a / n
12311285

12321286
@property
12331287
def templates_waveforms_durations(self):
12341288
"""Returns a vector of waveform durations (ms) for all templates"""
1235-
tmp = self.sparse_templates.data
1289+
return self._waveform_durations(self.sparse_templates.data)
1290+
1291+
@property
1292+
def clusters_waveforms_durations(self):
1293+
"""Returns a vector of waveform durations (ms) for all templates"""
1294+
waveform_duration = self._waveform_durations(self.sparse_clusters.data)
1295+
return waveform_duration
1296+
1297+
def _waveform_durations(self, tmp):
12361298
n_templates, n_samples, n_channels = tmp.shape
12371299
# Compute the peak channels for each template.
12381300
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
@@ -1241,6 +1303,23 @@ def templates_waveforms_durations(self):
12411303
(n_templates, n_channels), mode='raise', order='C')
12421304
return durations.flatten()[ind].astype(np.float64) / self.sample_rate * 1e3
12431305

1306+
def cluster_waveforms(self):
1307+
"""
1308+
Computes the cluster waveforms for split and merged clusters
1309+
:return:
1310+
"""
1311+
# Only non sparse implementation
1312+
ns = self.n_samples_waveforms # TODO put not implemented warning
1313+
data = np.zeros((np.max(self.cluster_ids) + 1, ns, self.n_channels)) # TODO can be self.n_clusters
1314+
for clust, val in self.merge_map.items():
1315+
if len(val) > 1:
1316+
mean_waveform = self.get_cluster_mean_waveforms(clust, unwhiten=False)
1317+
data[clust, :, mean_waveform.channel_ids] = np.swapaxes(mean_waveform.mean_waveforms, 0, 1)
1318+
elif len(val) == 1:
1319+
data[clust, :, :] = self.sparse_templates.data[val[0], :, :]
1320+
1321+
return Bunch(data=data, cols=None)
1322+
12441323
#--------------------------------------------------------------------------
12451324
# Saving methods
12461325
#--------------------------------------------------------------------------

0 commit comments

Comments
 (0)