Skip to content

Commit 5da87a6

Browse files
committed
multihead_attention: use heap-allocated arrays during back prop
1 parent 4d2fda2 commit 5da87a6

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,28 @@ module subroutine backward(self, input, gradient)
120120
real, intent(in) :: input(:, :, :)
121121
real, intent(in) :: gradient(:, :, :)
122122

123-
real :: d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
124-
real :: v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
125-
real :: k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
126-
real :: q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
127-
real :: d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
128-
real :: jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
129-
real :: d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
130-
real :: d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
131-
real :: dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
123+
real, allocatable :: d_output(:, :, :, :)
124+
real, allocatable :: v_heads(:, :, :, :)
125+
real, allocatable :: k_heads(:, :, :, :)
126+
real, allocatable :: q_heads(:, :, :, :)
127+
real, allocatable :: d_sdpa(:, :, :, :)
128+
real, allocatable :: jacobian(:, :, :, :)
129+
real, allocatable :: d_normalize(:, :, :, :)
130+
real, allocatable :: d_attn_matrix(:, :, :, :)
131+
real, allocatable :: dk(:, :, :, :)
132132
integer :: batch, head, i, j
133133

134+
! allocate temporary storages for backward computation
135+
allocate(d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
136+
allocate(v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
137+
allocate(k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
138+
allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
139+
allocate(d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
140+
allocate(jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
141+
allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
142+
allocate(d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
143+
allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
144+
134145
! calculate output layer delta
135146
call self % output_layer % backward(input, gradient)
136147

@@ -178,6 +189,17 @@ module subroutine backward(self, input, gradient)
178189
call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa))
179190
call self % key_layer % backward(self % k_input, self % combine_heads(dk))
180191
call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
192+
193+
! free temporary storages
194+
deallocate(d_output)
195+
deallocate(v_heads)
196+
deallocate(k_heads)
197+
deallocate(q_heads)
198+
deallocate(d_sdpa)
199+
deallocate(jacobian)
200+
deallocate(d_normalize)
201+
deallocate(d_attn_matrix)
202+
deallocate(dk)
181203
end subroutine backward
182204

183205
module subroutine forward(self, query, key, value)

0 commit comments

Comments
 (0)