Skip to content

Commit c005e09

Browse files
committed
multihead_attention: add tests for backward
1 parent 9c70efb commit c005e09

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

test/test_multihead_attention_layer.f90

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ program test_multihead_attention_layer
1818
call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok)
1919
call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2020
call test_multihead_attention_forward(attention, ok)
21-
call test_multihead_attention_forward_reallife_shape(ok)
2221
call test_multihead_attention_backward(attention, ok)
22+
call test_multihead_attention_forward_reallife_shape(ok)
2323

2424
contains
2525
subroutine test_multihead_attention_split_heads(attention, input, ok, output)
@@ -189,18 +189,31 @@ subroutine test_multihead_attention_backward(attention, ok)
189189
type(multihead_attention_layer), intent(in out) :: attention
190190
logical, intent(in out) :: ok
191191
real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1])
192-
real :: gradient(3, 4, 1) = reshape(&
193-
[.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3., 2., 2., .1, 3., 3., 3.], [3, 4, 1]&
194-
)
192+
real :: gradient(3, 4, 1) = reshape([0.1, 3. , 2. , 0.1, 3. , 3. , 0.1, 2. , 0.1, 3. , 0.1, 3. ], [3, 4, 1])
193+
real :: expected_output_flat(12) = [&
194+
0.489710003, 0.240968466, -3.35404873E-02, 0.489710003,&
195+
0.240968466, -3.35404873E-02, 0.489710003, 0.240968466,&
196+
-3.35404873E-02, 0.489710003, 0.240968466, -3.35404873E-02&
197+
]
195198
real :: expected_shape(3) = [3, 4, 1]
199+
real :: output(3, 4, 1)
200+
real :: output_flat(12)
196201
real :: output_shape(3)
197202

198203
call attention % backward(input, gradient)
199204

200-
output_shape = shape(attention % output_layer % gradient)
205+
! sample for Self Attention: sum of output gradients
206+
output = attention % query_layer % gradient + attention % key_layer % gradient + attention % value_layer % gradient
207+
208+
output_shape = shape(output)
201209
if (.not. all(output_shape.eq.expected_shape)) then
202210
ok = .false.
203211
write(stderr, '(a)') 'backward returned incorrect shape.. failed'
204212
end if
213+
output_flat = reshape(output, shape(output_flat))
214+
if (.not. all(output_flat.eq.expected_output_flat)) then
215+
ok = .false.
216+
write(stderr, '(a)') 'backward returned incorrect values.. failed'
217+
end if
205218
end subroutine test_multihead_attention_backward
206219
end program test_multihead_attention_layer

0 commit comments

Comments
 (0)