@@ -179,6 +179,7 @@ module subroutine backward(self, input, gradient)
179179 end subroutine backward
180180
181181 module subroutine forward (self , query , key , value )
182+ ! ! General forward prop for MultiHead Attention Mechenism
182183 class(multihead_attention_layer), intent (in out ) :: self
183184 real , intent (in ) :: query(:, :, :), key(:, :, :), value(:, :, :)
184185
@@ -192,16 +193,21 @@ module subroutine forward(self, query, key, value)
192193 self % k_input = key
193194 self % v_input = value
194195
196+ ! run inputs through linear layers (trainable params)
195197 call self % query_layer % forward(query)
196198 call self % key_layer % forward(key)
197199 call self % value_layer % forward(value)
198200
201+ ! split attention heads for more efficient computation
199202 q = self % split_heads(self % query_layer % output)
200203 k = self % split_heads(self % key_layer % output)
201204 v = self % split_heads(self % value_layer % output)
202205
206+ ! create key by value matrix
203207 call self % create_attention_matrix(q, k)
208+ ! apply softmax and scaling
204209 call self % normalize_attention_matrix()
210+ ! multiply attention matrix by value
205211 call self % scaled_dot_product_attention(v)
206212
207213 call self % output_layer % forward(self % combine_heads(self % sdpa))
0 commit comments