Skip to content

Commit 6950dba

Browse files
Fix RFT training (#630)
* fix RFT * added v1 model --------- Co-authored-by: Abhishek Govindarasu <abhishekgovindarasu@gmail.com>
1 parent 7bab34e commit 6950dba

File tree

9 files changed

+551
-36
lines changed

9 files changed

+551
-36
lines changed

src/judgeval/trainer/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import Optional, Dict, Any, TYPE_CHECKING
4+
from typing import Optional, Dict, Any
55
import json
66

7-
if TYPE_CHECKING:
8-
from fireworks.llm.llm_reinforcement_step import ReinforcementAcceleratorTypeLiteral # type: ignore[import-not-found]
9-
107

118
@dataclass
129
class TrainerConfig:
@@ -23,8 +20,6 @@ class TrainerConfig:
2320
concurrency: int = 100
2421
epochs: int = 1
2522
learning_rate: float = 1e-5
26-
accelerator_count: int = 1
27-
accelerator_type: ReinforcementAcceleratorTypeLiteral = "NVIDIA_A100_80GB"
2823
temperature: float = 1.5
2924
max_tokens: int = 50
3025
enable_addons: bool = True

src/judgeval/trainer/fireworks_trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ async def run_reinforcement_learning(
292292
"num_generations_per_prompt": self.config.num_generations_per_prompt,
293293
"epochs": self.config.epochs,
294294
"learning_rate": self.config.learning_rate,
295-
"accelerator_count": self.config.accelerator_count,
296-
"accelerator_type": self.config.accelerator_type,
297295
"temperature": self.config.temperature,
298296
"max_tokens": self.config.max_tokens,
299297
}
@@ -357,8 +355,6 @@ async def run_reinforcement_learning(
357355
self.config.num_steps,
358356
)
359357

360-
dataset.delete()
361-
362358
_print_progress("All training steps completed!")
363359

364360
with _spinner_progress("Deploying final trained model"):

src/judgeval/trainer/trainable_model.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from fireworks import LLM # type: ignore[import-not-found]
23
from .config import TrainerConfig, ModelConfig
34
from typing import Optional, Dict, Any, Callable
@@ -159,31 +160,39 @@ def advance_to_next_step(self, step: int):
159160
f"Failed to advance to training step {step}: {str(e)}"
160161
) from e
161162

162-
def perform_reinforcement_step(self, dataset, step: int):
163+
def perform_reinforcement_step(
164+
self, dataset, step: int, max_retries: int = 3, initial_backoff: float = 1.0
165+
):
163166
"""
164167
Perform a reinforcement learning step using the current model.
165168
166169
Args:
167170
dataset: Training dataset for the reinforcement step
168171
step: Current step number for output model naming
172+
max_retries: Maximum number of retry attempts (default: 3)
173+
initial_backoff: Initial backoff time in seconds for exponential backoff (default: 1.0)
169174
170175
Returns:
171176
Training job object
172177
"""
173-
try:
174-
model_name = f"{self.config.model_id}-v{step + 1}"
175-
return self._current_model.reinforcement_step(
176-
dataset=dataset,
177-
output_model=model_name,
178-
epochs=self.config.epochs,
179-
learning_rate=self.config.learning_rate,
180-
accelerator_count=self.config.accelerator_count,
181-
accelerator_type=self.config.accelerator_type,
182-
)
183-
except Exception as e:
184-
raise JudgmentRuntimeError(
185-
f"Failed to start reinforcement learning step {step + 1}: {str(e)}"
186-
) from e
178+
model_name = f"{self.config.model_id}-v{step + 1}"
179+
180+
for attempt in range(max_retries):
181+
try:
182+
return self._current_model.reinforcement_step(
183+
dataset=dataset,
184+
output_model=model_name,
185+
epochs=self.config.epochs,
186+
learning_rate=self.config.learning_rate,
187+
)
188+
except Exception as e:
189+
if attempt < max_retries - 1:
190+
backoff_time = initial_backoff * (2**attempt)
191+
time.sleep(backoff_time)
192+
else:
193+
raise JudgmentRuntimeError(
194+
f"Failed to start reinforcement learning step {step + 1} after {max_retries} attempts: {str(e)}"
195+
) from e
187196

