Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ actor:
result_queue_size: 64
throughput_window_size: 50
shared_memory_entry_size: 10000000
# Maximum number of entries to retain in the actor data stream (Redis only for now)
max_stream_size: 1000000
environment: null
preprocess:
input: actor
Expand All @@ -30,7 +32,7 @@ preprocess:
chunk_n_groups: 2
# queue for loaded raw groups
raw_queue_size: 8
# queue for processed chunks of multiple groups
# queue for processed chunks of multiple groups
input_queue_size: 32
# queue for ready chunks for multiple groups
output_queue_size: 32
Expand All @@ -40,9 +42,11 @@ preprocess:
ring_buffer_size: 128
# "virtual" sample queue per lead trainer
max_ready_samples_per_lead: 64
pop_old_data: ${..pop_old_data}
pop_old_data: ${..pop_old_data}
shared_memory_entry_size: 100000000
log_every_n_samples: 128
# Maximum number of entries to retain in the training data stream (Redis only for now)
max_stream_size: 1000000
llm:
parameters:
# changed
Expand Down Expand Up @@ -96,6 +100,9 @@ eval_every_n_versions: 78000
# changed
model_path: Qwen/Qwen2.5-7B

# Processor configuration for vision-language models (multimodal)
mm_processor_kwargs: {}

# will use default based on the chosen backend
accelerate_config: null
use_deepspeed: true
Expand Down
16 changes: 15 additions & 1 deletion conf/chartqa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ finetune:
seq_length: 8000
gradient_accumulation_passes: 512
seq_packing: false
rl:
epsilon_high: 4.0
epsilon_low: 0.0

llm:
parameters:
Expand All @@ -24,13 +27,15 @@ actor:
system_prompt: You are an expert at analyzing charts and graphs. Please examine the chart carefully and answer the question accurately. Remember to provide your final answer in a boxed format, like \\boxed{{your answer}}.
task_template: |-
Question: {question}

Please analyze the chart step by step and put your final answer within \\boxed{{}}.
llm_max_rollouts: 16
shared_memory_entry_size: 2000000000
max_stream_size: 1000

preprocess:
shared_memory_entry_size: 2000000000
max_stream_size: 1000

environment: null

Expand All @@ -45,9 +50,18 @@ test_dataset_names:
# Use vision-language model for multimodal support
model_path: Qwen/Qwen2.5-VL-3B-Instruct

eval_every_n_versions: 12500

# Processor configuration for vision-language models (shared between training and inference)
mm_processor_kwargs:
min_pixels: 784 # 28*28
max_pixels: 1003520 # 1280*28*28
use_fast: true

# Override vLLM config for multimodal support
vllm_config:
use_v1: true
vllm_kwargs:
max-num-seqs: 64
max-num-batched-tokens: 32768
mm-processor-kwargs: '{"min_pixels": 784, "max_pixels": 1003520, "use_fast": true}' # 28*28 to 1280*28*28
2 changes: 2 additions & 0 deletions conf/finetune/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ data: null
model_class: causal-language-modeling
# Model name or path of model to be trained.
config_name: ${..model_path}
# Freeze vision tower for vision-language models (only applicable for vision2seq-language-modeling)
freeze_vision_tower: false
# Optimizer type, supported: adamw_torch, adafactor, cpuadam, lion
optim: adamw_torch
# use half precision training, full bf16 without mixed precision copies at all
Expand Down
4 changes: 3 additions & 1 deletion pipelinerl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def run(self, dataset: list[tuple[str, dict]]):

logger.info(f"Start {'train' if self.is_training else 'test'} actor loop")
with (
write_to_streams(self.data_stream, "a") as data_stream_writer,
write_to_streams(self.data_stream, "a", max_stream_size=self.cfg.actor.max_stream_size) as data_stream_writer,
write_to_streams(self.stats_stream, "a") as stats_writer,
):
while True:
Expand Down Expand Up @@ -625,6 +625,7 @@ def run_actor_loop(cfg: DictConfig):
tokenizer_name=str(actor_model_path),
parameters=cfg.llm.parameters,
collect_logprobs=True,
mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}),
)
for url in llm_urls
]
Expand All @@ -635,6 +636,7 @@ def run_actor_loop(cfg: DictConfig):
tokenizer_name=str(actor_model_path),
parameters=cfg.test_llm.parameters,
collect_logprobs=True,
mm_processor_kwargs=cfg.get("mm_processor_kwargs", {}),
)
for url in llm_urls
]
Expand Down
11 changes: 6 additions & 5 deletions pipelinerl/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import aiohttp
import numpy as np
import torch
from PIL import Image
from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM

