@@ -18,8 +18,8 @@ program test_multihead_attention_layer
1818 call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok)
1919 call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2020 call test_multihead_attention_forward(attention, ok)
21- call test_multihead_attention_forward_reallife_shape(ok)
2221 call test_multihead_attention_backward(attention, ok)
22+ call test_multihead_attention_forward_reallife_shape(ok)
2323
2424contains
2525 subroutine test_multihead_attention_split_heads (attention , input , ok , output )
@@ -189,18 +189,31 @@ subroutine test_multihead_attention_backward(attention, ok)
189189 type (multihead_attention_layer), intent (in out ) :: attention
190190 logical , intent (in out ) :: ok
191191 real :: input(3 , 4 , 1 ) = reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 , 1 ])
192- real :: gradient(3 , 4 , 1 ) = reshape (&
193- [.1 , .1 , .1 , 3 ., 3 ., 3 ., 2 ., .1 , 2 ., 3 ., .1 , 3 ., 2 ., 2 ., .1 , 3 ., 3 ., 3 .], [3 , 4 , 1 ]&
194- )
192+ real :: gradient(3 , 4 , 1 ) = reshape ([0.1 , 3 . , 2 . , 0.1 , 3 . , 3 . , 0.1 , 2 . , 0.1 , 3 . , 0.1 , 3 . ], [3 , 4 , 1 ])
193+ real :: expected_output_flat(12 ) = [&
194+ 0.489710003 , 0.240968466 , - 3.35404873E-02 , 0.489710003 ,&
195+ 0.240968466 , - 3.35404873E-02 , 0.489710003 , 0.240968466 ,&
196+ - 3.35404873E-02 , 0.489710003 , 0.240968466 , - 3.35404873E-02 &
197+ ]
195198 real :: expected_shape(3 ) = [3 , 4 , 1 ]
199+ real :: output(3 , 4 , 1 )
200+ real :: output_flat(12 )
196201 real :: output_shape(3 )
197202
198203 call attention % backward(input, gradient)
199204
200- output_shape = shape (attention % output_layer % gradient)
205+ ! sample for Self Attention: sum of output gradients
206+ output = attention % query_layer % gradient + attention % key_layer % gradient + attention % value_layer % gradient
207+
208+ output_shape = shape (output)
201209 if (.not. all (output_shape.eq. expected_shape)) then
202210 ok = .false.
203211 write (stderr, ' (a)' ) ' backward returned incorrect shape.. failed'
204212 end if
213+ output_flat = reshape (output, shape (output_flat))
214+ if (.not. all (output_flat.eq. expected_output_flat)) then
215+ ok = .false.
216+ write (stderr, ' (a)' ) ' backward returned incorrect values.. failed'
217+ end if
205218 end subroutine test_multihead_attention_backward
206219end program test_multihead_attention_layer
0 commit comments