Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ class MultiHeadAttention(Module):
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection.
in_kv_features: number of input features for computing key and value.
num_key_value_heads: number of key and value heads. If None, it defaults to
``num_heads``. If set to a value smaller than ``num_heads``, Grouped Query
Attention (GQA) is used. ``num_heads`` must be divisible by
``num_key_value_heads``.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
Expand Down Expand Up @@ -406,6 +410,7 @@ def __init__(
in_features: int,
qkv_features: int | None = None,
out_features: int | None = None,
num_key_value_heads: int | None = None,
in_kv_features: int | None = None,
*,
dtype: Dtype | None = None,
Expand Down Expand Up @@ -450,6 +455,16 @@ def __init__(
self.in_kv_features = (
in_kv_features if in_kv_features is not None else in_features
)
self.num_key_value_heads = (
num_key_value_heads if num_key_value_heads is not None else num_heads
)

if self.num_heads % self.num_key_value_heads != 0:
raise ValueError(
f"num_heads ({self.num_heads}) must be divisible by "
f"num_key_value_heads ({self.num_key_value_heads})."
)

self.dtype = dtype
self.param_dtype = param_dtype
self.broadcast_dropout = broadcast_dropout
Expand Down Expand Up @@ -478,7 +493,7 @@ def __init__(

linear_general = functools.partial(
LinearGeneral,
out_features=(self.num_heads, self.head_dim),
# out_features is removed here so we can customize it below for GQA
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=kernel_init,
Expand All @@ -491,11 +506,28 @@ def __init__(
kernel_metadata=kernel_metadata,
bias_metadata=bias_metadata,
)

# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
self.query = linear_general(self.in_features, rngs=rngs)
self.key = linear_general(self.in_kv_features, rngs=rngs)
self.value = linear_general(self.in_kv_features, rngs=rngs)

# 1. Query Projection: Uses standard num_heads
self.query = linear_general(
self.in_features,
out_features=(self.num_heads, self.head_dim),
rngs=rngs
)

# 2. Key/Value Projections: Uses num_key_value_heads (GQA)
self.key = linear_general(
self.in_kv_features,
out_features=(self.num_key_value_heads, self.head_dim),
rngs=rngs
)
self.value = linear_general(
self.in_kv_features,
out_features=(self.num_key_value_heads, self.head_dim),
rngs=rngs
)

self.query_ln: LayerNorm | None
self.key_ln: LayerNorm | None
Expand Down Expand Up @@ -646,16 +678,15 @@ def __call__(
(
*batch_dims,
max_length,
num_heads,
num_kv_heads,
depth_per_head,
) = self.cached_key.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
# shape check of cached keys against key input
expected_shape = tuple(batch_dims) + (1, num_kv_heads, depth_per_head)
if expected_shape != key.shape:
raise ValueError(
'Autoregressive cache shape error, '
'expected query shape %s instead got %s.'
% (expected_shape, query.shape)
f'expected key shape {expected_shape} instead got {key.shape}.'
)
# update key, value caches with our new 1d spatial slices
cur_index = self.cache_index[...]
Expand Down Expand Up @@ -747,7 +778,7 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)
"""
cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim)
cache_shape = (*input_shape[:-1], self.num_key_value_heads, self.head_dim)
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
Expand Down Expand Up @@ -791,7 +822,7 @@ def set_view(
batch_size = (batch_size,)

# initialize cache
cache_shape = (*batch_size, max_length, self.num_heads, self.head_dim)
cache_shape = (*batch_size, max_length, self.num_key_value_heads, self.head_dim)
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, self.dtype))
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, self.dtype))
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
Expand Down
31 changes: 31 additions & 0 deletions tests/nnx/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,37 @@ def test_gqa_invalid_heads(self):
with self.assertRaisesRegex(ValueError, "must be a multiple"):
nnx.dot_product_attention(query, key, value)

def test_gqa_multihead_attention(self):
in_feat = 128
n_heads = 32
n_kv_heads = 8
qkv_feat = 2048
head_dim = qkv_feat // n_heads

model = nnx.MultiHeadAttention(
num_heads=n_heads,
in_features=in_feat,
qkv_features=qkv_feat,
num_key_value_heads=n_kv_heads,
rngs=nnx.Rngs(0),
)

assert model.query.kernel.shape == (in_feat, n_heads, head_dim)
assert model.key.kernel.shape == (in_feat, n_kv_heads, head_dim)
assert model.value.kernel.shape == (in_feat, n_kv_heads, head_dim)

x = jnp.ones((1, 10, in_feat))
y = model(x, decode=False)
assert y.shape == (1, 10, in_feat)

model.init_cache((1, 10, in_feat))
assert model.cached_key.shape == (1, 10, n_kv_heads, head_dim)

x_token = jnp.ones((1, 1, in_feat))
y_token = model(x_token, decode=True)
assert y_token.shape == (1, 1, in_feat)
assert model.cache_index == 1

def test_gqa_parity_with_jax(self):
class DummyModule(nnx.Module):
pass
Expand Down