@@ -14,7 +14,7 @@ program test_multihead_attention_layer
1414
1515 attention = multihead_attention_layer(sequence_length= 3 , model_dimension= 4 , n_heads= 2 )
1616 call attention % init([0 ])
17- !
17+
1818 call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
1919 call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
2020 call test_multihead_attention_normalization(attention, ok)
@@ -23,7 +23,7 @@ program test_multihead_attention_layer
2323 call test_multihead_attention_forward(attention, ok)
2424 call test_multihead_attention_backward(attention, ok)
2525 call test_multihead_attention_update_gradients(attention, ok)
26- ! call test_multihead_attention_forward_reallife_shape(ok)
26+ call test_multihead_attention_forward_reallife_shape(ok)
2727
2828contains
2929 subroutine test_multihead_attention_split_heads (attention , input , ok , output )
@@ -156,7 +156,7 @@ subroutine test_multihead_attention_forward(attention, ok)
156156 0.447508544 , 0.464612424 , 0.464721352 , 0.473546445 , 0.512576580 , 0.513393998 &
157157 ]
158158
159- call attention % forward (input, input, input)
159+ call attention % common_forward (input, input, input)
160160
161161 output_shape = shape (attention % output)
162162 if (.not. all (output_shape.eq. expected_shape)) then
@@ -196,7 +196,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
196196 attention = multihead_attention_layer(sequence_length= 148 , model_dimension= 512 , n_heads= 8 )
197197 call attention % init([0 ])
198198
199- call attention % forward (input, input, input)
199+ call attention % common_forward (input, input, input)
200200
201201 output_shape = shape (attention % output)
202202 if (.not. all (output_shape.eq. expected_shape)) then
@@ -221,7 +221,7 @@ subroutine test_multihead_attention_backward(attention, ok)
221221 real :: output_flat(12 )
222222 real :: output_shape(2 )
223223
224- call attention % backward (input, gradient)
224+ call attention % common_backward (input, gradient)
225225
226226 ! sample for Self Attention: sum of output gradients
227227 ! FIXME: remove reshapes when linear2d situation is resolved
@@ -271,7 +271,7 @@ subroutine test_multihead_attention_update_gradients(attention, ok)
271271 call optim % minimize(parameters, attention % get_gradients())
272272 call attention % set_params(parameters)
273273
274- call attention % forward (&
274+ call attention % common_forward (&
275275 reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ]),&
276276 reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ]),&
277277 reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ])&
0 commit comments