@@ -207,11 +207,14 @@ module subroutine forward(self, query, key, value)
207207 class(multihead_attention_layer), intent (in out ) :: self
208208 real , intent (in ) :: query(:, :, :), key(:, :, :), value(:, :, :)
209209
210- real :: q(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
211- real :: k(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
212- real :: v(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
213- real :: attention_matrix(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
214- real :: dot_product_attention(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
210+ real , allocatable :: q(:, :, :, :)
211+ real , allocatable :: k(:, :, :, :)
212+ real , allocatable :: v(:, :, :, :)
213+
214+ ! allocate storage for intermidiate stages
215+ allocate (q(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
216+ allocate (k(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
217+ allocate (v(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
215218
216219 self % q_input = query
217220 self % k_input = key
@@ -236,6 +239,11 @@ module subroutine forward(self, query, key, value)
236239
237240 call self % output_layer % forward(self % combine_heads(self % sdpa))
238241 self % output = self % output_layer % output
242+
243+ ! free temp vars from memory
244+ deallocate (q)
245+ deallocate (k)
246+ deallocate (v)
239247 end subroutine forward
240248
241249 module function split_heads (self , input ) result(output)
0 commit comments