Skip to content

Commit 4005a30

Browse files
committed
multihead_attention: proof of concept backward (works, but not mathematically correct)
1 parent abb02eb commit 4005a30

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ module nf_multihead_attention_layer
2626
real, allocatable :: sdpa(:, :, :, :)
2727
real, allocatable :: output(:, :, :)
2828

29+
real, allocatable :: q_input(:, :, :)
30+
real, allocatable :: k_input(:, :, :)
31+
real, allocatable :: v_input(:, :, :)
2932
contains
3033

31-
! procedure :: backward
34+
procedure :: backward
3235
procedure :: forward
3336
procedure :: split_heads
3437
procedure :: create_attention_matrix
@@ -49,15 +52,15 @@ end function multihead_attention_layer_cons
4952

5053
interface
5154

52-
pure module subroutine backward(self, input, gradient)
55+
module subroutine backward(self, input, gradient)
5356
!! Apply the backward gradient descent pass.
5457
!! Only weight and bias gradients are updated in this subroutine,
5558
!! while the weights and biases themselves are untouched.
5659
class(multihead_attention_layer), intent(in out) :: self
5760
!! Dense layer instance
58-
real, intent(in) :: input(:)
61+
real, intent(in) :: input(:, :, :)
5962
!! Input from the previous layer
60-
real, intent(in) :: gradient(:)
63+
real, intent(in) :: gradient(:, :, :)
6164
!! Gradient from the next layer
6265
end subroutine backward
6366

@@ -109,6 +112,20 @@ module function multihead_attention_layer_cons(&
109112
res % softmax_func = softmax()
110113
end function multihead_attention_layer_cons
111114

115+
module subroutine backward(self, input, gradient)
116+
class(multihead_attention_layer), intent(in out) :: self
117+
real, intent(in) :: input(:, :, :)
118+
real, intent(in) :: gradient(:, :, :)
119+
120+
call self % output_layer % backward(input, gradient)
121+
122+
! FIXME: calculate gradient for softmax
123+
124+
call self % value_layer % backward(self % v_input, self % output_layer % gradient)
125+
call self % key_layer % backward(self % k_input, self % output_layer % gradient)
126+
call self % query_layer % backward(self % q_input, self % output_layer % gradient)
127+
end subroutine backward
128+
112129
module subroutine forward(self, query, key, value)
113130
class(multihead_attention_layer), intent(in out) :: self
114131
real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :)
@@ -119,6 +136,10 @@ module subroutine forward(self, query, key, value)
119136
real :: attention_matrix(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
120137
real :: dot_product_attention(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
121138

139+
self % q_input = query
140+
self % k_input = key
141+
self % v_input = value
142+
122143
call self % query_layer % forward(query)
123144
call self % key_layer % forward(key)
124145
call self % value_layer % forward(value)
@@ -237,5 +258,9 @@ module subroutine init(self, input_shape)
237258
self % n_heads, self % sequence_length, self % head_size, self % batch_size&
238259
))
239260
allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size))
261+
262+
allocate(self % q_input(self % sequence_length, self % model_dimension, self % batch_size))
263+
allocate(self % k_input(self % sequence_length, self % model_dimension, self % batch_size))
264+
allocate(self % v_input(self % sequence_length, self % model_dimension, self % batch_size))
240265
end subroutine init
241266
end module nf_multihead_attention_layer

test/test_multihead_attention_layer.f90

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ program test_multihead_attention_layer
1919
call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2020
call test_multihead_attention_forward(attention, ok)
2121
call test_multihead_attention_forward_reallife_shape(ok)
22+
call test_multihead_attention_backward(attention, ok)
2223

2324
contains
2425
subroutine test_multihead_attention_split_heads(attention, input, ok, output)
@@ -183,4 +184,23 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
183184
write(stderr, '(a)') 'forward returned incorrect shape.. failed'
184185
end if
185186
end subroutine test_multihead_attention_forward_reallife_shape
187+
188+
subroutine test_multihead_attention_backward(attention, ok)
189+
type(multihead_attention_layer), intent(in out) :: attention
190+
logical, intent(in out) :: ok
191+
real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1])
192+
real :: gradient(3, 4, 1) = reshape(&
193+
[.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3., 2., 2., .1, 3., 3., 3.], [3, 4, 1]&
194+
)
195+
real :: expected_shape(3) = [3, 4, 1]
196+
real :: output_shape(3)
197+
198+
call attention % backward(input, gradient)
199+
200+
output_shape = shape(attention % output_layer % gradient)
201+
if (.not. all(output_shape.eq.expected_shape)) then
202+
ok = .false.
203+
write(stderr, '(a)') 'backward returned incorrect shape.. failed'
204+
end if
205+
end subroutine test_multihead_attention_backward
186206
end program test_multihead_attention_layer

0 commit comments

Comments
 (0)