[Tunix] Add multi-modal (vision) support to Gemma 4.#1608
Open
msghik wants to merge 17 commits into
Open
Conversation
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass.
Add multi-modal (vision) support to Gemma 4
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass.
…pt it. Fixes google#1539. Commit 49b63f7 ("Agentic GRPO improvements") started passing a derived non-pad mask as ``segment_ids`` to the model from rl/common.py for every RL training run. Only Qwen3's ``__call__`` was updated to accept the keyword; Gemma 2/3/4 and Llama 3 still don't, so GRPO training with those reference/policy models crashes with:: TypeError: Gemma3.__call__() got an unexpected keyword argument 'segment_ids' at the model call site in ``compute_per_token_logps``. A later commit (9b4a4c6) added an inline ``inspect.signature`` workaround in ``compute_per_token_logps`` only. ``compute_score`` was missed and still unconditionally forwarded the auto-derived ``input_seg_ids``, so reward-model paths in GRPO/GSPO remain broken even on HEAD. This change: - Extracts ``_model_accepts_segment_ids(model)`` so the signature-introspection logic lives in one place, with a docstring explaining why this gate exists (until every model accepts ``segment_ids`` natively). - Applies the gate consistently in both ``compute_per_token_logps`` and ``compute_score``: caller-supplied ``segment_ids`` is always passed through (matches pre-bug behavior); the auto-derived ``input_seg_ids`` is only forwarded to models whose signature accepts it. - Adds regression tests using mock modules whose ``__call__`` does NOT declare ``segment_ids``, which is the shape of the failure mode for Gemma 2/3/4 / Llama. The existing tests only used ``ToyTransformer``, whose ``__call__`` already accepts ``segment_ids`` — which is how this slipped past CI. Manually verified that ``compute_per_token_logps`` no longer raises when called with a real ``Gemma3`` reference model: graphdef, state = nnx.split(Gemma3(ModelConfig.gemma3_270m(), rngs=nnx.Rngs(0))) common.compute_per_token_logps( graphdef, state, prompt, completion, pad_id=0, eos_id=-1)
…'t accept it (Gemma/Llama) Fix TypeError: don't pass auto-derived segment_ids to models that don't accept it (Gemma/Llama)
…pes)
The earlier multi-modal PR reused Gemma 3's SigLIP encoder, but the released
google/gemma-4-e2b-it checkpoint declares its own vision architecture
(model_type: gemma4_vision): gated MLPs, 2D RoPE, per-head q/k/v RMSNorm, a
4-norm sandwich block, a factored 2D position table, and a spatial pooler, plus
a single-tensor embed_vision projector (no multi_modal_projector). SigLIP cannot
load or run these weights.
This is Stage 1 of a faithful from-source port of HF
transformers.models.gemma4.modeling_gemma4:
* tunix/models/gemma4/vision_real.py - nnx port of Gemma4VisionModel
(PatchEmbedder, 2D RotaryEmbedding, Attention with q/k/v-norm, gated MLP,
sandwich EncoderLayer, Encoder, Pooler) and the embed_vision projector
(Gemma4MultimodalEmbedder). The module tree mirrors the HF/checkpoint names
so weights map 1:1 (linears transpose, norms direct).
* tests/models/gemma4/vision_real_test.py - build + forward-shape tests
(144 patches -> 16 soft tokens -> projected to text dim 1536) and a
param-path coverage test pinning the nnx side of the checkpoint key map.
* docs/gemma4_vision_port.md - reverse-engineered spec, the full
checkpoint->nnx key mapping, Gemma-4-specific gotchas (RMSNorm scales by
weight not 1+weight; ClippableLinear .linear. nesting; attention scaling
1.0), and the staged plan.
STATUS: wiring and shapes are verified; NUMERIC PARITY with HF is NOT yet
validated (needs torch + the real checkpoint -- Stage 3). Audio tower is out of
scope; the loader should skip model.audio_tower.* / model.embed_audio.* keys.
Adds the pieces needed to load real google/gemma-4-*-it vision weights and feed
images to the Stage-1 tower:
* tunix/models/gemma4/image_processing.py - NumPy port of HF
Gemma4ImageProcessor: aspect-ratio target size, patchify to
[B, P, 3*patch^2] (channel-last in patch), (x, y) position ids row-major to
match patch order, padding to max_patches with positions=-1. Resize uses PIL
(documented as not bit-exact with torchvision antialias; the parity harness
should reuse identical pixel_values to isolate vision-tower numerics).
* tunix/models/gemma4/vision_real.py - Gemma4VisionStack (vision_tower +
embed_vision under the checkpoint's own submodule names) and a rotary
embedding that recomputes inv_freq per call so no derived buffer lands in
the param tree (which would otherwise break the loader).
* tunix/models/gemma4/vision_params_safetensors.py - vision_key_mapping
(optional model. prefix, $-anchored, exactly-one-match) mapping
vision_tower/embed_vision keys onto the stack (linears transpose, scaled
norms rename weight->scale; v_norm and the projector pre-norm are unscaled
and absent; audio_tower/embed_audio and e2b clip buffers are intentionally
skipped), plus create_vision_stack_from_safe_tensors.
Tests (all passing, no real weights needed):
* vision_params_safetensors_test.py - every real vision key maps to an
existing param; linears transpose and norms don't; audio + clip buffers are
skipped; no double-matches; every loadable param is covered (no
uninitialised params).
* image_processing_test.py - patchify shape/order, xy position ids, padding,
and processor-output -> Gemma4VisionStack end-to-end shapes.
STATUS: structure + key coverage verified. Numeric parity with HF and exact
resize parity remain (Stage 3, needs torch + the checkpoint).
examples/gemma4/vision_parity_check.py is a self-contained CLI that compares
per-layer activations of this JAX port against the HuggingFace torch reference:
* Torch side: selective load_state_dict of just model.vision_tower.* and
model.embed_vision.* from the real checkpoint, so it doesn't drag the 2.5B
language_model + audio_tower into memory.
* JAX side: vision_params_safetensors.create_vision_stack_from_safe_tensors
loads the same weights into Gemma4VisionStack, then an instrumented forward
captures activations at each tap.
Inputs are synthetic, all-valid (no padding), and shared bit-exactly between
the two stacks, so the harness isolates port numerics from torchvision-vs-PIL
resize differences. All taps are cast to fp32 before diffing.
Reported per checkpoint (after_patch_embed, after_layer_NN for each layer,
after_pool, after_vision_tower_gathered, after_projector, + an exact equality
check on the pool mask): max-abs / mean-abs / median-relative diff. Default
failure threshold 5e-2 (set for bf16; tighten if both stacks run fp32). The
process exits 0 iff every checkpoint passes.
Verified in this commit only what can be verified without torch + the real
weights: a dry-run of the JAX-side instrumented forward on random init produces
the expected shapes at every tap (1x36 patches -> 4 soft tokens -> 4x1536 after
the projector) and the _diff_stats helper returns sane numbers. The real parity
run requires the user to:
pip install torch transformers safetensors
python examples/gemma4/vision_parity_check.py --ckpt ~/gemma4-e2b
Until that returns "PARITY PASSED" on a real checkpoint, this port should not
be claimed to work.
… HF) Adds a checkpoint-free parity harness and records the validation result. examples/gemma4/vision_parity_random_weights.py builds a small HF Gemma4VisionModel + Gemma4MultimodalEmbedder with random weights, serializes them under real checkpoint key names, loads them into the JAX Gemma4VisionStack through the production loader (vision_params_safetensors), and compares per-layer activations in fp32. It validates the port MATH + key mapping + loader against the HF reference WITHOUT needing a model download. Result (transformers 5.9.0 + torch 2.12, fp32, 4 layers): after_patch_embed max_abs=0.000e+00 after_layer_00..03 max_abs ~1e-6 tower max_abs=1.526e-05 proj max_abs=3.576e-07 => PARITY PASSED This proves the gemma4_vision port (patch embedder with factored 2D position embeddings, 2D RoPE, q/k/v-norm attention, gated MLP, sandwich norms, spatial pooler) and the embed_vision projector are numerically equivalent to HF. The real-checkpoint harness (vision_parity_check.py) additionally validates loading the actual .safetensors files and is the last step before Stage 4 wiring. Note: transformers 5.10.1 + some torch builds hit a torch._dynamo "Duplicate dispatch rule" import error; transformers==5.9.0 is a known-good combo (and is the version this port was reverse-engineered from).
Composes the text Gemma4 with the ported Gemma4VisionStack for end-to-end
multimodal forward + a greedy caption demo.
* model.py: Gemma4.__call__ gains a non-breaking `input_embeddings` override
so a multimodal wrapper can inject merged (text + vision) embeddings without
touching the layer loop or the legacy SigLIP path.
* multimodal.py: Gemma4Multimodal embeds tokens, runs the vision stack, and
scatters the projected soft tokens into `tokens == image_token_id` positions
(HF masked_scatter equivalent via merge_embeddings), then runs the text
transformer on the merged embeddings. Bidirectional attention over each
image's soft-token span. Plus create_multimodal_from_safe_tensors, which
loads text + vision from one checkpoint (each loader skips the other's keys).
* examples/gemma4/multimodal_generate.py: single-image, no-padding, eager
greedy caption demo for a real checkpoint.
Tests (multimodal_test.py, JAX-only) verify the merge places soft tokens exactly
at image positions, leaves text positions unchanged, that a different image
changes downstream logits, and the mask is bidirectional over the image span.
A sandbox dry-run exercises the prompt builder, unpadded-patch slicing, and the
greedy loop on a small random model. 26/26 gemma4 tests pass.
Honest limits (documented in docs/gemma4_vision_port.md): no full-model numeric
parity vs HF Gemma4Model.forward yet (per-layer-input behavior on merged
embeddings unverified); single non-padded image only (merge assumes
#valid-soft-tokens == #placeholders); real caption requires the checkpoint.
Closes the per-layer-input question raised against HF Gemma4Model.forward and
adds a full-model parity harness.
Findings (validated by examples/gemma4/multimodal_parity_random_weights.py,
which transfers a tiny random HF Gemma4ForConditionalGeneration into Tunix via
the production loader):
* PLE token-identity branch: bit-exact (max=0) once we substitute
image_token_id -> pad_token_id in the ids handed to embed_tokens_per_layer.
HF Gemma4Model.forward calls get_per_layer_inputs(llm_input_ids, ...) which
ignores its inputs_embeds arg; the context-projection branch (inside
project_per_layer_inputs) then runs on the MERGED inputs_embeds (vision at
image positions). Gemma4Multimodal._compute_per_layer_inputs mirrors this
exactly.
* Bidirectional-vs-causal mask: HF Gemma4TextConfig.use_bidirectional_attention
== "vision" toggles the Gemma-3-style bidirectional mask over each image's
soft-token span; smaller checkpoints default to causal. Gemma4Multimodal now
exposes a `bidirectional_image_span` flag (default False = HF small-model
behavior) and `create_multimodal_from_safe_tensors` accepts/forwards it.
* Residual divergence at text positions in the parity harness (~9e-2 max)
reproduces in a pure-text parity (same HF weights, no image, no wrapper),
so it is a pre-existing Tunix-vs-HF divergence in the gemma4 text-model
arithmetic -- not introduced by the vision port and out of scope here.
Changes:
* model.py: Gemma4.__call__ accepts an optional `per_layer_inputs=` override
so a multimodal wrapper can precompute HF-style PLE. Skips the internal
computation when provided.
* multimodal.py: _compute_per_layer_inputs substitutes image_token_id with
pad_token_id ONLY in the token-identity branch; the context branch sees
the merged embeddings (matches HF). Adds the `bidirectional_image_span`
flag (default False) and threads pad_token_id through the constructor +
create_multimodal_from_safe_tensors.
* tests: locks in (a) PLE pad-substitution matches encode_per_layer_input
with substituted ids and the substitution is actually necessary (image
positions differ from non-substituted PLE), and (b) the bidirectional flag
flips the mask shape.
* examples/gemma4/multimodal_parity_random_weights.py: checkpoint-free
full-model harness. Result on this commit: PLE bit-exact, BOS 1.6e-7,
image-position logit diff ~4e-2 (down 5x from before the PLE fix).
28/28 tests in tests/models/gemma4/ pass.
Replace the eager non-JIT greedy loop (one full forward per token, growing sequence = JIT recompile every step) with a proper prefill/decode split: - _prefill: nnx.jit, static max_new_tokens, initializes KV cache, runs full multimodal forward once (image + prompt), returns first next-token + cache - _decode_step: nnx.jit, single token, uses model.text_model directly (vision stack skipped on decode steps), causal mask over cache positions - Both are module-level so JIT compilation is cached across calls On 4x A100 80GB: first call ~1-2 min (compile), subsequent tokens ~5 ms each.
Set config.dtype = config.param_dtype = bfloat16 in multimodal_generate.py so that the model computes in bfloat16 (matching checkpoint weights) and the KV cache is initialized with the same dtype as key/value projections. Previously: config.dtype defaulted to float32, causing a dtype mismatch in dynamic_update_slice (bfloat16 cache vs float32 value_proj). Also: derive cache_dtype from model.text_model.config.dtype instead of hardcoding, so the script is robust to different dtype configurations.
… throughput Replace per-step nnx.jit decode loop with a single lax.scan dispatch. Root cause of slow generation (5-16 tok/s on A100 80GB for a 2B model): - nnx.jit traverses ~500 NNX parameter tensors every decode step to extract and restore the module tree, costing ~55ms Python overhead per step. - Only ~5ms was actual GPU compute; 11x overhead ratio. Fix: _decode_n_tokens() runs all n_steps inside a single nnx.jit call via jax.lax.scan, amortising the 55ms Python cost over all tokens. Benchmark on A100 (30-step decode, Gemma4 2B bfloat16, seq=271, cache=335): Per-step nnx.jit loop: 54ms/step = 18 tok/s lax.scan (this fix): 8ms/step = 131 tok/s (+7x)
…orrectly Gemma4 instruction-tuned models use two stop tokens: token 1 = true EOS token 106 = <end_of_turn> (conversation turn separator) Previously eos_ids only contained token 1, so the decode loop ran past <end_of_turn> and produced empty or garbled captions. Now eos_ids collects both: tokenizer.eos_token_id (handles lists/ints) plus <end_of_turn> via convert_tokens_to_ids.
Add a tensor-parallel mesh to the multimodal generation demo. Weights already
carry per-array ShardingConfig metadata; the safetensors loader reads it via
nnx.get_named_sharding(mesh) and device_puts each tensor onto its target shards
as it loads. We build a (1, tp) mesh over ('fsdp', 'tp'), switch the text config
to the is_sampling sharding (Megatron-style: residual replicated, inner
attn/FFN dims sharded along 'tp'), and run generation inside the mesh context so
the activation-level shard() constraints resolve.
Default is --tp-size 1 (single device). Measured warmed-up decode on 4x A100
(Gemma4 2B bf16, batch=1, seq=271, cache=335):
TP=1: 136 tok/s (7.4 ms/token)
TP=4: 46 tok/s (21.8 ms/token) <- 3x SLOWER
At batch=1 each token incurs ~70 sequential all-reduces (2 per layer x 35
layers); the collective launch latency dominates and the matmuls become too
small to use the extra GPUs efficiently. Since the model fits on one device,
single-GPU is optimal. TP is only a win for models too large for one GPU or at
large batch sizes, so it is left opt-in via --tp-size.
…e 31B
Make the multimodal demo work across Gemma4 sizes instead of hardcoding E2B.
The text config is now auto-selected from the checkpoint, keyed by
(num_hidden_layers, hidden_size):
(35, 1536) -> E2B (42, 2560) -> E4B
(60, 5376) -> 31B (30, 2816) -> 26B-A4B (MoE)
Unknown variants (e.g. the 12B: 48 layers / 3840 hidden) exit with a clear
message — there is no tunix ModelConfig preset for it yet.
Validated google/gemma-4-31B-it end-to-end on 4x A100 with --tp-size 4:
- 62.5 GB sharded across 4 GPUs, loaded in 57s (cannot fit one 80GB GPU's
default 75% preallocation, so sharding is exercised for real)
- 60-layer multimodal prefill + lax.scan decode produce a correct caption
- exercises the 31B-only attention path (32 heads / 16 KV / 4 global-KV,
k_eq_v_global=True) that E2B never touches
Also switch `with mesh:` to `with jax.set_mesh(mesh):` (the former is
deprecated in JAX 0.10.x).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolves #1543
Enables SFT / LoRA / RL of Gemma 4 with text + image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling code paths (peft_trainer, sampler, rl/common) already forward an images kwarg to the model — they just had no Gemma 4 model that accepted it. This change closes that gap at the model level so they work end-to-end on Gemma 4.
What changed
tunix/models/gemma4/model.py
ModelConfig gains an optional vision_config: SigLIPConfig | None, exposed via a text_only: bool = True parameter on the gemma4_e2b / gemma4_e4b / gemma4_31b / gemma4_26b_a4b factory classmethods (matches the existing gemma3_*_pt/it(text_only=...) pattern).
ShardingConfig gains a siglip field for the vision encoder's sharding.
RMSNorm now accepts either a ShardingConfig (existing call sites) or a sharding tuple (new vision call site) — fully backward compatible.
Embedder optionally builds the mm_input_projection and mm_soft_embedding_norm layers and exposes encode_vision.
Gemma4.init constructs a SigLiP encoder when vision_config is set; Gemma4.call accepts an images kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions.
Adds get_attention_mask (bidirectional over image spans, causal over text) and get_model_input (used for LoRA tracing) on the model.
tunix/models/gemma4/params_safetensors.py
_get_key_and_transform_mapping adds vision-tower and multi-modal projector mappings when vision_config is set (no change for text-only configs).
tests/models/gemma4/model_test.py
4 new tests: multi-modal forward pass, text-only construction has no vision encoder, helpful error when images are passed to a text-only model, and the text-only attention-mask shape.
Design notes
The SigLIP encoder (vision.py), the embedding-merge helper (merge_embeddings.py), and the bidirectional-causal attention mask (utils.py) are reused from tunix.models.gemma3 to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Happy to refactor these into a shared tunix.models.common namespace if reviewers prefer.
The change is strictly additive: vision_config defaults to None, images defaults to None. Existing Gemma 4 call sites and tests are unaffected.
Usage
from tunix.models.gemma4 import model as model_lib
from tunix.models.gemma4 import params_safetensors
config = model_lib.ModelConfig.gemma4_e4b(text_only=False) # enables SigLIP
model = params_safetensors.create_model_from_safe_tensors(
file_dir=ckpt_path, config=config, mesh=mesh, dtype=jnp.bfloat16,
)
logits, _ = model(
tokens,
positions=positions,
attention_mask=model.get_attention_mask(tokens),
images=images,
)