from pipelinerl.finetune.data import MASKED_TOKEN_ID
from pipelinerl.rollouts import TrainingText
from pipelinerl.processor_factory import get_processor
from pipelinerl.vision_processor_utils import get_mm_processor
from omegaconf import DictConfig, ListConfig, OmegaConf

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,7 +158,7 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:

if use_processor:
# Use processor for vision-language models
processor = get_processor(llm.model_name)
processor = get_mm_processor(llm.model_name, mm_processor_kwargs=llm.mm_processor_kwargs)

try:
# Apply chat template using processor for proper image token handling
Expand Down Expand Up @@ -189,11 +190,11 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:
processed = processor(
text=[prompt_text], images=images, padding=True, return_tensors=None
)
# Convert PyTorch tensors to numpy arrays
visual_features = {
key: value
key: value.cpu().numpy() if torch.is_tensor(value) else value
for key, value in processed.items()
if isinstance(value, np.ndarray)
and key not in ["input_ids", "attention_mask"]
if key not in ["input_ids", "attention_mask"]
}

except Exception as e:
Expand Down
22 changes: 22 additions & 0 deletions pipelinerl/finetune/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,28 @@ def load_model(args, model_class, current_dir):
gradient_checkpointing_kwargs={"use_reentrant": args.reentrant_checkpointing}
)

# Freeze vision tower if specified
freeze_vision_tower = getattr(args, "freeze_vision_tower", False)
if freeze_vision_tower:
# Try to get vision tower module from the model
vision_tower = None
if hasattr(model, "visual"):
vision_tower = model.visual # Qwen-VL, Qwen2-VL, Qwen2.5-VL, Qwen3-VL
elif hasattr(model, "vision_tower"):
vision_tower = model.vision_tower # LLaVA
elif hasattr(model, "vision_model"):
vision_tower = model.vision_model # BLIP-2, InstructBLIP

if vision_tower is not None:
vision_tower.requires_grad_(False)
logger.info("Vision tower parameters frozen successfully (i.e. its parameters will be excluded from optimizer)")
else:
logger.warning(
"freeze_vision_tower=True but could not find vision tower. "
"Checked attributes: model.visual (Qwen*-VL), model.vision_tower (LlaVA), model.vision_model (BLIP-2, InstructBLIP). "
"So setting this parameter does not have any effect."
)

get_accelerator().wait_for_everyone()
return model

Expand Down
19 changes: 8 additions & 11 deletions pipelinerl/finetune/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pipelinerl.finetune.utils import create_sentinel_example
from pipelinerl.rollouts import TrainingText
from pipelinerl.vision_processor_utils import collate_visual_features

from .context import get_accelerator, logger
from .rl import RL_DATA_COLUMNS, prepare_rl_fields
Expand Down Expand Up @@ -172,17 +173,13 @@ def collate(
if seq_length % pad_to_multiple_of:
seq_length += pad_to_multiple_of - (seq_length % pad_to_multiple_of)
result = {}

# Visual feature fields that should be stacked, not padded
if "visual_features" in example_dict and isinstance(example_dict["visual_features"][0], dict):
for k, seq_list in example_dict["visual_features"][0].items():
if k == "image_grid_thw":
# image_grid_thw should remain as a list
result[k] = seq_list
else:
# Other visual fields like pixel_values can be stacked as tensors
valid_tensors = [torch.tensor(seq) for seq in seq_list]
result[k] = torch.stack(valid_tensors)

# Handle visual features with dynamic batching
if "visual_features" in example_dict:
visual_features_list = example_dict["visual_features"]
batched_visual_features = collate_visual_features(visual_features_list)
if batched_visual_features:
result["visual_features"] = batched_visual_features

for k, seq_list in example_dict.items():
if k == "model_version":
Expand Down
3 changes: 3 additions & 0 deletions pipelinerl/finetune/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def get_grouped_params(
):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
# Skip frozen parameters
if not p.requires_grad:
continue
if any(nd in n for nd in no_decay):
params_without_wd.append(p)
else:
Expand Down
11 changes: 5 additions & 6 deletions pipelinerl/finetune/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,12 @@ def rl_step(
}
if batch.is_packed:
model_inputs["position_ids"] = batch.position_ids

# Add visual features if present (for multimodal models)
if hasattr(batch, 'pixel_values') and batch.pixel_values is not None:
model_inputs["pixel_values"] = batch.pixel_values
if hasattr(batch, 'image_grid_thw') and batch.image_grid_thw is not None:
model_inputs["image_grid_thw"] = batch.image_grid_thw #torch.tensor(.reshape((1, 3))

# Unpack all visual features from the dict (e.g., pixel_values, image_grid_thw, image_sizes)
if hasattr(batch, 'visual_features') and batch.visual_features is not None:
model_inputs.update(batch.visual_features)

outputs = model(**model_inputs)

# compute log probs and entropy
Expand Down
24 changes: 14 additions & 10 deletions pipelinerl/finetune/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class PipelineBatchEncoding(BaseModel):
is_packed: bool = False
seq_boundaries: torch.IntTensor | None = None # Required when seq_packing=True

# Visual feature fields (optional, for multimodal models)
pixel_values: torch.FloatTensor | None = None
image_grid_thw: torch.LongTensor | None = None
# Visual features (optional, for multimodal models)
# Dict containing model-specific visual features (e.g., pixel_values, image_grid_thw, image_sizes)
visual_features: dict[str, torch.Tensor] | None = None

@field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'image_grid_thw', 'segment_ids', mode='before')
@field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'segment_ids', mode='before')
@classmethod
def convert_to_long_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.LongTensor | None:
"""Handle initialization of long tensors from different types."""
Expand All @@ -95,9 +95,8 @@ def convert_to_int_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.IntT
if isinstance(v, torch.Tensor):
return v.int() # type: ignore
return torch.tensor(v, dtype=torch.int)

