@@ -329,7 +329,7 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
329329 self ._nuInit = nuInit
330330 self ._name = name
331331 self ._gate_non_linearity = NON_LINEARITY [gate_non_linearity ]
332- self .W = nn .Parameter (0.1 * torch .randn ([input_size , hidden_size ]))
332+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
333333 self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
334334
335335 self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
@@ -1065,7 +1065,8 @@ def forward(self, input, hiddenState=None, cellState=None):
10651065
10661066class FastGRNNCUDA (nn .Module ):
10671067 """Unrolled implementation of the FastGRNNCUDACell"""
1068- def __init__ (self , input_size , hidden_size , gate_non_linearity = "sigmoid" , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
1068+ def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
1069+ update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
10691070 super (FastGRNNCUDA , self ).__init__ ()
10701071 if utils .findCUDA () is None :
10711072 raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
@@ -1075,7 +1076,34 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
10751076 self ._zetaInit = zetaInit
10761077 self ._nuInit = nuInit
10771078 self ._name = name
1078- self ._gate_non_linearity = NON_LINEARITY [gate_non_linearity ]
1079+
1080+ if wRank is not None :
1081+ self ._num_W_matrices += 1
1082+ self ._num_weight_matrices [0 ] = self ._num_W_matrices
1083+ if uRank is not None :
1084+ self ._num_U_matrices += 1
1085+ self ._num_weight_matrices [1 ] = self ._num_U_matrices
1086+ self ._name = name
1087+
1088+ if wRank is None :
1089+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
1090+ self .W1 = torch .empty (0 )
1091+ self .W2 = torch .empty (0 )
1092+ else :
1093+ self .W = torch .empty (0 )
1094+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
1095+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
1096+
1097+ if uRank is None :
1098+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
1099+ self .U1 = torch .empty (0 )
1100+ self .U2 = torch .empty (0 )
1101+ else :
1102+ self .U = torch .empty (0 )
1103+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
1104+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
1105+
1106+ self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
10791107 self .W = nn .Parameter (0.1 * torch .randn ([input_size , hidden_size ]))
10801108 self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
10811109
@@ -1086,9 +1114,12 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
10861114
10871115 def forward (self , input , h_state , cell_state = None ):
10881116 # input: [timesteps, batch, features, state_size]
1089- return FastGRNNUnrollFunction .apply (input , self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state , self ._gate_non_linearity )
1117+ return FastGRNNUnrollFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
1118+ self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
10901119
10911120 def getVars (self ):
1121+ if self ._num_W_matrices != 1 :
1122+ return [self .W1 , self .W2 , self .U1 , self .U2 , self .bias_gate , self .bias_update , self .zeta , self .nu ]
10921123 return [self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu ]
10931124
10941125class SRNN2 (nn .Module ):
@@ -1225,10 +1256,10 @@ def backward(ctx, grad_h):
12251256
12261257class FastGRNNUnrollFunction (Function ):
12271258 @staticmethod
1228- def forward (ctx , input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity ):
1229- outputs = fastgrnn_cuda .forward_unroll (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity )
1259+ def forward (ctx , input , bias_gate , bias_update , zeta , nu , old_h , w , u , w1 , w2 , u1 , u2 , gate_non_linearity ):
1260+ outputs = fastgrnn_cuda .forward_unroll (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity , w1 , w2 , u1 , u2 )
12301261 hidden_states = outputs [0 ]
1231- variables = [input , hidden_states , zeta , nu , w , u ] + outputs [1 :] + [old_h ]
1262+ variables = [input , hidden_states , zeta , nu , w , u ] + outputs [1 :] + [old_h , w1 , w2 , u1 , u2 ]
12321263 ctx .save_for_backward (* variables )
12331264 ctx .gate_non_linearity = gate_non_linearity
12341265 return hidden_states
@@ -1237,5 +1268,4 @@ def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non
12371268 def backward (ctx , grad_h ):
12381269 outputs = fastgrnn_cuda .backward_unroll (
12391270 grad_h .contiguous (), * ctx .saved_variables , ctx .gate_non_linearity )
1240- d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h = outputs
1241- return d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h
1271+ return tuple (outputs + [None ])
0 commit comments