@@ -87,6 +87,37 @@ __global__ void fastgrnn_cuda_backward_kernel(
8787 d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0 ][0 ];
8888 }
8989}
90+
91+ template <typename scalar_t , scalar_t (*d_non_linearity) (scalar_t )>
92+ __global__ void fastgrnn_unroll_cuda_backward_kernel (
93+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_precomp,
94+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_old_h,
95+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_z,
96+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_h_prime,
97+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu,
98+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta,
99+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > grad_h,
100+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z,
101+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime,
102+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
103+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu,
104+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta_sigmoid,
105+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu_sigmoid,
106+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h) {
107+ const int n = blockIdx .y ;
108+ const int c = blockIdx .x * blockDim .x + threadIdx .x ;
109+ if (c < old_h.size (1 )){
110+ d_old_h[n][c] = z[n][c] * grad_h[n][c];
111+ scalar_t temp_bias_h_prime = (zeta[0 ][0 ] * (1.0 - z[n][c]) + nu[0 ][0 ]) * d_tanh (h_prime[n][c]) * grad_h[n][c];
112+ scalar_t temp_bias_z = (old_h[n][c] - zeta[0 ][0 ] * h_prime[n][c]) * d_non_linearity (z[n][c]) * grad_h[n][c];
113+ d_bias_h_prime[n][c] += temp_bias_h_prime;
114+ d_bias_z[n][c] += temp_bias_z;
115+ d_precomp[n][c] = temp_bias_z + temp_bias_h_prime;
116+ d_zeta[n][c] += (1.0 - z[n][c]) * h_prime[n][c] * grad_h[n][c] * d_zeta_sigmoid[0 ][0 ];
117+ d_nu[n][c] += h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0 ][0 ];
118+ }
119+ }
120+
90121} // namespace
91122
92123std::vector<torch::Tensor> fastgrnn_cuda_forward (
@@ -246,3 +277,202 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
246277
247278 return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
248279}
280+
281+ std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward (
282+ torch::Tensor input,
283+ torch::Tensor w,
284+ torch::Tensor u,
285+ torch::Tensor bias_z,
286+ torch::Tensor bias_h_prime,
287+ torch::Tensor zeta,
288+ torch::Tensor nu,
289+ torch::Tensor initial_h,
290+ int z_non_linearity) {
291+ auto options = torch::TensorOptions ().dtype (input.dtype ()).device (input.device ().type ());
292+ const auto timesteps = input.size (0 );
293+ const auto batch_size = initial_h.size (0 );
294+ const auto state_size = initial_h.size (1 );
295+
296+ auto hidden_states = torch::zeros ({timesteps, batch_size, state_size}, options);
297+ auto z_s = torch::zeros_like (hidden_states);
298+ auto h_prime_s = torch::zeros_like (hidden_states);
299+
300+ auto prev_h = initial_h;
301+ auto new_h = torch::zeros_like (prev_h);
302+ auto z = torch::zeros_like (prev_h);
303+ auto h_prime = torch::zeros_like (prev_h);
304+ auto pre_comp = torch::zeros_like (prev_h);
305+
306+ const int threads = 1024 ;
307+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
308+
309+ w = w.transpose (0 , 1 );
310+ u = u.transpose (0 , 1 );
311+ zeta = torch::sigmoid (zeta);
312+ nu = torch::sigmoid (nu);
313+
314+ for (int t=0 ; t < timesteps; t++) {
315+ pre_comp = torch::addmm (torch::mm (input[t], w), prev_h, u);
316+
317+ if (z_non_linearity == 0 )
318+ AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
319+ fastgrnn_cuda_forward_kernel<scalar_t , sigmoid><<<blocks, threads>>> (
320+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
321+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
322+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
323+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
324+ bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
325+ bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
326+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
327+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
328+ prev_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
329+ }));
330+ else if (z_non_linearity == 1 )
331+ AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
332+ fastgrnn_cuda_forward_kernel<scalar_t , relu><<<blocks, threads>>> (
333+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
334+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
335+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
336+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
337+ bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
338+ bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
339+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
340+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
341+ prev_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
342+ }));
343+ else if (z_non_linearity == 2 )
344+ AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
345+ fastgrnn_cuda_forward_kernel<scalar_t , tanh><<<blocks, threads>>> (
346+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
347+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
348+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
349+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
350+ bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
351+ bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
352+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
353+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
354+ prev_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
355+ }));
356+ hidden_states[t] = new_h;
357+ z_s[t] = z;
358+ h_prime_s[t] = h_prime;
359+ prev_h = new_h;
360+ }
361+ return {hidden_states, z_s, h_prime_s};
362+ }
363+
364+ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward (
365+ torch::Tensor grad_h,
366+ torch::Tensor input,
367+ torch::Tensor hidden_states,
368+ torch::Tensor zeta,
369+ torch::Tensor nu,
370+ torch::Tensor w,
371+ torch::Tensor u,
372+ torch::Tensor z,
373+ torch::Tensor h_prime,
374+ torch::Tensor initial_h,
375+ int z_non_linearity) {
376+
377+ auto d_input = torch::zeros_like (input);
378+ auto d_w = torch::zeros_like (w);
379+ auto d_u = torch::zeros_like (u);
380+ auto d_zeta = torch::zeros_like (initial_h);
381+ auto d_nu = torch::zeros_like (initial_h);
382+ auto d_bias_z = torch::zeros_like (initial_h);
383+ auto d_bias_h_prime = torch::zeros_like (initial_h);
384+
385+ zeta = torch::sigmoid (zeta);
386+ nu = torch::sigmoid (nu);
387+ auto d_nu_sigmoid = d_sigmoid (nu);
388+ auto d_zeta_sigmoid = d_sigmoid (zeta);
389+
390+
391+ auto grad_curr_h = torch::zeros_like (initial_h);
392+ auto d_precomp = torch::zeros_like (initial_h);
393+ auto d_old_h = torch::zeros_like (initial_h);
394+ auto prev_h_ = hidden_states[0 ];
395+ auto z_t_ = torch::zeros_like (initial_h);
396+ auto h_prime_t_ = torch::zeros_like (initial_h);
397+
398+ const auto batch_size = hidden_states.size (1 );
399+ const auto state_size = hidden_states.size (2 );
400+
401+ const int threads = 1024 ;
402+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
403+ for (auto t = hidden_states.size (0 ) - 1 ; t>=0 ; t--) {
404+ grad_curr_h = torch::add (grad_h[t], d_old_h);
405+ z_t_ = z[t];
406+ h_prime_t_ = h_prime[t];
407+
408+ if (t == 0 )
409+ prev_h_ = initial_h;
410+ else
411+ prev_h_ = hidden_states[t-1 ];
412+
413+ if (z_non_linearity == 0 )
414+ AT_DISPATCH_FLOATING_TYPES (z_t_.type (), " fastgrnn_forward_cuda" , ([&] {
415+ fastgrnn_unroll_cuda_backward_kernel<scalar_t , d_sigmoid><<<blocks, threads>>> (
416+ d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
417+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
418+ d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
419+ d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
420+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
421+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
422+ grad_curr_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
423+ z_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
424+ h_prime_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
425+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
426+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
427+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
428+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
429+ prev_h_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
430+ }));
431+ else if (z_non_linearity == 1 )
432+ AT_DISPATCH_FLOATING_TYPES (z_t_.type (), " fastgrnn_forward_cuda" , ([&] {
433+ fastgrnn_unroll_cuda_backward_kernel<scalar_t , d_relu><<<blocks, threads>>> (
434+ d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
435+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
436+ d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
437+ d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
438+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
439+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
440+ grad_curr_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
441+ z_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
442+ h_prime_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
443+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
444+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
445+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
446+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
447+ prev_h_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
448+ }));
449+ else if (z_non_linearity == 2 )
450+ AT_DISPATCH_FLOATING_TYPES (z_t_.type (), " fastgrnn_forward_cuda" , ([&] {
451+ fastgrnn_unroll_cuda_backward_kernel<scalar_t , d_sigmoid><<<blocks, threads>>> (
452+ d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
453+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
454+ d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
455+ d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
456+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
457+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
458+ grad_curr_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
459+ z_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
460+ h_prime_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
461+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
462+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
463+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
464+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
465+ prev_h_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
466+ }));
467+ d_old_h = torch::addmm (d_old_h, d_precomp, u);
468+ d_input[t] = torch::mm (d_precomp, w);
469+ d_w = torch::addmm (d_w, d_precomp.transpose (0 , 1 ), input[t]);
470+ d_u = torch::addmm (d_u, d_precomp.transpose (0 , 1 ), prev_h_);
471+ // grad_curr_h = d_old_h;
472+ }
473+ d_bias_z = d_bias_z.sum (0 , true );
474+ d_bias_h_prime = d_bias_h_prime.sum (0 , true );
475+ d_zeta = (d_zeta.sum (0 , true )).sum (1 , true );
476+ d_nu = (d_nu.sum (0 , true )).sum (1 , true );
477+ return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
478+ }
0 commit comments