Skip to content

[Tunix] Add multi-modal (vision) support to Gemma 4.#1608

Open
msghik wants to merge 17 commits into
google:mainfrom
msghik:gemma4-vision-multimodal
Open

[Tunix] Add multi-modal (vision) support to Gemma 4.#1608
msghik wants to merge 17 commits into
google:mainfrom
msghik:gemma4-vision-multimodal

Conversation

@msghik

@msghik msghik commented Jun 21, 2026

Copy link
Copy Markdown

Resolves #1543

Enables SFT / LoRA / RL of Gemma 4 with text + image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling code paths (peft_trainer, sampler, rl/common) already forward an images kwarg to the model — they just had no Gemma 4 model that accepted it. This change closes that gap at the model level so they work end-to-end on Gemma 4.

What changed

tunix/models/gemma4/model.py

ModelConfig gains an optional vision_config: SigLIPConfig | None, exposed via a text_only: bool = True parameter on the gemma4_e2b / gemma4_e4b / gemma4_31b / gemma4_26b_a4b factory classmethods (matches the existing gemma3_*_pt/it(text_only=...) pattern).
ShardingConfig gains a siglip field for the vision encoder's sharding.
RMSNorm now accepts either a ShardingConfig (existing call sites) or a sharding tuple (new vision call site) — fully backward compatible.
Embedder optionally builds the mm_input_projection and mm_soft_embedding_norm layers and exposes encode_vision.
Gemma4.init constructs a SigLiP encoder when vision_config is set; Gemma4.call accepts an images kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions.
Adds get_attention_mask (bidirectional over image spans, causal over text) and get_model_input (used for LoRA tracing) on the model.
tunix/models/gemma4/params_safetensors.py

_get_key_and_transform_mapping adds vision-tower and multi-modal projector mappings when vision_config is set (no change for text-only configs).
tests/models/gemma4/model_test.py

4 new tests: multi-modal forward pass, text-only construction has no vision encoder, helpful error when images are passed to a text-only model, and the text-only attention-mask shape.
Design notes

The SigLIP encoder (vision.py), the embedding-merge helper (merge_embeddings.py), and the bidirectional-causal attention mask (utils.py) are reused from tunix.models.gemma3 to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Happy to refactor these into a shared tunix.models.common namespace if reviewers prefer.
The change is strictly additive: vision_config defaults to None, images defaults to None. Existing Gemma 4 call sites and tests are unaffected.
Usage

from tunix.models.gemma4 import model as model_lib
from tunix.models.gemma4 import params_safetensors

config = model_lib.ModelConfig.gemma4_e4b(text_only=False) # enables SigLIP
model = params_safetensors.create_model_from_safe_tensors(
file_dir=ckpt_path, config=config, mesh=mesh, dtype=jnp.bfloat16,
)

logits, _ = model(
tokens,
positions=positions,
attention_mask=model.get_attention_mask(tokens),
images=images,
)

msghik added 17 commits May 27, 2026 20:14
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the
SigLIP-based integration already present in Gemma 3. The training and
sampling infrastructures (peft_trainer, sampler, rl/common) already
forward an `images` kwarg to the model; this change closes the gap at
the model level so they actually work end-to-end on Gemma 4.

Specifically:
- `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`,
  exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b
  factory classmethods (matching gemma3_*_pt/it).
- `ShardingConfig` gains a `siglip` field for the encoder's sharding.
- `Embedder` optionally builds the `mm_input_projection` and
  `mm_soft_embedding_norm` layers and exposes `encode_vision`.
- `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config
  is set; `__call__` accepts an `images` kwarg and merges the soft
  vision tokens into the text embeddings at the placeholder positions.
- Adds `get_attention_mask` (bidirectional over image spans) and
  `get_model_input` (used for LoRA tracing) on the model.
- `params_safetensors._get_key_and_transform_mapping` adds vision-tower
  and multi-modal projector mappings when vision is enabled.

The SigLIP encoder, embedding-merge utility and attention-mask helper
are reused from `tunix.models.gemma3` to avoid ~900 lines of
duplication; they have no Gemma 3-specific assumptions.

Tests: 4 new tests cover text-only construction, multi-modal forward
pass, the helpful error when images are passed without a vision encoder,
and the text-only attention mask shape. All existing Gemma 3/4 tests
continue to pass.
Add multi-modal (vision) support to Gemma 4
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the
SigLIP-based integration already present in Gemma 3. The training and
sampling infrastructures (peft_trainer, sampler, rl/common) already
forward an `images` kwarg to the model; this change closes the gap at
the model level so they actually work end-to-end on Gemma 4.

Specifically:
- `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`,
  exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b
  factory classmethods (matching gemma3_*_pt/it).
- `ShardingConfig` gains a `siglip` field for the encoder's sharding.
- `Embedder` optionally builds the `mm_input_projection` and
  `mm_soft_embedding_norm` layers and exposes `encode_vision`.
- `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config
  is set; `__call__` accepts an `images` kwarg and merges the soft
  vision tokens into the text embeddings at the placeholder positions.
