@@ -124,20 +124,23 @@ module subroutine backward(self, input, gradient)
124124 real , allocatable :: v_heads(:, :, :, :)
125125 real , allocatable :: k_heads(:, :, :, :)
126126 real , allocatable :: q_heads(:, :, :, :)
127+ real , allocatable :: dv(:, :, :, :)
127128 real , allocatable :: d_sdpa(:, :, :, :)
128- real , allocatable :: jacobian(:, :, :, : )
129+ real , allocatable :: jacobian(:, :, :)
129130 real , allocatable :: d_normalize(:, :, :, :)
130131 real , allocatable :: d_attn_matrix(:, :, :, :)
131132 real , allocatable :: dk(:, :, :, :)
132- integer :: batch, head, i, j
133+ integer :: batch, head, seq, i, j
133134
134135 ! allocate temporary storages for backward computation
135136 allocate (d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
136137 allocate (v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
137138 allocate (k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
138139 allocate (q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
140+
141+ allocate (dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
139142 allocate (d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
140- allocate (jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size ))
143+ allocate (jacobian(self % sequence_length, self % sequence_length, self % sequence_length ))
141144 allocate (d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
142145 allocate (d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
143146 allocate (dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
@@ -153,40 +156,49 @@ module subroutine backward(self, input, gradient)
153156
154157 ! iterate over heads to calculate deltas for each of them
155158 do concurrent(batch = 1 : self % batch_size, head = 1 : self % n_heads)
156- ! calculate delta for value
159+ dv(head, :, :, batch) = matmul (transpose (self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch))
160+
161+ ! calculate delta for attention matrix
157162 d_sdpa(head, :, :, batch) = matmul (d_output(head, :, :, batch), transpose (v_heads(head, :, :, batch)))
158163
159- ! this monstrosity is scaled derivative of softmax
160- do concurrent(i = 1 : self % sequence_length, j = 1 : self % sequence_length)
164+ ! this monstrosity below is scaled derivative of softmax
165+ do concurrent(seq = 1 : self % sequence_length, i = 1 : self % sequence_length, j = 1 : self % sequence_length)
161166 ! jacobian matrix is used to calculate derivative of softmax (temporary storage)
162167 ! the idea behind this if-else is that for diagonal elements, the jacobian temp
163- ! should be: `softmax(x ) * (1 - softmax(x ))`
164- ! for off-diagonal: `-softmax^2(x )`
168+ ! should be: `softmax(x_i ) * (1 - softmax(x_i ))`
169+ ! for off-diagonal: `-softmax(x_i) * softmax(x_j )`
165170 ! For computational efficiency (avoid more temp storages), scaling is also done here
166171 if (i == j) then
167- jacobian(head , i, j, batch ) = &
168- self % attention_matrix(head, i, j , batch) &
169- * (1 - self % attention_matrix(head, i, j , batch)) &
172+ jacobian(seq , i, j) = &
173+ self % attention_matrix(head, seq, i , batch) &
174+ * (1 - self % attention_matrix(head, seq, i , batch)) &
170175 * self % scaling_factor
171176 else
172- jacobian(head , i, j, batch ) = &
173- - self % attention_matrix(head, i, j , batch) &
174- * self % attention_matrix(head, i , j, batch) &
177+ jacobian(seq , i, j) = &
178+ - self % attention_matrix(head, seq, i , batch) &
179+ * self % attention_matrix(head, seq , j, batch) &
175180 * self % scaling_factor
176181 end if
177182 end do
183+
178184 ! attention normalization delta, the last step of softmax derivative:
179185 ! multiply temp jacobian matrix by the output of softmax
180- d_normalize(head, :, :, batch) = matmul (d_sdpa(head, :, :, batch), jacobian(head, :, :, batch))
186+ do concurrent(seq = 1 : self % sequence_length)
187+ d_normalize(head, seq, :, batch) = reshape (matmul (&
188+ reshape (d_sdpa(head, seq, :, batch), [1 , self % sequence_length]),&
189+ jacobian(seq, :, :)&
190+ ), [self % sequence_length])
191+ end do
181192
182193 ! calculate delta for query
183194 d_attn_matrix(head, :, :, batch) = matmul (d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
195+
184196 ! calculate delta for key, attention matrix should be transposed unlike for query
185197 dk(head, :, :, batch) = matmul (transpose (d_normalize(head, :, :, batch)), q_heads(head, :, :, batch))
186198 end do
187199
188200 ! calculate deltas for input layers
189- call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa ))
201+ call self % value_layer % backward(self % v_input, self % combine_heads(dv ))
190202 call self % key_layer % backward(self % k_input, self % combine_heads(dk))
191203 call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
192204
0 commit comments