@@ -62,8 +62,20 @@ pure module subroutine backward_2d(self, previous, gradient)
6262 real , intent (in ) :: gradient(:,:)
6363
6464 ! Backward pass from a 2-d layer downstream currently implemented
65- ! only for dense and flatten layers
66- ! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
65+ ! only for input2d and linear2d layers
66+ select type (this_layer = > self % p)
67+
68+ type is (linear2d_layer)
69+
70+ select type (prev_layer = > previous % p)
71+ type is (input2d_layer)
72+ call this_layer % backward(prev_layer % output, gradient)
73+ type is (linear2d_layer)
74+ call this_layer % backward(prev_layer % output, gradient)
75+ end select
76+
77+ end select
78+
6779 end subroutine backward_2d
6880
6981
@@ -117,12 +129,6 @@ pure module subroutine backward_3d(self, previous, gradient)
117129 call this_layer % backward(prev_layer % output, gradient)
118130 end select
119131
120- ! type is(linear2d_layer)
121- ! select type(prev_layer => previous % p)
122- ! type is(input3d_layer)
123- ! call this_layer % backward(prev_layer % output, gradient)
124- ! end select
125-
126132 end select
127133
128134 end subroutine backward_3d
@@ -203,13 +209,15 @@ pure module subroutine forward(self, input)
203209 call this_layer % forward(prev_layer % output)
204210 end select
205211
206- ! type is(linear2d_layer)
207- ! select type(prev_layer => input % p)
208- ! type is(input3d_layer)
209- ! call this_layer % forward(prev_layer % output)
210- ! type is(linear2d_layer)
211- ! call this_layer % forward(prev_layer % output)
212- ! end select
212+ type is (linear2d_layer)
213+
214+ ! Upstream layers permitted: input2d, linear2d
215+ select type (prev_layer = > input % p)
216+ type is (input2d_layer)
217+ call this_layer % forward(prev_layer % output)
218+ type is (linear2d_layer)
219+ call this_layer % forward(prev_layer % output)
220+ end select
213221
214222 end select
215223
@@ -246,8 +254,10 @@ pure module subroutine get_output_2d(self, output)
246254
247255 type is (input2d_layer)
248256 allocate (output, source= this_layer % output)
257+ type is (linear2d_layer)
258+ allocate (output, source= this_layer % output)
249259 class default
250- error stop ' 1 -d output can only be read from an input1d, dense, or flatten layer.'
260+ error stop ' 2 -d output can only be read from an input2d or linear2d layer.'
251261
252262 end select
253263
@@ -343,8 +353,8 @@ elemental module function get_num_params(self) result(num_params)
343353 num_params = 0
344354 type is (reshape3d_layer)
345355 num_params = 0
346- ! type is (linear2d_layer)
347- ! num_params = this_layer % get_num_params()
356+ type is (linear2d_layer)
357+ num_params = this_layer % get_num_params()
348358 class default
349359 error stop ' Unknown layer type.'
350360 end select
@@ -372,8 +382,8 @@ module function get_params(self) result(params)
372382 ! No parameters to get.
373383 type is (reshape3d_layer)
374384 ! No parameters to get.
375- ! type is (linear2d_layer)
376- ! params = this_layer % get_params()
385+ type is (linear2d_layer)
386+ params = this_layer % get_params()
377387 class default
378388 error stop ' Unknown layer type.'
379389 end select
@@ -401,8 +411,8 @@ module function get_gradients(self) result(gradients)
401411 ! No gradients to get.
402412 type is (reshape3d_layer)
403413 ! No gradients to get.
404- ! type is (linear2d_layer)
405- ! gradients = this_layer % get_gradients()
414+ type is (linear2d_layer)
415+ gradients = this_layer % get_gradients()
406416 class default
407417 error stop ' Unknown layer type.'
408418 end select
@@ -450,6 +460,9 @@ module subroutine set_params(self, params)
450460 type is (conv2d_layer)
451461 call this_layer % set_params(params)
452462
463+ type is (linear2d_layer)
464+ call this_layer % set_params(params)
465+
453466 type is (maxpool2d_layer)
454467 ! No parameters to set.
455468 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
@@ -468,8 +481,6 @@ module subroutine set_params(self, params)
468481 class default
469482 error stop ' Unknown layer type.'
470483
471- ! type is (linear2d_layer)
472- ! call this_layer % set_params(params)
473484 end select
474485
475486 end subroutine set_params
0 commit comments