55
66#include < vector>
77
8+ __forceinline__ torch::Tensor d_sigmoid (torch::Tensor z) {
9+ return (1 - z) * z;
10+ }
11+
12+ __forceinline__ torch::Tensor d_tanh (torch::Tensor z) {
13+ return 1 - z.pow (2 );
14+ }
15+
16+
817namespace {
918template <typename scalar_t >
1019__device__ __forceinline__ scalar_t sigmoid (scalar_t z) {
1120 return 1.0 / (1.0 + exp (-z));
1221}
1322
1423template <typename scalar_t >
15- __device__ __forceinline__ scalar_t d_sigmoid (scalar_t z) {
16- const auto s = sigmoid (z);
17- return (1.0 - s) * s;
24+ __device__ __forceinline__ scalar_t d_sigmoid (scalar_t sig_z) {
25+ return (1.0 - sig_z) * sig_z;
1826}
1927
2028template <typename scalar_t >
21- __device__ __forceinline__ scalar_t d_tanh (scalar_t z) {
22- const auto t = tanh (z);
23- return 1 - (t * t);
29+ __device__ __forceinline__ scalar_t d_tanh (scalar_t tan_z) {
30+ return 1 - (tan_z * tan_z);
2431}
2532
2633template <typename scalar_t >
2734__global__ void fastgrnn_cuda_forward_kernel (
35+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > new_h,
36+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z,
37+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime,
2838 const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > pre_comp,
29- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h,
30- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > new_h,
31- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z_t ,
32- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime_t ,
33- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_z,
34- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_h_prime,
35- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
36- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu) {
37- // batch index
39+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_z,
40+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_h_prime,
41+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu,
42+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
43+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h) {
3844 const int n = blockIdx .y ;
39- // column index
4045 const int c = blockIdx .x * blockDim .x + threadIdx .x ;
41- if (c < pre_comp.size (1 )){
42- z_t [n][c] = sigmoid (pre_comp[n][c] + bias_z[n][c]);
43- h_prime_t [n][c] = tanh (pre_comp[n][c] + bias_h_prime[n][c]);
44-
45- new_h[n][c] = (sigmoid (zeta[0 ][0 ]) * (1 - z_t [n][c]) + sigmoid (nu[0 ][0 ])) * h_prime_t [n][c] + z_t [n][c] * old_h[n][c];
46+ if (c < old_h.size (1 )){
47+ z[n][c] = sigmoid (pre_comp[n][c] + bias_z[0 ][c]);
48+ h_prime[n][c] = tanh (pre_comp[n][c] + bias_h_prime[0 ][c]);
49+ new_h[n][c] = (zeta[0 ][0 ] * (1.0 - z[n][c]) + nu[0 ][0 ]) * h_prime[n][c] + old_h[n][c] * z[n][c];
4650 }
4751}
4852
53+
4954template <typename scalar_t >
5055__global__ void fastgrnn_cuda_backward_kernel (
51- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta,
52- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu,
5356 torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_precomp,
54- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_z,
55- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_h_prime_t ,
5657 torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_old_h,
58+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_z,
59+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_h_prime,
60+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu,
61+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta,
5762 const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > grad_h,
58- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h,
59- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z_t ,
60- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime_t ,
61- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > pre_comp,
62- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_z,
63- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_h_prime,
63+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z,
64+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime,
6465 const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
65- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu) {
66- // batch index
66+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu,
67+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta_sigmoid,
68+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu_sigmoid,
69+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h) {
6770 const int n = blockIdx .y ;
68- // column index
6971 const int c = blockIdx .x * blockDim .x + threadIdx .x ;
70- if (c < d_precomp.size (1 )){
71- auto temp_grad = grad_h[n][c] * h_prime_t [n][c];
72- d_zeta[0 ][0 ] = temp_grad * (1 - z_t [n][c]) * d_sigmoid (zeta[0 ][0 ]);
73- d_nu[0 ][0 ] = temp_grad * d_sigmoid (nu[0 ][0 ]);
74- d_bias_z[n][c] = grad_h[n][c] * (sigmoid (zeta[0 ][0 ]) * -1 * h_prime_t [n][c] + old_h[n][c]) * d_sigmoid (pre_comp[n][c] + bias_z[n][c]);;
75- d_bias_h_prime_t [n][c] = grad_h[n][c] * (sigmoid (zeta[0 ][0 ]) * (1 - z_t [n][c]) + sigmoid (nu[0 ][0 ])) * d_tanh (pre_comp[n][c] + bias_h_prime[n][c]);
76- d_old_h[n][c] = grad_h[n][c] * z_t [n][c];
77- d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime_t [n][c];
72+ if (c < old_h.size (1 )){
73+ d_old_h[n][c] = z[n][c] * grad_h[n][c];
74+ d_bias_h_prime[n][c] = (zeta[0 ][0 ] * (1.0 - z[n][c]) + nu[0 ][0 ]) * d_tanh (h_prime[n][c]) * grad_h[n][c];
75+ d_bias_z[n][c] = (old_h[n][c] - zeta[0 ][0 ] * h_prime[n][c]) * d_sigmoid (z[n][c]) * grad_h[n][c];
76+ d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime[n][c];
77+ d_zeta[n][c] = (1.0 - z[n][c]) * h_prime[n][c]*grad_h[n][c] * d_zeta_sigmoid[0 ][0 ];
78+ d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0 ][0 ];
7879 }
7980}
8081} // namespace
@@ -85,88 +86,86 @@ std::vector<torch::Tensor> fastgrnn_cuda_forward(
8586 torch::Tensor u,
8687 torch::Tensor bias_z,
8788 torch::Tensor bias_h_prime,
88- torch::Tensor old_h,
8989 torch::Tensor zeta,
90- torch::Tensor nu) {
91- auto w_comp = torch::mm (input, w);
92- auto u_comp = torch::mm (old_h, u);
93- auto pre_comp = torch::add (u_comp, w_comp);
94-
90+ torch::Tensor nu,
91+ torch::Tensor old_h) {
92+
93+ auto pre_comp = torch::addmm (torch::mm (input, w.transpose (0 , 1 )), old_h, u.transpose (0 , 1 ));
94+ nu = torch::sigmoid (nu);
95+ zeta = torch::sigmoid (zeta);
9596 const auto batch_size = old_h.size (0 );
9697 const auto state_size = old_h.size (1 );
97-
9898 auto new_h = torch::zeros_like (old_h);
99- auto z_t = torch::zeros_like (old_h);
100- auto h_prime_t = torch::zeros_like (old_h);
101-
99+ auto z = torch::zeros_like (old_h);
100+ auto h_prime = torch::zeros_like (old_h);
102101 const int threads = 1024 ;
103102 const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
104-
105103 AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
106104 fastgrnn_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
107- pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
108- old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
109105 new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
110- z_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
111- h_prime_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
106+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
107+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
108+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
112109 bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
113110 bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
111+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
114112 zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
115- nu .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
113+ old_h .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
116114 }));
117-
118- return {new_h, z_t , h_prime_t , pre_comp};
115+ return {new_h, z, h_prime};
119116}
120117
121118std::vector<torch::Tensor> fastgrnn_cuda_backward (
122119 torch::Tensor grad_h,
123120 torch::Tensor input,
124121 torch::Tensor old_h,
125- torch::Tensor z_t ,
126- torch::Tensor h_prime_t ,
127- torch::Tensor pre_comp,
122+ torch::Tensor zeta,
123+ torch::Tensor nu,
128124 torch::Tensor w,
129125 torch::Tensor u,
130- torch::Tensor bias_z ,
131- torch::Tensor bias_h_prime,
132- torch::Tensor zeta,
133- torch::Tensor nu) {
134- auto d_precomp = torch::zeros_like (pre_comp );
135- auto d_old_h = torch::zeros_like (old_h);
136- auto d_zeta = torch::zeros_like (zeta );
137- auto d_nu = torch::zeros_like (nu );
138- auto d_bias_z = torch::zeros_like (bias_z );
139- auto d_bias_h_prime = torch::zeros_like (bias_h_prime );
140-
141- const auto batch_size = old_h. size ( 0 );
142- const auto state_size = old_h.size (1 );
143-
144- const int threads = 1024 ;
145- const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size) ;
146-
147- AT_DISPATCH_FLOATING_TYPES (pre_comp .type (), " fastgrnn_forward_cuda " , ([&] {
126+ torch::Tensor z ,
127+ torch::Tensor h_prime) {
128+ auto d_precomp = torch::zeros_like (old_h);
129+ auto d_bias_z = torch::zeros_like (old_h);
130+ auto d_bias_h_prime = torch::zeros_like (old_h );
131+ auto d_nu = torch::zeros_like (old_h);
132+ auto d_zeta = torch::zeros_like (old_h );
133+ auto d_old_h = torch::zeros_like (old_h );
134+ zeta = torch::sigmoid (zeta );
135+ nu = torch::sigmoid (nu );
136+ auto d_nu_sigmoid = d_sigmoid (nu);
137+ auto d_zeta_sigmoid = d_sigmoid (zeta );
138+ const auto batch_size = old_h.size (0 );
139+ const auto state_size = old_h. size ( 1 );
140+
141+ const int threads = 1024 ;
142+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
143+ AT_DISPATCH_FLOATING_TYPES (old_h .type (), " fastgrnn_backward_cuda " , ([&] {
148144 fastgrnn_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
149- d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
150- d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
151145 d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
146+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
152147 d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
153148 d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
154- d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
149+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
150+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
155151 grad_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
156- old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
157- z_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
158- h_prime_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
159- pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
160- bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
161- bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
152+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
153+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
162154 zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
163- nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
155+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
156+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
157+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
158+ old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
164159 }));
165160
166- d_old_h = torch::add (d_old_h, torch::mm (torch::add (d_bias_h_prime, d_bias_z), u.transpose (0 , 1 )));
167- auto d_input = torch::mm (d_precomp, w.transpose (0 , 1 ));
168- auto d_w = torch::mm (input.transpose (0 , 1 ), d_precomp);
169- auto d_u = torch::mm (old_h.transpose (0 , 1 ), d_precomp);
170-
171- return {d_old_h, d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_nu, d_zeta};
161+ d_old_h = torch::addmm (d_old_h, d_precomp, u);
162+ auto d_input = torch::mm (d_precomp, w);
163+ auto d_w = torch::mm (d_precomp.transpose (0 , 1 ), input);
164+ auto d_u = torch::mm (d_precomp.transpose (0 , 1 ), old_h);
165+ d_bias_z = d_bias_z.sum (0 , true );
166+ d_bias_h_prime = d_bias_h_prime.sum (0 , true );
167+ d_zeta = (d_zeta.sum (0 , true )).sum (1 , true );
168+ d_nu = (d_nu.sum (0 , true )).sum (1 , true );
169+
170+ return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
172171}
0 commit comments