@@ -8,22 +8,15 @@ program test_multihead_attention_layer
88 type (multihead_attention_layer) :: attention
99 real :: sample_input(3 , 4 , 1 ) = reshape ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 , 1 ])
1010 real :: split_heads_output(2 , 3 , 2 , 1 )
11- real :: raw_attention_matrix(2 , 3 , 3 , 1 )
12- real :: normalized_attention_matrix(2 , 3 , 3 , 1 )
13- real :: scaled_dp_att(2 , 3 , 2 , 1 )
14- real :: scaled_dp_att_reshaped(1 , 3 , 2 , 2 )
15- real :: combined_attention(3 , 4 , 1 )
1611
1712 attention = multihead_attention_layer(batch_size= 1 , sequence_length= 3 , model_dimension= 4 , n_heads= 2 )
1813 call attention % init([0 ])
1914
2015 call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
21- call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix)
22- call test_multihead_attention_normalization(attention, raw_attention_matrix, ok, normalized_attention_matrix)
23- call test_multihead_attention_scaled_dot_product_attention(&
24- attention, normalized_attention_matrix, split_heads_output, ok, scaled_dp_att&
25- )
26- call test_multihead_attention_combine_heads(attention, scaled_dp_att, ok)
16+ call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
17+ call test_multihead_attention_normalization(attention, ok)
18+ call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok)
19+ call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2720 call test_multihead_attention_forward(attention, ok)
2821 call test_multihead_attention_forward_reallife_shape(ok)
2922
@@ -52,11 +45,10 @@ subroutine test_multihead_attention_split_heads(attention, input, ok, output)
5245 end if
5346 end subroutine test_multihead_attention_split_heads
5447
55- subroutine test_multihead_attention_create_attention_matrix (attention , input , ok , attention_matrix )
48+ subroutine test_multihead_attention_create_attention_matrix (attention , input , ok )
5649 type (multihead_attention_layer), intent (in ) :: attention
5750 real , intent (in ) :: input(:, :, :, :)
5851 logical , intent (in out ) :: ok
59- real , intent (in out ) :: attention_matrix(2 , 3 , 3 , 1 )
6052 real :: attention_matrix_shape(4 )
6153 real :: attention_matrix_flat(18 )
6254 real :: expected_shape(4 ) = [2 , 3 , 3 , 1 ]
@@ -69,58 +61,52 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok
6961 0.573199987 , 0.289999992 , 0.654400051 &
7062 ]
7163
72- attention_matrix = attention % create_attention_matrix(input, input)
64+ call attention % create_attention_matrix(input, input)
7365
74- attention_matrix_shape = shape (attention_matrix)
66+ attention_matrix_shape = shape (attention % attention_matrix)
7567 if (.not. all (attention_matrix_shape.eq. expected_shape)) then
7668 ok = .false.
7769 write (stderr, ' (a)' ) ' create_attention_matrix returned incorrect shape.. failed'
7870 end if
79- attention_matrix_flat = reshape (attention_matrix, shape (expected_attention_matrix_flat))
71+ attention_matrix_flat = reshape (attention % attention_matrix, shape (expected_attention_matrix_flat))
8072 if (.not. all (attention_matrix_flat.eq. expected_attention_matrix_flat)) then
8173 ok = .false.
8274 write (stderr, ' (a)' ) ' create_attention_matrix returned incorrect values.. failed'
8375 end if
8476 end subroutine test_multihead_attention_create_attention_matrix
8577
86- subroutine test_multihead_attention_normalization (attention , input , ok , output )
78+ subroutine test_multihead_attention_normalization (attention , ok )
8779 type (multihead_attention_layer), intent (in ) :: attention
88- real , intent (in ) :: input(:, :, :, :)
8980 logical , intent (in out ) :: ok
90- real , intent (out ) :: output(2 , 3 , 3 , 1 )
9181 real :: output_flat(18 )
9282 real :: expected_output_flat(18 ) = [&
9383 0.326287806 , 0.435975075 , 0.321620107 , 0.330339342 , 0.316976935 , 0.329200655 ,&
9484 0.333283335 , 0.275134116 , 0.333194464 , 0.326415271 , 0.333061278 , 0.325773478 ,&
9585 0.340428889 , 0.288890868 , 0.345185399 , 0.343245387 , 0.349961787 , 0.345025837 &
9686 ]
9787
98- output = attention % normalize_attention_matrix(input )
88+ call attention % normalize_attention_matrix()
9989
100- output_flat = reshape (output , shape (output_flat))
90+ output_flat = reshape (attention % attention_matrix , shape (output_flat))
10191 if (.not. all (output_flat.eq. expected_output_flat)) then
10292 ok = .false.
10393 write (stderr, ' (a)' ) ' normalize_attention_matrix returned incorrect values.. failed'
10494 end if
10595 end subroutine test_multihead_attention_normalization
10696
107- subroutine test_multihead_attention_scaled_dot_product_attention (attention , attention_matrix , value , ok , output )
97+ subroutine test_multihead_attention_scaled_dot_product_attention (attention , value , ok )
10898 type (multihead_attention_layer), intent (in ) :: attention
109- real , intent (in ) :: attention_matrix(:, :, :, :)
11099 real , intent (in ) :: value(:, :, :, :)
111100 logical , intent (in out ) :: ok
112- real , intent (out ) :: output(&
113- attention % n_heads, attention % sequence_length, attention % head_size, attention % batch_size&
114- )
115101 real :: output_flat(12 )
116102 real :: expected_output_flat(12 ) = [&
117103 0.101414114 , 0.685291648 , 0.102356531 , 0.701290607 , 0.103298485 , 0.701582491 ,&
118104 0.401414126 , 0.457309216 , 0.402356505 , 0.374400526 , 0.403298497 , 0.373518765 &
119105 ]
120106
121- output = attention % scaled_dot_product_attention(attention_matrix, value)
107+ call attention % scaled_dot_product_attention(value)
122108
123- output_flat = reshape (output , shape (output_flat))
109+ output_flat = reshape (attention % sdpa , shape (output_flat))
124110 if (.not. all (output_flat.eq. expected_output_flat)) then
125111 ok = .false.
126112 write (stderr, ' (a)' ) ' scaled_dot_product_attention returned incorrect values.. failed'
0 commit comments