Skip to content

Commit b5a600a

Browse files
committed
linear2d_layer: make it 2d
1 parent 07750db commit b5a600a

File tree

2 files changed

+30
-36
lines changed

2 files changed

+30
-36
lines changed

src/nf/nf_linear2d_layer.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ module nf_linear2d_layer
1313

1414
real, allocatable :: weights(:, :)
1515
real, allocatable :: biases(:)
16-
real, allocatable :: output(:, :, :)
17-
real, allocatable :: gradient(:, :, :) ! input gradient
16+
real, allocatable :: output(:, :)
17+
real, allocatable :: gradient(:, :) ! input gradient
1818
real, allocatable :: dw(:, :) ! weight gradients
1919
real, allocatable :: db(:) ! bias gradients
2020

@@ -32,23 +32,23 @@ module nf_linear2d_layer
3232

3333
interface linear2d_layer
3434
module function linear2d_layer_cons(&
35-
sequence_length, in_features, out_features, batch_size&
35+
sequence_length, in_features, out_features&
3636
) result(res)
37-
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
37+
integer, intent(in) :: sequence_length, in_features, out_features
3838
type(linear2d_layer) :: res
3939
end function linear2d_layer_cons
4040
end interface linear2d_layer
4141

4242
interface
4343
pure module subroutine forward(self, input)
4444
class(linear2d_layer), intent(in out) :: self
45-
real, intent(in) :: input(:, :, :)
45+
real, intent(in) :: input(:, :)
4646
end subroutine forward
4747

4848
pure module subroutine backward(self, input, gradient)
4949
class(linear2d_layer), intent(in out) :: self
50-
real, intent(in) :: input(:, :, :)
51-
real, intent(in) :: gradient(:, :, :)
50+
real, intent(in) :: input(:, :)
51+
real, intent(in) :: gradient(:, :)
5252
end subroutine backward
5353

5454
module subroutine init(self, input_shape)

test/test_linear2d_layer.f90

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ program test_linear2d_layer
44
implicit none
55

66
logical :: ok = .true.
7-
real :: sample_input(3, 4, 2) = reshape(&
8-
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,&
9-
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],&
10-
[3, 4, 2]) ! first batch are 0.1, second 0.2
11-
real :: sample_gradient(3, 1, 2) = reshape([2., 2., 2., 3., 3., 3.], [3, 1, 2])
7+
real :: sample_input(3, 4) = reshape(&
8+
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],&
9+
[3, 4]) ! first batch are 0.1, second 0.2
10+
real :: sample_gradient(3, 1) = reshape([2., 2., 3.], [3, 1])
1211
type(linear2d_layer) :: linear
1312

14-
linear = linear2d_layer(sequence_length=3, in_features=4, out_features=1, batch_size=2)
13+
linear = linear2d_layer(sequence_length=3, in_features=4, out_features=1)
1514
call linear % init([4])
1615

1716
call test_linear2d_layer_forward(linear, ok, sample_input)
@@ -22,11 +21,11 @@ program test_linear2d_layer
2221
subroutine test_linear2d_layer_forward(linear, ok, input)
2322
type(linear2d_layer), intent(in out) :: linear
2423
logical, intent(in out) :: ok
25-
real, intent(in) :: input(3, 4, 2)
26-
real :: output_shape(3)
27-
real :: output_flat(6)
28-
real :: expected_shape(3) = [3, 1, 2]
29-
real :: expected_output_flat(6) = [0.15, 0.15, 0.15, 0.19, 0.19, 0.19]
24+
real, intent(in) :: input(3, 4)
25+
real :: output_shape(2)
26+
real :: output_flat(3)
27+
real :: expected_shape(2) = [3, 1]
28+
real :: expected_output_flat(3) = [0.17, 0.17, 0.17]
3029

3130
call linear % forward(input)
3231

@@ -45,28 +44,23 @@ end subroutine test_linear2d_layer_forward
4544
subroutine test_linear2d_layer_backward(linear, ok, input, gradient)
4645
type(linear2d_layer), intent(in out) :: linear
4746
logical, intent(in out) :: ok
48-
real, intent(in) :: input(3, 4, 2)
49-
real, intent(in) :: gradient(3, 1, 2)
50-
real :: gradient_shape(3)
47+
real, intent(in) :: input(3, 4)
48+
real, intent(in) :: gradient(3, 1)
49+
real :: gradient_shape(2)
5150
real :: dw_shape(2)
5251
real :: db_shape(1)
53-
real :: gradient_flat(24)
52+
real :: gradient_flat(12)
5453
real :: dw_flat(4)
55-
real :: expected_gradient_shape(3) = [3, 4, 2]
54+
real :: expected_gradient_shape(2) = [3, 4]
5655
real :: expected_dw_shape(2) = [4, 1]
5756
real :: expected_db_shape(1) = [1]
58-
real :: expected_gradient_flat(24) = [&
59-
0.200000003, 0.200000003, 0.200000003, 0.200000003,&
60-
0.200000003, 0.200000003, 0.200000003, 0.200000003,&
61-
0.200000003, 0.200000003, 0.200000003, 0.200000003,&
62-
0.300000012, 0.300000012, 0.300000012, 0.300000012,&
63-
0.300000012, 0.300000012, 0.300000012, 0.300000012,&
64-
0.300000012, 0.300000012, 0.300000012, 0.300000012&
57+
real :: expected_gradient_flat(12) = [&
58+
0.2, 0.2, 0.3, 0.2,&
59+
0.2, 0.3, 0.2, 0.2,&
60+
0.3, 0.2, 0.2, 0.3&
6561
]
66-
real :: expected_dw_flat(4)
67-
real :: expected_db(1) = [15.0]
68-
69-
expected_dw_flat = 2.40000010
62+
real :: expected_dw_flat(4) = [0.7, 0.7, 1.4, 1.4]
63+
real :: expected_db(1) = [7]
7064

7165
call linear % backward(input, gradient)
7266

@@ -104,8 +98,8 @@ end subroutine test_linear2d_layer_backward
10498

10599
subroutine test_linear2d_layer_gradient_updates(ok)
106100
logical, intent(in out) :: ok
107-
real :: input(3, 4, 1) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4, 1])
108-
real :: gradient(3, 2, 1) = reshape([0.0, 10., 0.2, 3., 0.4, 1.], [3, 2, 1])
101+
real :: input(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4])
102+
real :: gradient(3, 2) = reshape([0.0, 10., 0.2, 3., 0.4, 1.], [3, 2])
109103
type(linear2d_layer) :: linear
110104

111105
integer :: num_parameters

0 commit comments

Comments
 (0)