Skip to content

Commit 4d2fda2

Browse files
committed
multihead_attention: calculate scaling factor only once
1 parent 91531fc commit 4d2fda2

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ module nf_multihead_attention_layer
2626
real, allocatable :: sdpa(:, :, :, :)
2727
real, allocatable :: output(:, :, :)
2828

29+
real :: scaling_factor
30+
2931
real, allocatable :: q_input(:, :, :)
3032
real, allocatable :: k_input(:, :, :)
3133
real, allocatable :: v_input(:, :, :)
@@ -154,12 +156,12 @@ module subroutine backward(self, input, gradient)
154156
jacobian(head, i, j, batch) = &
155157
self % attention_matrix(head, i, j, batch) &
156158
* (1 - self % attention_matrix(head, i, j, batch)) &
157-
* sqrt(1 / real(self % head_size))
159+
* self % scaling_factor
158160
else
159161
jacobian(head, i, j, batch) = &
160162
- self % attention_matrix(head, i, j, batch) &
161163
* self % attention_matrix(head, i, j, batch) &
162-
* sqrt(1 / real(self % head_size))
164+
* self % scaling_factor
163165
end if
164166
end do
165167
! attention normalization delta, the last step of softmax derivative:
@@ -267,7 +269,7 @@ module subroutine normalize_attention_matrix(self, attention_mask)
267269
allocate(output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
268270

269271
! scale dowm by square root of each head's size
270-
self % attention_matrix = self % attention_matrix * sqrt(1 / real(self % head_size))
272+
self % attention_matrix = self % attention_matrix * self % scaling_factor
271273
! attention mask is used to mask out some of the tokens if necessary
272274
if (present(attention_mask)) then
273275
self % attention_matrix = self % attention_matrix + attention_mask
@@ -317,6 +319,8 @@ module subroutine init(self, input_shape)
317319
))
318320
allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size))
319321

322+
self % scaling_factor = sqrt(1 / real(self % head_size))
323+
320324
allocate(self % q_input(self % sequence_length, self % model_dimension, self % batch_size))
321325
allocate(self % k_input(self % sequence_length, self % model_dimension, self % batch_size))
322326
allocate(self % v_input(self % sequence_length, self % model_dimension, self % batch_size))

0 commit comments

Comments
 (0)