@@ -102,10 +102,10 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension,
102102 end if
103103 res % head_size = model_dimension / n_heads
104104
105- res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1 )
106- res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1 )
107- res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1 )
108- res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1 )
105+ res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
106+ res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
107+ res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
108+ res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
109109 call res % query_layer % init([0 ])
110110 call res % key_layer % init([0 ])
111111 call res % value_layer % init([0 ])
@@ -145,20 +145,13 @@ module subroutine backward(self, input, gradient)
145145 allocate (dk(self % sequence_length, self % head_size, self % n_heads))
146146
147147 ! calculate output layer delta
148- ! FIXME: remove reshapes when linear2d situation is resolved
149- call self % output_layer % backward(&
150- reshape (self % o_input, [self % sequence_length, self % model_dimension, 1 ]),&
151- reshape (gradient, [self % sequence_length, self % model_dimension, 1 ])&
152- )
148+ call self % output_layer % backward(self % o_input, gradient)
153149
154150 ! split heads from output gradient
155- ! FIXME: remove reshapes when linear2d situation is resolved
156- d_output = self % split_heads(&
157- reshape (self % output_layer % gradient, [self % sequence_length, self % model_dimension]))
158- v_heads = self % split_heads(&
159- reshape (self % value_layer % output, [self % sequence_length, self % model_dimension]))
160- k_heads = self % split_heads(reshape (self % key_layer % output, [self % sequence_length, self % model_dimension]))
161- q_heads = self % split_heads(reshape (self % query_layer % output, [self % sequence_length, self % model_dimension]))
151+ d_output = self % split_heads(self % output_layer % gradient)
152+ v_heads = self % split_heads(self % value_layer % output)
153+ k_heads = self % split_heads(self % key_layer % output)
154+ q_heads = self % split_heads(self % query_layer % output)
162155
163156 ! iterate over heads to calculate deltas for each of them
164157 do concurrent(head = 1 : self % n_heads)
@@ -203,19 +196,9 @@ module subroutine backward(self, input, gradient)
203196 end do
204197
205198 ! calculate deltas for input layers
206- ! FIXME: remove reshapes when linear2d situation is resolved
207- call self % value_layer % backward(&
208- reshape (self % v_input, [self % sequence_length, self % model_dimension, 1 ]),&
209- reshape (self % combine_heads(dv), [self % sequence_length, self % model_dimension, 1 ])&
210- )
211- call self % key_layer % backward(&
212- reshape (self % k_input, [self % sequence_length, self % model_dimension, 1 ]),&
213- reshape (self % combine_heads(dk), [self % sequence_length, self % model_dimension, 1 ])&
214- )
215- call self % query_layer % backward(&
216- reshape (self % q_input, [self % sequence_length, self % model_dimension, 1 ]),&
217- reshape (self % combine_heads(dq), [self % sequence_length, self % model_dimension, 1 ])&
218- )
199+ call self % value_layer % backward(self % v_input, self % combine_heads(dv))
200+ call self % key_layer % backward(self % k_input, self % combine_heads(dk))
201+ call self % query_layer % backward(self % q_input, self % combine_heads(dq))
219202
220203 ! free temporary storages
221204 deallocate (d_output)
@@ -247,16 +230,14 @@ module subroutine forward(self, query, key, value)
247230 self % v_input = value
248231
249232 ! run inputs through linear layers (trainable params)
250- ! FIXME: remove reshapes when linear2d situation is resolved
251- call self % query_layer % forward(reshape (query, [self % sequence_length, self % model_dimension, 1 ]))
252- call self % key_layer % forward(reshape (key, [self % sequence_length, self % model_dimension, 1 ]))
253- call self % value_layer % forward(reshape (value, [self % sequence_length, self % model_dimension, 1 ]))
233+ call self % query_layer % forward(query)
234+ call self % key_layer % forward(key)
235+ call self % value_layer % forward(value)
254236
255237 ! split attention heads for more efficient computation
256- ! FIXME: remove reshapes when linear2d situation is resolved
257- q = self % split_heads(reshape (self % query_layer % output, [self % sequence_length, self % model_dimension]))
258- k = self % split_heads(reshape (self % key_layer % output, [self % sequence_length, self % model_dimension]))
259- v = self % split_heads(reshape (self % value_layer % output, [self % sequence_length, self % model_dimension]))
238+ q = self % split_heads(self % query_layer % output)
239+ k = self % split_heads(self % key_layer % output)
240+ v = self % split_heads(self % value_layer % output)
260241
261242 ! create key by value matrix
262243 call self % create_attention_matrix(q, k)
@@ -265,10 +246,9 @@ module subroutine forward(self, query, key, value)
265246 ! multiply attention matrix by value
266247 call self % scaled_dot_product_attention(v)
267248
268- ! FIXME: remove reshapes when linear2d situation is resolved
269249 self % o_input = self % combine_heads(self % sdpa)
270- call self % output_layer % forward(reshape ( self % o_input, [self % sequence_length, self % model_dimension, 1 ]) )
271- self % output = reshape ( self % output_layer % output, [self % sequence_length, self % model_dimension])
250+ call self % output_layer % forward(self % o_input)
251+ self % output = self % output_layer % output
272252
273253 ! free temp vars from memory
274254 deallocate (q)
0 commit comments