Skip to content

Commit bdefd02

Browse files
committed
layernorm: add gradient updates
1 parent c4a3e3c commit bdefd02

File tree

3 files changed

+125
-0
lines changed

3 files changed

+125
-0
lines changed

src/nf/nf_layernorm.f90

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ module nf_layernorm_layer
3232
procedure :: forward
3333
procedure :: backward
3434
procedure :: init
35+
procedure :: get_num_params
36+
procedure :: get_params
37+
procedure :: get_gradients
38+
procedure :: set_params
3539
end type layernorm_layer
3640

3741
interface layernorm_layer
@@ -57,5 +61,28 @@ module subroutine init(self, input_shape)
5761
class(layernorm_layer), intent(in out) :: self
5862
integer, intent(in) :: input_shape(:)
5963
end subroutine init
64+
65+
pure module function get_num_params(self) result(num_params)
66+
class(layernorm_layer), intent(in) :: self
67+
integer :: num_params
68+
end function get_num_params
69+
70+
71+
module function get_params(self) result(params)
72+
class(layernorm_layer), intent(in), target :: self
73+
real, allocatable :: params(:)
74+
end function get_params
75+
76+
77+
module function get_gradients(self) result(gradients)
78+
class(layernorm_layer), intent(in), target :: self
79+
real, allocatable :: gradients(:)
80+
end function get_gradients
81+
82+
83+
module subroutine set_params(self, params)
84+
class(layernorm_layer), intent(in out) :: self
85+
real, intent(in), target :: params(:)
86+
end subroutine set_params
6087
end interface
6188
end module nf_layernorm_layer

src/nf/nf_layernorm_submodule.f90

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,52 @@ module subroutine init(self, input_shape)
107107

108108
allocate(self % output(self % sequence_length, self % model_dimension))
109109
end subroutine init
110+
111+
pure module function get_num_params(self) result(num_params)
112+
class(layernorm_layer), intent(in) :: self
113+
integer :: num_params
114+
115+
! Number of weights times number of biases
116+
num_params = 2 * self % model_dimension
117+
118+
end function get_num_params
119+
120+
121+
module function get_params(self) result(params)
122+
class(layernorm_layer), intent(in), target :: self
123+
real, allocatable :: params(:)
124+
125+
params = [ &
126+
self % gamma, &
127+
self % beta &
128+
]
129+
130+
end function get_params
131+
132+
133+
module function get_gradients(self) result(gradients)
134+
class(layernorm_layer), intent(in), target :: self
135+
real, allocatable :: gradients(:)
136+
137+
gradients = [ &
138+
self % d_gamma, &
139+
self % d_beta &
140+
]
141+
142+
end function get_gradients
143+
144+
145+
module subroutine set_params(self, params)
146+
class(layernorm_layer), intent(in out) :: self
147+
real, intent(in), target :: params(:)
148+
149+
! check if the number of parameters is correct
150+
if (size(params) /= self % get_num_params()) then
151+
error stop 'Error: number of parameters does not match'
152+
end if
153+
154+
self % gamma = params(1: self % model_dimension)
155+
self % beta = params(self % model_dimension + 1: 2 * self % model_dimension)
156+
157+
end subroutine set_params
110158
end submodule nf_layernorm_layer_submodule

test/test_layernorm.f90

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
program test_layernorm
22
use iso_fortran_env, only: stderr => error_unit
33
use nf_layernorm_layer, only: layernorm_layer
4+
use nf, only: sgd
45
implicit none
56

67
logical :: ok = .true.
@@ -13,6 +14,7 @@ program test_layernorm
1314

1415
call test_layernorm_forward(layernorm, sample_input, ok)
1516
call test_layernorm_backward(layernorm, sample_input, sample_gradient, ok)
17+
call test_layernorm_gradients(sample_input, sample_gradient, ok)
1618

1719
if (ok) then
1820
print '(a)', 'test_layernorm_layer: All tests passed.'
@@ -90,4 +92,52 @@ subroutine test_layernorm_backward(layernorm, input, gradient, ok)
9092
end if
9193
end subroutine test_layernorm_backward
9294

95+
subroutine test_layernorm_gradients(input, gradient, ok)
96+
real, intent(in out) :: input(:, :)
97+
real, intent(in out) :: gradient(:, :)
98+
logical, intent(in out) :: ok
99+
type(layernorm_layer) :: layernorm
100+
type(sgd) :: optim
101+
102+
real :: parameters(8)
103+
real :: expected_parameters(8)
104+
real :: updated_output(12)
105+
real :: expected_updated_output(12) = [&
106+
-0.738849819, 0.881645918, -1.03555739,&
107+
1.66299772, -1.02966857, 0.908487320,&
108+
-0.562230229, 1.01311040, 0.984123051,&
109+
-0.564699769, -1.13543355, -1.11444426&
110+
]
111+
112+
layernorm = layernorm_layer()
113+
call layernorm % init([3, 4])
114+
115+
call layernorm % forward(input)
116+
call layernorm % backward(input, gradient)
117+
118+
if (layernorm % get_num_params() /= 8) then
119+
ok = .false.
120+
write(stderr, '(a)') 'incorrect number of parameters.. failed'
121+
end if
122+
123+
expected_parameters(1: 4) = 1.
124+
expected_parameters(5: 8) = 0.
125+
parameters = layernorm % get_params()
126+
if (.not. all(parameters.eq.expected_parameters)) then
127+
ok = .false.
128+
write(stderr, '(a)') 'incorrect parameters.. failed'
129+
end if
130+
131+
optim = SGD(learning_rate=0.01)
132+
call optim % minimize(parameters, layernorm % get_gradients())
133+
call layernorm % set_params(parameters)
134+
135+
call layernorm % forward(input)
136+
137+
updated_output = reshape(layernorm % output, [12])
138+
if (.not. all(updated_output.eq.expected_updated_output)) then
139+
ok = .false.
140+
write(stderr, '(a)') 'incorrect output after parameters update.. failed'
141+
end if
142+
end subroutine test_layernorm_gradients
93143
end program test_layernorm

0 commit comments

Comments
 (0)