Skip to content

Commit 5a66192

Browse files
committed
multihead_attention: complete backward implementation
1 parent 3144673 commit 5a66192

File tree

1 file changed

+56
-4
lines changed

1 file changed

+56
-4
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,69 @@ module function multihead_attention_layer_cons(&
113113
end function multihead_attention_layer_cons
114114

115115
module subroutine backward(self, input, gradient)
116+
!! General backprop for MultiHead Attention mechanism
116117
class(multihead_attention_layer), intent(in out) :: self
117118
real, intent(in) :: input(:, :, :)
118119
real, intent(in) :: gradient(:, :, :)
119120

121+
real :: d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
122+
real :: v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
123+
real :: k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
124+
real :: q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
125+
real :: d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
126+
real :: jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
127+
real :: d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
128+
real :: d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
129+
real :: dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
130+
integer :: batch, head, i, j
131+
132+
! calculate output layer delta
120133
call self % output_layer % backward(input, gradient)
121134

122-
! FIXME: calculate gradient for softmax
135+
! split heads from output gradient
136+
d_output = self % split_heads(self % output_layer % gradient)
137+
v_heads = self % split_heads(self % value_layer % output)
138+
k_heads = self % split_heads(self % key_layer % output)
139+
q_heads = self % split_heads(self % query_layer % output)
123140

124-
call self % value_layer % backward(self % v_input, self % output_layer % gradient)
125-
call self % key_layer % backward(self % k_input, self % output_layer % gradient)
126-
call self % query_layer % backward(self % q_input, self % output_layer % gradient)
141+
! iterate over heads to calculate deltas for each of them
142+
do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads)
143+
! calculate delta for value
144+
d_sdpa(head, :, :, batch) = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch)))
145+
146+
! this monstrosity is scaled derivative of softmax
147+
do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length)
148+
! jacobian matrix is used to calculate derivative of softmax (temporary storage)
149+
! the idea behind this if-else is that for diagonal elements, the jacobian temp
150+
! should be: `softmax(x) * (1 - softmax(x))`
151+
! for off-diagonal: `-softmax^2(x)`
152+
! For computational efficiency (avoid more temp storages), scaling is also done here
153+
if (i == j) then
154+
jacobian(head, i, j, batch) = &
155+
self % attention_matrix(head, i, j, batch) &
156+
* (1 - self % attention_matrix(head, i, j, batch)) &
157+
* sqrt(1 / real(self % head_size))
158+
else
159+
jacobian(head, i, j, batch) = &
160+
- self % attention_matrix(head, i, j, batch) &
161+
* self % attention_matrix(head, i, j, batch) &
162+
* sqrt(1 / real(self % head_size))
163+
end if
164+
end do
165+
! attention normalization delta, the last step of softmax derivative:
166+
! multiply temp jacobian matrix by the output of softmax
167+
d_normalize(head, :, :, batch) = matmul(d_sdpa(head, :, :, batch), jacobian(head, :, :, batch))
168+
169+
! calculate delta for query
170+
d_attn_matrix(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
171+
! calculate delta for key, attention matrix should be transposed unlike for query
172+
dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch))
173+
end do
174+
175+
! calculate deltas for input layers
176+
call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa))
177+
call self % key_layer % backward(self % k_input, self % combine_heads(dk))
178+
call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
127179
end subroutine backward
128180

129181
module subroutine forward(self, query, key, value)

0 commit comments

Comments
 (0)