Skip to content

Commit 26278cd

Browse files
committed
embedding_layer: plumbing
1 parent 2bd9874 commit 26278cd

File tree

5 files changed

+51
-5
lines changed

5 files changed

+51
-5
lines changed

src/nf.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
6+
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, embedding
77
use nf_loss, only: mse, quadratic
88
use nf_metrics, only: corr, maxabs
99
use nf_network, only: network

src/nf/nf_layer_constructors.f90

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module nf_layer_constructors
88
implicit none
99

1010
private
11-
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
11+
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, embedding
1212

1313
interface input
1414

@@ -195,6 +195,11 @@ module function linear2d(out_features) result(res)
195195
!! Resulting layer instance
196196
end function linear2d
197197

198+
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
199+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
200+
type(layer) :: res
201+
end function embedding
202+
198203
end interface
199204

200205
end module nf_layer_constructors

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
13+
use nf_embedding_layer, only: embedding_layer
1314
use nf_activation, only: activation_function, relu, sigmoid
1415

1516
implicit none
@@ -160,4 +161,20 @@ module function linear2d(out_features) result(res)
160161

161162
end function linear2d
162163

164+
165+
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
166+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
167+
type(layer) :: res
168+
type(embedding_layer) :: embedding_layer_instance
169+
170+
embedding_layer_instance = embedding_layer(vocab_size, model_dimension)
171+
call embedding_layer_instance % init([sequence_length])
172+
res % name = 'embedding'
173+
res % layer_shape = [sequence_length, model_dimension]
174+
res % input_layer_shape = [integer ::]
175+
allocate(res % p, source=embedding_layer_instance)
176+
res % initialized = .true.
177+
178+
end function embedding
179+
163180
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
13+
use nf_embedding_layer, only: embedding_layer
1314
use nf_optimizers, only: optimizer_base_type
1415

1516
contains
@@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5051
call this_layer % backward(prev_layer % output, gradient)
5152
type is(linear2d_layer)
5253
call this_layer % backward(prev_layer % output, gradient)
54+
type is(embedding_layer)
55+
call this_layer % backward(prev_layer % output, gradient)
5356
end select
5457

5558
end select
@@ -70,6 +73,8 @@ pure module subroutine backward_2d(self, previous, gradient)
7073
select type(prev_layer => previous % p)
7174
type is(input2d_layer)
7275
call this_layer % backward(prev_layer % output, gradient)
76+
type is(embedding_layer)
77+
call this_layer % backward(prev_layer % output, gradient)
7378
type is(linear2d_layer)
7479
call this_layer % backward(prev_layer % output, gradient)
7580
end select
@@ -217,6 +222,8 @@ pure module subroutine forward(self, input)
217222
select type(prev_layer => input % p)
218223
type is(input2d_layer)
219224
call this_layer % forward(prev_layer % output)
225+
type is(embedding_layer)
226+
call this_layer % forward(prev_layer % output)
220227
type is(linear2d_layer)
221228
call this_layer % forward(prev_layer % output)
222229
end select
@@ -256,6 +263,8 @@ pure module subroutine get_output_2d(self, output)
256263

257264
type is(input2d_layer)
258265
allocate(output, source=this_layer % output)
266+
type is(embedding_layer)
267+
allocate(output, source=this_layer % output)
259268
type is(linear2d_layer)
260269
allocate(output, source=this_layer % output)
261270
class default
@@ -359,6 +368,8 @@ elemental module function get_num_params(self) result(num_params)
359368
num_params = 0
360369
type is (linear2d_layer)
361370
num_params = this_layer % get_num_params()
371+
type is (embedding_layer)
372+
num_params = this_layer % get_num_params()
362373
class default
363374
error stop 'Unknown layer type.'
364375
end select
@@ -388,6 +399,8 @@ module function get_params(self) result(params)
388399
! No parameters to get.
389400
type is (linear2d_layer)
390401
params = this_layer % get_params()
402+
type is (embedding_layer)
403+
params = this_layer % get_params()
391404
class default
392405
error stop 'Unknown layer type.'
393406
end select
@@ -417,6 +430,8 @@ module function get_gradients(self) result(gradients)
417430
! No gradients to get.
418431
type is (linear2d_layer)
419432
gradients = this_layer % get_gradients()
433+
type is (embedding_layer)
434+
gradients = this_layer % get_gradients()
420435
class default
421436
error stop 'Unknown layer type.'
422437
end select
@@ -467,6 +482,9 @@ module subroutine set_params(self, params)
467482
type is (linear2d_layer)
468483
call this_layer % set_params(params)
469484

485+
type is (embedding_layer)
486+
call this_layer % set_params(params)
487+
470488
type is (maxpool2d_layer)
471489
! No parameters to set.
472490
write(stderr, '(a)') 'Warning: calling set_params() ' &

src/nf/nf_network_submodule.f90

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
1111
use nf_linear2d_layer, only: linear2d_layer
12+
use nf_embedding_layer, only: embedding_layer
1213
use nf_layer, only: layer
1314
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1415
use nf_loss, only: quadratic
@@ -44,7 +45,7 @@ module function network_from_layers(layers) result(res)
4445
error stop 'Error: A network must have at least 2 layers.'
4546

4647
! The first layer must be an input layer
47-
if (.not. layers(1) % name == 'input') &
48+
if (.not. layers(1) % name == 'input' .and. .not. layers(1) % name == 'embedding') &
4849
error stop 'Error: First layer in the network must be an input layer.'
4950

5051
!TODO Ensure that the layers are in allowed sequence:
@@ -158,6 +159,8 @@ module subroutine backward(self, output, loss)
158159
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
159160
type is(linear2d_layer)
160161
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
162+
! type is(embedding_layer)
163+
! call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
161164
end select
162165
end if
163166

@@ -202,8 +205,11 @@ module subroutine forward_1d(self, input)
202205
integer :: n
203206

204207
! Set the input array into the input layer
205-
select type(input_layer => self % layers(1) % p); type is(input1d_layer)
206-
call input_layer % set(input)
208+
select type(input_layer => self % layers(1) % p)
209+
type is(input1d_layer)
210+
call input_layer % set(input)
211+
type is(embedding_layer)
212+
call input_layer % forward(nint(input))
207213
end select
208214

209215
do n = 2, size(self % layers)

0 commit comments

Comments
 (0)