Skip to content

Commit dd1297e

Browse files
committed
linear2d_layer: forgot a file
1 parent b5a600a commit dd1297e

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,22 @@
33
implicit none
44
contains
55
module function linear2d_layer_cons(&
6-
sequence_length, in_features, out_features, batch_size&
6+
sequence_length, in_features, out_features&
77
) result(res)
8-
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
8+
integer, intent(in) :: sequence_length, in_features, out_features
99
type(linear2d_layer) :: res
1010

1111
res % in_features = in_features
1212
res % out_features = out_features
1313
res % sequence_length = sequence_length
14-
res % batch_size = batch_size
1514
end function linear2d_layer_cons
1615

1716
module subroutine init(self, input_shape)
1817
class(linear2d_layer), intent(in out) :: self
1918
integer, intent(in) :: input_shape(:)
2019

21-
allocate(self % output(self % sequence_length, self % out_features, self % batch_size))
22-
allocate(self % gradient(self % sequence_length, self % in_features, self % batch_size))
20+
allocate(self % output(self % sequence_length, self % out_features))
21+
allocate(self % gradient(self % sequence_length, self % in_features))
2322

2423
allocate(self % weights(self % in_features, self % out_features))
2524
self % weights = 0.1
@@ -35,30 +34,26 @@ end subroutine init
3534

3635
pure module subroutine forward(self, input)
3736
class(linear2d_layer), intent(in out) :: self
38-
real, intent(in) :: input(:, :, :)
39-
integer :: i, j
37+
real, intent(in) :: input(:, :)
38+
integer :: i
4039

41-
do concurrent(i = 1: self % batch_size)
42-
self % output(:, :, i) = matmul(input(:, :, i), self % weights)
43-
end do
44-
do concurrent(i = 1: self % batch_size, j = 1: self % sequence_length)
45-
self % output(j, :, i) = self % output(j, :, i) + self % biases
40+
self % output(:, :) = matmul(input(:, :), self % weights)
41+
do concurrent(i = 1: self % sequence_length)
42+
self % output(i, :) = self % output(i, :) + self % biases
4643
end do
4744
end subroutine forward
4845

4946
pure module subroutine backward(self, input, gradient)
5047
class(linear2d_layer), intent(in out) :: self
51-
real, intent(in) :: input(:, :, :)
52-
real, intent(in) :: gradient(:, :, :)
48+
real, intent(in) :: input(:, :)
49+
real, intent(in) :: gradient(:, :)
5350
real :: db(self % out_features)
5451
real :: dw(self % in_features, self % out_features)
5552
integer :: i
5653

57-
do concurrent(i = 1: self % batch_size)
58-
self % dw = self % dw + matmul(transpose(input(:, :, i)), gradient(:, :, i))
59-
self % db = self % db + sum(gradient(:, :, i), 1)
60-
self % gradient(:, :, i) = matmul(gradient(:, :, i), transpose(self % weights))
61-
end do
54+
self % dw = self % dw + matmul(transpose(input(:, :)), gradient(:, :))
55+
self % db = self % db + sum(gradient(:, :), 1)
56+
self % gradient(:, :) = matmul(gradient(:, :), transpose(self % weights))
6257
end subroutine backward
6358

6459
pure module function get_num_params(self) result(num_params)

0 commit comments

Comments
 (0)