Skip to content

Commit 77ccb16

Browse files
snimuwillccbb
andauthored
add trajectory_id to TrajectoryStep (#675)
* add trajectory_id to TrajectoryStep * rename key * update test --------- Co-authored-by: William Brown <williambrown97@gmail.com>
1 parent 7f04c69 commit 77ccb16

File tree

10 files changed

+43
-0
lines changed

10 files changed

+43
-0
lines changed

docs/release/TRAJECTORIES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ async def add_model_response(
359359
tokens=tokens,
360360
reward=None,
361361
advantage=None,
362+
is_truncated=False,
363+
trajectory_id=state["current_trajectory_id"],
362364
extras={},
363365
)
364366
state["trajectory"].append(trajectory_step)

tests/test_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ async def rollout(
5656
tokens=tokens,
5757
reward=None,
5858
advantage=None,
59+
is_truncated=False,
60+
trajectory_id=state["trajectory_id"],
5961
extras={},
6062
)
6163
state["trajectory"].append(trajectory_step)

tests/test_environment_extra.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ async def rollout(
6767
tokens=tokens,
6868
reward=None,
6969
advantage=None,
70+
is_truncated=False,
71+
trajectory_id=state["trajectory_id"],
7072
extras={},
7173
)
7274
state["trajectory"].append(trajectory_step)

tests/test_rlm_env.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,8 @@ async def test_prepends_trajectory_steps_during_cleanup(self, rlm_env):
11891189
tokens=None,
11901190
reward=None,
11911191
advantage=None,
1192+
is_truncated=False,
1193+
trajectory_id="sub_batch1_req1",
11921194
extras={"is_sub_llm_call": True, "timestamp": 1.0},
11931195
)
11941196
sub_step2 = TrajectoryStep(
@@ -1198,6 +1200,8 @@ async def test_prepends_trajectory_steps_during_cleanup(self, rlm_env):
11981200
tokens=None,
11991201
reward=None,
12001202
advantage=None,
1203+
is_truncated=False,
1204+
trajectory_id="sub_batch1_req2",
12011205
extras={"is_sub_llm_call": True, "timestamp": 2.0},
12021206
)
12031207
rlm_env.active_rollouts[rollout_id] = {
@@ -1213,6 +1217,8 @@ async def test_prepends_trajectory_steps_during_cleanup(self, rlm_env):
12131217
tokens=None,
12141218
reward=None,
12151219
advantage=None,
1220+
is_truncated=False,
1221+
trajectory_id="main_trajectory",
12161222
extras={},
12171223
)
12181224
state = {"rollout_id": rollout_id, "trajectory": [main_step]}
@@ -1251,6 +1257,8 @@ async def test_no_prepend_when_disabled(self, mock_sandbox_client, mock_dataset)
12511257
tokens=None,
12521258
reward=None,
12531259
advantage=None,
1260+
is_truncated=False,
1261+
trajectory_id="sub_batch1_req1",
12541262
extras={"is_sub_llm_call": True, "timestamp": 1.0},
12551263
)
12561264
env.active_rollouts[rollout_id] = {
@@ -1265,6 +1273,8 @@ async def test_no_prepend_when_disabled(self, mock_sandbox_client, mock_dataset)
12651273
tokens=None,
12661274
reward=None,
12671275
advantage=None,
1276+
is_truncated=False,
1277+
trajectory_id="main_trajectory",
12681278
extras={},
12691279
)
12701280
state = {"rollout_id": rollout_id, "trajectory": [main_step]}

