diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index cd6cc917489..ac2ee2f9e40 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -249,3 +249,11 @@ class GRPOConfig(TrainingArguments): default=False, metadata={"help": "Whether to log the completions during training."}, ) + limit_image_per_prompt: int = field( + default=1, + metadata={"help": "Limit the number of images per prompt for vllm generation."}, + ) + limit_video_per_prompt: int = field( + default=0, + metadata={"help": "Limit the number of videos per prompt for vllm generation."}, + ) \ No newline at end of file diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 6730793810d..d0e2efca84f 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -228,6 +228,8 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, + image_pad_id: int = 151655, + inputs_to_log: list[str] = [], ): # Args if args is None: @@ -439,7 +441,7 @@ def data_collator(features): # No data collation is needed in GRPO enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, # Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it. - limit_mm_per_prompt={"image": 1, "video": 0}, + limit_mm_per_prompt={"image": self.args.limit_image_per_prompt, "video": self.args.limit_video_per_prompt}, ) self.sampling_params = SamplingParams( temperature=args.temperature, @@ -478,6 +480,9 @@ def data_collator(features): # No data collation is needed in GRPO if isinstance(reward_func, PreTrainedModel): self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + self.image_pad_id = image_pad_id + self.inputs_to_log = inputs_to_log + def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs. @@ -621,15 +626,22 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if self.env is None: raise ValueError("No environment provided. Only supporting envs now.") else: - completion_ids = self.env.generate( + generated_output = self.env.generate( conversations=all_conversations, vlm_inputs=all_env_inputs, vlm=self.vlm, sampling_params=self.sampling_params, ) + completion_ids = generated_output['ids'] + completion_messages = generated_output.get('messages', None) + completion_mask = generated_output.get('mask', None) + + else: completion_ids = [None] * len(all_env_inputs) + completion_messages = [None] * len(all_env_inputs) + completion_mask = [None] * len(all_env_inputs) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. @@ -640,14 +652,61 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s ) completion_ids = completion_ids[process_slice] - eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) - # Pad completion_ids to uniform length, mask from last output token (EOS) completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id) - sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # broadcast and slice completion messages too. + completion_messages = broadcast_object_list(completion_messages, from_process=0) + completion_messages = completion_messages[process_slice] + + # Handle completion mask: broadcast from main process to all processes if available + if completion_mask is not None: + # Broadcast the completion_mask from the main process to all processes + completion_mask = broadcast_object_list(completion_mask, from_process=0) + + # Each process takes its corresponding slice based on process index + completion_mask = completion_mask[process_slice] + + # Convert mask elements to tensors and move to correct device + completion_mask = [torch.tensor(mask, device=device) for mask in completion_mask] + # Pad masks to uniform length + completion_mask = pad(completion_mask, padding_value=0) + else: + print("No completion mask provided. Computing mask based on EOS positions.") + # Fallback: compute mask based on EOS positions if not provided + eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) + sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + + # Handle the potential new images generated from the environment (tool) in completion_messages + new_images = [] + for i, completion_message in enumerate(completion_messages): + if completion_message is not None: + for message in completion_message: + for content in message["content"]: + if content.get("type", None) == "image": + new_images.append(content["image"]) + + if len(new_images) > 0: + # use the processor to get pixel_values and image_grid_thw for the new images + new_images_info = self.processing_class( + text='', + images=new_images, + return_tensors='pt', + padding=True, + ) + new_pixel_values = new_images_info["pixel_values"] + new_image_grid_thw = new_images_info["image_grid_thw"] + + # Concatenate the new pixel_values and image_grid_thw with the existing ones + # make sure pixel_values and new_pixel_values are on the same device. same for image_grid_thw and new_image_grid_thw + new_pixel_values = new_pixel_values.to(device) + new_image_grid_thw = new_image_grid_thw.to(device) + pixel_values = torch.cat([pixel_values, new_pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") @@ -682,7 +741,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if is_conversational(inputs[0]): completions = [] for prompt, completion in zip(conversations, completions_text): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + bootstrap = prompt[-1]["content"] if prompt[-1]["role"] == "assistant" else "" if isinstance(bootstrap, list): if len(bootstrap) > 1: raise ValueError("Only one bootstrap is supported for now.") @@ -697,6 +756,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s zip(self.reward_funcs, self.reward_processing_classes) ): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + raise NotImplementedError("Models as reward functions are not supported yet.") if is_conversational(inputs[0]): messages = [{"messages": p + c} for p, c in zip(conversations, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] @@ -717,6 +777,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] reward_kwargs = {key: [example[key] for example in inputs] for key in keys} reward_kwargs["prompts_text"] = prompts_text + reward_kwargs["completions_messages"] = completion_messages output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) @@ -779,11 +840,30 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s import pandas as pd # For logging + inputs_data_to_log = { + key: gather_object( + [i[key] for i in inputs if key in i] + ) for key in self.inputs_to_log + } + # if the value is torch.Tensor, convert it to a list + for key, value in inputs_data_to_log.items(): + if isinstance(value, torch.Tensor): + inputs_data_to_log[key] = value.tolist() + + # gather completion_ids and get num_image_pad_ids + # completion_ids shape: (B*G, C) B is batch size, G is number of generations, C is completion length + gathered_completion_ids = gather_object(completion_ids) + # after gathering, there will be B*G items and each item is a tensor of shape their own(C,) + # handle each item one by one + num_image_pad_ids = [(ids == self.image_pad_id).sum().item() for ids in gathered_completion_ids] table = { "step": [str(self.state.global_step)] * len(rewards), "prompt": gather_object(prompts_text), "completion": gather_object(completions_text), "reward": rewards.tolist(), + "reward_per_func": rewards_per_func.tolist(), + "num_image_pad_ids": num_image_pad_ids, + **inputs_data_to_log, } df = pd.DataFrame(table)