Skip to content

Commit 79abce3

Browse files
committed
move linear2d layer logic into submodule
1 parent a27ec09 commit 79abce3

File tree

2 files changed

+154
-124
lines changed

2 files changed

+154
-124
lines changed

src/nf/nf_linear2d_layer.f90

Lines changed: 27 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ module nf_linear2d_layer
3131
end type linear2d_layer
3232

3333
interface linear2d_layer
34-
module function linear2d_layer_cons(in_features, out_features) &
35-
result(res)
36-
integer, intent(in) :: in_features, out_features
34+
module function linear2d_layer_cons(&
35+
batch_size, sequence_length, in_features, out_features&
36+
) result(res)
37+
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
3738
type(linear2d_layer) :: res
3839
end function linear2d_layer_cons
3940
end interface linear2d_layer
@@ -44,133 +45,35 @@ pure module subroutine forward(self, input)
4445
real, intent(in) :: input(:, :, :)
4546
end subroutine forward
4647

48+
pure module subroutine backward(self, input, gradient)
49+
class(linear2d_layer), intent(in out) :: self
50+
real, intent(in) :: input(:, :, :)
51+
real, intent(in) :: gradient(:, :, :)
52+
end subroutine backward
53+
4754
module subroutine init(self, input_shape)
4855
class(linear2d_layer), intent(in out) :: self
4956
integer, intent(in) :: input_shape(:)
5057
end subroutine init
51-
end interface
52-
53-
contains
54-
module function linear2d_layer_cons(&
55-
batch_size, sequence_length, in_features, out_features&
56-
) result(res)
57-
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
58-
type(linear2d_layer) :: res
59-
60-
res % in_features = in_features
61-
res % out_features = out_features
62-
res % sequence_length = sequence_length
63-
res % batch_size = batch_size
64-
end function linear2d_layer_cons
65-
66-
module subroutine init(self, input_shape)
67-
class(linear2d_layer), intent(in out) :: self
68-
integer, intent(in) :: input_shape(:)
69-
70-
allocate(self % output(self % batch_size, self % sequence_length, self % out_features))
71-
allocate(self % gradient(self % batch_size, self % sequence_length, self % in_features))
72-
73-
allocate(self % weights(self % in_features, self % out_features))
74-
self % weights = 0.1
75-
76-
allocate(self % biases(self % out_features))
77-
self%biases = 0.11
78-
79-
allocate(self % dw(self % in_features, self % out_features))
80-
self % dw = 0.0
81-
allocate(self % db(self % out_features))
82-
self % db = 0.0
83-
end subroutine init
84-
85-
pure module subroutine forward(self, input)
86-
class(linear2d_layer), intent(in out) :: self
87-
real, intent(in) :: input(:, :, :)
88-
integer :: i, j
89-
90-
do concurrent(i = 1: self % batch_size)
91-
self % output(i, :, :) = matmul(input(i, :, :), self % weights)
92-
end do
93-
do concurrent(i = 1: self % batch_size, j = 1: self % sequence_length)
94-
self % output(i, j, :) = self % output(i, j, :) + self % biases
95-
end do
96-
end subroutine forward
97-
98-
pure module subroutine backward(self, input, gradient)
99-
class(linear2d_layer), intent(in out) :: self
100-
real, intent(in) :: input(:, :, :)
101-
real, intent(in) :: gradient(:, :, :)
102-
real :: db(self % out_features)
103-
real :: dw(self % in_features, self % out_features)
104-
integer :: i
105-
106-
do concurrent(i = 1: self % batch_size)
107-
self % dw = self % dw + matmul(transpose(input(i, :, :)), gradient(i, :, :))
108-
self % db = self % db + sum(gradient(i, :, :), 1)
109-
self % gradient(i, :, :) = matmul(gradient(i, :, :), transpose(self % weights))
110-
end do
111-
end subroutine backward
112-
113-
pure module function get_num_params(self) result(num_params)
114-
class(linear2d_layer), intent(in) :: self
115-
integer :: num_params
116-
117-
! Number of weigths times number of biases
118-
num_params = self % in_features * self % out_features + self % out_features
119-
120-
end function get_num_params
12158

59+
pure module function get_num_params(self) result(num_params)
60+
class(linear2d_layer), intent(in) :: self
61+
integer :: num_params
62+
end function get_num_params
12263

123-
module function get_params(self) result(params)
124-
class(linear2d_layer), intent(in), target :: self
125-
real, allocatable :: params(:)
64+
module function get_params(self) result(params)
65+
class(linear2d_layer), intent(in), target :: self
66+
real, allocatable :: params(:)
67+
end function get_params
12668

127-
real, pointer :: w_(:) => null()
69+
module function get_gradients(self) result(gradients)
70+
class(linear2d_layer), intent(in), target :: self
71+
real, allocatable :: gradients(:)
72+
end function get_gradients
12873

