@@ -239,7 +239,7 @@ def compute_features(waveforms):
239239 return features
240240
241241
242- #-----------------------------------------------------------------------------
242+ #------------------------------------------------------------------------------
243243# I/O util functions
244244#------------------------------------------------------------------------------
245245
@@ -683,10 +683,13 @@ def _load_templates(self):
683683 try :
684684 path = self ._find_path (
685685 'templates.npy' , 'templates.waveforms.npy' , 'templates.waveforms.*.npy' )
686- data = self ._read_array (path , mmap_mode = 'r' )
686+ data = self ._read_array (path , mmap_mode = 'r+ ' )
687687 data = np .atleast_3d (data )
688688 assert data .ndim == 3
689689 assert data .dtype in (np .float32 , np .float64 )
690+ # WARNING: this will load the full array in memory, might cause memory problems
691+ empty_templates = np .all (np .all (np .isnan (data ), axis = 1 ), axis = 1 )
692+ data [empty_templates , ...] = 0
690693 n_templates , n_samples , n_channels_loc = data .shape
691694 except IOError :
692695 return
@@ -818,6 +821,7 @@ def _find_best_channels(self, template, amplitude_threshold=None):
818821 # Compute the template amplitude on each channel.
819822 assert template .ndim == 2 # shape: (n_samples, n_channels)
820823 amplitude = template .max (axis = 0 ) - template .min (axis = 0 )
824+ assert not np .all (np .isnan (amplitude )), "Template is all NaN!"
821825 assert amplitude .ndim == 1 # shape: (n_channels,)
822826 # Find the peak channel.
823827 best_channel = np .argmax (amplitude )
0 commit comments