- Adds `get_attention_mask` (bidirectional over image spans) and
  `get_model_input` (used for LoRA tracing) on the model.
- `params_safetensors._get_key_and_transform_mapping` adds vision-tower
  and multi-modal projector mappings when vision is enabled.

The SigLIP encoder, embedding-merge utility and attention-mask helper
are reused from `tunix.models.gemma3` to avoid ~900 lines of
duplication; they have no Gemma 3-specific assumptions.

Tests: 4 new tests cover text-only construction, multi-modal forward
pass, the helpful error when images are passed without a vision encoder,
and the text-only attention mask shape. All existing Gemma 3/4 tests
continue to pass.
…pt it.

Fixes google#1539.

Commit 49b63f7 ("Agentic GRPO improvements") started passing a derived
non-pad mask as ``segment_ids`` to the model from rl/common.py for every
RL training run. Only Qwen3's ``__call__`` was updated to accept the
keyword; Gemma 2/3/4 and Llama 3 still don't, so GRPO training with
those reference/policy models crashes with::

  TypeError: Gemma3.__call__() got an unexpected keyword argument
    'segment_ids'

at the model call site in ``compute_per_token_logps``.

A later commit (9b4a4c6) added an inline ``inspect.signature``
workaround in ``compute_per_token_logps`` only. ``compute_score`` was
missed and still unconditionally forwarded the auto-derived
``input_seg_ids``, so reward-model paths in GRPO/GSPO remain broken
even on HEAD.

This change:

- Extracts ``_model_accepts_segment_ids(model)`` so the
  signature-introspection logic lives in one place, with a docstring
  explaining why this gate exists (until every model accepts
  ``segment_ids`` natively).
- Applies the gate consistently in both ``compute_per_token_logps`` and
  ``compute_score``: caller-supplied ``segment_ids`` is always passed
  through (matches pre-bug behavior); the auto-derived
  ``input_seg_ids`` is only forwarded to models whose signature accepts
  it.
- Adds regression tests using mock modules whose ``__call__`` does NOT
  declare ``segment_ids``, which is the shape of the failure mode for
  Gemma 2/3/4 / Llama. The existing tests only used
  ``ToyTransformer``, whose ``__call__`` already accepts
  ``segment_ids`` — which is how this slipped past CI.

Manually verified that ``compute_per_token_logps`` no longer raises
when called with a real ``Gemma3`` reference model:

  graphdef, state = nnx.split(Gemma3(ModelConfig.gemma3_270m(),
                                     rngs=nnx.Rngs(0)))
  common.compute_per_token_logps(
      graphdef, state, prompt, completion, pad_id=0, eos_id=-1)
…'t accept it (Gemma/Llama)

Fix TypeError: don't pass auto-derived segment_ids to models that don't accept it (Gemma/Llama)
…pes)

The earlier multi-modal PR reused Gemma 3's SigLIP encoder, but the released
google/gemma-4-e2b-it checkpoint declares its own vision architecture
(model_type: gemma4_vision): gated MLPs, 2D RoPE, per-head q/k/v RMSNorm, a
4-norm sandwich block, a factored 2D position table, and a spatial pooler, plus
a single-tensor embed_vision projector (no multi_modal_projector). SigLIP cannot
load or run these weights.

This is Stage 1 of a faithful from-source port of HF
transformers.models.gemma4.modeling_gemma4:

  * tunix/models/gemma4/vision_real.py - nnx port of Gemma4VisionModel
    (PatchEmbedder, 2D RotaryEmbedding, Attention with q/k/v-norm, gated MLP,
    sandwich EncoderLayer, Encoder, Pooler) and the embed_vision projector
    (Gemma4MultimodalEmbedder). The module tree mirrors the HF/checkpoint names
    so weights map 1:1 (linears transpose, norms direct).
  * tests/models/gemma4/vision_real_test.py - build + forward-shape tests
    (144 patches -> 16 soft tokens -> projected to text dim 1536) and a
    param-path coverage test pinning the nnx side of the checkpoint key map.
  * docs/gemma4_vision_port.md - reverse-engineered spec, the full
    checkpoint->nnx key mapping, Gemma-4-specific gotchas (RMSNorm scales by
    weight not 1+weight; ClippableLinear .linear. nesting; attention scaling
    1.0), and the staged plan.

STATUS: wiring and shapes are verified; NUMERIC PARITY with HF is NOT yet
validated (needs torch + the real checkpoint -- Stage 3). Audio tower is out of
scope; the loader should skip model.audio_tower.* / model.embed_audio.* keys.
Adds the pieces needed to load real google/gemma-4-*-it vision weights and feed
images to the Stage-1 tower:

  * tunix/models/gemma4/image_processing.py - NumPy port of HF
    Gemma4ImageProcessor: aspect-ratio target size, patchify to
    [B, P, 3*patch^2] (channel-last in patch), (x, y) position ids row-major to
    match patch order, padding to max_patches with positions=-1. Resize uses PIL
    (documented as not bit-exact with torchvision antialias; the parity harness
    should reuse identical pixel_values to isolate vision-tower numerics).
  * tunix/models/gemma4/vision_real.py - Gemma4VisionStack (vision_tower +
    embed_vision under the checkpoint's own submodule names) and a rotary
    embedding that recomputes inv_freq per call so no derived buffer lands in
    the param tree (which would otherwise break the loader).
  * tunix/models/gemma4/vision_params_safetensors.py - vision_key_mapping
    (optional model. prefix, $-anchored, exactly-one-match) mapping
    vision_tower/embed_vision keys onto the stack (linears transpose, scaled
    norms rename weight->scale; v_norm and the projector pre-norm are unscaled
    and absent; audio_tower/embed_audio and e2b clip buffers are intentionally
    skipped), plus create_vision_stack_from_safe_tensors.

Tests (all passing, no real weights needed):
  * vision_params_safetensors_test.py - every real vision key maps to an
    existing param; linears transpose and norms don't; audio + clip buffers are
    skipped; no double-matches; every loadable param is covered (no
    uninitialised params).
  * image_processing_test.py - patchify shape/order, xy position ids, padding,
    and processor-output -> Gemma4VisionStack end-to-end shapes.

STATUS: structure + key coverage verified. Numeric parity with HF and exact
resize parity remain (Stage 3, needs torch + the checkpoint).
examples/gemma4/vision_parity_check.py is a self-contained CLI that compares
per-layer activations of this JAX port against the HuggingFace torch reference:

  * Torch side: selective load_state_dict of just model.vision_tower.* and
    model.embed_vision.* from the real checkpoint, so it doesn't drag the 2.5B
    language_model + audio_tower into memory.
  * JAX side: vision_params_safetensors.create_vision_stack_from_safe_tensors
    loads the same weights into Gemma4VisionStack, then an instrumented forward
    captures activations at each tap.

Inputs are synthetic, all-valid (no padding), and shared bit-exactly between
the two stacks, so the harness isolates port numerics from torchvision-vs-PIL
resize differences. All taps are cast to fp32 before diffing.

Reported per checkpoint (after_patch_embed, after_layer_NN for each layer,
after_pool, after_vision_tower_gathered, after_projector, + an exact equality
check on the pool mask): max-abs / mean-abs / median-relative diff. Default
failure threshold 5e-2 (set for bf16; tighten if both stacks run fp32). The
process exits 0 iff every checkpoint passes.

Verified in this commit only what can be verified without torch + the real
weights: a dry-run of the JAX-side instrumented forward on random init produces
the expected shapes at every tap (1x36 patches -> 4 soft tokens -> 4x1536 after
the projector) and the _diff_stats helper returns sane numbers. The real parity
run requires the user to:

  pip install torch transformers safetensors
  python examples/gemma4/vision_parity_check.py --ckpt ~/gemma4-e2b

Until that returns "PARITY PASSED" on a real checkpoint, this port should not
be claimed to work.
… HF)

Adds a checkpoint-free parity harness and records the validation result.

examples/gemma4/vision_parity_random_weights.py builds a small HF
Gemma4VisionModel + Gemma4MultimodalEmbedder with random weights, serializes
them under real checkpoint key names, loads them into the JAX Gemma4VisionStack
through the production loader (vision_params_safetensors), and compares
per-layer activations in fp32. It validates the port MATH + key mapping +
loader against the HF reference WITHOUT needing a model download.

Result (transformers 5.9.0 + torch 2.12, fp32, 4 layers):
  after_patch_embed  max_abs=0.000e+00
  after_layer_00..03 max_abs ~1e-6
  tower              max_abs=1.526e-05
  proj               max_abs=3.576e-07
  => PARITY PASSED

