diff --git a/docs/models/index.rst b/docs/models/index.rst index 3fec34f7..c3e15741 100644 --- a/docs/models/index.rst +++ b/docs/models/index.rst @@ -20,3 +20,5 @@ Documentation for available models and model architectures. dllm rae_siglip sit + llava_onevision1_5 + llava_onevision2 diff --git a/docs/models/llava_onevision2.md b/docs/models/llava_onevision2.md new file mode 100644 index 00000000..3b64b381 --- /dev/null +++ b/docs/models/llava_onevision2.md @@ -0,0 +1,106 @@ +# LLaVA-OneVision2 Training + +## Overview + +LLaVA-OneVision2 (OV2) is the LMMs-Lab successor to LLaVA-OneVision 1.5. The +8B-Instruct checkpoint pairs a custom OneVision vision encoder (SigLIP-like +ViT with 3D RoPE and a patch-merger) with a stock **Qwen3-8B** language +model. Modeling code is shipped via Hugging Face ``auto_map`` and is loaded +at runtime through ``trust_remote_code``. + +## Supported Features + +| Feature | Support | +|---------|---------| +| **FSDP2** | ✅ | +| **FlashAttention 2** | ✅ | +| **Liger Kernel** | ✅ | +| **RMPAD (sequence packing)** | ✅ | +| **Packing** | ✅ | +| **Ulysses Sequence Parallel** | ✅ (via Qwen3 inner LM) | + +## Quick Start + +- **Example Config**: [examples/llava_onevision2/example.yaml](../../examples/llava_onevision2/example.yaml) +- **Run Script**: [examples/llava_onevision2/run.sh](../../examples/llava_onevision2/run.sh) + +```bash +bash examples/llava_onevision2/run.sh +``` + +## How Monkey Patching Works + +Because OV2's modeling classes are loaded dynamically (no shared import +path), patches are applied at the **model instance** level. Two patch_types +are registered for ``model_type == "llava_onevision2"``: + +* ``"liger"`` – Liger kernels: RoPE, RMSNorm, SwiGLU MLP (inner Qwen3 LM), + LayerNorm (OV2 vision encoder), plus a fused linear cross-entropy bound + onto OV2's ``ForConditionalGeneration.forward``. +* ``"rmpad"`` – Sequence-packing (unpadded) attention path: class-level + patches to inner Qwen3 attention/decoder/model forwards so they consume + ``cu_seq_lens``/``indices``, and an outer ``model_forward`` that wires + rmpad metadata through to ``causal_lm_forward``. + +The runner appends them in order based on ``trainer_args``: + +| `use_liger_kernel` | `use_rmpad` | Resulting behaviour | +|---|---|---| +| ✅ | ✅ | rmpad + fused LCE (historical default) | +| ✅ | ❌ | fused LCE, no unpadding | +| ❌ | ✅ | unpadded attention, standard CE | +| ❌ | ❌ | stock HF forward | + +## Key Configuration + +```yaml +model_config: + load_from_pretrained_path: lmms-lab-ov2/LLaVA-OneVision2-8B-Instruct + attn_implementation: flash_attention_2 + torch_dtype: bfloat16 + model_type: llava_onevision2 + extra_kwargs: + trust_remote_code: true # required: OV2 ships modeling via auto_map + +dataset_config: + dataset_type: qwen3_vl_iterable + processor_config: + processor_name: lmms-lab-ov2/LLaVA-OneVision2-8B-Instruct + processor_type: llava_onevision2 + packing: true + packing_length: 8192 + +trainer_args: + use_liger_kernel: true + use_rmpad: true + fsdp2: true + fsdp_config: + transformer_layer_cls_to_wrap: + - Qwen3DecoderLayer # inner LM (stock Qwen3) + - OneVisionEncoderEncoderLayer # OV2 vision tower +``` + +## Data Processor + +``LlavaOnevision2DataProcessor`` inherits from ``Qwen3_VLDataProcessor`` +and: + +1. Uses the OV2 ``AutoProcessor`` (image_processor + video_processor) + loaded with ``trust_remote_code=True``. +2. Rewrites each chat-template ```` + into a sequence of per-frame blocks + ``*n`` and aliases the + video patch tensors into the image path (every frame becomes a + ``[1, H, W]`` row of ``image_grid_thw``). +3. Computes the block-layout ``patch_positions`` tensor required by the + OV2 vision tower's 3D RoPE. +4. Normalizes per-frame numpy arrays from ``qwen_vl_utils`` (CHW float) to + HWC uint8 so OV2's video processor can PIL-convert them. + +## Implementation Pointers + +* Monkey patches: ``src/lmms_engine/models/llava_onevision2/monkey_patch.py`` +* OV2 forward replacements: ``src/lmms_engine/models/llava_onevision2/llava_onevision2_ops.py`` +* Shared LM loss helper (LCE / CE, rmpad shift, Ulysses SP): + ``src/lmms_engine/models/common_ops/loss.py`` +* Data processor: ``src/lmms_engine/datasets/processor/llava_onevision2_processor.py`` diff --git a/examples/llava_onevision2/example.yaml b/examples/llava_onevision2/example.yaml new file mode 100644 index 00000000..fc08878b --- /dev/null +++ b/examples/llava_onevision2/example.yaml @@ -0,0 +1,90 @@ +# LLaVA-OneVision2 (8B-Instruct) training example. +# +# This config drives the LMMs-Lab LLaVA-OneVision2 checkpoint, which ships +# its modeling and processor code via ``auto_map`` (trust_remote_code). +# The model_config.extra_kwargs.trust_remote_code flag is required so the +# runner forwards it through AutoConfig / AutoModelFor*ImageTextToText. + +trainer_type: fsdp2_trainer + +dataset_config: + dataset_type: qwen3_vl_iterable + dataset_format: yaml + datasets: + - path: data/LLaVA-Video-178K/llava_video_0_30_s_cap_oe.parquet + data_folder: /path/to/LLaVA-Video-178K + data_type: parquet + processor_config: + processor_name: lmms-lab-ov2/LLaVA-OneVision2-8B-Instruct + processor_type: llava_onevision2 + extra_kwargs: + image_max_pixels: 360448 + image_min_pixels: 28800 + video_max_pixels: 360448 + video_min_pixels: 28800 + packing: true + packing_strategy: balanced + packing_length: 8192 + shuffle: true + filter_overlong: true + filter_overlong_workers: 8 + video_sampling_strategy: fps + video_max_pixels: 360448 + video_max_frames: 64 + frame_num: 32 + fps: 1 + video_backend: qwen_vl_utils + extra_kwargs: + packing_kwargs: + num_buckets: 2 + +model_config: + load_from_pretrained_path: lmms-lab-ov2/LLaVA-OneVision2-8B-Instruct + attn_implementation: flash_attention_2 + torch_dtype: bfloat16 + model_type: llava_onevision2 + extra_kwargs: + trust_remote_code: true + +trainer_args: + output_dir: ./output/llava_onevision2_training + do_train: true + do_eval: false + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-05 + num_train_epochs: 1 + max_steps: 1000 + lr_scheduler_type: cosine + warmup_ratio: 0.03 + logging_steps: 10 + save_strategy: steps + save_steps: 500 + save_total_limit: 2 + bf16: true + tf32: true + dataloader_drop_last: true + dataloader_num_workers: 4 + dataloader_prefetch_factor: 2 + remove_unused_columns: false + gradient_checkpointing: true + # Liger kernel + sequence packing (rmpad). The OV2 monkey patch registers + # 'liger' and 'rmpad' patch_types independently; the runner stacks them + # so the final causal_lm_forward runs with loss_fn=lce + use_rmpad=True. + use_liger_kernel: true + use_rmpad: true + fsdp2: true + fsdp_config: + transformer_layer_cls_to_wrap: + # Inner LM is stock Qwen3; OV2 ships its own vision encoder block. + - Qwen3DecoderLayer + - OneVisionEncoderEncoderLayer + reshard_after_forward: false + min_num_params: 0 + sp_ulysses_degree: 1 + reduce_dtype: bfloat16 + output_dtype: bfloat16 + report_to: none + seed: 42 + optim: adamw_torch_fused + run_name: llava_onevision2_training diff --git a/examples/llava_onevision2/run.sh b/examples/llava_onevision2/run.sh new file mode 100755 index 00000000..d5b3ecfe --- /dev/null +++ b/examples/llava_onevision2/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +################################################################################ +# LLaVA-OneVision2 (8B-Instruct) Training with FSDP2 +################################################################################ +# +# DESCRIPTION: +# Train the LMMs-Lab LLaVA-OneVision2 checkpoint with FSDP2, sequence +# packing (rmpad), and Liger fused linear cross-entropy. +# +# KEY NOTES: +# - OV2 ships its modeling + processor code via auto_map. We forward +# trust_remote_code through AutoConfig / AutoModelFor*ImageTextToText +# so the remote code path is honored. The yaml sets: +# model_config.extra_kwargs.trust_remote_code: true +# - Inner LM is stock Qwen3, so most liger / rmpad work is delegated to +# the qwen3 monkey patch. OV2-specific bits (outer model.forward, +# vision LayerNorm, video token expansion) live under +# ``src/lmms_engine/models/llava_onevision2``. +# - Video frames go through the same image path as multi-image inputs; +# the data processor rewrites into per-frame +# ``*n`` blocks. +# +# REQUIREMENTS: +# - 8x GPUs (A100/H100 with 80GB recommended) +# - flash-attn: pip install flash-attn --no-build-isolation +# - liger-kernel: pip install liger-kernel +# +# DATASET: +# OpenAI chat format (JSONL/Arrow/Parquet); see docs/user_guide/data_prep.md. +# +################################################################################ + +NGPUS=8 + +# Auto-accept trust_remote_code prompts triggered by transitive HF auto_* +# loads (the explicit kwarg we pass should already cover the main path). +export HF_HUB_TRUST_REMOTE_CODE=1 +export TRUST_REMOTE_CODE=1 + +torchrun --nproc_per_node=${NGPUS} \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=12356 \ + -m lmms_engine.launch.cli \ + config_yaml=examples/llava_onevision2/example.yaml diff --git a/src/lmms_engine/datasets/processor/__init__.py b/src/lmms_engine/datasets/processor/__init__.py index e9730d18..60f787c5 100644 --- a/src/lmms_engine/datasets/processor/__init__.py +++ b/src/lmms_engine/datasets/processor/__init__.py @@ -2,6 +2,7 @@ from .bagel_processor import BagelDataProcessor from .base_qwen2_5_processor import BaseQwen2_5_DataProcessor from .config import ProcessorConfig +from .llava_onevision2_processor import LlavaOnevision2DataProcessor from .llava_processor import LLaVADataProcessor from .llava_video_processor import LLaVAVideoDataProcessor from .nanovlm_processor import NanovlmDataProcessor @@ -34,4 +35,5 @@ "RaeSiglipDataProcessor", "SitDataProcessor", "Qwen3_VLDataProcessor", + "LlavaOnevision2DataProcessor", ] diff --git a/src/lmms_engine/datasets/processor/llava_onevision2_processor.py b/src/lmms_engine/datasets/processor/llava_onevision2_processor.py new file mode 100644 index 00000000..d3d1e792 --- /dev/null +++ b/src/lmms_engine/datasets/processor/llava_onevision2_processor.py @@ -0,0 +1,358 @@ +"""Data processor for LLaVA-OneVision2 (8B-Instruct, trust_remote_code). + +OV2 reuses the Qwen2VL image processor (patches in 2x2 block order) and a +custom video processor that emits per-frame patches + ``patch_positions``. +Token-side, every video is rewritten as a sequence of per-frame blocks of +the form ``<|vision_start|><|image_pad|>*n<|vision_end|>``, +so the model only ever sees the *image* path (videos are aliased into +``pixel_values`` / ``image_grid_thw`` / ``patch_positions``). + +This processor inherits Qwen3-VL's image-side logic and overrides the video +expansion to produce the OV2 block format. +""" + +from typing import List, Optional + +import numpy as np +import torch +from PIL.Image import Image +from transformers import AutoProcessor + +from lmms_engine.mapping_func import register_processor +from lmms_engine.utils import DataUtilities + +from .qwen3_vl_processor import Qwen3_VLDataProcessor + + +@register_processor("llava_onevision2") +class LlavaOnevision2DataProcessor(Qwen3_VLDataProcessor): + def _build_processor(self): + # OV2 ships its processor via auto_map / trust_remote_code. + processor = AutoProcessor.from_pretrained(self.config.processor_name, trust_remote_code=True) + + # Optional pixel-budget overrides (consistent with Qwen3VL). + image_max_pixels = self.config.extra_kwargs.get("image_max_pixels", None) + image_min_pixels = self.config.extra_kwargs.get("image_min_pixels", None) + if image_max_pixels is not None or image_min_pixels is not None: + self._set_vision_processor_size(processor.image_processor, image_min_pixels, image_max_pixels) + + video_max_pixels = self.config.extra_kwargs.get("video_max_pixels", None) + video_min_pixels = self.config.extra_kwargs.get("video_min_pixels", None) + if video_processor := getattr(processor, "video_processor", None): + if video_max_pixels is not None: + video_processor.max_pixels = int(video_max_pixels) + if video_min_pixels is not None: + video_processor.min_pixels = int(video_min_pixels) + return processor + + # ------------------------------------------------------------------ process + + def process( + self, + images: List[Image], + hf_messages, + audios: Optional[List[np.ndarray]] = None, + sampling_rate: Optional[int] = None, + videos=None, + system_message: str = "You are a helpful assistant", + add_system_prompt: bool = True, + add_generation_prompt: bool = False, + **kwargs, + ): + assert audios is None, "LlavaOnevision2DataProcessor does not support audio" + + # ---------------- Image branch ---------------- + if images is not None: + image_inputs = self.processor.image_processor(images=images, return_tensors="pt") + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + # ---------------- Video branch ---------------- + # OV2 video_processor expects a list of videos as pre-decoded frame + # arrays in **HWC uint8** format (so it can call PIL.Image.fromarray). + # The iterable dataset (qwen3_vl_iterable / qwen_vl_utils) hands us + # **CHW float32 / uint8** numpy arrays of shape ``[T, 3, H, W]``, so + # convert here. + if videos is not None: + videos = [self._normalize_video_for_ov2(v) for v in videos] + videos_inputs = self.processor.video_processor(videos=videos, return_tensors="pt") + video_grid_thw = videos_inputs["video_grid_thw"] # [num_videos, 3] + frame_timestamps = videos_inputs["frame_timestamps"] # list[list[float]] + video_pixel_values = videos_inputs["pixel_values_videos"] + video_patch_positions = videos_inputs["patch_positions"] + else: + video_grid_thw = None + frame_timestamps = None + video_pixel_values = None + video_patch_positions = None + + # Token-count math (Qwen2VL-style, per merge-block). + merge_length = self.processor.image_processor.merge_size**2 + + if image_grid_thw is not None: + num_image_tokens = [int(g.prod()) // merge_length for g in image_grid_thw] + else: + num_image_tokens = None + + # For videos: each frame becomes one "image". Token count per frame is + # (H_p * W_p) // merge_length (T is always 1 after we split rows). + if video_grid_thw is not None: + num_video_tokens_per_frame = [int(g[1] * g[2]) // merge_length for g in video_grid_thw] + else: + num_video_tokens_per_frame = None + + # Build text/labels using OV2-specific video expansion. + inputs = self._get_ov2_template_labels( + hf_messages=hf_messages, + num_image_tokens=num_image_tokens, + num_video_tokens_per_frame=num_video_tokens_per_frame, + frame_timestamps=frame_timestamps, + video_grid_thw=video_grid_thw, + system_message=system_message, + add_system_prompt=add_system_prompt, + add_generation_prompt=add_generation_prompt, + ) + + # ---------------- Build patch_positions for IMAGE inputs ---------------- + # Pull build_patch_positions from the OV2 remote module via the + # processor (already imported in trust_remote_code load). + build_patch_positions = self._get_build_patch_positions() + sms = int(self.processor.image_processor.merge_size) + + # ---------------- Alias videos -> image path --------------------------- + # Each video row [T, H, W] -> T rows of [1, H, W]; concat with images. + pixel_values_parts = [] + image_grid_thw_parts = [] + patch_positions_parts = [] + + if image_grid_thw is not None: + pixel_values_parts.append(image_inputs["pixel_values"]) + image_grid_thw_parts.append(image_grid_thw) + patch_positions_parts.append(build_patch_positions(image_grid_thw, spatial_merge_size=sms)) + + if video_grid_thw is not None: + pixel_values_parts.append(video_pixel_values) + expanded_rows = [] + for row in video_grid_thw: + T_v, H_v, W_v = int(row[0]), int(row[1]), int(row[2]) + expanded_rows.extend([[1, H_v, W_v]] * T_v) + image_grid_thw_parts.append(torch.tensor(expanded_rows, dtype=video_grid_thw.dtype)) + # Video processor already produced block-layout patch_positions + # using REAL frame indices for t — preserve that and just concat. + patch_positions_parts.append(video_patch_positions) + + if pixel_values_parts: + inputs["pixel_values"] = torch.cat(pixel_values_parts, dim=0) + inputs["image_grid_thw"] = torch.cat(image_grid_thw_parts, dim=0) + inputs["patch_positions"] = torch.cat(patch_positions_parts, dim=0) + + return inputs + + # ----------------------------------------------------------------- helpers + + @staticmethod + def _normalize_video_for_ov2(video): + """Coerce decoder output into a list[np.ndarray HWC uint8]. + + OV2's ``LlavaOnevision2VideoProcessor._coerce_video_input`` only + accepts list[PIL.Image], list[np.ndarray HWC uint8], or a path. The + Qwen VL utils backend returns a torch tensor / numpy array shaped + ``[T, 3, H, W]`` in CHW order with float or uint8 dtype, so we need + to permute + cast it. + """ + import torch + + if isinstance(video, str): + return video + + if isinstance(video, torch.Tensor): + arr = video.detach().cpu().numpy() + elif isinstance(video, np.ndarray): + arr = video + elif isinstance(video, list): + # Already a list[PIL.Image] / list[np.ndarray frame]: pass through. + return video + else: + return video + + # CHW -> HWC if the leading inner dim looks like channels. + if arr.ndim == 4 and arr.shape[1] in (1, 3, 4): + arr = np.transpose(arr, (0, 2, 3, 1)) + + # Cast to uint8 for PIL. + if arr.dtype != np.uint8: + arr_max = float(arr.max()) if arr.size else 0.0 + arr_min = float(arr.min()) if arr.size else 0.0 + if arr_max <= 1.5 and arr_min >= -0.01: + # Looks like a [0,1] float tensor. + arr = (arr * 255.0).clip(0, 255).astype(np.uint8) + else: + arr = arr.clip(0, 255).astype(np.uint8) + + # Hand back a list of per-frame HWC arrays so OV2's coercion takes the + # ``list[np.ndarray]`` branch. + return [arr[i] for i in range(arr.shape[0])] + + def _get_build_patch_positions(self): + """Resolve ``build_patch_positions`` from the dynamically-loaded + video_processing module shipped with the OV2 checkpoint.""" + if not hasattr(self, "_cached_build_patch_positions"): + video_proc = self.processor.video_processor + import sys + + mod = sys.modules[type(video_proc).__module__] + self._cached_build_patch_positions = mod.build_patch_positions + return self._cached_build_patch_positions + + @property + def vision_start_token_id(self) -> int: + return self.processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + + @property + def vision_end_token_id(self) -> int: + return self.processor.tokenizer.convert_tokens_to_ids("<|vision_end|>") + + # OV2's processor does not expose ``image_token`` / ``video_token`` string + # attributes (unlike Qwen3VLProcessor), so the parent class properties + # return None. Resolve straight from the special tokens used by the OV2 + # chat template. + @property + def image_token_id(self) -> int: + if not hasattr(self, "_image_token_id"): + self._image_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + return self._image_token_id + + @property + def video_token_id(self) -> int: + if not hasattr(self, "_video_token_id"): + self._video_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + return self._video_token_id + + # ----------------------------------------------------- label construction + + def _get_ov2_template_labels( + self, + hf_messages, + num_image_tokens: Optional[List[int]], + num_video_tokens_per_frame: Optional[List[int]], + frame_timestamps: Optional[List[List[float]]], + video_grid_thw=None, + system_message: str = "You are a helpful assistant", + add_system_prompt: bool = True, + add_generation_prompt: bool = False, + ): + unmask_tokens_idx = [self.processor.tokenizer.convert_tokens_to_ids(t) for t in self.special_tokens] + input_id, target = [], [] + image_start_from = 0 + video_start_from = 0 + + if add_system_prompt and hf_messages[0]["role"] != "system": + input_id += DataUtilities.apply_chat_template( + self.processor, + [{"role": "system", "content": [{"type": "text", "text": system_message}]}], + ) + target += [-100] * len(input_id) + + for message in hf_messages: + role = message["role"] + encode_id = DataUtilities.apply_chat_template(self.processor, [message]) + + if self.image_token_id in encode_id and num_image_tokens is not None: + encode_id, used_images = self._expand_encode_id_image_tokens( + encode_id, num_image_tokens, image_start_from + ) + image_start_from += used_images + + if ( + self.video_token_id in encode_id + and num_video_tokens_per_frame is not None + and frame_timestamps is not None + ): + encode_id, used_video = self._expand_encode_id_video_tokens_ov2( + encode_id, + num_video_tokens_per_frame, + video_start_from, + frame_timestamps, + video_grid_thw, + ) + video_start_from += used_video + + input_id += encode_id + if role in ["user", "system"]: + target += [-100] * len(encode_id) + else: + encode_id[:3] = [-100] * 3 # mask out the assistant header + target += encode_id + + if add_generation_prompt: + generation_tokens = self.processor.tokenizer.encode("<|im_start|>assistant\n") + input_id += generation_tokens + target += [-100] * len(generation_tokens) + + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + for idx, tok in enumerate(input_id): + if tok in unmask_tokens_idx: + target[idx] = tok + if tok == self.image_token_id: + target[idx] = -100 + if tok == self.video_token_id: + target[idx] = -100 + + return dict( + input_ids=torch.tensor(input_id, dtype=torch.long), + labels=torch.tensor(target, dtype=torch.long), + ) + + def _expand_encode_id_video_tokens_ov2( + self, + encode_id: List[int], + num_video_tokens_per_frame: List[int], + start_from: int, + frame_timestamps: List[List[float]], + video_grid_thw, + ): + """Rewrite each ``<|vision_start|><|video_pad|><|vision_end|>`` triplet + into per-frame ``<|vision_start|><|image_pad|>*n<|vision_end|>`` + blocks (OV2 native format). + """ + video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id] + expanded = [] + prev = 0 + tokenizer = self.processor.tokenizer + vs_id = self.vision_start_token_id + ve_id = self.vision_end_token_id + + for idx, pos in enumerate(video_pos): + vidx = idx + start_from + T_v = int(video_grid_thw[vidx, 0]) + n_per_frame = num_video_tokens_per_frame[vidx] + seconds_seq = frame_timestamps[vidx] + # Defensive pad/truncate to T_v. + if len(seconds_seq) < T_v: + pad_val = seconds_seq[-1] if seconds_seq else 0.0 + seconds_seq = list(seconds_seq) + [pad_val] * (T_v - len(seconds_seq)) + elif len(seconds_seq) > T_v: + seconds_seq = list(seconds_seq[:T_v]) + + # Strip the original surrounding <|vision_start|>/<|vision_end|>: + # chat_template produces <|vision_start|><|video_pad|><|vision_end|> + # at positions (pos-1, pos, pos+1). The per-frame blocks below + # re-emit their own. + expanded.extend(encode_id[prev : pos - 1]) + + for t in range(T_v): + ts_token = f"<{float(seconds_seq[t]):.1f} seconds>" + ts_ids = tokenizer.encode(ts_token, add_special_tokens=False) + expanded.extend(ts_ids) + expanded.append(vs_id) + expanded.extend([self.image_token_id] * n_per_frame) + expanded.append(ve_id) + + prev = pos + 2 # skip <|vision_end|> + + if idx == len(video_pos) - 1: + expanded.extend(encode_id[prev:]) + + return expanded, len(video_pos) diff --git a/src/lmms_engine/mapping_func.py b/src/lmms_engine/mapping_func.py index fc48417a..ce7635ad 100644 --- a/src/lmms_engine/mapping_func.py +++ b/src/lmms_engine/mapping_func.py @@ -61,7 +61,11 @@ def register_model( AUTO_REGISTER_MODEL_MAPPING[model_general_type].register(model_config, model_class) -def create_model_from_pretrained(load_from_pretrained_path, model_general_type: str | None = None): +def create_model_from_pretrained( + load_from_pretrained_path, + model_general_type: str | None = None, + trust_remote_code: bool = False, +): """Pick an HF Auto* class for ``load_from_pretrained_path``. Args: @@ -72,9 +76,11 @@ def create_model_from_pretrained(load_from_pretrained_path, model_general_type: to disambiguate when the same config is registered under multiple AutoModel mappings (e.g. Qwen3.5 registers under both ``causal_lm`` and ``image_text_to_text``). + trust_remote_code: forwarded to ``AutoConfig.from_pretrained``; needed + for checkpoints that ship custom modeling code via ``auto_map``. """ # Handle both config object and model name/path - config = AutoConfig.from_pretrained(load_from_pretrained_path) + config = AutoConfig.from_pretrained(load_from_pretrained_path, trust_remote_code=trust_remote_code) if model_general_type is not None: if model_general_type not in AUTO_REGISTER_MODEL_MAPPING: @@ -93,7 +99,22 @@ def create_model_from_pretrained(load_from_pretrained_path, model_general_type: elif type(config) in AutoModel._model_mapping.keys(): model_class = AutoModel else: - raise ValueError(f"Model {load_from_pretrained_path} is not supported.") + # Fallback for trust_remote_code checkpoints: the config class is loaded + # dynamically via auto_map and won't be in any HF model mapping. Pick + # the AutoModelFor* class declared in the config's auto_map; that class + # will resolve the remote modeling code itself when from_pretrained + # is called with trust_remote_code=True. + auto_map = getattr(config, "auto_map", None) or {} + if "AutoModelForImageTextToText" in auto_map: + model_class = AutoModelForImageTextToText + elif "AutoModelForCausalLM" in auto_map: + model_class = AutoModelForCausalLM + elif "AutoModelForMaskedLM" in auto_map: + model_class = AutoModelForMaskedLM + elif "AutoModel" in auto_map: + model_class = AutoModel + else: + raise ValueError(f"Model {load_from_pretrained_path} is not supported.") return model_class diff --git a/src/lmms_engine/models/__init__.py b/src/lmms_engine/models/__init__.py index e71204a6..1d296b3e 100644 --- a/src/lmms_engine/models/__init__.py +++ b/src/lmms_engine/models/__init__.py @@ -6,6 +6,7 @@ from .bagel import Bagel, BagelConfig from .config import ModelConfig from .llava_onevision import apply_liger_kernel_to_llava_onevision +from .llava_onevision2 import apply_liger_kernel_to_llava_onevision2 from .monkey_patch import MONKEY_PATCHER from .nanovlm import NanovlmConfig, NanovlmForConditionalGeneration from .qwen2 import apply_liger_kernel_to_qwen2 @@ -42,6 +43,7 @@ "ModelConfig", "AeroProcessor", "apply_liger_kernel_to_llava_onevision", + "apply_liger_kernel_to_llava_onevision2", "apply_liger_kernel_to_qwen2", "apply_liger_kernel_to_qwen3", "Qwen2_5OmniThinkerConfig", diff --git a/src/lmms_engine/models/common_ops/loss.py b/src/lmms_engine/models/common_ops/loss.py new file mode 100644 index 00000000..ac6afb84 --- /dev/null +++ b/src/lmms_engine/models/common_ops/loss.py @@ -0,0 +1,123 @@ +"""Common LM loss helpers shared across model wrappers. + +Used by custom CausalLM forwards (e.g. LlavaOnevision2) to compute the next- +token loss with the right combination of fused linear CE (liger), packed +``rmpad`` shifting, and Ulysses sequence parallelism. +""" + +from typing import Optional + +import torch +import torch.distributed as dist + +from lmms_engine.parallel.sequence_parallel.ulysses import ( + calculate_seq_len_per_rank, + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, + pad_to_max_across_ranks, + slice_input_tensor, +) + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) +except ImportError: + LigerFusedLinearCrossEntropyLoss = None + + +def compute_lm_loss( + hidden_states: torch.Tensor, + labels: torch.Tensor, + lm_head_weight: torch.Tensor, + hidden_size: int, + *, + loss_fn: str = "lce", + use_rmpad: bool = False, + seq_lens: Optional[torch.Tensor] = None, + word_idx: Optional[torch.Tensor] = None, + kwargs: Optional[dict] = None, +) -> torch.Tensor: + """Next-token LM loss with optional fused LCE / rmpad / Ulysses SP. + + Args: + hidden_states: ``[B, L, H]`` (when ``use_rmpad=False``) or + ``[total_tokens, H]`` (when ``use_rmpad=True``). + labels: padded ``[B, L]`` token labels. When ``use_rmpad=True`` and + ``word_idx`` is provided, labels are gathered via ``word_idx``. + lm_head_weight: the LM head weight tensor used either by the fused + LCE kernel or by a plain ``F.linear`` for ``loss_fn="ce"``. + hidden_size: text-model hidden size, for reshaping shifted hidden states. + loss_fn: ``"lce"`` (Liger fused linear CE) or ``"ce"`` (materialized + logits + ``F.cross_entropy``). + use_rmpad: when True, shift inside each packed seq using ``seq_lens``. + seq_lens: cumulative seq lens of packed sequences (rmpad metadata). + word_idx: indices into the flattened padded ``labels`` for unpadding. + kwargs: forwarded model kwargs; we look for ``num_items_in_batch`` to + decide reduction mode. + + Returns: + Scalar loss tensor. + """ + if kwargs is None: + kwargs = {} + sp_size = get_ulysses_sequence_parallel_world_size() + + # Align labels with hidden_states layout. + if use_rmpad and word_idx is not None: + labels_use = labels.view(-1)[word_idx.long()] + else: + labels_use = labels + + if sp_size > 1: + if seq_lens is not None: + seq_lens = calculate_seq_len_per_rank(seq_lens.tolist()) + labels_use = slice_input_tensor(labels_use, dim=0, padding=True) + + # Shift hidden states / labels for next-token prediction. + if use_rmpad and seq_lens is not None: + shift_h, shift_l = [], [] + for i in range(len(seq_lens) - 1): + ch = hidden_states[seq_lens[i] : seq_lens[i + 1], :] + cl = labels_use[seq_lens[i] : seq_lens[i + 1]] + shift_h.append(ch[:-1, :].contiguous()) + shift_l.append(cl[1:].contiguous()) + shift_hidden = torch.cat(shift_h, dim=0) + shift_labels = torch.cat(shift_l, dim=0) + else: + shift_hidden = hidden_states[..., :-1, :].contiguous() + shift_labels = labels_use[..., 1:].contiguous() + + shift_hidden = shift_hidden.view(-1, hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in kwargs else "mean" + if sp_size > 1: + reduction = "none" + + if loss_fn == "lce": + if LigerFusedLinearCrossEntropyLoss is None: + raise RuntimeError("loss_fn='lce' requires liger-kernel; install it or use loss_fn='ce'.") + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + loss = lce(lm_head_weight, shift_hidden, shift_labels) + elif loss_fn == "ce": + logits = torch.nn.functional.linear(shift_hidden, lm_head_weight) + loss = torch.nn.functional.cross_entropy(logits.float(), shift_labels, reduction=reduction, ignore_index=-100) + else: + raise ValueError(f"Unknown loss_fn={loss_fn!r}; expected 'lce' or 'ce'.") + + # Ulysses SP gather. + if sp_size > 1: + loss, total_padding = pad_to_max_across_ranks(loss, dim=0) + loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding) + num_valid_tokens = (shift_labels != -100).sum().float() + sp_group = get_ulysses_sequence_parallel_group() + if sp_group is not None: + dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group) + loss = torch.sum(loss) / (num_valid_tokens + 1e-8) + + if reduction == "sum": + loss = loss / kwargs["num_items_in_batch"] + + return loss diff --git a/src/lmms_engine/models/llava_onevision2/__init__.py b/src/lmms_engine/models/llava_onevision2/__init__.py new file mode 100644 index 00000000..f8c9468f --- /dev/null +++ b/src/lmms_engine/models/llava_onevision2/__init__.py @@ -0,0 +1,9 @@ +from .monkey_patch import ( + apply_liger_kernel_to_llava_onevision2, + apply_rmpad_to_llava_onevision2, +) + +__all__ = [ + "apply_liger_kernel_to_llava_onevision2", + "apply_rmpad_to_llava_onevision2", +] diff --git a/src/lmms_engine/models/llava_onevision2/llava_onevision2_ops.py b/src/lmms_engine/models/llava_onevision2/llava_onevision2_ops.py new file mode 100644 index 00000000..4f78e5e4 --- /dev/null +++ b/src/lmms_engine/models/llava_onevision2/llava_onevision2_ops.py @@ -0,0 +1,248 @@ +"""Forward overrides for LlavaOnevision2 model instances. + +Because the OV2 modeling code is loaded via ``auto_map`` (trust_remote_code), +we cannot patch class objects in a shared module. These functions are bound +onto the OV2 model **instances** at load time by ``monkey_patch.py`` using +``types.MethodType``. + +Provides: +- ``model_forward``: replacement for ``LlavaOnevision2Model.forward`` that + performs sequence unpadding (rmpad) before invoking the inner ``Qwen3Model`` + (whose forward is already patched class-level by the qwen3 monkey patch). +- ``causal_lm_forward``: replacement for + ``LlavaOnevision2ForConditionalGeneration.forward`` that adds Liger fused + linear cross-entropy support, mirroring ``qwen3_vl_lce_forward``. +""" + +from typing import List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import Cache + +from ..common_ops.loss import compute_lm_loss +from ..sequence_packing_utils import _unpad_input + +# Filled in by monkey_patch.apply_liger_kernel_to_llava_onevision2 when it +# binds these forwards to a model instance. We cannot look up the OV2 output +# dataclasses from ``type(self).__module__`` at call time because FSDP wraps +# the module and replaces its class with an internal one. +_OV2_MODULES = {} + + +def _register_ov2_module(model): + """Cache the OV2 modeling module so forwards can locate ModelOutput classes + even after FSDP wraps the model.""" + import sys + + inner = getattr(model, "model", model) + cls = type(inner) + mod = sys.modules.get(cls.__module__) + if mod is None: + return + _OV2_MODULES["modeling"] = mod + _OV2_MODULES["ModelOutputWithPast"] = getattr(mod, "LlavaOnevision2ModelOutputWithPast", None) + _OV2_MODULES["CausalLMOutputWithPast"] = getattr(mod, "LlavaOnevision2CausalLMOutputWithPast", None) + + +def model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + patch_positions: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, +): + """Drop-in replacement for ``LlavaOnevision2Model.forward`` with rmpad. + + Steps: + 1. Build padded ``inputs_embeds`` and inject image / video features just + like the original OV2 forward (multi-image path; video aliased into + the image path is handled by the data processor + image entry below). + 2. Unpad to ``(total_tokens,)`` and pass ``cu_seq_lens`` / ``indices`` + down to the inner Qwen3 language model (its forward is already + patched to consume those kwargs). + """ + return_dict = True if return_dict is None else return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # --- Vision injection (still on padded tensors) --------------------------- + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw, patch_positions=patch_positions) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, patch_positions=patch_positions) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # --- Unpad ---------------------------------------------------------------- + # NB: the patched ``Qwen3Model.forward`` (qwen3_ops.model_forward) ALSO + # unpads internally if ``cu_seq_lens`` is None. We just let it handle that: + # forward ``inputs_embeds`` + ``attention_mask`` straight through. The + # qwen3 model_forward will do _unpad_input itself and return + # ``BaseModelOutputWithPastAndRmpad`` carrying seq_lens / word_idx, which + # we propagate upward so the OV2 LCE forward can slice labels. + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + # Reuse OV2 output dataclass to stay drop-in compatible, but stash rmpad + # info on it. The OV2 ModelOutput dataclass accepts arbitrary kwargs + # because it's a ``ModelOutput`` (just attribute assignment after init). + ModelOutputCls = _OV2_MODULES.get("ModelOutputWithPast") + + out = ModelOutputCls( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=getattr(outputs, "hidden_states", None), + attentions=getattr(outputs, "attentions", None), + ) + # Stash rmpad metadata (set as attributes; ModelOutput supports __setattr__) + out["seq_lens"] = getattr(outputs, "seq_lens", None) + out["word_idx"] = getattr(outputs, "word_idx", None) + return out if return_dict else out.to_tuple() + + +def _compute_loss( + hidden_states: torch.Tensor, + labels: torch.Tensor, + seq_lens: Optional[torch.Tensor], + word_idx: Optional[torch.Tensor], + lm_head_weight: torch.Tensor, + loss_fn: str, + use_rmpad: bool, + text_config, + kwargs: dict, +) -> torch.Tensor: + """Thin wrapper over :func:`compute_lm_loss` that pulls ``hidden_size`` + from OV2's ``text_config``.""" + return compute_lm_loss( + hidden_states=hidden_states, + labels=labels, + lm_head_weight=lm_head_weight, + hidden_size=text_config.hidden_size, + loss_fn=loss_fn, + use_rmpad=use_rmpad, + seq_lens=seq_lens, + word_idx=word_idx, + kwargs=kwargs, + ) + + +def causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + patch_positions: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + # ---- patch-time options (fixed via ``functools.partial`` in monkey patch) ---- + loss_fn: str = "lce", + use_rmpad: bool = False, + **kwargs, +): + """Drop-in replacement for ``LlavaOnevision2ForConditionalGeneration.forward``. + + Behaviour depends on the patch-time options ``loss_fn`` and ``use_rmpad``: + + * ``loss_fn="lce"``: fused linear cross-entropy via Liger (no logits + materialized). Falls back to ``loss_fn="ce"`` if liger is unavailable. + * ``loss_fn="ce"``: standard cross-entropy on materialized logits. + * ``use_rmpad=True``: assumes the inner LM ran with rmpad and the output + ``hidden_states`` is a packed ``[total_tokens, H]`` tensor; shifts + per-seq using ``seq_lens`` from the LM output. + * Inference (``labels is None``) always materializes logits. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + patch_positions=patch_positions, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + seq_lens = outputs.get("seq_lens", None) if hasattr(outputs, "get") else None + word_idx = outputs.get("word_idx", None) if hasattr(outputs, "get") else None + + loss = None + logits = None + text_config = getattr(self.config, "text_config", self.config) + + if labels is not None: + loss = _compute_loss( + hidden_states=hidden_states, + labels=labels, + seq_lens=seq_lens, + word_idx=word_idx, + lm_head_weight=self.lm_head.weight, + loss_fn=loss_fn, + use_rmpad=use_rmpad, + text_config=text_config, + kwargs=kwargs, + ) + else: + # Pure inference path. + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + OutputCls = _OV2_MODULES.get("CausalLMOutputWithPast") + return OutputCls( + loss=loss, + logits=logits, + past_key_values=getattr(outputs, "past_key_values", None), + hidden_states=getattr(outputs, "hidden_states", None), + attentions=getattr(outputs, "attentions", None), + ) diff --git a/src/lmms_engine/models/llava_onevision2/monkey_patch.py b/src/lmms_engine/models/llava_onevision2/monkey_patch.py new file mode 100644 index 00000000..8ab1e551 --- /dev/null +++ b/src/lmms_engine/models/llava_onevision2/monkey_patch.py @@ -0,0 +1,266 @@ +"""Monkey patches for LlavaOnevision2 (auto_map / trust_remote_code model). + +Because OV2 modeling classes are loaded dynamically from the checkpoint, +we cannot patch their classes module-globally. All patches here are applied +at the **instance** level on a materialized model. + +Two independent patch_types are registered: + + * ``"liger"`` – Liger kernels: RoPE, RMSNorm, SwiGLU MLP, vision LayerNorm, + and the fused linear cross-entropy loss bound onto the + OV2 ``ForConditionalGeneration.forward``. + * ``"rmpad"`` – Sequence-packing (unpadded) attention path: class-level + patches to the inner Qwen3 layers (so attention/decoder/ + model forwards consume ``cu_seq_lens`` + ``indices``) and + a CE-loss-flavoured ``causal_lm_forward`` that shifts + per-seq using ``seq_lens``. + +The runner applies these in order: ``["liger", "rmpad"]`` when both are +requested, so ``rmpad``'s rebinding runs *after* ``liger``'s and detects the +already-installed ``loss_fn="lce"`` to preserve it. + +Stacked behaviour: + + * ``liger`` alone → fused LCE, no unpadding. + * ``rmpad`` alone → unpacked attention, standard CE loss. + * ``liger`` + ``rmpad`` (runner order, rmpad rebinds last) → unpacked + + fused LCE (the historical default of this codebase). + * Neither → stock HF forward. +""" + +from functools import partial +from types import MethodType + +from loguru import logger +from transformers import PreTrainedModel + +try: + from liger_kernel.transformers.monkey_patch import ( + _patch_layer_norm_module, + _patch_rms_norm_module, + _patch_swiglu_module, + ) + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +except ImportError: + _patch_layer_norm_module = None + _patch_rms_norm_module = None + _patch_swiglu_module = None + LigerSwiGLUMLP = None + logger.warning("liger kernel not installed; OV2 liger patch will be a no-op.") + +from lmms_engine.models.monkey_patch import MONKEY_PATCHER + +from .llava_onevision2_ops import _register_ov2_module, causal_lm_forward, model_forward + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _bind_causal_lm_forward(model, *, loss_fn: str, use_rmpad: bool) -> None: + """Bind ``causal_lm_forward`` onto the OV2 CausalLM instance with patch-time + options fixed via :func:`functools.partial`. + + If called multiple times (e.g. once by ``liger`` then by ``rmpad``), the + last call wins — the matrix in the module docstring relies on this. + """ + bound = partial(causal_lm_forward, loss_fn=loss_fn, use_rmpad=use_rmpad) + + # ``MethodType`` requires a function; wrap the partial to expose ``self``. + def _forward(self, *args, **kwargs): + return bound(self, *args, **kwargs) + + model.forward = MethodType(_forward, model) + + +def _bind_outer_model_forward(model) -> None: + """Bind OV2 ``LlavaOnevision2Model.forward`` (vision injection + LM call).""" + ov2_model = getattr(model, "model", None) + if ov2_model is None: + logger.warning("OV2: model.model not found; cannot bind outer model_forward.") + return + _register_ov2_module(model) + ov2_model.forward = MethodType(model_forward, ov2_model) + + +def _patch_qwen3_text_submodules(language_model, *, rms_norm: bool, swiglu: bool) -> None: + """Instance-level swaps of RMSNorm / SwiGLU modules in an already-loaded + Qwen3Model. The class-level Liger patches only affect *future* construction; + we must mutate existing instances explicitly.""" + if language_model is None: + return + if _patch_rms_norm_module is None: + return + + if rms_norm and hasattr(language_model, "norm"): + _patch_rms_norm_module(language_model.norm) + + for decoder_layer in getattr(language_model, "layers", []): + if rms_norm: + if hasattr(decoder_layer, "input_layernorm"): + _patch_rms_norm_module(decoder_layer.input_layernorm) + if hasattr(decoder_layer, "post_attention_layernorm"): + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + self_attn = getattr(decoder_layer, "self_attn", None) + if self_attn is not None: + if getattr(self_attn, "q_norm", None) is not None: + _patch_rms_norm_module(self_attn.q_norm) + if getattr(self_attn, "k_norm", None) is not None: + _patch_rms_norm_module(self_attn.k_norm) + if swiglu and _patch_swiglu_module is not None and LigerSwiGLUMLP is not None: + if hasattr(decoder_layer, "mlp"): + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + + +def _patch_ov2_vision_layer_norms(visual) -> None: + """Replace LayerNorm modules inside the OV2 vision encoder with Liger's.""" + if visual is None or _patch_layer_norm_module is None: + return + if hasattr(visual, "layernorm_pre"): + _patch_layer_norm_module(visual.layernorm_pre) + if getattr(visual, "layernorm_post", None) is not None: + _patch_layer_norm_module(visual.layernorm_post) + encoder = getattr(visual, "encoder", None) + if encoder is None: + return + for vlayer in getattr(encoder, "layers", []): + if hasattr(vlayer, "layer_norm1"): + _patch_layer_norm_module(vlayer.layer_norm1) + if hasattr(vlayer, "layer_norm2"): + _patch_layer_norm_module(vlayer.layer_norm2) + + +# --------------------------------------------------------------------------- +# Public entry points (registered with MONKEY_PATCHER) +# --------------------------------------------------------------------------- + + +@MONKEY_PATCHER.register("llava_onevision2", "liger") +def apply_liger_kernel_to_llava_onevision2( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + layer_norm: bool = True, + model: PreTrainedModel = None, +) -> None: + """Apply *only* the Liger-kernel patches to an OV2 model instance. + + Does not touch the attention / decoder forwards (those are owned by the + ``"rmpad"`` patch). Binds ``causal_lm_forward(loss_fn="lce", use_rmpad=False)`` + so the LM head loss runs as fused linear CE without materializing logits. + + ``model`` is required: OV2 is auto_map / trust_remote_code, so we have no + shared class to mutate. + """ + if model is None: + logger.warning("OV2 liger patch skipped: no model instance passed.") + return + + # ----- 1. Class-level Liger patches for the inner Qwen3 LM ---------------- + # Reuse the qwen3 patch but force its rmpad / fused-LCE bits OFF so we + # don't accidentally rebind ``Qwen3ForCausalLM.forward`` or set up unpad + # ops we don't want yet. + from lmms_engine.models.qwen3.monkey_patch import apply_liger_kernel_to_qwen3 + + apply_liger_kernel_to_qwen3( + rope=rope, + cross_entropy=cross_entropy, + fused_linear_cross_entropy=False, # we bind OV2's CausalLM forward below + rms_norm=rms_norm, + swiglu=swiglu, + model=None, + use_rmpad=False, + ) + + # ----- 2. Instance-level Liger swaps in already-loaded submodules --------- + ov2_model = getattr(model, "model", None) + if ov2_model is None: + logger.warning("OV2 liger patch: model.model not found; aborting.") + return + _patch_qwen3_text_submodules( + getattr(ov2_model, "language_model", None), + rms_norm=rms_norm, + swiglu=swiglu, + ) + if layer_norm: + _patch_ov2_vision_layer_norms(getattr(ov2_model, "visual", None)) + + # ----- 3. Bind OV2 causal LM forward with fused LCE ----------------------- + # The ``"rmpad"`` patch (if requested) is applied separately *after* this + # one by the runner; it will rebind ``causal_lm_forward`` with + # ``use_rmpad=True`` while detecting and preserving ``loss_fn="lce"``. + if fused_linear_cross_entropy: + _register_ov2_module(model) # cache OV2 output classes pre-FSDP + _bind_causal_lm_forward(model, loss_fn="lce", use_rmpad=False) + + +@MONKEY_PATCHER.register("llava_onevision2", "rmpad") +def apply_rmpad_to_llava_onevision2( + model: PreTrainedModel = None, +) -> None: + """Apply *only* the rmpad (sequence-packing) patches to an OV2 model. + + Patches the inner Qwen3 attention/decoder/model forwards class-level to + consume ``cu_seq_lens`` + ``indices``, binds OV2 ``model_forward`` (which + propagates rmpad metadata out to ``causal_lm_forward``), and binds + ``causal_lm_forward(loss_fn="ce", use_rmpad=True)``. + + When stacked on top of the ``"liger"`` patch, this overrides the latter's + ``causal_lm_forward`` binding so the loss becomes ``loss_fn="lce" + + use_rmpad=True`` (handled by the caller passing both patches; the stacking + behaviour is documented in the matrix in the module docstring). + """ + if model is None: + logger.warning("OV2 rmpad patch skipped: no model instance passed.") + return + + # ----- 1. Class-level rmpad patches for inner Qwen3 layers ---------------- + # We piggy-back on qwen3's apply function with everything else disabled. + from lmms_engine.models.qwen3.monkey_patch import apply_liger_kernel_to_qwen3 + + apply_liger_kernel_to_qwen3( + rope=False, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=False, + swiglu=False, + model=None, + use_rmpad=True, # this is the bit we actually want + ) + + # ----- 2. Outer OV2 model_forward + causal_lm_forward bindings ------------ + _bind_outer_model_forward(model) + + # If liger already bound a fused-LCE forward, preserve that and just flip + # ``use_rmpad=True``. Otherwise bind a plain-CE rmpad forward. + current_loss_fn = _detect_bound_loss_fn(model) + _bind_causal_lm_forward( + model, + loss_fn=current_loss_fn or "ce", + use_rmpad=True, + ) + + +def _detect_bound_loss_fn(model) -> str: + """Inspect ``model.forward`` to recover ``loss_fn`` if we previously bound + it via :func:`_bind_causal_lm_forward`. Returns ``None`` if no prior bind + can be detected (e.g. stock HF forward).""" + fwd = getattr(model, "forward", None) + if fwd is None: + return None + # ``MethodType(_forward, model)`` exposes the underlying function on .__func__. + inner = getattr(fwd, "__func__", None) + if inner is None: + return None + # The closure of ``_forward`` captures ``bound = partial(...)``; pull it out. + closure = getattr(inner, "__closure__", None) or () + for cell in closure: + try: + val = cell.cell_contents + except ValueError: + continue + if isinstance(val, partial): + return val.keywords.get("loss_fn") + return None diff --git a/src/lmms_engine/models/monkey_patch.py b/src/lmms_engine/models/monkey_patch.py index be90f618..d67ed6e9 100644 --- a/src/lmms_engine/models/monkey_patch.py +++ b/src/lmms_engine/models/monkey_patch.py @@ -40,6 +40,12 @@ def apply_monkey_patch(self, model_type, patch_type, **kwargs): f"There are currently no patches supported for model type: {model_type} with patch type: {patch_type}. Available model types: {self._dict.keys()}" ) return + if patch_type not in self._dict[model_type]: + logger.info( + f"Patch type {patch_type!r} not registered for model type {model_type!r}; skipping. " + f"Available patch types: {list(self._dict[model_type].keys())}" + ) + return apply_fn = self._dict[model_type][patch_type] apply_fn_signature = inspect.signature(apply_fn) @@ -68,6 +74,12 @@ def apply_monkey_patch_to_instance(self, model: PreTrainedModel, patch_type, **k f"There are currently no patches supported for model type: {model_type} with patch type: {patch_type}. Available model types: {self._dict.keys()}" ) return + if patch_type not in self._dict[model_type]: + logger.info( + f"Patch type {patch_type!r} not registered for model type {model_type!r}; skipping. " + f"Available patch types: {list(self._dict[model_type].keys())}" + ) + return apply_fn = self._dict[model_type][patch_type] diff --git a/src/lmms_engine/models/utils.py b/src/lmms_engine/models/utils.py index 8082bfc2..ca23eba9 100644 --- a/src/lmms_engine/models/utils.py +++ b/src/lmms_engine/models/utils.py @@ -40,6 +40,7 @@ "minicpmv", "minicpmo", "llava_onevision", + "llava_onevision2", } @@ -78,10 +79,12 @@ def __init__(self, config: PretrainedConfig): "minicpmv": self._estimate_qwen2_flops, "minicpmo": self._estimate_qwen2_flops, "llava_onevision": self._estimate_qwen2_flops, + "llava_onevision2": self._estimate_qwen2_flops, "bagel": self._estimate_qwen2_flops, } if config.model_type in [ "llava_onevision", + "llava_onevision2", "qwen2_5_vl", "qwen3_5", "qwen3_vl", diff --git a/src/lmms_engine/train/runner.py b/src/lmms_engine/train/runner.py index 51dfc2b5..e3a3be41 100644 --- a/src/lmms_engine/train/runner.py +++ b/src/lmms_engine/train/runner.py @@ -70,6 +70,7 @@ def _build_model(self): model_class = create_model_from_pretrained( load_from_pretrained_path, model_general_type=self.model_config.model_general_type, + trust_remote_code=bool(model_kwargs.get("trust_remote_code", False)), ) model = model_class.from_pretrained( load_from_pretrained_path, @@ -108,6 +109,9 @@ def _apply_monkey_patch(self): # Overwrite the use_liger_kernel to False as we already apply the liger kernel by ourselves self.config.trainer_args.use_liger_kernel = False + if self.config.trainer_args.use_rmpad: + kwargs["patch_type"].append("rmpad") + if self.model_config.monkey_patch_kwargs: patch_type = getattr(self.model_config.monkey_patch_kwargs, "patch_type", []) kwargs["patch_type"].extend(patch_type)