2222import torch .nn .functional as F
2323import torch .optim as optim
2424import torch .onnx
25+ import random
2526
2627from torch .autograd import Variable , Function
2728from torch .utils .data import Dataset , DataLoader
@@ -306,7 +307,7 @@ class AudioDataset(Dataset):
306307 mini-batch training.
307308 """
308309
309- def __init__ (self , filename , config , keywords ):
310+ def __init__ (self , filename , config , keywords , training = False ):
310311 """ Initialize the AudioDataset from the given *.npz file """
311312 self .dataset = np .load (filename )
312313
@@ -331,34 +332,59 @@ def __init__(self, filename, config, keywords):
331332 else :
332333 self .mean = None
333334 self .std = None
335+
334336 self .label_names = self .dataset ["labels" ]
335337 self .keywords = keywords
336338 self .num_keywords = len (self .keywords )
337339 self .labels = self .to_long_vector ()
340+
341+ self .keywords_idx = None
342+ self .non_keywords_idx = None
343+ if training and config .sample_non_kw is not None :
344+ self .keywords_idx , self .non_keywords_idx = self .get_keyword_idx (config .sample_non_kw )
345+ self .sample_non_kw_probability = config .sample_non_kw_probability
346+
338347 msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}"
339348 print (msg .format (os .path .basename (filename ), self .sample_rate , self .audio_size , self .input_size ,
340349 self .window_size , self .shift ))
341350
342351 def get_data_loader (self , batch_size ):
343352 """ Get a DataLoader that can enumerate shuffled batches of data in this dataset """
344353 return DataLoader (self , batch_size = batch_size , shuffle = True , drop_last = True )
345-
354+
346355 def to_long_vector (self ):
347356 """ convert the expected labels to a list of integer indexes into the array of keywords """
348357 indexer = [(0 if x == "<null>" else self .keywords .index (x )) for x in self .label_names ]
349358 return np .array (indexer , dtype = np .longlong )
350359
360+ def get_keyword_idx (self , non_kw_label ):
361+ """ find the keywords and store there index """
362+ indexer = [ids for ids , label in enumerate (self .label_names ) if label != non_kw_label ]
363+ non_indexer = [ids for ids , label in enumerate (self .label_names ) if label == non_kw_label ]
364+ return (np .array (indexer , dtype = np .longlong ), np .array (non_indexer , dtype = np .longlong ))
365+
351366 def __len__ (self ):
352367 """ Return the number of rows in this Dataset """
353- return self .num_rows
368+ if self .non_keywords_idx is None :
369+ return self .num_rows
370+ else :
371+ return int (len (self .keywords_idx ) / (1 - self .sample_non_kw_probability ))
354372
355373 def __getitem__ (self , idx ):
356374 """ Return a single labelled sample here as a tuple """
357- audio = self .features [idx ] # batch index is second dimension
358- label = self .labels [idx ]
375+ if self .non_keywords_idx is None :
376+ updated_idx = idx
377+ else :
378+ if idx < len (self .keywords_idx ):
379+ updated_idx = self .keywords_idx [idx ]
380+ else :
381+ updated_idx = np .random .choice (self .non_keywords_idx )
382+ audio = self .features [updated_idx ] # batch index is second dimension
383+ label = self .labels [updated_idx ]
359384 sample = (audio , label )
360385 return sample
361386
387+
362388
363389def create_model (model_config , input_size , num_keywords ):
364390 ModelClass = get_model_class (KeywordSpotter )
@@ -453,7 +479,7 @@ def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False):
453479 log = None
454480 if not evaluate_only :
455481 print ("Loading {}..." .format (training_file ))
456- training_data = AudioDataset (training_file , config .dataset , keywords )
482+ training_data = AudioDataset (training_file , config .dataset , keywords , training = True )
457483
458484 print ("Loading {}..." .format (validation_file ))
459485 validation_data = AudioDataset (validation_file , config .dataset , keywords )
@@ -556,6 +582,8 @@ def str2bool(v):
556582 parser .add_argument ("--rolling" , help = "Whether to train model in rolling fashion or not" , action = "store_true" )
557583 parser .add_argument ("--max_rolling_length" , help = "Max number of epochs you want to roll the rolling training"
558584 " default is 100" , type = int )
585+ parser .add_argument ("--sample_non_kw" , "-sl" , type = str , help = "Sample data for this label with probability sample_prob" )
586+ parser .add_argument ("--sample_non_kw_probability" , "-spr" , type = float , help = "Sample from scl with this probability" )
559587
560588 # arguments for fastgrnn
561589 parser .add_argument ("--wRank" , "-wr" , help = "Rank of W in 1st layer of FastGRNN default is None" , type = int )
@@ -645,6 +673,15 @@ def str2bool(v):
645673 config .dataset .categories = args .categories
646674 if args .dataset :
647675 config .dataset .path = args .dataset
676+ if args .sample_non_kw :
677+ config .dataset .sample_non_kw = args .sample_non_kw
678+ if args .sample_non_kw_probability is None :
679+ config .dataset .sample_non_kw_probability = 0.5
680+ else :
681+ config .dataset .sample_non_kw_probability = args .sample_non_kw_probability
682+ else :
683+ config .dataset .sample_non_kw = None
684+
648685 if args .wRank :
649686 config .model .wRank = args .wRank
650687 if args .uRank :
0 commit comments