44
55std::vector<torch::Tensor> fastgrnn_cuda_forward (
66 torch::Tensor input,
7- torch::Tensor w ,
8- torch::Tensor u ,
7+ torch::Tensor W ,
8+ torch::Tensor U ,
99 torch::Tensor bias_gate,
1010 torch::Tensor bias_update,
1111 torch::Tensor zeta,
1212 torch::Tensor nu,
13- torch::Tensor old_h);
13+ torch::Tensor old_h,
14+ int z_non_linearity);
1415
1516std::vector<torch::Tensor> fastgrnn_cuda_backward (
1617 torch::Tensor grad_h,
1718 torch::Tensor input,
1819 torch::Tensor old_h,
1920 torch::Tensor zeta,
2021 torch::Tensor nu,
21- torch::Tensor w,
22- torch::Tensor u,
22+ torch::Tensor W,
23+ torch::Tensor U,
24+ int z_non_linearity,
2325 torch::Tensor z,
2426 torch::Tensor h_prime);
2527
@@ -29,23 +31,24 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
2931
3032std::vector<torch::Tensor> fastgrnn_forward (
3133 torch::Tensor input,
32- torch::Tensor w ,
33- torch::Tensor u ,
34+ torch::Tensor W ,
35+ torch::Tensor U ,
3436 torch::Tensor bias_gate,
3537 torch::Tensor bias_update,
3638 torch::Tensor zeta,
3739 torch::Tensor nu,
38- torch::Tensor old_h) {
40+ torch::Tensor old_h,
41+ int z_non_linearity) {
3942 CHECK_INPUT (input);
40- CHECK_INPUT (w );
41- CHECK_INPUT (u );
43+ CHECK_INPUT (W );
44+ CHECK_INPUT (U );
4245 CHECK_INPUT (bias_gate);
4346 CHECK_INPUT (bias_update);
4447 CHECK_INPUT (zeta);
4548 CHECK_INPUT (nu);
4649 CHECK_INPUT (old_h);
4750
48- return fastgrnn_cuda_forward (input, w, u , bias_gate, bias_update, zeta, nu, old_h);
51+ return fastgrnn_cuda_forward (input, W, U , bias_gate, bias_update, zeta, nu, old_h, z_non_linearity );
4952}
5053
5154std::vector<torch::Tensor> fastgrnn_backward (
@@ -54,21 +57,22 @@ std::vector<torch::Tensor> fastgrnn_backward(
5457 torch::Tensor old_h,
5558 torch::Tensor zeta,
5659 torch::Tensor nu,
57- torch::Tensor w ,
58- torch::Tensor u ,
60+ torch::Tensor W ,
61+ torch::Tensor U ,
5962 torch::Tensor z,
60- torch::Tensor h_prime) {
63+ torch::Tensor h_prime,
64+ int z_non_linearity) {
6165 CHECK_INPUT (grad_h);
6266 CHECK_INPUT (input);
6367 CHECK_INPUT (old_h);
6468 CHECK_INPUT (zeta);
6569 CHECK_INPUT (nu);
6670 CHECK_INPUT (z);
6771 CHECK_INPUT (h_prime);
68- CHECK_INPUT (w );
69- CHECK_INPUT (u );
72+ CHECK_INPUT (W );
73+ CHECK_INPUT (U );
7074
71- return fastgrnn_cuda_backward (grad_h, input, old_h, zeta, nu, w, u , z, h_prime);
75+ return fastgrnn_cuda_backward (grad_h, input, old_h, zeta, nu, W, U, z_non_linearity , z, h_prime);
7276}
7377
7478PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
0 commit comments