@@ -120,17 +120,28 @@ module subroutine backward(self, input, gradient)
120120 real , intent (in ) :: input(:, :, :)
121121 real , intent (in ) :: gradient(:, :, :)
122122
123- real :: d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size )
124- real :: v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size )
125- real :: k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size )
126- real :: q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size )
127- real :: d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size )
128- real :: jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size )
129- real :: d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size )
130- real :: d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size )
131- real :: dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size )
123+ real , allocatable :: d_output(:, :, :, : )
124+ real , allocatable :: v_heads(:, :, :, : )
125+ real , allocatable :: k_heads(:, :, :, : )
126+ real , allocatable :: q_heads(:, :, :, : )
127+ real , allocatable :: d_sdpa(:, :, :, : )
128+ real , allocatable :: jacobian(:, :, :, : )
129+ real , allocatable :: d_normalize(:, :, :, : )
130+ real , allocatable :: d_attn_matrix(:, :, :, : )
131+ real , allocatable :: dk(:, :, :, : )
132132 integer :: batch, head, i, j
133133
134+ ! allocate temporary storages for backward computation
135+ allocate (d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
136+ allocate (v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
137+ allocate (k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
138+ allocate (q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
139+ allocate (d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
140+ allocate (jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
141+ allocate (d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
142+ allocate (d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
143+ allocate (dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
144+
134145 ! calculate output layer delta
135146 call self % output_layer % backward(input, gradient)
136147
@@ -178,6 +189,17 @@ module subroutine backward(self, input, gradient)
178189 call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa))
179190 call self % key_layer % backward(self % k_input, self % combine_heads(dk))
180191 call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
192+
193+ ! free temporary storages
194+ deallocate (d_output)
195+ deallocate (v_heads)
196+ deallocate (k_heads)
197+ deallocate (q_heads)
198+ deallocate (d_sdpa)
199+ deallocate (jacobian)
200+ deallocate (d_normalize)
201+ deallocate (d_attn_matrix)
202+ deallocate (dk)
181203 end subroutine backward
182204
183205 module subroutine forward (self , query , key , value )
0 commit comments