Skip to content

Commit d4731a1

Browse files
committed
embedding_layer: implementation of embedding layer
1 parent 1c54cf0 commit d4731a1

File tree

3 files changed

+56
-25
lines changed

3 files changed

+56
-25
lines changed

src/nf/nf_embedding_layer.f90

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ module nf_embedding_layer
99
public :: embedding_layer
1010

1111
type, extends(base_layer) :: embedding_layer
12+
!! Embedding Layer
13+
!! Stores inputs as a trainable lookup table. Inputs are
14+
!! integer indicies in a dictionary of `vocab_size`.
15+
!! This layer converts them into a table of shape
16+
!! (`sequence_length`, `model_dimension`)
1217
integer :: sequence_length, vocab_size, model_dimension
1318

1419
real, allocatable :: weights(:, :)
@@ -29,24 +34,25 @@ module nf_embedding_layer
2934
end type embedding_layer
3035

3136
interface embedding_layer
32-
module function embedding_layer_cons(&
33-
sequence_length, vocab_size, model_dimension&
34-
) result(res)
35-
integer, intent(in) :: sequence_length, vocab_size, model_dimension
37+
module function embedding_layer_cons(vocab_size, model_dimension) result(res)
38+
integer, intent(in) :: vocab_size, model_dimension
3639
type(embedding_layer) :: res
3740
end function embedding_layer_cons
3841
end interface embedding_layer
3942

4043
interface
4144
pure module subroutine forward(self, input)
45+
!! Get vectors by indicis in the dictionary
4246
class(embedding_layer), intent(in out) :: self
4347
integer, intent(in) :: input(:)
4448
end subroutine forward
4549

4650
pure module subroutine backward(self, input, gradient)
51+
!! Update gradient at `input` indices
52+
!! dw_i = W_i + d_output_i
4753
class(embedding_layer), intent(in out) :: self
4854
integer, intent(in) :: input(:)
49-
real, intent(in) :: gradient(:)
55+
real, intent(in) :: gradient(:, :)
5056
end subroutine backward
5157

5258
module subroutine init(self, input_shape)

src/nf/nf_embedding_submodule.f90

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,20 @@
22
use nf_base_layer, only: base_layer
33
implicit none
44
contains
5-
module function embedding_layer_cons(&
6-
sequence_length, vocab_size, model_dimension&
7-
) result(res)
8-
integer, intent(in) :: sequence_length, vocab_size, model_dimension
5+
module function embedding_layer_cons(vocab_size, model_dimension) result(res)
6+
integer, intent(in) :: vocab_size, model_dimension
97
type(embedding_layer) :: res
108

119
res % vocab_size = vocab_size
1210
res % model_dimension = model_dimension
13-
res % sequence_length = sequence_length
1411
end function embedding_layer_cons
1512

1613
module subroutine init(self, input_shape)
1714
class(embedding_layer), intent(in out) :: self
1815
integer, intent(in) :: input_shape(:)
1916

17+
self % sequence_length = input_shape(1)
18+
2019
allocate(self % output(self % sequence_length, self % model_dimension))
2120
allocate(self % gradient(self % sequence_length, self % vocab_size))
2221

@@ -30,32 +29,34 @@ end subroutine init
3029
pure module subroutine forward(self, input)
3130
class(embedding_layer), intent(in out) :: self
3231
integer, intent(in) :: input(:)
33-
integer :: i
32+
integer :: i, index
3433

3534
do concurrent(i = 1: self % sequence_length)
36-
self % output(i, :) = self % weights(input(i), :)
35+
index = input(i)
36+
if (index > size(self % weights, 1)) then
37+
index = 1
38+
end if
39+
self % output(i, :) = self % weights(index, :)
3740
end do
3841
end subroutine forward
3942

4043
pure module subroutine backward(self, input, gradient)
4144
class(embedding_layer), intent(in out) :: self
4245
integer, intent(in) :: input(:)
43-
real, intent(in) :: gradient(:)
44-
real :: db(self % model_dimension)
45-
real :: dw(self % vocab_size, self % model_dimension)
46+
real, intent(in) :: gradient(:, :)
4647
integer :: i
48+
49+
do concurrent(i = 1: self % sequence_length)
50+
self % dw(input(i), :) = self % dw(input(i), :) + gradient(i, :)
51+
end do
4752
end subroutine backward
4853

4954
pure module function get_num_params(self) result(num_params)
5055
class(embedding_layer), intent(in) :: self
5156
integer :: num_params
52-
53-
! Number of weigths times number of biases
54-
num_params = self % vocab_size * self % model_dimension + self % model_dimension
55-
57+
num_params = self % vocab_size * self % model_dimension
5658
end function get_num_params
5759

58-
5960
module function get_params(self) result(params)
6061
class(embedding_layer), intent(in), target :: self
6162
real, allocatable :: params(:)
@@ -65,7 +66,6 @@ module function get_params(self) result(params)
6566
params = [w_]
6667
end function get_params
6768

68-
6969
module function get_gradients(self) result(gradients)
7070
class(embedding_layer), intent(in), target :: self
7171
real, allocatable :: gradients(:)
@@ -75,7 +75,6 @@ module function get_gradients(self) result(gradients)
7575
gradients = [dw_]
7676
end function get_gradients
7777

78-
7978
module subroutine set_params(self, params)
8079
class(embedding_layer), intent(in out) :: self
8180
real, intent(in), target :: params(:)

test/test_embedding_layer.f90

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,36 @@ program test_embedding_layer
55

66
logical :: ok = .true.
77
integer :: sample_input(3) = [2, 1, 3]
8+
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
9+
real :: output_flat(6)
10+
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
11+
real :: dw_flat(8)
12+
real :: expected_dw_flat(8) = reshape([0.2, 0.1, 0.3, 0., 0.6, 0.4, 0.6, 0.], [8])
813
type(embedding_layer) :: embedding
914

10-
embedding = embedding_layer(sequence_length=3, vocab_size=4, model_dimension=2)
11-
call embedding % init([0])
15+
embedding = embedding_layer(vocab_size=4, model_dimension=2)
16+
call embedding % init([3])
1217
embedding % weights = reshape([0.1, 0.3, 0.5, 0.7, 0.2, 0.4, 0.6, 0.8], [4, 2])
18+
1319
call embedding % forward(sample_input)
14-
end program test_embedding_layer
20+
21+
output_flat = reshape(embedding % output, [6])
22+
if (.not. all(output_flat.eq.expected_output_flat)) then
23+
ok = .false.
24+
write(stderr, '(a)') 'forward returned incorrect values.. failed'
25+
end if
26+
27+
call embedding % backward(sample_input, sample_gradient)
28+
dw_flat = reshape(embedding % dw, shape(dw_flat))
29+
if (.not. all(dw_flat.eq.expected_dw_flat)) then
30+
ok = .false.
31+
write(stderr, '(a)') 'backward returned incorrect dw values.. failed'
32+
end if
33+
34+
if (ok) then
35+
print '(a)', 'test_embedding_layer: All tests passed.'
36+
else
37+
write(stderr, '(a)') 'test_embedding_layer: One or more tests failed.'
38+
stop 1
39+
end if
40+
end program test_embedding_layer

0 commit comments

Comments
 (0)