@@ -26,9 +26,12 @@ module nf_multihead_attention_layer
2626 real , allocatable :: sdpa(:, :, :, :)
2727 real , allocatable :: output(:, :, :)
2828
29+ real , allocatable :: q_input(:, :, :)
30+ real , allocatable :: k_input(:, :, :)
31+ real , allocatable :: v_input(:, :, :)
2932 contains
3033
31- ! procedure :: backward
34+ procedure :: backward
3235 procedure :: forward
3336 procedure :: split_heads
3437 procedure :: create_attention_matrix
@@ -49,15 +52,15 @@ end function multihead_attention_layer_cons
4952
5053 interface
5154
52- pure module subroutine backward(self, input, gradient)
55+ module subroutine backward (self , input , gradient )
5356 ! ! Apply the backward gradient descent pass.
5457 ! ! Only weight and bias gradients are updated in this subroutine,
5558 ! ! while the weights and biases themselves are untouched.
5659 class(multihead_attention_layer), intent (in out ) :: self
5760 ! ! Dense layer instance
58- real , intent (in ) :: input(:)
61+ real , intent (in ) :: input(:, :, : )
5962 ! ! Input from the previous layer
60- real , intent (in ) :: gradient(:)
63+ real , intent (in ) :: gradient(:, :, : )
6164 ! ! Gradient from the next layer
6265 end subroutine backward
6366
@@ -109,6 +112,20 @@ module function multihead_attention_layer_cons(&
109112 res % softmax_func = softmax()
110113 end function multihead_attention_layer_cons
111114
115+ module subroutine backward (self , input , gradient )
116+ class(multihead_attention_layer), intent (in out ) :: self
117+ real , intent (in ) :: input(:, :, :)
118+ real , intent (in ) :: gradient(:, :, :)
119+
120+ call self % output_layer % backward(input, gradient)
121+
122+ ! FIXME: calculate gradient for softmax
123+
124+ call self % value_layer % backward(self % v_input, self % output_layer % gradient)
125+ call self % key_layer % backward(self % k_input, self % output_layer % gradient)
126+ call self % query_layer % backward(self % q_input, self % output_layer % gradient)
127+ end subroutine backward
128+
112129 module subroutine forward (self , query , key , value )
113130 class(multihead_attention_layer), intent (in out ) :: self
114131 real , intent (in ) :: query(:, :, :), key(:, :, :), value(:, :, :)
@@ -119,6 +136,10 @@ module subroutine forward(self, query, key, value)
119136 real :: attention_matrix(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)
120137 real :: dot_product_attention(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
121138
139+ self % q_input = query
140+ self % k_input = key
141+ self % v_input = value
142+
122143 call self % query_layer % forward(query)
123144 call self % key_layer % forward(key)
124145 call self % value_layer % forward(value)
@@ -237,5 +258,9 @@ module subroutine init(self, input_shape)
237258 self % n_heads, self % sequence_length, self % head_size, self % batch_size&
238259 ))
239260 allocate (self % output(self % sequence_length, self % model_dimension, self % batch_size))
261+
262+ allocate (self % q_input(self % sequence_length, self % model_dimension, self % batch_size))
263+ allocate (self % k_input(self % sequence_length, self % model_dimension, self % batch_size))
264+ allocate (self % v_input(self % sequence_length, self % model_dimension, self % batch_size))
240265 end subroutine init
241266end module nf_multihead_attention_layer
0 commit comments