@@ -319,8 +319,9 @@ class FastGRNNCUDACell(RNNCell):
319319
320320 '''
321321 def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
322- update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
323- super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , update_nonlinearity , 1 , 1 , 2 , wRank , uRank )
322+ update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , wSparsity = 1.0 , uSparsity = 1.0 , name = "FastGRNNCUDACell" ):
323+ super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , update_nonlinearity ,
324+ 1 , 1 , 2 , wRank , uRank , wSparsity , uSparsity )
324325 if utils .findCUDA () is None :
325326 raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
326327 NON_LINEARITY = {"sigmoid" : 0 , "relu" : 1 , "tanh" : 2 }
@@ -1166,6 +1167,54 @@ def getVars(self):
11661167 Vars .extend ([self .bias_gate , self .bias_update , self .zeta , self .nu ])
11671168 return Vars
11681169
1170+ def get_model_size (self ):
1171+ '''
1172+ Function to get aimed model size
1173+ '''
1174+ mats = self .getVars ()
1175+ endW = self ._num_W_matrices
1176+ endU = endW + self ._num_U_matrices
1177+
1178+ totalnnz = 2 # For Zeta and Nu
1179+ for i in range (0 , endW ):
1180+ device = mats [i ].device
1181+ totalnnz += utils .countNNZ (mats [i ].cpu (), self ._wSparsity )
1182+ mats [i ].to (device )
1183+ for i in range (endW , endU ):
1184+ device = mats [i ].device
1185+ totalnnz += utils .countNNZ (mats [i ].cpu (), self ._uSparsity )
1186+ mats [i ].to (device )
1187+ for i in range (endU , len (mats )):
1188+ device = mats [i ].device
1189+ totalnnz += utils .countNNZ (mats [i ].cpu (), False )
1190+ mats [i ].to (device )
1191+ return totalnnz * 4
1192+
1193+ def copy_previous_UW (self ):
1194+ mats = self .getVars ()
1195+ num_mats = self ._num_W_matrices + self ._num_U_matrices
1196+ if len (self .oldmats ) != num_mats :
1197+ for i in range (num_mats ):
1198+ self .oldmats .append (torch .FloatTensor ())
1199+ for i in range (num_mats ):
1200+ self .oldmats [i ] = torch .FloatTensor (mats [i ].detach ().clone ().to (mats [i ].device ))
1201+
1202+ def sparsify (self ):
1203+ mats = self .getVars ()
1204+ endW = self ._num_W_matrices
1205+ endU = endW + self ._num_U_matrices
1206+ for i in range (0 , endW ):
1207+ mats [i ] = utils .hardThreshold (mats [i ], self ._wSparsity )
1208+ for i in range (endW , endU ):
1209+ mats [i ] = utils .hardThreshold (mats [i ], self ._uSparsity )
1210+ self .copy_previous_UW ()
1211+
1212+ def sparsifyWithSupport (self ):
1213+ mats = self .getVars ()
1214+ endU = self ._num_W_matrices + self ._num_U_matrices
1215+ for i in range (0 , endU ):
1216+ mats [i ] = utils .supportBasedThreshold (mats [i ], self .oldmats [i ])
1217+
11691218class SRNN2 (nn .Module ):
11701219
11711220 def __init__ (self , inputDim , outputDim , hiddenDim0 , hiddenDim1 , cellType ,
0 commit comments