Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion roll/distributed/strategy/deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from roll.third_party.deepspeed.model_update import DeepSpeedWeightUpdater
from roll.third_party.deepspeed.offload_states_patch import bind_deepspeed_offload_states_func
from roll.utils.collective import collective
from roll.utils.context_parallel import get_ulysses_group, set_upg_manager
from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager
from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters
from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits
from roll.utils.constants import IGNORE_INDEX
Expand Down Expand Up @@ -69,6 +69,7 @@ def initialize(self, model_provider):
if (cp_size := self.worker_config.model_args.ulysses_size) > 1:
if current_platform.apply_ulysses_patch() is not None:
set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size)
apply_vision_dp_patch()
else:
cp_size = 1

Expand Down Expand Up @@ -332,6 +333,7 @@ def initialize(self, model_provider):
if (cp_size := self.worker_config.model_args.ulysses_size) > 1:
current_platform.apply_ulysses_patch()
set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size)
apply_vision_dp_patch()

self.worker.rank_info.dp_rank = global_rank // cp_size
self.worker.rank_info.dp_size = world_size // cp_size
Expand Down
17 changes: 15 additions & 2 deletions roll/utils/context_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from roll.utils.context_parallel.globals import get_ulysses_group, set_upg_manager
from roll.utils.context_parallel.monkey_patch import apply_ulysses_patch, unapply_ulysses_patch
from roll.utils.context_parallel.monkey_patch import (
apply_ulysses_patch,
apply_vision_dp_patch,
unapply_ulysses_patch,
unapply_vision_dp_patch,
)

__all__ = ["set_upg_manager", "get_ulysses_group", "apply_ulysses_patch", "unapply_ulysses_patch"]

__all__ = [
"set_upg_manager",
"get_ulysses_group",
"apply_ulysses_patch",
"apply_vision_dp_patch",
"unapply_ulysses_patch",
"unapply_vision_dp_patch",
]
98 changes: 98 additions & 0 deletions roll/utils/context_parallel/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
else:
old_update_causal_mask = None

# Store original vision forwards for unapply
_original_vision_forwards = {}


def apply_ulysses_patch():
from .ulysses_attention import _flash_attention_forward, _update_causal_mask
Expand All @@ -35,6 +38,100 @@ def apply_ulysses_patch():
return patch_info


def apply_vision_dp_patch():
"""Patch VisionTransformer.forward for Vision Data Parallel.

Distributes whole images across Ulysses SP ranks for parallelized ViT computation.
Each rank processes 1/sp_size of images, then all-gathers embeddings.

This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction).
"""
from .vision_dp import create_dp_vision_forward

# Patch Qwen2-VL VisionTransformer
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

original = Qwen2VisionTransformerPretrainedModel.forward
_original_vision_forwards["qwen2_vl"] = original
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original)
logger.info("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP")
except ImportError as e:
logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}")

# Patch Qwen2.5-VL VisionTransformer
try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)

original = Qwen2_5_VisionTransformerPretrainedModel.forward
_original_vision_forwards["qwen2_5_vl"] = original
Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original)
logger.info("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP")
except ImportError as e:
logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}")

# Patch Qwen3-VL VisionModel
try:
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel

original = Qwen3VLVisionModel.forward
_original_vision_forwards["qwen3_vl"] = original
Qwen3VLVisionModel.forward = create_dp_vision_forward(original)
logger.info("Monkey patch Qwen3VLVisionModel.forward for Vision DP")
except ImportError as e:
logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}")

# Patch Qwen3-VL-MoE VisionModel
try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel

original = Qwen3VLMoeVisionModel.forward
_original_vision_forwards["qwen3_vl_moe"] = original
Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original)
logger.info("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP")
except ImportError as e:
logger.debug(f"Qwen3-VL-MoE not available for Vision DP patch: {e}")


def unapply_vision_dp_patch():
"""Restore original VisionTransformer.forward methods."""
if "qwen2_vl" in _original_vision_forwards:
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

Qwen2VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_vl")
except ImportError:
pass

if "qwen2_5_vl" in _original_vision_forwards:
try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)

Qwen2_5_VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_5_vl")
except ImportError:
pass

if "qwen3_vl" in _original_vision_forwards:
try:
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel

Qwen3VLVisionModel.forward = _original_vision_forwards.pop("qwen3_vl")
except ImportError:
pass

if "qwen3_vl_moe" in _original_vision_forwards:
try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel

Qwen3VLMoeVisionModel.forward = _original_vision_forwards.pop("qwen3_vl_moe")
except ImportError:
pass


def unapply_ulysses_patch():
global old_flash_attention_forward, old_update_causal_mask
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = old_flash_attention_forward
Expand All @@ -47,3 +144,4 @@ def unapply_ulysses_patch():
unapply_hf_flash_attention_ulysses_patch()
except Exception:
pass
unapply_vision_dp_patch()
Loading