Skip to content

feat(nnx): add GQA support to MultiHeadAttention#5259

Open
ayulockedin wants to merge 2 commits intogoogle:mainfrom
ayulockedin:feat/gqa-support-mha
Open

feat(nnx): add GQA support to MultiHeadAttention#5259
ayulockedin wants to merge 2 commits intogoogle:mainfrom
ayulockedin:feat/gqa-support-mha

Conversation

@ayulockedin
Copy link
Contributor

@ayulockedin ayulockedin commented Feb 18, 2026

This PR adds support for Grouped-Query Attention (GQA) to the nnx.MultiHeadAttention module.

Currently, nnx.MultiHeadAttention enforces symmetry between the number of query heads and key/value heads. This update introduces the num_key_value_heads argument, allowing for asymmetric configurations where multiple query heads share a single key/value head. This is a critical architectural feature for modern LLMs (e.g., Llama 3, Mistral) to reduce KV cache memory usage during inference.

Technical Changes:
API Update: Added num_key_value_heads to init. It defaults to None (which sets it equal to num_heads), ensuring full backward compatibility with standard Multi-Head Attention.

Projection Logic: Decoupled the LinearGeneral layers. The Query projection now uses num_heads, while Key and Value projections use num_key_value_heads.

Head Dimension: Logic updated to ensure head_dim is consistently derived from qkv_features // num_heads, preserving the correct embedding size across all heads.

Decoding & Cache: Updated call (decoding path), init_cache, and set_view to allocate and validate cache shapes based on the smaller num_key_value_heads count rather than the query head count.

verification:
Added a new test case test_gqa_multihead_attention in tests/nnx/nn/attention_test.py that verifies:

Correct projection shapes (e.g., 32 Query heads vs 8 KV heads).

Successful forward pass (broadcasting).

Correct KV cache initialization and autoregressive decoding updates.

Fixes #5198

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ayulockedin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the nnx.MultiHeadAttention module by integrating Grouped-Query Attention (GQA). This architectural improvement allows for a reduced number of key/value heads compared to query heads, which is crucial for optimizing memory consumption in large language models during inference. The changes involve modifying the attention mechanism's initialization, projection layers, and cache management to correctly handle the asymmetric head configuration, ensuring efficient and backward-compatible operation.

Highlights

  • GQA Support: Introduced support for Grouped-Query Attention (GQA) in nnx.MultiHeadAttention by adding a num_key_value_heads argument, allowing for asymmetric configurations where multiple query heads share key/value heads.
  • API Update: The __init__ method of MultiHeadAttention now accepts num_key_value_heads, which defaults to num_heads for backward compatibility. A validation check ensures num_heads is divisible by num_key_value_heads.
  • Projection Logic: Decoupled the LinearGeneral layers for query, key, and value projections. Query projection now uses num_heads, while key and value projections use the new num_key_value_heads.
  • Cache Management: Updated the __call__ method (decoding path), init_cache, and set_view to correctly allocate and validate cache shapes based on num_key_value_heads, optimizing KV cache memory usage.
  • Testing: Added a new test case, test_gqa_multihead_attention, to verify correct projection shapes, successful forward passes, and proper KV cache initialization and updates for GQA.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flax/nnx/nn/attention.py
    • Added num_key_value_heads parameter to the MultiHeadAttention constructor.
    • Implemented logic to initialize self.num_key_value_heads and added a validation check for divisibility by num_heads.
    • Modified LinearGeneral partial application to allow dynamic out_features for query, key, and value projections.
    • Separated the out_features for query, key, and value projections, using self.num_heads for query and self.num_key_value_heads for key/value.
    • Updated the cache shape validation in the __call__ method to use num_kv_heads instead of num_heads.
    • Adjusted the cache_shape calculation in init_cache to use self.num_key_value_heads.
    • Modified the cache_shape calculation in set_view to use self.num_key_value_heads.
  • tests/nnx/nn/attention_test.py
    • Added test_gqa_multihead_attention to validate GQA functionality, including projection shapes, forward pass, and KV cache behavior.
Activity
  • The change was discussed in a GitHub issue or discussion, as indicated by the checklist.
  • Documentation and docstrings have been updated to adhere to guidelines.
  • High-coverage tests have been included to ensure the quality of the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@ayulockedin
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Grouped-Query Attention (GQA) in nnx.MultiHeadAttention by adding the num_key_value_heads parameter. The changes are well-implemented, covering the module's initialization, projection layers, and cache handling for autoregressive decoding. Backward compatibility is maintained by correctly defaulting num_key_value_heads to num_heads. The addition of a dedicated test case for GQA is a great inclusion, verifying the correctness of projection shapes, forward pass, and cache logic. The code is clear and the changes are robust. I have one minor suggestion to improve code consistency.

@ayulockedin
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively adds support for Grouped-Query Attention (GQA) to the nnx.MultiHeadAttention module. The implementation is clean and correct, introducing the num_key_value_heads parameter and updating the projection layers and cache logic accordingly while maintaining backward compatibility. The accompanying test case is thorough, validating the new functionality for both standard and autoregressive decoding paths. My only suggestion is to update the class docstring to reflect the addition of the new parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(nnx): Expose GQA support (num_key_value_heads) in MultiHeadAttention

1 participant

Comments