129-
w_(1:size(self % weights)) => self % weights
130-
131-
params = [ &
132-
w_, &
133-
self % biases &
134-
]
135-
136-
end function get_params
137-
138-
139-
module function get_gradients(self) result(gradients)
140-
class(linear2d_layer), intent(in), target :: self
141-
real, allocatable :: gradients(:)
142-
143-
real, pointer :: dw_(:) => null()
144-
145-
dw_(1:size(self % dw)) => self % dw
146-
147-
gradients = [ &
148-
dw_, &
149-
self % db &
150-
]
151-
152-
end function get_gradients
153-
154-
155-
module subroutine set_params(self, params)
156-
class(linear2d_layer), intent(in out) :: self
157-
real, intent(in), target :: params(:)
158-
159-
real, pointer :: p_(:,:) => null()
160-
161-
! check if the number of parameters is correct
162-
if (size(params) /= self % get_num_params()) then
163-
error stop 'Error: number of parameters does not match'
164-
end if
165-
166-
associate(n => self % in_features * self % out_features)
167-
! reshape the weights
168-
p_(1:self % in_features, 1:self % out_features) => params(1 : n)
169-
self % weights = p_
170-
171-
! reshape the biases
172-
self % biases = params(n + 1 : n + self % out_features)
173-
end associate
174-
175-
end subroutine set_params
74+
module subroutine set_params(self, params)
75+
class(linear2d_layer), intent(in out) :: self
76+
real, intent(in), target :: params(:)
77+
end subroutine set_params
78+
end interface
17679
end module nf_linear2d_layer
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
submodule(nf_linear2d_layer) nf_linear2d_layer_submodule
2+
use nf_base_layer, only: base_layer
3+
implicit none
4+
contains
5+
module function linear2d_layer_cons(&
6+
batch_size, sequence_length, in_features, out_features&
7+
) result(res)
8+
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
9+
type(linear2d_layer) :: res
10+
11+
res % in_features = in_features
12+
res % out_features = out_features
13+
res % sequence_length = sequence_length
14+
res % batch_size = batch_size
15+
end function linear2d_layer_cons
16+
17+
module subroutine init(self, input_shape)
18+
class(linear2d_layer), intent(in out) :: self
19+
integer, intent(in) :: input_shape(:)
20+
21+
allocate(self % output(self % batch_size, self % sequence_length, self % out_features))
22+
allocate(self % gradient(self % batch_size, self % sequence_length, self % in_features))
23+
24+
allocate(self % weights(self % in_features, self % out_features))
25+
self % weights = 0.1
26+
27+
allocate(self % biases(self % out_features))
28+
self%biases = 0.11
29+
30+
allocate(self % dw(self % in_features, self % out_features))
31+
self % dw = 0.0
32+
allocate(self % db(self % out_features))
33+
self % db = 0.0
34+
end subroutine init
35+
36+
pure module subroutine forward(self, input)
37+
class(linear2d_layer), intent(in out) :: self
38+
real, intent(in) :: input(:, :, :)
39+
integer :: i, j
40+
41+
do concurrent(i = 1: self % batch_size)
42+
self % output(i, :, :) = matmul(input(i, :, :), self % weights)
43+
end do
44+
do concurrent(i = 1: self % batch_size, j = 1: self % sequence_length)
45+
self % output(i, j, :) = self % output(i, j, :) + self % biases
46+
end do
47+
end subroutine forward
48+
49+
pure module subroutine backward(self, input, gradient)
50+
class(linear2d_layer), intent(in out) :: self
51+
real, intent(in) :: input(:, :, :)
52+
real, intent(in) :: gradient(:, :, :)
53+
real :: db(self % out_features)
54+
real :: dw(self % in_features, self % out_features)
55+
integer :: i
56+
57+
do concurrent(i = 1: self % batch_size)
58+
self % dw = self % dw + matmul(transpose(input(i, :, :)), gradient(i, :, :))
59+
self % db = self % db + sum(gradient(i, :, :), 1)
60+
self % gradient(i, :, :) = matmul(gradient(i, :, :), transpose(self % weights))
61+
end do
62+
end subroutine backward
63+
64+
pure module function get_num_params(self) result(num_params)
65+
class(linear2d_layer), intent(in) :: self
66+
integer :: num_params
67+
68+
! Number of weigths times number of biases
69+
num_params = self % in_features * self % out_features + self % out_features
70+
71+
end function get_num_params
72+
73+
74+
module function get_params(self) result(params)
75+
class(linear2d_layer), intent(in), target :: self
76+
real, allocatable :: params(:)
77+
78+
real, pointer :: w_(:) => null()
79+
80+
w_(1:size(self % weights)) => self % weights
81+
82+
params = [ &
83+
w_, &
84+
self % biases &
85+
]
86+
87+
end function get_params
88+
89+
90+
module function get_gradients(self) result(gradients)
91+
class(linear2d_layer), intent(in), target :: self
92+
real, allocatable :: gradients(:)
93+
94+
real, pointer :: dw_(:) => null()
95+
96+
dw_(1:size(self % dw)) => self % dw
97+
98+
gradients = [ &
99+
dw_, &
100+
self % db &
101+
]
102+
103+
end function get_gradients
104+
105+
106+
module subroutine set_params(self, params)
107+
class(linear2d_layer), intent(in out) :: self
108+
real, intent(in), target :: params(:)
109+
110+
real, pointer :: p_(:,:) => null()
111+
112+
! check if the number of parameters is correct
113+
if (size(params) /= self % get_num_params()) then
114+
error stop 'Error: number of parameters does not match'
115+
end if
116+
117+
associate(n => self % in_features * self % out_features)
118+
! reshape the weights
119+
p_(1:self % in_features, 1:self % out_features) => params(1 : n)
120+
self % weights = p_
121+
122+
! reshape the biases
123+
self % biases = params(n + 1 : n + self % out_features)
124+
end associate
125+
126+
end subroutine set_params
127+
end submodule nf_linear2d_layer_submodule

0 commit comments

Comments
 (0)