Skip to content

Commit abb02eb

Browse files
committed
multihead_attention: concurrency
1 parent d8d7355 commit abb02eb

File tree

1 file changed

+8
-22
lines changed

1 file changed

+8
-22
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,8 @@ module subroutine create_attention_matrix(self, query, key)
169169
real :: key(:, :, :, :)
170170
integer :: i, j
171171
! create attention matrix for each sequence in each batch
172-
do i = 1, self % batch_size
173-
do j = 1, self % n_heads
174-
self % attention_matrix(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i)))
175-
end do
172+
do concurrent(i = 1: self % batch_size, j = 1: self % n_heads)
173+
self % attention_matrix(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i)))
176174
end do
177175
end subroutine create_attention_matrix
178176

@@ -196,14 +194,8 @@ module subroutine normalize_attention_matrix(self, attention_mask)
196194
self % attention_matrix = self % attention_matrix + attention_mask
197195
end if
198196
! softmax by last sequnce_length
199-
do batch = 1, self % batch_size
200-
do head = 1, self % n_heads
201-
do seq = 1, self % sequence_length
202-
output(head, seq, :, batch) = self % softmax_func % eval_1d(&
203-
self % attention_matrix(head, seq, :, batch)&
204-
)
205-
end do
206-
end do
197+
do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads, seq = 1: self % sequence_length)
198+
output(head, seq, :, batch) = self % softmax_func % eval_1d(self % attention_matrix(head, seq, :, batch))
207199
end do
208200
self % attention_matrix = output
209201

@@ -217,10 +209,8 @@ module subroutine scaled_dot_product_attention(self, value)
217209
real :: value(:, :, :, :)
218210
integer :: batch, head
219211

220-
do batch = 1, self % batch_size
221-
do head = 1, self % n_heads
222-
self % sdpa(head, :, :, batch) = matmul(self % attention_matrix(head, :, :, batch), value(head, :, :, batch))
223-
end do
212+
do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads)
213+
self % sdpa(head, :, :, batch) = matmul(self % attention_matrix(head, :, :, batch), value(head, :, :, batch))
224214
end do
225215
end subroutine scaled_dot_product_attention
226216

@@ -231,12 +221,8 @@ module function combine_heads(self, input) result(output)
231221
real :: output(self % sequence_length, self % model_dimension, self % batch_size)
232222
integer :: batch, seq
233223

234-
do batch = 1, self % batch_size
235-
do seq = 1, self % sequence_length
236-
output(seq, :, batch) = reshape(&
237-
transpose(input(:, seq, :, batch)), [self % model_dimension]&
238-
)
239-
end do
224+
do concurrent(batch = 1: self % batch_size, seq = 1: self % sequence_length)
225+
output(seq, :, batch) = reshape(transpose(input(:, seq, :, batch)), [self % model_dimension])
240226
end do
241227
end function combine_heads
242228

0 commit comments

Comments
 (0)