diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index 58b7e1b4..f240ecfe 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -23,7 +23,7 @@ from roll.third_party.deepspeed.model_update import DeepSpeedWeightUpdater from roll.third_party.deepspeed.offload_states_patch import bind_deepspeed_offload_states_func from roll.utils.collective import collective -from roll.utils.context_parallel import get_ulysses_group, set_upg_manager +from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits from roll.utils.constants import IGNORE_INDEX @@ -69,6 +69,7 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: if current_platform.apply_ulysses_patch() is not None: set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) + apply_vision_dp_patch() else: cp_size = 1 @@ -332,6 +333,7 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: current_platform.apply_ulysses_patch() set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) + apply_vision_dp_patch() self.worker.rank_info.dp_rank = global_rank // cp_size self.worker.rank_info.dp_size = world_size // cp_size diff --git a/roll/utils/context_parallel/__init__.py b/roll/utils/context_parallel/__init__.py index 8112b8d2..cd3f0101 100644 --- a/roll/utils/context_parallel/__init__.py +++ b/roll/utils/context_parallel/__init__.py @@ -1,4 +1,17 @@ from roll.utils.context_parallel.globals import get_ulysses_group, set_upg_manager -from roll.utils.context_parallel.monkey_patch import apply_ulysses_patch, unapply_ulysses_patch +from roll.utils.context_parallel.monkey_patch import ( + apply_ulysses_patch, + apply_vision_dp_patch, + unapply_ulysses_patch, + unapply_vision_dp_patch, +) -__all__ = ["set_upg_manager", "get_ulysses_group", "apply_ulysses_patch", "unapply_ulysses_patch"] + +__all__ = [ + "set_upg_manager", + "get_ulysses_group", + "apply_ulysses_patch", + "apply_vision_dp_patch", + "unapply_ulysses_patch", + "unapply_vision_dp_patch", +] diff --git a/roll/utils/context_parallel/monkey_patch.py b/roll/utils/context_parallel/monkey_patch.py index a98ec66d..bf668139 100644 --- a/roll/utils/context_parallel/monkey_patch.py +++ b/roll/utils/context_parallel/monkey_patch.py @@ -13,6 +13,9 @@ else: old_update_causal_mask = None +# Store original vision forwards for unapply +_original_vision_forwards = {} + def apply_ulysses_patch(): from .ulysses_attention import _flash_attention_forward, _update_causal_mask @@ -35,6 +38,100 @@ def apply_ulysses_patch(): return patch_info +def apply_vision_dp_patch(): + """Patch VisionTransformer.forward for Vision Data Parallel. + + Distributes whole images across Ulysses SP ranks for parallelized ViT computation. + Each rank processes 1/sp_size of images, then all-gathers embeddings. + + This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction). + """ + from .vision_dp import create_dp_vision_forward + + # Patch Qwen2-VL VisionTransformer + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + + original = Qwen2VisionTransformerPretrainedModel.forward + _original_vision_forwards["qwen2_vl"] = original + Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}") + + # Patch Qwen2.5-VL VisionTransformer + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) + + original = Qwen2_5_VisionTransformerPretrainedModel.forward + _original_vision_forwards["qwen2_5_vl"] = original + Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}") + + # Patch Qwen3-VL VisionModel + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + original = Qwen3VLVisionModel.forward + _original_vision_forwards["qwen3_vl"] = original + Qwen3VLVisionModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen3VLVisionModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}") + + # Patch Qwen3-VL-MoE VisionModel + try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel + + original = Qwen3VLMoeVisionModel.forward + _original_vision_forwards["qwen3_vl_moe"] = original + Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen3-VL-MoE not available for Vision DP patch: {e}") + + +def unapply_vision_dp_patch(): + """Restore original VisionTransformer.forward methods.""" + if "qwen2_vl" in _original_vision_forwards: + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + + Qwen2VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_vl") + except ImportError: + pass + + if "qwen2_5_vl" in _original_vision_forwards: + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) + + Qwen2_5_VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_5_vl") + except ImportError: + pass + + if "qwen3_vl" in _original_vision_forwards: + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + Qwen3VLVisionModel.forward = _original_vision_forwards.pop("qwen3_vl") + except ImportError: + pass + + if "qwen3_vl_moe" in _original_vision_forwards: + try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel + + Qwen3VLMoeVisionModel.forward = _original_vision_forwards.pop("qwen3_vl_moe") + except ImportError: + pass + + def unapply_ulysses_patch(): global old_flash_attention_forward, old_update_causal_mask ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = old_flash_attention_forward @@ -47,3 +144,4 @@ def unapply_ulysses_patch(): unapply_hf_flash_attention_ulysses_patch() except Exception: pass + unapply_vision_dp_patch() diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py new file mode 100644 index 00000000..0f0c0116 --- /dev/null +++ b/roll/utils/context_parallel/vision_dp.py @@ -0,0 +1,352 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Alibaba Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Vision Data Parallel utilities for distributing ViT computation across Ulysses SP ranks. + +Ported from verl (https://github.com/verl-project/verl/pull/5230). + +Strategy: Distribute whole images across DP ranks, not patches within images. +This avoids breaking cu_seqlens semantics while parallelizing ViT computation. + +Key difference from text SP: +- Text SP: Split sequence within attention layers, all-to-all per layer +- Vision DP: Split images across ranks, all_gather once at the end +""" + +import torch +import torch.distributed as dist +from torch.autograd import Function + +from roll.utils.context_parallel.globals import get_ulysses_group, get_ulysses_size + + +def get_image_patch_counts(grid_thw: torch.Tensor) -> list[int]: + """Compute number of patches per image from grid_thw. + + Args: + grid_thw: Tensor of shape (num_images, 3) where each row is [t, h, w]. + + Returns: + List of patch counts per image. + """ + if grid_thw.numel() == 0: + return [] + return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + + +def get_image_embedding_counts(grid_thw: torch.Tensor, spatial_merge_size: int = 1) -> list[int]: + """Compute number of embeddings per image after spatial merging. + + Args: + grid_thw: Tensor of shape (num_images, 3) where each row is [t, h, w]. + spatial_merge_size: Spatial merge factor (typically 2 for Qwen-VL). + + Returns: + List of embedding counts per image. + """ + if grid_thw.numel() == 0: + return [] + if spatial_merge_size == 1: + return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + t = grid_thw[:, 0] + h = grid_thw[:, 1] // spatial_merge_size + w = grid_thw[:, 2] // spatial_merge_size + return (t * h * w).tolist() + + +def assign_images_to_dp_ranks( + patch_counts: list[int], + dp_size: int, +) -> tuple[list[list[int]], list[int]]: + """Assign whole images to DP ranks using contiguous distribution. + + Rank 0 gets images [0, 1, ...], rank 1 gets next chunk, etc. + This ensures no reordering is needed after all-gather. + + Args: + patch_counts: Number of patches per image. + dp_size: Number of DP ranks. + + Returns: + Tuple of (image_assignments, rank_loads) where: + - image_assignments[rank] = list of image indices assigned to that rank + - rank_loads[rank] = total patches assigned to that rank + """ + num_images = len(patch_counts) + if num_images == 0: + return [[] for _ in range(dp_size)], [0] * dp_size + + image_assignments: list[list[int]] = [[] for _ in range(dp_size)] + rank_loads = [0] * dp_size + + base_size = num_images // dp_size + remainder = num_images % dp_size + + start = 0 + for rank in range(dp_size): + chunk_size = base_size + (1 if rank < remainder else 0) + end = start + chunk_size + for img_idx in range(start, end): + image_assignments[rank].append(img_idx) + rank_loads[rank] += patch_counts[img_idx] + start = end + + return image_assignments, rank_loads + + +def prepare_local_vision_inputs( + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + image_assignments: list[list[int]], + dp_rank: int, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Extract pixel values and grid_thw for this DP rank's assigned images. + + Args: + pixel_values: All pixel values concatenated, shape (total_patches, dim). + grid_thw: Grid dimensions per image, shape (num_images, 3). + image_assignments: Per-rank image index assignments. + dp_rank: This rank's index in the DP group. + + Returns: + Tuple of (local_pixel_values, local_grid_thw, local_indices). + """ + local_indices = image_assignments[dp_rank] + + if len(local_indices) == 0: + return ( + torch.empty( + (0, pixel_values.shape[1]) if pixel_values.dim() > 1 else (0,), + dtype=pixel_values.dtype, + device=pixel_values.device, + ), + torch.empty((0, 3), dtype=grid_thw.dtype, device=grid_thw.device), + [], + ) + + patch_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + cumsum = [0] + for c in patch_counts: + cumsum.append(cumsum[-1] + c) + + local_patches = [] + local_grids = [] + for idx in local_indices: + start, end = cumsum[idx], cumsum[idx + 1] + local_patches.append(pixel_values[start:end]) + local_grids.append(grid_thw[idx : idx + 1]) + + local_pixel_values = torch.cat(local_patches, dim=0) + local_grid_thw = torch.cat(local_grids, dim=0) + + expected_patches = sum(patch_counts[idx] for idx in local_indices) + assert local_pixel_values.shape[0] == expected_patches + + return local_pixel_values, local_grid_thw, local_indices + + +class GatherVisionEmbeddings(Function): + """All-gather vision embeddings with gradient support. + + Contiguous assignment means simple concat without reordering. + Backward: scales gradients by dp_size to compensate for partial processing. + """ + + @staticmethod + def forward(ctx, local_embeddings, dp_group, grad_scaler=True): + ctx.grad_scaler = grad_scaler + dp_size = dist.get_world_size(dp_group) + dp_rank = dist.get_rank(dp_group) + ctx.dp_size = dp_size + + if dp_size == 1: + return local_embeddings + + local_count = torch.tensor( + [local_embeddings.shape[0]], dtype=torch.long, device=local_embeddings.device + ) + all_counts = [torch.zeros_like(local_count) for _ in range(dp_size)] + dist.all_gather(all_counts, local_count, group=dp_group) + all_counts = [c.item() for c in all_counts] + ctx.all_counts = all_counts + ctx.dp_rank = dp_rank + + max_count = max(all_counts) if all_counts else 0 + if max_count == 0: + return local_embeddings + + hidden_size = local_embeddings.shape[1] if local_embeddings.dim() > 1 else 1 + ctx.hidden_size = hidden_size + + if local_embeddings.shape[0] < max_count: + pad_size = max_count - local_embeddings.shape[0] + padding = torch.zeros( + (pad_size, hidden_size), + dtype=local_embeddings.dtype, + device=local_embeddings.device, + ) + local_padded = torch.cat([local_embeddings, padding], dim=0) + else: + local_padded = local_embeddings + + gathered = [torch.empty_like(local_padded) for _ in range(dp_size)] + dist.all_gather(gathered, local_padded, group=dp_group) + + result_chunks = [gathered[r][: all_counts[r]] for r in range(dp_size)] + result = torch.cat(result_chunks, dim=0) + return result + + @staticmethod + def backward(ctx, grad_output): + dp_size = ctx.dp_size + grad_scaler = ctx.grad_scaler + + if dp_size == 1: + return grad_output, None, None + + all_counts = ctx.all_counts + dp_rank = ctx.dp_rank + + if grad_scaler: + grad_output = grad_output * dp_size + + start = sum(all_counts[:dp_rank]) + end = start + all_counts[dp_rank] + local_grad = grad_output[start:end] + return local_grad, None, None + + +def gather_vision_embeddings(local_embeddings, dp_group=None, grad_scaler=True): + """All-gather vision embeddings from all DP ranks. + + Args: + local_embeddings: This rank's vision embeddings. + dp_group: Process group for all-gather. Defaults to Ulysses group. + grad_scaler: Whether to scale gradients in backward pass. + + Returns: + All-gathered embeddings concatenated across ranks. + """ + dp_group = get_ulysses_group() if dp_group is None else dp_group + if dp_group is None or dist.get_world_size(dp_group) == 1: + return local_embeddings + return GatherVisionEmbeddings.apply(local_embeddings, dp_group, grad_scaler) + + +def create_dp_vision_forward(original_forward): + """Wrap VisionTransformer.forward for Vision DP. + + Model-agnostic wrapper for any VisionTransformer with + ``forward(self, hidden_states, grid_thw, **kwargs) -> Tensor`` signature. + + When Ulysses SP size > 1, distributes images across SP ranks and + all-gathers the embeddings after ViT computation. + + Args: + original_forward: The original VisionTransformer.forward method. + + Returns: + Wrapped forward method with Vision DP support. + """ + + def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): + dp_size = get_ulysses_size() + if dp_size is None or dp_size <= 1: + return original_forward(self, hidden_states, grid_thw, **kwargs) + + dp_group = get_ulysses_group() + dp_rank = dist.get_rank(dp_group) + + # Step 1: Get image assignment + patch_counts = get_image_patch_counts(grid_thw) + total_patches = sum(patch_counts) + assert hidden_states.shape[0] == total_patches + + spatial_merge_size = 1 + if hasattr(self, "merger") and hasattr(self.merger, "spatial_merge_size"): + spatial_merge_size = self.merger.spatial_merge_size + elif hasattr(self, "spatial_merge_size"): + spatial_merge_size = self.spatial_merge_size + + embedding_counts = get_image_embedding_counts(grid_thw, spatial_merge_size) + total_embeddings = sum(embedding_counts) + + image_assignments, rank_loads = assign_images_to_dp_ranks(patch_counts, dp_size) + + # Step 2: Extract local inputs + local_pixels, local_grid_thw, local_indices = prepare_local_vision_inputs( + hidden_states, grid_thw, image_assignments, dp_rank + ) + + # Step 3: Process local images + if local_pixels.shape[0] > 0: + local_embeddings = original_forward(self, local_pixels, local_grid_thw, **kwargs) + else: + # Determine hidden_size for empty tensor + if hasattr(self, "merger") and hasattr(self.merger, "ln_q"): + ln_q = self.merger.ln_q + if hasattr(ln_q, "normalized_shape"): + hidden_size = ln_q.normalized_shape[0] + elif hasattr(ln_q, "weight"): + hidden_size = ln_q.weight.shape[0] + else: + raise RuntimeError( + "Cannot determine hidden_size from merger.ln_q: " + "no 'normalized_shape' or 'weight' attribute found" + ) + elif hasattr(self, "out_hidden_size"): + hidden_size = self.out_hidden_size + elif hasattr(self, "config") and hasattr(self.config, "hidden_size"): + hidden_size = self.config.hidden_size + else: + raise RuntimeError( + "Cannot determine hidden_size for empty Vision DP output. " + "Expected one of: self.merger.ln_q, self.out_hidden_size, self.config.hidden_size" + ) + + local_embeddings = torch.empty( + (0, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # Handle Qwen3-VL which returns (embeddings, deepstack_embeddings) + deepstack_outputs = None + if isinstance(local_embeddings, tuple): + local_embeddings, deepstack_outputs = local_embeddings[0], local_embeddings[1:] + + # Step 4: All-gather + all_embeddings = gather_vision_embeddings(local_embeddings, dp_group) + assert all_embeddings.shape[0] == total_embeddings + + if deepstack_outputs is not None: + # All-gather deepstack embeddings too + gathered_deepstack = [] + for ds_emb in deepstack_outputs: + if isinstance(ds_emb, list): + # List of tensors (one per deepstack layer) + gathered_list = [] + for single_emb in ds_emb: + gathered_list.append(gather_vision_embeddings(single_emb, dp_group)) + gathered_deepstack.append(gathered_list) + elif isinstance(ds_emb, torch.Tensor): + gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group)) + else: + gathered_deepstack.append(ds_emb) + return (all_embeddings, *gathered_deepstack) + + return all_embeddings + + return dp_vision_forward diff --git a/tests/utils/test_vision_dp_on_cpu.py b/tests/utils/test_vision_dp_on_cpu.py new file mode 100644 index 00000000..86b6410a --- /dev/null +++ b/tests/utils/test_vision_dp_on_cpu.py @@ -0,0 +1,235 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Alibaba Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for Vision Data Parallel utilities. +Ported from verl (https://github.com/verl-project/verl/pull/5230). +""" + +import pytest +import torch + +from roll.utils.context_parallel.vision_dp import ( + assign_images_to_dp_ranks, + get_image_patch_counts, + prepare_local_vision_inputs, +) + + +class TestGetImagePatchCounts: + """Tests for get_image_patch_counts function.""" + + def test_basic_patch_counts(self): + grid_thw = torch.tensor([ + [2, 4, 4], # 2*4*4 = 32 + [1, 2, 2], # 1*2*2 = 4 + [1, 8, 8], # 1*8*8 = 64 + ]) + counts = get_image_patch_counts(grid_thw) + assert counts == [32, 4, 64] + + def test_single_image(self): + grid_thw = torch.tensor([[1, 4, 4]]) # 16 patches + counts = get_image_patch_counts(grid_thw) + assert counts == [16] + + def test_empty_input(self): + grid_thw = torch.empty((0, 3), dtype=torch.long) + counts = get_image_patch_counts(grid_thw) + assert counts == [] + + def test_video_frames(self): + grid_thw = torch.tensor([[4, 4, 4]]) # 4 frames, 4*4 patches each = 64 + counts = get_image_patch_counts(grid_thw) + assert counts == [64] + + +class TestAssignImagesToDpRanks: + """Tests for assign_images_to_dp_ranks function.""" + + def test_balanced_assignment(self): + patch_counts = [100, 100, 100, 100] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + assert len(assignments[0]) == 2 + assert len(assignments[1]) == 2 + assert loads[0] == 200 + assert loads[1] == 200 + + def test_imbalanced_images(self): + patch_counts = [500, 100, 100, 100] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + total_assigned = sum(len(a) for a in assignments) + assert total_assigned == 4 + assert 0 in assignments[0] or 0 in assignments[1] + + def test_fewer_images_than_ranks(self): + patch_counts = [100, 200] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + non_empty_ranks = sum(1 for a in assignments if len(a) > 0) + assert non_empty_ranks == 2 + all_assigned = set() + for a in assignments: + all_assigned.update(a) + assert all_assigned == {0, 1} + + def test_empty_input(self): + patch_counts = [] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + assert all(len(a) == 0 for a in assignments) + assert all(load == 0 for load in loads) + + def test_single_rank(self): + patch_counts = [100, 200, 300] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=1) + assert assignments == [[0, 1, 2]] + assert loads == [600] + + def test_equal_images_equal_size(self): + patch_counts = [100, 100, 100, 100, 100, 100] # 6 images + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=3) + assert all(len(a) == 2 for a in assignments) + assert all(load == 200 for load in loads) + + def test_image_order_preserved(self): + patch_counts = [10, 20, 30, 40, 50] + assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size=2) + for rank_assignment in assignments: + assert rank_assignment == sorted(rank_assignment) + + +class TestPrepareLocalVisionInputs: + """Tests for prepare_local_vision_inputs function.""" + + def test_basic_extraction(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([ + [1, 6, 6], # 36 patches (indices 0-35) + [1, 8, 8], # 64 patches (indices 36-99) + ]) + image_assignments = [[0], [1]] + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=0 + ) + assert local_pix.shape[0] == 36 + assert local_grid.shape[0] == 1 + assert local_indices == [0] + assert torch.allclose(local_pix, pixel_values[:36]) + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=1 + ) + assert local_pix.shape[0] == 64 + assert local_grid.shape[0] == 1 + assert local_indices == [1] + assert torch.allclose(local_pix, pixel_values[36:100]) + + def test_multiple_images_per_rank(self): + pixel_values = torch.randn(200, 768) + grid_thw = torch.tensor([ + [1, 5, 10], # 50 patches + [1, 5, 10], # 50 patches + [1, 5, 10], # 50 patches + [1, 5, 10], # 50 patches + ]) + image_assignments = [[0, 2], [1, 3]] + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=0 + ) + assert local_pix.shape[0] == 100 + assert local_grid.shape[0] == 2 + assert local_indices == [0, 2] + expected = torch.cat([pixel_values[0:50], pixel_values[100:150]], dim=0) + assert torch.allclose(local_pix, expected) + + def test_empty_rank(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0], []] + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=1 + ) + assert local_pix.shape[0] == 0 + assert local_grid.shape[0] == 0 + assert local_indices == [] + + def test_grid_thw_preserved(self): + pixel_values = torch.randn(150, 768) + grid_thw = torch.tensor([ + [1, 5, 5], # 25 patches + [2, 5, 5], # 50 patches + [3, 5, 5], # 75 patches + ]) + image_assignments = [[0, 2], [1]] + + _, local_grid, _ = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=0 + ) + assert local_grid.shape == (2, 3) + assert torch.equal(local_grid[0], grid_thw[0]) + assert torch.equal(local_grid[1], grid_thw[2]) + + +class TestIntegration: + """Integration tests combining multiple functions.""" + + def test_full_workflow(self): + grid_thw = torch.tensor([ + [1, 4, 4], # 16 patches + [1, 8, 8], # 64 patches + [1, 4, 4], # 16 patches + [1, 6, 6], # 36 patches + [1, 4, 4], # 16 patches + ]) + total_patches = 16 + 64 + 16 + 36 + 16 # 148 + pixel_values = torch.randn(total_patches, 768) + + patch_counts = get_image_patch_counts(grid_thw) + assert patch_counts == [16, 64, 16, 36, 16] + + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3, 4] + + total_local_patches = 0 + for rank in range(2): + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, assignments, dp_rank=rank + ) + expected_patches = sum(patch_counts[i] for i in local_indices) + assert local_pix.shape[0] == expected_patches + assert local_grid.shape[0] == len(local_indices) + total_local_patches += local_pix.shape[0] + + assert total_local_patches == total_patches + + def test_same_size_images(self): + num_images = 50 + patch_per_image = 64 + grid_thw = torch.tensor([[1, 8, 8]] * num_images) + total_patches = num_images * patch_per_image + _ = torch.randn(total_patches, 768) + + patch_counts = get_image_patch_counts(grid_thw) + assert all(c == 64 for c in patch_counts) + + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + for rank in range(4): + assert 12 <= len(assignments[rank]) <= 13 + for load in loads: + assert load in [768, 832]