diff --git a/pyproject.toml b/pyproject.toml index b46e51697..ce03ab9a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,8 @@ rl = [ "liger-kernel>=0.5.10", "deepspeed>=0.17.6", "flash-attn>=2.8.3", + "unsloth>=2025.12.9", + ] envs = [ "math-verify>=0.8.0", diff --git a/verifiers/__init__.py b/verifiers/__init__.py index d7e2fdb61..ac74ff91b 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -110,6 +110,7 @@ def setup_logging( "get_model_and_tokenizer", "RLTrainer", "RLConfig", + "UnslothConfig", "GRPOTrainer", "GRPOConfig", "grpo_defaults", @@ -123,6 +124,7 @@ def setup_logging( "get_model": "verifiers.rl.trainer.utils:get_model", "get_model_and_tokenizer": "verifiers.rl.trainer.utils:get_model_and_tokenizer", "RLConfig": "verifiers.rl.trainer:RLConfig", + "UnslothConfig": "verifiers.configs.unsloth_config:UnslothConfig", "RLTrainer": "verifiers.rl.trainer:RLTrainer", "GRPOTrainer": "verifiers.rl.trainer:GRPOTrainer", "GRPOConfig": "verifiers.rl.trainer:GRPOConfig", @@ -170,6 +172,7 @@ def __getattr__(name: str): grpo_defaults, lora_defaults, ) + from .configs.unsloth_config import UnslothConfig # noqa: F401 from .rl.trainer.utils import ( # noqa: F401 get_model, get_model_and_tokenizer, diff --git a/verifiers/rl/configs/unsloth_config.py b/verifiers/rl/configs/unsloth_config.py new file mode 100644 index 000000000..0b09e96fe --- /dev/null +++ b/verifiers/rl/configs/unsloth_config.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +@dataclass +class UnslothConfig: + """ + Configuration class for Unsloth Trainer. + """ + + # Model Load parameters + load_in_4bit: bool = field( + default=False, + metadata={"help": "Whether to use 4-bit precision for model weights."}, + ) + + load_in_8bit: bool = field( + default=False, + metadata={"help": "Whether to use 8-bit precision for model weights."}, + ) + + load_in_16bit: bool = field( + default=True, + metadata={"help": "Whether to use 16-bit precision for model weights."}, + ) + + full_finetuning: bool = field( + default=False, + metadata={"help": "Whether to fine-tune the entire model."}, + ) + + use_exact_model_name: bool = field( + default=False, + metadata={"help": "Whether to use the exact model name without mapping."}, + ) + + gpu_memory_utilization: float = field( + default=0.8, + metadata={"help": "Target GPU memory utilization for model loading."}, + ) + + random_state: int = field( + default=3407, + metadata={"help": "Random state for reproducibility."}, + ) + + max_lora_rank: int = field( + default=64, + metadata={"help": "Maximum allowable rank for LoRA adapters."}, + ) + + token: Optional[str] = field( + default=None, + metadata={"help": "Huggingface token for private model access."}, + ) + + # Additional Model Lora parameters + + use_gradient_checkpointing: str = field( + default="unsloth", + metadata={"help": "Gradient checkpointing strategy."}, + ) + + loftq_config: Optional[dict] = field( + default=None, + metadata={"help": "Configuration for LoFT-Q."}, + ) + diff --git a/verifiers/rl/trainer/config.py b/verifiers/rl/trainer/config.py index 49c1824f6..f505f8ab6 100644 --- a/verifiers/rl/trainer/config.py +++ b/verifiers/rl/trainer/config.py @@ -5,6 +5,8 @@ from transformers import TrainingArguments from transformers.trainer_utils import SchedulerType +from verifiers.rl.configs.unsloth_config import UnslothConfig + @dataclass class RLConfig(TrainingArguments): @@ -188,10 +190,12 @@ class RLConfig(TrainingArguments): default="0.0.0.0", metadata={"help": "Host of the vLLM server to connect to."}, ) + vllm_server_port: int = field( default=8000, metadata={"help": "Port of the vLLM server to connect to."}, ) + vllm_server_timeout: float = field( default=300.0, metadata={ @@ -273,6 +277,28 @@ class RLConfig(TrainingArguments): metadata={"help": "Whether to shuffle the training dataset."}, ) + use_unsloth: bool = field( + default=False, + metadata={"help": "Whether to use UnslothConfig for additional model training parameters."}, + ) + + unsloth_config: Optional[UnslothConfig] = field( + default=None, + metadata={"help": "UnslothConfig instance for additional model training parameters."}, + ) + + unsloth_base_model_args: dict = field( + init=False, + default_factory=dict, + metadata={"help": "Arguments for loading the base model with Unsloth."}, + ) + + unsloth_lora_args: dict = field( + init=False, + default_factory=dict, + metadata={"help": "Additional arguments for LoRA configuration with Unsloth."}, + ) + def __post_init__(self): # configure output dir if self.output_dir is None: @@ -325,6 +351,40 @@ def __post_init__(self): }, } self.gradient_accumulation_steps = 1 + + if self.use_unsloth: + self.unsloth_base_model_args = { + "load_in_4bit": self.unsloth_config.load_in_4bit, + "load_in_8bit": self.unsloth_config.load_in_8bit, + "load_in_16bit": self.unsloth_config.load_in_16bit, + "full_finetuning": self.unsloth_config.full_finetuning, + "use_exact_model_name": self.unsloth_config.use_exact_model_name, + "gpu_memory_utilization": self.unsloth_config.gpu_memory_utilization, + "token": self.unsloth_config.token, + } + + self.unsloth_lora_args = { + "r": self.lora_rank, + "lora_alpha": self.lora_alpha, + "target_modules": self.lora_target_modules, + "lora_dropout": self.lora_dropout, + "use_rslora": self.lora_use_rslora, + "loftq_config": self.unsloth_config.loftq_config, + "random_state": self.unsloth_config.random_state, + "use_gradient_checkpointing": self.unsloth_config.use_gradient_checkpointing, + } + + self.unsloth_config.r, + self.unsloth_config.lora_alpha, + self.unsloth_config.target_modules, + self.unsloth_config.lora_dropout = ( + self.lora_rank, + self.lora_rank, + self.lora_alpha, + self.lora_target_modules, + self.lora_dropout, + ) + super().__post_init__() num_processes = self.world_size diff --git a/verifiers/rl/trainer/trainer.py b/verifiers/rl/trainer/trainer.py index 63c79594a..e1784b8dc 100644 --- a/verifiers/rl/trainer/trainer.py +++ b/verifiers/rl/trainer/trainer.py @@ -53,12 +53,24 @@ def __init__( # model + tokenizer if isinstance(model, str): model_name = model - model, processing_class = vf.get_model_and_tokenizer(model) + if args.use_unsloth and args.unsloth_config is not None: + model, processing_class = vf.unsloth_get_model_and_tokenizer( + model_name, + unsloth_config=args.unsloth_base_model_args, + ) + else: + model, processing_class = vf.get_model_and_tokenizer(model_name) else: model_name = model.config._name_or_path assert isinstance(model, PreTrainedModel) - if args.use_lora and isinstance(args.lora_config, PeftConfig): - model = prepare_peft_model(model, args.lora_config, args) + if args.use_lora: + if args.use_unsloth and args.unsloth_lora_args is not None: + model = vf.unsloth_prepare_peft_model( + model, + unsloth_config=args.unsloth_lora_args, + ) + elif isinstance(args.lora_config, PeftConfig): + model = prepare_peft_model(model, args.lora_config, args) model.warnings_issued["estimate_tokens"] = True # suppress warning super().__init__( @@ -88,6 +100,7 @@ def __init__( if self.accelerator.is_main_process: host = args.vllm_server_host port = args.vllm_server_port + self.client = VLLMClient( host=host, port=port, connection_timeout=args.vllm_server_timeout ) diff --git a/verifiers/rl/trainer/utils.py b/verifiers/rl/trainer/utils.py index 4b7f975e0..9ac851df2 100644 --- a/verifiers/rl/trainer/utils.py +++ b/verifiers/rl/trainer/utils.py @@ -43,6 +43,13 @@ def get_model_and_tokenizer( tokenizer = AutoTokenizer.from_pretrained(model_name) return model, tokenizer +def unsloth_get_model_and_tokenizer( + model_name: str, + unsloth_config: dict[str, Any], + ) -> tuple[Any, Any]: + from unsloth import FastLanguageModel + model, tokenizer = FastLanguageModel.from_pretrained(model_name, **unsloth_config) + return model, tokenizer def pad( tensors: list[torch.Tensor], @@ -168,6 +175,15 @@ def prepare_peft_model( return model +def unsloth_prepare_peft_model( + model: PreTrainedModel, unsloth_config: dict[str, Any], +) -> PreTrainedModel: + """Prepares a model for PEFT training using Unsloth.""" + from unsloth import FastLanguageModel + # TODO: check additional args ad kwargs + model = cast(PreTrainedModel, FastLanguageModel.get_peft_model(model, **unsloth_config)) + return model + def init_stat_tracker(device: torch.device) -> dict[str, torch.Tensor]: zero = torch.zeros((), device=device, dtype=torch.float32)