2727# File utils
2828#------------------------------------------------------------------------------
2929
30- NCH_WAVEFORMS = 32 # number of channels to be saved in templates.waveforms and channels.waveforms
30+ NSAMPLE_WAVEFORMS = 500 # number of waveforrms sampled out of the raw data
3131
3232_FILE_RENAMES = [ # file_in, file_out, squeeze (bool to squeeze vector from matlab in npy)
3333 ('params.py' , 'params.py' , None ),
34- ('cluster_metrics.csv ' , 'clusters.metrics.csv ' , None ),
34+ ('cluster_KSLabel.tsv ' , 'cluster_KSLabel.tsv ' , None ),
3535 ('spike_clusters.npy' , 'spikes.clusters.npy' , True ),
3636 ('spike_templates.npy' , 'spikes.templates.npy' , True ),
3737 ('channel_positions.npy' , 'channels.localCoordinates.npy' , False ),
4242 ('_phy_spikes_subset.channels.npy' , '_phy_spikes_subset.channels.npy' , False ),
4343 ('_phy_spikes_subset.spikes.npy' , '_phy_spikes_subset.spikes.npy' , False ),
4444 ('_phy_spikes_subset.waveforms.npy' , '_phy_spikes_subset.waveforms.npy' , False ),
45+ ('drift.depth_scale.npy' , 'drift.depth_scale.npy' , False ),
46+ ('drift.time_scale.npy' , 'drift.time_scale.npy' , False ),
47+ ('drift.um.npy' , 'drift.um.npy' , False ),
4548 # ('cluster_group.tsv', 'ks2/clusters.phyAnnotation.tsv', False), # todo check indexing, add2QC
4649]
4750
@@ -116,21 +119,23 @@ def convert(self, out_path, force=False, label='', ampfactor=1):
116119 if not self .out_path .exists ():
117120 self .out_path .mkdir ()
118121
119- with tqdm (desc = "Converting to ALF" , total = 95 ) as bar :
120- self .copy_files (force = force )
121- bar .update (10 )
122- self .make_spike_times_amplitudes ()
122+ with tqdm (desc = "Converting to ALF" , total = 125 ) as bar :
123123 bar .update (10 )
124124 self .make_cluster_objects ()
125125 bar .update (10 )
126126 self .make_channel_objects ()
127127 bar .update (5 )
128+ self .make_template_and_spikes_objects ()
129+ bar .update (30 )
130+ self .model .save_spikes_subset_waveforms (
131+ NSAMPLE_WAVEFORMS , sample2unit = self .ampfactor )
132+ bar .update (50 )
128133 self .make_depths ()
129134 bar .update (20 )
130- self .make_template_object ()
131- bar .update (30 )
132135 self .rm_files ()
133136 bar .update (10 )
137+ self .copy_files (force = force )
138+ bar .update (10 )
134139 self .rename_with_label ()
135140
136141 # Return the TemplateModel of the converted ALF dataset if the params.py file exists.
@@ -165,16 +170,8 @@ def _save_npy(self, filename, arr):
165170 """Save an array into a .npy file."""
166171 np .save (self .out_path / filename , arr )
167172
168- def make_spike_times_amplitudes (self ):
169- """We cannot just rename/copy spike_times.npy because it is in unit of
170- *samples*, and not in seconds."""
171- self ._save_npy ('spikes.times.npy' , self .model .spike_times )
172- self ._save_npy ('spikes.samples.npy' , self .model .spike_samples )
173- self ._save_npy ('spikes.amps.npy' , self .model .get_amplitudes_true () * self .ampfactor )
174-
175173 def make_cluster_objects (self ):
176174 """Create clusters.channels, clusters.waveformsDuration and clusters.amps"""
177-
178175 peak_channel_path = self .dir_path / 'clusters.channels.npy'
179176 if not peak_channel_path .exists ():
180177 self ._save_npy (peak_channel_path .name , self .model .templates_channels )
@@ -184,8 +181,8 @@ def make_cluster_objects(self):
184181 self ._save_npy (waveform_duration_path .name , self .model .templates_waveforms_durations )
185182
186183 # group by average over cluster number
187- camps = np .zeros (np . max ( self .cluster_ids ) - np . min ( self . cluster_ids ) + 1 ,) * np .nan
188- camps [self .cluster_ids - np . min ( self . cluster_ids ) ] = self .model .templates_amplitudes
184+ camps = np .zeros (self .model . templates_channels . shape [ 0 ] ,) * np .nan
185+ camps [self .cluster_ids ] = self .model .templates_amplitudes
189186 amps_path = self .dir_path / 'clusters.amps.npy'
190187 self ._save_npy (amps_path .name , camps * self .ampfactor )
191188
@@ -198,7 +195,7 @@ def make_cluster_objects(self):
198195 def make_channel_objects (self ):
199196 """If there is no rawInd file, create it"""
200197 rawInd_path = self .dir_path / 'channels.rawInd.npy'
201- rawInd = np .zeros_like (self .model .channel_probes ).astype (np . int )
198+ rawInd = np .zeros_like (self .model .channel_probes ).astype (int )
202199 channel_offset = 0
203200 for probe in np .unique (self .model .channel_probes ):
204201 ind = self .model .channel_probes == probe
@@ -225,20 +222,27 @@ def make_depths(self):
225222 spikes_depths = clusters_depths [spike_clusters ]
226223 else :
227224 spikes_depths = self .model .get_depths ()
228- # if PC features are provided, compute the depth as the weighted sum of coordinates
229-
230225 self ._save_npy ('spikes.depths.npy' , spikes_depths )
231226 self ._save_npy ('clusters.depths.npy' , clusters_depths )
232227
233- def make_template_object (self ):
228+ def make_template_and_spikes_objects (self ):
234229 """Creates the template waveforms sparse object
235230 Without manual curation, it also corresponds to clusters waveforms objects.
236231 """
232+ # "We cannot just rename/copy spike_times.npy because it is in unit of samples,
233+ # and not seconds
234+ self ._save_npy ('spikes.times.npy' , self .model .spike_times )
235+ self ._save_npy ('spikes.samples.npy' , self .model .spike_samples )
236+ spike_amps , templates_v , template_amps = self .model .get_amplitudes_true (self .ampfactor )
237+ self ._save_npy ('spikes.amps.npy' , spike_amps )
238+ self ._save_npy ('templates.amps.npy' , template_amps )
239+
237240 if self .model .sparse_templates .cols :
238241 raise (NotImplementedError ("Sparse template export to ALF not implemented yet" ))
239242 else :
240- n_templates , n_wavsamps , nchall = self .model .sparse_templates .data .shape
241- ncw = min (NCH_WAVEFORMS , nchall ) # for some datasets, 32 may be too much
243+ n_templates , n_wavsamps , nchall = templates_v .shape
244+ # for some datasets, 32 may be too much
245+ ncw = min (self .model .n_closest_channels , nchall )
242246 assert (n_templates == self .model .n_templates )
243247 templates = np .zeros ((n_templates , n_wavsamps , ncw ), dtype = np .float32 )
244248 templates_inds = np .zeros ((n_templates , ncw ), dtype = np .int32 )
@@ -250,10 +254,10 @@ def make_template_object(self):
250254 self .model .channel_positions [self .model .templates_channels [t ]]), axis = 1 )
251255 channel_distance [self .model .channel_probes != current_probe ] += np .inf
252256 templates_inds [t , :] = np .argsort (channel_distance )[:ncw ]
253- templates [t , ...] = self . model . sparse_templates . data [t , :][:, templates_inds [t , :]]
254- np .save (self .out_path .joinpath ('templates.waveforms' ), templates * self . ampfactor )
257+ templates [t , ...] = templates_v [t , :][:, templates_inds [t , :]]
258+ np .save (self .out_path .joinpath ('templates.waveforms' ), templates )
255259 np .save (self .out_path .joinpath ('templates.waveformsChannels' ), templates_inds )
256- np .save (self .out_path .joinpath ('clusters.waveforms' ), templates * self . ampfactor )
260+ np .save (self .out_path .joinpath ('clusters.waveforms' ), templates )
257261 np .save (self .out_path .joinpath ('clusters.waveformsChannels' ), templates_inds )
258262
259263 def rename_with_label (self ):
0 commit comments