Skip to content

Commit d8d7355

Browse files
committed
multihead_attention: update tests
1 parent 25d523c commit d8d7355

File tree

1 file changed

+14
-28
lines changed

1 file changed

+14
-28
lines changed

test/test_multihead_attention_layer.f90

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,15 @@ program test_multihead_attention_layer
88
type(multihead_attention_layer) :: attention
99
real :: sample_input(3, 4, 1) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4, 1])
1010
real :: split_heads_output(2, 3, 2, 1)
11-
real :: raw_attention_matrix(2, 3, 3, 1)
12-
real :: normalized_attention_matrix(2, 3, 3, 1)
13-
real :: scaled_dp_att(2, 3, 2, 1)
14-
real :: scaled_dp_att_reshaped(1, 3, 2, 2)
15-
real :: combined_attention(3, 4, 1)
1611

1712
attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2)
1813
call attention % init([0])
1914

2015
call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
21-
call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix)
22-
call test_multihead_attention_normalization(attention, raw_attention_matrix, ok, normalized_attention_matrix)
23-
call test_multihead_attention_scaled_dot_product_attention(&
24-
attention, normalized_attention_matrix, split_heads_output, ok, scaled_dp_att&
25-
)
26-
call test_multihead_attention_combine_heads(attention, scaled_dp_att, ok)
16+
call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
17+
call test_multihead_attention_normalization(attention, ok)
18+
call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok)
19+
call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2720
call test_multihead_attention_forward(attention, ok)
2821
call test_multihead_attention_forward_reallife_shape(ok)
2922

@@ -52,11 +45,10 @@ subroutine test_multihead_attention_split_heads(attention, input, ok, output)
5245
end if
5346
end subroutine test_multihead_attention_split_heads
5447

55-
subroutine test_multihead_attention_create_attention_matrix(attention, input, ok, attention_matrix)
48+
subroutine test_multihead_attention_create_attention_matrix(attention, input, ok)
5649
type(multihead_attention_layer), intent(in) :: attention
5750
real, intent(in) :: input(:, :, :, :)
5851
logical, intent(in out) :: ok
59-
real, intent(in out) :: attention_matrix(2, 3, 3, 1)
6052
real :: attention_matrix_shape(4)
6153
real :: attention_matrix_flat(18)
6254
real :: expected_shape(4) = [2, 3, 3, 1]
@@ -69,58 +61,52 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok
6961
0.573199987, 0.289999992, 0.654400051&
7062
]
7163

72-
attention_matrix = attention % create_attention_matrix(input, input)
64+
call attention % create_attention_matrix(input, input)
7365

74-
attention_matrix_shape = shape(attention_matrix)
66+
attention_matrix_shape = shape(attention % attention_matrix)
7567
if (.not. all(attention_matrix_shape.eq.expected_shape)) then
7668
ok = .false.
7769
write(stderr, '(a)') 'create_attention_matrix returned incorrect shape.. failed'
7870
end if
79-
attention_matrix_flat = reshape(attention_matrix, shape(expected_attention_matrix_flat))
71+
attention_matrix_flat = reshape(attention % attention_matrix, shape(expected_attention_matrix_flat))
8072
if (.not. all(attention_matrix_flat.eq.expected_attention_matrix_flat)) then
8173
ok = .false.
8274
write(stderr, '(a)') 'create_attention_matrix returned incorrect values.. failed'
8375
end if
8476
end subroutine test_multihead_attention_create_attention_matrix
8577

86-
subroutine test_multihead_attention_normalization(attention, input, ok, output)
78+
subroutine test_multihead_attention_normalization(attention, ok)
8779
type(multihead_attention_layer), intent(in) :: attention
88-
real, intent(in) :: input(:, :, :, :)
8980
logical, intent(in out) :: ok
90-
real, intent(out) :: output(2, 3, 3, 1)
9181
real :: output_flat(18)
9282
real :: expected_output_flat(18) = [&
9383
0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,&
9484
0.333283335, 0.275134116, 0.333194464, 0.326415271, 0.333061278, 0.325773478,&
9585
0.340428889, 0.288890868, 0.345185399, 0.343245387, 0.349961787, 0.345025837&
9686
]
9787

98-
output = attention % normalize_attention_matrix(input)
88+
call attention % normalize_attention_matrix()
9989

100-
output_flat = reshape(output, shape(output_flat))
90+
output_flat = reshape(attention % attention_matrix, shape(output_flat))
10191
if (.not. all(output_flat.eq.expected_output_flat)) then
10292
ok = .false.
10393
write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed'
10494
end if
10595
end subroutine test_multihead_attention_normalization
10696

107-
subroutine test_multihead_attention_scaled_dot_product_attention(attention, attention_matrix, value, ok, output)
97+
subroutine test_multihead_attention_scaled_dot_product_attention(attention, value, ok)
10898
type(multihead_attention_layer), intent(in) :: attention
109-
real, intent(in) :: attention_matrix(:, :, :, :)
11099
real, intent(in) :: value(:, :, :, :)
111100
logical, intent(in out) :: ok
112-
real, intent(out) :: output(&
113-
attention % n_heads, attention % sequence_length, attention % head_size, attention % batch_size&
114-
)
115101
real :: output_flat(12)
116102
real :: expected_output_flat(12) = [&
117103
0.101414114, 0.685291648, 0.102356531, 0.701290607, 0.103298485, 0.701582491,&
118104
0.401414126, 0.457309216, 0.402356505, 0.374400526, 0.403298497, 0.373518765&
119105
]
120106

121-
output = attention % scaled_dot_product_attention(attention_matrix, value)
107+
call attention % scaled_dot_product_attention(value)
122108

123-
output_flat = reshape(output, shape(output_flat))
109+
output_flat = reshape(attention % sdpa, shape(output_flat))
124110
if (.not. all(output_flat.eq.expected_output_flat)) then
125111
ok = .false.
126112
write(stderr, '(a)') 'scaled_dot_product_attention returned incorrect values.. failed'

0 commit comments

Comments
 (0)