Skip to content

Commit 1bec531

Browse files
committed
linear2d_layer: add flatten2d layer
1 parent b39e6da commit 1bec531

File tree

5 files changed

+240
-1
lines changed

5 files changed

+240
-1
lines changed

src/nf/nf_flatten2d_layer.f90

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
module nf_flatten2d_layer
2+
3+
!! This module provides the concrete flatten2d layer type.
4+
!! It is used internally by the layer type.
5+
!! It is not intended to be used directly by the user.
6+
7+
use nf_base_layer, only: base_layer
8+
9+
implicit none
10+
11+
private
12+
public :: flatten2d_layer
13+
14+
type, extends(base_layer) :: flatten2d_layer
15+
16+
!! Concrete implementation of a flatten2d (2-d to 1-d) layer.
17+
18+
integer, allocatable :: input_shape(:)
19+
integer :: output_size
20+
21+
real, allocatable :: gradient(:,:)
22+
real, allocatable :: output(:)
23+
24+
contains
25+
26+
procedure :: backward
27+
procedure :: forward
28+
procedure :: init
29+
30+
end type flatten2d_layer
31+
32+
interface flatten2d_layer
33+
elemental module function flatten2d_layer_cons() result(res)
34+
!! This function returns the `flatten2d_layer` instance.
35+
type(flatten2d_layer) :: res
36+
!! `flatten2d_layer` instance
37+
end function flatten2d_layer_cons
38+
end interface flatten2d_layer
39+
40+
interface
41+
42+
pure module subroutine backward(self, input, gradient)
43+
!! Apply the backward pass to the flatten2d layer.
44+
!! This is a reshape operation from 1-d gradient to 2-d input.
45+
class(flatten2d_layer), intent(in out) :: self
46+
!! flatten2d layer instance
47+
real, intent(in) :: input(:,:)
48+
!! Input from the previous layer
49+
real, intent(in) :: gradient(:)
50+
!! Gradient from the next layer
51+
end subroutine backward
52+
53+
pure module subroutine forward(self, input)
54+
!! Propagate forward the layer.
55+
!! Calling this subroutine updates the values of a few data components
56+
!! of `flatten2d_layer` that are needed for the backward pass.
57+
class(flatten2d_layer), intent(in out) :: self
58+
!! Dense layer instance
59+
real, intent(in) :: input(:,:)
60+
!! Input from the previous layer
61+
end subroutine forward
62+
63+
module subroutine init(self, input_shape)
64+
!! Initialize the layer data structures.
65+
!!
66+
!! This is a deferred procedure from the `base_layer` abstract type.
67+
class(flatten2d_layer), intent(in out) :: self
68+
!! Dense layer instance
69+
integer, intent(in) :: input_shape(:)
70+
!! Shape of the input layer
71+
end subroutine init
72+
73+
end interface
74+
75+
end module nf_flatten2d_layer
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
submodule(nf_flatten2d_layer) nf_flatten2d_layer_submodule
2+
3+
!! This module provides the concrete flatten2d layer type.
4+
!! It is used internally by the layer type.
5+
!! It is not intended to be used directly by the user.
6+
7+
use nf_base_layer, only: base_layer
8+
9+
implicit none
10+
11+
contains
12+
13+
elemental module function flatten2d_layer_cons() result(res)
14+
type(flatten2d_layer) :: res
15+
end function flatten2d_layer_cons
16+
17+
18+
pure module subroutine backward(self, input, gradient)
19+
class(flatten2d_layer), intent(in out) :: self
20+
real, intent(in) :: input(:,:)
21+
real, intent(in) :: gradient(:)
22+
self % gradient = reshape(gradient, shape(input))
23+
end subroutine backward
24+
25+
26+
pure module subroutine forward(self, input)
27+
class(flatten2d_layer), intent(in out) :: self
28+
real, intent(in) :: input(:,:)
29+
self % output = pack(input, .true.)
30+
end subroutine forward
31+
32+
33+
module subroutine init(self, input_shape)
34+
class(flatten2d_layer), intent(in out) :: self
35+
integer, intent(in) :: input_shape(:)
36+
37+
self % input_shape = input_shape
38+
self % output_size = product(input_shape)
39+
40+
allocate(self % gradient(input_shape(1), input_shape(2)))
41+
self % gradient = 0
42+
43+
allocate(self % output(self % output_size))
44+
self % output = 0
45+
46+
end subroutine init
47+
48+
end submodule nf_flatten2d_layer_submodule

src/nf/nf_layer_constructors.f90

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module nf_layer_constructors
88
implicit none
99

1010
private
11-
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
11+
public :: conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d
1212

1313
interface input
1414

@@ -125,6 +125,25 @@ module function flatten() result(res)
125125
!! Resulting layer instance
126126
end function flatten
127127

