feat(nnx): add GQA support to MultiHeadAttention#5259
feat(nnx): add GQA support to MultiHeadAttention#5259ayulockedin wants to merge 2 commits intogoogle:mainfrom
Conversation
Summary of ChangesHello @ayulockedin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Grouped-Query Attention (GQA) in nnx.MultiHeadAttention by adding the num_key_value_heads parameter. The changes are well-implemented, covering the module's initialization, projection layers, and cache handling for autoregressive decoding. Backward compatibility is maintained by correctly defaulting num_key_value_heads to num_heads. The addition of a dedicated test case for GQA is a great inclusion, verifying the correctness of projection shapes, forward pass, and cache logic. The code is clear and the changes are robust. I have one minor suggestion to improve code consistency.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request effectively adds support for Grouped-Query Attention (GQA) to the nnx.MultiHeadAttention module. The implementation is clean and correct, introducing the num_key_value_heads parameter and updating the projection layers and cache logic accordingly while maintaining backward compatibility. The accompanying test case is thorough, validating the new functionality for both standard and autoregressive decoding paths. My only suggestion is to update the class docstring to reflect the addition of the new parameter.
fdf6dff to
91fea70
Compare
This PR adds support for Grouped-Query Attention (GQA) to the nnx.MultiHeadAttention module.
Currently, nnx.MultiHeadAttention enforces symmetry between the number of query heads and key/value heads. This update introduces the num_key_value_heads argument, allowing for asymmetric configurations where multiple query heads share a single key/value head. This is a critical architectural feature for modern LLMs (e.g., Llama 3, Mistral) to reduce KV cache memory usage during inference.
Technical Changes:
API Update: Added num_key_value_heads to init. It defaults to None (which sets it equal to num_heads), ensuring full backward compatibility with standard Multi-Head Attention.
Projection Logic: Decoupled the LinearGeneral layers. The Query projection now uses num_heads, while Key and Value projections use num_key_value_heads.
Head Dimension: Logic updated to ensure head_dim is consistently derived from qkv_features // num_heads, preserving the correct embedding size across all heads.
Decoding & Cache: Updated call (decoding path), init_cache, and set_view to allocate and validate cache shapes based on the smaller num_key_value_heads count rather than the query head count.
verification:
Added a new test case test_gqa_multihead_attention in tests/nnx/nn/attention_test.py that verifies:
Correct projection shapes (e.g., 32 Query heads vs 8 KV heads).
Successful forward pass (broadcasting).
Correct KV cache initialization and autoregressive decoding updates.
Fixes #5198
Checklist