Skip to content

Add text-backbone support for Gemma 4 12B and 12B-it#1570

Open
deep-diver wants to merge 1 commit into
google:mainfrom
deep-diver:add-gemma4-12b-support
Open

Add text-backbone support for Gemma 4 12B and 12B-it#1570
deep-diver wants to merge 1 commit into
google:mainfrom
deep-diver:add-gemma4-12b-support

Conversation

@deep-diver

Copy link
Copy Markdown

Summary

This PR adds Tunix text-backbone support for:

  • google/gemma-4-12b
  • google/gemma-4-12b-it

It adds Gemma 4 12B model configs, registry entries, and safetensors mapping tests based on the existing Gemma 4 implementation.

This change covers the Gemma 4 12B language-model/text backbone only. Full Gemma 4 Unified multimodal image/audio/video support is out of scope.

What Changed

  • Added ModelConfig.gemma4_12b()
  • Added ModelConfig.gemma4_12b_it()
  • Registered:
    • google/gemma-4-12b
    • google/gemma-4-12b-it
  • Added tests for:
    • config values
    • registry lookup
    • safetensors local/global attention mapping
    • AutoModel resolution

Implementation Notes

The 12B config follows the HF google/gemma-4-12B-it text config:

  • 48 layers
  • hidden size 3840
  • intermediate size 15360
  • 16 attention heads
  • 8 KV heads
  • 1 global KV head
  • global head dim 512
  • sliding window 1024
  • 5 local sliding layers followed by 1 global layer
  • global attention uses k_eq_v

Validation

Local tests:

pytest tests/models/gemma4/model_test.py tests/models/gemma4/params_safetensors_test.py

Result:

10 passed

Registry / naming / AutoModel tests:

586 passed, 61 deselected

GPU validation on A100-80GB:

  • Loaded full google/gemma-4-12B-it safetensors checkpoint in Tunix/JAX
  • Ran full forward pass
  • Ran text generation through tunix.generate.sampler
  • Verified HF-vs-Tunix short-prompt float32 parity
  • Ran MMLU-Pro-style HF-vs-Tunix parity probes

Selected parity evidence:

Short prompt float32 last-token logits cosine: 0.9999985
Top-50 overlap: 50/50

500-sample MMLU-Pro-style parity probe:

Prediction match: 481/500 = 96.2%
Wilson 95% CI: 94.14% - 97.55%
HF accuracy: 75/500 = 15.0%
Tunix accuracy: 75/500 = 15.0%
Delta: 0.0 pp

Note: this is a benchmark-style parity probe, not an official MMLU-Pro benchmark reproduction.

Generation smoke tests:

  • Arithmetic
  • Explanation
  • Code generation
  • Structured extraction
  • Gemma 4 thinking-template prompt

Tunix generated coherent text outputs and matched HF parsed outputs on representative text-only samples.

Scope

This PR does not add full Gemma 4 Unified multimodal support.

The current Tunix Gemma4 model path supports text tokens only. The unified checkpoint contains multimodal weights such as:

model.embed_audio.embedding_projection.weight
model.embed_vision.embedding_projection.weight
model.vision_embedder.*

These are intentionally out of scope for this PR and are not loaded by the current text-backbone implementation.

@google-cla

google-cla Bot commented Jun 9, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@deep-diver deep-diver force-pushed the add-gemma4-12b-support branch from 9f259ee to 490a1b6 Compare June 9, 2026 05:34
@tianshub

Copy link
Copy Markdown
Collaborator

@deep-diver thanks a lot for adding the support. A followup question: how is the MMLU result being measured? is it using Tunix's sampler? Meanwhile, can you also include a sampling output in the test plan?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants