44 use nf_conv2d_layer, only: conv2d_layer
55 use nf_dense_layer, only: dense_layer
66 use nf_flatten_layer, only: flatten_layer
7+ use nf_flatten2d_layer, only: flatten2d_layer
78 use nf_input1d_layer, only: input1d_layer
89 use nf_input2d_layer, only: input2d_layer
910 use nf_input3d_layer, only: input3d_layer
@@ -48,8 +49,16 @@ pure module subroutine backward_1d(self, previous, gradient)
4849 call this_layer % backward(prev_layer % output, gradient)
4950 type is (maxpool2d_layer)
5051 call this_layer % backward(prev_layer % output, gradient)
51- ! type is(linear2d_layer)
52- ! call this_layer % backward(prev_layer % output, gradient)
52+ end select
53+
54+ type is (flatten2d_layer)
55+
56+ ! Upstream layers permitted: linear2d_layer
57+ select type (prev_layer = > previous % p)
58+ type is (linear2d_layer)
59+ call this_layer % backward(prev_layer % output, gradient)
60+ type is (input2d_layer)
61+ call this_layer % backward(prev_layer % output, gradient)
5362 end select
5463
5564 end select
@@ -63,8 +72,6 @@ pure module subroutine backward_2d(self, previous, gradient)
6372 class(layer), intent (in ) :: previous
6473 real , intent (in ) :: gradient(:,:)
6574
66- ! Backward pass from a 2-d layer downstream currently implemented
67- ! only for input2d and linear2d layers
6875 select type (this_layer = > self % p)
6976
7077 type is (linear2d_layer)
@@ -197,8 +204,14 @@ pure module subroutine forward(self, input)
197204 call this_layer % forward(prev_layer % output)
198205 type is (reshape3d_layer)
199206 call this_layer % forward(prev_layer % output)
200- ! type is(linear2d_layer)
201- ! call this_layer % forward(prev_layer % output)
207+ end select
208+
209+ type is (flatten2d_layer)
210+ select type (prev_layer = > input % p)
211+ type is (linear2d_layer)
212+ call this_layer % forward(prev_layer % output)
213+ type is (input2d_layer)
214+ call this_layer % forward(prev_layer % output)
202215 end select
203216
204217 type is (reshape3d_layer)
@@ -241,6 +254,8 @@ pure module subroutine get_output_1d(self, output)
241254 allocate (output, source= this_layer % output)
242255 type is (flatten_layer)
243256 allocate (output, source= this_layer % output)
257+ type is (flatten2d_layer)
258+ allocate (output, source= this_layer % output)
244259 class default
245260 error stop ' 1-d output can only be read from an input1d, dense, or flatten layer.'
246261
@@ -312,9 +327,11 @@ impure elemental module subroutine init(self, input)
312327 self % layer_shape = shape (this_layer % output)
313328 type is (flatten_layer)
314329 self % layer_shape = shape (this_layer % output)
330+ type is (flatten2d_layer)
331+ self % layer_shape = shape (this_layer % output)
315332 end select
316333
317- self % input_layer_shape = input % layer_shape
334+ self % input_layer_shape = input % layer_shape
318335 self % initialized = .true.
319336
320337 end subroutine init
@@ -355,6 +372,8 @@ elemental module function get_num_params(self) result(num_params)
355372 num_params = 0
356373 type is (flatten_layer)
357374 num_params = 0
375+ type is (flatten2d_layer)
376+ num_params = 0
358377 type is (reshape3d_layer)
359378 num_params = 0
360379 type is (linear2d_layer)
@@ -384,6 +403,8 @@ module function get_params(self) result(params)
384403 ! No parameters to get.
385404 type is (flatten_layer)
386405 ! No parameters to get.
406+ type is (flatten2d_layer)
407+ ! No parameters to get.
387408 type is (reshape3d_layer)
388409 ! No parameters to get.
389410 type is (linear2d_layer)
@@ -412,6 +433,8 @@ module function get_gradients(self) result(gradients)
412433 type is (maxpool2d_layer)
413434 ! No gradients to get.
414435 type is (flatten_layer)
436+ ! No parameters to get.
437+ type is (flatten2d_layer)
415438 ! No gradients to get.
416439 type is (reshape3d_layer)
417440 ! No gradients to get.
@@ -477,6 +500,11 @@ module subroutine set_params(self, params)
477500 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
478501 // ' on a zero-parameter layer; nothing to do.'
479502
503+ type is (flatten2d_layer)
504+ ! No parameters to set.
505+ write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
506+ // ' on a zero-parameter layer; nothing to do.'
507+
480508 type is (reshape3d_layer)
481509 ! No parameters to set.
482510 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments