@@ -330,7 +330,8 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
330330 self ._zetaInit = zetaInit
331331 self ._nuInit = nuInit
332332 self ._name = name
333-
333+ self .device = torch .device ("cuda" )
334+
334335 if wRank is not None :
335336 self ._num_W_matrices += 1
336337 self ._num_weight_matrices [0 ] = self ._num_W_matrices
@@ -340,29 +341,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
340341 self ._name = name
341342
342343 if wRank is None :
343- self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
344+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ], self . device ))
344345 self .W1 = torch .empty (0 )
345346 self .W2 = torch .empty (0 )
346347 else :
347348 self .W = torch .empty (0 )
348- self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
349- self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
349+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ], self . device ))
350+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ], self . device ))
350351
351352 if uRank is None :
352- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
353+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ], self . device ))
353354 self .U1 = torch .empty (0 )
354355 self .U2 = torch .empty (0 )
355356 else :
356357 self .U = torch .empty (0 )
357- self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
358- self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
358+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ], self . device ))
359+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ], self . device ))
359360
360361 self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
361362
362- self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
363- self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
364- self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ]))
365- self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ]))
363+ self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ], self . device ))
364+ self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ], self . device ))
365+ self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ], self . device ))
366+ self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ], self . device ))
366367
367368 @property
368369 def name (self ):
@@ -374,7 +375,11 @@ def cellType(self):
374375
375376 def forward (self , input , state ):
376377 # Calls the custom autograd function while invokes the CUDA implementation
377- return FastGRNNFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
378+ if not input .is_cuda :
379+ input .to (self .device )
380+ if not state .is_cuda :
381+ state .to (self .device )
382+ return FastGRNNFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , state ,
378383 self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
379384
380385 def getVars (self ):
@@ -1103,7 +1108,7 @@ class FastGRNNCUDA(nn.Module):
11031108 def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
11041109 update_nonlinearity = "tanh" , wRank = None , uRank = None ,
11051110 wSparsity = 1.0 , uSparsity = 1.0 , zetaInit = 1.0 , nuInit = - 4.0 ,
1106- name = "FastGRNNCUDACell " ):
1111+ batch_first = False , name = "FastGRNNCUDA " ):
11071112 super (FastGRNNCUDA , self ).__init__ ()
11081113 if utils .findCUDA () is None :
11091114 raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
@@ -1113,7 +1118,17 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11131118 self ._zetaInit = zetaInit
11141119 self ._nuInit = nuInit
11151120 self ._name = name
1116-
1121+ self ._num_W_matrices = 1
1122+ self ._num_U_matrices = 1
1123+ self ._num_biases = 2
1124+ self ._num_weight_matrices = [self ._num_W_matrices , self ._num_U_matrices , self ._num_biases ]
1125+ self ._wRank = wRank
1126+ self ._uRank = uRank
1127+ self ._wSparsity = wSparsity
1128+ self ._uSparsity = uSparsity
1129+ self .oldmats = []
1130+ self .device = torch .device ("cuda" )
1131+ self .batch_first = batch_first
11171132 if wRank is not None :
11181133 self ._num_W_matrices += 1
11191134 self ._num_weight_matrices [0 ] = self ._num_W_matrices
@@ -1123,33 +1138,42 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11231138 self ._name = name
11241139
11251140 if wRank is None :
1126- self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
1141+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ], device = self . device ))
11271142 self .W1 = torch .empty (0 )
11281143 self .W2 = torch .empty (0 )
11291144 else :
11301145 self .W = torch .empty (0 )
1131- self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
1132- self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
1146+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ], device = self . device ))
1147+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ], device = self . device ))
11331148
11341149 if uRank is None :
1135- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
1150+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ], device = self . device ))
11361151 self .U1 = torch .empty (0 )
11371152 self .U2 = torch .empty (0 )
11381153 else :
11391154 self .U = torch .empty (0 )
1140- self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
1141- self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
1155+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ], device = self . device ))
1156+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ], device = self . device ))
11421157
11431158 self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
11441159
1145- self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
1146- self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
1147- self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ]))
1148- self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ]))
1160+ self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ], device = self . device ))
1161+ self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ], device = self . device ))
1162+ self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ], device = self . device ))
1163+ self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ], device = self . device ))
11491164
1150- def forward (self , input , h_state , cell_state = None ):
1165+ def forward (self , input , hiddenState , cell_state = None ):
11511166 # input: [timesteps, batch, features, state_size]
1152- return FastGRNNUnrollFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
1167+ if self .batch_first :
1168+ input = input .transpose (0 , 1 )
1169+ if not input .is_cuda :
1170+ input = input .to (self .device )
1171+ if hiddenState is None :
1172+ hiddenState = torch .zeros (
1173+ [input .shape [1 ], self .hidden_size ]).to (self .device )
1174+ if not hiddenState .is_cuda :
1175+ hiddenState = hiddenState .to (self .device )
1176+ return FastGRNNUnrollFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , hiddenState ,
11531177 self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
11541178
11551179 def getVars (self ):
0 commit comments