diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa5e59cf..2b2b10df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: check-yaml - id: check-case-conflict @@ -18,20 +18,20 @@ repos: - id: requirements-txt-fixer - repo: https://github.com/PyCQA/autoflake - rev: v2.0.2 + rev: v2.3.3 hooks: - id: autoflake args: [--remove-all-unused-imports, --in-place] - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: 8.0.1 hooks: - id: isort name: Format imports exclude: docs/ - - repo: https://github.com/psf/black - rev: 24.3.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 26.3.1 hooks: - id: black name: Format code diff --git a/openrlhf/cli/train_dpo.py b/openrlhf/cli/train_dpo.py index a43b1f2b..9703b68a 100644 --- a/openrlhf/cli/train_dpo.py +++ b/openrlhf/cli/train_dpo.py @@ -119,7 +119,7 @@ def train(args): ) # strategy prepare - ((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model) + (model, optim, scheduler), ref_model = strategy.prepare((model, optim, scheduler), ref_model) # load checkpoint consumed_samples = 0 diff --git a/openrlhf/cli/train_kd.py b/openrlhf/cli/train_kd.py index ca30038d..9862841c 100644 --- a/openrlhf/cli/train_kd.py +++ b/openrlhf/cli/train_kd.py @@ -104,7 +104,7 @@ def train(args): ) # prepare models - ((model, optim, scheduler), teacher_model) = strategy.prepare((model, optim, scheduler), teacher_model) + (model, optim, scheduler), teacher_model = strategy.prepare((model, optim, scheduler), teacher_model) # load checkpoint consumed_samples = 0 diff --git a/openrlhf/cli/train_kto.py b/openrlhf/cli/train_kto.py index e5b343f5..ded8d09e 100644 --- a/openrlhf/cli/train_kto.py +++ b/openrlhf/cli/train_kto.py @@ -99,7 +99,7 @@ def train(args): ) # strategy prepare - ((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model) + (model, optim, scheduler), ref_model = strategy.prepare((model, optim, scheduler), ref_model) # load checkpoint consumed_samples = 0 diff --git a/openrlhf/cli/train_ppo.py b/openrlhf/cli/train_ppo.py index f19bdccd..488fb989 100644 --- a/openrlhf/cli/train_ppo.py +++ b/openrlhf/cli/train_ppo.py @@ -344,8 +344,8 @@ def train(args): ) parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") - parser.add_argument("--max_pixels",type=int,default=640*28*28,help="Max pixels for image") - parser.add_argument("--min_pixels",type=int,default=4*28*28,help="Min pixels for image") + parser.add_argument("--max_pixels", type=int, default=640 * 28 * 28, help="Max pixels for image") + parser.add_argument("--min_pixels", type=int, default=4 * 28 * 28, help="Min pixels for image") # DeepSpeed parser.add_argument("--seed", type=int, default=42) parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") @@ -362,8 +362,12 @@ def train(args): parser.add_argument("--overlap_comm", action="store_true", default=False) parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) - parser.add_argument("--freeze_prefix", type=str, nargs="+", default=None, - help="List of parameter name prefixes to freeze during training" + parser.add_argument( + "--freeze_prefix", + type=str, + nargs="+", + default=None, + help="List of parameter name prefixes to freeze during training", ) parser.add_argument("--drop_maxlen", action="store_true", default=False) diff --git a/openrlhf/cli/train_ppo_ray.py b/openrlhf/cli/train_ppo_ray.py index 631765a5..dcbc8b86 100644 --- a/openrlhf/cli/train_ppo_ray.py +++ b/openrlhf/cli/train_ppo_ray.py @@ -337,12 +337,16 @@ def train(args): parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") - parser.add_argument("--freeze_prefix", type=str, nargs="+", default=None, - help="List of parameter name prefixes to freeze during training" + parser.add_argument( + "--freeze_prefix", + type=str, + nargs="+", + default=None, + help="List of parameter name prefixes to freeze during training", ) parser.add_argument("--drop_maxlen", action="store_true", default=False) - parser.add_argument("--max_pixels",type=int,default=640*28*28,help="Max pixels for image") - parser.add_argument("--min_pixels",type=int,default=4*28*28,help="Min pixels for image") + parser.add_argument("--max_pixels", type=int, default=640 * 28 * 28, help="Max pixels for image") + parser.add_argument("--min_pixels", type=int, default=4 * 28 * 28, help="Min pixels for image") # Reinforce parser.add_argument( "--advantage_estimator", diff --git a/openrlhf/cli/train_prm.py b/openrlhf/cli/train_prm.py index b8d8278c..284ffab6 100644 --- a/openrlhf/cli/train_prm.py +++ b/openrlhf/cli/train_prm.py @@ -87,7 +87,7 @@ def train(args): ) # prepare models - (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + model, optim, scheduler = strategy.prepare((model, optim, scheduler)) # load checkpoint consumed_samples = 0 diff --git a/openrlhf/cli/train_rm.py b/openrlhf/cli/train_rm.py index eb6e4f84..ded8243d 100644 --- a/openrlhf/cli/train_rm.py +++ b/openrlhf/cli/train_rm.py @@ -106,7 +106,7 @@ def train(args): ) # strategy prepare - (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + model, optim, scheduler = strategy.prepare((model, optim, scheduler)) # load checkpoint consumed_samples = 0 diff --git a/openrlhf/cli/train_sft.py b/openrlhf/cli/train_sft.py index cb6e7619..3ef4fa67 100644 --- a/openrlhf/cli/train_sft.py +++ b/openrlhf/cli/train_sft.py @@ -105,7 +105,7 @@ def train(args): ) # prepare models - (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + model, optim, scheduler = strategy.prepare((model, optim, scheduler)) # load checkpoint consumed_samples = 0 diff --git a/openrlhf/models/actor.py b/openrlhf/models/actor.py index 1d63fc2d..936560a1 100644 --- a/openrlhf/models/actor.py +++ b/openrlhf/models/actor.py @@ -3,17 +3,18 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.nn import functional as F +from flash_attn.utils.distributed import all_gather from peft import LoraConfig, TaskType, get_peft_model from peft.tuners.lora import LoraLayer -from transformers import BitsAndBytesConfig, AutoConfig +from torch.nn import functional as F +from transformers import AutoConfig, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig -from flash_attn.utils.distributed import all_gather -from .ring_attn_utils import convert_ring_attn_params, set_hacked_position_ids, clear_hacked_position_ids -from .utils import log_probs_from_logits, reset_position_ids from openrlhf.models.lmm_kits.utils import get_generation_cls +from .ring_attn_utils import clear_hacked_position_ids, convert_ring_attn_params, set_hacked_position_ids +from .utils import log_probs_from_logits, reset_position_ids + class Actor(nn.Module): """ @@ -73,7 +74,7 @@ def __init__( else: nf4_config = None - #There is no AutoModelForConditionalGeneration in transformers. We manually implement it. + # There is no AutoModelForConditionalGeneration in transformers. We manually implement it. config = AutoConfig.from_pretrained(pretrain_or_model) model_cls = get_generation_cls(config) self.model = model_cls.from_pretrained( @@ -200,27 +201,31 @@ def forward( """Returns action log probs""" if visual_inputs is None: visual_inputs = {} - ''' + """ for k,v in visual_inputs.items(): if v.dtype == torch.float32: visual_inputs[k] = v.to(self.model.get_input_embeddings().weight.dtype) - ''' + """ inputs_embeds = self.model.get_inputs_embeds(sequences, **visual_inputs) if not self.packing_samples: # https://github.com/OpenRLHF/OpenRLHF/issues/217 - #position_ids = attention_mask.long().cumsum(-1) - 1 - #position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = self.model.get_position_ids(sequences,attention_mask=attention_mask, **visual_inputs) + # position_ids = attention_mask.long().cumsum(-1) - 1 + # position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = self.model.get_position_ids(sequences, attention_mask=attention_mask, **visual_inputs) else: # convert attention_mask to position_ids packed_position_ids = self.model.get_position_ids(sequences, **visual_inputs) if ring_attn_group is not None: labels = sequences - sequences, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params( - sequences, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids + sequences, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = ( + convert_ring_attn_params( + sequences, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids + ) ) - position_ids = self.model.offset_split_position_ids(split_position_ids, hacked_position_ids) # this is true position_ids - #position_ids is directly hacked into flash_attn_forward to distinguish between different sequences + position_ids = self.model.offset_split_position_ids( + split_position_ids, hacked_position_ids + ) # this is true position_ids + # position_ids is directly hacked into flash_attn_forward to distinguish between different sequences else: hacked_position_ids = reset_position_ids(attention_mask) position_ids = self.model.offset_split_position_ids(packed_position_ids, hacked_position_ids) @@ -228,7 +233,9 @@ def forward( set_hacked_position_ids(hacked_position_ids) # explicitly ignore attention_mask for packing_samples attention_mask = None - output = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **visual_inputs) + output = self.model( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **visual_inputs + ) clear_hacked_position_ids() # https://github.com/OpenRLHF/OpenRLHF/pull/634 output["logits"] = output["logits"].to(torch.float32) diff --git a/openrlhf/models/lmm_kits/base/data_processor.py b/openrlhf/models/lmm_kits/base/data_processor.py index 40dbf835..2531c651 100644 --- a/openrlhf/models/lmm_kits/base/data_processor.py +++ b/openrlhf/models/lmm_kits/base/data_processor.py @@ -1,17 +1,19 @@ import json -import os from abc import ABC, abstractmethod -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union + import torch -from transformers.processing_utils import ProcessorMixin from qwen_vl_utils import process_vision_info +from transformers.processing_utils import ProcessorMixin + class BaseDataProcessor(ABC): - def __init__(self, processor: ProcessorMixin,min_pixels:int,max_pixels:int): + def __init__(self, processor: ProcessorMixin, min_pixels: int, max_pixels: int): super().__init__() self.processor = processor self.min_pixels = min_pixels self.max_pixels = max_pixels + @abstractmethod def __call__( self, @@ -25,24 +27,24 @@ def __call__( ) -> Dict: raise NotImplementedError - def _add_pixel_bounds(self,messages:List[Dict]) -> List[Dict]: - DEFAULT_MIN_PIXELS = self.min_pixels - DEFAULT_MAX_PIXELS = self.max_pixels - - def process_content(content): - if isinstance(content, list): - for item in content: - if isinstance(item, dict) and item.get("type") == "image": - if "min_pixels" not in item: - item["min_pixels"] = DEFAULT_MIN_PIXELS - if "max_pixels" not in item: - item["max_pixels"] = DEFAULT_MAX_PIXELS - return content - - for message in messages: - for msg in message: - msg["content"] = process_content(msg["content"]) - return messages + def _add_pixel_bounds(self, messages: List[Dict]) -> List[Dict]: + DEFAULT_MIN_PIXELS = self.min_pixels + DEFAULT_MAX_PIXELS = self.max_pixels + + def process_content(content): + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "image": + if "min_pixels" not in item: + item["min_pixels"] = DEFAULT_MIN_PIXELS + if "max_pixels" not in item: + item["max_pixels"] = DEFAULT_MAX_PIXELS + return content + + for message in messages: + for msg in message: + msg["content"] = process_content(msg["content"]) + return messages @abstractmethod def make_input_batch(self, inputs: List[Dict]) -> Dict: @@ -70,19 +72,16 @@ def apply_chat_template( add_generation_prompt: bool = True, ) -> List[str]: messages = self._format_messages(messages) - + return self.processor.apply_chat_template( messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt ) - def get_images_from_messages( - self, messages: Union[Dict, List[str], str] - ) -> List[Dict]: + def get_images_from_messages(self, messages: Union[Dict, List[str], str]) -> List[Dict]: messages = self._format_messages(messages) image_inputs, _ = process_vision_info(messages) return image_inputs - @property def pad_token_id(self) -> int: return self.processor.tokenizer.pad_token_id @@ -93,4 +92,4 @@ def eos_token_id(self) -> int: @property def tokenizer(self): - return self.processor.tokenizer \ No newline at end of file + return self.processor.tokenizer diff --git a/openrlhf/models/lmm_kits/base/patch.py b/openrlhf/models/lmm_kits/base/patch.py index 561250d5..843aa6e2 100644 --- a/openrlhf/models/lmm_kits/base/patch.py +++ b/openrlhf/models/lmm_kits/base/patch.py @@ -1,38 +1,40 @@ from abc import ABC, abstractmethod + class BasePatch(ABC): def __init__(self): self.loaded = False + @abstractmethod def _add_get_inputs_embeds(): - ''' - Add a `get_inputs_embeds(*args,**kwargs)` method to the model class, + """ + Add a `get_inputs_embeds(*args,**kwargs)` method to the model class, which embeds image embeddings into the text embeddings and return the results. - ''' + """ return NotImplementedError @abstractmethod def _add_get_position_ids(): - ''' - Add a `get_posiiton_ids(*args,**kwargs)` method to the model class, + """ + Add a `get_posiiton_ids(*args,**kwargs)` method to the model class, which return the position_ids of the given inputs. - ''' + """ return NotImplementedError @abstractmethod def _add_offset_split_position_ids(): - ''' - Add a `offset_split_position_ids(*args,**kwargs)` method to the model class, + """ + Add a `offset_split_position_ids(*args,**kwargs)` method to the model class, which offset the split position_ids to true position_ids. - ''' + """ return NotImplementedError @classmethod @abstractmethod def _load_all_patches(cls): - ''' + """ Load all patches. - ''' + """ return NotImplementedError def load_all_patches(self): diff --git a/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py b/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py index e3ea7d1e..3f3aea31 100644 --- a/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py +++ b/openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py @@ -1,5 +1,4 @@ -import os -from typing import List, Dict +from typing import Dict, List import torch from qwen_vl_utils import process_vision_info @@ -20,9 +19,7 @@ def __call__( ) -> Dict: messages = self._format_messages(messages) processor = self.processor - texts = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) batch = processor( @@ -44,7 +41,7 @@ def make_input_batch(self, inputs: List[Dict]) -> Dict: batch = {} # collect all keys for inp in inputs: - batch.update({k:None for k,v in inp.items() if v is not None}) + batch.update({k: None for k, v in inp.items() if v is not None}) for k in batch.keys(): if k in ["input_ids", "attention_mask"]: batch[k] = torch.stack([inp[k] for inp in inputs if k in inp], dim=0) @@ -67,12 +64,8 @@ def split_input_batch(self, batch: Dict) -> List[Dict]: for i in range(batch_size): batch_kwargs[i][k] = None - if "pixel_values" in keys and ( - "input_ids" not in keys or "image_grid_thw" not in keys - ): - raise ValueError( - "Cannot split batch with pixel_values without input_ids and image_grid_thw" - ) + if "pixel_values" in keys and ("input_ids" not in keys or "image_grid_thw" not in keys): + raise ValueError("Cannot split batch with pixel_values without input_ids and image_grid_thw") if "image_grid_thw" in keys and ("input_ids" not in keys): raise ValueError("Cannot split batch with image_grid_thw without input_ids") for k in ["input_ids", "attention_mask"]: @@ -114,7 +107,8 @@ def split_input_batch(self, batch: Dict) -> List[Dict]: assert len(thws) == 0 assert len(pixel_values) == 0 return batch_kwargs - + + DataProcessor = Qwen2_5_VLDataProcessor -__all__ = ["DataProcessor"] \ No newline at end of file +__all__ = ["DataProcessor"] diff --git a/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py b/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py index 83c4a4e5..4c05bb4b 100644 --- a/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py +++ b/openrlhf/models/lmm_kits/qwen2_5_vl/patch.py @@ -1,10 +1,21 @@ -from ..base.patch import BasePatch import torch +from ..base.patch import BasePatch + + class Qwen2_5_VLPatch(BasePatch): def _add_get_inputs_embeds(): from transformers import Qwen2_5_VLForConditionalGeneration - def get_inputs_embeds(self, input_ids, image_grid_thw=None, video_grid_thw=None, pixel_values=None, pixel_values_videos=None, **kwargs): + + def get_inputs_embeds( + self, + input_ids, + image_grid_thw=None, + video_grid_thw=None, + pixel_values=None, + pixel_values_videos=None, + **kwargs, + ): inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) @@ -47,33 +58,47 @@ def get_inputs_embeds(self, input_ids, image_grid_thw=None, video_grid_thw=None, def _add_get_position_ids(): from transformers import Qwen2_5_VLForConditionalGeneration + def get_position_ids(self, input_ids, image_grid_thw=None, video_grid_thw=None, attention_mask=None, **kwargs): - position_ids,mrope_position_deltas = self.get_rope_index(input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask) + position_ids, mrope_position_deltas = self.get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) return position_ids + Qwen2_5_VLForConditionalGeneration.get_position_ids = get_position_ids def _add_offset_split_position_ids(): from transformers import Qwen2_5_VLForConditionalGeneration - def offset_split_position_ids(self,position_ids,hacked_position_ids): + + def offset_split_position_ids(self, position_ids, hacked_position_ids): new_position_ids = position_ids.clone() for i in range(hacked_position_ids.size(0)): - seq_idxes = torch.nonzero(hacked_position_ids[i]==0)[:,0] - seq_idxes = torch.cat([seq_idxes, torch.tensor([hacked_position_ids.size(1)],device=seq_idxes.device)], dim=0) + seq_idxes = torch.nonzero(hacked_position_ids[i] == 0)[:, 0] + seq_idxes = torch.cat( + [seq_idxes, torch.tensor([hacked_position_ids.size(1)], device=seq_idxes.device)], dim=0 + ) st = 0 for seq_idx in seq_idxes: if st == 0 and seq_idx == 0: continue - #shape: [3,bs,seq_len] - raw_seq_position_ids = position_ids[:,i,st:seq_idx] - new_position_ids[:,i,st:seq_idx] = raw_seq_position_ids - raw_seq_position_ids[:,:1] + hacked_position_ids[i,st] + # shape: [3,bs,seq_len] + raw_seq_position_ids = position_ids[:, i, st:seq_idx] + new_position_ids[:, i, st:seq_idx] = ( + raw_seq_position_ids - raw_seq_position_ids[:, :1] + hacked_position_ids[i, st] + ) st = seq_idx return new_position_ids + Qwen2_5_VLForConditionalGeneration.offset_split_position_ids = offset_split_position_ids - + @classmethod def _load_all_patches(cls): cls._add_get_inputs_embeds() cls._add_get_position_ids() cls._add_offset_split_position_ids() -Patch = Qwen2_5_VLPatch() \ No newline at end of file + +Patch = Qwen2_5_VLPatch() diff --git a/openrlhf/models/lmm_kits/utils.py b/openrlhf/models/lmm_kits/utils.py index 1227480e..6cae1454 100644 --- a/openrlhf/models/lmm_kits/utils.py +++ b/openrlhf/models/lmm_kits/utils.py @@ -1,18 +1,22 @@ -from transformers import AutoConfig, AutoProcessor, AutoModel import importlib -import os -def _get_kit_root_path(pretrain_or_model=None,model_type=None): +from transformers import AutoConfig, AutoModel, AutoProcessor + + +def _get_kit_root_path(pretrain_or_model=None, model_type=None): if model_type is None: config = AutoConfig.from_pretrained(pretrain_or_model) model_type = config.model_type root_path = f".models.lmm_kits.{model_type}" return root_path + def _get_hf_processor(pretrain, model, padding_side="left", strategy=None, use_fast=True): min_pixels = strategy.args.min_pixels max_pixels = strategy.args.max_pixels - processor = AutoProcessor.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast, min_pixels=min_pixels, max_pixels=max_pixels) + processor = AutoProcessor.from_pretrained( + pretrain, trust_remote_code=True, use_fast=use_fast, min_pixels=min_pixels, max_pixels=max_pixels + ) tokenizer = processor.tokenizer tokenizer.padding_side = padding_side # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. @@ -23,34 +27,41 @@ def _get_hf_processor(pretrain, model, padding_side="left", strategy=None, use_f model.config.pad_token_id = tokenizer.pad_token_id return processor + def get_data_processor(pretrain_or_model, model, padding_side="left", strategy=None, use_fast=True): root_path = _get_kit_root_path(pretrain_or_model) - module = importlib.import_module(f"{root_path}.data_processor",package="openrlhf") + module = importlib.import_module(f"{root_path}.data_processor", package="openrlhf") data_processor_cls = getattr(module, "DataProcessor") - hf_processor = _get_hf_processor(pretrain_or_model, model, padding_side, strategy,use_fast=use_fast) - data_processor = data_processor_cls(hf_processor,min_pixels=strategy.args.min_pixels,max_pixels=strategy.args.max_pixels) + hf_processor = _get_hf_processor(pretrain_or_model, model, padding_side, strategy, use_fast=use_fast) + data_processor = data_processor_cls( + hf_processor, min_pixels=strategy.args.min_pixels, max_pixels=strategy.args.max_pixels + ) return data_processor -def load_patch(pretrain_or_model=None,model_type=None): - root_path = _get_kit_root_path(pretrain_or_model,model_type) - module = importlib.import_module(f"{root_path}.patch",package="openrlhf") + +def load_patch(pretrain_or_model=None, model_type=None): + root_path = _get_kit_root_path(pretrain_or_model, model_type) + module = importlib.import_module(f"{root_path}.patch", package="openrlhf") Patch = getattr(module, "Patch") Patch.load_all_patches() + def get_generation_cls(config): model_type = config.model_type load_patch(model_type=model_type) model_arch = AutoModel._model_mapping[type(config)].__name__ - if model_arch.endswith("ForCausalLM") or \ - model_arch.endswith("ForConditionalGeneration"): + if model_arch.endswith("ForCausalLM") or model_arch.endswith("ForConditionalGeneration"): return AutoModel._model_mapping[type(config)] elif model_arch.endswith("Model"): - possible_arch = [model_arch.replace("Model", "ForCausalLM"), model_arch.replace("Model", "ForConditionalGeneration")] - module = importlib.import_module(f".models.{model_type}.modeling_{model_type}",package="transformers") + possible_arch = [ + model_arch.replace("Model", "ForCausalLM"), + model_arch.replace("Model", "ForConditionalGeneration"), + ] + module = importlib.import_module(f".models.{model_type}.modeling_{model_type}", package="transformers") for arch in possible_arch: model_cls = getattr(module, arch, None) if model_cls is not None: return model_cls raise ValueError(f"Cannot find ForCausalLM or ForConditionalGeneration class for {model_arch}") else: - raise ValueError(f"Unexpected model architecture {model_arch}") \ No newline at end of file + raise ValueError(f"Unexpected model architecture {model_arch}") diff --git a/openrlhf/models/model.py b/openrlhf/models/model.py index 6b7dd054..3477fe0b 100644 --- a/openrlhf/models/model.py +++ b/openrlhf/models/model.py @@ -6,14 +6,14 @@ from flash_attn.utils.distributed import all_gather from peft import LoraConfig, get_peft_model from peft.tuners.lora import LoraLayer -from transformers import AutoConfig, AutoModel, BitsAndBytesConfig +from transformers import AutoConfig, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig +from openrlhf.models.lmm_kits.utils import get_generation_cls from openrlhf.utils.logging_utils import init_logger -from .ring_attn_utils import convert_ring_attn_params, set_hacked_position_ids, clear_hacked_position_ids +from .ring_attn_utils import clear_hacked_position_ids, convert_ring_attn_params, set_hacked_position_ids from .utils import reset_position_ids -from openrlhf.models.lmm_kits.utils import get_generation_cls logger = init_logger(__name__) @@ -193,18 +193,25 @@ def forward( inputs_embeds = super().get_inputs_embeds(input_ids, **visual_inputs) if not self.packing_samples: # https://github.com/OpenRLHF/OpenRLHF/issues/217 - #position_ids = attention_mask.long().cumsum(-1) - 1 - #position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = super().get_position_ids(input_ids,attention_mask=attention_mask, **visual_inputs) + # position_ids = attention_mask.long().cumsum(-1) - 1 + # position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = super().get_position_ids(input_ids, attention_mask=attention_mask, **visual_inputs) else: # convert attention_mask to position_ids packed_position_ids = super().get_position_ids(input_ids, **visual_inputs) if ring_attn_group is not None: - input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params( - input_ids, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids + input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = ( + convert_ring_attn_params( + input_ids, + attention_mask, + packed_seq_lens, + ring_attn_group, + inputs_embeds, + packed_position_ids, + ) ) position_ids = super().offset_split_position_ids(split_position_ids, hacked_position_ids) - + else: hacked_position_ids = reset_position_ids(attention_mask) position_ids = super().offset_split_position_ids(packed_position_ids, hacked_position_ids) @@ -214,7 +221,11 @@ def forward( attention_mask = None outputs = super().forward( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + **visual_inputs, ) clear_hacked_position_ids() if "last_hidden_state" in outputs: @@ -288,17 +299,24 @@ def forward( # https://github.com/OpenRLHF/OpenRLHF/issues/217 # position_ids = attention_mask.long().cumsum(-1) - 1 # position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = super().get_position_ids(input_ids,attention_mask=attention_mask, **visual_inputs) + position_ids = super().get_position_ids(input_ids, attention_mask=attention_mask, **visual_inputs) else: # convert attention_mask to position_ids packed_position_ids = super().get_position_ids(input_ids, **visual_inputs) if ring_attn_group is not None: - - input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = convert_ring_attn_params( - input_ids, attention_mask, packed_seq_lens, ring_attn_group, inputs_embeds, packed_position_ids + + input_ids, attention_mask, hacked_position_ids, inputs_embeds, split_position_ids = ( + convert_ring_attn_params( + input_ids, + attention_mask, + packed_seq_lens, + ring_attn_group, + inputs_embeds, + packed_position_ids, + ) ) position_ids = super().offset_split_position_ids(split_position_ids, hacked_position_ids) - + else: hacked_position_ids = reset_position_ids(attention_mask) position_ids = super().offset_split_position_ids(packed_position_ids, hacked_position_ids) @@ -308,7 +326,11 @@ def forward( attention_mask = None outputs = super().forward( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + **visual_inputs, ) clear_hacked_position_ids() if "last_hidden_state" in outputs: diff --git a/openrlhf/models/remote_rm/math_verifier.py b/openrlhf/models/remote_rm/math_verifier.py index 232e0619..d39ce27f 100644 --- a/openrlhf/models/remote_rm/math_verifier.py +++ b/openrlhf/models/remote_rm/math_verifier.py @@ -1,17 +1,15 @@ import json -import os import random import re from argparse import ArgumentParser -from multiprocessing import Process, Queue +from concurrent import futures import Levenshtein from flask import Flask, jsonify, request from latex2sympy2_extended import NormalizationConfig +from loguru import logger from math_verify import LatexExtractionConfig, parse, verify -from loguru import logger -from concurrent import futures app = Flask(__name__) problem_to_answer = {} @@ -40,7 +38,6 @@ def verify_format(content): return bool(re.match(format_pattern, content, re.DOTALL)) and think_count == 1 and answer_count == 1 - def find_similar_problem(problem): max_sim = -1 target_problem = None @@ -52,7 +49,7 @@ def find_similar_problem(problem): return target_problem -def verify_math(content,sol): +def verify_math(content, sol): gold_parsed = parse( sol, extraction_mode="first_match", @@ -102,7 +99,7 @@ def get_reward(): rewards = [] format_rewards = [] acc_rewards_futures = [] - for q,problem in zip(data["query"],data["prompts"]): + for q, problem in zip(data["query"], data["prompts"]): if problem is None: return jsonify({"error": f"problem not found from {q}"}), 400 if problem not in problem_to_answer: @@ -115,39 +112,33 @@ def get_reward(): return jsonify({"error": f"response not found from {q}"}), 400 format_reward = float(verify_format(response)) * 0.5 acc_reward_future = math_verify_executor.submit(verify_math, response, answer) - + do_print = random.randint(1, 20) == 1 if do_print: - info=f"Query: {q}\n\nProblem: {problem}\n\n Answer: {answer}\n\n Response: {response}\n\n Format Reward: {format_reward}\n\n Acc Reward: {acc_reward_future.result()}\n\n" - info = re.sub(r"<\|.*?\|>","",info) + info = f"Query: {q}\n\nProblem: {problem}\n\n Answer: {answer}\n\n Response: {response}\n\n Format Reward: {format_reward}\n\n Acc Reward: {acc_reward_future.result()}\n\n" + info = re.sub(r"<\|.*?\|>", "", info) logger.info(info) - + format_rewards.append(format_reward) acc_rewards_futures.append(acc_reward_future) acc_rewards = [f.result() for f in acc_rewards_futures] rewards = [f + a for f, a in zip(format_rewards, acc_rewards)] # 返回包含 rewards 的响应 - return jsonify({"rewards": rewards,"format_rewards":format_rewards,"acc_rewards":acc_rewards}) + return jsonify({"rewards": rewards, "format_rewards": format_rewards, "acc_rewards": acc_rewards}) if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument( - "--dataset", type=str, default=None, help="Datasets to use (comma separated)", required=True - ) - parser.add_argument( - "--prompt-template", type=str, default=None, help="Prompt template", required=True - ) - parser.add_argument( - "--input_key", type=str, default="prompt", help="The key name of prompt." - ) + parser.add_argument("--dataset", type=str, default=None, help="Datasets to use (comma separated)", required=True) + parser.add_argument("--prompt-template", type=str, default=None, help="Prompt template", required=True) + parser.add_argument("--input_key", type=str, default="prompt", help="The key name of prompt.") parser.add_argument("--log_file", type=str, default="remote_rm.log", help="Log file path") args = parser.parse_args() logger.remove() logger.add(args.log_file) # Split dataset paths and load all datasets dataset = [] - for dataset_path in args.dataset.split(','): + for dataset_path in args.dataset.split(","): dataset_path = dataset_path.strip() if dataset_path.endswith("json"): with open(dataset_path, "r") as f: @@ -160,13 +151,13 @@ def get_reward(): format_pattern = r"^(?:(?!).)*(?:(?!).)*\Z" - if args.prompt_template=="chatml": + if args.prompt_template == "chatml": problem_pattern = r"<\|im_start\|>user\n(.*?)<\|im_end\|>" response_prefix = r"<\|im_start\|>assistant\n" - elif args.prompt_template=="qwen1": + elif args.prompt_template == "qwen1": problem_pattern = r"|User|>(.*?)<|Assistant|>" response_prefix = r"<|Assistant|>" - elif args.prompt_template=="base": + elif args.prompt_template == "base": problem_pattern = r"User: (.*?)\n\nAssistant:" response_prefix = r"Assistant: " else: @@ -184,4 +175,4 @@ def get_reward(): math_verify_executor = futures.ProcessPoolExecutor(max_workers=16) app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False) - math_verify_executor.shutdown() \ No newline at end of file + math_verify_executor.shutdown() diff --git a/openrlhf/models/ring_attn_utils.py b/openrlhf/models/ring_attn_utils.py index 4fd0eadf..df32b722 100644 --- a/openrlhf/models/ring_attn_utils.py +++ b/openrlhf/models/ring_attn_utils.py @@ -3,7 +3,6 @@ import torch.nn.functional as F import transformers - RING_ATTN_GROUP = None @@ -71,7 +70,7 @@ def convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_at sequences = sequences[:, start:end] inputs_embeds = inputs_embeds[:, start:end] attention_mask = attention_mask[:, start:end] - position_ids = position_ids[..., start:end] #qwen2_5_vl has position_ids shape: [3,bs,seq_len] + position_ids = position_ids[..., start:end] # qwen2_5_vl has position_ids shape: [3,bs,seq_len] hacked_position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens) update_ring_attn_params(packed_seq_lens, total_seq_len) return sequences, attention_mask, hacked_position_ids, inputs_embeds, position_ids @@ -125,27 +124,33 @@ def unpad_sequences( kl = kl[:, :-pad_len] return sequences, attention_mask, num_actions, packed_seq_lens, action_log_probs, values, kl + HACKED_POSITION_IDS = None -#Both ring and our hack substitute flash_attn. This func must be called after ring's substitue_hf_flash_attn. + +# Both ring and our hack substitute flash_attn. This func must be called after ring's substitue_hf_flash_attn. def substitute_ring_flash_attn(): raw_flash_attention_forward = transformers.modeling_flash_attention_utils._flash_attention_forward - def _hacked_flash_attention_forward(*args,**kwargs): + + def _hacked_flash_attention_forward(*args, **kwargs): global HACKED_POSITION_IDS if HACKED_POSITION_IDS is not None: - kwargs['position_ids'] = HACKED_POSITION_IDS - return raw_flash_attention_forward(*args,**kwargs) + kwargs["position_ids"] = HACKED_POSITION_IDS + return raw_flash_attention_forward(*args, **kwargs) transformers.modeling_flash_attention_utils._flash_attention_forward = _hacked_flash_attention_forward + def set_hacked_position_ids(position_ids): global HACKED_POSITION_IDS HACKED_POSITION_IDS = position_ids + def clear_hacked_position_ids(): global HACKED_POSITION_IDS HACKED_POSITION_IDS = None + def get_hacked_position_ids(): global HACKED_POSITION_IDS return HACKED_POSITION_IDS diff --git a/openrlhf/trainer/ppo_trainer.py b/openrlhf/trainer/ppo_trainer.py index c0d7dd32..ffaad07e 100644 --- a/openrlhf/trainer/ppo_trainer.py +++ b/openrlhf/trainer/ppo_trainer.py @@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, List, Optional import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from torch.optim import Optimizer from torch.utils.data import DataLoader from tqdm import tqdm @@ -107,7 +107,6 @@ def __init__( self.tokenizer = data_processor.tokenizer self.processor = data_processor.processor - self.generate_kwargs = generate_kwargs self.dataloader_pin_memory = dataloader_pin_memory self.max_norm = max_norm @@ -159,8 +158,12 @@ def __init__( ) packing_samples = getattr(self.args, "packing_samples", False) self.replay_buffer = NaiveReplayBuffer( - micro_train_batch_size, self.data_processor, buffer_limit, buffer_cpu_offload, packing_samples, - drop_maxlen=self.args.drop_maxlen, + micro_train_batch_size, + self.data_processor, + buffer_limit, + buffer_cpu_offload, + packing_samples, + drop_maxlen=self.args.drop_maxlen, maxlen=self.args.generate_max_len + prompt_max_len, ) @@ -359,11 +362,7 @@ def training_step_actor(self, experience: Experience) -> Dict[str, float]: # pad seq makes the sequence a multiple of ring_attention_size. if self.strategy.ring_attn_group is not None: pad_len, sequences, attention_mask, num_actions, packed_seq_lens = pad_sequences( - sequences, - attention_mask, - num_actions, - packed_seq_lens, - self.strategy.ring_attn_group + sequences, attention_mask, num_actions, packed_seq_lens, self.strategy.ring_attn_group ) if self.args.use_kl_loss and experience.base_action_log_probs is not None: base_action_log_probs = torch.cat(experience.base_action_log_probs, dim=0).unsqueeze(0) @@ -387,7 +386,7 @@ def training_step_actor(self, experience: Experience) -> Dict[str, float]: ring_attn_group=self.strategy.ring_attn_group, logps_allgather=True, packed_seq_lens=packed_seq_lens, - visual_inputs=visual_inputs + visual_inputs=visual_inputs, ) # unpad sequence ensures that pad tokens do not contribute to the loss calculation. if self.strategy.ring_attn_group is not None: @@ -501,11 +500,7 @@ def training_step_critic(self, experience: Experience) -> Dict[str, float]: # pad seq makes the sequence len a multiple of ring_attention_size. if self.strategy.ring_attn_group is not None: pad_len, sequences, attention_mask, num_actions, packed_seq_lens = pad_sequences( - sequences, - attention_mask, - num_actions, - packed_seq_lens, - self.strategy.ring_attn_group + sequences, attention_mask, num_actions, packed_seq_lens, self.strategy.ring_attn_group ) else: @@ -584,6 +579,7 @@ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, c if self.experience_maker.perf_stats is not None: logs.update({f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()}) from wandb import Histogram + response_length_list = Histogram(response_length_list) logs["response_length_dist"] = response_length_list self._wandb.log(logs) diff --git a/openrlhf/trainer/ppo_utils/experience_maker.py b/openrlhf/trainer/ppo_utils/experience_maker.py index 52456921..b797dfff 100644 --- a/openrlhf/trainer/ppo_utils/experience_maker.py +++ b/openrlhf/trainer/ppo_utils/experience_maker.py @@ -1,22 +1,18 @@ -import os import time from abc import ABC from copy import deepcopy from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union, Dict +from typing import Dict, List, Optional, Tuple, Union import ray import torch import torch.distributed as dist import torch.nn as nn -import torch.distributed as dist -import torch.nn.functional as F from tqdm import tqdm from openrlhf.models.actor import Actor from openrlhf.models.ring_attn_utils import pad_sequences, unpad_sequences from openrlhf.models.utils import compute_approx_kl, compute_reward, masked_mean, unpacking_samples -from openrlhf.models.ring_attn_utils import pad_sequences, unpad_sequences from openrlhf.utils.logging_utils import init_logger from openrlhf.utils.remote_rm_utils import remote_rm_fn, remote_rm_fn_ray @@ -182,7 +178,6 @@ def __init__( spec.loader.exec_module(reward_module) self.custom_reward_func = reward_module.reward_func - @torch.no_grad() def make_experience_list( self, all_prompts: Union[str, List[str]], all_labels, **generate_kwargs @@ -298,14 +293,13 @@ def generate_samples(self, all_prompts: List[str], all_labels, **generate_kwargs samples_list = [] for i in range(0, len(all_prompts), args.micro_rollout_batch_size): prompts = all_prompts[i : i + args.micro_rollout_batch_size] - + inputs = self.data_processor(prompts, self.prompt_max_len, device="cuda") visual_inputs = {} - for k,v in inputs.items(): + for k, v in inputs.items(): if k not in ["input_ids", "attention_mask"]: visual_inputs[k] = v - labels = all_labels[i : i + args.micro_rollout_batch_size] sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs) samples = Samples( @@ -349,7 +343,9 @@ def make_experience(self, samples: Samples) -> Experience: # init log probs if self.initial_model is not None: - base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask, visual_inputs=visual_inputs) + base_action_log_probs = self.initial_model( + sequences, num_actions, attention_mask, visual_inputs=visual_inputs + ) else: base_action_log_probs = None @@ -385,7 +381,7 @@ def make_experience(self, samples: Samples) -> Experience: else: kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device) - assert isinstance(r,dict) + assert isinstance(r, dict) total_reward = r.pop("rewards") specific_rewards = r @@ -395,7 +391,7 @@ def make_experience(self, samples: Samples) -> Experience: "response_length": samples.response_length, "total_length": samples.total_length, "num_actions": num_actions, - **specific_rewards + **specific_rewards, } # reset model state self.actor.train() @@ -413,7 +409,7 @@ def make_experience(self, samples: Samples) -> Experience: action_mask, info, kl, - visual_inputs=visual_inputs + visual_inputs=visual_inputs, ) @torch.no_grad() @@ -621,17 +617,17 @@ def make_experience(self, samples: Samples) -> Experience: ) visual_inputs_cpu = None if visual_inputs is not None: - visual_inputs_cpu = {k: v.to("cpu") for k, v in visual_inputs.items()} + visual_inputs_cpu = {k: v.to("cpu") for k, v in visual_inputs.items()} # init log probs if self.initial_model is not None: base_action_log_probs_ref = self.initial_model.forward.remote( - sequences_cpu, - num_actions, - attention_mask_cpu, - logps_allgather=True, - packed_seq_lens=packed_seq_lens, - visual_inputs=visual_inputs_cpu - ) + sequences_cpu, + num_actions, + attention_mask_cpu, + logps_allgather=True, + packed_seq_lens=packed_seq_lens, + visual_inputs=visual_inputs_cpu, + ) if args.colocate_actor_ref or args.colocate_all_models: ray.get([base_action_log_probs_ref]) @@ -642,7 +638,11 @@ def make_experience(self, samples: Samples) -> Experience: # values if self.critic is not None: value_ref = self.critic.forward.remote( - sequences_cpu, num_actions, attention_mask_cpu, packed_seq_lens=packed_seq_lens, visual_inputs=visual_inputs_cpu + sequences_cpu, + num_actions, + attention_mask_cpu, + packed_seq_lens=packed_seq_lens, + visual_inputs=visual_inputs_cpu, ) # avoid CUDA OOM when colocate models if args.colocate_critic_reward or args.colocate_all_models: @@ -658,7 +658,11 @@ def make_experience(self, samples: Samples) -> Experience: for rm in self.reward_model: r_refs.append( rm.forward.remote( - sequences_cpu, attention_mask_cpu, packed_seq_lens=packed_seq_lens, pad_sequence=True, visual_inputs=visual_inputs_cpu + sequences_cpu, + attention_mask_cpu, + packed_seq_lens=packed_seq_lens, + pad_sequence=True, + visual_inputs=visual_inputs_cpu, ) ) else: @@ -688,13 +692,13 @@ def make_experience(self, samples: Samples) -> Experience: # log probs action_log_probs = self.actor( - sequences, - num_actions, - attention_mask, + sequences, + num_actions, + attention_mask, ring_attn_group=self.strategy.ring_attn_group, logps_allgather=True, packed_seq_lens=packed_seq_lens, - visual_inputs=visual_inputs + visual_inputs=visual_inputs, ) actor_value_rm_time = time.time() - start @@ -709,10 +713,10 @@ def make_experience(self, samples: Samples) -> Experience: if value is not None: value = value.to(device) - total_rewards = [r.pop('rewards').to(device) if isinstance(r,dict) else r.to(device) for r in rewards] + total_rewards = [r.pop("rewards").to(device) if isinstance(r, dict) else r.to(device) for r in rewards] specific_rewards = {} for r in rewards: - if isinstance(r,dict): + if isinstance(r, dict): for k in r.keys(): r[k] = r[k].to(device) specific_rewards.update(r) @@ -774,7 +778,7 @@ def make_experience(self, samples: Samples) -> Experience: "response_length": samples.response_length, "total_length": samples.total_length, "num_actions": num_actions, - **specific_rewards + **specific_rewards, } if self.strategy.args.perf: @@ -792,7 +796,7 @@ def make_experience(self, samples: Samples) -> Experience: action_mask, info, kl, - visual_inputs=visual_inputs + visual_inputs=visual_inputs, ) self.actor.train() # reset model state @@ -800,6 +804,7 @@ def make_experience(self, samples: Samples) -> Experience: def _generate_vllm(self, all_prompts: List[str], all_labels, **kwargs) -> List[Samples]: from vllm import SamplingParams + # round-robin load balance rank = torch.distributed.get_rank() // self.strategy.ring_attn_size world_size = torch.distributed.get_world_size() // self.strategy.ring_attn_size @@ -835,19 +840,21 @@ def _generate_vllm(self, all_prompts: List[str], all_labels, **kwargs) -> List[S if messages: prompts = self.data_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) images = [self.data_processor.get_images_from_messages(m) for m in messages] - vllm_inputs = [{ + vllm_inputs = [ + { "prompt": p, - "multi_modal_data":{"image": imgs} if imgs else None, + "multi_modal_data": {"image": imgs} if imgs else None, "mm_processor_kwargs": { - "min_pixels": kwargs.get("min_pixels", 4*28*28), - "max_pixels": kwargs.get("max_pixels", 640*28*28), + "min_pixels": kwargs.get("min_pixels", 4 * 28 * 28), + "max_pixels": kwargs.get("max_pixels", 640 * 28 * 28), }, - } for p, imgs in zip(prompts,images)] + } + for p, imgs in zip(prompts, images) + ] refs.append( llm.add_requests.remote(rank, sampling_params=sampling_params, vllm_vision_input=vllm_inputs) ) - ray.get(refs) # Make sure all requests are sent. @@ -901,7 +908,7 @@ def _generate_vllm(self, all_prompts: List[str], all_labels, **kwargs) -> List[S attention_mask = attention_mask.to("cuda") action_mask = action_mask.to("cuda") # Collect for visual input - + visual_inputs = self.data_processor(prompts, self.prompt_max_len, device="cuda") visual_inputs.pop("input_ids") visual_inputs.pop("attention_mask") diff --git a/openrlhf/trainer/ppo_utils/replay_buffer.py b/openrlhf/trainer/ppo_utils/replay_buffer.py index 481644d2..3c9f31bf 100644 --- a/openrlhf/trainer/ppo_utils/replay_buffer.py +++ b/openrlhf/trainer/ppo_utils/replay_buffer.py @@ -6,9 +6,10 @@ import torch import torch.nn.functional as F +from openrlhf.models.lmm_kits.base.data_processor import BaseDataProcessor from .experience_maker import Experience -from openrlhf.models.lmm_kits.base.data_processor import BaseDataProcessor + @dataclass class BufferItem: @@ -64,15 +65,14 @@ def split_experience_batch(experience: Experience, data_processor: Optional[Base assert batch_size == len(vals) for i, v in enumerate(vals): batch_kwargs[i][key] = v - + visual_inputs_batch = experience.visual_inputs - visual_inputs_batch['input_ids'] = experience.sequences + visual_inputs_batch["input_ids"] = experience.sequences visual_inputs_chunks = data_processor.split_input_batch(visual_inputs_batch) for i, visual_inputs in enumerate(visual_inputs_chunks): - visual_inputs.pop('input_ids') + visual_inputs.pop("input_ids") batch_kwargs[i]["visual_inputs"] = visual_inputs - for i in range(batch_size): batch_kwargs[i]["info"] = {} for k, v in experience.info.items(): @@ -99,7 +99,9 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> tor return torch.stack(padded_sequences, dim=0) -def make_experience_batch(items: List[BufferItem], data_processor: Optional[BaseDataProcessor], packing_samples=False) -> Experience: +def make_experience_batch( + items: List[BufferItem], data_processor: Optional[BaseDataProcessor], packing_samples=False +) -> Experience: kwargs = {} keys = ( "sequences", @@ -123,7 +125,7 @@ def make_experience_batch(items: List[BufferItem], data_processor: Optional[Base for key in items[0].info.keys(): vals = torch.tensor([item.info[key] for item in items]) kwargs["info"][key] = vals - + kwargs["visual_inputs"] = data_processor.make_input_batch([item.visual_inputs for item in items]) return Experience(**kwargs) @@ -177,11 +179,11 @@ class NaiveReplayBuffer(ABC): """ def __init__( - self, - sample_batch_size: int, - data_processor: Optional[BaseDataProcessor] = None, - limit: int = 0, - cpu_offload: bool = True, + self, + sample_batch_size: int, + data_processor: Optional[BaseDataProcessor] = None, + limit: int = 0, + cpu_offload: bool = True, packing_samples: bool = False, drop_maxlen: bool = False, maxlen: int = 10**8, diff --git a/openrlhf/trainer/ray/launcher.py b/openrlhf/trainer/ray/launcher.py index 9dece250..aecac635 100644 --- a/openrlhf/trainer/ray/launcher.py +++ b/openrlhf/trainer/ray/launcher.py @@ -5,7 +5,6 @@ import ray import torch -import torch.distributed as dist from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -95,7 +94,7 @@ def forward( visual_inputs = {} device = torch.cuda.current_device() with torch.no_grad(): - visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()} + visual_inputs = {k: v.to(device) for k, v in visual_inputs.items()} log_probs = self.model( sequences.to(device), num_actions, @@ -148,7 +147,7 @@ def forward( device = torch.cuda.current_device() if visual_inputs is None: visual_inputs = {} - visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()} + visual_inputs = {k: v.to(device) for k, v in visual_inputs.items()} with torch.no_grad(): reward = self.model( sequences.to(device), diff --git a/openrlhf/trainer/ray/ppo_actor.py b/openrlhf/trainer/ray/ppo_actor.py index 982ce991..96f3c5c1 100644 --- a/openrlhf/trainer/ray/ppo_actor.py +++ b/openrlhf/trainer/ray/ppo_actor.py @@ -324,10 +324,12 @@ def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): if any(name.startswith(prefix) for prefix in strategy.args.freeze_prefix): param.requires_grad = False frozen_count += 1 - strategy.print(f"Froze {frozen_count}/{total_params} parameters based on prefixes: {strategy.args.freeze_prefix}") + strategy.print( + f"Froze {frozen_count}/{total_params} parameters based on prefixes: {strategy.args.freeze_prefix}" + ) # configure tokenizer - + self.data_processor = get_data_processor( pretrain, actor.model, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer ) diff --git a/openrlhf/trainer/ray/ppo_critic.py b/openrlhf/trainer/ray/ppo_critic.py index b95f638b..6d56b787 100644 --- a/openrlhf/trainer/ray/ppo_critic.py +++ b/openrlhf/trainer/ray/ppo_critic.py @@ -9,9 +9,9 @@ from transformers.trainer import get_scheduler from openrlhf.models import get_llm_for_sequence_regression +from openrlhf.models.lmm_kits.utils import get_data_processor from openrlhf.trainer import PPOTrainer from openrlhf.trainer.ppo_utils import Experience -from openrlhf.models.lmm_kits.utils import get_data_processor from openrlhf.utils.deepspeed import DeepspeedStrategy from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states @@ -151,7 +151,7 @@ def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_ prompt_max_len=args.prompt_max_len, value_clip=args.value_clip, eps_clip=args.eps_clip, - data_processor=self.data_processor + data_processor=self.data_processor, ) def forward( @@ -207,7 +207,6 @@ def save_model(self): args.save_path + "_critic", ) - def save_checkpoint(self, tag): args = self.strategy.args self.strategy.save_ckpt( diff --git a/openrlhf/utils/__init__.py b/openrlhf/utils/__init__.py index aec56564..60803c08 100644 --- a/openrlhf/utils/__init__.py +++ b/openrlhf/utils/__init__.py @@ -1,10 +1,4 @@ from .processor import get_processor, reward_normalization from .utils import blending_datasets, get_strategy, get_tokenizer -__all__ = [ - "get_processor", - "reward_normalization", - "blending_datasets", - "get_strategy", - "get_tokenizer" -] +__all__ = ["get_processor", "reward_normalization", "blending_datasets", "get_strategy", "get_tokenizer"] diff --git a/openrlhf/utils/distributed_sampler.py b/openrlhf/utils/distributed_sampler.py index 1f765820..df4ddd99 100644 --- a/openrlhf/utils/distributed_sampler.py +++ b/openrlhf/utils/distributed_sampler.py @@ -6,7 +6,6 @@ from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import Sampler - __all__ = ["DistributedSampler"] diff --git a/openrlhf/utils/logging_utils.py b/openrlhf/utils/logging_utils.py index eb39f39a..7b19681e 100644 --- a/openrlhf/utils/logging_utils.py +++ b/openrlhf/utils/logging_utils.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" + import logging import sys diff --git a/openrlhf/utils/remote_rm_utils.py b/openrlhf/utils/remote_rm_utils.py index bff5a85f..dd33523d 100644 --- a/openrlhf/utils/remote_rm_utils.py +++ b/openrlhf/utils/remote_rm_utils.py @@ -37,7 +37,7 @@ def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"): score_key: RM score key """ responses = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key) - return {k:torch.tensor(v) for k,v in responses.items()} + return {k: torch.tensor(v) for k, v in responses.items()} @ray.remote @@ -49,4 +49,4 @@ def remote_rm_fn_ray(api_url, queries, prompts, labels, score_key="rewards"): # test utils url = "http:xxx/get_rm_score" score = remote_rm_fn(url, ["example query"], ["example response"]) - print(score) \ No newline at end of file + print(score) diff --git a/openrlhf/utils/utils.py b/openrlhf/utils/utils.py index 08b0b160..6698ef88 100644 --- a/openrlhf/utils/utils.py +++ b/openrlhf/utils/utils.py @@ -3,6 +3,7 @@ from datasets import interleave_datasets, load_dataset, load_from_disk from transformers import AutoTokenizer + def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True): tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) tokenizer.padding_side = padding_side @@ -126,4 +127,4 @@ def convert_token_to_id(token, tokenizer): assert len(token) == 1 return token[0] else: - raise ValueError("token should be int or str") \ No newline at end of file + raise ValueError("token should be int or str") diff --git a/requirements.txt b/requirements.txt index 03f9ae27..441e8f28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,10 @@ einops flask isort jsonlines +levenshtein +loguru loralib math-verify==0.5.2 -levenshtein optimum packaging peft @@ -16,11 +17,10 @@ pynvml>=12.0.0 qwen_vl_utils tensorboard torch==2.5.1 -torchvision torchmetrics +torchvision tqdm transformers @ git+https://github.com/huggingface/transformers@2ab7bdc40333b230b642f09e8334fb8e1a92d2a4 transformers_stream_generator wandb wheel -loguru \ No newline at end of file