188197
def get_model_config(
189198
self, training_params: Optional[Dict[str, Any]] = None

src/judgeval/v1/trainers/base_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
if TYPE_CHECKING:
77
from judgeval.v1.tracer.tracer import Tracer
8-
from judgeval.trainer.trainable_model import TrainableModel
9-
from judgeval.trainer.config import TrainerConfig, ModelConfig
8+
from judgeval.v1.trainers.trainable_model import TrainableModel
9+
from judgeval.v1.trainers.config import TrainerConfig, ModelConfig
1010
from judgeval.v1.scorers.base_scorer import BaseScorer
1111

1212

src/judgeval/v1/trainers/config.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Optional, Dict, Any
5+
import json
6+
7+
8+
@dataclass
9+
class TrainerConfig:
10+
"""Configuration class for JudgmentTrainer parameters."""
11+
12+
deployment_id: str
13+
user_id: str
14+
model_id: str
15+
base_model_name: str = "qwen2p5-7b-instruct"
16+
rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
17+
num_steps: int = 5
18+
num_generations_per_prompt: int = 4
19+
num_prompts_per_step: int = 4
20+
concurrency: int = 100
21+
epochs: int = 1
22+
learning_rate: float = 1e-5
23+
temperature: float = 1.5
24+
max_tokens: int = 50
25+
enable_addons: bool = True
26+
27+
28+
@dataclass
29+
class ModelConfig:
30+
"""
31+
Configuration class for storing and loading trained model state.
32+
33+
This class enables persistence of trained models so they can be loaded
34+
and used later without retraining.
35+
36+
Example usage:
37+
trainer = JudgmentTrainer(config)
38+
model_config = trainer.train(agent_function, scorers, prompts)
39+
40+
# Save the trained model configuration
41+
model_config.save_to_file("my_trained_model.json")
42+
43+
# Later, load and use the trained model
44+
loaded_config = ModelConfig.load_from_file("my_trained_model.json")
45+
trained_model = TrainableModel.from_model_config(loaded_config)
46+
47+
# Use the trained model for inference
48+
response = trained_model.chat.completions.create(
49+
model="current", # Uses the loaded trained model
50+
messages=[{"role": "user", "content": "Hello!"}]
51+
)
52+
"""
53+
54+
# Base model configuration
55+
base_model_name: str
56+
deployment_id: str
57+
user_id: str
58+
model_id: str
59+
enable_addons: bool
60+
61+
# Training state
62+
current_step: int
63+
total_steps: int
64+
65+
# Current model information
66+
current_model_name: Optional[str] = None
67+
is_trained: bool = False
68+
69+
# Training parameters used (for reference)
70+
training_params: Optional[Dict[str, Any]] = None
71+
72+
def to_dict(self) -> Dict[str, Any]:
73+
"""Convert ModelConfig to dictionary for serialization."""
74+
return {
75+
"base_model_name": self.base_model_name,
76+
"deployment_id": self.deployment_id,
77+
"user_id": self.user_id,
78+
"model_id": self.model_id,
79+
"enable_addons": self.enable_addons,
80+
"current_step": self.current_step,
81+
"total_steps": self.total_steps,
82+
"current_model_name": self.current_model_name,
83+
"is_trained": self.is_trained,
84+
"training_params": self.training_params,
85+
}
86+
87+
@classmethod
88+
def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
89+
"""Create ModelConfig from dictionary."""
90+
return cls(
91+
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
92+
deployment_id=data.get("deployment_id", "my-base-deployment"),
93+
user_id=data.get("user_id", ""),
94+
model_id=data.get("model_id", ""),
95+
enable_addons=data.get("enable_addons", True),
96+
current_step=data.get("current_step", 0),
97+
total_steps=data.get("total_steps", 0),
98+
current_model_name=data.get("current_model_name"),
99+
is_trained=data.get("is_trained", False),
100+
training_params=data.get("training_params"),
101+
)
102+
103+
def to_json(self) -> str:
104+
"""Convert ModelConfig to JSON string."""
105+
return json.dumps(self.to_dict(), indent=2)
106+
107+
@classmethod
108+
def from_json(cls, json_str: str) -> ModelConfig:
109+
"""Create ModelConfig from JSON string."""
110+
data = json.loads(json_str)
111+
return cls.from_dict(data)
112+
113+
def save_to_file(self, filepath: str):
114+
"""Save ModelConfig to a JSON file."""
115+
with open(filepath, "w") as f:
116+
f.write(self.to_json())
117+
118+
@classmethod
119+
def load_from_file(cls, filepath: str) -> ModelConfig:
120+
"""Load ModelConfig from a JSON file."""
121+
with open(filepath, "r") as f:
122+
json_str = f.read()
123+
return cls.from_json(json_str)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from contextlib import contextmanager
2+
from typing import Optional
3+
import sys
4+
import os
5+
from judgeval.utils.decorators.use_once import use_once
6+
7+
8+
@use_once
9+
def _is_jupyter_environment():
10+
"""Check if we're running in a Jupyter notebook or similar environment."""
11+
try:
12+
# Check for IPython kernel
13+
if "ipykernel" in sys.modules or "IPython" in sys.modules:
14+
return True
15+
# Check for Jupyter environment variables
16+
if "JPY_PARENT_PID" in os.environ:
17+
return True
18+
# Check if we're in Google Colab
19+
if "google.colab" in sys.modules:
20+
return True
21+
return False
22+
except Exception:
23+
return False
24+
25+
26+
IS_JUPYTER = _is_jupyter_environment()
27+
28+
if not IS_JUPYTER:
29+
try:
30+
from rich.console import Console
31+
from rich.spinner import Spinner
32+
from rich.live import Live
33+
from rich.text import Text
34+
35+
shared_console = Console()
36+
RICH_AVAILABLE = True
37+
except ImportError:
38+
RICH_AVAILABLE = False
39+
else:
40+
RICH_AVAILABLE = False
41+
42+
43+
class SimpleSpinner:
44+
def __init__(self, name, text):
45+
self.text = text
46+
47+
48+
class SimpleLive:
49+
def __init__(self, spinner, console=None, refresh_per_second=None):
50+
self.spinner = spinner
51+
52+
def __enter__(self):
53+
print(f"🔄 {self.spinner.text}")
54+
return self
55+
56+
def __exit__(self, *args):
57+
pass
58+
59+
def update(self, spinner):
60+
print(f"🔄 {spinner.text}")
61+
62+
63+
def safe_print(message, style=None):
64+
"""Safe print function that works in all environments."""
65+
if RICH_AVAILABLE and not IS_JUPYTER:
66+
shared_console.print(message, style=style)
67+
else:
68+
if style == "green":
69+
print(f"✅ {message}")
70+
elif style == "yellow":
71+
print(f"⚠️ {message}")
72+
elif style == "blue":
73+
print(f"🔵 {message}")
74+
elif style == "cyan":
75+
print(f"🔷 {message}")
76+
else:
77+
print(message)
78+
79+
80+
@contextmanager
81+
def _spinner_progress(
82+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
83+
):
84+
"""Context manager for spinner-based progress display."""
85+
if step is not None and total_steps is not None:
86+
full_message = f"[Step {step}/{total_steps}] {message}"
87+
else:
88+
full_message = f"[Training] {message}"
89+
90+
if RICH_AVAILABLE and not IS_JUPYTER:
91+
spinner = Spinner("dots", text=Text(full_message, style="cyan"))
92+
with Live(spinner, console=shared_console, refresh_per_second=10):
93+
yield
94+
else:
95+
print(f"🔄 {full_message}")
96+
try:
97+
yield
98+
finally:
99+
print(f"✅ {full_message} - Complete")
100+
101+
102+
@contextmanager
103+
def _model_spinner_progress(message: str):
104+
"""Context manager for model operation spinner-based progress display."""
105+
if RICH_AVAILABLE and not IS_JUPYTER:
106+
spinner = Spinner("dots", text=Text(f"[Model] {message}", style="blue"))
107+
with Live(spinner, console=shared_console, refresh_per_second=10) as live:
108+
109+
def update_progress(progress_message: str):
110+
"""Update the spinner with a new progress message."""
111+
new_text = f"[Model] {message}\n └─ {progress_message}"
112+
spinner.text = Text(new_text, style="blue")
113+
live.update(spinner)
114+
115+
yield update_progress
116+
else:
117+
print(f"🔵 [Model] {message}")
118+
119+
def update_progress(progress_message: str):
120+
print(f" └─ {progress_message}")
121+
122+
yield update_progress
123+
124+
125+
def _print_progress(
126+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
127+
):
128+
"""Print progress message with consistent formatting."""
129+
if step is not None and total_steps is not None:
130+
safe_print(f"[Step {step}/{total_steps}] {message}", style="green")
131+
else:
132+
safe_print(f"[Training] {message}", style="green")
133+
134+
135+
def _print_progress_update(
136+
message: str, step: Optional[int] = None, total_steps: Optional[int] = None
137+
):
138+
"""Print progress update message (for status changes during long operations)."""
139+
safe_print(f" └─ {message}", style="yellow")
140+
141+
142+
def _print_model_progress(message: str):
143+
"""Print model progress message with consistent formatting."""
144+
safe_print(f"[Model] {message}", style="blue")

0 commit comments

Comments
 (0)