Skip to content

Commit a7ed20a

Browse files
committed
multihead_attention: fix issues with shapes (softmax prime became even more monstruos)
1 parent c6a89de commit a7ed20a

File tree

1 file changed

+28
-16
lines changed

1 file changed

+28
-16
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,23 @@ module subroutine backward(self, input, gradient)
124124
real, allocatable :: v_heads(:, :, :, :)
125125
real, allocatable :: k_heads(:, :, :, :)
126126
real, allocatable :: q_heads(:, :, :, :)
127+
real, allocatable :: dv(:, :, :, :)
127128
real, allocatable :: d_sdpa(:, :, :, :)
128-
real, allocatable :: jacobian(:, :, :, :)
129+
real, allocatable :: jacobian(:, :, :)
129130
real, allocatable :: d_normalize(:, :, :, :)
130131
real, allocatable :: d_attn_matrix(:, :, :, :)
131132
real, allocatable :: dk(:, :, :, :)
132-
integer :: batch, head, i, j
133+
integer :: batch, head, seq, i, j
133134

134135
! allocate temporary storages for backward computation
135136
allocate(d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
136137
allocate(v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
137138
allocate(k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
138139
allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
140+
141+
allocate(dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
139142
allocate(d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
140-
allocate(jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
143+
allocate(jacobian(self % sequence_length, self % sequence_length, self % sequence_length))
141144
allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
142145
allocate(d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
143146
allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
@@ -153,40 +156,49 @@ module subroutine backward(self, input, gradient)
153156

154157
! iterate over heads to calculate deltas for each of them
155158
do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads)
156-
! calculate delta for value
159+
dv(head, :, :, batch) = matmul(transpose(self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch))
160+
161+
! calculate delta for attention matrix
157162
d_sdpa(head, :, :, batch) = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch)))
158163

159-
! this monstrosity is scaled derivative of softmax
160-
do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length)
164+
! this monstrosity below is scaled derivative of softmax
165+
do concurrent(seq = 1: self % sequence_length, i = 1: self % sequence_length, j = 1: self % sequence_length)
161166
! jacobian matrix is used to calculate derivative of softmax (temporary storage)
162167
! the idea behind this if-else is that for diagonal elements, the jacobian temp
163-
! should be: `softmax(x) * (1 - softmax(x))`
164-
! for off-diagonal: `-softmax^2(x)`
168+
! should be: `softmax(x_i) * (1 - softmax(x_i))`
169+
! for off-diagonal: `-softmax(x_i) * softmax(x_j)`
165170
! For computational efficiency (avoid more temp storages), scaling is also done here
166171
if (i == j) then
167-
jacobian(head, i, j, batch) = &
168-
self % attention_matrix(head, i, j, batch) &
169-
* (1 - self % attention_matrix(head, i, j, batch)) &
172+
jacobian(seq, i, j) = &
173+
self % attention_matrix(head, seq, i, batch) &
174+
* (1 - self % attention_matrix(head, seq, i, batch)) &
170175
* self % scaling_factor
171176
else
172-
jacobian(head, i, j, batch) = &
173-
- self % attention_matrix(head, i, j, batch) &
174-
* self % attention_matrix(head, i, j, batch) &
177+
jacobian(seq, i, j) = &
178+
- self % attention_matrix(head, seq, i, batch) &
179+
* self % attention_matrix(head, seq, j, batch) &
175180
* self % scaling_factor
176181
end if
177182
end do
183+
178184
! attention normalization delta, the last step of softmax derivative:
179185
! multiply temp jacobian matrix by the output of softmax
180-
d_normalize(head, :, :, batch) = matmul(d_sdpa(head, :, :, batch), jacobian(head, :, :, batch))
186+
do concurrent(seq = 1: self % sequence_length)
187+
d_normalize(head, seq, :, batch) = reshape(matmul(&
188+
reshape(d_sdpa(head, seq, :, batch), [1, self % sequence_length]),&
189+
jacobian(seq, :, :)&
190+
), [self % sequence_length])
191+
end do
181192

182193
! calculate delta for query
183194
d_attn_matrix(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
195+
184196
! calculate delta for key, attention matrix should be transposed unlike for query
185197
dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch))
186198
end do
187199

188200
! calculate deltas for input layers
189-
call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa))
201+
call self % value_layer % backward(self % v_input, self % combine_heads(dv))
190202
call self % key_layer % backward(self % k_input, self % combine_heads(dk))
191203
call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
192204

0 commit comments

Comments
 (0)