From e3ed0b8ee57d6a485af0977c11d73fcbd7e81230 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 26 Nov 2025 20:39:27 +0000 Subject: [PATCH 01/12] argument for freezing vision encoder parameters added --- conf/chartqa.yaml | 1 + conf/finetune/base.yaml | 6 +++++ pipelinerl/finetune_loop.py | 51 +++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/conf/chartqa.yaml b/conf/chartqa.yaml index 154db7ca..4b9c3875 100644 --- a/conf/chartqa.yaml +++ b/conf/chartqa.yaml @@ -8,6 +8,7 @@ finetune: seq_length: 8000 gradient_accumulation_passes: 512 seq_packing: false + freeze_vision_tower: true llm: parameters: diff --git a/conf/finetune/base.yaml b/conf/finetune/base.yaml index 4998a8bf..c80c3423 100644 --- a/conf/finetune/base.yaml +++ b/conf/finetune/base.yaml @@ -3,6 +3,12 @@ data: null model_class: causal-language-modeling # Model name or path of model to be trained. config_name: ${..model_path} +# Freeze vision tower for vision-language models (only applicable for vision2seq-language-modeling) +# Auto-detects common patterns: visual., vision_tower., vision_model., vision_embed_tokens., vit., qformer. +freeze_vision_tower: false +# Optional: Manually specify parameter prefixes to freeze (e.g., ["visual.", "qformer."]) +# If null, auto-detection will be used when freeze_vision_tower=true +vision_encoder_prefixes: null # Optimizer type, supported: adamw_torch, adafactor, cpuadam, lion optim: adamw_torch # use half precision training, full bf16 without mixed precision copies at all diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index da39938e..75018932 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -352,6 +352,57 @@ def run_finetuning_loop( model = load_model(args, args.model_class, current_dir) logger.info(f"Model loaded in dtype {model.dtype}") + # Freeze vision tower if specified + freeze_vision_tower = getattr(args, "freeze_vision_tower", False) + vision_encoder_prefixes = getattr(args, "vision_encoder_prefixes", None) + + if freeze_vision_tower: + # Auto-detect common vision encoder patterns if not specified + if vision_encoder_prefixes is None: + common_prefixes = [ + "visual.", # Qwen-VL, Qwen2-VL + "vision_tower.", # LLaVA + "vision_model.", # InstructBLIP, BLIP-2 + "vision_embed_tokens.", # Phi-3-Vision + "vit.", # CogVLM + "qformer.", # BLIP-2 Q-Former + ] + vision_encoder_prefixes = common_prefixes + + vision_encoder_parameters = set() + + # Check which prefixes exist in the model + for prefix in common_prefixes: + if any(name.startswith(prefix) for name, _ in model.named_parameters()): + vision_encoder_parameters.add(prefix) + + if not vision_encoder_parameters: + logger.warning( + "freeze_vision_tower=True but could not auto-detect vision encoder. " + "No parameters matching common patterns: " + ", ".join(common_prefixes) + ". " + "Set 'vision_encoder_prefixes' in config to specify manually." + ) + else: + logger.debug(f"Freezing vision encoder with prefixes: {vision_encoder_prefixes}") + total_params = 0 + frozen_params = 0 + frozen_param_names = [] + + for name, param in model.named_parameters(): + total_params += param.numel() + if name in vision_encoder_parameters: + param.requires_grad = False + frozen_params += param.numel() + frozen_param_names.append(name) + + trainable_params = total_params - frozen_params + logger.info( + f"Frozen vision encoder: {frozen_params:,} params | " + f"Trainable: {trainable_params:,} params | " + f"Total: {total_params:,} params | " + f"Trainable%: {100 * trainable_params / total_params:.2f}%" + ) + dt = log_time(dt, time_stats, "finetune/model_load") data_stream = SingleStreamSpec( From 97943e21ba6fb92cf58efdbb68c7e25be4bf453b Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 26 Nov 2025 21:12:33 +0000 Subject: [PATCH 02/12] freezing vision tower code simplified --- conf/finetune/base.yaml | 4 --- pipelinerl/finetune/checkpoints.py | 33 +++++++++++++++++++ pipelinerl/finetune_loop.py | 52 ------------------------------ 3 files changed, 33 insertions(+), 56 deletions(-) diff --git a/conf/finetune/base.yaml b/conf/finetune/base.yaml index c80c3423..0741249d 100644 --- a/conf/finetune/base.yaml +++ b/conf/finetune/base.yaml @@ -4,11 +4,7 @@ model_class: causal-language-modeling # Model name or path of model to be trained. config_name: ${..model_path} # Freeze vision tower for vision-language models (only applicable for vision2seq-language-modeling) -# Auto-detects common patterns: visual., vision_tower., vision_model., vision_embed_tokens., vit., qformer. freeze_vision_tower: false -# Optional: Manually specify parameter prefixes to freeze (e.g., ["visual.", "qformer."]) -# If null, auto-detection will be used when freeze_vision_tower=true -vision_encoder_prefixes: null # Optimizer type, supported: adamw_torch, adafactor, cpuadam, lion optim: adamw_torch # use half precision training, full bf16 without mixed precision copies at all diff --git a/pipelinerl/finetune/checkpoints.py b/pipelinerl/finetune/checkpoints.py index 9d949e17..f6ddc521 100644 --- a/pipelinerl/finetune/checkpoints.py +++ b/pipelinerl/finetune/checkpoints.py @@ -129,6 +129,39 @@ def load_model(args, model_class, current_dir): gradient_checkpointing_kwargs={"use_reentrant": args.reentrant_checkpointing} ) + # Freeze vision tower if specified + freeze_vision_tower = getattr(args, "freeze_vision_tower", False) + if freeze_vision_tower: + # Try to get vision tower module from the model + vision_tower = None + if hasattr(model, "visual"): + vision_tower = model.visual # Qwen-VL, Qwen2-VL, Qwen2.5-VL, Qwen3-VL + elif hasattr(model, "vision_tower"): + vision_tower = model.vision_tower # LLaVA + elif hasattr(model, "vision_model"): + vision_tower = model.vision_model # BLIP-2, InstructBLIP + + if vision_tower is not None: + vision_tower.requires_grad_(False) + + # Count frozen parameters + total_params = sum(p.numel() for p in model.parameters()) + frozen_params = sum(p.numel() for p in vision_tower.parameters()) + trainable_params = total_params - frozen_params + + logger.info( + f"Vision tower frozen: {frozen_params:,} params | " + f"Trainable: {trainable_params:,} params | " + f"Total: {total_params:,} params | " + f"Trainable%: {trainable_params / total_params:.2%}" + ) + else: + logger.warning( + "freeze_vision_tower=True but could not find vision tower. " + "Checked attributes: model.visual (Qwen*-VL), model.vision_tower (LlaVA), model.vision_model (BLIP-2, InstructBLIP). " + "So setting this parameter does not have any effect." + ) + get_accelerator().wait_for_everyone() return model diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 75018932..1cc9d2b1 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -351,58 +351,6 @@ def run_finetuning_loop( logger.info("About to load model") model = load_model(args, args.model_class, current_dir) logger.info(f"Model loaded in dtype {model.dtype}") - - # Freeze vision tower if specified - freeze_vision_tower = getattr(args, "freeze_vision_tower", False) - vision_encoder_prefixes = getattr(args, "vision_encoder_prefixes", None) - - if freeze_vision_tower: - # Auto-detect common vision encoder patterns if not specified - if vision_encoder_prefixes is None: - common_prefixes = [ - "visual.", # Qwen-VL, Qwen2-VL - "vision_tower.", # LLaVA - "vision_model.", # InstructBLIP, BLIP-2 - "vision_embed_tokens.", # Phi-3-Vision - "vit.", # CogVLM - "qformer.", # BLIP-2 Q-Former - ] - vision_encoder_prefixes = common_prefixes - - vision_encoder_parameters = set() - - # Check which prefixes exist in the model - for prefix in common_prefixes: - if any(name.startswith(prefix) for name, _ in model.named_parameters()): - vision_encoder_parameters.add(prefix) - - if not vision_encoder_parameters: - logger.warning( - "freeze_vision_tower=True but could not auto-detect vision encoder. " - "No parameters matching common patterns: " + ", ".join(common_prefixes) + ". " - "Set 'vision_encoder_prefixes' in config to specify manually." - ) - else: - logger.debug(f"Freezing vision encoder with prefixes: {vision_encoder_prefixes}") - total_params = 0 - frozen_params = 0 - frozen_param_names = [] - - for name, param in model.named_parameters(): - total_params += param.numel() - if name in vision_encoder_parameters: - param.requires_grad = False - frozen_params += param.numel() - frozen_param_names.append(name) - - trainable_params = total_params - frozen_params - logger.info( - f"Frozen vision encoder: {frozen_params:,} params | " - f"Trainable: {trainable_params:,} params | " - f"Total: {total_params:,} params | " - f"Trainable%: {100 * trainable_params / total_params:.2f}%" - ) - dt = log_time(dt, time_stats, "finetune/model_load") data_stream = SingleStreamSpec( From 0e336be664dbe23b32f242a6776a5fda102dab3d Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 26 Nov 2025 21:18:14 +0000 Subject: [PATCH 03/12] minor issue fixed --- pipelinerl/finetune/checkpoints.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pipelinerl/finetune/checkpoints.py b/pipelinerl/finetune/checkpoints.py index f6ddc521..d7a5d040 100644 --- a/pipelinerl/finetune/checkpoints.py +++ b/pipelinerl/finetune/checkpoints.py @@ -149,12 +149,18 @@ def load_model(args, model_class, current_dir): frozen_params = sum(p.numel() for p in vision_tower.parameters()) trainable_params = total_params - frozen_params - logger.info( - f"Vision tower frozen: {frozen_params:,} params | " - f"Trainable: {trainable_params:,} params | " - f"Total: {total_params:,} params | " - f"Trainable%: {trainable_params / total_params:.2%}" - ) + if total_params > 0: + logger.info( + f"Vision tower frozen: {frozen_params:,} params | " + f"Trainable: {trainable_params:,} params | " + f"Total: {total_params:,} params | " + f"Trainable%: {trainable_params / total_params:.2%}" + ) + else: + logger.warning( + "Total parameters is 0, cannot compute trainable percentage. " + "This indicates freeze_vision_tower may not have been applied correctly or the model has a different structure than expected." + ) else: logger.warning( "freeze_vision_tower=True but could not find vision tower. " From 7d17b5ac98c2332c971c4ef2dd3f2a99a1ee167b Mon Sep 17 00:00:00 2001 From: ehsk Date: Thu, 27 Nov 2025 19:35:22 +0000 Subject: [PATCH 04/12] removed unnecessary logs --- pipelinerl/finetune/checkpoints.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/pipelinerl/finetune/checkpoints.py b/pipelinerl/finetune/checkpoints.py index d7a5d040..210d1029 100644 --- a/pipelinerl/finetune/checkpoints.py +++ b/pipelinerl/finetune/checkpoints.py @@ -143,24 +143,7 @@ def load_model(args, model_class, current_dir): if vision_tower is not None: vision_tower.requires_grad_(False) - - # Count frozen parameters - total_params = sum(p.numel() for p in model.parameters()) - frozen_params = sum(p.numel() for p in vision_tower.parameters()) - trainable_params = total_params - frozen_params - - if total_params > 0: - logger.info( - f"Vision tower frozen: {frozen_params:,} params | " - f"Trainable: {trainable_params:,} params | " - f"Total: {total_params:,} params | " - f"Trainable%: {trainable_params / total_params:.2%}" - ) - else: - logger.warning( - "Total parameters is 0, cannot compute trainable percentage. " - "This indicates freeze_vision_tower may not have been applied correctly or the model has a different structure than expected." - ) + logger.info("Vision tower parameters frozen successfully (i.e. its parameters will be excluded from optimizer)") else: logger.warning( "freeze_vision_tower=True but could not find vision tower. " From 691073dcfbffbc52cf6601f4d85eb0eb6bfa681f Mon Sep 17 00:00:00 2001 From: ehsk Date: Thu, 27 Nov 2025 19:35:54 +0000 Subject: [PATCH 05/12] non-trainable parameters excluded from grouped_parameters --- pipelinerl/finetune/optim.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pipelinerl/finetune/optim.py b/pipelinerl/finetune/optim.py index 268e88cf..b1a7f738 100644 --- a/pipelinerl/finetune/optim.py +++ b/pipelinerl/finetune/optim.py @@ -12,6 +12,9 @@ def get_grouped_params( ): params_with_wd, params_without_wd = [], [] for n, p in model.named_parameters(): + # Skip frozen parameters + if not p.requires_grad: + continue if any(nd in n for nd in no_decay): params_without_wd.append(p) else: From 32d8985c376594bc5cba63923de020a025de1b2d Mon Sep 17 00:00:00 2001 From: ehsk Date: Fri, 28 Nov 2025 15:42:06 +0000 Subject: [PATCH 06/12] replace "python" with current executable python depending on current env --- pipelinerl/launch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 77017602..a7973856 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -88,7 +88,7 @@ def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus os.makedirs(log_dir, exist_ok=True) cmd = [ - "python", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "--model", @@ -140,7 +140,7 @@ def run_actor_llm( "pipelinerl.entrypoints.run_vllm0" ) cmd = [ - "python", + sys.executable, "-m", entrypoint, "--model", @@ -190,7 +190,7 @@ def run_actor(world_map: WorldMap, actor_idx: int, exp_dir: Path): raise NotImplementedError("Can only do 1 actor yet") llm_urls = "+".join(world_map.get_actor_urls()) cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_actor", "--config-dir", @@ -215,7 +215,7 @@ def run_environment(cfg: DictConfig, job: Job): # run in a subprocess like in the rest of the code run_dir = Path(cfg.output_dir) / f"environment_{job.replica_idx}" cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_environment", "--config-dir", @@ -246,7 +246,7 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: if cfg.use_fsdp and cfg.use_deepspeed: raise ValueError("Cannot use both FSDP and DeepSpeed") cmd = [ - "python", + sys.executable, "-m", "accelerate.commands.launch", ] @@ -343,7 +343,7 @@ def run_preprocess(world_map: WorldMap, preprocessor_idx: int, exp_dir: Path): raise NotImplementedError("Can only do 1 preprocessor yet") llm_urls = "+".join(world_map.get_preprocessor_urls()) cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_preprocess", "--config-dir", From 71db4a883049dd01209314321854a84e9c615533 Mon Sep 17 00:00:00 2001 From: ehsk Date: Fri, 28 Nov 2025 15:42:20 +0000 Subject: [PATCH 07/12] add processor args to vllm --- conf/chartqa.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/conf/chartqa.yaml b/conf/chartqa.yaml index 4b9c3875..95d5e23c 100644 --- a/conf/chartqa.yaml +++ b/conf/chartqa.yaml @@ -52,3 +52,4 @@ vllm_config: vllm_kwargs: max-num-seqs: 64 max-num-batched-tokens: 32768 + mm-processor-kwargs: '{"min_pixels": 784, "max_pixels": 1003520, "use_fast": true}' # 28*28 to 1280*28*28 From 14b017b2c6ebdebee00978ab83b07d2f29bb1e2b Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 2 Dec 2025 20:52:43 +0000 Subject: [PATCH 08/12] epsilons for chartqa added --- conf/chartqa.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/conf/chartqa.yaml b/conf/chartqa.yaml index 95d5e23c..c839dbf5 100644 --- a/conf/chartqa.yaml +++ b/conf/chartqa.yaml @@ -8,7 +8,8 @@ finetune: seq_length: 8000 gradient_accumulation_passes: 512 seq_packing: false - freeze_vision_tower: true + epsilon_high: 4.0 + epsilon_low: 0.0 llm: parameters: @@ -25,13 +26,15 @@ actor: system_prompt: You are an expert at analyzing charts and graphs. Please examine the chart carefully and answer the question accurately. Remember to provide your final answer in a boxed format, like \\boxed{{your answer}}. task_template: |- Question: {question} - + Please analyze the chart step by step and put your final answer within \\boxed{{}}. llm_max_rollouts: 16 shared_memory_entry_size: 2000000000 + max_stream_size: 1000 preprocess: shared_memory_entry_size: 2000000000 + max_stream_size: 1000 environment: null From fe00beab5b83d35b75f7365369eec9b39a895a51 Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 2 Dec 2025 20:57:30 +0000 Subject: [PATCH 09/12] max_stream_size added for redis to avoid OOM --- conf/base.yaml | 8 ++++++-- pipelinerl/actor.py | 2 +- pipelinerl/preprocess.py | 18 ++++++++++++++++-- pipelinerl/streams.py | 37 +++++++++++++++++++++++++++++-------- 4 files changed, 52 insertions(+), 13 deletions(-) diff --git a/conf/base.yaml b/conf/base.yaml index e3122f5a..b99b8c07 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -18,6 +18,8 @@ actor: result_queue_size: 64 throughput_window_size: 50 shared_memory_entry_size: 10000000 + # Maximum number of entries to retain in the actor data stream (Redis only for now) + max_stream_size: 1000000 environment: null preprocess: input: actor @@ -26,7 +28,7 @@ preprocess: chunk_n_groups: 2 # queue for loaded raw groups raw_queue_size: 8 - # queue for processed chunks of multiple groups + # queue for processed chunks of multiple groups input_queue_size: 32 # queue for ready chunks for multiple groups output_queue_size: 32 @@ -36,9 +38,11 @@ preprocess: ring_buffer_size: 128 # "virtual" sample queue per lead trainer max_ready_samples_per_lead: 64 - pop_old_data: ${..pop_old_data} + pop_old_data: ${..pop_old_data} shared_memory_entry_size: 100000000 log_every_n_samples: 128 + # Maximum number of entries to retain in the training data stream (Redis only for now) + max_stream_size: 1000000 llm: parameters: # changed diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 1c238ff9..1c0e5af3 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -438,7 +438,7 @@ def run(self, dataset: list[tuple[str, dict]]): logger.info(f"Start {'train' if self.is_training else 'test'} actor loop") with ( - write_to_streams(self.data_stream, "a") as data_stream_writer, + write_to_streams(self.data_stream, "a", max_stream_size=self.cfg.actor.max_stream_size) as data_stream_writer, write_to_streams(self.stats_stream, "a") as stats_writer, ): while True: diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 65fcee47..1c5d666a 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -157,7 +157,20 @@ def preprocess_dataset( entry["step_index"] = entry["metadata"]["step_index"] if not isinstance(tokenizer.eos_token_id, int): raise ValueError(f"Tokenizer {tokenizer} does not have an eos_token_id") - dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + try: + dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + except Exception as e: + logger.exception( + "Error in populate_rl_data: {}".format({ + "Data": data, + "Dataset": dataset, + "Tokenizer eos_token_id": tokenizer.eos_token_id, + "RL config": rl_config, + "LLM": llm, + "Seq length": seq_length, + }), + ) + raise e return dataset @@ -450,7 +463,8 @@ def run_preprocessing_loop( # Per-trainer sample tracking (similar to finetune_loop.py) total_filtered_out = 0 # Track total filtered samples across all batches - with write_to_streams(output_stream) as data_writer, write_to_streams(stats_streams) as stats_writer: + max_stream_size = cfg.preprocess.max_stream_size + with write_to_streams(output_stream, max_stream_size=max_stream_size) as data_writer, write_to_streams(stats_streams) as stats_writer: with SharedMemoryManager() as smm: # Create shared memory queues without the manager parameter input_queue = SharedMemoryQueue(smm, cfg.preprocess.input_queue_size, cfg.preprocess.shared_memory_entry_size) diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py index 632b760e..b4ac7851 100644 --- a/pipelinerl/streams.py +++ b/pipelinerl/streams.py @@ -110,7 +110,7 @@ def connect_to_redis(config: RedisConfig): logger.debug(f"Trying to connect to Redis server at {config.host}:{config.port}") client = redis.Redis(host=config.host, port=config.port) client.ping() - logger.info(f"Connected to Redis server") + logger.debug("Connected to Redis server") return client except (redis.exceptions.TimeoutError, redis.ConnectionError) as e: logger.warning(f"Waiting for Redis server ({type(e)}). Retrying in 5 seconds.") @@ -118,8 +118,15 @@ def connect_to_redis(config: RedisConfig): class RedisStreamWriter(StreamWriter): - def __init__(self, stream: SingleStreamSpec, mode: Literal["w", "a"] = "a"): + def __init__(self, stream: SingleStreamSpec, mode: Literal["w", "a"] = "a", max_stream_size: int = 1000000): + """ + Args: + stream: The stream specification + mode: Write mode - 'w' for write (new stream) or 'a' for append + max_stream_size: Maximum number of entries to retain in the stream (Redis only) + """ self.stream = stream + self.max_stream_size = max_stream_size assert isinstance(_backend, RedisConfig) self._stream_name = str(self.stream) self._redis = connect_to_redis(_backend) @@ -155,7 +162,7 @@ def write(self, data, partition: int | None = None): if isinstance(data, BaseModel): data = data.model_dump() data = pickle.dumps(data) - self._redis.xadd(self._stream_name, {"index": self._index, "data": data}, maxlen=1000000, approximate=True) + self._redis.xadd(self._stream_name, {"index": self._index, "data": data}, maxlen=self.max_stream_size, approximate=True) self._index += 1 @@ -195,7 +202,13 @@ def read(self): class RoundRobinRedisStreamWriter(StreamWriter): # TODO: share the connection across writers - def __init__(self, streams: StreamRangeSpec, mode: Literal["w", "a"] = "a"): + def __init__(self, streams: StreamRangeSpec, mode: Literal["w", "a"] = "a", max_stream_size: int = 1000000): + """ + Args: + streams: The stream range specification + mode: Write mode - 'w' for write (new stream) or 'a' for append + max_stream_size: Maximum number of entries to retain in the stream (Redis only) + """ self.streams = streams self._next_stream = 0 self._writers = [ @@ -207,6 +220,7 @@ def __init__(self, streams: StreamRangeSpec, mode: Literal["w", "a"] = "a"): partition=i, ), mode=mode, + max_stream_size=max_stream_size, ) for i in range(*self.streams.partition_range) ] @@ -400,16 +414,23 @@ def read_stream(stream: SingleStreamSpec) -> StreamReader: assert False -def write_to_streams(streams: StreamSpec, mode: Literal["w", "a"] = "a") -> StreamWriter: - """Append to the end of the stream.""" +def write_to_streams(streams: StreamSpec, mode: Literal["w", "a"] = "a", max_stream_size: int = 1000000) -> StreamWriter: + """ + Append to the end of the stream. + + Args: + streams: The stream specification + mode: Write mode - 'w' for write (new stream) or 'a' for append + max_stream_size: Maximum number of entries to retain in the stream (Redis only) + """ raise_if_backend_not_set() if not isinstance(streams, (SingleStreamSpec, StreamRangeSpec)): raise ValueError(f"Invalid stream spec: {streams}") if isinstance(_backend, RedisConfig): if isinstance(streams, SingleStreamSpec): - return RedisStreamWriter(streams, mode) + return RedisStreamWriter(streams, mode, max_stream_size) elif isinstance(streams, StreamRangeSpec): - return RoundRobinRedisStreamWriter(streams, mode) + return RoundRobinRedisStreamWriter(streams, mode, max_stream_size) else: assert False elif _backend == "files": From 84c43494c4487dcf190965ee6c4d521f797c9889 Mon Sep 17 00:00:00 2001 From: ehsk Date: Fri, 5 Dec 2025 03:17:15 +0000 Subject: [PATCH 10/12] mini-batch size can be greater than 1 --- pipelinerl/finetune/data.py | 46 ++++++++++++++++++++++++++++--------- pipelinerl/launch.py | 3 +-- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/pipelinerl/finetune/data.py b/pipelinerl/finetune/data.py index 4e395e3b..e12b2821 100644 --- a/pipelinerl/finetune/data.py +++ b/pipelinerl/finetune/data.py @@ -172,17 +172,41 @@ def collate( if seq_length % pad_to_multiple_of: seq_length += pad_to_multiple_of - (seq_length % pad_to_multiple_of) result = {} - - # Visual feature fields that should be stacked, not padded - if "visual_features" in example_dict and isinstance(example_dict["visual_features"][0], dict): - for k, seq_list in example_dict["visual_features"][0].items(): - if k == "image_grid_thw": - # image_grid_thw should remain as a list - result[k] = seq_list - else: - # Other visual fields like pixel_values can be stacked as tensors - valid_tensors = [torch.tensor(seq) for seq in seq_list] - result[k] = torch.stack(valid_tensors) + + # Handle visual features with dynamic batching + if "visual_features" in example_dict: + visual_features_list = example_dict["visual_features"] + + if visual_features_list and visual_features_list[0] is not None: + first_vf = visual_features_list[0] + + for key in first_vf.keys(): + if key == "image_grid_thw": + # Concatenate all image_grid_thw arrays into a single tensor + # Each sample has shape (num_images, 3), concatenate along image dimension + all_grids = [torch.as_tensor(vf[key]) for vf in visual_features_list] + result[key] = torch.cat(all_grids, dim=0) + else: + # Convert to torch tensors (zero-copy for numpy arrays) + all_tensors = [torch.as_tensor(vf[key]) for vf in visual_features_list] + + # Find max number of images in this batch + max_num_images = max(t.shape[0] for t in all_tensors) + + # Get shape of single image: (C, H, W) + single_shape = all_tensors[0].shape[1:] + dtype = all_tensors[0].dtype + + # Pre-allocate batch tensor: (batch_size, max_num_images, C, H, W) + batch_shape = (len(all_tensors), max_num_images) + single_shape + batched = torch.zeros(batch_shape, dtype=dtype) + + # Fill in actual data (padding is already zeros) + for i, tensor in enumerate(all_tensors): + num_images = tensor.shape[0] + batched[i, :num_images] = tensor + + result[key] = batched for k, seq_list in example_dict.items(): if k == "model_version": diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index a7973856..19c05999 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -60,8 +60,6 @@ def validate_config(cfg: DictConfig): raise ValueError("Only Qwen2.5-VL models are supported for vision language modeling") if cfg.finetune.seq_packing: raise ValueError("Vision language models cannot use sequence packing (seq_packing must be false)") - if cfg.finetune.train_batch_size > 1: - raise ValueError("Vision language models cannot use batch size > 1 (train_batch_size must be 1)") if cfg.finetune.seq_parallel > 1: if not cfg.finetune.seq_packing: @@ -494,6 +492,7 @@ def debug_link_streams(cfg: DictConfig, topics: list[str]): if not cfg.debug.streams_from: raise ValueError("Need to specify streams_from for debug mode") stream_dir = Path(cfg.output_dir) / "streams" + stream_dir.mkdir(parents=True, exist_ok=True) for topic in topics: source_topic_dir = Path(cfg.debug.streams_from) / "streams" / topic target_topic_dir = stream_dir / topic From 3b5e2a06e77cb6e956c3dd659c856e6d03c0ce81 Mon Sep 17 00:00:00 2001 From: ehsk Date: Thu, 18 Dec 2025 16:51:17 +0000 Subject: [PATCH 11/12] a fix for configs in VLMs like Qwen3-VL or Apriel where there's a text_config inside config that creates problems for deepspeed integration --- pipelinerl/finetune_loop.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 1cc9d2b1..282cae90 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -353,6 +353,24 @@ def run_finetuning_loop( logger.info(f"Model loaded in dtype {model.dtype}") dt = log_time(dt, time_stats, "finetune/model_load") + # Fix for multimodal models (e.g., Apriel, Qwen3-VL) with Accelerate+DeepSpeed + # Accelerate's _prepare_deepspeed() doesn't check text_config.hidden_size + if not hasattr(model.config, "hidden_size") and not hasattr(model.config, "hidden_sizes"): + if hasattr(model.config, "text_config"): + hidden_size = None + if hasattr(model.config.text_config, "hidden_size"): + hidden_size = model.config.text_config.hidden_size + elif hasattr(model.config.text_config, "hidden_sizes"): + hidden_size = max(model.config.text_config.hidden_sizes) + + if hidden_size is not None: + if get_accelerator().is_main_process: + logger.info( + f"Detected multimodal model with text_config.hidden_size={hidden_size}. " + f"Setting config.hidden_size to enable DeepSpeed auto-configuration." + ) + model.config.hidden_size = hidden_size + data_stream = SingleStreamSpec( exp_path=exp_root_dir, topic=args.input, From c3a30398c4dadfcd29bf14fb258f55a5b853e019 Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 26 Jan 2026 17:52:45 +0000 Subject: [PATCH 12/12] refactorings and improvements --- conf/base.yaml | 3 + conf/chartqa.yaml | 13 ++- pipelinerl/actor.py | 2 + pipelinerl/async_llm.py | 11 +-- pipelinerl/finetune/data.py | 35 +------- pipelinerl/finetune/rl/__init__.py | 11 ++- pipelinerl/finetune/types.py | 24 +++--- pipelinerl/launch.py | 2 - pipelinerl/llm.py | 1 + pipelinerl/preprocess.py | 1 + pipelinerl/processor_factory.py | 19 ----- pipelinerl/vision_processor_utils.py | 122 +++++++++++++++++++++++++++ 12 files changed, 169 insertions(+), 75 deletions(-) delete mode 100644 pipelinerl/processor_factory.py create mode 100644 pipelinerl/vision_processor_utils.py diff --git a/conf/base.yaml b/conf/base.yaml index b99b8c07..c80eb08a 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -91,6 +91,9 @@ eval_every_n_versions: 78000 # changed model_path: Qwen/Qwen2.5-7B +# Processor configuration for vision-language models (multimodal) +mm_processor_kwargs: {} + # will use default based on the chosen backend accelerate_config: null use_deepspeed: true diff --git a/conf/chartqa.yaml b/conf/chartqa.yaml index c839dbf5..2ed7b4cb 100644 --- a/conf/chartqa.yaml +++ b/conf/chartqa.yaml @@ -8,8 +8,9 @@ finetune: seq_length: 8000 gradient_accumulation_passes: 512 seq_packing: false - epsilon_high: 4.0 - epsilon_low: 0.0 + rl: + epsilon_high: 4.0 + epsilon_low: 0.0 llm: parameters: @@ -49,6 +50,14 @@ test_dataset_names: # Use vision-language model for multimodal support model_path: Qwen/Qwen2.5-VL-3B-Instruct +eval_every_n_versions: 12500 + +# Processor configuration for vision-language models (shared between training and inference) +mm_processor_kwargs: + min_pixels: 784 # 28*28 + max_pixels: 1003520 # 1280*28*28 + use_fast: true + # Override vLLM config for multimodal support vllm_config: use_v1: true diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 1c0e5af3..8c321b82 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -622,6 +622,7 @@ def run_actor_loop(cfg: DictConfig): tokenizer_name=str(actor_model_path), parameters=cfg.llm.parameters, collect_logprobs=True, + mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}), ) for url in llm_urls ] @@ -632,6 +633,7 @@ def run_actor_loop(cfg: DictConfig): tokenizer_name=str(actor_model_path), parameters=cfg.test_llm.parameters, collect_logprobs=True, + mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}), ) for url in llm_urls ] diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 4e78ebf9..a8b45b26 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -4,12 +4,13 @@ import aiohttp import numpy as np +import torch from PIL import Image from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM from pipelinerl.finetune.data import MASKED_TOKEN_ID from pipelinerl.rollouts import TrainingText -from pipelinerl.processor_factory import get_processor +from pipelinerl.vision_processor_utils import get_mm_processor from omegaconf import DictConfig, ListConfig, OmegaConf logger = logging.getLogger(__name__) @@ -157,7 +158,7 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: if use_processor: # Use processor for vision-language models - processor = get_processor(llm.model_name) + processor = get_mm_processor(llm.model_name, mm_processor_kwargs=llm.mm_processor_kwargs) try: # Apply chat template using processor for proper image token handling @@ -189,11 +190,11 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: processed = processor( text=[prompt_text], images=images, padding=True, return_tensors=None ) + # Convert PyTorch tensors to numpy arrays visual_features = { - key: value + key: value.cpu().numpy() if torch.is_tensor(value) else value for key, value in processed.items() - if isinstance(value, np.ndarray) - and key not in ["input_ids", "attention_mask"] + if key not in ["input_ids", "attention_mask"] } except Exception as e: diff --git a/pipelinerl/finetune/data.py b/pipelinerl/finetune/data.py index e12b2821..833f6490 100644 --- a/pipelinerl/finetune/data.py +++ b/pipelinerl/finetune/data.py @@ -15,6 +15,7 @@ from pipelinerl.finetune.utils import create_sentinel_example from pipelinerl.rollouts import TrainingText +from pipelinerl.vision_processor_utils import collate_visual_features from .context import get_accelerator, logger from .rl import RL_DATA_COLUMNS, prepare_rl_fields @@ -176,37 +177,9 @@ def collate( # Handle visual features with dynamic batching if "visual_features" in example_dict: visual_features_list = example_dict["visual_features"] - - if visual_features_list and visual_features_list[0] is not None: - first_vf = visual_features_list[0] - - for key in first_vf.keys(): - if key == "image_grid_thw": - # Concatenate all image_grid_thw arrays into a single tensor - # Each sample has shape (num_images, 3), concatenate along image dimension - all_grids = [torch.as_tensor(vf[key]) for vf in visual_features_list] - result[key] = torch.cat(all_grids, dim=0) - else: - # Convert to torch tensors (zero-copy for numpy arrays) - all_tensors = [torch.as_tensor(vf[key]) for vf in visual_features_list] - - # Find max number of images in this batch - max_num_images = max(t.shape[0] for t in all_tensors) - - # Get shape of single image: (C, H, W) - single_shape = all_tensors[0].shape[1:] - dtype = all_tensors[0].dtype - - # Pre-allocate batch tensor: (batch_size, max_num_images, C, H, W) - batch_shape = (len(all_tensors), max_num_images) + single_shape - batched = torch.zeros(batch_shape, dtype=dtype) - - # Fill in actual data (padding is already zeros) - for i, tensor in enumerate(all_tensors): - num_images = tensor.shape[0] - batched[i, :num_images] = tensor - - result[key] = batched + batched_visual_features = collate_visual_features(visual_features_list) + if batched_visual_features: + result["visual_features"] = batched_visual_features for k, seq_list in example_dict.items(): if k == "model_version": diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index d33e1961..adcba384 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -190,13 +190,12 @@ def rl_step( } if batch.is_packed: model_inputs["position_ids"] = batch.position_ids - + # Add visual features if present (for multimodal models) - if hasattr(batch, 'pixel_values') and batch.pixel_values is not None: - model_inputs["pixel_values"] = batch.pixel_values - if hasattr(batch, 'image_grid_thw') and batch.image_grid_thw is not None: - model_inputs["image_grid_thw"] = batch.image_grid_thw #torch.tensor(.reshape((1, 3)) - + # Unpack all visual features from the dict (e.g., pixel_values, image_grid_thw, image_sizes) + if hasattr(batch, 'visual_features') and batch.visual_features is not None: + model_inputs.update(batch.visual_features) + outputs = model(**model_inputs) # compute log probs and entropy diff --git a/pipelinerl/finetune/types.py b/pipelinerl/finetune/types.py index a3c16f2e..eda751f7 100644 --- a/pipelinerl/finetune/types.py +++ b/pipelinerl/finetune/types.py @@ -70,11 +70,11 @@ class PipelineBatchEncoding(BaseModel): is_packed: bool = False seq_boundaries: torch.IntTensor | None = None # Required when seq_packing=True - # Visual feature fields (optional, for multimodal models) - pixel_values: torch.FloatTensor | None = None - image_grid_thw: torch.LongTensor | None = None + # Visual features (optional, for multimodal models) + # Dict containing model-specific visual features (e.g., pixel_values, image_grid_thw, image_sizes) + visual_features: dict[str, torch.Tensor] | None = None - @field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'image_grid_thw', 'segment_ids', mode='before') + @field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'segment_ids', mode='before') @classmethod def convert_to_long_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.LongTensor | None: """Handle initialization of long tensors from different types.""" @@ -95,9 +95,8 @@ def convert_to_int_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.IntT if isinstance(v, torch.Tensor): return v.int() # type: ignore return torch.tensor(v, dtype=torch.int) - - # TODO: am i needed? - @field_validator('rewards', 'advantages', 'ref_logprobs', 'old_logprobs', 'group_tokens', 'num_labels', 'overflow', 'pixel_values', mode='before') + + @field_validator('rewards', 'advantages', 'ref_logprobs', 'old_logprobs', 'group_tokens', 'num_labels', 'overflow', mode='before') @classmethod def convert_to_float_tensor(cls, v: List[float] | torch.Tensor | None) -> torch.FloatTensor | None: """Handle initialization of float tensors from different types.""" @@ -111,10 +110,16 @@ def convert_to_float_tensor(cls, v: List[float] | torch.Tensor | None) -> torch. def to_device(self, device: Union[str, torch.device]) -> 'PipelineBatchEncoding': """Move all tensors to the specified device and return updated instance.""" - for field_name in self.model_fields: + for field_name in type(self).model_fields: field_value = getattr(self, field_name) if isinstance(field_value, torch.Tensor): setattr(self, field_name, field_value.to(device)) + elif isinstance(field_value, dict): + setattr( + self, + field_name, + {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in field_value.items()} + ) return self @classmethod @@ -173,8 +178,7 @@ def make_slices(self, num_slices: int) -> list['PipelineBatchEncoding']: "is_packed": self.is_packed, "padding": self.padding, "seq_boundaries": self.seq_boundaries, - "pixel_values": self.pixel_values, - "image_grid_thw": self.image_grid_thw + "visual_features": self.visual_features } slices.append(PipelineBatchEncoding(**result)) return slices diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 19c05999..a0df771b 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -56,8 +56,6 @@ def validate_config(cfg: DictConfig): # Check for vision language model constraints if cfg.finetune.model_class == "vision2seq-language-modeling": - if "Qwen2.5-VL" not in cfg.model_path: - raise ValueError("Only Qwen2.5-VL models are supported for vision language modeling") if cfg.finetune.seq_packing: raise ValueError("Vision language models cannot use sequence packing (seq_packing must be false)") diff --git a/pipelinerl/llm.py b/pipelinerl/llm.py index cc099c15..9d76a848 100644 --- a/pipelinerl/llm.py +++ b/pipelinerl/llm.py @@ -403,6 +403,7 @@ class TrainableLLM(LLM): max_parallel_requests: int = 32 max_retries: int = 5 base_delay: float = 0.5 + mm_processor_kwargs: dict = Field(default_factory=dict) _semaphore: asyncio.Semaphore def model_post_init(self, __context): diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 1c5d666a..3d402049 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -420,6 +420,7 @@ def run_preprocessing_loop( model_name=cfg.finetune.config_name, tokenizer_name=cfg.finetune.config_name, parameters=cfg.llm.parameters, + mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}), ) for url in llm_urls ] diff --git a/pipelinerl/processor_factory.py b/pipelinerl/processor_factory.py deleted file mode 100644 index 06f0fc2b..00000000 --- a/pipelinerl/processor_factory.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Simple cache for AutoProcessor instances.""" -from typing import Dict -from transformers import AutoProcessor -import logging -logger = logging.getLogger(__name__) - -_processors: Dict[str, AutoProcessor] = {} - -def get_processor(model_name: str) -> AutoProcessor: - """Get or create an AutoProcessor for the given model.""" - if model_name not in _processors: - logger.info(f"Loading processor for model: {model_name}") - #TODO: should be args - _processors[model_name] = AutoProcessor.from_pretrained(model_name, min_pixels=28*28, max_pixels=1280*28*28) - return _processors[model_name] - -def clear_cache() -> None: - """Clear all cached processors.""" - _processors.clear() \ No newline at end of file diff --git a/pipelinerl/vision_processor_utils.py b/pipelinerl/vision_processor_utils.py new file mode 100644 index 00000000..ee2e8635 --- /dev/null +++ b/pipelinerl/vision_processor_utils.py @@ -0,0 +1,122 @@ +""" +Vision processor utilities for multimodal models. + +This module provides processor caching and management for vision-language models. + +Supported models: +- Qwen2.5-VL: Uses image_grid_thw (B, 3) and flattened pixel_values +- Pixtral/Apriel: Uses image_sizes (B, 2) and standard pixel_values (B, C, H, W) +""" +import logging +from typing import Dict +import torch +from transformers import AutoProcessor + +logger = logging.getLogger(__name__) + +# Processor cache +_processors: Dict[str, AutoProcessor] = {} + + +def get_mm_processor(model_name: str, mm_processor_kwargs: dict | None = None) -> AutoProcessor: + """ + Get or create an AutoProcessor for multimodal models. + + Args: + model_name: HuggingFace model identifier + mm_processor_kwargs: Optional kwargs to pass to AutoProcessor.from_pretrained() + + Returns: + AutoProcessor instance + """ + if model_name not in _processors: + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + logger.info(f"Loading processor for model: {model_name} with kwargs: {mm_processor_kwargs}") + _processors[model_name] = AutoProcessor.from_pretrained( + model_name, **mm_processor_kwargs + ) + return _processors[model_name] + + +def clear_cache() -> None: + """Clear all cached processors.""" + _processors.clear() + + +def collate_visual_features(visual_features_list: list[dict]) -> dict[str, torch.Tensor]: + """ + Collate visual features from multiple samples into batched tensors. + + Handles different formats: + - Metadata (image_grid_thw, image_sizes): Concatenate along image dimension + - Qwen pixel_values (2D): Concatenate flattened features + - Pixtral pixel_values (4D): Pad to max_num_images + - Other features: Pad to max_num_images + + Args: + visual_features_list: List of visual feature dicts from individual samples + + Returns: + Dict mapping feature names to batched tensors + """ + if not visual_features_list or visual_features_list[0] is None: + return {} + + first_vf = visual_features_list[0] + batched_visual_features = {} + + for key in first_vf.keys(): + if key in ("image_grid_thw", "image_sizes"): + # Concatenate metadata arrays (image_grid_thw or image_sizes) + # Each sample has shape (num_images, 2 or 3), concatenate along image dimension + all_metadata = [torch.as_tensor(vf[key]) for vf in visual_features_list] + batched_visual_features[key] = torch.cat(all_metadata, dim=0) + + elif key == "pixel_values": + # Handle pixel_values - format differs by model: + # - Qwen: (total_pixels, hidden_dim) - flattened, concatenate along pixel dimension + # - Pixtral: (num_images, C, H, W) - standard, needs padding to max_num_images + all_tensors = [torch.as_tensor(vf[key]) for vf in visual_features_list] + + # Check if this is flattened format (2D) or image format (4D) + if all_tensors[0].ndim == 2: + # Qwen format: (total_pixels, hidden_dim) - just concatenate + batched_visual_features[key] = torch.cat(all_tensors, dim=0) + elif all_tensors[0].ndim == 4: + # Pixtral format: (num_images, C, H, W) - pad to max_num_images + max_num_images = max(t.shape[0] for t in all_tensors) + single_shape = all_tensors[0].shape[1:] # (C, H, W) + dtype = all_tensors[0].dtype + + # Pre-allocate: (batch_size, max_num_images, C, H, W) + batch_shape = (len(all_tensors), max_num_images) + single_shape + batched = torch.zeros(batch_shape, dtype=dtype) + + # Fill in actual data + for i, tensor in enumerate(all_tensors): + num_images = tensor.shape[0] + batched[i, :num_images] = tensor + + batched_visual_features[key] = batched + else: + raise ValueError(f"Unexpected pixel_values shape: {all_tensors[0].shape}") + + else: + # Other visual features - assume they need padding like Pixtral pixel_values + all_tensors = [torch.as_tensor(vf[key]) for vf in visual_features_list] + max_num_images = max(t.shape[0] for t in all_tensors) + single_shape = all_tensors[0].shape[1:] + dtype = all_tensors[0].dtype + + batch_shape = (len(all_tensors), max_num_images) + single_shape + batched = torch.zeros(batch_shape, dtype=dtype) + + for i, tensor in enumerate(all_tensors): + num_images = tensor.shape[0] + batched[i, :num_images] = tensor + + batched_visual_features[key] = batched + + return batched_visual_features