[Tunix] Add multi-modal (vision) support to Gemma 4.#1545
Conversation
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass.
|
Hi @msghik thanks for adding the vision support. Can you paste a sample output for the multi-modal change? |
|
Hi @tianshub, sorry for the late response. I had been busy past couple of weeks. I figured out that there is a bug in this pull request that makes inference slow with my A100 GPU, something like 3 tokens/sec. I resolved the issue and open another pull request Link, I believe we should contniue our conversation there. regarding your questoin: Here's an end-to-end sample from the multimodal path — google/gemma-4-E2B-it captioning a real photo (the standard COCO image, Detected Gemma4 text variant: E2B === Caption === The image features, text prompt, and KV cache all flow through correctly — it picks up both cats, the pink surface, and the remotes. With a more detailed prompt it I've also run the same code path on the large google/gemma-4-31B-it (60 layers) with 4-way tensor parallelism across 4×A100, so it works end-to-end across the Gemma4 |
Resolves #1543
Enables SFT / LoRA / RL of Gemma 4 with text + image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling code paths (
peft_trainer,sampler,rl/common) already forward animageskwarg to the model — they just had no Gemma 4 model that accepted it. This change closes that gap at the model level so they work end-to-end on Gemma 4.What changed
tunix/models/gemma4/model.pyModelConfiggains an optionalvision_config: SigLIPConfig | None, exposed via atext_only: bool = Trueparameter on thegemma4_e2b/gemma4_e4b/gemma4_31b/gemma4_26b_a4bfactory classmethods (matches the existinggemma3_*_pt/it(text_only=...)pattern).ShardingConfiggains asiglipfield for the vision encoder's sharding.RMSNormnow accepts either aShardingConfig(existing call sites) or a sharding tuple (new vision call site) — fully backward compatible.Embedderoptionally builds themm_input_projectionandmm_soft_embedding_normlayers and exposesencode_vision.Gemma4.__init__constructs aSigLiPencoder whenvision_configis set;Gemma4.__call__accepts animageskwarg and merges the soft vision tokens into the text embeddings at the placeholder positions.get_attention_mask(bidirectional over image spans, causal over text) andget_model_input(used for LoRA tracing) on the model.tunix/models/gemma4/params_safetensors.py_get_key_and_transform_mappingadds vision-tower and multi-modal projector mappings whenvision_configis set (no change for text-only configs).tests/models/gemma4/model_test.pyDesign notes
vision.py), the embedding-merge helper (merge_embeddings.py), and the bidirectional-causal attention mask (utils.py) are reused fromtunix.models.gemma3to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Happy to refactor these into a sharedtunix.models.commonnamespace if reviewers prefer.vision_configdefaults toNone,imagesdefaults toNone. Existing Gemma 4 call sites and tests are unaffected.Usage