@@ -318,30 +318,51 @@ class FastGRNNCUDACell(RNNCell):
318318 h_t = z_t*h_{t-1} + (sigmoid(zeta)(1-z_t) + sigmoid(nu))*h_t^
319319
320320 '''
321- def __init__ (self , input_size , hidden_size , gate_non_linearity = "sigmoid" , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
322- super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , "tanh" , 1 , 1 , 2 )
321+ def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
322+ update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
323+ super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , update_nonlinearity , 1 , 1 , 2 , wRank , uRank )
323324 if utils .findCUDA () is None :
324- raise Exception ('FastGRNNCUDACell is supported only on GPU devices.' )
325+ raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
325326 NON_LINEARITY = {"sigmoid" : 0 , "relu" : 1 , "tanh" : 2 }
326327 self ._input_size = input_size
327328 self ._hidden_size = hidden_size
328329 self ._zetaInit = zetaInit
329330 self ._nuInit = nuInit
330331 self ._name = name
331- self ._gate_non_linearity = NON_LINEARITY [gate_non_linearity ]
332- self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
333- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
332+
333+ if wRank is not None :
334+ self ._num_W_matrices += 1
335+ self ._num_weight_matrices [0 ] = self ._num_W_matrices
336+ if uRank is not None :
337+ self ._num_U_matrices += 1
338+ self ._num_weight_matrices [1 ] = self ._num_U_matrices
339+ self ._name = name
340+
341+ if wRank is None :
342+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
343+ self .W1 = torch .empty (0 )
344+ self .W2 = torch .empty (0 )
345+ else :
346+ self .W = torch .empty (0 )
347+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
348+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
349+
350+ if uRank is None :
351+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
352+ self .U1 = torch .empty (0 )
353+ self .U2 = torch .empty (0 )
354+ else :
355+ self .U = torch .empty (0 )
356+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
357+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
358+
359+ self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
334360
335361 self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
336362 self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
337363 self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ]))
338364 self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ]))
339365
340- def reset_parameters (self ):
341- stdv = 1.0 / math .sqrt (self .state_size )
342- for weight in self .parameters ():
343- weight .data .uniform_ (- stdv , + stdv )
344-
345366 @property
346367 def name (self ):
347368 return self ._name
@@ -352,10 +373,23 @@ def cellType(self):
352373
353374 def forward (self , input , state ):
354375 # Calls the custom autograd function while invokes the CUDA implementation
355- return FastGRNNFunction .apply (input , self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu , state , self ._gate_non_linearity )
376+ return FastGRNNFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
377+ self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
356378
357379 def getVars (self ):
358- return [self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu ]
380+ Vars = []
381+ if self ._num_W_matrices == 1 :
382+ Vars .append (self .W )
383+ else :
384+ Vars .extend ([self .W1 , self .W2 ])
385+
386+ if self ._num_U_matrices == 1 :
387+ Vars .append (self .U )
388+ else :
389+ Vars .extend ([self .U1 , self .U2 ])
390+
391+ Vars .extend ([self .bias_gate , self .bias_update , self .zeta , self .nu ])
392+ return Vars
359393
360394class FastRNNCell (RNNCell ):
361395 '''
@@ -1104,8 +1138,6 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11041138 self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
11051139
11061140 self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
1107- self .W = nn .Parameter (0.1 * torch .randn ([input_size , hidden_size ]))
1108- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
11091141
11101142 self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
11111143 self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
@@ -1118,9 +1150,19 @@ def forward(self, input, h_state, cell_state=None):
11181150 self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
11191151
11201152 def getVars (self ):
1121- if self ._num_W_matrices != 1 :
1122- return [self .W1 , self .W2 , self .U1 , self .U2 , self .bias_gate , self .bias_update , self .zeta , self .nu ]
1123- return [self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu ]
1153+ Vars = []
1154+ if self ._num_W_matrices == 1 :
1155+ Vars .append (self .W )
1156+ else :
1157+ Vars .extend ([self .W1 , self .W2 ])
1158+
1159+ if self ._num_U_matrices == 1 :
1160+ Vars .append (self .U )
1161+ else :
1162+ Vars .extend ([self .U1 , self .U2 ])
1163+
1164+ Vars .extend ([self .bias_gate , self .bias_update , self .zeta , self .nu ])
1165+ return Vars
11241166
11251167class SRNN2 (nn .Module ):
11261168
@@ -1239,10 +1281,10 @@ def forward(self, x, brickSize):
12391281
12401282class FastGRNNFunction (Function ):
12411283 @staticmethod
1242- def forward (ctx , input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity ):
1243- outputs = fastgrnn_cuda .forward (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity )
1284+ def forward (ctx , input , bias_gate , bias_update , zeta , nu , old_h , w , u , w1 , w2 , u1 , u2 , gate_non_linearity ):
1285+ outputs = fastgrnn_cuda .forward (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity , w1 , w2 , u1 , u2 )
12441286 new_h = outputs [0 ]
1245- variables = [input , old_h , zeta , nu , w , u ] + outputs [1 :]
1287+ variables = [input , old_h , zeta , nu , w , u ] + outputs [1 :] + [ w1 , w2 , u1 , u2 ]
12461288 ctx .save_for_backward (* variables )
12471289 ctx .non_linearity = gate_non_linearity
12481290 return new_h
@@ -1251,8 +1293,7 @@ def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_
12511293 def backward (ctx , grad_h ):
12521294 outputs = fastgrnn_cuda .backward (
12531295 grad_h .contiguous (), * ctx .saved_variables , ctx .non_linearity )
1254- d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h = outputs
1255- return d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h , None
1296+ return tuple (outputs + [None ])
12561297
12571298class FastGRNNUnrollFunction (Function ):
12581299 @staticmethod
0 commit comments