Skip to content

Commit d01a174

Browse files
committed
linear2d_layer: make linear2d layer work with input2d and flatten2d
1 parent 1bec531 commit d01a174

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

src/nf.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
6+
conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d
77
use nf_loss, only: mse, quadratic
88
use nf_metrics, only: corr, maxabs
99
use nf_network, only: network

src/nf/nf_layer_submodule.f90

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use nf_conv2d_layer, only: conv2d_layer
55
use nf_dense_layer, only: dense_layer
66
use nf_flatten_layer, only: flatten_layer
7+
use nf_flatten2d_layer, only: flatten2d_layer
78
use nf_input1d_layer, only: input1d_layer
89
use nf_input2d_layer, only: input2d_layer
910
use nf_input3d_layer, only: input3d_layer
@@ -48,8 +49,16 @@ pure module subroutine backward_1d(self, previous, gradient)
4849
call this_layer % backward(prev_layer % output, gradient)
4950
type is(maxpool2d_layer)
5051
call this_layer % backward(prev_layer % output, gradient)
51-
! type is(linear2d_layer)
52-
! call this_layer % backward(prev_layer % output, gradient)
52+
end select
53+
54+
type is(flatten2d_layer)
55+
56+
! Upstream layers permitted: linear2d_layer
57+
select type(prev_layer => previous % p)
58+
type is(linear2d_layer)
59+
call this_layer % backward(prev_layer % output, gradient)
60+
type is(input2d_layer)
61+
call this_layer % backward(prev_layer % output, gradient)
5362
end select
5463

5564
end select
@@ -63,8 +72,6 @@ pure module subroutine backward_2d(self, previous, gradient)
6372
class(layer), intent(in) :: previous
6473
real, intent(in) :: gradient(:,:)
6574

66-
! Backward pass from a 2-d layer downstream currently implemented
67-
! only for input2d and linear2d layers
6875
select type(this_layer => self % p)
6976

7077
type is(linear2d_layer)
@@ -197,8 +204,14 @@ pure module subroutine forward(self, input)
197204
call this_layer % forward(prev_layer % output)
198205
type is(reshape3d_layer)
199206
call this_layer % forward(prev_layer % output)
200-
! type is(linear2d_layer)
201-
! call this_layer % forward(prev_layer % output)
207+
end select
208+
209+
type is(flatten2d_layer)
210+
select type(prev_layer => input % p)
211+
type is(linear2d_layer)
212+
call this_layer % forward(prev_layer % output)
213+
type is(input2d_layer)
214+
call this_layer % forward(prev_layer % output)
202215
end select
203216

204217
type is(reshape3d_layer)
@@ -241,6 +254,8 @@ pure module subroutine get_output_1d(self, output)
241254
allocate(output, source=this_layer % output)
242255
type is(flatten_layer)
243256
allocate(output, source=this_layer % output)
257+
type is(flatten2d_layer)
258+
allocate(output, source=this_layer % output)
244259
class default
245260
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'
246261

@@ -312,9 +327,11 @@ impure elemental module subroutine init(self, input)
312327
self % layer_shape = shape(this_layer % output)
313328
type is(flatten_layer)
314329
self % layer_shape = shape(this_layer % output)
330+
type is(flatten2d_layer)
331+
self % layer_shape = shape(this_layer % output)
315332
end select
316333

317-
self % input_layer_shape = input % layer_shape
334+
self % input_layer_shape = input % layer_shape
318335
self % initialized = .true.
319336

320337
end subroutine init
@@ -355,6 +372,8 @@ elemental module function get_num_params(self) result(num_params)
355372
num_params = 0
356373
type is (flatten_layer)
357374
num_params = 0
375+
type is (flatten2d_layer)
376+
num_params = 0
358377
type is (reshape3d_layer)
359378
num_params = 0
360379
type is (linear2d_layer)
@@ -384,6 +403,8 @@ module function get_params(self) result(params)
384403
! No parameters to get.
385404
type is (flatten_layer)
386405
! No parameters to get.
406+
type is (flatten2d_layer)
407+
! No parameters to get.
387408
type is (reshape3d_layer)
388409
! No parameters to get.
389410
type is (linear2d_layer)
@@ -412,6 +433,8 @@ module function get_gradients(self) result(gradients)
412433
type is (maxpool2d_layer)
413434
! No gradients to get.
414435
type is (flatten_layer)
436+
! No parameters to get.
437+
type is (flatten2d_layer)
415438
! No gradients to get.
416439
type is (reshape3d_layer)
417440
! No gradients to get.
@@ -477,6 +500,11 @@ module subroutine set_params(self, params)
477500
write(stderr, '(a)') 'Warning: calling set_params() ' &
478501
// 'on a zero-parameter layer; nothing to do.'
479502

503+
type is (flatten2d_layer)
504+
! No parameters to set.
505+
write(stderr, '(a)') 'Warning: calling set_params() ' &
506+
// 'on a zero-parameter layer; nothing to do.'
507+
480508
type is (reshape3d_layer)
481509
! No parameters to set.
482510
write(stderr, '(a)') 'Warning: calling set_params() ' &

src/nf/nf_network_submodule.f90

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ module subroutine backward(self, output, loss)
151151
else
152152
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
153153
end if
154-
155154
type is(maxpool2d_layer)
156155
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
157156

@@ -283,6 +282,10 @@ module function predict_2d(self, input) result(res)
283282
select type(output_layer => self % layers(num_layers) % p)
284283
type is(dense_layer)
285284
res = output_layer % output
285+
type is(flatten_layer)
286+
res = output_layer % output
287+
class default
288+
error stop 'network % output not implemented for this output layer'
286289
end select
287290

288291
end function predict_2d

0 commit comments

Comments
 (0)