tests/test_singleturn_env.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ async def test_is_completed_method(self, mock_singleturn_env):
7777
tokens=None,
7878
reward=None,
7979
advantage=None,
80+
is_truncated=False,
81+
trajectory_id="test_trajectory",
8082
extras={},
8183
)
8284
],
@@ -487,6 +489,8 @@ async def test_singleturn_stops_after_one_response(
487489
tokens=None,
488490
reward=None,
489491
advantage=None,
492+
is_truncated=False,
493+
trajectory_id="test_trajectory",
490494
extras={},
491495
)
492496
]
@@ -514,6 +518,8 @@ async def test_singleturn_stops_after_one_response(
514518
tokens=None,
515519
reward=None,
516520
advantage=None,
521+
is_truncated=False,
522+
trajectory_id="test_trajectory",
517523
extras={},
518524
),
519525
TrajectoryStep(
@@ -523,6 +529,8 @@ async def test_singleturn_stops_after_one_response(
523529
tokens=None,
524530
reward=None,
525531
advantage=None,
532+
is_truncated=False,
533+
trajectory_id="test_trajectory",
526534
extras={},
527535
),
528536
]

tests/test_trajectory_processing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ def test_process_trajectory_steps_for_training():
110110
completion_ids=[3, 4],
111111
completion_mask=[1, 1],
112112
completion_logprobs=[-0.1, -0.2],
113+
overlong_prompt=False,
114+
is_truncated=False,
113115
),
114116
reward=1.0,
115117
advantage=None,
118+
is_truncated=False,
119+
trajectory_id="test_trajectory",
116120
extras={},
117121
)
118122
]
@@ -135,9 +139,13 @@ def test_process_trajectory_steps_for_training():
135139
completion_ids=[6, 7, 8],
136140
completion_mask=[1, 1, 1],
137141
completion_logprobs=[-0.3, -0.4, -0.5],
142+
overlong_prompt=False,
143+
is_truncated=False,
138144
),
139145
reward=0.5,
140146
advantage=None,
147+
is_truncated=False,
148+
trajectory_id="test_trajectory",
141149
extras={},
142150
)
143151
]
@@ -192,6 +200,8 @@ def test_process_trajectory_steps_skip_missing_tokens():
192200
tokens=None,
193201
reward=1.0,
194202
advantage=None,
203+
is_truncated=False,
204+
trajectory_id="test_trajectory",
195205
extras={},
196206
),
197207
TrajectoryStep(
@@ -204,9 +214,13 @@ def test_process_trajectory_steps_skip_missing_tokens():
204214
completion_ids=[2, 3],
205215
completion_mask=[1, 1],
206216
completion_logprobs=[-0.1, -0.2],
217+
overlong_prompt=False,
218+
is_truncated=False,
207219
),
208220
reward=0.5,
209221
advantage=None,
222+
is_truncated=False,
223+
trajectory_id="test_trajectory",
210224
extras={},
211225
),
212226
]

verifiers/envs/environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import signal
88
import time
9+
import uuid
910
from abc import ABC, abstractmethod
1011
from concurrent.futures import ThreadPoolExecutor
1112
from copy import deepcopy
@@ -597,6 +598,7 @@ async def init_state(
597598
else:
598599
state["oai_tools"] = []
599600
state["trajectory"] = []
601+
state["trajectory_id"] = uuid.uuid4().hex
600602
state["reward"] = None
601603
state["metrics"] = None
602604
state["error"] = None

verifiers/envs/experimental/rlm_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ async def _handle_sub_llm_request(self, request: Any) -> Any:
939939
reward=None,
940940
advantage=None,
941941
is_truncated=is_truncated,
942+
trajectory_id=f"{batch_id}_{request_id}",
942943
extras={
943944
"is_sub_llm_call": True,
944945
"parent_turn": parent_turn,

verifiers/envs/multiturn_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def add_model_response(
8585
reward=None,
8686
advantage=None,
8787
is_truncated=is_truncated,
88+
trajectory_id=state["trajectory_id"],
8889
extras={},
8990
)
9091
trajectory_step["completion"] = completion_messages

verifiers/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class TrajectoryStep(TypedDict):
6868
reward: float | None
6969
advantage: float | None
7070
is_truncated: bool
71+
trajectory_id: str
7172
extras: dict[str, Any]
7273

7374

0 commit comments

Comments
 (0)