Skip to content

Commit 91531fc

Browse files
committed
multihead_attention: adjust expected test values for updated scaling
1 parent c005e09 commit 91531fc

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

test/test_multihead_attention_layer.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ subroutine test_multihead_attention_normalization(attention, ok)
8181
logical, intent(in out) :: ok
8282
real :: output_flat(18)
8383
real :: expected_output_flat(18) = [&
84-
0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,&
85-
0.333283335, 0.275134116, 0.333194464, 0.326415271, 0.333061278, 0.325773478,&
86-
0.340428889, 0.288890868, 0.345185399, 0.343245387, 0.349961787, 0.345025837&
84+
0.326287806, 0.435975075, 0.321620107, 0.330339372, 0.316976935, 0.329200655,&
85+
0.333283335, 0.275134116, 0.333194494, 0.326415271, 0.333061278, 0.325773478,&
86+
0.340428889, 0.288890868, 0.345185429, 0.343245387, 0.349961787, 0.345025837&
8787
]
8888

8989
call attention % normalize_attention_matrix()
@@ -101,8 +101,8 @@ subroutine test_multihead_attention_scaled_dot_product_attention(attention, valu
101101
logical, intent(in out) :: ok
102102
real :: output_flat(12)
103103
real :: expected_output_flat(12) = [&
104-
0.101414114, 0.685291648, 0.102356531, 0.701290607, 0.103298485, 0.701582491,&
105-
0.401414126, 0.457309216, 0.402356505, 0.374400526, 0.403298497, 0.373518765&
104+
0.101414114, 0.685291648, 0.102356538, 0.701290667, 0.103298485, 0.701582491,&
105+
0.401414126, 0.457309216, 0.402356565, 0.374400556, 0.403298497, 0.373518765&
106106
]
107107

108108
call attention % scaled_dot_product_attention(value)
@@ -121,8 +121,8 @@ subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok)
121121
real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size)
122122
real :: output_flat(12)
123123
real :: expected_output_flat(12) = [&
124-
0.101414114, 0.102356531, 0.103298485, 0.401414126, 0.402356505, 0.403298497,&
125-
0.685291648, 0.701290607, 0.701582491, 0.457309216, 0.374400526, 0.373518765&
124+
0.101414114, 0.102356538, 0.103298485, 0.401414126, 0.402356565, 0.403298497,&
125+
0.685291648, 0.701290667, 0.701582491, 0.457309216, 0.374400556, 0.373518765&
126126
]
127127

128128
output = attention % combine_heads(scaled_dp_att)

0 commit comments

Comments
 (0)