33 implicit none
44contains
55 module function linear2d_layer_cons (&
6- sequence_length , in_features , out_features , batch_size &
6+ sequence_length , in_features , out_features&
77 ) result(res)
8- integer , intent (in ) :: batch_size, sequence_length, in_features, out_features
8+ integer , intent (in ) :: sequence_length, in_features, out_features
99 type (linear2d_layer) :: res
1010
1111 res % in_features = in_features
1212 res % out_features = out_features
1313 res % sequence_length = sequence_length
14- res % batch_size = batch_size
1514 end function linear2d_layer_cons
1615
1716 module subroutine init (self , input_shape )
1817 class(linear2d_layer), intent (in out ) :: self
1918 integer , intent (in ) :: input_shape(:)
2019
21- allocate (self % output(self % sequence_length, self % out_features, self % batch_size ))
22- allocate (self % gradient(self % sequence_length, self % in_features, self % batch_size ))
20+ allocate (self % output(self % sequence_length, self % out_features))
21+ allocate (self % gradient(self % sequence_length, self % in_features))
2322
2423 allocate (self % weights(self % in_features, self % out_features))
2524 self % weights = 0.1
@@ -35,30 +34,26 @@ end subroutine init
3534
3635 pure module subroutine forward(self, input)
3736 class(linear2d_layer), intent (in out ) :: self
38- real , intent (in ) :: input(:, :, : )
39- integer :: i, j
37+ real , intent (in ) :: input(:, :)
38+ integer :: i
4039
41- do concurrent(i = 1 : self % batch_size)
42- self % output(:, :, i) = matmul (input(:, :, i), self % weights)
43- end do
44- do concurrent(i = 1 : self % batch_size, j = 1 : self % sequence_length)
45- self % output(j, :, i) = self % output(j, :, i) + self % biases
40+ self % output(:, :) = matmul (input(:, :), self % weights)
41+ do concurrent(i = 1 : self % sequence_length)
42+ self % output(i, :) = self % output(i, :) + self % biases
4643 end do
4744 end subroutine forward
4845
4946 pure module subroutine backward(self, input, gradient)
5047 class(linear2d_layer), intent (in out ) :: self
51- real , intent (in ) :: input(:, :, : )
52- real , intent (in ) :: gradient(:, :, : )
48+ real , intent (in ) :: input(:, :)
49+ real , intent (in ) :: gradient(:, :)
5350 real :: db(self % out_features)
5451 real :: dw(self % in_features, self % out_features)
5552 integer :: i
5653
57- do concurrent(i = 1 : self % batch_size)
58- self % dw = self % dw + matmul (transpose (input(:, :, i)), gradient(:, :, i))
59- self % db = self % db + sum (gradient(:, :, i), 1 )
60- self % gradient(:, :, i) = matmul (gradient(:, :, i), transpose (self % weights))
61- end do
54+ self % dw = self % dw + matmul (transpose (input(:, :)), gradient(:, :))
55+ self % db = self % db + sum (gradient(:, :), 1 )
56+ self % gradient(:, :) = matmul (gradient(:, :), transpose (self % weights))
6257 end subroutine backward
6358
6459 pure module function get_num_params(self) result(num_params)
0 commit comments