This proves the gemma4_vision port (patch embedder with factored 2D position
embeddings, 2D RoPE, q/k/v-norm attention, gated MLP, sandwich norms, spatial
pooler) and the embed_vision projector are numerically equivalent to HF. The
real-checkpoint harness (vision_parity_check.py) additionally validates loading
the actual .safetensors files and is the last step before Stage 4 wiring.

Note: transformers 5.10.1 + some torch builds hit a torch._dynamo "Duplicate
dispatch rule" import error; transformers==5.9.0 is a known-good combo (and is
the version this port was reverse-engineered from).
Composes the text Gemma4 with the ported Gemma4VisionStack for end-to-end
multimodal forward + a greedy caption demo.

  * model.py: Gemma4.__call__ gains a non-breaking `input_embeddings` override
    so a multimodal wrapper can inject merged (text + vision) embeddings without
    touching the layer loop or the legacy SigLIP path.
  * multimodal.py: Gemma4Multimodal embeds tokens, runs the vision stack, and
    scatters the projected soft tokens into `tokens == image_token_id` positions
    (HF masked_scatter equivalent via merge_embeddings), then runs the text
    transformer on the merged embeddings. Bidirectional attention over each
    image's soft-token span. Plus create_multimodal_from_safe_tensors, which
    loads text + vision from one checkpoint (each loader skips the other's keys).
  * examples/gemma4/multimodal_generate.py: single-image, no-padding, eager
    greedy caption demo for a real checkpoint.

Tests (multimodal_test.py, JAX-only) verify the merge places soft tokens exactly
at image positions, leaves text positions unchanged, that a different image
changes downstream logits, and the mask is bidirectional over the image span.
A sandbox dry-run exercises the prompt builder, unpadded-patch slicing, and the
greedy loop on a small random model. 26/26 gemma4 tests pass.

Honest limits (documented in docs/gemma4_vision_port.md): no full-model numeric
parity vs HF Gemma4Model.forward yet (per-layer-input behavior on merged
embeddings unverified); single non-padded image only (merge assumes
#valid-soft-tokens == #placeholders); real caption requires the checkpoint.
Closes the per-layer-input question raised against HF Gemma4Model.forward and
adds a full-model parity harness.

Findings (validated by examples/gemma4/multimodal_parity_random_weights.py,
which transfers a tiny random HF Gemma4ForConditionalGeneration into Tunix via
the production loader):

  * PLE token-identity branch: bit-exact (max=0) once we substitute
    image_token_id -> pad_token_id in the ids handed to embed_tokens_per_layer.
    HF Gemma4Model.forward calls get_per_layer_inputs(llm_input_ids, ...) which
    ignores its inputs_embeds arg; the context-projection branch (inside
    project_per_layer_inputs) then runs on the MERGED inputs_embeds (vision at
    image positions). Gemma4Multimodal._compute_per_layer_inputs mirrors this
    exactly.
  * Bidirectional-vs-causal mask: HF Gemma4TextConfig.use_bidirectional_attention
    == "vision" toggles the Gemma-3-style bidirectional mask over each image's
    soft-token span; smaller checkpoints default to causal. Gemma4Multimodal now
    exposes a `bidirectional_image_span` flag (default False = HF small-model
    behavior) and `create_multimodal_from_safe_tensors` accepts/forwards it.
  * Residual divergence at text positions in the parity harness (~9e-2 max)
    reproduces in a pure-text parity (same HF weights, no image, no wrapper),
    so it is a pre-existing Tunix-vs-HF divergence in the gemma4 text-model
    arithmetic -- not introduced by the vision port and out of scope here.

Changes:
  * model.py: Gemma4.__call__ accepts an optional `per_layer_inputs=` override
    so a multimodal wrapper can precompute HF-style PLE. Skips the internal
    computation when provided.
  * multimodal.py: _compute_per_layer_inputs substitutes image_token_id with
    pad_token_id ONLY in the token-identity branch; the context branch sees
    the merged embeddings (matches HF). Adds the `bidirectional_image_span`
    flag (default False) and threads pad_token_id through the constructor +
    create_multimodal_from_safe_tensors.
  * tests: locks in (a) PLE pad-substitution matches encode_per_layer_input
    with substituted ids and the substitution is actually necessary (image
    positions differ from non-substituted PLE), and (b) the bidirectional flag
    flips the mask shape.
  * examples/gemma4/multimodal_parity_random_weights.py: checkpoint-free
    full-model harness. Result on this commit: PLE bit-exact, BOS 1.6e-7,
    image-position logit diff ~4e-2 (down 5x from before the PLE fix).

28/28 tests in tests/models/gemma4/ pass.
Replace the eager non-JIT greedy loop (one full forward per token, growing
sequence = JIT recompile every step) with a proper prefill/decode split:

- _prefill: nnx.jit, static max_new_tokens, initializes KV cache, runs full
  multimodal forward once (image + prompt), returns first next-token + cache
- _decode_step: nnx.jit, single token, uses model.text_model directly
  (vision stack skipped on decode steps), causal mask over cache positions
- Both are module-level so JIT compilation is cached across calls

On 4x A100 80GB: first call ~1-2 min (compile), subsequent tokens ~5 ms each.
Set config.dtype = config.param_dtype = bfloat16 in multimodal_generate.py
so that the model computes in bfloat16 (matching checkpoint weights) and the
KV cache is initialized with the same dtype as key/value projections.

Previously: config.dtype defaulted to float32, causing a dtype mismatch in
dynamic_update_slice (bfloat16 cache vs float32 value_proj).

Also: derive cache_dtype from model.text_model.config.dtype instead of
hardcoding, so the script is robust to different dtype configurations.
… throughput

Replace per-step nnx.jit decode loop with a single lax.scan dispatch.

Root cause of slow generation (5-16 tok/s on A100 80GB for a 2B model):
- nnx.jit traverses ~500 NNX parameter tensors every decode step to extract
  and restore the module tree, costing ~55ms Python overhead per step.
- Only ~5ms was actual GPU compute; 11x overhead ratio.

Fix: _decode_n_tokens() runs all n_steps inside a single nnx.jit call via
jax.lax.scan, amortising the 55ms Python cost over all tokens.

Benchmark on A100 (30-step decode, Gemma4 2B bfloat16, seq=271, cache=335):
  Per-step nnx.jit loop:  54ms/step  =  18 tok/s
  lax.scan (this fix):     8ms/step  = 131 tok/s   (+7x)
…orrectly

Gemma4 instruction-tuned models use two stop tokens:
  token 1  = true EOS
  token 106 = <end_of_turn> (conversation turn separator)

Previously eos_ids only contained token 1, so the decode loop ran past
<end_of_turn> and produced empty or garbled captions.

Now eos_ids collects both: tokenizer.eos_token_id (handles lists/ints) plus
<end_of_turn> via convert_tokens_to_ids.
Add a tensor-parallel mesh to the multimodal generation demo. Weights already
carry per-array ShardingConfig metadata; the safetensors loader reads it via
nnx.get_named_sharding(mesh) and device_puts each tensor onto its target shards
as it loads. We build a (1, tp) mesh over ('fsdp', 'tp'), switch the text config
to the is_sampling sharding (Megatron-style: residual replicated, inner
attn/FFN dims sharded along 'tp'), and run generation inside the mesh context so
the activation-level shard() constraints resolve.

Default is --tp-size 1 (single device). Measured warmed-up decode on 4x A100
(Gemma4 2B bf16, batch=1, seq=271, cache=335):

  TP=1:  136 tok/s   (7.4 ms/token)
  TP=4:   46 tok/s  (21.8 ms/token)   <- 3x SLOWER

At batch=1 each token incurs ~70 sequential all-reduces (2 per layer x 35
layers); the collective launch latency dominates and the matmuls become too
small to use the extra GPUs efficiently. Since the model fits on one device,
single-GPU is optimal. TP is only a win for models too large for one GPU or at
large batch sizes, so it is left opt-in via --tp-size.
…e 31B

Make the multimodal demo work across Gemma4 sizes instead of hardcoding E2B.
The text config is now auto-selected from the checkpoint, keyed by
(num_hidden_layers, hidden_size):

  (35, 1536) -> E2B      (42, 2560) -> E4B
  (60, 5376) -> 31B      (30, 2816) -> 26B-A4B (MoE)

Unknown variants (e.g. the 12B: 48 layers / 3840 hidden) exit with a clear
message — there is no tunix ModelConfig preset for it yet.

Validated google/gemma-4-31B-it end-to-end on 4x A100 with --tp-size 4:
  - 62.5 GB sharded across 4 GPUs, loaded in 57s (cannot fit one 80GB GPU's
    default 75% preallocation, so sharding is exercised for real)
  - 60-layer multimodal prefill + lax.scan decode produce a correct caption
  - exercises the 31B-only attention path (32 heads / 16 KV / 4 global-KV,
    k_eq_v_global=True) that E2B never touches

Also switch `with mesh:` to `with jax.set_mesh(mesh):` (the former is
deprecated in JAX 0.10.x).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MM SFT of Gemma4

2 participants