Skip to content

Commit 8aa8278

Browse files
committed
update tests for linear2d_layer
1 parent 06afae3 commit 8aa8278

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

test/test_linear2d_layer.f90

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ program test_linear2d_layer
88
[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2,&
99
0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2],&
1010
[2, 3, 4]) ! first batch are 0.1, second 0.2
11-
real :: sample_gradient(2, 3, 1) = reshape([2., 2., 2., 2., 2., 2.], [2, 3, 1])
11+
real :: sample_gradient(2, 3, 1) = reshape([2., 2., 2., 3., 3., 3.], [2, 3, 1])
1212
type(linear2d_layer) :: linear
1313

1414
linear = linear2d_layer(batch_size=2, sequence_length=3, in_features=4, out_features=1)
@@ -54,12 +54,18 @@ subroutine test_linear2d_layer_backward(linear, ok, input, gradient)
5454
real :: expected_gradient_shape(3) = [2, 3, 4]
5555
real :: expected_dw_shape(2) = [4, 1]
5656
real :: expected_db_shape(1) = [1]
57-
real :: expected_gradient_flat(24)
57+
real :: expected_gradient_flat(24) = [&
58+
0.200000003, 0.200000003, 0.200000003, 0.300000012,&
59+
0.300000012, 0.300000012, 0.200000003, 0.200000003,&
60+
0.200000003, 0.300000012, 0.300000012, 0.300000012,&
61+
0.200000003, 0.200000003, 0.200000003, 0.300000012,&
62+
0.300000012, 0.300000012, 0.200000003, 0.200000003,&
63+
0.200000003, 0.300000012, 0.300000012, 0.300000012&
64+
]
5865
real :: expected_dw_flat(4)
59-
real :: expected_db(1) = [12.0]
66+
real :: expected_db(1) = [15.0]
6067

61-
expected_gradient_flat = 0.200000003
62-
expected_dw_flat = 1.80000007
68+
expected_dw_flat = 2.29999995
6369

6470
call linear % backward(input, gradient)
6571

0 commit comments

Comments
 (0)