Skip to content
This repository was archived by the owner on Apr 12, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: check-case-conflict
Expand All @@ -18,20 +18,20 @@ repos:
- id: requirements-txt-fixer

- repo: https://github.com/PyCQA/autoflake
rev: v2.0.2
rev: v2.3.3
hooks:
- id: autoflake
args: [--remove-all-unused-imports, --in-place]

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 8.0.1
hooks:
- id: isort
name: Format imports
exclude: docs/

- repo: https://github.com/psf/black
rev: 24.3.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.3.1
hooks:
- id: black
name: Format code
Expand Down
2 changes: 1 addition & 1 deletion openrlhf/cli/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def train(args):
)

# strategy prepare
((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model)
(model, optim, scheduler), ref_model = strategy.prepare((model, optim, scheduler), ref_model)

# load checkpoint
consumed_samples = 0
Expand Down
2 changes: 1 addition & 1 deletion openrlhf/cli/train_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def train(args):
)

# prepare models
((model, optim, scheduler), teacher_model) = strategy.prepare((model, optim, scheduler), teacher_model)
(model, optim, scheduler), teacher_model = strategy.prepare((model, optim, scheduler), teacher_model)

# load checkpoint
consumed_samples = 0
Expand Down
2 changes: 1 addition & 1 deletion openrlhf/cli/train_kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def train(args):
)

# strategy prepare
((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model)
(model, optim, scheduler), ref_model = strategy.prepare((model, optim, scheduler), ref_model)

# load checkpoint
consumed_samples = 0
Expand Down
12 changes: 8 additions & 4 deletions openrlhf/cli/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ def train(args):
)
parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range")
parser.add_argument("--max_pixels",type=int,default=640*28*28,help="Max pixels for image")
parser.add_argument("--min_pixels",type=int,default=4*28*28,help="Min pixels for image")
parser.add_argument("--max_pixels", type=int, default=640 * 28 * 28, help="Max pixels for image")
parser.add_argument("--min_pixels", type=int, default=4 * 28 * 28, help="Min pixels for image")
# DeepSpeed
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
Expand All @@ -362,8 +362,12 @@ def train(args):
parser.add_argument("--overlap_comm", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
parser.add_argument("--freeze_prefix", type=str, nargs="+", default=None,
help="List of parameter name prefixes to freeze during training"
parser.add_argument(
"--freeze_prefix",
type=str,
nargs="+",
default=None,
help="List of parameter name prefixes to freeze during training",
)
parser.add_argument("--drop_maxlen", action="store_true", default=False)

Expand Down
12 changes: 8 additions & 4 deletions openrlhf/cli/train_ppo_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,16 @@ def train(args):
parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range")
parser.add_argument("--freeze_prefix", type=str, nargs="+", default=None,
help="List of parameter name prefixes to freeze during training"
parser.add_argument(
"--freeze_prefix",
type=str,
nargs="+",
default=None,
help="List of parameter name prefixes to freeze during training",
)
parser.add_argument("--drop_maxlen", action="store_true", default=False)
parser.add_argument("--max_pixels",type=int,default=640*28*28,help="Max pixels for image")
parser.add_argument("--min_pixels",type=int,default=4*28*28,help="Min pixels for image")
parser.add_argument("--max_pixels", type=int, default=640 * 28 * 28, help="Max pixels for image")
parser.add_argument("--min_pixels", type=int, default=4 * 28 * 28, help="Min pixels for image")
# Reinforce
parser.add_argument(
"--advantage_estimator",
Expand Down
2 changes: 1 addition & 1 deletion openrlhf/cli/train_prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def train(args):
)

# prepare models
(model, optim, scheduler) = strategy.prepare((model, optim, scheduler))
model, optim, scheduler = strategy.prepare((model, optim, scheduler))

# load checkpoint
consumed_samples = 0
Expand Down
2 changes: 1 addition & 1 deletion openrlhf/cli/train_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train(args):
)

# strategy prepare
(model, optim, scheduler) = strategy.prepare((model, optim, scheduler))
model, optim, scheduler = strategy.prepare((model, optim, scheduler))

# load checkpoint
consumed_samples = 0
Expand Down
2 changes: 1 addition & 1 deletion openrlhf/cli/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train(args):
)

# prepare models
(model, optim, scheduler) = strategy.prepare((model, optim, scheduler))
model, optim, scheduler = strategy.prepare((model, optim, scheduler))

# load checkpoint
consumed_samples = 0
Expand Down
39 changes: 23 additions & 16 deletions openrlhf/models/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn import functional as F
from flash_attn.utils.distributed import all_gather
from peft import LoraConfig, TaskType, get_peft_model
from peft.tuners.lora import LoraLayer
from transformers import BitsAndBytesConfig, AutoConfig
from torch.nn import functional as F
from transformers import AutoConfig, BitsAndBytesConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from flash_attn.utils.distributed import all_gather

from .ring_attn_utils import convert_ring_attn_params, set_hacked_position_ids, clear_hacked_position_ids
from .utils import log_probs_from_logits, reset_position_ids
from openrlhf.models.lmm_kits.utils import get_generation_cls

from .ring_attn_utils import clear_hacked_position_ids, convert_ring_attn_params, set_hacked_position_ids
from .utils import log_probs_from_logits, reset_position_ids


class Actor(nn.Module):
"""
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
else:
nf4_config = None

#There is no AutoModelForConditionalGeneration in transformers. We manually implement it.
# There is no AutoModelForConditionalGeneration in transformers. We manually implement it.
config = AutoConfig.from_pretrained(pretrain_or_model)
model_cls = get_generation_cls(config)
self.model = model_cls.from_pretrained(
Expand Down Expand Up @@ -200,35 +201,41 @@ def forward(
"""Returns action log probs"""
if visual_inputs is None:
visual_inputs = {}
'''
"""
for k,v in visual_inputs.items():
if v.dtype == torch.float32:
visual_inputs[k] = v.to(self.model.get_input_embeddings().weight.dtype)
'''
"""
inputs_embeds = self.model.get_inputs_embeds(sequences, **visual_inputs)
if not self.packing_samples:
# https://github.com/OpenRLHF/OpenRLHF/issues/217
#position_ids = attention_mask.long().cumsum(-1) - 1
#position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = self.model.get_position_ids(sequences,attention_mask=attention_mask, **visual_inputs)
# position_ids = attention_mask.long().cumsum(-1) - 1
# position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = self.model.get_position_ids(sequences, attention_mask=attention_mask, **visual_inputs)
else:
# convert attention_mask to position_ids
packed_position_ids = self.model.get_position_ids(sequences, **visual_inputs)
if ring_attn_group is not None:
labels = sequences
sequences, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params(
sequences, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids
sequences, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = (
convert_ring_attn_params(
sequences, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids
)
)
position_ids = self.model.offset_split_position_ids(split_position_ids, hacked_position_ids) # this is true position_ids
#position_ids is directly hacked into flash_attn_forward to distinguish between different sequences
position_ids = self.model.offset_split_position_ids(
split_position_ids, hacked_position_ids
) # this is true position_ids
# position_ids is directly hacked into flash_attn_forward to distinguish between different sequences
else:
hacked_position_ids = reset_position_ids(attention_mask)
position_ids = self.model.offset_split_position_ids(packed_position_ids, hacked_position_ids)

set_hacked_position_ids(hacked_position_ids)
# explicitly ignore attention_mask for packing_samples
attention_mask = None
output = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **visual_inputs)
output = self.model(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **visual_inputs
)
clear_hacked_position_ids()
# https://github.com/OpenRLHF/OpenRLHF/pull/634
output["logits"] = output["logits"].to(torch.float32)
Expand Down
55 changes: 27 additions & 28 deletions openrlhf/models/lmm_kits/base/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import json
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict
from typing import Dict, List, Optional, Union

import torch
from transformers.processing_utils import ProcessorMixin
from qwen_vl_utils import process_vision_info
from transformers.processing_utils import ProcessorMixin


class BaseDataProcessor(ABC):
def __init__(self, processor: ProcessorMixin,min_pixels:int,max_pixels:int):
def __init__(self, processor: ProcessorMixin, min_pixels: int, max_pixels: int):
super().__init__()
self.processor = processor
self.min_pixels = min_pixels
self.max_pixels = max_pixels

@abstractmethod
def __call__(
self,
Expand All @@ -25,24 +27,24 @@ def __call__(
) -> Dict:
raise NotImplementedError

def _add_pixel_bounds(self,messages:List[Dict]) -> List[Dict]:
DEFAULT_MIN_PIXELS = self.min_pixels
DEFAULT_MAX_PIXELS = self.max_pixels

def process_content(content):
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "image":
if "min_pixels" not in item:
item["min_pixels"] = DEFAULT_MIN_PIXELS
if "max_pixels" not in item:
item["max_pixels"] = DEFAULT_MAX_PIXELS
return content

for message in messages:
for msg in message:
msg["content"] = process_content(msg["content"])
return messages
def _add_pixel_bounds(self, messages: List[Dict]) -> List[Dict]:
DEFAULT_MIN_PIXELS = self.min_pixels
DEFAULT_MAX_PIXELS = self.max_pixels

def process_content(content):
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "image":
if "min_pixels" not in item:
item["min_pixels"] = DEFAULT_MIN_PIXELS
if "max_pixels" not in item:
item["max_pixels"] = DEFAULT_MAX_PIXELS
return content

for message in messages:
for msg in message:
msg["content"] = process_content(msg["content"])
return messages

@abstractmethod
def make_input_batch(self, inputs: List[Dict]) -> Dict:
Expand Down Expand Up @@ -70,19 +72,16 @@ def apply_chat_template(
add_generation_prompt: bool = True,
) -> List[str]:
messages = self._format_messages(messages)

return self.processor.apply_chat_template(
messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt
)

def get_images_from_messages(
self, messages: Union[Dict, List[str], str]
) -> List[Dict]:
def get_images_from_messages(self, messages: Union[Dict, List[str], str]) -> List[Dict]:
messages = self._format_messages(messages)
image_inputs, _ = process_vision_info(messages)
return image_inputs


@property
def pad_token_id(self) -> int:
return self.processor.tokenizer.pad_token_id
Expand All @@ -93,4 +92,4 @@ def eos_token_id(self) -> int:

@property
def tokenizer(self):
return self.processor.tokenizer
return self.processor.tokenizer
24 changes: 13 additions & 11 deletions openrlhf/models/lmm_kits/base/patch.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
from abc import ABC, abstractmethod


class BasePatch(ABC):
def __init__(self):
self.loaded = False

@abstractmethod
def _add_get_inputs_embeds():
'''
Add a `get_inputs_embeds(*args,**kwargs)` method to the model class,
"""
Add a `get_inputs_embeds(*args,**kwargs)` method to the model class,
which embeds image embeddings into the text embeddings and return the results.
'''
"""
return NotImplementedError

@abstractmethod
def _add_get_position_ids():
'''
Add a `get_posiiton_ids(*args,**kwargs)` method to the model class,
"""
Add a `get_posiiton_ids(*args,**kwargs)` method to the model class,
which return the position_ids of the given inputs.
'''
"""
return NotImplementedError

@abstractmethod
def _add_offset_split_position_ids():
'''
Add a `offset_split_position_ids(*args,**kwargs)` method to the model class,
"""
Add a `offset_split_position_ids(*args,**kwargs)` method to the model class,
which offset the split position_ids to true position_ids.
'''
"""
return NotImplementedError

@classmethod
@abstractmethod
def _load_all_patches(cls):
'''
"""
Load all patches.
'''
"""
return NotImplementedError

def load_all_patches(self):
Expand Down
Loading