Skip to content

Commit 3144673

Browse files
committed
multihead_attention: fix minor scaling issue
1 parent 4005a30 commit 3144673

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ module subroutine normalize_attention_matrix(self, attention_mask)
209209
allocate(output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
210210

211211
! scale dowm by square root of each head's size
212-
self % attention_matrix = self % attention_matrix / sqrt(real(self % head_size))
212+
self % attention_matrix = self % attention_matrix * sqrt(1 / real(self % head_size))
213213
! attention mask is used to mask out some of the tokens if necessary
214214
if (present(attention_mask)) then
215215
self % attention_matrix = self % attention_matrix + attention_mask

0 commit comments

Comments
 (0)