@@ -169,10 +169,8 @@ module subroutine create_attention_matrix(self, query, key)
169169 real :: key(:, :, :, :)
170170 integer :: i, j
171171 ! create attention matrix for each sequence in each batch
172- do i = 1 , self % batch_size
173- do j = 1 , self % n_heads
174- self % attention_matrix(j, :, :, i) = matmul (query(j, :, :, i), transpose (key(j, :, :, i)))
175- end do
172+ do concurrent(i = 1 : self % batch_size, j = 1 : self % n_heads)
173+ self % attention_matrix(j, :, :, i) = matmul (query(j, :, :, i), transpose (key(j, :, :, i)))
176174 end do
177175 end subroutine create_attention_matrix
178176
@@ -196,14 +194,8 @@ module subroutine normalize_attention_matrix(self, attention_mask)
196194 self % attention_matrix = self % attention_matrix + attention_mask
197195 end if
198196 ! softmax by last sequnce_length
199- do batch = 1 , self % batch_size
200- do head = 1 , self % n_heads
201- do seq = 1 , self % sequence_length
202- output(head, seq, :, batch) = self % softmax_func % eval_1d(&
203- self % attention_matrix(head, seq, :, batch)&
204- )
205- end do
206- end do
197+ do concurrent(batch = 1 : self % batch_size, head = 1 : self % n_heads, seq = 1 : self % sequence_length)
198+ output(head, seq, :, batch) = self % softmax_func % eval_1d(self % attention_matrix(head, seq, :, batch))
207199 end do
208200 self % attention_matrix = output
209201
@@ -217,10 +209,8 @@ module subroutine scaled_dot_product_attention(self, value)
217209 real :: value(:, :, :, :)
218210 integer :: batch, head
219211
220- do batch = 1 , self % batch_size
221- do head = 1 , self % n_heads
222- self % sdpa(head, :, :, batch) = matmul (self % attention_matrix(head, :, :, batch), value(head, :, :, batch))
223- end do
212+ do concurrent(batch = 1 : self % batch_size, head = 1 : self % n_heads)
213+ self % sdpa(head, :, :, batch) = matmul (self % attention_matrix(head, :, :, batch), value(head, :, :, batch))
224214 end do
225215 end subroutine scaled_dot_product_attention
226216
@@ -231,12 +221,8 @@ module function combine_heads(self, input) result(output)
231221 real :: output(self % sequence_length, self % model_dimension, self % batch_size)
232222 integer :: batch, seq
233223
234- do batch = 1 , self % batch_size
235- do seq = 1 , self % sequence_length
236- output(seq, :, batch) = reshape (&
237- transpose (input(:, seq, :, batch)), [self % model_dimension]&
238- )
239- end do
224+ do concurrent(batch = 1 : self % batch_size, seq = 1 : self % sequence_length)
225+ output(seq, :, batch) = reshape (transpose (input(:, seq, :, batch)), [self % model_dimension])
240226 end do
241227 end function combine_heads
242228
0 commit comments