@@ -113,17 +113,69 @@ module function multihead_attention_layer_cons(&
113113 end function multihead_attention_layer_cons
114114
115115 module subroutine backward (self , input , gradient )
116+ ! ! General backprop for MultiHead Attention mechanism
116117 class(multihead_attention_layer), intent (in out ) :: self
117118 real , intent (in ) :: input(:, :, :)
118119 real , intent (in ) :: gradient(:, :, :)
119120
121+ real :: d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
122+ real :: v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
123+ real :: k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
124+ real :: q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
125+ real :: d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
126+ real :: jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
127+ real :: d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
128+ real :: d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
129+ real :: dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
130+ integer :: batch, head, i, j
131+
132+ ! calculate output layer delta
120133 call self % output_layer % backward(input, gradient)
121134
122- ! FIXME: calculate gradient for softmax
135+ ! split heads from output gradient
136+ d_output = self % split_heads(self % output_layer % gradient)
137+ v_heads = self % split_heads(self % value_layer % output)
138+ k_heads = self % split_heads(self % key_layer % output)
139+ q_heads = self % split_heads(self % query_layer % output)
123140
124- call self % value_layer % backward(self % v_input, self % output_layer % gradient)
125- call self % key_layer % backward(self % k_input, self % output_layer % gradient)
126- call self % query_layer % backward(self % q_input, self % output_layer % gradient)
141+ ! iterate over heads to calculate deltas for each of them
142+ do concurrent(batch = 1 : self % batch_size, head = 1 : self % n_heads)
143+ ! calculate delta for value
144+ d_sdpa(head, :, :, batch) = matmul (d_output(head, :, :, batch), transpose (v_heads(head, :, :, batch)))
145+
146+ ! this monstrosity is scaled derivative of softmax
147+ do concurrent(i = 1 : self % sequence_length, j = 1 : self % sequence_length)
148+ ! jacobian matrix is used to calculate derivative of softmax (temporary storage)
149+ ! the idea behind this if-else is that for diagonal elements, the jacobian temp
150+ ! should be: `softmax(x) * (1 - softmax(x))`
151+ ! for off-diagonal: `-softmax^2(x)`
152+ ! For computational efficiency (avoid more temp storages), scaling is also done here
153+ if (i == j) then
154+ jacobian(head, i, j, batch) = &
155+ self % attention_matrix(head, i, j, batch) &
156+ * (1 - self % attention_matrix(head, i, j, batch)) &
157+ * sqrt (1 / real (self % head_size))
158+ else
159+ jacobian(head, i, j, batch) = &
160+ - self % attention_matrix(head, i, j, batch) &
161+ * self % attention_matrix(head, i, j, batch) &
162+ * sqrt (1 / real (self % head_size))
163+ end if
164+ end do
165+ ! attention normalization delta, the last step of softmax derivative:
166+ ! multiply temp jacobian matrix by the output of softmax
167+ d_normalize(head, :, :, batch) = matmul (d_sdpa(head, :, :, batch), jacobian(head, :, :, batch))
168+
169+ ! calculate delta for query
170+ d_attn_matrix(head, :, :, batch) = matmul (d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
171+ ! calculate delta for key, attention matrix should be transposed unlike for query
172+ dk(head, :, :, batch) = matmul (transpose (d_normalize(head, :, :, batch)), q_heads(head, :, :, batch))
173+ end do
174+
175+ ! calculate deltas for input layers
176+ call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa))
177+ call self % key_layer % backward(self % k_input, self % combine_heads(dk))
178+ call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
127179 end subroutine backward
128180
129181 module subroutine forward (self , query , key , value )
0 commit comments