Skip to content

Commit 9c70efb

Browse files
committed
multihead_attention: add comments for forward prop
1 parent 5a66192 commit 9c70efb

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ module subroutine backward(self, input, gradient)
179179
end subroutine backward
180180

181181
module subroutine forward(self, query, key, value)
182+
!! General forward prop for MultiHead Attention Mechenism
182183
class(multihead_attention_layer), intent(in out) :: self
183184
real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :)
184185

@@ -192,16 +193,21 @@ module subroutine forward(self, query, key, value)
192193
self % k_input = key
193194
self % v_input = value
194195

196+
! run inputs through linear layers (trainable params)
195197
call self % query_layer % forward(query)
196198
call self % key_layer % forward(key)
197199
call self % value_layer % forward(value)
198200

201+
! split attention heads for more efficient computation
199202
q = self % split_heads(self % query_layer % output)
200203
k = self % split_heads(self % key_layer % output)
201204
v = self % split_heads(self % value_layer % output)
202205

206+
! create key by value matrix
203207
call self % create_attention_matrix(q, k)
208+
! apply softmax and scaling
204209
call self % normalize_attention_matrix()
210+
! multiply attention matrix by value
205211
call self % scaled_dot_product_attention(v)
206212

207213
call self % output_layer % forward(self % combine_heads(self % sdpa))

0 commit comments

Comments
 (0)