@@ -64,8 +64,20 @@ pure module subroutine backward_2d(self, previous, gradient)
6464 real , intent (in ) :: gradient(:,:)
6565
6666 ! Backward pass from a 2-d layer downstream currently implemented
67- ! only for dense and flatten layers
68- ! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
67+ ! only for input2d and linear2d layers
68+ select type (this_layer = > self % p)
69+
70+ type is (linear2d_layer)
71+
72+ select type (prev_layer = > previous % p)
73+ type is (input2d_layer)
74+ call this_layer % backward(prev_layer % output, gradient)
75+ type is (linear2d_layer)
76+ call this_layer % backward(prev_layer % output, gradient)
77+ end select
78+
79+ end select
80+
6981 end subroutine backward_2d
7082
7183
@@ -119,12 +131,6 @@ pure module subroutine backward_3d(self, previous, gradient)
119131 call this_layer % backward(prev_layer % output, gradient)
120132 end select
121133
122- ! type is(linear2d_layer)
123- ! select type(prev_layer => previous % p)
124- ! type is(input3d_layer)
125- ! call this_layer % backward(prev_layer % output, gradient)
126- ! end select
127-
128134 end select
129135
130136 end subroutine backward_3d
@@ -207,13 +213,15 @@ pure module subroutine forward(self, input)
207213 call this_layer % forward(prev_layer % output)
208214 end select
209215
210- ! type is(linear2d_layer)
211- ! select type(prev_layer => input % p)
212- ! type is(input3d_layer)
213- ! call this_layer % forward(prev_layer % output)
214- ! type is(linear2d_layer)
215- ! call this_layer % forward(prev_layer % output)
216- ! end select
216+ type is (linear2d_layer)
217+
218+ ! Upstream layers permitted: input2d, linear2d
219+ select type (prev_layer = > input % p)
220+ type is (input2d_layer)
221+ call this_layer % forward(prev_layer % output)
222+ type is (linear2d_layer)
223+ call this_layer % forward(prev_layer % output)
224+ end select
217225
218226 end select
219227
@@ -250,8 +258,10 @@ pure module subroutine get_output_2d(self, output)
250258
251259 type is (input2d_layer)
252260 allocate (output, source= this_layer % output)
261+ type is (linear2d_layer)
262+ allocate (output, source= this_layer % output)
253263 class default
254- error stop ' 1 -d output can only be read from an input1d, dense, or flatten layer.'
264+ error stop ' 2 -d output can only be read from an input2d or linear2d layer.'
255265
256266 end select
257267
@@ -347,8 +357,8 @@ elemental module function get_num_params(self) result(num_params)
347357 num_params = 0
348358 type is (reshape3d_layer)
349359 num_params = 0
350- ! type is (linear2d_layer)
351- ! num_params = this_layer % get_num_params()
360+ type is (linear2d_layer)
361+ num_params = this_layer % get_num_params()
352362 class default
353363 error stop ' Unknown layer type.'
354364 end select
@@ -376,8 +386,8 @@ module function get_params(self) result(params)
376386 ! No parameters to get.
377387 type is (reshape3d_layer)
378388 ! No parameters to get.
379- ! type is (linear2d_layer)
380- ! params = this_layer % get_params()
389+ type is (linear2d_layer)
390+ params = this_layer % get_params()
381391 class default
382392 error stop ' Unknown layer type.'
383393 end select
@@ -405,8 +415,8 @@ module function get_gradients(self) result(gradients)
405415 ! No gradients to get.
406416 type is (reshape3d_layer)
407417 ! No gradients to get.
408- ! type is (linear2d_layer)
409- ! gradients = this_layer % get_gradients()
418+ type is (linear2d_layer)
419+ gradients = this_layer % get_gradients()
410420 class default
411421 error stop ' Unknown layer type.'
412422 end select
@@ -454,6 +464,9 @@ module subroutine set_params(self, params)
454464 type is (conv2d_layer)
455465 call this_layer % set_params(params)
456466
467+ type is (linear2d_layer)
468+ call this_layer % set_params(params)
469+
457470 type is (maxpool2d_layer)
458471 ! No parameters to set.
459472 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
@@ -472,8 +485,6 @@ module subroutine set_params(self, params)
472485 class default
473486 error stop ' Unknown layer type.'
474487
475- ! type is (linear2d_layer)
476- ! call this_layer % set_params(params)
477488 end select
478489
479490 end subroutine set_params
0 commit comments