@@ -16,6 +16,7 @@ program test_multihead_attention_layer
1616
1717 attention = multihead_attention_layer(sequence_length= 3 , model_dimension= 4 , n_heads= 2 )
1818 call attention % init_base([0 ])
19+ call set_weights(attention)
1920
2021 call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
2122 call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
@@ -30,6 +31,18 @@ program test_multihead_attention_layer
3031 call test_cross_attention(ok)
3132
3233contains
34+ subroutine set_weights (attention )
35+ type (multihead_attention_layer), intent (in out ) :: attention
36+ attention % query_layer % weights = 0.1
37+ attention % key_layer % weights = 0.1
38+ attention % value_layer % weights = 0.1
39+ attention % output_layer % weights = 0.1
40+ attention % query_layer % biases = 0.11
41+ attention % key_layer % biases = 0.11
42+ attention % value_layer % biases = 0.11
43+ attention % output_layer % biases = 0.11
44+ end subroutine set_weights
45+
3346 subroutine test_multihead_attention_split_heads (attention , input , ok , output )
3447 type (multihead_attention_layer), intent (in ) :: attention
3548 real , intent (in ) :: input(:, :)
@@ -199,6 +212,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
199212
200213 attention = multihead_attention_layer(sequence_length= 148 , model_dimension= 512 , n_heads= 8 )
201214 call attention % init_base([0 ])
215+ call set_weights(attention)
202216
203217 call attention % common_forward(input, input, input)
204218
@@ -305,6 +319,14 @@ subroutine test_self_attention(ok)
305319
306320 attention = self_attention_layer(sequence_length= 2 , model_dimension= 3 , n_heads= 1 )
307321 call attention % init([0 ])
322+ attention % query_layer % weights = 0.1
323+ attention % key_layer % weights = 0.1
324+ attention % value_layer % weights = 0.1
325+ attention % output_layer % weights = 0.1
326+ attention % query_layer % biases = 0.11
327+ attention % key_layer % biases = 0.11
328+ attention % value_layer % biases = 0.11
329+ attention % output_layer % biases = 0.11
308330
309331 call attention % forward(input)
310332 output_flat = reshape (attention % output, shape (output_flat))
@@ -346,6 +368,14 @@ subroutine test_cross_attention(ok)
346368
347369 attention = cross_attention_layer(sequence_length= 2 , model_dimension= 3 , n_heads= 1 )
348370 call attention % init([0 ])
371+ attention % query_layer % weights = 0.1
372+ attention % key_layer % weights = 0.1
373+ attention % value_layer % weights = 0.1
374+ attention % output_layer % weights = 0.1
375+ attention % query_layer % biases = 0.11
376+ attention % key_layer % biases = 0.11
377+ attention % value_layer % biases = 0.11
378+ attention % output_layer % biases = 0.11
349379
350380 call attention % forward(input)
351381 output_flat = reshape (attention % output, shape (output_flat))
0 commit comments