@@ -26,6 +26,8 @@ module nf_multihead_attention_layer
2626 real , allocatable :: sdpa(:, :, :, :)
2727 real , allocatable :: output(:, :, :)
2828
29+ real :: scaling_factor
30+
2931 real , allocatable :: q_input(:, :, :)
3032 real , allocatable :: k_input(:, :, :)
3133 real , allocatable :: v_input(:, :, :)
@@ -154,12 +156,12 @@ module subroutine backward(self, input, gradient)
154156 jacobian(head, i, j, batch) = &
155157 self % attention_matrix(head, i, j, batch) &
156158 * (1 - self % attention_matrix(head, i, j, batch)) &
157- * sqrt ( 1 / real ( self % head_size))
159+ * self % scaling_factor
158160 else
159161 jacobian(head, i, j, batch) = &
160162 - self % attention_matrix(head, i, j, batch) &
161163 * self % attention_matrix(head, i, j, batch) &
162- * sqrt ( 1 / real ( self % head_size))
164+ * self % scaling_factor
163165 end if
164166 end do
165167 ! attention normalization delta, the last step of softmax derivative:
@@ -267,7 +269,7 @@ module subroutine normalize_attention_matrix(self, attention_mask)
267269 allocate (output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
268270
269271 ! scale dowm by square root of each head's size
270- self % attention_matrix = self % attention_matrix * sqrt ( 1 / real ( self % head_size))
272+ self % attention_matrix = self % attention_matrix * self % scaling_factor
271273 ! attention mask is used to mask out some of the tokens if necessary
272274 if (present (attention_mask)) then
273275 self % attention_matrix = self % attention_matrix + attention_mask
@@ -317,6 +319,8 @@ module subroutine init(self, input_shape)
317319 ))
318320 allocate (self % output(self % sequence_length, self % model_dimension, self % batch_size))
319321
322+ self % scaling_factor = sqrt (1 / real (self % head_size))
323+
320324 allocate (self % q_input(self % sequence_length, self % model_dimension, self % batch_size))
321325 allocate (self % k_input(self % sequence_length, self % model_dimension, self % batch_size))
322326 allocate (self % v_input(self % sequence_length, self % model_dimension, self % batch_size))
0 commit comments