From 92b476f2c367d52c65f9c6f3513e4d78b23f9f53 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 19 Mar 2025 19:18:19 -0700 Subject: [PATCH 01/11] handled competion messages and dict of outputs properly. prototype for completion masking --- trl/trainer/qwen_grpo_trainer.py | 43 +++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 6730793810d..bb13921dead 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -621,16 +621,21 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if self.env is None: raise ValueError("No environment provided. Only supporting envs now.") else: - completion_ids = self.env.generate( + generated_output = self.env.generate( conversations=all_conversations, vlm_inputs=all_env_inputs, vlm=self.vlm, sampling_params=self.sampling_params, ) + completion_ids = generated_output['ids'] + completion_messages = generated_output.get('messages', None) + completion_mask = generated_output.get('completion_mask', None) + + else: completion_ids = [None] * len(all_env_inputs) - + completion_messages = [None] * len(all_env_inputs) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) @@ -640,17 +645,42 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s ) completion_ids = completion_ids[process_slice] - eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) - # Pad completion_ids to uniform length, mask from last output token (EOS) completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id) - sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # Handle completion mask: broadcast from main process to all processes if available + if completion_mask is not None: + # Broadcast the completion_mask from the main process to all processes + completion_mask = broadcast_object_list(completion_mask, from_process=0) + + # Each process takes its corresponding slice based on process index + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + completion_mask = completion_mask[process_slice] + + # Convert mask elements to tensors and move to correct device + completion_mask = [torch.tensor(mask, device=device) for mask in completion_mask] + # Pad masks to uniform length + completion_mask = pad(completion_mask, padding_value=0) + else: + print("No completion mask provided. Computing mask based on EOS positions.") + # Fallback: compute mask based on EOS positions if not provided + eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) + sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") + if self.accelerator.is_main_process: + print("SHAPE CHECK") + print(completion_ids[0].shape) + print(completion_mask[0].shape) + # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -717,6 +747,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] reward_kwargs = {key: [example[key] for example in inputs] for key in keys} reward_kwargs["prompts_text"] = prompts_text + reward_kwargs["completions_messages"] = completion_messages output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) From 38a35dafd4346932f5ed8ef3fe5090d40ddee723 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 19 Mar 2025 19:41:35 -0700 Subject: [PATCH 02/11] masking seems to be working properly? at least its masking out the proper tokens. Still need to see a full step. --- trl/trainer/qwen_grpo_trainer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index bb13921dead..2c502fe0d9d 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -630,12 +630,14 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s completion_ids = generated_output['ids'] completion_messages = generated_output.get('messages', None) - completion_mask = generated_output.get('completion_mask', None) + completion_mask = generated_output.get('mask', None) else: completion_ids = [None] * len(all_env_inputs) completion_messages = [None] * len(all_env_inputs) + completion_mask = [None] * len(all_env_inputs) + # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) @@ -676,11 +678,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") - if self.accelerator.is_main_process: - print("SHAPE CHECK") - print(completion_ids[0].shape) - print(completion_mask[0].shape) - # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) From 6aa33a01a403ef51e9316ac263606ff53531e732 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 19 Mar 2025 20:53:50 -0700 Subject: [PATCH 03/11] Don't pop when doing bootstrap combination. Added broadcasting for completion messages and ensured proper slicing. Introduced error handling for unsupported reward function types. --- trl/trainer/qwen_grpo_trainer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 2c502fe0d9d..c4dfaaca043 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -651,16 +651,16 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id) + # broadcast and slice completion messages too. + completion_messages = broadcast_object_list(completion_messages, from_process=0) + completion_messages = completion_messages[process_slice] + # Handle completion mask: broadcast from main process to all processes if available if completion_mask is not None: # Broadcast the completion_mask from the main process to all processes completion_mask = broadcast_object_list(completion_mask, from_process=0) # Each process takes its corresponding slice based on process index - process_slice = slice( - self.accelerator.process_index * len(inputs), - (self.accelerator.process_index + 1) * len(inputs), - ) completion_mask = completion_mask[process_slice] # Convert mask elements to tensors and move to correct device @@ -709,7 +709,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if is_conversational(inputs[0]): completions = [] for prompt, completion in zip(conversations, completions_text): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + bootstrap = prompt[-1]["content"] if prompt[-1]["role"] == "assistant" else "" if isinstance(bootstrap, list): if len(bootstrap) > 1: raise ValueError("Only one bootstrap is supported for now.") @@ -724,6 +724,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s zip(self.reward_funcs, self.reward_processing_classes) ): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + raise NotImplementedError("Models as reward functions are not supported yet.") if is_conversational(inputs[0]): messages = [{"messages": p + c} for p, c in zip(conversations, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] From 2c10fb3b963a5cbf3fa7c7dc557c6f9d0929c2d7 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Tue, 25 Mar 2025 16:54:48 -0700 Subject: [PATCH 04/11] handle the new images added to the message by the environment (tools) --- trl/trainer/qwen_grpo_trainer.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index c4dfaaca043..0d556f18e2a 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -675,6 +675,33 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + + # Handle the potential new images generated from the environment (tool) in completion_messages + new_images = [] + for i, completion_message in enumerate(completion_messages): + if completion_message is not None: + for message in completion_message: + for content in message["content"]: + if content.get("type", None) == "image": + new_images.append(content["image"]) + + if len(new_images) > 0: + # use the processor to get pixel_values and image_grid_thw for the new images + new_images_info = self.processing_class( + text='', + images=new_images, + return_tensors='pt', + padding=True, + ) + new_pixel_values = new_images_info["pixel_values"] + new_image_grid_thw = new_images_info["image_grid_thw"] + + # Concatenate the new pixel_values and image_grid_thw with the existing ones + # make sure pixel_values and new_pixel_values are on the same device. same for image_grid_thw and new_image_grid_thw + new_pixel_values = new_pixel_values.to(device) + new_image_grid_thw = new_image_grid_thw.to(device) + pixel_values = torch.cat([pixel_values, new_pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") From c9ea32f721e610d72bedcbdcdd03481197cb5284 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Tue, 25 Mar 2025 16:55:44 -0700 Subject: [PATCH 05/11] make the limitation of num images/videos per prompt configurable --- trl/trainer/qwen_grpo_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 0d556f18e2a..5725b81d312 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -228,6 +228,8 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, + limit_image_per_prompt: int = 1, + limit_video_per_prompt: int = 0, ): # Args if args is None: @@ -439,7 +441,7 @@ def data_collator(features): # No data collation is needed in GRPO enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, # Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it. - limit_mm_per_prompt={"image": 1, "video": 0}, + limit_mm_per_prompt={"image": limit_image_per_prompt, "video": limit_video_per_prompt}, ) self.sampling_params = SamplingParams( temperature=args.temperature, From 58db45b1a6eb37cbff6ab4524b6b36569766784b Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Tue, 25 Mar 2025 17:04:19 -0700 Subject: [PATCH 06/11] move the limitation args to GRPOConfig --- trl/trainer/grpo_config.py | 8 ++++++++ trl/trainer/qwen_grpo_trainer.py | 4 +--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index cd6cc917489..ac2ee2f9e40 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -249,3 +249,11 @@ class GRPOConfig(TrainingArguments): default=False, metadata={"help": "Whether to log the completions during training."}, ) + limit_image_per_prompt: int = field( + default=1, + metadata={"help": "Limit the number of images per prompt for vllm generation."}, + ) + limit_video_per_prompt: int = field( + default=0, + metadata={"help": "Limit the number of videos per prompt for vllm generation."}, + ) \ No newline at end of file diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 5725b81d312..a9638c63bc1 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -228,8 +228,6 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, - limit_image_per_prompt: int = 1, - limit_video_per_prompt: int = 0, ): # Args if args is None: @@ -441,7 +439,7 @@ def data_collator(features): # No data collation is needed in GRPO enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, # Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it. - limit_mm_per_prompt={"image": limit_image_per_prompt, "video": limit_video_per_prompt}, + limit_mm_per_prompt={"image": self.args.limit_image_per_prompt, "video": self.args.limit_video_per_prompt}, ) self.sampling_params = SamplingParams( temperature=args.temperature, From 6dd8939280d3ec636316a0211ad5e89977c01df7 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 3 Apr 2025 12:28:13 -0700 Subject: [PATCH 07/11] add logs for reward_per_func --- trl/trainer/qwen_grpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index a9638c63bc1..74d06ce0d29 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -840,6 +840,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s "prompt": gather_object(prompts_text), "completion": gather_object(completions_text), "reward": rewards.tolist(), + "reward_per_func": rewards_per_func.tolist(), } df = pd.DataFrame(table) From 2521d2ea6fb1d0a250aa1a21be80947e7d4865a7 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 3 Apr 2025 12:47:27 -0700 Subject: [PATCH 08/11] add the number of image_pad_ids in logging for debug purposes --- trl/trainer/qwen_grpo_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 74d06ce0d29..f52eb059540 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -228,6 +228,7 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, + image_pad_id: int = 151655, ): # Args if args is None: @@ -478,6 +479,8 @@ def data_collator(features): # No data collation is needed in GRPO if isinstance(reward_func, PreTrainedModel): self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + self.image_pad_id = image_pad_id + def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs. @@ -733,6 +736,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + num_image_pad_ids = (completion_ids == self.image_pad_id).sum(dim=1) if is_conversational(inputs[0]): completions = [] for prompt, completion in zip(conversations, completions_text): @@ -841,6 +845,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s "completion": gather_object(completions_text), "reward": rewards.tolist(), "reward_per_func": rewards_per_func.tolist(), + "num_image_pad_ids": num_image_pad_ids.tolist(), } df = pd.DataFrame(table) From 030179dad31511adad4c6de651a672aa7fb79e0e Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 3 Apr 2025 12:52:18 -0700 Subject: [PATCH 09/11] add the option to add inputs to be logged --- trl/trainer/qwen_grpo_trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index f52eb059540..593eaa04297 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -229,6 +229,7 @@ def __init__( peft_config: Optional["PeftConfig"] = None, shuffle_dataset: bool = True, image_pad_id: int = 151655, + inputs_to_log: list[str] = [], ): # Args if args is None: @@ -480,6 +481,7 @@ def data_collator(features): # No data collation is needed in GRPO self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) self.image_pad_id = image_pad_id + self.inputs_to_log = inputs_to_log def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. @@ -839,6 +841,12 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s import pandas as pd # For logging + inputs_data_to_log = {key: gather_object(inputs[key]) for key in self.inputs_to_log} + # if the value is torch.Tensor, convert it to a list + for key, value in inputs_data_to_log.items(): + if isinstance(value, torch.Tensor): + inputs_data_to_log[key] = value.tolist() + table = { "step": [str(self.state.global_step)] * len(rewards), "prompt": gather_object(prompts_text), @@ -846,6 +854,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s "reward": rewards.tolist(), "reward_per_func": rewards_per_func.tolist(), "num_image_pad_ids": num_image_pad_ids.tolist(), + **inputs_data_to_log, } df = pd.DataFrame(table) From 1ae78857e1c5a11c0f1156158f70980d10307772 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 3 Apr 2025 13:18:42 -0700 Subject: [PATCH 10/11] fix type and shape issues --- trl/trainer/qwen_grpo_trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 593eaa04297..0c04b2c4a8b 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -738,7 +738,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - num_image_pad_ids = (completion_ids == self.image_pad_id).sum(dim=1) + num_image_pad_ids = (completion_ids == self.image_pad_id).sum(dim=0) if is_conversational(inputs[0]): completions = [] for prompt, completion in zip(conversations, completions_text): @@ -841,7 +841,11 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s import pandas as pd # For logging - inputs_data_to_log = {key: gather_object(inputs[key]) for key in self.inputs_to_log} + inputs_data_to_log = { + key: gather_object( + [i[key] for i in inputs if key in i] + ) for key in self.inputs_to_log + } # if the value is torch.Tensor, convert it to a list for key, value in inputs_data_to_log.items(): if isinstance(value, torch.Tensor): From 671de07f700569f4f535df3bc6d18193482862e2 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 3 Apr 2025 13:59:50 -0700 Subject: [PATCH 11/11] fix handling the num_image_pad_ids calculation --- trl/trainer/qwen_grpo_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 0c04b2c4a8b..d0e2efca84f 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -738,7 +738,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - num_image_pad_ids = (completion_ids == self.image_pad_id).sum(dim=0) if is_conversational(inputs[0]): completions = [] for prompt, completion in zip(conversations, completions_text): @@ -851,13 +850,19 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if isinstance(value, torch.Tensor): inputs_data_to_log[key] = value.tolist() + # gather completion_ids and get num_image_pad_ids + # completion_ids shape: (B*G, C) B is batch size, G is number of generations, C is completion length + gathered_completion_ids = gather_object(completion_ids) + # after gathering, there will be B*G items and each item is a tensor of shape their own(C,) + # handle each item one by one + num_image_pad_ids = [(ids == self.image_pad_id).sum().item() for ids in gathered_completion_ids] table = { "step": [str(self.state.global_step)] * len(rewards), "prompt": gather_object(prompts_text), "completion": gather_object(completions_text), "reward": rewards.tolist(), "reward_per_func": rewards_per_func.tolist(), - "num_image_pad_ids": num_image_pad_ids.tolist(), + "num_image_pad_ids": num_image_pad_ids, **inputs_data_to_log, } df = pd.DataFrame(table)