-
Notifications
You must be signed in to change notification settings - Fork 28
Description
I've been trying the training and I have the same issues no matter if I try to train on 2 or on 4 gpus. I've tried all the training scripts and I have similar issues on all of them.
Description
When running training on digit recognition env, the process fails with KeyError related to completion_ids indexing with process_slice.
Error Details
File "src/r1_vlm/environments/digit_recognition_env/train.py", line 91, in <module>
trainer.train()
File "site-packages/transformers/trainer.py", line 2241, in train
return inner_training_loop(
File "site-packages/transformers/trainer.py", line 2548, in *inner*training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "site-packages/transformers/trainer.py", line 3692, in training_step
inputs = self._prepare_inputs(inputs)
File "site-packages/trl/trl/trainer/qwen_grpo_trainer.py", line 641, in *prepare*inputs
completion_ids = completion_ids[process_slice]
KeyError: slice(0, 30, None)
Steps to Reproduce
- Run training script for digit recognition environment
- The error occurs during the training process in the QwenGRPOTrainer._prepare_inputs method
CUDA_VISIBLE_DEVICES=0,1 uv run src/r1_vlm/environments/digit_recognition_env/train.py
Similar story for
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run accelerate launch --config_file src/r1_vlm/deepspeed_configs/multi_gpu_3only.yaml src/r1_vlm/environments/digits_tool_use/train.py
Additional Context
Based on the error, it appears that the completion_ids dictionary doesn't have the expected slice key or structure for indexing. This may indicate an issue with how completion_ids is formatted or accessed in the QwenGRPOTrainer implementation.
Possible Solution
The issue appears to be in the _prepare_inputs method in qwen_grpo_trainer.py, where it attempts to access completion_ids with process_slice. Check if completion_ids is actually a dictionary rather than a sequence that supports slice indexing, or if the process_slice value is invalid.