@@ -397,6 +397,10 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
397397 auto d_nu = torch::zeros_like (initial_h);
398398 auto d_bias_z = torch::zeros_like (initial_h);
399399 auto d_bias_h_prime = torch::zeros_like (initial_h);
400+ auto d_w1 = torch::empty (0 );
401+ auto d_w2 = torch::empty (0 );
402+ auto d_u1 = torch::empty (0 );
403+ auto d_u2 = torch::empty (0 );
400404
401405 bool w_low_rank = w1.size (0 ) != 0 ;
402406 bool u_low_rank = u1.size (0 ) != 0 ;
@@ -501,20 +505,14 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
501505 d_zeta = (d_zeta.sum (0 , true )).sum (1 , true );
502506 d_nu = (d_nu.sum (0 , true )).sum (1 , true );
503507 if (w_low_rank) {
504- auto d_w1 = torch::mm (w2.transpose (0 , 1 ), d_w);
505- auto d_w2 = torch::mm (d_w, w1.transpose (0 , 1 ));
508+ d_w1 = torch::mm (w2.transpose (0 , 1 ), d_w);
509+ d_w2 = torch::mm (d_w, w1.transpose (0 , 1 ));
506510 d_w = torch::empty (0 );
507- } else {
508- auto d_w1 = torch::empty (0 );
509- auto d_w2 = torch::empty (0 );
510511 }
511512 if (u_low_rank) {
512- auto d_u1 = torch::mm (u2.transpose (0 , 1 ), d_u);
513- auto d_u2 = torch::mm (d_u, u1.transpose (0 , 1 ));
513+ d_u1 = torch::mm (u2.transpose (0 , 1 ), d_u);
514+ d_u2 = torch::mm (d_u, u1.transpose (0 , 1 ));
514515 d_u = torch::empty (0 );
515- } else {
516- auto d_u1 = torch::empty (0 );
517- auto d_u2 = torch::empty (0 );
518516 }
519517 return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
520518}
0 commit comments