1212 use nf_reshape_layer, only: reshape3d_layer
1313 use nf_linear2d_layer, only: linear2d_layer
1414 use nf_self_attention_layer, only: self_attention_layer
15+ use nf_embedding_layer, only: embedding_layer
1516 use nf_optimizers, only: optimizer_base_type
1617
1718contains
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6061 call this_layer % backward(prev_layer % output, gradient)
6162 type is (self_attention_layer)
6263 call this_layer % backward(prev_layer % output, gradient)
64+ type is (embedding_layer)
65+ call this_layer % backward(prev_layer % output, gradient)
6366 end select
6467
6568 end select
@@ -80,6 +83,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8083 select type (prev_layer = > previous % p)
8184 type is (input2d_layer)
8285 call this_layer % backward(prev_layer % output, gradient)
86+ type is (embedding_layer)
87+ call this_layer % backward(prev_layer % output, gradient)
8388 type is (linear2d_layer)
8489 call this_layer % backward(prev_layer % output, gradient)
8590 type is (self_attention_layer)
@@ -91,6 +96,8 @@ pure module subroutine backward_2d(self, previous, gradient)
9196 select type (prev_layer = > previous % p)
9297 type is (input2d_layer)
9398 call this_layer % backward(prev_layer % output, gradient)
99+ type is (embedding_layer)
100+ call this_layer % backward(prev_layer % output, gradient)
94101 type is (linear2d_layer)
95102 call this_layer % backward(prev_layer % output, gradient)
96103 type is (self_attention_layer)
@@ -254,6 +261,8 @@ module subroutine forward(self, input)
254261 select type (prev_layer = > input % p)
255262 type is (input2d_layer)
256263 call this_layer % forward(prev_layer % output)
264+ type is (embedding_layer)
265+ call this_layer % forward(prev_layer % output)
257266 type is (linear2d_layer)
258267 call this_layer % forward(prev_layer % output)
259268 type is (self_attention_layer)
@@ -266,6 +275,8 @@ module subroutine forward(self, input)
266275 select type (prev_layer = > input % p)
267276 type is (input2d_layer)
268277 call this_layer % forward(prev_layer % output)
278+ type is (embedding_layer)
279+ call this_layer % forward(prev_layer % output)
269280 type is (linear2d_layer)
270281 call this_layer % forward(prev_layer % output)
271282 type is (self_attention_layer)
@@ -307,6 +318,8 @@ pure module subroutine get_output_2d(self, output)
307318
308319 type is (input2d_layer)
309320 allocate (output, source= this_layer % output)
321+ type is (embedding_layer)
322+ allocate (output, source= this_layer % output)
310323 type is (linear2d_layer)
311324 allocate (output, source= this_layer % output)
312325 type is (self_attention_layer)
@@ -425,6 +438,8 @@ elemental module function get_num_params(self) result(num_params)
425438 num_params = this_layer % get_num_params()
426439 type is (self_attention_layer)
427440 num_params = this_layer % get_num_params()
441+ type is (embedding_layer)
442+ num_params = this_layer % get_num_params()
428443 class default
429444 error stop ' Unknown layer type.'
430445 end select
@@ -458,6 +473,8 @@ module function get_params(self) result(params)
458473 params = this_layer % get_params()
459474 type is (self_attention_layer)
460475 params = this_layer % get_params()
476+ type is (embedding_layer)
477+ params = this_layer % get_params()
461478 class default
462479 error stop ' Unknown layer type.'
463480 end select
@@ -491,6 +508,8 @@ module function get_gradients(self) result(gradients)
491508 gradients = this_layer % get_gradients()
492509 type is (self_attention_layer)
493510 gradients = this_layer % get_gradients()
511+ type is (embedding_layer)
512+ gradients = this_layer % get_gradients()
494513 class default
495514 error stop ' Unknown layer type.'
496515 end select
@@ -548,6 +567,8 @@ module subroutine set_params(self, params)
548567
549568 type is (self_attention_layer)
550569 call this_layer % set_params(params)
570+ type is (embedding_layer)
571+ call this_layer % set_params(params)
551572
552573 type is (maxpool2d_layer)
553574 ! No parameters to set.
0 commit comments