Skip to content

[Tunix] Add multi-modal (vision) support to Gemma 4.#1545

Open
msghik wants to merge 1 commit into
google:mainfrom
msghik:add-gemma4-vision-support
Open

[Tunix] Add multi-modal (vision) support to Gemma 4.#1545
msghik wants to merge 1 commit into
google:mainfrom
msghik:add-gemma4-vision-support

Conversation

@msghik

@msghik msghik commented May 27, 2026

Copy link
Copy Markdown

Resolves #1543

Enables SFT / LoRA / RL of Gemma 4 with text + image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling code paths (peft_trainer, sampler, rl/common) already forward an images kwarg to the model — they just had no Gemma 4 model that accepted it. This change closes that gap at the model level so they work end-to-end on Gemma 4.

What changed

tunix/models/gemma4/model.py

  • ModelConfig gains an optional vision_config: SigLIPConfig | None, exposed via a text_only: bool = True parameter on the gemma4_e2b / gemma4_e4b / gemma4_31b / gemma4_26b_a4b factory classmethods (matches the existing gemma3_*_pt/it(text_only=...) pattern).
  • ShardingConfig gains a siglip field for the vision encoder's sharding.
  • RMSNorm now accepts either a ShardingConfig (existing call sites) or a sharding tuple (new vision call site) — fully backward compatible.
  • Embedder optionally builds the mm_input_projection and mm_soft_embedding_norm layers and exposes encode_vision.
  • Gemma4.__init__ constructs a SigLiP encoder when vision_config is set; Gemma4.__call__ accepts an images kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions.
  • Adds get_attention_mask (bidirectional over image spans, causal over text) and get_model_input (used for LoRA tracing) on the model.

tunix/models/gemma4/params_safetensors.py

  • _get_key_and_transform_mapping adds vision-tower and multi-modal projector mappings when vision_config is set (no change for text-only configs).

tests/models/gemma4/model_test.py

  • 4 new tests: multi-modal forward pass, text-only construction has no vision encoder, helpful error when images are passed to a text-only model, and the text-only attention-mask shape.

Design notes

  • The SigLIP encoder (vision.py), the embedding-merge helper (merge_embeddings.py), and the bidirectional-causal attention mask (utils.py) are reused from tunix.models.gemma3 to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Happy to refactor these into a shared tunix.models.common namespace if reviewers prefer.
  • The change is strictly additive: vision_config defaults to None, images defaults to None. Existing Gemma 4 call sites and tests are unaffected.

Usage

from tunix.models.gemma4 import model as model_lib
from tunix.models.gemma4 import params_safetensors

config = model_lib.ModelConfig.gemma4_e4b(text_only=False)  # enables SigLIP
model = params_safetensors.create_model_from_safe_tensors(
    file_dir=ckpt_path, config=config, mesh=mesh, dtype=jnp.bfloat16,
)

logits, _ = model(
    tokens,
    positions=positions,
    attention_mask=model.get_attention_mask(tokens),
    images=images,
)

Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the
SigLIP-based integration already present in Gemma 3. The training and
sampling infrastructures (peft_trainer, sampler, rl/common) already
forward an `images` kwarg to the model; this change closes the gap at
the model level so they actually work end-to-end on Gemma 4.

Specifically:
- `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`,
  exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b
  factory classmethods (matching gemma3_*_pt/it).
- `ShardingConfig` gains a `siglip` field for the encoder's sharding.
- `Embedder` optionally builds the `mm_input_projection` and
  `mm_soft_embedding_norm` layers and exposes `encode_vision`.
- `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config
  is set; `__call__` accepts an `images` kwarg and merges the soft
  vision tokens into the text embeddings at the placeholder positions.
- Adds `get_attention_mask` (bidirectional over image spans) and
  `get_model_input` (used for LoRA tracing) on the model.
- `params_safetensors._get_key_and_transform_mapping` adds vision-tower
  and multi-modal projector mappings when vision is enabled.

The SigLIP encoder, embedding-merge utility and attention-mask helper
are reused from `tunix.models.gemma3` to avoid ~900 lines of
duplication; they have no Gemma 3-specific assumptions.

Tests: 4 new tests cover text-only construction, multi-modal forward
pass, the helpful error when images are passed without a vision encoder,
and the text-only attention mask shape. All existing Gemma 3/4 tests
continue to pass.
@tianshub

tianshub commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

Hi @msghik thanks for adding the vision support. Can you paste a sample output for the multi-modal change?

@tianshub tianshub left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @msghik thanks for adding the vision support. Can you paste a sample output for the multi-modal change?

@msghik

msghik commented Jun 21, 2026

Copy link
Copy Markdown
Author

Hi @tianshub, sorry for the late response. I had been busy past couple of weeks. I figured out that there is a bug in this pull request that makes inference slow with my A100 GPU, something like 3 tokens/sec. I resolved the issue and open another pull request Link, I believe we should contniue our conversation there.

regarding your questoin:

Here's an end-to-end sample from the multimodal path — google/gemma-4-E2B-it captioning a real photo (the standard COCO image,
image: two cats sleeping on a pink couch with two TV remotes):

 $ python examples/gemma4/multimodal_generate.py \
  --ckpt ~/gemma-4-E2B-it --image cats.jpg \
   --prompt "What is in this image?"

Detected Gemma4 text variant: E2B
Image: 266 soft tokens (2394 raw patches)
Prompt: 'What is in this image?' (prompt length: 273 tokens)

=== Caption ===
This image contains two tabby cats lying on a pink surface. There are also
some remote controls visible in the background.

The image features, text prompt, and KV cache all flow through correctly — it picks up both cats, the pink surface, and the remotes. With a more detailed prompt it
gives a full paragraph (per-cat coat patterns, positions, etc.).

I've also run the same code path on the large google/gemma-4-31B-it (60 layers) with 4-way tensor parallelism across 4×A100, so it works end-to-end across the Gemma4
size range.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MM SFT of Gemma4

3 participants