From 23918aeae0d58291e5e52eb6824a0659787c8698 Mon Sep 17 00:00:00 2001 From: Ayush Kumar Date: Wed, 18 Feb 2026 15:29:49 +0000 Subject: [PATCH 1/2] feat(nnx): add GQA support to MultiHeadAttention --- flax/nnx/nn/attention.py | 52 ++++++++++++++++++++++++++-------- tests/nnx/nn/attention_test.py | 31 ++++++++++++++++++++ 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 58f5bfe96..5e8cfabf9 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -406,6 +406,7 @@ def __init__( in_features: int, qkv_features: int | None = None, out_features: int | None = None, + num_key_value_heads: int | None = None, in_kv_features: int | None = None, *, dtype: Dtype | None = None, @@ -450,6 +451,16 @@ def __init__( self.in_kv_features = ( in_kv_features if in_kv_features is not None else in_features ) + self.num_key_value_heads = ( + num_key_value_heads if num_key_value_heads is not None else num_heads + ) + + if self.num_heads % self.num_key_value_heads != 0: + raise ValueError( + f"num_heads ({self.num_heads}) must be divisible by " + f"num_key_value_heads ({self.num_key_value_heads})." + ) + self.dtype = dtype self.param_dtype = param_dtype self.broadcast_dropout = broadcast_dropout @@ -478,7 +489,7 @@ def __init__( linear_general = functools.partial( LinearGeneral, - out_features=(self.num_heads, self.head_dim), + # out_features is removed here so we can customize it below for GQA dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=kernel_init, @@ -491,11 +502,28 @@ def __init__( kernel_metadata=kernel_metadata, bias_metadata=bias_metadata, ) + # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] - self.query = linear_general(self.in_features, rngs=rngs) - self.key = linear_general(self.in_kv_features, rngs=rngs) - self.value = linear_general(self.in_kv_features, rngs=rngs) + + # 1. Query Projection: Uses standard num_heads + self.query = linear_general( + self.in_features, + out_features=(self.num_heads, self.head_dim), + rngs=rngs + ) + + # 2. Key/Value Projections: Uses num_key_value_heads (GQA) + self.key = linear_general( + self.in_kv_features, + out_features=(self.num_key_value_heads, self.head_dim), + rngs=rngs + ) + self.value = linear_general( + self.in_kv_features, + out_features=(self.num_key_value_heads, self.head_dim), + rngs=rngs + ) self.query_ln: LayerNorm | None self.key_ln: LayerNorm | None @@ -646,16 +674,16 @@ def __call__( ( *batch_dims, max_length, - num_heads, + num_kv_heads, depth_per_head, ) = self.cached_key.shape - # shape check of cached keys against query input - expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) - if expected_shape != query.shape: + # shape check of cached keys against key input + expected_shape = tuple(batch_dims) + (1, num_kv_heads, depth_per_head) + if expected_shape != key.shape: raise ValueError( 'Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' - % (expected_shape, query.shape) + 'expected key shape %s instead got %s.' + % (expected_shape, key.shape) ) # update key, value caches with our new 1d spatial slices cur_index = self.cache_index[...] @@ -747,7 +775,7 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x) """ - cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim) + cache_shape = (*input_shape[:-1], self.num_key_value_heads, self.head_dim) self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype)) self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype)) self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32)) @@ -791,7 +819,7 @@ def set_view( batch_size = (batch_size,) # initialize cache - cache_shape = (*batch_size, max_length, self.num_heads, self.head_dim) + cache_shape = (*batch_size, max_length, self.num_key_value_heads, self.head_dim) self.cached_key = nnx.Cache(jnp.zeros(cache_shape, self.dtype)) self.cached_value = nnx.Cache(jnp.zeros(cache_shape, self.dtype)) self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32)) diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index bbc48847a..37e0452e1 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -324,6 +324,37 @@ def test_gqa_invalid_heads(self): with self.assertRaisesRegex(ValueError, "must be a multiple"): nnx.dot_product_attention(query, key, value) + def test_gqa_multihead_attention(self): + in_feat = 128 + n_heads = 32 + n_kv_heads = 8 + qkv_feat = 2048 + head_dim = qkv_feat // n_heads + + model = nnx.MultiHeadAttention( + num_heads=n_heads, + in_features=in_feat, + qkv_features=qkv_feat, + num_key_value_heads=n_kv_heads, + rngs=nnx.Rngs(0), + ) + + assert model.query.kernel.shape == (in_feat, n_heads, head_dim) + assert model.key.kernel.shape == (in_feat, n_kv_heads, head_dim) + assert model.value.kernel.shape == (in_feat, n_kv_heads, head_dim) + + x = jnp.ones((1, 10, in_feat)) + y = model(x, decode=False) + assert y.shape == (1, 10, in_feat) + + model.init_cache((1, 10, in_feat)) + assert model.cached_key.shape == (1, 10, n_kv_heads, head_dim) + + x_token = jnp.ones((1, 1, in_feat)) + y_token = model(x_token, decode=True) + assert y_token.shape == (1, 1, in_feat) + assert model.cache_index == 1 + def test_gqa_parity_with_jax(self): class DummyModule(nnx.Module): pass From 91fea70eff40316dd772288124d48b433e3cc05b Mon Sep 17 00:00:00 2001 From: Ayush Kumar Date: Wed, 18 Feb 2026 16:32:51 +0000 Subject: [PATCH 2/2] Implemented f'string --- flax/nnx/nn/attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 5e8cfabf9..18f33562c 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -352,6 +352,10 @@ class MultiHeadAttention(Module): qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection. in_kv_features: number of input features for computing key and value. + num_key_value_heads: number of key and value heads. If None, it defaults to + ``num_heads``. If set to a value smaller than ``num_heads``, Grouped Query + Attention (GQA) is used. ``num_heads`` must be divisible by + ``num_key_value_heads``. dtype: the dtype of the computation (default: infer from inputs and params) param_dtype: the dtype passed to parameter initializers (default: float32) broadcast_dropout: bool: use a broadcasted dropout along batch dims. @@ -682,8 +686,7 @@ def __call__( if expected_shape != key.shape: raise ValueError( 'Autoregressive cache shape error, ' - 'expected key shape %s instead got %s.' - % (expected_shape, key.shape) + f'expected key shape {expected_shape} instead got {key.shape}.' ) # update key, value caches with our new 1d spatial slices cur_index = self.cache_index[...]