Skip to content

Commit 1c54cf0

Browse files
committed
embedding_layer: initial forward implementation
1 parent ed8b340 commit 1c54cf0

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

src/nf/nf_embedding_layer.f90

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
module nf_embedding_layer
2+
3+
use nf_activation, only: activation_function
4+
use nf_base_layer, only: base_layer
5+
6+
implicit none
7+
8+
private
9+
public :: embedding_layer
10+
11+
type, extends(base_layer) :: embedding_layer
12+
integer :: sequence_length, vocab_size, model_dimension
13+
14+
real, allocatable :: weights(:, :)
15+
real, allocatable :: output(:, :)
16+
real, allocatable :: gradient(:, :) ! input gradient
17+
real, allocatable :: dw(:, :) ! weight gradients
18+
19+
contains
20+
21+
procedure :: backward
22+
procedure :: forward
23+
procedure :: init
24+
procedure :: get_num_params
25+
procedure :: get_params
26+
procedure :: get_gradients
27+
procedure :: set_params
28+
29+
end type embedding_layer
30+
31+
interface embedding_layer
32+
module function embedding_layer_cons(&
33+
sequence_length, vocab_size, model_dimension&
34+
) result(res)
35+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
36+
type(embedding_layer) :: res
37+
end function embedding_layer_cons
38+
end interface embedding_layer
39+
40+
interface
41+
pure module subroutine forward(self, input)
42+
class(embedding_layer), intent(in out) :: self
43+
integer, intent(in) :: input(:)
44+
end subroutine forward
45+
46+
pure module subroutine backward(self, input, gradient)
47+
class(embedding_layer), intent(in out) :: self
48+
integer, intent(in) :: input(:)
49+
real, intent(in) :: gradient(:)
50+
end subroutine backward
51+
52+
module subroutine init(self, input_shape)
53+
class(embedding_layer), intent(in out) :: self
54+
integer, intent(in) :: input_shape(:)
55+
end subroutine init
56+
57+
pure module function get_num_params(self) result(num_params)
58+
class(embedding_layer), intent(in) :: self
59+
integer :: num_params
60+
end function get_num_params
61+
62+
module function get_params(self) result(params)
63+
class(embedding_layer), intent(in), target :: self
64+
real, allocatable :: params(:)
65+
end function get_params
66+
67+
module function get_gradients(self) result(gradients)
68+
class(embedding_layer), intent(in), target :: self
69+
real, allocatable :: gradients(:)
70+
end function get_gradients
71+
72+
module subroutine set_params(self, params)
73+
class(embedding_layer), intent(in out) :: self
74+
real, intent(in), target :: params(:)
75+
end subroutine set_params
76+
end interface
77+
end module nf_embedding_layer

src/nf/nf_embedding_submodule.f90

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
submodule(nf_embedding_layer) nf_embedding_layer_submodule
2+
use nf_base_layer, only: base_layer
3+
implicit none
4+
contains
5+
module function embedding_layer_cons(&
6+
sequence_length, vocab_size, model_dimension&
7+
) result(res)
8+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
9+
type(embedding_layer) :: res
10+
11+
res % vocab_size = vocab_size
12+
res % model_dimension = model_dimension
13+
res % sequence_length = sequence_length
14+
end function embedding_layer_cons
15+
16+
module subroutine init(self, input_shape)
17+
class(embedding_layer), intent(in out) :: self
18+
integer, intent(in) :: input_shape(:)
19+
20+
allocate(self % output(self % sequence_length, self % model_dimension))
21+
allocate(self % gradient(self % sequence_length, self % vocab_size))
22+
23+
allocate(self % weights(self % vocab_size, self % model_dimension))
24+
self % weights = 0.1
25+
26+
allocate(self % dw(self % vocab_size, self % model_dimension))
27+
self % dw = 0.0
28+
end subroutine init
29+
30+
pure module subroutine forward(self, input)
31+
class(embedding_layer), intent(in out) :: self
32+
integer, intent(in) :: input(:)
33+
integer :: i
34+
35+
do concurrent(i = 1: self % sequence_length)
36+
self % output(i, :) = self % weights(input(i), :)
37+
end do
38+
end subroutine forward
39+
40+
pure module subroutine backward(self, input, gradient)
41+
class(embedding_layer), intent(in out) :: self
42+
integer, intent(in) :: input(:)
43+
real, intent(in) :: gradient(:)
44+
real :: db(self % model_dimension)
45+
real :: dw(self % vocab_size, self % model_dimension)
46+
integer :: i
47+
end subroutine backward
48+
49+
pure module function get_num_params(self) result(num_params)
50+
class(embedding_layer), intent(in) :: self
51+
integer :: num_params
52+
53+
! Number of weigths times number of biases
54+
num_params = self % vocab_size * self % model_dimension + self % model_dimension
55+
56+
end function get_num_params
57+
58+
59+
module function get_params(self) result(params)
60+
class(embedding_layer), intent(in), target :: self
61+
real, allocatable :: params(:)
62+
real, pointer :: w_(:) => null()
63+
64+
w_(1: product(shape(self % weights))) => self % weights
65+
params = [w_]
66+
end function get_params
67+
68+
69+
module function get_gradients(self) result(gradients)
70+
class(embedding_layer), intent(in), target :: self
71+
real, allocatable :: gradients(:)
72+
real, pointer :: dw_(:) => null()
73+
74+
dw_(1: product(shape(self % dw))) => self % dw
75+
gradients = [dw_]
76+
end function get_gradients
77+
78+
79+
module subroutine set_params(self, params)
80+
class(embedding_layer), intent(in out) :: self
81+
real, intent(in), target :: params(:)
82+
83+
real, pointer :: p_(:,:) => null()
84+
85+
! check if the number of parameters is correct
86+
if (size(params) /= self % get_num_params()) then
87+
error stop 'Error: number of parameters does not match'
88+
end if
89+
90+
associate(n => self % vocab_size * self % model_dimension)
91+
! reshape the weights
92+
p_(1:self % vocab_size, 1:self % model_dimension) => params(1 : n)
93+
self % weights = p_
94+
end associate
95+
96+
end subroutine set_params
97+
end submodule nf_embedding_layer_submodule

test/test_embedding_layer.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
program test_embedding_layer
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_embedding_layer, only: embedding_layer
4+
implicit none
5+
6+
logical :: ok = .true.
7+
integer :: sample_input(3) = [2, 1, 3]
8+
type(embedding_layer) :: embedding
9+
10+
embedding = embedding_layer(sequence_length=3, vocab_size=4, model_dimension=2)
11+
call embedding % init([0])
12+
embedding % weights = reshape([0.1, 0.3, 0.5, 0.7, 0.2, 0.4, 0.6, 0.8], [4, 2])
13+
call embedding % forward(sample_input)
14+
end program test_embedding_layer

0 commit comments

Comments
 (0)