# TODO: am i needed?
@field_validator('rewards', 'advantages', 'ref_logprobs', 'old_logprobs', 'group_tokens', 'num_labels', 'overflow', 'pixel_values', mode='before')

@field_validator('rewards', 'advantages', 'ref_logprobs', 'old_logprobs', 'group_tokens', 'num_labels', 'overflow', mode='before')
@classmethod
def convert_to_float_tensor(cls, v: List[float] | torch.Tensor | None) -> torch.FloatTensor | None:
"""Handle initialization of float tensors from different types."""
Expand All @@ -111,10 +110,16 @@ def convert_to_float_tensor(cls, v: List[float] | torch.Tensor | None) -> torch.

def to_device(self, device: Union[str, torch.device]) -> 'PipelineBatchEncoding':
"""Move all tensors to the specified device and return updated instance."""
for field_name in self.model_fields:
for field_name in type(self).model_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, torch.Tensor):
setattr(self, field_name, field_value.to(device))
elif isinstance(field_value, dict):
setattr(
self,
field_name,
{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in field_value.items()}
)
return self

@classmethod
Expand Down Expand Up @@ -173,8 +178,7 @@ def make_slices(self, num_slices: int) -> list['PipelineBatchEncoding']:
"is_packed": self.is_packed,
"padding": self.padding,
"seq_boundaries": self.seq_boundaries,
"pixel_values": self.pixel_values,
"image_grid_thw": self.image_grid_thw
"visual_features": self.visual_features
}
slices.append(PipelineBatchEncoding(**result))
return slices
Expand Down
19 changes: 18 additions & 1 deletion pipelinerl/finetune_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,26 @@ def run_finetuning_loop(
logger.info("About to load model")
model = load_model(args, args.model_class, current_dir)
logger.info(f"Model loaded in dtype {model.dtype}")

dt = log_time(dt, time_stats, "finetune/model_load")

# Fix for multimodal models (e.g., Apriel, Qwen3-VL) with Accelerate+DeepSpeed
# Accelerate's _prepare_deepspeed() doesn't check text_config.hidden_size
if not hasattr(model.config, "hidden_size") and not hasattr(model.config, "hidden_sizes"):
if hasattr(model.config, "text_config"):
hidden_size = None
if hasattr(model.config.text_config, "hidden_size"):
hidden_size = model.config.text_config.hidden_size
elif hasattr(model.config.text_config, "hidden_sizes"):
hidden_size = max(model.config.text_config.hidden_sizes)

if hidden_size is not None:
if get_accelerator().is_main_process:
logger.info(
f"Detected multimodal model with text_config.hidden_size={hidden_size}. "
f"Setting config.hidden_size to enable DeepSpeed auto-configuration."
)
model.config.hidden_size = hidden_size

data_stream = SingleStreamSpec(
exp_path=exp_root_dir,
topic=args.input,
Expand Down
Loading