22 use nf_base_layer, only: base_layer
33 implicit none
44contains
5- module function embedding_layer_cons (&
6- sequence_length , vocab_size , model_dimension&
7- ) result(res)
8- integer , intent (in ) :: sequence_length, vocab_size, model_dimension
5+ module function embedding_layer_cons (vocab_size , model_dimension ) result(res)
6+ integer , intent (in ) :: vocab_size, model_dimension
97 type (embedding_layer) :: res
108
119 res % vocab_size = vocab_size
1210 res % model_dimension = model_dimension
13- res % sequence_length = sequence_length
1411 end function embedding_layer_cons
1512
1613 module subroutine init (self , input_shape )
1714 class(embedding_layer), intent (in out ) :: self
1815 integer , intent (in ) :: input_shape(:)
1916
17+ self % sequence_length = input_shape(1 )
18+
2019 allocate (self % output(self % sequence_length, self % model_dimension))
2120 allocate (self % gradient(self % sequence_length, self % vocab_size))
2221
@@ -30,32 +29,34 @@ end subroutine init
3029 pure module subroutine forward(self, input)
3130 class(embedding_layer), intent (in out ) :: self
3231 integer , intent (in ) :: input(:)
33- integer :: i
32+ integer :: i, index
3433
3534 do concurrent(i = 1 : self % sequence_length)
36- self % output(i, :) = self % weights(input(i), :)
35+ index = input(i)
36+ if (index > size (self % weights, 1 )) then
37+ index = 1
38+ end if
39+ self % output(i, :) = self % weights(index, :)
3740 end do
3841 end subroutine forward
3942
4043 pure module subroutine backward(self, input, gradient)
4144 class(embedding_layer), intent (in out ) :: self
4245 integer , intent (in ) :: input(:)
43- real , intent (in ) :: gradient(:)
44- real :: db(self % model_dimension)
45- real :: dw(self % vocab_size, self % model_dimension)
46+ real , intent (in ) :: gradient(:, :)
4647 integer :: i
48+
49+ do concurrent(i = 1 : self % sequence_length)
50+ self % dw(input(i), :) = self % dw(input(i), :) + gradient(i, :)
51+ end do
4752 end subroutine backward
4853
4954 pure module function get_num_params(self) result(num_params)
5055 class(embedding_layer), intent (in ) :: self
5156 integer :: num_params
52-
53- ! Number of weigths times number of biases
54- num_params = self % vocab_size * self % model_dimension + self % model_dimension
55-
57+ num_params = self % vocab_size * self % model_dimension
5658 end function get_num_params
5759
58-
5960 module function get_params (self ) result(params)
6061 class(embedding_layer), intent (in ), target :: self
6162 real , allocatable :: params(:)
@@ -65,7 +66,6 @@ module function get_params(self) result(params)
6566 params = [w_]
6667 end function get_params
6768
68-
6969 module function get_gradients (self ) result(gradients)
7070 class(embedding_layer), intent (in ), target :: self
7171 real , allocatable :: gradients(:)
@@ -75,7 +75,6 @@ module function get_gradients(self) result(gradients)
7575 gradients = [dw_]
7676 end function get_gradients
7777
78-
7978 module subroutine set_params (self , params )
8079 class(embedding_layer), intent (in out ) :: self
8180 real , intent (in ), target :: params(:)
0 commit comments