@@ -31,9 +31,10 @@ module nf_linear2d_layer
3131 end type linear2d_layer
3232
3333 interface linear2d_layer
34- module function linear2d_layer_cons (in_features , out_features ) &
35- result(res)
36- integer , intent (in ) :: in_features, out_features
34+ module function linear2d_layer_cons (&
35+ batch_size , sequence_length , in_features , out_features&
36+ ) result(res)
37+ integer , intent (in ) :: batch_size, sequence_length, in_features, out_features
3738 type (linear2d_layer) :: res
3839 end function linear2d_layer_cons
3940 end interface linear2d_layer
@@ -44,133 +45,35 @@ pure module subroutine forward(self, input)
4445 real , intent (in ) :: input(:, :, :)
4546 end subroutine forward
4647
48+ pure module subroutine backward(self, input, gradient)
49+ class(linear2d_layer), intent (in out ) :: self
50+ real , intent (in ) :: input(:, :, :)
51+ real , intent (in ) :: gradient(:, :, :)
52+ end subroutine backward
53+
4754 module subroutine init (self , input_shape )
4855 class(linear2d_layer), intent (in out ) :: self
4956 integer , intent (in ) :: input_shape(:)
5057 end subroutine init
51- end interface
52-
53- contains
54- module function linear2d_layer_cons (&
55- batch_size , sequence_length , in_features , out_features&
56- ) result(res)
57- integer , intent (in ) :: batch_size, sequence_length, in_features, out_features
58- type (linear2d_layer) :: res
59-
60- res % in_features = in_features
61- res % out_features = out_features
62- res % sequence_length = sequence_length
63- res % batch_size = batch_size
64- end function linear2d_layer_cons
65-
66- module subroutine init (self , input_shape )
67- class(linear2d_layer), intent (in out ) :: self
68- integer , intent (in ) :: input_shape(:)
69-
70- allocate (self % output(self % batch_size, self % sequence_length, self % out_features))
71- allocate (self % gradient(self % batch_size, self % sequence_length, self % in_features))
72-
73- allocate (self % weights(self % in_features, self % out_features))
74- self % weights = 0.1
75-
76- allocate (self % biases(self % out_features))
77- self% biases = 0.11
78-
79- allocate (self % dw(self % in_features, self % out_features))
80- self % dw = 0.0
81- allocate (self % db(self % out_features))
82- self % db = 0.0
83- end subroutine init
84-
85- pure module subroutine forward(self, input)
86- class(linear2d_layer), intent (in out ) :: self
87- real , intent (in ) :: input(:, :, :)
88- integer :: i, j
89-
90- do concurrent(i = 1 : self % batch_size)
91- self % output(i, :, :) = matmul (input(i, :, :), self % weights)
92- end do
93- do concurrent(i = 1 : self % batch_size, j = 1 : self % sequence_length)
94- self % output(i, j, :) = self % output(i, j, :) + self % biases
95- end do
96- end subroutine forward
97-
98- pure module subroutine backward(self, input, gradient)
99- class(linear2d_layer), intent (in out ) :: self
100- real , intent (in ) :: input(:, :, :)
101- real , intent (in ) :: gradient(:, :, :)
102- real :: db(self % out_features)
103- real :: dw(self % in_features, self % out_features)
104- integer :: i
105-
106- do concurrent(i = 1 : self % batch_size)
107- self % dw = self % dw + matmul (transpose (input(i, :, :)), gradient(i, :, :))
108- self % db = self % db + sum (gradient(i, :, :), 1 )
109- self % gradient(i, :, :) = matmul (gradient(i, :, :), transpose (self % weights))
110- end do
111- end subroutine backward
112-
113- pure module function get_num_params(self) result(num_params)
114- class(linear2d_layer), intent (in ) :: self
115- integer :: num_params
116-
117- ! Number of weigths times number of biases
118- num_params = self % in_features * self % out_features + self % out_features
119-
120- end function get_num_params
12158
59+ pure module function get_num_params(self) result(num_params)
60+ class(linear2d_layer), intent (in ) :: self
61+ integer :: num_params
62+ end function get_num_params
12263
123- module function get_params (self ) result(params)
124- class(linear2d_layer), intent (in ), target :: self
125- real , allocatable :: params(:)
64+ module function get_params (self ) result(params)
65+ class(linear2d_layer), intent (in ), target :: self
66+ real , allocatable :: params(:)
67+ end function get_params
12668
127- real , pointer :: w_(:) = > null ()
69+ module function get_gradients (self ) result(gradients)
70+ class(linear2d_layer), intent (in ), target :: self
71+ real , allocatable :: gradients(:)
72+ end function get_gradients
12873
129- w_(1 :size (self % weights)) = > self % weights
130-
131- params = [ &
132- w_, &
133- self % biases &
134- ]
135-
136- end function get_params
137-
138-
139- module function get_gradients (self ) result(gradients)
140- class(linear2d_layer), intent (in ), target :: self
141- real , allocatable :: gradients(:)
142-
143- real , pointer :: dw_(:) = > null ()
144-
145- dw_(1 :size (self % dw)) = > self % dw
146-
147- gradients = [ &
148- dw_, &
149- self % db &
150- ]
151-
152- end function get_gradients
153-
154-
155- module subroutine set_params (self , params )
156- class(linear2d_layer), intent (in out ) :: self
157- real , intent (in ), target :: params(:)
158-
159- real , pointer :: p_(:,:) = > null ()
160-
161- ! check if the number of parameters is correct
162- if (size (params) /= self % get_num_params()) then
163- error stop ' Error: number of parameters does not match'
164- end if
165-
166- associate(n = > self % in_features * self % out_features)
167- ! reshape the weights
168- p_(1 :self % in_features, 1 :self % out_features) = > params(1 : n)
169- self % weights = p_
170-
171- ! reshape the biases
172- self % biases = params(n + 1 : n + self % out_features)
173- end associate
174-
175- end subroutine set_params
74+ module subroutine set_params (self , params )
75+ class(linear2d_layer), intent (in out ) :: self
76+ real , intent (in ), target :: params(:)
77+ end subroutine set_params
78+ end interface
17679end module nf_linear2d_layer
0 commit comments