@@ -412,6 +412,18 @@ def _load_data(self):
412412 self .n_samples_waveforms = 0
413413 self .n_channels_loc = 0
414414
415+ # Clusters waveforms
416+ if np .all (self .spike_clusters == self .spike_templates ):
417+ self .merge_map = {}
418+ self .nan_clusters = []
419+ self .sparse_clusters = self .sparse_templates
420+ self .n_clusters = self .spike_templates .max () + 1
421+ else :
422+ if self .sparse_templates .cols is None :
423+ self .merge_map , self .nan_clusters = self .get_merge_map ()
424+ self .sparse_clusters = self .cluster_waveforms ()
425+ self .n_clusters = self .spike_clusters .max () + 1
426+
415427 # Spike waveforms (optional, otherwise fetched from raw data as needed).
416428 self .spike_waveforms = self ._load_spike_waveforms ()
417429
@@ -861,12 +873,12 @@ def _template_n_channels(self, template_id, n_channels):
861873 channel_ids += [- 1 ] * (n_channels - len (channel_ids ))
862874 return channel_ids
863875
864- def _get_template_dense (self , template_id , channel_ids = None , amplitude_threshold = None ):
876+ def _get_template_dense (self , template_id , channel_ids = None , amplitude_threshold = None , unwhiten = True ):
865877 """Return data for one template."""
866878 if not self .sparse_templates :
867879 return
868880 template_w = self .sparse_templates .data [template_id , ...]
869- template = self ._unwhiten (template_w ).astype (np .float32 )
881+ template = self ._unwhiten (template_w ).astype (np .float32 ) if unwhiten else template_w
870882 assert template .ndim == 2
871883 channel_ids_ , amplitude , best_channel = self ._find_best_channels (
872884 template , amplitude_threshold = amplitude_threshold )
@@ -881,7 +893,7 @@ def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold
881893 channel_ids = channel_ids ,
882894 )
883895
884- def _get_template_sparse (self , template_id ):
896+ def _get_template_sparse (self , template_id , unwhiten = True ):
885897 data , cols = self .sparse_templates .data , self .sparse_templates .cols
886898 assert cols is not None
887899 template_w , channel_ids = data [template_id ], cols [template_id ]
@@ -902,7 +914,7 @@ def _get_template_sparse(self, template_id):
902914 channel_ids = channel_ids .astype (np .uint32 )
903915
904916 # Unwhiten.
905- template = self ._unwhiten (template_w , channel_ids = channel_ids )
917+ template = self ._unwhiten (template_w , channel_ids = channel_ids ) if unwhiten else template_w
906918 template = template .astype (np .float32 )
907919 assert template .ndim == 2
908920 assert template .shape [1 ] == len (channel_ids )
@@ -920,17 +932,31 @@ def _get_template_sparse(self, template_id):
920932 )
921933 return out
922934
935+ def get_merge_map (self ):
936+ """"Gets the merge mapping for between spikes.clusters and spikes.templates"""
937+ inverse_mapping_dict = {key : [] for key in range (np .max (self .spike_clusters ) + 1 )}
938+ for temp in np .unique (self .spike_templates ):
939+ idx = np .where (self .spike_templates == temp )[0 ]
940+ new_idx = self .spike_clusters [idx ]
941+ mapping = np .unique (new_idx )
942+ for n in mapping :
943+ inverse_mapping_dict [n ].append (temp )
944+
945+ nan_idx = np .array ([idx for idx , val in inverse_mapping_dict .items () if len (val ) == 0 ])
946+
947+ return inverse_mapping_dict , nan_idx
948+
923949 #--------------------------------------------------------------------------
924950 # Data access methods
925951 #--------------------------------------------------------------------------
926952
927- def get_template (self , template_id , channel_ids = None , amplitude_threshold = None ):
953+ def get_template (self , template_id , channel_ids = None , amplitude_threshold = None , unwhiten = True ):
928954 """Get data about a template."""
929955 if self .sparse_templates and self .sparse_templates .cols is not None :
930- return self ._get_template_sparse (template_id )
956+ return self ._get_template_sparse (template_id , unwhiten = unwhiten )
931957 else :
932958 return self ._get_template_dense (
933- template_id , channel_ids = channel_ids , amplitude_threshold = amplitude_threshold )
959+ template_id , channel_ids = channel_ids , amplitude_threshold = amplitude_threshold , unwhiten = unwhiten )
934960
935961 def get_waveforms (self , spike_ids , channel_ids = None ):
936962 """Return spike waveforms on specified channels."""
@@ -1047,7 +1073,7 @@ def get_depths(self):
10471073 # take only first component
10481074 features = self .sparse_features .data [ispi , :, 0 ]
10491075 features = np .maximum (features , 0 ) ** 2 # takes only positive values into account
1050- ichannels = self .sparse_features .cols [self .spike_clusters [ispi ]].astype (np .uint32 )
1076+ ichannels = self .sparse_features .cols [self .spike_templates [ispi ]].astype (np .uint32 ) ## TODO this should be templates, otherwise won't work
10511077 # features = np.square(self.sparse_features.data[ispi, :, 0])
10521078 # ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64)
10531079 ypos = self .channel_positions [ichannels , 1 ]
@@ -1059,7 +1085,7 @@ def get_depths(self):
10591085 break
10601086 return spikes_depths
10611087
1062- def get_amplitudes_true (self , sample2unit = 1. ):
1088+ def get_amplitudes_true (self , sample2unit = 1. , use = 'templates' ):
10631089 """Convert spike amplitude values to input amplitudes units
10641090 via scaling by unwhitened template waveform.
10651091 :param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
@@ -1074,26 +1100,35 @@ def get_amplitudes_true(self, sample2unit=1.):
10741100 # spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
10751101 # to rescale the template,
10761102
1103+ if use == 'clusters' :
1104+ sparse = self .sparse_clusters
1105+ spikes = self .spike_clusters
1106+ n_wav = self .n_clusters
1107+ else :
1108+ sparse = self .sparse_templates
1109+ spikes = self .spike_templates
1110+ n_wav = self .n_templates
1111+
10771112 # unwhiten template waveforms on their channels of max amplitude
1078- if self . sparse_templates .cols :
1113+ if sparse .cols :
10791114 raise NotImplementedError
10801115 # apply the inverse whitening matrix to the template
1081- templates_wfs = np .zeros_like (self . sparse_templates .data ) # nt, ns, nc
1082- for n in np .arange (self . n_templates ):
1083- templates_wfs [n , :, :] = np .matmul (self . sparse_templates .data [n , :, :], self .wmi )
1116+ templates_wfs = np .zeros_like (sparse .data ) # nt, ns, nc
1117+ for n in np .arange (n_wav ):
1118+ templates_wfs [n , :, :] = np .matmul (sparse .data [n , :, :], self .wmi )
10841119
10851120 # The amplitude on each channel is the positive peak minus the negative
10861121 templates_ch_amps = np .max (templates_wfs , axis = 1 ) - np .min (templates_wfs , axis = 1 )
10871122
10881123 # The template arbitrary unit amplitude is the amplitude of its largest channel
10891124 # (but see below for true tempAmps)
10901125 templates_amps_au = np .max (templates_ch_amps , axis = 1 )
1091- spike_amps = templates_amps_au [self . spike_templates ] * self .amplitudes
1126+ spike_amps = templates_amps_au [spikes ] * self .amplitudes
10921127
10931128 with np .errstate (divide = 'ignore' , invalid = 'ignore' ):
10941129 # take the average spike amplitude per template
1095- templates_amps_v = (np .bincount (self . spike_templates , weights = spike_amps ) /
1096- np .bincount (self . spike_templates ))
1130+ templates_amps_v = (np .bincount (spikes , weights = spike_amps ) /
1131+ np .bincount (spikes ))
10971132 # scale back the template according to the spikes units
10981133 templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
10991134 )[:, np .newaxis , np .newaxis ]
@@ -1167,18 +1202,18 @@ def get_template_waveforms(self, template_id):
11671202 template = self .get_template (template_id )
11681203 return template .template if template else None
11691204
1170- def get_cluster_mean_waveforms (self , cluster_id ):
1205+ def get_cluster_mean_waveforms (self , cluster_id , unwhiten = True ):
11711206 """Return the mean template waveforms of a cluster, as a weighted average of the
11721207 template waveforms from which the cluster originates from."""
11731208 count = self .get_template_counts (cluster_id )
11741209 best_template = np .argmax (count )
11751210 template_ids = np .nonzero (count )[0 ]
11761211 count = count [template_ids ]
11771212 # Get local channels of the best template for the given cluster.
1178- template = self .get_template (best_template )
1213+ template = self .get_template (best_template , unwhiten = unwhiten )
11791214 channel_ids = template .channel_ids
11801215 # Get all templates from which this cluster stems from.
1181- templates = [self .get_template (template_id ) for template_id in template_ids ]
1216+ templates = [self .get_template (template_id , unwhiten = unwhiten ) for template_id in template_ids ]
11821217 # Construct the waveforms array.
11831218 ns = self .n_samples_waveforms
11841219 data = np .zeros ((len (template_ids ), ns , self .n_channels ))
@@ -1205,16 +1240,27 @@ def get_cluster_spike_waveforms(self, cluster_id):
12051240 @property
12061241 def templates_channels (self ):
12071242 """Returns a vector of peak channels for all templates"""
1208- tmp = self .sparse_templates .data
1243+ return self ._channels (self .sparse_templates )
1244+
1245+ @property
1246+ def clusters_channels (self ):
1247+ """Returns a vector of peak channels for all templates"""
1248+ channels = self ._channels (self .sparse_clusters )
1249+ return channels
1250+
1251+ def _channels (self , sparse ):
1252+ # TODO document and better name
1253+ tmp = sparse .data
12091254 n_templates , n_samples , n_channels = tmp .shape
1210- if self . sparse_templates .cols is None :
1255+ if sparse .cols is None :
12111256 template_peak_channels = np .argmax (tmp .max (axis = 1 ) - tmp .min (axis = 1 ), axis = 1 )
12121257 else :
12131258 # when the templates are sparse, the first channel is the highest amplitude channel
1214- template_peak_channels = self . sparse_templates .cols [:, 0 ]
1259+ template_peak_channels = sparse .cols [:, 0 ]
12151260 assert template_peak_channels .shape == (n_templates ,)
12161261 return template_peak_channels
12171262
1263+
12181264 @property
12191265 def templates_probes (self ):
12201266 """Returns a vector of probe index for all templates"""
@@ -1223,16 +1269,32 @@ def templates_probes(self):
12231269 @property
12241270 def templates_amplitudes (self ):
12251271 """Returns the average amplitude per cluster"""
1226- tid = np .unique (self .spike_templates )
1227- n = np .bincount (self .spike_templates )[tid ]
1228- a = np .bincount (self .spike_templates , weights = self .amplitudes )[tid ]
1272+ return self ._amplitudes (self .spike_templates )
1273+
1274+ @property
1275+ def clusters_amplitudes (self ):
1276+ """Returns the average amplitude per cluster"""
1277+ return self ._amplitudes (self .spike_clusters )
1278+
1279+ def _amplitudes (self , tmp ):
1280+ tid = np .unique (tmp )
1281+ n = np .bincount (tmp )[tid ]
1282+ a = np .bincount (tmp , weights = self .amplitudes )[tid ]
12291283 n [np .isnan (n )] = 1
12301284 return a / n
12311285
12321286 @property
12331287 def templates_waveforms_durations (self ):
12341288 """Returns a vector of waveform durations (ms) for all templates"""
1235- tmp = self .sparse_templates .data
1289+ return self ._waveform_durations (self .sparse_templates .data )
1290+
1291+ @property
1292+ def clusters_waveforms_durations (self ):
1293+ """Returns a vector of waveform durations (ms) for all templates"""
1294+ waveform_duration = self ._waveform_durations (self .sparse_clusters .data )
1295+ return waveform_duration
1296+
1297+ def _waveform_durations (self , tmp ):
12361298 n_templates , n_samples , n_channels = tmp .shape
12371299 # Compute the peak channels for each template.
12381300 template_peak_channels = np .argmax (tmp .max (axis = 1 ) - tmp .min (axis = 1 ), axis = 1 )
@@ -1241,6 +1303,23 @@ def templates_waveforms_durations(self):
12411303 (n_templates , n_channels ), mode = 'raise' , order = 'C' )
12421304 return durations .flatten ()[ind ].astype (np .float64 ) / self .sample_rate * 1e3
12431305
1306+ def cluster_waveforms (self ):
1307+ """
1308+ Computes the cluster waveforms for split and merged clusters
1309+ :return:
1310+ """
1311+ # Only non sparse implementation
1312+ ns = self .n_samples_waveforms # TODO put not implemented warning
1313+ data = np .zeros ((np .max (self .cluster_ids ) + 1 , ns , self .n_channels )) # TODO can be self.n_clusters
1314+ for clust , val in self .merge_map .items ():
1315+ if len (val ) > 1 :
1316+ mean_waveform = self .get_cluster_mean_waveforms (clust , unwhiten = False )
1317+ data [clust , :, mean_waveform .channel_ids ] = np .swapaxes (mean_waveform .mean_waveforms , 0 , 1 )
1318+ elif len (val ) == 1 :
1319+ data [clust , :, :] = self .sparse_templates .data [val [0 ], :, :]
1320+
1321+ return Bunch (data = data , cols = None )
1322+
12441323 #--------------------------------------------------------------------------
12451324 # Saving methods
12461325 #--------------------------------------------------------------------------
0 commit comments