diff --git a/conf/base.yaml b/conf/base.yaml index 5418f846..af146c3b 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -22,6 +22,8 @@ actor: result_queue_size: 64 throughput_window_size: 50 shared_memory_entry_size: 10000000 + # Maximum number of entries to retain in the actor data stream (Redis only for now) + max_stream_size: 1000000 environment: null preprocess: input: actor @@ -30,7 +32,7 @@ preprocess: chunk_n_groups: 2 # queue for loaded raw groups raw_queue_size: 8 - # queue for processed chunks of multiple groups + # queue for processed chunks of multiple groups input_queue_size: 32 # queue for ready chunks for multiple groups output_queue_size: 32 @@ -40,9 +42,11 @@ preprocess: ring_buffer_size: 128 # "virtual" sample queue per lead trainer max_ready_samples_per_lead: 64 - pop_old_data: ${..pop_old_data} + pop_old_data: ${..pop_old_data} shared_memory_entry_size: 100000000 log_every_n_samples: 128 + # Maximum number of entries to retain in the training data stream (Redis only for now) + max_stream_size: 1000000 llm: parameters: # changed @@ -96,6 +100,9 @@ eval_every_n_versions: 78000 # changed model_path: Qwen/Qwen2.5-7B +# Processor configuration for vision-language models (multimodal) +mm_processor_kwargs: {} + # will use default based on the chosen backend accelerate_config: null use_deepspeed: true diff --git a/conf/chartqa.yaml b/conf/chartqa.yaml index 154db7ca..2ed7b4cb 100644 --- a/conf/chartqa.yaml +++ b/conf/chartqa.yaml @@ -8,6 +8,9 @@ finetune: seq_length: 8000 gradient_accumulation_passes: 512 seq_packing: false + rl: + epsilon_high: 4.0 + epsilon_low: 0.0 llm: parameters: @@ -24,13 +27,15 @@ actor: system_prompt: You are an expert at analyzing charts and graphs. Please examine the chart carefully and answer the question accurately. Remember to provide your final answer in a boxed format, like \\boxed{{your answer}}. task_template: |- Question: {question} - + Please analyze the chart step by step and put your final answer within \\boxed{{}}. llm_max_rollouts: 16 shared_memory_entry_size: 2000000000 + max_stream_size: 1000 preprocess: shared_memory_entry_size: 2000000000 + max_stream_size: 1000 environment: null @@ -45,9 +50,18 @@ test_dataset_names: # Use vision-language model for multimodal support model_path: Qwen/Qwen2.5-VL-3B-Instruct +eval_every_n_versions: 12500 + +# Processor configuration for vision-language models (shared between training and inference) +mm_processor_kwargs: + min_pixels: 784 # 28*28 + max_pixels: 1003520 # 1280*28*28 + use_fast: true + # Override vLLM config for multimodal support vllm_config: use_v1: true vllm_kwargs: max-num-seqs: 64 max-num-batched-tokens: 32768 + mm-processor-kwargs: '{"min_pixels": 784, "max_pixels": 1003520, "use_fast": true}' # 28*28 to 1280*28*28 diff --git a/conf/finetune/base.yaml b/conf/finetune/base.yaml index e707a1a9..bc5d9539 100644 --- a/conf/finetune/base.yaml +++ b/conf/finetune/base.yaml @@ -3,6 +3,8 @@ data: null model_class: causal-language-modeling # Model name or path of model to be trained. config_name: ${..model_path} +# Freeze vision tower for vision-language models (only applicable for vision2seq-language-modeling) +freeze_vision_tower: false # Optimizer type, supported: adamw_torch, adafactor, cpuadam, lion optim: adamw_torch # use half precision training, full bf16 without mixed precision copies at all diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 6d3317af..add29e02 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -439,7 +439,7 @@ def run(self, dataset: list[tuple[str, dict]]): logger.info(f"Start {'train' if self.is_training else 'test'} actor loop") with ( - write_to_streams(self.data_stream, "a") as data_stream_writer, + write_to_streams(self.data_stream, "a", max_stream_size=self.cfg.actor.max_stream_size) as data_stream_writer, write_to_streams(self.stats_stream, "a") as stats_writer, ): while True: @@ -625,6 +625,7 @@ def run_actor_loop(cfg: DictConfig): tokenizer_name=str(actor_model_path), parameters=cfg.llm.parameters, collect_logprobs=True, + mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}), ) for url in llm_urls ] @@ -635,6 +636,7 @@ def run_actor_loop(cfg: DictConfig): tokenizer_name=str(actor_model_path), parameters=cfg.test_llm.parameters, collect_logprobs=True, + mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}), ) for url in llm_urls ] diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 4e78ebf9..a8b45b26 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -4,12 +4,13 @@ import aiohttp import numpy as np +import torch from PIL import Image from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM from pipelinerl.finetune.data import MASKED_TOKEN_ID from pipelinerl.rollouts import TrainingText -from pipelinerl.processor_factory import get_processor +from pipelinerl.vision_processor_utils import get_mm_processor from omegaconf import DictConfig, ListConfig, OmegaConf logger = logging.getLogger(__name__) @@ -157,7 +158,7 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: if use_processor: # Use processor for vision-language models - processor = get_processor(llm.model_name) + processor = get_mm_processor(llm.model_name, mm_processor_kwargs=llm.mm_processor_kwargs) try: # Apply chat template using processor for proper image token handling @@ -189,11 +190,11 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: processed = processor( text=[prompt_text], images=images, padding=True, return_tensors=None ) + # Convert PyTorch tensors to numpy arrays visual_features = { - key: value + key: value.cpu().numpy() if torch.is_tensor(value) else value for key, value in processed.items() - if isinstance(value, np.ndarray) - and key not in ["input_ids", "attention_mask"] + if key not in ["input_ids", "attention_mask"] } except Exception as e: diff --git a/pipelinerl/finetune/checkpoints.py b/pipelinerl/finetune/checkpoints.py index 1dda0b99..c090fdc5 100644 --- a/pipelinerl/finetune/checkpoints.py +++ b/pipelinerl/finetune/checkpoints.py @@ -210,6 +210,28 @@ def load_model(args, model_class, current_dir): gradient_checkpointing_kwargs={"use_reentrant": args.reentrant_checkpointing} ) + # Freeze vision tower if specified + freeze_vision_tower = getattr(args, "freeze_vision_tower", False) + if freeze_vision_tower: + # Try to get vision tower module from the model + vision_tower = None + if hasattr(model, "visual"): + vision_tower = model.visual # Qwen-VL, Qwen2-VL, Qwen2.5-VL, Qwen3-VL + elif hasattr(model, "vision_tower"): + vision_tower = model.vision_tower # LLaVA + elif hasattr(model, "vision_model"): + vision_tower = model.vision_model # BLIP-2, InstructBLIP + + if vision_tower is not None: + vision_tower.requires_grad_(False) + logger.info("Vision tower parameters frozen successfully (i.e. its parameters will be excluded from optimizer)") + else: + logger.warning( + "freeze_vision_tower=True but could not find vision tower. " + "Checked attributes: model.visual (Qwen*-VL), model.vision_tower (LlaVA), model.vision_model (BLIP-2, InstructBLIP). " + "So setting this parameter does not have any effect." + ) + get_accelerator().wait_for_everyone() return model diff --git a/pipelinerl/finetune/data.py b/pipelinerl/finetune/data.py index 4e395e3b..833f6490 100644 --- a/pipelinerl/finetune/data.py +++ b/pipelinerl/finetune/data.py @@ -15,6 +15,7 @@ from pipelinerl.finetune.utils import create_sentinel_example from pipelinerl.rollouts import TrainingText +from pipelinerl.vision_processor_utils import collate_visual_features from .context import get_accelerator, logger from .rl import RL_DATA_COLUMNS, prepare_rl_fields @@ -172,17 +173,13 @@ def collate( if seq_length % pad_to_multiple_of: seq_length += pad_to_multiple_of - (seq_length % pad_to_multiple_of) result = {} - - # Visual feature fields that should be stacked, not padded - if "visual_features" in example_dict and isinstance(example_dict["visual_features"][0], dict): - for k, seq_list in example_dict["visual_features"][0].items(): - if k == "image_grid_thw": - # image_grid_thw should remain as a list - result[k] = seq_list - else: - # Other visual fields like pixel_values can be stacked as tensors - valid_tensors = [torch.tensor(seq) for seq in seq_list] - result[k] = torch.stack(valid_tensors) + + # Handle visual features with dynamic batching + if "visual_features" in example_dict: + visual_features_list = example_dict["visual_features"] + batched_visual_features = collate_visual_features(visual_features_list) + if batched_visual_features: + result["visual_features"] = batched_visual_features for k, seq_list in example_dict.items(): if k == "model_version": diff --git a/pipelinerl/finetune/optim.py b/pipelinerl/finetune/optim.py index 268e88cf..b1a7f738 100644 --- a/pipelinerl/finetune/optim.py +++ b/pipelinerl/finetune/optim.py @@ -12,6 +12,9 @@ def get_grouped_params( ): params_with_wd, params_without_wd = [], [] for n, p in model.named_parameters(): + # Skip frozen parameters + if not p.requires_grad: + continue if any(nd in n for nd in no_decay): params_without_wd.append(p) else: diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 9118078e..cd99bf19 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -190,13 +190,12 @@ def rl_step( } if batch.is_packed: model_inputs["position_ids"] = batch.position_ids - + # Add visual features if present (for multimodal models) - if hasattr(batch, 'pixel_values') and batch.pixel_values is not None: - model_inputs["pixel_values"] = batch.pixel_values - if hasattr(batch, 'image_grid_thw') and batch.image_grid_thw is not None: - model_inputs["image_grid_thw"] = batch.image_grid_thw #torch.tensor(.reshape((1, 3)) - + # Unpack all visual features from the dict (e.g., pixel_values, image_grid_thw, image_sizes) + if hasattr(batch, 'visual_features') and batch.visual_features is not None: + model_inputs.update(batch.visual_features) + outputs = model(**model_inputs) # compute log probs and entropy diff --git a/pipelinerl/finetune/types.py b/pipelinerl/finetune/types.py index a3c16f2e..eda751f7 100644 --- a/pipelinerl/finetune/types.py +++ b/pipelinerl/finetune/types.py @@ -70,11 +70,11 @@ class PipelineBatchEncoding(BaseModel): is_packed: bool = False seq_boundaries: torch.IntTensor | None = None # Required when seq_packing=True - # Visual feature fields (optional, for multimodal models) - pixel_values: torch.FloatTensor | None = None - image_grid_thw: torch.LongTensor | None = None + # Visual features (optional, for multimodal models) + # Dict containing model-specific visual features (e.g., pixel_values, image_grid_thw, image_sizes) + visual_features: dict[str, torch.Tensor] | None = None - @field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'image_grid_thw', 'segment_ids', mode='before') + @field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'segment_ids', mode='before') @classmethod def convert_to_long_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.LongTensor | None: """Handle initialization of long tensors from different types.""" @@ -95,9 +95,8 @@ def convert_to_int_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.IntT if isinstance(v, torch.Tensor): return v.int() # type: ignore return torch.tensor(v, dtype=torch.int) - - # TODO: am i needed? - @field_validator('rewards', 'advantages', 'ref_logprobs', 'old_logprobs', 'group_tokens', 'num_labels', 'overflow', 'pixel_values', mode='before') + + @field_validator('rewards', 'advantages', 'ref_logprobs', 'old_logprobs', 'group_tokens', 'num_labels', 'overflow', mode='before') @classmethod def convert_to_float_tensor(cls, v: List[float] | torch.Tensor | None) -> torch.FloatTensor | None: """Handle initialization of float tensors from different types.""" @@ -111,10 +110,16 @@ def convert_to_float_tensor(cls, v: List[float] | torch.Tensor | None) -> torch. def to_device(self, device: Union[str, torch.device]) -> 'PipelineBatchEncoding': """Move all tensors to the specified device and return updated instance.""" - for field_name in self.model_fields: + for field_name in type(self).model_fields: field_value = getattr(self, field_name) if isinstance(field_value, torch.Tensor): setattr(self, field_name, field_value.to(device)) + elif isinstance(field_value, dict): + setattr( + self, + field_name, + {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in field_value.items()} + ) return self @classmethod @@ -173,8 +178,7 @@ def make_slices(self, num_slices: int) -> list['PipelineBatchEncoding']: "is_packed": self.is_packed, "padding": self.padding, "seq_boundaries": self.seq_boundaries, - "pixel_values": self.pixel_values, - "image_grid_thw": self.image_grid_thw + "visual_features": self.visual_features } slices.append(PipelineBatchEncoding(**result)) return slices diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index b23d6429..623c877c 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -353,9 +353,26 @@ def run_finetuning_loop( logger.info("About to load model") model = load_model(args, args.model_class, current_dir) logger.info(f"Model loaded in dtype {model.dtype}") - dt = log_time(dt, time_stats, "finetune/model_load") + # Fix for multimodal models (e.g., Apriel, Qwen3-VL) with Accelerate+DeepSpeed + # Accelerate's _prepare_deepspeed() doesn't check text_config.hidden_size + if not hasattr(model.config, "hidden_size") and not hasattr(model.config, "hidden_sizes"): + if hasattr(model.config, "text_config"): + hidden_size = None + if hasattr(model.config.text_config, "hidden_size"): + hidden_size = model.config.text_config.hidden_size + elif hasattr(model.config.text_config, "hidden_sizes"): + hidden_size = max(model.config.text_config.hidden_sizes) + + if hidden_size is not None: + if get_accelerator().is_main_process: + logger.info( + f"Detected multimodal model with text_config.hidden_size={hidden_size}. " + f"Setting config.hidden_size to enable DeepSpeed auto-configuration." + ) + model.config.hidden_size = hidden_size + data_stream = SingleStreamSpec( exp_path=exp_root_dir, topic=args.input, diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 56bb9f8a..817e31ee 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -56,12 +56,8 @@ def validate_config(cfg: DictConfig): # Check for vision language model constraints if cfg.finetune.model_class == "vision2seq-language-modeling": - if "Qwen2.5-VL" not in cfg.model_path: - raise ValueError("Only Qwen2.5-VL models are supported for vision language modeling") if cfg.finetune.seq_packing: raise ValueError("Vision language models cannot use sequence packing (seq_packing must be false)") - if cfg.finetune.train_batch_size > 1: - raise ValueError("Vision language models cannot use batch size > 1 (train_batch_size must be 1)") if cfg.finetune.seq_parallel > 1: if not cfg.finetune.seq_packing: @@ -132,7 +128,7 @@ def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus os.makedirs(log_dir, exist_ok=True) cmd = [ - "python", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "--model", @@ -187,7 +183,7 @@ def run_actor_llm( "pipelinerl.entrypoints.run_vllm0" ) cmd = [ - "python", + sys.executable, "-m", entrypoint, "--model", @@ -240,7 +236,7 @@ def run_actor(world_map: WorldMap, actor_idx: int, exp_dir: Path): raise NotImplementedError("Can only do 1 actor yet") llm_urls = "+".join(world_map.get_actor_urls()) cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_actor", "--config-dir", @@ -265,7 +261,7 @@ def run_environment(cfg: DictConfig, job: Job): # run in a subprocess like in the rest of the code run_dir = Path(cfg.output_dir) / f"environment_{job.replica_idx}" cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_environment", "--config-dir", @@ -296,7 +292,7 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: if cfg.use_fsdp and cfg.use_deepspeed: raise ValueError("Cannot use both FSDP and DeepSpeed") cmd = [ - "python", + sys.executable, "-m", "accelerate.commands.launch", ] @@ -393,7 +389,7 @@ def run_preprocess(world_map: WorldMap, preprocessor_idx: int, exp_dir: Path): raise NotImplementedError("Can only do 1 preprocessor yet") llm_urls = "+".join(world_map.get_preprocessor_urls()) cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_preprocess", "--config-dir", @@ -544,6 +540,7 @@ def debug_link_streams(cfg: DictConfig, topics: list[str]): if not cfg.debug.streams_from: raise ValueError("Need to specify streams_from for debug mode") stream_dir = Path(cfg.output_dir) / "streams" + stream_dir.mkdir(parents=True, exist_ok=True) for topic in topics: source_topic_dir = Path(cfg.debug.streams_from) / "streams" / topic target_topic_dir = stream_dir / topic diff --git a/pipelinerl/llm.py b/pipelinerl/llm.py index 96950231..57d662bd 100644 --- a/pipelinerl/llm.py +++ b/pipelinerl/llm.py @@ -370,6 +370,7 @@ class TrainableLLM(LLM): max_parallel_requests: int = 32 max_retries: int = 5 base_delay: float = 0.5 + mm_processor_kwargs: dict = Field(default_factory=dict) _semaphore: asyncio.Semaphore def model_post_init(self, __context): diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 0a6015e4..de0cd065 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -419,6 +419,7 @@ def run_preprocessing_loop( model_name=cfg.finetune.config_name, tokenizer_name=cfg.finetune.config_name, parameters=cfg.llm.parameters, + mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}), ) for url in llm_urls ] @@ -462,7 +463,8 @@ def run_preprocessing_loop( # Per-trainer sample tracking (similar to finetune_loop.py) total_filtered_out = 0 # Track total filtered samples across all batches - with write_to_streams(output_stream) as data_writer, write_to_streams(stats_streams) as stats_writer: + max_stream_size = cfg.preprocess.max_stream_size + with write_to_streams(output_stream, max_stream_size=max_stream_size) as data_writer, write_to_streams(stats_streams) as stats_writer: with SharedMemoryManager() as smm: # Create shared memory queues without the manager parameter input_queue = SharedMemoryQueue(smm, cfg.preprocess.input_queue_size, cfg.preprocess.shared_memory_entry_size) diff --git a/pipelinerl/processor_factory.py b/pipelinerl/processor_factory.py deleted file mode 100644 index 06f0fc2b..00000000 --- a/pipelinerl/processor_factory.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Simple cache for AutoProcessor instances.""" -from typing import Dict -from transformers import AutoProcessor -import logging -logger = logging.getLogger(__name__) - -_processors: Dict[str, AutoProcessor] = {} - -def get_processor(model_name: str) -> AutoProcessor: - """Get or create an AutoProcessor for the given model.""" - if model_name not in _processors: - logger.info(f"Loading processor for model: {model_name}") - #TODO: should be args - _processors[model_name] = AutoProcessor.from_pretrained(model_name, min_pixels=28*28, max_pixels=1280*28*28) - return _processors[model_name] - -def clear_cache() -> None: - """Clear all cached processors.""" - _processors.clear() \ No newline at end of file diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py index 632b760e..b4ac7851 100644 --- a/pipelinerl/streams.py +++ b/pipelinerl/streams.py @@ -110,7 +110,7 @@ def connect_to_redis(config: RedisConfig): logger.debug(f"Trying to connect to Redis server at {config.host}:{config.port}") client = redis.Redis(host=config.host, port=config.port) client.ping() - logger.info(f"Connected to Redis server") + logger.debug("Connected to Redis server") return client except (redis.exceptions.TimeoutError, redis.ConnectionError) as e: logger.warning(f"Waiting for Redis server ({type(e)}). Retrying in 5 seconds.") @@ -118,8 +118,15 @@ def connect_to_redis(config: RedisConfig): class RedisStreamWriter(StreamWriter): - def __init__(self, stream: SingleStreamSpec, mode: Literal["w", "a"] = "a"): + def __init__(self, stream: SingleStreamSpec, mode: Literal["w", "a"] = "a", max_stream_size: int = 1000000): + """ + Args: + stream: The stream specification + mode: Write mode - 'w' for write (new stream) or 'a' for append + max_stream_size: Maximum number of entries to retain in the stream (Redis only) + """ self.stream = stream + self.max_stream_size = max_stream_size assert isinstance(_backend, RedisConfig) self._stream_name = str(self.stream) self._redis = connect_to_redis(_backend) @@ -155,7 +162,7 @@ def write(self, data, partition: int | None = None): if isinstance(data, BaseModel): data = data.model_dump() data = pickle.dumps(data) - self._redis.xadd(self._stream_name, {"index": self._index, "data": data}, maxlen=1000000, approximate=True) + self._redis.xadd(self._stream_name, {"index": self._index, "data": data}, maxlen=self.max_stream_size, approximate=True) self._index += 1 @@ -195,7 +202,13 @@ def read(self): class RoundRobinRedisStreamWriter(StreamWriter): # TODO: share the connection across writers - def __init__(self, streams: StreamRangeSpec, mode: Literal["w", "a"] = "a"): + def __init__(self, streams: StreamRangeSpec, mode: Literal["w", "a"] = "a", max_stream_size: int = 1000000): + """ + Args: + streams: The stream range specification + mode: Write mode - 'w' for write (new stream) or 'a' for append + max_stream_size: Maximum number of entries to retain in the stream (Redis only) + """ self.streams = streams self._next_stream = 0 self._writers = [ @@ -207,6 +220,7 @@ def __init__(self, streams: StreamRangeSpec, mode: Literal["w", "a"] = "a"): partition=i, ), mode=mode, + max_stream_size=max_stream_size, ) for i in range(*self.streams.partition_range) ] @@ -400,16 +414,23 @@ def read_stream(stream: SingleStreamSpec) -> StreamReader: assert False -def write_to_streams(streams: StreamSpec, mode: Literal["w", "a"] = "a") -> StreamWriter: - """Append to the end of the stream.""" +def write_to_streams(streams: StreamSpec, mode: Literal["w", "a"] = "a", max_stream_size: int = 1000000) -> StreamWriter: + """ + Append to the end of the stream. + + Args: + streams: The stream specification + mode: Write mode - 'w' for write (new stream) or 'a' for append + max_stream_size: Maximum number of entries to retain in the stream (Redis only) + """ raise_if_backend_not_set() if not isinstance(streams, (SingleStreamSpec, StreamRangeSpec)): raise ValueError(f"Invalid stream spec: {streams}") if isinstance(_backend, RedisConfig): if isinstance(streams, SingleStreamSpec): - return RedisStreamWriter(streams, mode) + return RedisStreamWriter(streams, mode, max_stream_size) elif isinstance(streams, StreamRangeSpec): - return RoundRobinRedisStreamWriter(streams, mode) + return RoundRobinRedisStreamWriter(streams, mode, max_stream_size) else: assert False elif _backend == "files": diff --git a/pipelinerl/vision_processor_utils.py b/pipelinerl/vision_processor_utils.py new file mode 100644 index 00000000..ee2e8635 --- /dev/null +++ b/pipelinerl/vision_processor_utils.py @@ -0,0 +1,122 @@ +""" +Vision processor utilities for multimodal models. + +This module provides processor caching and management for vision-language models. + +Supported models: +- Qwen2.5-VL: Uses image_grid_thw (B, 3) and flattened pixel_values +- Pixtral/Apriel: Uses image_sizes (B, 2) and standard pixel_values (B, C, H, W) +""" +import logging +from typing import Dict +import torch +from transformers import AutoProcessor + +logger = logging.getLogger(__name__) + +# Processor cache +_processors: Dict[str, AutoProcessor] = {} + + +def get_mm_processor(model_name: str, mm_processor_kwargs: dict | None = None) -> AutoProcessor: + """ + Get or create an AutoProcessor for multimodal models. + + Args: + model_name: HuggingFace model identifier + mm_processor_kwargs: Optional kwargs to pass to AutoProcessor.from_pretrained() + + Returns: + AutoProcessor instance + """ + if model_name not in _processors: + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + logger.info(f"Loading processor for model: {model_name} with kwargs: {mm_processor_kwargs}") + _processors[model_name] = AutoProcessor.from_pretrained( + model_name, **mm_processor_kwargs + ) + return _processors[model_name] + + +def clear_cache() -> None: + """Clear all cached processors.""" + _processors.clear() + + +def collate_visual_features(visual_features_list: list[dict]) -> dict[str, torch.Tensor]: + """ + Collate visual features from multiple samples into batched tensors. + + Handles different formats: + - Metadata (image_grid_thw, image_sizes): Concatenate along image dimension + - Qwen pixel_values (2D): Concatenate flattened features + - Pixtral pixel_values (4D): Pad to max_num_images + - Other features: Pad to max_num_images + + Args: + visual_features_list: List of visual feature dicts from individual samples + + Returns: + Dict mapping feature names to batched tensors + """ + if not visual_features_list or visual_features_list[0] is None: + return {} + + first_vf = visual_features_list[0] + batched_visual_features = {} + + for key in first_vf.keys(): + if key in ("image_grid_thw", "image_sizes"): + # Concatenate metadata arrays (image_grid_thw or image_sizes) + # Each sample has shape (num_images, 2 or 3), concatenate along image dimension + all_metadata = [torch.as_tensor(vf[key]) for vf in visual_features_list] + batched_visual_features[key] = torch.cat(all_metadata, dim=0) + + elif key == "pixel_values": + # Handle pixel_values - format differs by model: + # - Qwen: (total_pixels, hidden_dim) - flattened, concatenate along pixel dimension + # - Pixtral: (num_images, C, H, W) - standard, needs padding to max_num_images + all_tensors = [torch.as_tensor(vf[key]) for vf in visual_features_list] + + # Check if this is flattened format (2D) or image format (4D) + if all_tensors[0].ndim == 2: + # Qwen format: (total_pixels, hidden_dim) - just concatenate + batched_visual_features[key] = torch.cat(all_tensors, dim=0) + elif all_tensors[0].ndim == 4: + # Pixtral format: (num_images, C, H, W) - pad to max_num_images + max_num_images = max(t.shape[0] for t in all_tensors) + single_shape = all_tensors[0].shape[1:] # (C, H, W) + dtype = all_tensors[0].dtype + + # Pre-allocate: (batch_size, max_num_images, C, H, W) + batch_shape = (len(all_tensors), max_num_images) + single_shape + batched = torch.zeros(batch_shape, dtype=dtype) + + # Fill in actual data + for i, tensor in enumerate(all_tensors): + num_images = tensor.shape[0] + batched[i, :num_images] = tensor + + batched_visual_features[key] = batched + else: + raise ValueError(f"Unexpected pixel_values shape: {all_tensors[0].shape}") + + else: + # Other visual features - assume they need padding like Pixtral pixel_values + all_tensors = [torch.as_tensor(vf[key]) for vf in visual_features_list] + max_num_images = max(t.shape[0] for t in all_tensors) + single_shape = all_tensors[0].shape[1:] + dtype = all_tensors[0].dtype + + batch_shape = (len(all_tensors), max_num_images) + single_shape + batched = torch.zeros(batch_shape, dtype=dtype) + + for i, tensor in enumerate(all_tensors): + num_images = tensor.shape[0] + batched[i, :num_images] = tensor + + batched_visual_features[key] = batched + + return batched_visual_features