Skip to content

Commit 43d2447

Browse files
committed
multihead_attention: rename common forward and backward calls
1 parent 8a526db commit 43d2447

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ module nf_multihead_attention_layer
3434
real, allocatable :: o_input(:, :)
3535
contains
3636

37-
procedure :: backward
38-
procedure :: forward
37+
procedure :: common_backward
38+
procedure :: common_forward
3939
procedure :: split_heads
4040
procedure :: create_attention_matrix
4141
procedure :: normalize_attention_matrix
@@ -59,24 +59,24 @@ end function multihead_attention_layer_cons
5959

6060
interface
6161

62-
module subroutine backward(self, input, gradient)
62+
module subroutine common_backward(self, input, gradient)
6363
!! General backprop for MultiHead Attention mechanism
6464
!! Might be used for both Self and Cross Attention
6565
!! Self Attention: sum output gradients
6666
!! Cross Attention: use them separately
6767
class(multihead_attention_layer), intent(in out) :: self
6868
real, intent(in) :: input(:, :)
6969
real, intent(in) :: gradient(:, :)
70-
end subroutine backward
70+
end subroutine common_backward
7171

72-
module subroutine forward(self, query, key, value)
72+
module subroutine common_forward(self, query, key, value)
7373
!! General forward propagation for MultiHead Attention Mechanism
7474
!! Might be used for both Self and Cross Attention
7575
!! Self Attention: pass the same value thrice
7676
!! Cross Attention: pass three values for your query, key and value
7777
class(multihead_attention_layer), intent(in out) :: self
7878
real, intent(in) :: query(:, :), key(:, :), value(:, :)
79-
end subroutine forward
79+
end subroutine common_forward
8080

8181
module subroutine init(self, input_shape)
8282
!! Initialize the layer data structures.
@@ -114,7 +114,7 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension,
114114
res % softmax_func = softmax()
115115
end function multihead_attention_layer_cons
116116

117-
module subroutine backward(self, input, gradient)
117+
module subroutine common_backward(self, input, gradient)
118118
class(multihead_attention_layer), intent(in out) :: self
119119
real, intent(in) :: input(:, :)
120120
real, intent(in) :: gradient(:, :)
@@ -210,9 +210,9 @@ module subroutine backward(self, input, gradient)
210210
deallocate(d_normalize)
211211
deallocate(dq)
212212
deallocate(dk)
213-
end subroutine backward
213+
end subroutine common_backward
214214

215-
module subroutine forward(self, query, key, value)
215+
module subroutine common_forward(self, query, key, value)
216216
class(multihead_attention_layer), intent(in out) :: self
217217
real, intent(in) :: query(:, :), key(:, :), value(:, :)
218218

@@ -254,7 +254,7 @@ module subroutine forward(self, query, key, value)
254254
deallocate(q)
255255
deallocate(k)
256256
deallocate(v)
257-
end subroutine forward
257+
end subroutine common_forward
258258

259259
module function split_heads(self, input) result(output)
260260
!! Split inputs into heads

test/test_multihead_attention_layer.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ program test_multihead_attention_layer
1414

1515
attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2)
1616
call attention % init([0])
17-
!
17+
1818
call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
1919
call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
2020
call test_multihead_attention_normalization(attention, ok)
@@ -23,7 +23,7 @@ program test_multihead_attention_layer
2323
call test_multihead_attention_forward(attention, ok)
2424
call test_multihead_attention_backward(attention, ok)
2525
call test_multihead_attention_update_gradients(attention, ok)
26-
! call test_multihead_attention_forward_reallife_shape(ok)
26+
call test_multihead_attention_forward_reallife_shape(ok)
2727

2828
contains
2929
subroutine test_multihead_attention_split_heads(attention, input, ok, output)
@@ -156,7 +156,7 @@ subroutine test_multihead_attention_forward(attention, ok)
156156
0.447508544, 0.464612424, 0.464721352, 0.473546445, 0.512576580, 0.513393998&
157157
]
158158

159-
call attention % forward(input, input, input)
159+
call attention % common_forward(input, input, input)
160160

161161
output_shape = shape(attention % output)
162162
if (.not. all(output_shape.eq.expected_shape)) then
@@ -196,7 +196,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
196196
attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8)
197197
call attention % init([0])
198198

199-
call attention % forward(input, input, input)
199+
call attention % common_forward(input, input, input)
200200

201201
output_shape = shape(attention % output)
202202
if (.not. all(output_shape.eq.expected_shape)) then
@@ -221,7 +221,7 @@ subroutine test_multihead_attention_backward(attention, ok)
221221
real :: output_flat(12)
222222
real :: output_shape(2)
223223

224-
call attention % backward(input, gradient)
224+
call attention % common_backward(input, gradient)
225225

226226
! sample for Self Attention: sum of output gradients
227227
! FIXME: remove reshapes when linear2d situation is resolved
@@ -271,7 +271,7 @@ subroutine test_multihead_attention_update_gradients(attention, ok)
271271
call optim % minimize(parameters, attention % get_gradients())
272272
call attention % set_params(parameters)
273273

274-
call attention % forward(&
274+
call attention % common_forward(&
275275
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),&
276276
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),&
277277
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])&

0 commit comments

Comments
 (0)