Skip to content

Commit 00f94d0

Browse files
committed
layernorm: initial implementation
1 parent d516437 commit 00f94d0

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

src/nf/layernorm.f90

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
module nf_layernorm_layer
2+
use nf_activation, only: activation_function
3+
use nf_base_layer, only: base_layer
4+
5+
implicit none
6+
7+
private
8+
public :: layernorm_layer
9+
10+
type, extends(base_layer) :: layernorm_layer
11+
!! Layer Normalization
12+
!! ((x − mean(x)) / sqrt(variance(x) + eps) * gamma + beta
13+
!! Based upon `Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton(2016)`:
14+
!! https://arxiv.org/abs/1607.06450v1
15+
integer :: sequence_length
16+
integer :: model_dimension
17+
18+
real :: eps
19+
real, allocatable :: gamma(:)
20+
real, allocatable :: beta(:)
21+
22+
real, allocatable :: d_gamma(:)
23+
real, allocatable :: d_beta(:)
24+
real, allocatable :: gradient(:, :)
25+
26+
real, allocatable :: mu(:, :)
27+
real, allocatable :: sigma(:)
28+
29+
real, allocatable :: output(:, :)
30+
31+
contains
32+
procedure :: forward
33+
procedure :: backward
34+
procedure :: spread_by_sequence
35+
procedure :: spread_by_model_dim
36+
procedure :: init
37+
end type layernorm_layer
38+
39+
interface layernorm_layer
40+
module function layernorm_layer_cons(sequence_length, model_dimension) &
41+
result(res)
42+
integer, intent(in) :: sequence_length, model_dimension
43+
type(layernorm_layer) :: res
44+
end function layernorm_layer_cons
45+
end interface layernorm_layer
46+
47+
contains
48+
module function layernorm_layer_cons(sequence_length, model_dimension) &
49+
result(res)
50+
integer, intent(in) :: sequence_length, model_dimension
51+
type(layernorm_layer) :: res
52+
53+
res % sequence_length = sequence_length
54+
res % model_dimension = model_dimension
55+
res % eps = 1e-5
56+
end function layernorm_layer_cons
57+
58+
pure module subroutine forward(self, input)
59+
class(layernorm_layer), intent(in out) :: self
60+
real, intent(in) :: input(:, :)
61+
real, allocatable :: normalized(:, :)
62+
integer :: i
63+
64+
allocate(normalized(self % sequence_length, self % model_dimension))
65+
66+
! mu = x - MEAN_last_dim(x)
67+
do concurrent(i = 1: self % model_dimension)
68+
self % mu(:, i) = input(:, i) - (sum(input, dim=2) / self % model_dimension)
69+
end do
70+
71+
! square root of variance shifted be eps
72+
self % sigma = sqrt((sum(self % mu ** 2, dim=2) / self % model_dimension) + self % eps)
73+
74+
! normalize mu by variance by first axis
75+
do concurrent(i = 1: self % model_dimension)
76+
normalized(:, i) = self % mu(:, i) / self % sigma
77+
end do
78+
79+
! forward through trainable params gamma and beta
80+
do concurrent(i = 1: self % sequence_length)
81+
self % output(i, :) = normalized(i, :) * self % gamma + self % beta
82+
end do
83+
84+
deallocate(normalized)
85+
end subroutine forward
86+
87+
pure module subroutine backward(self, input, gradient)
88+
class(layernorm_layer), intent(in out) :: self
89+
real, intent(in) :: input(:, :)
90+
real, intent(in) :: gradient(:, :)
91+
real, allocatable :: one_over_sigma(:, :)
92+
real, allocatable :: gradient_by_gamma_over_sigma(:, :)
93+
94+
allocate(one_over_sigma(self % sequence_length, self % model_dimension))
95+
allocate(gradient_by_gamma_over_sigma(self % sequence_length, self % model_dimension))
96+
97+
one_over_sigma = (1 / self % spread_by_model_dim(self % sigma))
98+
gradient_by_gamma_over_sigma = gradient * self % spread_by_sequence(self % gamma) * one_over_sigma
99+
100+
! d_output/d_gamma = sum(d_output/d_y * mu/sigma)
101+
self % d_gamma = sum(gradient * self % mu * one_over_sigma, dim=1)
102+
103+
! d_output/d_beta = sum(d_output/d_y) * 1
104+
self % d_beta = sum(gradient, dim=1)
105+
106+
! From this article:
107+
! https://robotchinwag.com/posts/layer-normalization-deriving-the-gradient-for-the-backward-pass/
108+
! d_output/d_x = d_output/d_y * gamma/sigma
109+
! - d_output/d_y
110+
! - sum(d_output/d_y * gamma/sigma) / len
111+
! - mu * sum(d_output/d_y * gamma * mu * sigma^(03)) / len
112+
self % gradient = &
113+
gradient_by_gamma_over_sigma &
114+
- self % spread_by_model_dim(sum(gradient_by_gamma_over_sigma, dim=2)) / self % model_dimension &
115+
- self % mu * self % spread_by_model_dim(sum(&
116+
gradient_by_gamma_over_sigma * self % mu * (one_over_sigma ** 2),&
117+
dim=2)&
118+
) / self % model_dimension
119+
120+
deallocate(one_over_sigma)
121+
deallocate(gradient_by_gamma_over_sigma)
122+
end subroutine backward
123+
124+
pure function spread_by_sequence(self, input) result(output)
125+
class(layernorm_layer), intent(in) :: self
126+
real, intent(in) :: input(:)
127+
real :: output(self % sequence_length, self % model_dimension)
128+
129+
output = spread(input, dim=1, ncopies=self % sequence_length)
130+
end function spread_by_sequence
131+
132+
pure function spread_by_model_dim(self, input) result(output)
133+
class(layernorm_layer), intent(in) :: self
134+
real, intent(in) :: input(:)
135+
real :: output(self % sequence_length, self % model_dimension)
136+
137+
output = spread(input, dim=2, ncopies=self % model_dimension)
138+
end function spread_by_model_dim
139+
140+
module subroutine init(self, input_shape)
141+
class(layernorm_layer), intent(in out) :: self
142+
integer, intent(in) :: input_shape(:)
143+
144+
! default initialization from PyTorch
145+
allocate(self % gamma(self % model_dimension))
146+
self % gamma = 1.
147+
allocate(self % beta(self % model_dimension))
148+
self % beta = 0.
149+
150+
allocate(self % d_gamma(self % model_dimension))
151+
allocate(self % d_beta(self % model_dimension))
152+
allocate(self % gradient(self % sequence_length, self % model_dimension))
153+
154+
allocate(self % mu(self % sequence_length, self % model_dimension))
155+
allocate(self % sigma(self % sequence_length))
156+
157+
allocate(self % output(self % sequence_length, self % model_dimension))
158+
end subroutine init
159+
end module nf_layernorm_layer

test/test_layernorm.f90

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
program test_layernorm
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_layernorm_layer, only: layernorm_layer
4+
implicit none
5+
6+
logical :: ok = .true.
7+
type(layernorm_layer) :: layernorm
8+
real :: sample_input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])
9+
real :: sample_gradient(3, 4) = reshape([0.1, 3., 2., 0.1, 3., 3., 0.1, 2., 0.1, 3., 0.1, 3.], [3, 4])
10+
11+
layernorm = layernorm_layer(3, 4)
12+
call layernorm % init([0])
13+
14+
call test_layernorm_forward(layernorm, sample_input, ok)
15+
call test_layernorm_backward(layernorm, sample_input, sample_gradient, ok)
16+
17+
contains
18+
subroutine test_layernorm_forward(layernorm, input, ok)
19+
type(layernorm_layer), intent(in out) :: layernorm
20+
real, intent(in out) :: input(:, :)
21+
logical, intent(in out) :: ok
22+
real :: output_shape(2)
23+
real :: output_flat(12)
24+
real :: expected_shape(2) = [3, 4]
25+
real :: expected_output_flat(12) = [&
26+
-0.693158746, 0.939844191, -0.992156327, 1.72702277, -0.970368207, 0.971188426,&
27+
-0.552177250, 1.05800152, 1.02837324, -0.481686622, -1.02747762, -1.00740564&
28+
]
29+
30+
call layernorm % forward(input)
31+
32+
output_shape = shape(layernorm % output)
33+
if (.not. all(output_shape.eq.expected_shape)) then
34+
ok = .false.
35+
write(stderr, '(a)') 'forward returned incorrect shape.. failed'
36+
end if
37+
output_flat = reshape(layernorm % output, shape(output_flat))
38+
if (.not. all(output_flat.eq.expected_output_flat)) then
39+
ok = .false.
40+
write(stderr, '(a)') 'forward returned incorrect values.. failed'
41+
end if
42+
end subroutine test_layernorm_forward
43+
44+
subroutine test_layernorm_backward(layernorm, input, gradient, ok)
45+
type(layernorm_layer), intent(in out) :: layernorm
46+
real, intent(in out) :: input(:, :)
47+
real, intent(in out) :: gradient(:, :)
48+
logical, intent(in out) :: ok
49+
50+
real :: gradient_shape(2)
51+
real :: gradient_flat(12)
52+
real :: expected_gradient_shape(2) = [3, 4]
53+
real :: expected_gradient_flat(12) = [&
54+
-0.227230772, 0.103088334, -9.88590196E-02, -2.86390483E-02, 0.283811331, 0.277955681,&
55+
-0.215662330, -0.105019525, -0.269407451, 0.471532196, -0.281880081, 9.03107598E-02&
56+
]
57+
58+
real :: d_gamma(4)
59+
real :: expected_d_gamma(4) = [0.765904069, 0.175162792, 2.16362262, -4.57002449]
60+
real :: d_beta(4)
61+
real :: expected_d_beta(4) = [5.09999990, 6.09999990, 2.19999981, 6.09999990]
62+
63+
call layernorm % backward(input, gradient)
64+
65+
gradient_shape = shape(layernorm % gradient)
66+
if (.not. all(gradient_shape.eq.expected_gradient_shape)) then
67+
ok = .false.
68+
write(stderr, '(a)') 'backward returned incorrect gradient shape.. failed'
69+
end if
70+
gradient_flat = reshape(layernorm % gradient, shape(gradient_flat))
71+
if (.not. all(gradient_flat.eq.expected_gradient_flat)) then
72+
ok = .false.
73+
write(stderr, '(a)') 'backward returned incorrect gradient values.. failed'
74+
end if
75+
76+
if (.not. all(layernorm % d_gamma.eq.expected_d_gamma)) then
77+
ok = .false.
78+
write(stderr, '(a)') 'backward returned incorrect d_gamma values.. failed'
79+
end if
80+
if (.not. all(layernorm % d_beta.eq.expected_d_beta)) then
81+
ok = .false.
82+
write(stderr, '(a)') 'backward returned incorrect d_beta values.. failed'
83+
end if
84+
end subroutine test_layernorm_backward
85+
86+
end program test_layernorm

0 commit comments

Comments
 (0)