Skip to content

Commit 475cd06

Browse files
committed
multihead_attention: update attention in accordance with linear2d
1 parent 3c3f185 commit 475cd06

File tree

4 files changed

+54
-24
lines changed

4 files changed

+54
-24
lines changed

src/nf/nf_cross_attention_layer.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ module function cross_attention_layer_cons(sequence_length, model_dimension, n_h
4242
end if
4343
res % head_size = model_dimension / n_heads
4444

45-
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
46-
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
47-
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
48-
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
49-
call res % query_layer % init([0])
50-
call res % key_layer % init([0])
51-
call res % value_layer % init([0])
52-
call res % output_layer % init([0])
45+
res % query_layer = linear2d_layer(model_dimension)
46+
res % key_layer = linear2d_layer(model_dimension)
47+
res % value_layer = linear2d_layer(model_dimension)
48+
res % output_layer = linear2d_layer(model_dimension)
49+
call res % query_layer % init([sequence_length, model_dimension])
50+
call res % key_layer % init([sequence_length, model_dimension])
51+
call res % value_layer % init([sequence_length, model_dimension])
52+
call res % output_layer % init([sequence_length, model_dimension])
5353

5454
res % softmax_func = softmax()
5555
end function cross_attention_layer_cons

src/nf/nf_multihead_attention_submodule.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension,
2020
end if
2121
res % head_size = model_dimension / n_heads
2222

23-
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
24-
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
25-
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
26-
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
27-
call res % query_layer % init([0])
28-
call res % key_layer % init([0])
29-
call res % value_layer % init([0])
30-
call res % output_layer % init([0])
23+
res % query_layer = linear2d_layer(model_dimension)
24+
res % key_layer = linear2d_layer(model_dimension)
25+
res % value_layer = linear2d_layer(model_dimension)
26+
res % output_layer = linear2d_layer(model_dimension)
27+
call res % query_layer % init([sequence_length, model_dimension])
28+
call res % key_layer % init([sequence_length, model_dimension])
29+
call res % value_layer % init([sequence_length, model_dimension])
30+
call res % output_layer % init([sequence_length, model_dimension])
3131

3232
res % softmax_func = softmax()
3333
end function multihead_attention_layer_cons

src/nf/nf_self_attention_layer.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ module function self_attention_layer_cons(sequence_length, model_dimension, n_he
4242
end if
4343
res % head_size = model_dimension / n_heads
4444

45-
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
46-
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
47-
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
48-
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
49-
call res % query_layer % init([0])
50-
call res % key_layer % init([0])
51-
call res % value_layer % init([0])
52-
call res % output_layer % init([0])
45+
res % query_layer = linear2d_layer(model_dimension)
46+
res % key_layer = linear2d_layer(model_dimension)
47+
res % value_layer = linear2d_layer(model_dimension)
48+
res % output_layer = linear2d_layer(model_dimension)
49+
call res % query_layer % init([sequence_length, model_dimension])
50+
call res % key_layer % init([sequence_length, model_dimension])
51+
call res % value_layer % init([sequence_length, model_dimension])
52+
call res % output_layer % init([sequence_length, model_dimension])
5353

5454
res % softmax_func = softmax()
5555
end function self_attention_layer_cons

test/test_multihead_attention_layer.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ program test_multihead_attention_layer
1616

1717
attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2)
1818
call attention % init_base([0])
19+
call set_weights(attention)
1920

2021
call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
2122
call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
@@ -30,6 +31,18 @@ program test_multihead_attention_layer
3031
call test_cross_attention(ok)
3132

3233
contains
34+
subroutine set_weights(attention)
35+
type(multihead_attention_layer), intent(in out) :: attention
36+
attention % query_layer % weights = 0.1
37+
attention % key_layer % weights = 0.1
38+
attention % value_layer % weights = 0.1
39+
attention % output_layer % weights = 0.1
40+
attention % query_layer % biases = 0.11
41+
attention % key_layer % biases = 0.11
42+
attention % value_layer % biases = 0.11
43+
attention % output_layer % biases = 0.11
44+
end subroutine set_weights
45+
3346
subroutine test_multihead_attention_split_heads(attention, input, ok, output)
3447
type(multihead_attention_layer), intent(in) :: attention
3548
real, intent(in) :: input(:, :)
@@ -199,6 +212,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
199212

200213
attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8)
201214
call attention % init_base([0])
215+
call set_weights(attention)
202216

203217
call attention % common_forward(input, input, input)
204218

@@ -305,6 +319,14 @@ subroutine test_self_attention(ok)
305319

306320
attention = self_attention_layer(sequence_length=2, model_dimension=3, n_heads=1)
307321
call attention % init([0])
322+
attention % query_layer % weights = 0.1
323+
attention % key_layer % weights = 0.1
324+
attention % value_layer % weights = 0.1
325+
attention % output_layer % weights = 0.1
326+
attention % query_layer % biases = 0.11
327+
attention % key_layer % biases = 0.11
328+
attention % value_layer % biases = 0.11
329+
attention % output_layer % biases = 0.11
308330

309331
call attention % forward(input)
310332
output_flat = reshape(attention % output, shape(output_flat))
@@ -346,6 +368,14 @@ subroutine test_cross_attention(ok)
346368

347369
attention = cross_attention_layer(sequence_length=2, model_dimension=3, n_heads=1)
348370
call attention % init([0])
371+
attention % query_layer % weights = 0.1
372+
attention % key_layer % weights = 0.1
373+
attention % value_layer % weights = 0.1
374+
attention % output_layer % weights = 0.1
375+
attention % query_layer % biases = 0.11
376+
attention % key_layer % biases = 0.11
377+
attention % value_layer % biases = 0.11
378+
attention % output_layer % biases = 0.11
349379

350380
call attention % forward(input)
351381
output_flat = reshape(attention % output, shape(output_flat))

0 commit comments

Comments
 (0)