Skip to content

Commit 8a526db

Browse files
committed
multihead_attention: remove reshape crutches
1 parent e798a95 commit 8a526db

File tree

1 file changed

+20
-40
lines changed

1 file changed

+20
-40
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension,
102102
end if
103103
res % head_size = model_dimension / n_heads
104104

105-
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1)
106-
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1)
107-
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1)
108-
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1)
105+
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
106+
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
107+
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
108+
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
109109
call res % query_layer % init([0])
110110
call res % key_layer % init([0])
111111
call res % value_layer % init([0])
@@ -145,20 +145,13 @@ module subroutine backward(self, input, gradient)
145145
allocate(dk(self % sequence_length, self % head_size, self % n_heads))
146146

147147
! calculate output layer delta
148-
! FIXME: remove reshapes when linear2d situation is resolved
149-
call self % output_layer % backward(&
150-
reshape(self % o_input, [self % sequence_length, self % model_dimension, 1]),&
151-
reshape(gradient, [self % sequence_length, self % model_dimension, 1])&
152-
)
148+
call self % output_layer % backward(self % o_input, gradient)
153149

154150
! split heads from output gradient
155-
! FIXME: remove reshapes when linear2d situation is resolved
156-
d_output = self % split_heads(&
157-
reshape(self % output_layer % gradient, [self % sequence_length, self % model_dimension]))
158-
v_heads = self % split_heads(&
159-
reshape(self % value_layer % output, [self % sequence_length, self % model_dimension]))
160-
k_heads = self % split_heads(reshape(self % key_layer % output, [self % sequence_length, self % model_dimension]))
161-
q_heads = self % split_heads(reshape(self % query_layer % output, [self % sequence_length, self % model_dimension]))
151+
d_output = self % split_heads(self % output_layer % gradient)
152+
v_heads = self % split_heads(self % value_layer % output)
153+
k_heads = self % split_heads(self % key_layer % output)
154+
q_heads = self % split_heads(self % query_layer % output)
162155

163156
! iterate over heads to calculate deltas for each of them
164157
do concurrent(head = 1: self % n_heads)
@@ -203,19 +196,9 @@ module subroutine backward(self, input, gradient)
203196
end do
204197

205198
! calculate deltas for input layers
206-
! FIXME: remove reshapes when linear2d situation is resolved
207-
call self % value_layer % backward(&
208-
reshape(self % v_input, [self % sequence_length, self % model_dimension, 1]),&
209-
reshape(self % combine_heads(dv), [self % sequence_length, self % model_dimension, 1])&
210-
)
211-
call self % key_layer % backward(&
212-
reshape(self % k_input, [self % sequence_length, self % model_dimension, 1]),&
213-
reshape(self % combine_heads(dk), [self % sequence_length, self % model_dimension, 1])&
214-
)
215-
call self % query_layer % backward(&
216-
reshape(self % q_input, [self % sequence_length, self % model_dimension, 1]),&
217-
reshape(self % combine_heads(dq), [self % sequence_length, self % model_dimension, 1])&
218-
)
199+
call self % value_layer % backward(self % v_input, self % combine_heads(dv))
200+
call self % key_layer % backward(self % k_input, self % combine_heads(dk))
201+
call self % query_layer % backward(self % q_input, self % combine_heads(dq))
219202

220203
! free temporary storages
221204
deallocate(d_output)
@@ -247,16 +230,14 @@ module subroutine forward(self, query, key, value)
247230
self % v_input = value
248231

249232
! run inputs through linear layers (trainable params)
250-
! FIXME: remove reshapes when linear2d situation is resolved
251-
call self % query_layer % forward(reshape(query, [self % sequence_length, self % model_dimension, 1]))
252-
call self % key_layer % forward(reshape(key, [self % sequence_length, self % model_dimension, 1]))
253-
call self % value_layer % forward(reshape(value, [self % sequence_length, self % model_dimension, 1]))
233+
call self % query_layer % forward(query)
234+
call self % key_layer % forward(key)
235+
call self % value_layer % forward(value)
254236

255237
! split attention heads for more efficient computation
256-
! FIXME: remove reshapes when linear2d situation is resolved
257-
q = self % split_heads(reshape(self % query_layer % output, [self % sequence_length, self % model_dimension]))
258-
k = self % split_heads(reshape(self % key_layer % output, [self % sequence_length, self % model_dimension]))
259-
v = self % split_heads(reshape(self % value_layer % output, [self % sequence_length, self % model_dimension]))
238+
q = self % split_heads(self % query_layer % output)
239+
k = self % split_heads(self % key_layer % output)
240+
v = self % split_heads(self % value_layer % output)
260241

261242
! create key by value matrix
262243
call self % create_attention_matrix(q, k)
@@ -265,10 +246,9 @@ module subroutine forward(self, query, key, value)
265246
! multiply attention matrix by value
266247
call self % scaled_dot_product_attention(v)
267248

268-
! FIXME: remove reshapes when linear2d situation is resolved
269249
self % o_input = self % combine_heads(self % sdpa)
270-
call self % output_layer % forward(reshape(self % o_input, [self % sequence_length, self % model_dimension, 1]))
271-
self % output = reshape(self % output_layer % output, [self % sequence_length, self % model_dimension])
250+
call self % output_layer % forward(self % o_input)
251+
self % output = self % output_layer % output
272252

273253
! free temp vars from memory
274254
deallocate(q)

0 commit comments

Comments
 (0)