128+
module function flatten2d() result(res)
129+
!! Flatten (2-d -> 1-d) layer constructor.
130+
!!
131+
!! Use this layer to chain layers with 2-d outputs to layers with 2-d
132+
!! inputs.
133+
!!
134+
!! A flatten layer must not be the first layer in the network.
135+
!!
136+
!! Example:
137+
!!
138+
!! ```
139+
!! use nf, only :: flatten, layer
140+
!! type(layer) :: flatten_layer
141+
!! flatten_layer = flatten()
142+
!! ```
143+
type(layer) :: res
144+
!! Resulting layer instance
145+
end function flatten2d
146+
128147
module function conv2d(filters, kernel_size, activation) result(res)
129148
!! 2-d convolutional layer constructor.
130149
!!

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use nf_conv2d_layer, only: conv2d_layer
55
use nf_dense_layer, only: dense_layer
66
use nf_flatten_layer, only: flatten_layer
7+
use nf_flatten2d_layer, only: flatten2d_layer
78
use nf_input1d_layer, only: input1d_layer
89
use nf_input2d_layer, only: input2d_layer
910
use nf_input3d_layer, only: input3d_layer
@@ -72,6 +73,13 @@ module function flatten() result(res)
7273
end function flatten
7374

7475

76+
module function flatten2d() result(res)
77+
type(layer) :: res
78+
res % name = 'flatten2d'
79+
allocate(res % p, source=flatten2d_layer())
80+
end function flatten2d
81+
82+
7583
module function input1d(layer_size) result(res)
7684
integer, intent(in) :: layer_size
7785
type(layer) :: res

test/test_flatten2d_layer.f90

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
program test_flatten2d_layer
2+
3+
use iso_fortran_env, only: stderr => error_unit
4+
use nf, only: dense, flatten2d, input, layer, network
5+
use nf_flatten2d_layer, only: flatten2d_layer
6+
use nf_input2d_layer, only: input2d_layer
7+
8+
implicit none
9+
10+
type(layer) :: test_layer, input_layer
11+
type(network) :: net
12+
real, allocatable :: gradient(:,:)
13+
real, allocatable :: output(:)
14+
logical :: ok = .true.
15+
16+
test_layer = flatten2d()
17+
18+
if (.not. test_layer % name == 'flatten2d') then
19+
ok = .false.
20+
write(stderr, '(a)') 'flatten2d layer has its name set correctly.. failed'
21+
end if
22+
23+
if (test_layer % initialized) then
24+
ok = .false.
25+
write(stderr, '(a)') 'flatten2d layer is not initialized yet.. failed'
26+
end if
27+
28+
input_layer = input(1, 2)
29+
call test_layer % init(input_layer)
30+
31+
if (.not. test_layer % initialized) then
32+
ok = .false.
33+
write(stderr, '(a)') 'flatten2d layer is now initialized.. failed'
34+
end if
35+
36+
if (.not. all(test_layer % layer_shape == [2])) then
37+
ok = .false.
38+
write(stderr, '(a)') 'flatten2d layer has an incorrect output shape.. failed'
39+
end if
40+
41+
! Test forward pass - reshaping from 2-d to 1-d
42+
43+
select type(this_layer => input_layer % p); type is(input2d_layer)
44+
call this_layer % set(reshape(real([1, 2, 3, 4]), [2, 2]))
45+
end select
46+
47+
call test_layer % forward(input_layer)
48+
call test_layer % get_output(output)
49+
50+
if (.not. all(output == [1, 2, 3, 4])) then
51+
ok = .false.
52+
write(stderr, '(a)') 'flatten2d layer correctly propagates forward.. failed'
53+
end if
54+
55+
! Test backward pass - reshaping from 1-d to 2-d
56+
57+
! Calling backward() will set the values on the gradient component
58+
! input_layer is used only to determine shape
59+
call test_layer % backward(input_layer, real([1, 2, 3, 4]))
60+
61+
select type(this_layer => test_layer % p); type is(flatten2d_layer)
62+
gradient = this_layer % gradient
63+
end select
64+
65+
if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [2, 2]))) then
66+
ok = .false.
67+
write(stderr, '(a)') 'flatten2d layer correctly propagates backward.. failed'
68+
end if
69+
70+
net = network([ &
71+
input(28, 28), &
72+
flatten2d(), &
73+
dense(10) &
74+
])
75+
76+
! Test that the output layer receives 784 elements in the input
77+
if (.not. all(net % layers(3) % input_layer_shape == [784])) then
78+
ok = .false.
79+
write(stderr, '(a)') 'flatten2d layer correctly chains input2d to dense.. failed'
80+
end if
81+
82+
if (ok) then
83+
print '(a)', 'test_flatten2d_layer: All tests passed.'
84+
else
85+
write(stderr, '(a)') 'test_flatten2d_layer: One or more tests failed.'
86+
stop 1
87+
end if
88+
89+
end program test_flatten2d_layer

0 commit comments

Comments
 (0)