feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks#357
Open
aoshen524 wants to merge 2 commits intoalibaba:mainfrom
Open
feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks#357aoshen524 wants to merge 2 commits intoalibaba:mainfrom
aoshen524 wants to merge 2 commits intoalibaba:mainfrom
Conversation
…es SP ranks Distribute whole images across Ulysses SP ranks for parallelized ViT computation, reducing ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x ViT memory reduction). Key changes: - Add roll/utils/context_parallel/vision_dp.py with image distribution utilities, GatherVisionEmbeddings autograd function, and model-agnostic VisionTransformer wrapper - Add apply_vision_dp_patch() in monkey_patch.py for Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE VisionTransformer classes - Integrate into DeepSpeed strategy (both inference and training workers) - Add 17 unit tests covering all utility functions, edge cases, and integration workflows Ported from verl (verl-project/verl#5230). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Integrate upstream hf_flash_attention_patch for transformers>=4.53.0 alongside Vision DP patches. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
1/sp_sizeof the images, reducing ViT peak memory by ~sp_sizex (e.g. SP=4 → ~4x ViT memory reduction)ulysses_size > 1, each rank processes a subset of images independently, then all-gathers embeddings once at the endcreate_dp_vision_forward()wrapper supports any VisionTransformer withforward(self, hidden_states, grid_thw)signatureGatherVisionEmbeddingscustom autograd function with proper gradient scaling for distributed training compatibilityWhy this matters
In VLM RL training with Ulysses SP, the ViT (VisionTransformer) is a major memory bottleneck. Text SP splits the sequence across ranks at each attention layer, but the ViT runs on the full set of images on every rank — meaning ViT memory usage is completely unaffected by SP. For scenarios with many images (e.g. multi-turn GUI agent training with screenshots), ViT activation memory can dominate.
Vision DP solves this by distributing images at the ViT level:
O(total_images)total_images/Nimages → ViT memory =O(total_images/N)Key design choices
cu_seqlenstrackingdp_sizeto compensate for partial image processing before reduction(embeddings, deepstack_embeddings)from Qwen3-VL VisionModelFiles changed
roll/utils/context_parallel/vision_dp.pyroll/utils/context_parallel/monkey_patch.pyapply_vision_dp_patch()/unapply_vision_dp_patch()for Qwen2/2.5/3-VL VisionTransformersroll/utils/context_parallel/__init__.pyroll/distributed/strategy/deepspeed_strategy.pyapply_vision_dp_patch()in both inference and training workers whenulysses_size > 1tests/utils/test_vision_dp_on_cpu.pyTest plan
get_image_patch_counts,assign_images_to_dp_ranks,prepare_local_vision_inputs🤖 Generated with Claude Code