Skip to content

[feat] add LLaVA-OneVision2 (8B-Instruct) training support#170

Merged
yiyexy merged 4 commits into
mainfrom
feat/llava-onevision2
May 14, 2026
Merged

[feat] add LLaVA-OneVision2 (8B-Instruct) training support#170
yiyexy merged 4 commits into
mainfrom
feat/llava-onevision2

Conversation

@kcz358
Copy link
Copy Markdown
Collaborator

@kcz358 kcz358 commented May 14, 2026

Motivation

Add training support for the LMMs-Lab LLaVA-OneVision2 8B-Instruct
checkpoint. OV2 ships its modeling and processor code via
`auto_map` (trust_remote_code) and pairs a custom OneVision vision
encoder with a stock Qwen3-8B language model, so most of the
machinery (liger, rmpad, fused LCE) can be reused from the existing
qwen3 plumbing — the OV2-specific work is mostly glue.

Modifications

Model loading

  • `mapping_func.create_model_from_pretrained` now forwards
    `trust_remote_code` to `AutoConfig.from_pretrained` and falls
    back to the AutoModelFor* class declared in the checkpoint's
    `auto_map` when the config class is not registered in any HF
    model mapping (the case for all trust_remote_code checkpoints).
  • `runner._build_model` passes `extra_kwargs.trust_remote_code`
    through.

Monkey patch infra

  • `MonkeyPatcher.apply_monkey_patch[_to_instance]` now skips
    patch_types that are not registered for a given model_type instead
    of raising `KeyError`. This lets us auto-append new patch_types
    from the runner without breaking existing models.

OV2 monkey patches (instance-only, since auto_map)

  • New `models/llava_onevision2/` with two patch_types:
    • `liger`: RoPE, RMSNorm, SwiGLU on inner Qwen3 layers; LayerNorm
      on OV2 vision encoder; binds OV2 `causal_lm_forward` with
      `loss_fn="lce"`, `use_rmpad=False`.
    • `rmpad`: class-level qwen3 attention/decoder/model_forward
      rmpad patches, OV2 outer `model_forward` binding, and a rebind
      of `causal_lm_forward` with `use_rmpad=True` that detects the
      liger-bound `loss_fn` (if any) to preserve fused LCE.
  • Runner auto-appends `rmpad` patch_type when
    `trainer_args.use_rmpad=True`.
  • Stacking matrix:
    use_liger_kernel use_rmpad Behaviour
    rmpad + fused LCE (historical default)
    fused LCE, no unpad
    unpad + standard CE
    stock HF forward

Shared LM loss helper

  • New `models/common_ops/loss.py::compute_lm_loss` factors out the
    reusable next-token loss path (LCE / CE, optional rmpad seq-wise
    shifting, Ulysses SP gather). OV2 uses it; future RFCs that split
    liger/rmpad on other models can reuse it too.

Data processor

  • New `datasets/processor/llava_onevision2_processor.py` inheriting
    `Qwen3_VLDataProcessor`. Loads the OV2 `AutoProcessor` with
    `trust_remote_code=True`, rewrites
    `<vision_start><video_pad><vision_end>` into per-frame
    `<X.X seconds><vision_start><image_pad>*n<vision_end>` blocks,
    aliases video tensors into the image path (expands
    `video_grid_thw[T,H,W]` into T rows of `[1,H,W]`), builds
    block-layout `patch_positions`, and normalizes `qwen_vl_utils`
    CHW float frames to HWC uint8 so OV2's video processor's
    `PIL.Image.fromarray` branch works.

Misc

  • `models/utils.py`: route `llava_onevision2` flops through the
    qwen2 estimator using `config.text_config` (its LM is Qwen3).
  • New `examples/llava_onevision2/{example.yaml,run.sh}` end-to-end
    training scaffold.
  • New `docs/models/llava_onevision2.md` describing the patch
    composition + data processor; registered in
    `docs/models/index.rst` (also picks up the existing-but-unlinked
    `llava_onevision1_5.md`).

Testing

  • Manual run on a single-node 4-GPU box: torchrun + FSDP2 + rmpad +
    liger + sequence packing kicks off training successfully against
    `LLaVA-Video-178K` parquet shards; loss decreases over a short
    smoke run with no NaNs.
  • All pre-commit hooks (black + isort) pass.

Commit log

  • `feat(model_loading): forward trust_remote_code for auto_map checkpoints`
  • `feat(monkey_patch): skip unregistered patch_types gracefully`
  • `feat: add LLaVA-OneVision2 model + data processor`
  • `docs: add LLaVA-OneVision2 example config and model guide`

Checklist

  • Follow commit message convention
  • Run `pre-commit run --all-files` and ensure all checks pass
  • Format with `black` (line-length=120) and `isort`
  • Add unit tests for new functionality (manual smoke test only for now)
  • Update documentation

kcz358 added 4 commits May 13, 2026 20:39
AutoConfig.from_pretrained was called without trust_remote_code so any
checkpoint that registers its config via auto_map raised before model class
resolution. Even after fixing that, the resolved config class is not in any
HF model mapping, so add an auto_map fallback that picks the AutoModelFor*
class declared by the checkpoint.
When a model_type registers only a subset of patch_types, looking up the
missing entry in the inner dict raised KeyError. Log and return instead so
multi-patch invocations can probe different models without raising.
OV2 ships its modeling code via auto_map / trust_remote_code, so all
patches are applied at the model instance level. Inner LM is stock
Qwen3, so the qwen3 class-level liger/rmpad patches do most of the work;
the OV2-specific bits are:

- llava_onevision2/monkey_patch.py registers two patch_types:
  - liger: bind causal_lm_forward with loss_fn=lce + module-level liger
    swaps (rms_norm/swiglu in Qwen3 layers, layer_norm in OV2 vision)
  - rmpad: bind OV2 model_forward + rebind causal_lm_forward with
    use_rmpad=True, preserving liger's loss_fn if already bound
- llava_onevision2/llava_onevision2_ops.py: replacement forwards for the
  outer LlavaOnevision2Model and ForConditionalGeneration; loss is
  delegated to the shared compute_lm_loss helper
- common_ops/loss.py: shared next-token loss with optional fused LCE,
  rmpad seq-wise shifting, and Ulysses SP gather
- datasets/processor/llava_onevision2_processor.py: inherits
  Qwen3_VLDataProcessor; rewrites the chat template's
  <vision_start><video_pad><vision_end> into per-frame
  <X.X seconds><vision_start><image_pad>*n<vision_end> blocks and
  aliases video tensors into the image path. Includes a normalizer for
  CHW float frames from qwen_vl_utils so they hit OV2 video processor's
  list[np.ndarray] branch.
- runner.py: auto-append 'rmpad' patch_type when trainer_args.use_rmpad
- models/utils.py: route llava_onevision2 flops through the qwen2 path
  using text_config (its LM is Qwen3)
Adds examples/llava_onevision2/{example.yaml,run.sh} and
docs/models/llava_onevision2.md describing the auto_map / trust_remote_code
loading flow, the split liger / rmpad monkey-patch composition, and the
data processor's video-as-multi-image rewrite. Also registers the new doc
in the models index (and adds the missing llava_onevision1_5 entry that
was already present as a file).
@yiyexy yiyexy merged commit f1279f0 into main May 14, 2026
3 checks passed
@kcz358 kcz358 deleted the feat/llava-onevision2 branch May 14, 2026 11:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants