1010 use nf_maxpool2d_layer, only: maxpool2d_layer
1111 use nf_reshape_layer, only: reshape3d_layer
1212 use nf_linear2d_layer, only: linear2d_layer
13+ use nf_embedding_layer, only: embedding_layer
1314 use nf_optimizers, only: optimizer_base_type
1415
1516contains
@@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5051 call this_layer % backward(prev_layer % output, gradient)
5152 type is (linear2d_layer)
5253 call this_layer % backward(prev_layer % output, gradient)
54+ type is (embedding_layer)
55+ call this_layer % backward(prev_layer % output, gradient)
5356 end select
5457
5558 end select
@@ -70,6 +73,8 @@ pure module subroutine backward_2d(self, previous, gradient)
7073 select type (prev_layer = > previous % p)
7174 type is (input2d_layer)
7275 call this_layer % backward(prev_layer % output, gradient)
76+ type is (embedding_layer)
77+ call this_layer % backward(prev_layer % output, gradient)
7378 type is (linear2d_layer)
7479 call this_layer % backward(prev_layer % output, gradient)
7580 end select
@@ -217,6 +222,8 @@ pure module subroutine forward(self, input)
217222 select type (prev_layer = > input % p)
218223 type is (input2d_layer)
219224 call this_layer % forward(prev_layer % output)
225+ type is (embedding_layer)
226+ call this_layer % forward(prev_layer % output)
220227 type is (linear2d_layer)
221228 call this_layer % forward(prev_layer % output)
222229 end select
@@ -256,6 +263,8 @@ pure module subroutine get_output_2d(self, output)
256263
257264 type is (input2d_layer)
258265 allocate (output, source= this_layer % output)
266+ type is (embedding_layer)
267+ allocate (output, source= this_layer % output)
259268 type is (linear2d_layer)
260269 allocate (output, source= this_layer % output)
261270 class default
@@ -359,6 +368,8 @@ elemental module function get_num_params(self) result(num_params)
359368 num_params = 0
360369 type is (linear2d_layer)
361370 num_params = this_layer % get_num_params()
371+ type is (embedding_layer)
372+ num_params = this_layer % get_num_params()
362373 class default
363374 error stop ' Unknown layer type.'
364375 end select
@@ -388,6 +399,8 @@ module function get_params(self) result(params)
388399 ! No parameters to get.
389400 type is (linear2d_layer)
390401 params = this_layer % get_params()
402+ type is (embedding_layer)
403+ params = this_layer % get_params()
391404 class default
392405 error stop ' Unknown layer type.'
393406 end select
@@ -417,6 +430,8 @@ module function get_gradients(self) result(gradients)
417430 ! No gradients to get.
418431 type is (linear2d_layer)
419432 gradients = this_layer % get_gradients()
433+ type is (embedding_layer)
434+ gradients = this_layer % get_gradients()
420435 class default
421436 error stop ' Unknown layer type.'
422437 end select
@@ -467,6 +482,9 @@ module subroutine set_params(self, params)
467482 type is (linear2d_layer)
468483 call this_layer % set_params(params)
469484
485+ type is (embedding_layer)
486+ call this_layer % set_params(params)
487+
470488 type is (maxpool2d_layer)
471489 ! No parameters to set.
472490 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments