diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index fa5e59cf..2b2b10df 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,7 +8,7 @@ ci:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.5.0
+ rev: v6.0.0
hooks:
- id: check-yaml
- id: check-case-conflict
@@ -18,20 +18,20 @@ repos:
- id: requirements-txt-fixer
- repo: https://github.com/PyCQA/autoflake
- rev: v2.0.2
+ rev: v2.3.3
hooks:
- id: autoflake
args: [--remove-all-unused-imports, --in-place]
- repo: https://github.com/PyCQA/isort
- rev: 5.13.2
+ rev: 8.0.1
hooks:
- id: isort
name: Format imports
exclude: docs/
- - repo: https://github.com/psf/black
- rev: 24.3.0
+ - repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 26.3.1
hooks:
- id: black
name: Format code
diff --git a/openrlhf/cli/train_dpo.py b/openrlhf/cli/train_dpo.py
index a43b1f2b..9703b68a 100644
--- a/openrlhf/cli/train_dpo.py
+++ b/openrlhf/cli/train_dpo.py
@@ -119,7 +119,7 @@ def train(args):
)
# strategy prepare
- ((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model)
+ (model, optim, scheduler), ref_model = strategy.prepare((model, optim, scheduler), ref_model)
# load checkpoint
consumed_samples = 0
diff --git a/openrlhf/cli/train_kd.py b/openrlhf/cli/train_kd.py
index ca30038d..9862841c 100644
--- a/openrlhf/cli/train_kd.py
+++ b/openrlhf/cli/train_kd.py
@@ -104,7 +104,7 @@ def train(args):
)
# prepare models
- ((model, optim, scheduler), teacher_model) = strategy.prepare((model, optim, scheduler), teacher_model)
+ (model, optim, scheduler), teacher_model = strategy.prepare((model, optim, scheduler), teacher_model)
# load checkpoint
consumed_samples = 0
diff --git a/openrlhf/cli/train_kto.py b/openrlhf/cli/train_kto.py
index e5b343f5..ded8d09e 100644
--- a/openrlhf/cli/train_kto.py
+++ b/openrlhf/cli/train_kto.py
@@ -99,7 +99,7 @@ def train(args):
)
# strategy prepare
- ((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model)
+ (model, optim, scheduler), ref_model = strategy.prepare((model, optim, scheduler), ref_model)
# load checkpoint
consumed_samples = 0
diff --git a/openrlhf/cli/train_ppo.py b/openrlhf/cli/train_ppo.py
index f19bdccd..488fb989 100644
--- a/openrlhf/cli/train_ppo.py
+++ b/openrlhf/cli/train_ppo.py
@@ -344,8 +344,8 @@ def train(args):
)
parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range")
- parser.add_argument("--max_pixels",type=int,default=640*28*28,help="Max pixels for image")
- parser.add_argument("--min_pixels",type=int,default=4*28*28,help="Min pixels for image")
+ parser.add_argument("--max_pixels", type=int, default=640 * 28 * 28, help="Max pixels for image")
+ parser.add_argument("--min_pixels", type=int, default=4 * 28 * 28, help="Min pixels for image")
# DeepSpeed
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
@@ -362,8 +362,12 @@ def train(args):
parser.add_argument("--overlap_comm", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
- parser.add_argument("--freeze_prefix", type=str, nargs="+", default=None,
- help="List of parameter name prefixes to freeze during training"
+ parser.add_argument(
+ "--freeze_prefix",
+ type=str,
+ nargs="+",
+ default=None,
+ help="List of parameter name prefixes to freeze during training",
)
parser.add_argument("--drop_maxlen", action="store_true", default=False)
diff --git a/openrlhf/cli/train_ppo_ray.py b/openrlhf/cli/train_ppo_ray.py
index 631765a5..dcbc8b86 100644
--- a/openrlhf/cli/train_ppo_ray.py
+++ b/openrlhf/cli/train_ppo_ray.py
@@ -337,12 +337,16 @@ def train(args):
parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range")
- parser.add_argument("--freeze_prefix", type=str, nargs="+", default=None,
- help="List of parameter name prefixes to freeze during training"
+ parser.add_argument(
+ "--freeze_prefix",
+ type=str,
+ nargs="+",
+ default=None,
+ help="List of parameter name prefixes to freeze during training",
)
parser.add_argument("--drop_maxlen", action="store_true", default=False)
- parser.add_argument("--max_pixels",type=int,default=640*28*28,help="Max pixels for image")
- parser.add_argument("--min_pixels",type=int,default=4*28*28,help="Min pixels for image")
+ parser.add_argument("--max_pixels", type=int, default=640 * 28 * 28, help="Max pixels for image")
+ parser.add_argument("--min_pixels", type=int, default=4 * 28 * 28, help="Min pixels for image")
# Reinforce
parser.add_argument(
"--advantage_estimator",
diff --git a/openrlhf/cli/train_prm.py b/openrlhf/cli/train_prm.py
index b8d8278c..284ffab6 100644
--- a/openrlhf/cli/train_prm.py
+++ b/openrlhf/cli/train_prm.py
@@ -87,7 +87,7 @@ def train(args):
)
# prepare models
- (model, optim, scheduler) = strategy.prepare((model, optim, scheduler))
+ model, optim, scheduler = strategy.prepare((model, optim, scheduler))
# load checkpoint
consumed_samples = 0
diff --git a/openrlhf/cli/train_rm.py b/openrlhf/cli/train_rm.py
index eb6e4f84..ded8243d 100644
--- a/openrlhf/cli/train_rm.py
+++ b/openrlhf/cli/train_rm.py
@@ -106,7 +106,7 @@ def train(args):
)
# strategy prepare
- (model, optim, scheduler) = strategy.prepare((model, optim, scheduler))
+ model, optim, scheduler = strategy.prepare((model, optim, scheduler))
# load checkpoint
consumed_samples = 0
diff --git a/openrlhf/cli/train_sft.py b/openrlhf/cli/train_sft.py
index cb6e7619..3ef4fa67 100644
--- a/openrlhf/cli/train_sft.py
+++ b/openrlhf/cli/train_sft.py
@@ -105,7 +105,7 @@ def train(args):
)
# prepare models
- (model, optim, scheduler) = strategy.prepare((model, optim, scheduler))
+ model, optim, scheduler = strategy.prepare((model, optim, scheduler))
# load checkpoint
consumed_samples = 0
diff --git a/openrlhf/models/actor.py b/openrlhf/models/actor.py
index 1d63fc2d..936560a1 100644
--- a/openrlhf/models/actor.py
+++ b/openrlhf/models/actor.py
@@ -3,17 +3,18 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-from torch.nn import functional as F
+from flash_attn.utils.distributed import all_gather
from peft import LoraConfig, TaskType, get_peft_model
from peft.tuners.lora import LoraLayer
-from transformers import BitsAndBytesConfig, AutoConfig
+from torch.nn import functional as F
+from transformers import AutoConfig, BitsAndBytesConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig
-from flash_attn.utils.distributed import all_gather
-from .ring_attn_utils import convert_ring_attn_params, set_hacked_position_ids, clear_hacked_position_ids
-from .utils import log_probs_from_logits, reset_position_ids
from openrlhf.models.lmm_kits.utils import get_generation_cls
+from .ring_attn_utils import clear_hacked_position_ids, convert_ring_attn_params, set_hacked_position_ids
+from .utils import log_probs_from_logits, reset_position_ids
+
class Actor(nn.Module):
"""
@@ -73,7 +74,7 @@ def __init__(
else:
nf4_config = None
- #There is no AutoModelForConditionalGeneration in transformers. We manually implement it.
+ # There is no AutoModelForConditionalGeneration in transformers. We manually implement it.
config = AutoConfig.from_pretrained(pretrain_or_model)
model_cls = get_generation_cls(config)
self.model = model_cls.from_pretrained(
@@ -200,27 +201,31 @@ def forward(
"""Returns action log probs"""
if visual_inputs is None:
visual_inputs = {}
- '''
+ """
for k,v in visual_inputs.items():
if v.dtype == torch.float32:
visual_inputs[k] = v.to(self.model.get_input_embeddings().weight.dtype)
- '''
+ """
inputs_embeds = self.model.get_inputs_embeds(sequences, **visual_inputs)
if not self.packing_samples:
# https://github.com/OpenRLHF/OpenRLHF/issues/217
- #position_ids = attention_mask.long().cumsum(-1) - 1
- #position_ids.masked_fill_(attention_mask == 0, 1)
- position_ids = self.model.get_position_ids(sequences,attention_mask=attention_mask, **visual_inputs)
+ # position_ids = attention_mask.long().cumsum(-1) - 1
+ # position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = self.model.get_position_ids(sequences, attention_mask=attention_mask, **visual_inputs)
else:
# convert attention_mask to position_ids
packed_position_ids = self.model.get_position_ids(sequences, **visual_inputs)
if ring_attn_group is not None:
labels = sequences
- sequences, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params(
- sequences, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids
+ sequences, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = (
+ convert_ring_attn_params(
+ sequences, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids
+ )
)
- position_ids = self.model.offset_split_position_ids(split_position_ids, hacked_position_ids) # this is true position_ids
- #position_ids is directly hacked into flash_attn_forward to distinguish between different sequences
+ position_ids = self.model.offset_split_position_ids(
+ split_position_ids, hacked_position_ids
+ ) # this is true position_ids
+ # position_ids is directly hacked into flash_attn_forward to distinguish between different sequences
else:
hacked_position_ids = reset_position_ids(attention_mask)
position_ids = self.model.offset_split_position_ids(packed_position_ids, hacked_position_ids)
@@ -228,7 +233,9 @@ def forward(
set_hacked_position_ids(hacked_position_ids)
# explicitly ignore attention_mask for packing_samples
attention_mask = None
- output = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **visual_inputs)
+ output = self.model(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **visual_inputs
+ )
clear_hacked_position_ids()
# https://github.com/OpenRLHF/OpenRLHF/pull/634
output["logits"] = output["logits"].to(torch.float32)
diff --git a/openrlhf/models/lmm_kits/base/data_processor.py b/openrlhf/models/lmm_kits/base/data_processor.py
index 40dbf835..2531c651 100644
--- a/openrlhf/models/lmm_kits/base/data_processor.py
+++ b/openrlhf/models/lmm_kits/base/data_processor.py
@@ -1,17 +1,19 @@
import json
-import os
from abc import ABC, abstractmethod
-from typing import List, Optional, Union, Dict
+from typing import Dict, List, Optional, Union
+
import torch
-from transformers.processing_utils import ProcessorMixin
from qwen_vl_utils import process_vision_info
+from transformers.processing_utils import ProcessorMixin
+
class BaseDataProcessor(ABC):
- def __init__(self, processor: ProcessorMixin,min_pixels:int,max_pixels:int):
+ def __init__(self, processor: ProcessorMixin, min_pixels: int, max_pixels: int):
super().__init__()
self.processor = processor
self.min_pixels = min_pixels
self.max_pixels = max_pixels
+
@abstractmethod
def __call__(
self,
@@ -25,24 +27,24 @@ def __call__(
) -> Dict:
raise NotImplementedError
- def _add_pixel_bounds(self,messages:List[Dict]) -> List[Dict]:
- DEFAULT_MIN_PIXELS = self.min_pixels
- DEFAULT_MAX_PIXELS = self.max_pixels
-
- def process_content(content):
- if isinstance(content, list):
- for item in content:
- if isinstance(item, dict) and item.get("type") == "image":
- if "min_pixels" not in item:
- item["min_pixels"] = DEFAULT_MIN_PIXELS
- if "max_pixels" not in item:
- item["max_pixels"] = DEFAULT_MAX_PIXELS
- return content
-
- for message in messages:
- for msg in message:
- msg["content"] = process_content(msg["content"])
- return messages
+ def _add_pixel_bounds(self, messages: List[Dict]) -> List[Dict]:
+ DEFAULT_MIN_PIXELS = self.min_pixels
+ DEFAULT_MAX_PIXELS = self.max_pixels
+
+ def process_content(content):
+ if isinstance(content, list):
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "image":
+ if "min_pixels" not in item:
+ item["min_pixels"] = DEFAULT_MIN_PIXELS
+ if "max_pixels" not in item:
+ item["max_pixels"] = DEFAULT_MAX_PIXELS
+ return content
+
+ for message in messages:
+ for msg in message:
+ msg["content"] = process_content(msg["content"])
+ return messages
@abstractmethod
def make_input_batch(self, inputs: List[Dict]) -> Dict:
@@ -70,19 +72,16 @@ def apply_chat_template(
add_generation_prompt: bool = True,
) -> List[str]:
messages = self._format_messages(messages)
-
+
return self.processor.apply_chat_template(
messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt
)
- def get_images_from_messages(
- self, messages: Union[Dict, List[str], str]
- ) -> List[Dict]:
+ def get_images_from_messages(self, messages: Union[Dict, List[str], str]) -> List[Dict]:
messages = self._format_messages(messages)
image_inputs, _ = process_vision_info(messages)
return image_inputs
-
@property
def pad_token_id(self) -> int:
return self.processor.tokenizer.pad_token_id
@@ -93,4 +92,4 @@ def eos_token_id(self) -> int:
@property
def tokenizer(self):
- return self.processor.tokenizer
\ No newline at end of file
+ return self.processor.tokenizer
diff --git a/openrlhf/models/lmm_kits/base/patch.py b/openrlhf/models/lmm_kits/base/patch.py
index 561250d5..843aa6e2 100644
--- a/openrlhf/models/lmm_kits/base/patch.py
+++ b/openrlhf/models/lmm_kits/base/patch.py
@@ -1,38 +1,40 @@
from abc import ABC, abstractmethod
+
class BasePatch(ABC):
def __init__(self):
self.loaded = False
+
@abstractmethod
def _add_get_inputs_embeds():
- '''
- Add a `get_inputs_embeds(*args,**kwargs)` method to the model class,
+ """
+ Add a `get_inputs_embeds(*args,**kwargs)` method to the model class,
which embeds image embeddings into the text embeddings and return the results.
- '''
+ """
return NotImplementedError
@abstractmethod
def _add_get_position_ids():
- '''
- Add a `get_posiiton_ids(*args,**kwargs)` method to the model class,
+ """
+ Add a `get_posiiton_ids(*args,**kwargs)` method to the model class,
which return the position_ids of the given inputs.
- '''
+ """
return NotImplementedError
@abstractmethod
def _add_offset_split_position_ids():
- '''
- Add a `offset_split_position_ids(*args,**kwargs)` method to the model class,
+ """
+ Add a `offset_split_position_ids(*args,**kwargs)` method to the model class,
which offset the split position_ids to true position_ids.
- '''
+ """
return NotImplementedError
@classmethod
@abstractmethod
def _load_all_patches(cls):
- '''
+ """
Load all patches.
- '''
+ """
return NotImplementedError
def load_all_patches(self):
diff --git a/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py b/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py
index e3ea7d1e..3f3aea31 100644
--- a/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py
+++ b/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py
@@ -1,5 +1,4 @@
-import os
-from typing import List, Dict
+from typing import Dict, List
import torch
from qwen_vl_utils import process_vision_info
@@ -20,9 +19,7 @@ def __call__(
) -> Dict:
messages = self._format_messages(messages)
processor = self.processor
- texts = processor.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
+ texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
batch = processor(
@@ -44,7 +41,7 @@ def make_input_batch(self, inputs: List[Dict]) -> Dict:
batch = {}
# collect all keys
for inp in inputs:
- batch.update({k:None for k,v in inp.items() if v is not None})
+ batch.update({k: None for k, v in inp.items() if v is not None})
for k in batch.keys():
if k in ["input_ids", "attention_mask"]:
batch[k] = torch.stack([inp[k] for inp in inputs if k in inp], dim=0)
@@ -67,12 +64,8 @@ def split_input_batch(self, batch: Dict) -> List[Dict]:
for i in range(batch_size):
batch_kwargs[i][k] = None
- if "pixel_values" in keys and (
- "input_ids" not in keys or "image_grid_thw" not in keys
- ):
- raise ValueError(
- "Cannot split batch with pixel_values without input_ids and image_grid_thw"
- )
+ if "pixel_values" in keys and ("input_ids" not in keys or "image_grid_thw" not in keys):
+ raise ValueError("Cannot split batch with pixel_values without input_ids and image_grid_thw")
if "image_grid_thw" in keys and ("input_ids" not in keys):
raise ValueError("Cannot split batch with image_grid_thw without input_ids")
for k in ["input_ids", "attention_mask"]:
@@ -114,7 +107,8 @@ def split_input_batch(self, batch: Dict) -> List[Dict]:
assert len(thws) == 0
assert len(pixel_values) == 0
return batch_kwargs
-
+
+
DataProcessor = Qwen2_5_VLDataProcessor
-__all__ = ["DataProcessor"]
\ No newline at end of file
+__all__ = ["DataProcessor"]
diff --git a/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py b/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py
index 83c4a4e5..4c05bb4b 100644
--- a/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py
+++ b/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py
@@ -1,10 +1,21 @@
-from ..base.patch import BasePatch
import torch
+from ..base.patch import BasePatch
+
+
class Qwen2_5_VLPatch(BasePatch):
def _add_get_inputs_embeds():
from transformers import Qwen2_5_VLForConditionalGeneration
- def get_inputs_embeds(self, input_ids, image_grid_thw=None, video_grid_thw=None, pixel_values=None, pixel_values_videos=None, **kwargs):
+
+ def get_inputs_embeds(
+ self,
+ input_ids,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ pixel_values=None,
+ pixel_values_videos=None,
+ **kwargs,
+ ):
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
@@ -47,33 +58,47 @@ def get_inputs_embeds(self, input_ids, image_grid_thw=None, video_grid_thw=None,
def _add_get_position_ids():
from transformers import Qwen2_5_VLForConditionalGeneration
+
def get_position_ids(self, input_ids, image_grid_thw=None, video_grid_thw=None, attention_mask=None, **kwargs):
- position_ids,mrope_position_deltas = self.get_rope_index(input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask)
+ position_ids, mrope_position_deltas = self.get_rope_index(
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ )
return position_ids
+
Qwen2_5_VLForConditionalGeneration.get_position_ids = get_position_ids
def _add_offset_split_position_ids():
from transformers import Qwen2_5_VLForConditionalGeneration
- def offset_split_position_ids(self,position_ids,hacked_position_ids):
+
+ def offset_split_position_ids(self, position_ids, hacked_position_ids):
new_position_ids = position_ids.clone()
for i in range(hacked_position_ids.size(0)):
- seq_idxes = torch.nonzero(hacked_position_ids[i]==0)[:,0]
- seq_idxes = torch.cat([seq_idxes, torch.tensor([hacked_position_ids.size(1)],device=seq_idxes.device)], dim=0)
+ seq_idxes = torch.nonzero(hacked_position_ids[i] == 0)[:, 0]
+ seq_idxes = torch.cat(
+ [seq_idxes, torch.tensor([hacked_position_ids.size(1)], device=seq_idxes.device)], dim=0
+ )
st = 0
for seq_idx in seq_idxes:
if st == 0 and seq_idx == 0:
continue
- #shape: [3,bs,seq_len]
- raw_seq_position_ids = position_ids[:,i,st:seq_idx]
- new_position_ids[:,i,st:seq_idx] = raw_seq_position_ids - raw_seq_position_ids[:,:1] + hacked_position_ids[i,st]
+ # shape: [3,bs,seq_len]
+ raw_seq_position_ids = position_ids[:, i, st:seq_idx]
+ new_position_ids[:, i, st:seq_idx] = (
+ raw_seq_position_ids - raw_seq_position_ids[:, :1] + hacked_position_ids[i, st]
+ )
st = seq_idx
return new_position_ids
+
Qwen2_5_VLForConditionalGeneration.offset_split_position_ids = offset_split_position_ids
-
+
@classmethod
def _load_all_patches(cls):
cls._add_get_inputs_embeds()
cls._add_get_position_ids()
cls._add_offset_split_position_ids()
-Patch = Qwen2_5_VLPatch()
\ No newline at end of file
+
+Patch = Qwen2_5_VLPatch()
diff --git a/openrlhf/models/lmm_kits/utils.py b/openrlhf/models/lmm_kits/utils.py
index 1227480e..6cae1454 100644
--- a/openrlhf/models/lmm_kits/utils.py
+++ b/openrlhf/models/lmm_kits/utils.py
@@ -1,18 +1,22 @@
-from transformers import AutoConfig, AutoProcessor, AutoModel
import importlib
-import os
-def _get_kit_root_path(pretrain_or_model=None,model_type=None):
+from transformers import AutoConfig, AutoModel, AutoProcessor
+
+
+def _get_kit_root_path(pretrain_or_model=None, model_type=None):
if model_type is None:
config = AutoConfig.from_pretrained(pretrain_or_model)
model_type = config.model_type
root_path = f".models.lmm_kits.{model_type}"
return root_path
+
def _get_hf_processor(pretrain, model, padding_side="left", strategy=None, use_fast=True):
min_pixels = strategy.args.min_pixels
max_pixels = strategy.args.max_pixels
- processor = AutoProcessor.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast, min_pixels=min_pixels, max_pixels=max_pixels)
+ processor = AutoProcessor.from_pretrained(
+ pretrain, trust_remote_code=True, use_fast=use_fast, min_pixels=min_pixels, max_pixels=max_pixels
+ )
tokenizer = processor.tokenizer
tokenizer.padding_side = padding_side
# NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
@@ -23,34 +27,41 @@ def _get_hf_processor(pretrain, model, padding_side="left", strategy=None, use_f
model.config.pad_token_id = tokenizer.pad_token_id
return processor
+
def get_data_processor(pretrain_or_model, model, padding_side="left", strategy=None, use_fast=True):
root_path = _get_kit_root_path(pretrain_or_model)
- module = importlib.import_module(f"{root_path}.data_processor",package="openrlhf")
+ module = importlib.import_module(f"{root_path}.data_processor", package="openrlhf")
data_processor_cls = getattr(module, "DataProcessor")
- hf_processor = _get_hf_processor(pretrain_or_model, model, padding_side, strategy,use_fast=use_fast)
- data_processor = data_processor_cls(hf_processor,min_pixels=strategy.args.min_pixels,max_pixels=strategy.args.max_pixels)
+ hf_processor = _get_hf_processor(pretrain_or_model, model, padding_side, strategy, use_fast=use_fast)
+ data_processor = data_processor_cls(
+ hf_processor, min_pixels=strategy.args.min_pixels, max_pixels=strategy.args.max_pixels
+ )
return data_processor
-def load_patch(pretrain_or_model=None,model_type=None):
- root_path = _get_kit_root_path(pretrain_or_model,model_type)
- module = importlib.import_module(f"{root_path}.patch",package="openrlhf")
+
+def load_patch(pretrain_or_model=None, model_type=None):
+ root_path = _get_kit_root_path(pretrain_or_model, model_type)
+ module = importlib.import_module(f"{root_path}.patch", package="openrlhf")
Patch = getattr(module, "Patch")
Patch.load_all_patches()
+
def get_generation_cls(config):
model_type = config.model_type
load_patch(model_type=model_type)
model_arch = AutoModel._model_mapping[type(config)].__name__
- if model_arch.endswith("ForCausalLM") or \
- model_arch.endswith("ForConditionalGeneration"):
+ if model_arch.endswith("ForCausalLM") or model_arch.endswith("ForConditionalGeneration"):
return AutoModel._model_mapping[type(config)]
elif model_arch.endswith("Model"):
- possible_arch = [model_arch.replace("Model", "ForCausalLM"), model_arch.replace("Model", "ForConditionalGeneration")]
- module = importlib.import_module(f".models.{model_type}.modeling_{model_type}",package="transformers")
+ possible_arch = [
+ model_arch.replace("Model", "ForCausalLM"),
+ model_arch.replace("Model", "ForConditionalGeneration"),
+ ]
+ module = importlib.import_module(f".models.{model_type}.modeling_{model_type}", package="transformers")
for arch in possible_arch:
model_cls = getattr(module, arch, None)
if model_cls is not None:
return model_cls
raise ValueError(f"Cannot find ForCausalLM or ForConditionalGeneration class for {model_arch}")
else:
- raise ValueError(f"Unexpected model architecture {model_arch}")
\ No newline at end of file
+ raise ValueError(f"Unexpected model architecture {model_arch}")
diff --git a/openrlhf/models/model.py b/openrlhf/models/model.py
index 6b7dd054..3477fe0b 100644
--- a/openrlhf/models/model.py
+++ b/openrlhf/models/model.py
@@ -6,14 +6,14 @@
from flash_attn.utils.distributed import all_gather
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import LoraLayer
-from transformers import AutoConfig, AutoModel, BitsAndBytesConfig
+from transformers import AutoConfig, BitsAndBytesConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig
+from openrlhf.models.lmm_kits.utils import get_generation_cls
from openrlhf.utils.logging_utils import init_logger
-from .ring_attn_utils import convert_ring_attn_params, set_hacked_position_ids, clear_hacked_position_ids
+from .ring_attn_utils import clear_hacked_position_ids, convert_ring_attn_params, set_hacked_position_ids
from .utils import reset_position_ids
-from openrlhf.models.lmm_kits.utils import get_generation_cls
logger = init_logger(__name__)
@@ -193,18 +193,25 @@ def forward(
inputs_embeds = super().get_inputs_embeds(input_ids, **visual_inputs)
if not self.packing_samples:
# https://github.com/OpenRLHF/OpenRLHF/issues/217
- #position_ids = attention_mask.long().cumsum(-1) - 1
- #position_ids.masked_fill_(attention_mask == 0, 1)
- position_ids = super().get_position_ids(input_ids,attention_mask=attention_mask, **visual_inputs)
+ # position_ids = attention_mask.long().cumsum(-1) - 1
+ # position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = super().get_position_ids(input_ids, attention_mask=attention_mask, **visual_inputs)
else:
# convert attention_mask to position_ids
packed_position_ids = super().get_position_ids(input_ids, **visual_inputs)
if ring_attn_group is not None:
- input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params(
- input_ids, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids
+ input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = (
+ convert_ring_attn_params(
+ input_ids,
+ attention_mask,
+ packed_seq_lens,
+ ring_attn_group,
+ inputs_embeds,
+ packed_position_ids,
+ )
)
position_ids = super().offset_split_position_ids(split_position_ids, hacked_position_ids)
-
+
else:
hacked_position_ids = reset_position_ids(attention_mask)
position_ids = super().offset_split_position_ids(packed_position_ids, hacked_position_ids)
@@ -214,7 +221,11 @@ def forward(
attention_mask = None
outputs = super().forward(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_hidden_states=True,
+ **visual_inputs,
)
clear_hacked_position_ids()
if "last_hidden_state" in outputs:
@@ -288,17 +299,24 @@ def forward(
# https://github.com/OpenRLHF/OpenRLHF/issues/217
# position_ids = attention_mask.long().cumsum(-1) - 1
# position_ids.masked_fill_(attention_mask == 0, 1)
- position_ids = super().get_position_ids(input_ids,attention_mask=attention_mask, **visual_inputs)
+ position_ids = super().get_position_ids(input_ids, attention_mask=attention_mask, **visual_inputs)
else:
# convert attention_mask to position_ids
packed_position_ids = super().get_position_ids(input_ids, **visual_inputs)
if ring_attn_group is not None:
-
- input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params(
- input_ids, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids
+
+ input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = (
+ convert_ring_attn_params(
+ input_ids,
+ attention_mask,
+ packed_seq_lens,
+ ring_attn_group,
+ inputs_embeds,
+ packed_position_ids,
+ )
)
position_ids = super().offset_split_position_ids(split_position_ids, hacked_position_ids)
-
+
else:
hacked_position_ids = reset_position_ids(attention_mask)
position_ids = super().offset_split_position_ids(packed_position_ids, hacked_position_ids)
@@ -308,7 +326,11 @@ def forward(
attention_mask = None
outputs = super().forward(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_hidden_states=True,
+ **visual_inputs,
)
clear_hacked_position_ids()
if "last_hidden_state" in outputs:
diff --git a/openrlhf/models/remote_rm/math_verifier.py b/openrlhf/models/remote_rm/math_verifier.py
index 232e0619..d39ce27f 100644
--- a/openrlhf/models/remote_rm/math_verifier.py
+++ b/openrlhf/models/remote_rm/math_verifier.py
@@ -1,17 +1,15 @@
import json
-import os
import random
import re
from argparse import ArgumentParser
-from multiprocessing import Process, Queue
+from concurrent import futures
import Levenshtein
from flask import Flask, jsonify, request
from latex2sympy2_extended import NormalizationConfig
+from loguru import logger
from math_verify import LatexExtractionConfig, parse, verify
-from loguru import logger
-from concurrent import futures
app = Flask(__name__)
problem_to_answer = {}
@@ -40,7 +38,6 @@ def verify_format(content):
return bool(re.match(format_pattern, content, re.DOTALL)) and think_count == 1 and answer_count == 1
-
def find_similar_problem(problem):
max_sim = -1
target_problem = None
@@ -52,7 +49,7 @@ def find_similar_problem(problem):
return target_problem
-def verify_math(content,sol):
+def verify_math(content, sol):
gold_parsed = parse(
sol,
extraction_mode="first_match",
@@ -102,7 +99,7 @@ def get_reward():
rewards = []
format_rewards = []
acc_rewards_futures = []
- for q,problem in zip(data["query"],data["prompts"]):
+ for q, problem in zip(data["query"], data["prompts"]):
if problem is None:
return jsonify({"error": f"problem not found from {q}"}), 400
if problem not in problem_to_answer:
@@ -115,39 +112,33 @@ def get_reward():
return jsonify({"error": f"response not found from {q}"}), 400
format_reward = float(verify_format(response)) * 0.5
acc_reward_future = math_verify_executor.submit(verify_math, response, answer)
-
+
do_print = random.randint(1, 20) == 1
if do_print:
- info=f"Query: {q}\n\nProblem: {problem}\n\n Answer: {answer}\n\n Response: {response}\n\n Format Reward: {format_reward}\n\n Acc Reward: {acc_reward_future.result()}\n\n"
- info = re.sub(r"<\|.*?\|>","",info)
+ info = f"Query: {q}\n\nProblem: {problem}\n\n Answer: {answer}\n\n Response: {response}\n\n Format Reward: {format_reward}\n\n Acc Reward: {acc_reward_future.result()}\n\n"
+ info = re.sub(r"<\|.*?\|>", "", info)
logger.info(info)
-
+
format_rewards.append(format_reward)
acc_rewards_futures.append(acc_reward_future)
acc_rewards = [f.result() for f in acc_rewards_futures]
rewards = [f + a for f, a in zip(format_rewards, acc_rewards)]
# 返回包含 rewards 的响应
- return jsonify({"rewards": rewards,"format_rewards":format_rewards,"acc_rewards":acc_rewards})
+ return jsonify({"rewards": rewards, "format_rewards": format_rewards, "acc_rewards": acc_rewards})
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument(
- "--dataset", type=str, default=None, help="Datasets to use (comma separated)", required=True
- )
- parser.add_argument(
- "--prompt-template", type=str, default=None, help="Prompt template", required=True
- )
- parser.add_argument(
- "--input_key", type=str, default="prompt", help="The key name of prompt."
- )
+ parser.add_argument("--dataset", type=str, default=None, help="Datasets to use (comma separated)", required=True)
+ parser.add_argument("--prompt-template", type=str, default=None, help="Prompt template", required=True)
+ parser.add_argument("--input_key", type=str, default="prompt", help="The key name of prompt.")
parser.add_argument("--log_file", type=str, default="remote_rm.log", help="Log file path")
args = parser.parse_args()
logger.remove()
logger.add(args.log_file)
# Split dataset paths and load all datasets
dataset = []
- for dataset_path in args.dataset.split(','):
+ for dataset_path in args.dataset.split(","):
dataset_path = dataset_path.strip()
if dataset_path.endswith("json"):
with open(dataset_path, "r") as f:
@@ -160,13 +151,13 @@ def get_reward():
format_pattern = r"^(?:(?!).)*(?:(?!).)*\Z"
- if args.prompt_template=="chatml":
+ if args.prompt_template == "chatml":
problem_pattern = r"<\|im_start\|>user\n(.*?)<\|im_end\|>"
response_prefix = r"<\|im_start\|>assistant\n"
- elif args.prompt_template=="qwen1":
+ elif args.prompt_template == "qwen1":
problem_pattern = r"|User|>(.*?)<|Assistant|>"
response_prefix = r"<|Assistant|>"
- elif args.prompt_template=="base":
+ elif args.prompt_template == "base":
problem_pattern = r"User: (.*?)\n\nAssistant:"
response_prefix = r"Assistant: "
else:
@@ -184,4 +175,4 @@ def get_reward():
math_verify_executor = futures.ProcessPoolExecutor(max_workers=16)
app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
- math_verify_executor.shutdown()
\ No newline at end of file
+ math_verify_executor.shutdown()
diff --git a/openrlhf/models/ring_attn_utils.py b/openrlhf/models/ring_attn_utils.py
index 4fd0eadf..df32b722 100644
--- a/openrlhf/models/ring_attn_utils.py
+++ b/openrlhf/models/ring_attn_utils.py
@@ -3,7 +3,6 @@
import torch.nn.functional as F
import transformers
-
RING_ATTN_GROUP = None
@@ -71,7 +70,7 @@ def convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_at
sequences = sequences[:, start:end]
inputs_embeds = inputs_embeds[:, start:end]
attention_mask = attention_mask[:, start:end]
- position_ids = position_ids[..., start:end] #qwen2_5_vl has position_ids shape: [3,bs,seq_len]
+ position_ids = position_ids[..., start:end] # qwen2_5_vl has position_ids shape: [3,bs,seq_len]
hacked_position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens)
update_ring_attn_params(packed_seq_lens, total_seq_len)
return sequences, attention_mask, hacked_position_ids, inputs_embeds, position_ids
@@ -125,27 +124,33 @@ def unpad_sequences(
kl = kl[:, :-pad_len]
return sequences, attention_mask, num_actions, packed_seq_lens, action_log_probs, values, kl
+
HACKED_POSITION_IDS = None
-#Both ring and our hack substitute flash_attn. This func must be called after ring's substitue_hf_flash_attn.
+
+# Both ring and our hack substitute flash_attn. This func must be called after ring's substitue_hf_flash_attn.
def substitute_ring_flash_attn():
raw_flash_attention_forward = transformers.modeling_flash_attention_utils._flash_attention_forward
- def _hacked_flash_attention_forward(*args,**kwargs):
+
+ def _hacked_flash_attention_forward(*args, **kwargs):
global HACKED_POSITION_IDS
if HACKED_POSITION_IDS is not None:
- kwargs['position_ids'] = HACKED_POSITION_IDS
- return raw_flash_attention_forward(*args,**kwargs)
+ kwargs["position_ids"] = HACKED_POSITION_IDS
+ return raw_flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = _hacked_flash_attention_forward
+
def set_hacked_position_ids(position_ids):
global HACKED_POSITION_IDS
HACKED_POSITION_IDS = position_ids
+
def clear_hacked_position_ids():
global HACKED_POSITION_IDS
HACKED_POSITION_IDS = None
+
def get_hacked_position_ids():
global HACKED_POSITION_IDS
return HACKED_POSITION_IDS
diff --git a/openrlhf/trainer/ppo_trainer.py b/openrlhf/trainer/ppo_trainer.py
index c0d7dd32..ffaad07e 100644
--- a/openrlhf/trainer/ppo_trainer.py
+++ b/openrlhf/trainer/ppo_trainer.py
@@ -4,8 +4,8 @@
from typing import Any, Callable, Dict, List, Optional
import torch
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
@@ -107,7 +107,6 @@ def __init__(
self.tokenizer = data_processor.tokenizer
self.processor = data_processor.processor
-
self.generate_kwargs = generate_kwargs
self.dataloader_pin_memory = dataloader_pin_memory
self.max_norm = max_norm
@@ -159,8 +158,12 @@ def __init__(
)
packing_samples = getattr(self.args, "packing_samples", False)
self.replay_buffer = NaiveReplayBuffer(
- micro_train_batch_size, self.data_processor, buffer_limit, buffer_cpu_offload, packing_samples,
- drop_maxlen=self.args.drop_maxlen,
+ micro_train_batch_size,
+ self.data_processor,
+ buffer_limit,
+ buffer_cpu_offload,
+ packing_samples,
+ drop_maxlen=self.args.drop_maxlen,
maxlen=self.args.generate_max_len + prompt_max_len,
)
@@ -359,11 +362,7 @@ def training_step_actor(self, experience: Experience) -> Dict[str, float]:
# pad seq makes the sequence a multiple of ring_attention_size.
if self.strategy.ring_attn_group is not None:
pad_len, sequences, attention_mask, num_actions, packed_seq_lens = pad_sequences(
- sequences,
- attention_mask,
- num_actions,
- packed_seq_lens,
- self.strategy.ring_attn_group
+ sequences, attention_mask, num_actions, packed_seq_lens, self.strategy.ring_attn_group
)
if self.args.use_kl_loss and experience.base_action_log_probs is not None:
base_action_log_probs = torch.cat(experience.base_action_log_probs, dim=0).unsqueeze(0)
@@ -387,7 +386,7 @@ def training_step_actor(self, experience: Experience) -> Dict[str, float]:
ring_attn_group=self.strategy.ring_attn_group,
logps_allgather=True,
packed_seq_lens=packed_seq_lens,
- visual_inputs=visual_inputs
+ visual_inputs=visual_inputs,
)
# unpad sequence ensures that pad tokens do not contribute to the loss calculation.
if self.strategy.ring_attn_group is not None:
@@ -501,11 +500,7 @@ def training_step_critic(self, experience: Experience) -> Dict[str, float]:
# pad seq makes the sequence len a multiple of ring_attention_size.
if self.strategy.ring_attn_group is not None:
pad_len, sequences, attention_mask, num_actions, packed_seq_lens = pad_sequences(
- sequences,
- attention_mask,
- num_actions,
- packed_seq_lens,
- self.strategy.ring_attn_group
+ sequences, attention_mask, num_actions, packed_seq_lens, self.strategy.ring_attn_group
)
else:
@@ -584,6 +579,7 @@ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, c
if self.experience_maker.perf_stats is not None:
logs.update({f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()})
from wandb import Histogram
+
response_length_list = Histogram(response_length_list)
logs["response_length_dist"] = response_length_list
self._wandb.log(logs)
diff --git a/openrlhf/trainer/ppo_utils/experience_maker.py b/openrlhf/trainer/ppo_utils/experience_maker.py
index 52456921..b797dfff 100644
--- a/openrlhf/trainer/ppo_utils/experience_maker.py
+++ b/openrlhf/trainer/ppo_utils/experience_maker.py
@@ -1,22 +1,18 @@
-import os
import time
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass, field
-from typing import List, Optional, Tuple, Union, Dict
+from typing import Dict, List, Optional, Tuple, Union
import ray
import torch
import torch.distributed as dist
import torch.nn as nn
-import torch.distributed as dist
-import torch.nn.functional as F
from tqdm import tqdm
from openrlhf.models.actor import Actor
from openrlhf.models.ring_attn_utils import pad_sequences, unpad_sequences
from openrlhf.models.utils import compute_approx_kl, compute_reward, masked_mean, unpacking_samples
-from openrlhf.models.ring_attn_utils import pad_sequences, unpad_sequences
from openrlhf.utils.logging_utils import init_logger
from openrlhf.utils.remote_rm_utils import remote_rm_fn, remote_rm_fn_ray
@@ -182,7 +178,6 @@ def __init__(
spec.loader.exec_module(reward_module)
self.custom_reward_func = reward_module.reward_func
-
@torch.no_grad()
def make_experience_list(
self, all_prompts: Union[str, List[str]], all_labels, **generate_kwargs
@@ -298,14 +293,13 @@ def generate_samples(self, all_prompts: List[str], all_labels, **generate_kwargs
samples_list = []
for i in range(0, len(all_prompts), args.micro_rollout_batch_size):
prompts = all_prompts[i : i + args.micro_rollout_batch_size]
-
+
inputs = self.data_processor(prompts, self.prompt_max_len, device="cuda")
visual_inputs = {}
- for k,v in inputs.items():
+ for k, v in inputs.items():
if k not in ["input_ids", "attention_mask"]:
visual_inputs[k] = v
-
labels = all_labels[i : i + args.micro_rollout_batch_size]
sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs)
samples = Samples(
@@ -349,7 +343,9 @@ def make_experience(self, samples: Samples) -> Experience:
# init log probs
if self.initial_model is not None:
- base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask, visual_inputs=visual_inputs)
+ base_action_log_probs = self.initial_model(
+ sequences, num_actions, attention_mask, visual_inputs=visual_inputs
+ )
else:
base_action_log_probs = None
@@ -385,7 +381,7 @@ def make_experience(self, samples: Samples) -> Experience:
else:
kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)
- assert isinstance(r,dict)
+ assert isinstance(r, dict)
total_reward = r.pop("rewards")
specific_rewards = r
@@ -395,7 +391,7 @@ def make_experience(self, samples: Samples) -> Experience:
"response_length": samples.response_length,
"total_length": samples.total_length,
"num_actions": num_actions,
- **specific_rewards
+ **specific_rewards,
}
# reset model state
self.actor.train()
@@ -413,7 +409,7 @@ def make_experience(self, samples: Samples) -> Experience:
action_mask,
info,
kl,
- visual_inputs=visual_inputs
+ visual_inputs=visual_inputs,
)
@torch.no_grad()
@@ -621,17 +617,17 @@ def make_experience(self, samples: Samples) -> Experience:
)
visual_inputs_cpu = None
if visual_inputs is not None:
- visual_inputs_cpu = {k: v.to("cpu") for k, v in visual_inputs.items()}
+ visual_inputs_cpu = {k: v.to("cpu") for k, v in visual_inputs.items()}
# init log probs
if self.initial_model is not None:
base_action_log_probs_ref = self.initial_model.forward.remote(
- sequences_cpu,
- num_actions,
- attention_mask_cpu,
- logps_allgather=True,
- packed_seq_lens=packed_seq_lens,
- visual_inputs=visual_inputs_cpu
- )
+ sequences_cpu,
+ num_actions,
+ attention_mask_cpu,
+ logps_allgather=True,
+ packed_seq_lens=packed_seq_lens,
+ visual_inputs=visual_inputs_cpu,
+ )
if args.colocate_actor_ref or args.colocate_all_models:
ray.get([base_action_log_probs_ref])
@@ -642,7 +638,11 @@ def make_experience(self, samples: Samples) -> Experience:
# values
if self.critic is not None:
value_ref = self.critic.forward.remote(
- sequences_cpu, num_actions, attention_mask_cpu, packed_seq_lens=packed_seq_lens, visual_inputs=visual_inputs_cpu
+ sequences_cpu,
+ num_actions,
+ attention_mask_cpu,
+ packed_seq_lens=packed_seq_lens,
+ visual_inputs=visual_inputs_cpu,
)
# avoid CUDA OOM when colocate models
if args.colocate_critic_reward or args.colocate_all_models:
@@ -658,7 +658,11 @@ def make_experience(self, samples: Samples) -> Experience:
for rm in self.reward_model:
r_refs.append(
rm.forward.remote(
- sequences_cpu, attention_mask_cpu, packed_seq_lens=packed_seq_lens, pad_sequence=True, visual_inputs=visual_inputs_cpu
+ sequences_cpu,
+ attention_mask_cpu,
+ packed_seq_lens=packed_seq_lens,
+ pad_sequence=True,
+ visual_inputs=visual_inputs_cpu,
)
)
else:
@@ -688,13 +692,13 @@ def make_experience(self, samples: Samples) -> Experience:
# log probs
action_log_probs = self.actor(
- sequences,
- num_actions,
- attention_mask,
+ sequences,
+ num_actions,
+ attention_mask,
ring_attn_group=self.strategy.ring_attn_group,
logps_allgather=True,
packed_seq_lens=packed_seq_lens,
- visual_inputs=visual_inputs
+ visual_inputs=visual_inputs,
)
actor_value_rm_time = time.time() - start
@@ -709,10 +713,10 @@ def make_experience(self, samples: Samples) -> Experience:
if value is not None:
value = value.to(device)
- total_rewards = [r.pop('rewards').to(device) if isinstance(r,dict) else r.to(device) for r in rewards]
+ total_rewards = [r.pop("rewards").to(device) if isinstance(r, dict) else r.to(device) for r in rewards]
specific_rewards = {}
for r in rewards:
- if isinstance(r,dict):
+ if isinstance(r, dict):
for k in r.keys():
r[k] = r[k].to(device)
specific_rewards.update(r)
@@ -774,7 +778,7 @@ def make_experience(self, samples: Samples) -> Experience:
"response_length": samples.response_length,
"total_length": samples.total_length,
"num_actions": num_actions,
- **specific_rewards
+ **specific_rewards,
}
if self.strategy.args.perf:
@@ -792,7 +796,7 @@ def make_experience(self, samples: Samples) -> Experience:
action_mask,
info,
kl,
- visual_inputs=visual_inputs
+ visual_inputs=visual_inputs,
)
self.actor.train() # reset model state
@@ -800,6 +804,7 @@ def make_experience(self, samples: Samples) -> Experience:
def _generate_vllm(self, all_prompts: List[str], all_labels, **kwargs) -> List[Samples]:
from vllm import SamplingParams
+
# round-robin load balance
rank = torch.distributed.get_rank() // self.strategy.ring_attn_size
world_size = torch.distributed.get_world_size() // self.strategy.ring_attn_size
@@ -835,19 +840,21 @@ def _generate_vllm(self, all_prompts: List[str], all_labels, **kwargs) -> List[S
if messages:
prompts = self.data_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images = [self.data_processor.get_images_from_messages(m) for m in messages]
- vllm_inputs = [{
+ vllm_inputs = [
+ {
"prompt": p,
- "multi_modal_data":{"image": imgs} if imgs else None,
+ "multi_modal_data": {"image": imgs} if imgs else None,
"mm_processor_kwargs": {
- "min_pixels": kwargs.get("min_pixels", 4*28*28),
- "max_pixels": kwargs.get("max_pixels", 640*28*28),
+ "min_pixels": kwargs.get("min_pixels", 4 * 28 * 28),
+ "max_pixels": kwargs.get("max_pixels", 640 * 28 * 28),
},
- } for p, imgs in zip(prompts,images)]
+ }
+ for p, imgs in zip(prompts, images)
+ ]
refs.append(
llm.add_requests.remote(rank, sampling_params=sampling_params, vllm_vision_input=vllm_inputs)
)
-
ray.get(refs)
# Make sure all requests are sent.
@@ -901,7 +908,7 @@ def _generate_vllm(self, all_prompts: List[str], all_labels, **kwargs) -> List[S
attention_mask = attention_mask.to("cuda")
action_mask = action_mask.to("cuda")
# Collect for visual input
-
+
visual_inputs = self.data_processor(prompts, self.prompt_max_len, device="cuda")
visual_inputs.pop("input_ids")
visual_inputs.pop("attention_mask")
diff --git a/openrlhf/trainer/ppo_utils/replay_buffer.py b/openrlhf/trainer/ppo_utils/replay_buffer.py
index 481644d2..3c9f31bf 100644
--- a/openrlhf/trainer/ppo_utils/replay_buffer.py
+++ b/openrlhf/trainer/ppo_utils/replay_buffer.py
@@ -6,9 +6,10 @@
import torch
import torch.nn.functional as F
+from openrlhf.models.lmm_kits.base.data_processor import BaseDataProcessor
from .experience_maker import Experience
-from openrlhf.models.lmm_kits.base.data_processor import BaseDataProcessor
+
@dataclass
class BufferItem:
@@ -64,15 +65,14 @@ def split_experience_batch(experience: Experience, data_processor: Optional[Base
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
-
+
visual_inputs_batch = experience.visual_inputs
- visual_inputs_batch['input_ids'] = experience.sequences
+ visual_inputs_batch["input_ids"] = experience.sequences
visual_inputs_chunks = data_processor.split_input_batch(visual_inputs_batch)
for i, visual_inputs in enumerate(visual_inputs_chunks):
- visual_inputs.pop('input_ids')
+ visual_inputs.pop("input_ids")
batch_kwargs[i]["visual_inputs"] = visual_inputs
-
for i in range(batch_size):
batch_kwargs[i]["info"] = {}
for k, v in experience.info.items():
@@ -99,7 +99,9 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> tor
return torch.stack(padded_sequences, dim=0)
-def make_experience_batch(items: List[BufferItem], data_processor: Optional[BaseDataProcessor], packing_samples=False) -> Experience:
+def make_experience_batch(
+ items: List[BufferItem], data_processor: Optional[BaseDataProcessor], packing_samples=False
+) -> Experience:
kwargs = {}
keys = (
"sequences",
@@ -123,7 +125,7 @@ def make_experience_batch(items: List[BufferItem], data_processor: Optional[Base
for key in items[0].info.keys():
vals = torch.tensor([item.info[key] for item in items])
kwargs["info"][key] = vals
-
+
kwargs["visual_inputs"] = data_processor.make_input_batch([item.visual_inputs for item in items])
return Experience(**kwargs)
@@ -177,11 +179,11 @@ class NaiveReplayBuffer(ABC):
"""
def __init__(
- self,
- sample_batch_size: int,
- data_processor: Optional[BaseDataProcessor] = None,
- limit: int = 0,
- cpu_offload: bool = True,
+ self,
+ sample_batch_size: int,
+ data_processor: Optional[BaseDataProcessor] = None,
+ limit: int = 0,
+ cpu_offload: bool = True,
packing_samples: bool = False,
drop_maxlen: bool = False,
maxlen: int = 10**8,
diff --git a/openrlhf/trainer/ray/launcher.py b/openrlhf/trainer/ray/launcher.py
index 9dece250..aecac635 100644
--- a/openrlhf/trainer/ray/launcher.py
+++ b/openrlhf/trainer/ray/launcher.py
@@ -5,7 +5,6 @@
import ray
import torch
-import torch.distributed as dist
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -95,7 +94,7 @@ def forward(
visual_inputs = {}
device = torch.cuda.current_device()
with torch.no_grad():
- visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()}
+ visual_inputs = {k: v.to(device) for k, v in visual_inputs.items()}
log_probs = self.model(
sequences.to(device),
num_actions,
@@ -148,7 +147,7 @@ def forward(
device = torch.cuda.current_device()
if visual_inputs is None:
visual_inputs = {}
- visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()}
+ visual_inputs = {k: v.to(device) for k, v in visual_inputs.items()}
with torch.no_grad():
reward = self.model(
sequences.to(device),
diff --git a/openrlhf/trainer/ray/ppo_actor.py b/openrlhf/trainer/ray/ppo_actor.py
index 982ce991..96f3c5c1 100644
--- a/openrlhf/trainer/ray/ppo_actor.py
+++ b/openrlhf/trainer/ray/ppo_actor.py
@@ -324,10 +324,12 @@ def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain):
if any(name.startswith(prefix) for prefix in strategy.args.freeze_prefix):
param.requires_grad = False
frozen_count += 1
- strategy.print(f"Froze {frozen_count}/{total_params} parameters based on prefixes: {strategy.args.freeze_prefix}")
+ strategy.print(
+ f"Froze {frozen_count}/{total_params} parameters based on prefixes: {strategy.args.freeze_prefix}"
+ )
# configure tokenizer
-
+
self.data_processor = get_data_processor(
pretrain, actor.model, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer
)
diff --git a/openrlhf/trainer/ray/ppo_critic.py b/openrlhf/trainer/ray/ppo_critic.py
index b95f638b..6d56b787 100644
--- a/openrlhf/trainer/ray/ppo_critic.py
+++ b/openrlhf/trainer/ray/ppo_critic.py
@@ -9,9 +9,9 @@
from transformers.trainer import get_scheduler
from openrlhf.models import get_llm_for_sequence_regression
+from openrlhf.models.lmm_kits.utils import get_data_processor
from openrlhf.trainer import PPOTrainer
from openrlhf.trainer.ppo_utils import Experience
-from openrlhf.models.lmm_kits.utils import get_data_processor
from openrlhf.utils.deepspeed import DeepspeedStrategy
from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states
@@ -151,7 +151,7 @@ def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_
prompt_max_len=args.prompt_max_len,
value_clip=args.value_clip,
eps_clip=args.eps_clip,
- data_processor=self.data_processor
+ data_processor=self.data_processor,
)
def forward(
@@ -207,7 +207,6 @@ def save_model(self):
args.save_path + "_critic",
)
-
def save_checkpoint(self, tag):
args = self.strategy.args
self.strategy.save_ckpt(
diff --git a/openrlhf/utils/__init__.py b/openrlhf/utils/__init__.py
index aec56564..60803c08 100644
--- a/openrlhf/utils/__init__.py
+++ b/openrlhf/utils/__init__.py
@@ -1,10 +1,4 @@
from .processor import get_processor, reward_normalization
from .utils import blending_datasets, get_strategy, get_tokenizer
-__all__ = [
- "get_processor",
- "reward_normalization",
- "blending_datasets",
- "get_strategy",
- "get_tokenizer"
-]
+__all__ = ["get_processor", "reward_normalization", "blending_datasets", "get_strategy", "get_tokenizer"]
diff --git a/openrlhf/utils/distributed_sampler.py b/openrlhf/utils/distributed_sampler.py
index 1f765820..df4ddd99 100644
--- a/openrlhf/utils/distributed_sampler.py
+++ b/openrlhf/utils/distributed_sampler.py
@@ -6,7 +6,6 @@
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import Sampler
-
__all__ = ["DistributedSampler"]
diff --git a/openrlhf/utils/logging_utils.py b/openrlhf/utils/logging_utils.py
index eb39f39a..7b19681e 100644
--- a/openrlhf/utils/logging_utils.py
+++ b/openrlhf/utils/logging_utils.py
@@ -1,6 +1,7 @@
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
+
import logging
import sys
diff --git a/openrlhf/utils/remote_rm_utils.py b/openrlhf/utils/remote_rm_utils.py
index bff5a85f..dd33523d 100644
--- a/openrlhf/utils/remote_rm_utils.py
+++ b/openrlhf/utils/remote_rm_utils.py
@@ -37,7 +37,7 @@ def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"):
score_key: RM score key
"""
responses = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key)
- return {k:torch.tensor(v) for k,v in responses.items()}
+ return {k: torch.tensor(v) for k, v in responses.items()}
@ray.remote
@@ -49,4 +49,4 @@ def remote_rm_fn_ray(api_url, queries, prompts, labels, score_key="rewards"):
# test utils
url = "http:xxx/get_rm_score"
score = remote_rm_fn(url, ["example query"], ["example response"])
- print(score)
\ No newline at end of file
+ print(score)
diff --git a/openrlhf/utils/utils.py b/openrlhf/utils/utils.py
index 08b0b160..6698ef88 100644
--- a/openrlhf/utils/utils.py
+++ b/openrlhf/utils/utils.py
@@ -3,6 +3,7 @@
from datasets import interleave_datasets, load_dataset, load_from_disk
from transformers import AutoTokenizer
+
def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True):
tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
tokenizer.padding_side = padding_side
@@ -126,4 +127,4 @@ def convert_token_to_id(token, tokenizer):
assert len(token) == 1
return token[0]
else:
- raise ValueError("token should be int or str")
\ No newline at end of file
+ raise ValueError("token should be int or str")
diff --git a/requirements.txt b/requirements.txt
index 03f9ae27..441e8f28 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,9 +6,10 @@ einops
flask
isort
jsonlines
+levenshtein
+loguru
loralib
math-verify==0.5.2
-levenshtein
optimum
packaging
peft
@@ -16,11 +17,10 @@ pynvml>=12.0.0
qwen_vl_utils
tensorboard
torch==2.5.1
-torchvision
torchmetrics
+torchvision
tqdm
transformers @ git+https://github.com/huggingface/transformers@2ab7bdc40333b230b642f09e8334fb8e1a92d2a4
transformers_stream_generator
wandb
wheel
-loguru
\ No newline at end of file