Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
055e426
chore: bump sglang 0.5.5.post1 -> 0.5.12.post1 (FSDP path)
WWWjiahui May 29, 2026
b1bf6de
Merge pull request #5 from WWWjiahui/chore/bump-sglang-0.5.12
haizhongzheng May 29, 2026
705112b
fix: relax RaaS engine health watchdog for sglang 0.5.12
haizhongzheng May 29, 2026
9234802
docs: fix CUDA 13 install steps for flash-attn build and sglang resol…
haizhongzheng May 29, 2026
ea86fa2
fix: pin kernels<0.13 for transformers 5.6.1 compatibility
haizhongzheng May 29, 2026
851b513
feat(megatron): streaming Megatron->HF per-tensor weight export
jsw-zorro May 29, 2026
88c5068
feat(megatron): weight sync via HF-space buffer (PP/EP/VPP)
jsw-zorro May 29, 2026
84e6b19
feat(examples): add Qwen3-8B Megatron math RL recipe
jsw-zorro May 29, 2026
7ebc1d4
perf(megatron): direct-DMA weight offload (~23x faster)
jsw-zorro May 29, 2026
77410df
fix: auto-select sglang attention backend and norm path by GPU arch
haizhongzheng May 30, 2026
02c4335
docs: align docker run shm-size and add --ulimit nofile
haizhongzheng May 30, 2026
23fe945
Merge pull request #9 from jsw-zorro/feat/megatron-weight-sync-dev
haizhongzheng Jun 2, 2026
5c5d91f
feat: add spawn-sub-agents workflow for math RL
haizhongzheng May 26, 2026
9b49df6
feat: add TextCraft recursive-agent workflow + Qwen3-4B recipe
haizhongzheng May 27, 2026
d4065c8
feat: add Oolong recursive-agent workflow + Qwen3-4B recipe
haizhongzheng May 27, 2026
bd5e94e
feat: add TextCraft recursive-agent recipe variants (gen4k, lr5e6)
haizhongzheng May 27, 2026
adbb4fb
feat: add minimal LLM-as-judge library at astraEnv/judge.py
haizhongzheng May 27, 2026
b0e2a31
feat: add reward_mode selector + LLM judge for Oolong sub-agents
haizhongzheng May 27, 2026
b3b6424
feat: add CMU RAG search client and ai-rubric checklist grader to ast…
haizhongzheng May 28, 2026
9bcfff7
feat: add DeepDive recursive-agent workflow + Qwen3-4B recipe
haizhongzheng May 28, 2026
42481f3
feat: implement oolong-real D&D grader + fix sub-agent reward routing
haizhongzheng May 28, 2026
b5b4042
feat: let workflows opt out of producer's default group-reward stats
haizhongzheng May 28, 2026
bf27d70
feat: enrich DeepDive dumps and add qwen3-4b-recursive-v7 recipe
haizhongzheng Jun 2, 2026
d2b05bd
fix: normalize apply_chat_template output to token ids for transforme…
haizhongzheng Jun 3, 2026
f427cd9
fix: cap textcraft recursive max_concurrent_rollouts at 512
haizhongzheng Jun 3, 2026
cf7b270
chore: remove oolong workflow, reward, dataset, and recipes
haizhongzheng Jun 3, 2026
48f5205
chore: remove unused deepdive and textcraft variant recipes
haizhongzheng Jun 3, 2026
0a4ed98
chore: bump version to 0.1.1
haizhongzheng Jun 3, 2026
c1e9f52
Merge pull request #12 from Infini-AI-Lab/spawn-solution
haizhongzheng Jun 3, 2026
5ed2499
refactor: rename textcraft example dir to textcraft-recursive-agent
haizhongzheng Jun 3, 2026
aab15a2
docs: add 8-agent textcraft recursive episode for animations
haizhongzheng Jun 3, 2026
2ba1ecd
fix(megatron): build Transformer Engine from source for CUDA 13
jsw-zorro Jun 3, 2026
6e42bd3
docs: document the pre-built Megatron Docker image
haizhongzheng Jun 5, 2026
ab8fa84
fix: harden Megatron weight offload and textcraft reward default
haizhongzheng Jun 5, 2026
1e0d5d6
docs: add TextCraft recursive-agent recipe README and docs
haizhongzheng Jun 5, 2026
7baeaab
docs: move textcraft recipe README to parent example folder
haizhongzheng Jun 5, 2026
0844c6e
docs: add textcraft recipe README at parent example folder
haizhongzheng Jun 5, 2026
6d24d31
docs: announce dynamic recursive-agent recipe in README news
haizhongzheng Jun 5, 2026
1f11b72
Merge pull request #14 from jsw-zorro/fix/megatron-cuda13-te-build
haizhongzheng Jun 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ data-data
data-log
tmp-yaml/
issue-draft/
claude-doc/
.pip-tmp/
*.pid
*.pdf
Expand Down Expand Up @@ -230,3 +231,9 @@ evaluation/data/AReaL-boba-2-RL-Code
tmp*
torchelastic_*
torchinductor_*

# Oolong HF dataset cache (auto-downloaded, multi-GB)
astraflow/core/workflow/impl/oolong/oolong_*.jsonl

# DeepDive HF dataset cache (auto-downloaded)
astraflow/core/workflow/impl/deepdive/deepdive_*.jsonl
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ AstraFlow **natively** supports the following for LLM RL training **without any
<!-- <p align="center"><i>AstraFlow training a multi-policy workflow on an elastic, heterogeneous, cross-region rollout pool — all at once, with no feature-specific code.</i></p> -->

## News
- **[2026/06]** New recipe: **dynamic recursive agent** on TextCraft — a multi-turn agent that recursively spawns sub-agents sharing inventory under a team reward. See the [recipe docs](https://Infini-AI-Lab.github.io/astraflow/docs/en/recipes/textcraft-recursive.html).
- **[2026/06]** AstraFlow **v0.1.1** released — CUDA 13 image, SGLang 0.5.12, Megatron weight-sync training backend, and transformers 5 support. See the [project website](https://Infini-AI-Lab.github.io/astraflow/).
- **[2026/05]** AstraFlow **v0.1.0** released — first public release of the full system. See the [project website](https://Infini-AI-Lab.github.io/astraflow/).
- **[2026/05]** AstraFlow paper is on [arXiv](https://arxiv.org/abs/2605.15565).

Expand Down
172 changes: 172 additions & 0 deletions astraEnv/checklist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Single-call auto-checklist grader — local replacement for
``ai-rubric``'s ``rubric.core.checklist.RubricChecklistFast``.

Matches the upstream package's behavior:

- **One LLM call** (not two) — the model generates the checklist and
scores every item in a single response.
- **Continuous per-item scores** (0-1, not binary pass/fail) so the LLM
can reflect partial satisfaction.
- **Holistic ``overall_score``** chosen by the LLM, not mechanical
``passed / total`` (lets critical items dominate non-critical ones).
- **No caching** — fresh checklist every call (matches upstream).

System prompt ported verbatim from
``ai_rubric-0.2.4/rubric/prompts/generate-rubric-checklist-fast-system.jinja``.

Usage::

from astraEnv.checklist import ChecklistGrader

grader = ChecklistGrader(goal="find the actor's birth year")
score, reason = await grader.aevaluate(context=trajectory_text)

Uses ``astraEnv.judge.judge`` for the LLM call so we keep our retry / key
handling. Temperature defaults to 1.0 for parity with the upstream package.
"""

from __future__ import annotations

from typing import Any

from astraEnv.judge import extract_json, judge


# Verbatim from ai-rubric 0.2.4:
# rubric/prompts/generate-rubric-checklist-fast-system.jinja
_SYSTEM_PROMPT = (
"We are building a rubric to evaluate a task. We will do this by "
"decomposing success criteria for the task into a checklist\n"
"and reasoning about the task success using this checklist. The "
"checklist should comprehensively test that the task is successfully "
"completed.\n\n"
"The rubric checklist should be as comprehensive as possible, and "
"should be able to evaluate the task in a way that is fair and accurate.\n\n"
"The rubric checklist should be as concise as possible, and should be "
"able to be easily understood by a human.\n\n"
"The rubric checklist should be as easy to evaluate as possible.\n\n"
"To evaluate a task on a checklist, you may consider the following "
"procedure:\n"
"1. For each criterion, reason whether it is critical or non-critical.\n"
"2. For each criterion, provide a score between 0 and 1 for how well "
"the task satisfies the criterion.\n"
"3. Consider the overall progress towards task completion and allow "
"for partial credit when generating the overall score.\n\n"
"# Output Format\n"
"```json\n"
"{\n"
' "checklist": [\n'
' "...", // a list of strings\n'
" ],\n"
' "checklist_scores": [\n'
" 0.0, // between 0 and 1\n"
" ],\n"
' "reasoning": "...",\n'
' "overall_score": 0.0 // between 0 and 1\n'
"}\n"
"```"
)


def _build_user_prompt(task: str, context: str) -> str:
"""Mirrors generate-rubric-checklist-fast-user.jinja."""
return f"# Task\n{task}\n\n{context}\n\n# Your Evaluation Output"


class ChecklistGrader:
"""Single-call checklist grader matching ai-rubric's RubricChecklistFast.

Parameters
----------
goal : str
The task goal the agent was given.
judge_model : str | None
Optional override for the judge model. None = astraEnv.judge default.
temperature : float
Sampling temperature. 1.0 matches the upstream package's default.
"""

def __init__(
self,
goal: str,
*,
judge_model: str | None = None,
temperature: float = 1.0,
):
self.goal = goal
self.judge_model = judge_model
self.temperature = temperature
# Most-recent parsed response — exposed for inspection / debugging.
self.last_checklist: list[str] = []
self.last_checklist_scores: list[float] = []
self.last_reasoning: str = ""
self.last_overall_score: float | None = None

def _judge_kwargs(self) -> dict[str, Any]:
kw: dict[str, Any] = {"temperature": self.temperature}
if self.judge_model:
kw["model"] = self.judge_model
return kw

async def aevaluate(self, *, context: str) -> tuple[float, str]:
"""Run one LLM call that generates+scores the checklist.

Returns
-------
score : float in [0, 1]
The LLM's holistic ``overall_score``.
reason : str
The LLM's reasoning. Empty string on failure.

On any failure (network, parse, out-of-range score) returns
``(0.0, error_message)`` — never raises.
"""
user = _build_user_prompt(self.goal, context)
try:
raw = await judge(
system=_SYSTEM_PROMPT, user=user, **self._judge_kwargs()
)
except Exception as e:
return 0.0, f"checklist call failed: {e}"

try:
parsed = extract_json(raw)
except Exception as e:
return 0.0, f"checklist response unparseable: {e}"

try:
overall = float(parsed.get("overall_score", 0.0))
except (TypeError, ValueError) as e:
return 0.0, f"overall_score not a number: {e}"

# Clamp defensively; the upstream package raises if out of [0,1],
# but we prefer to log and continue so a flaky judge response
# never crashes the rollout.
overall = max(0.0, min(1.0, overall))

# Stash for inspection.
checklist = parsed.get("checklist") or []
scores = parsed.get("checklist_scores") or []
self.last_checklist = [str(x) for x in checklist if isinstance(x, (str, int, float))]
self.last_checklist_scores = []
for s in scores:
try:
self.last_checklist_scores.append(float(s))
except (TypeError, ValueError):
continue
self.last_reasoning = str(parsed.get("reasoning", ""))
self.last_overall_score = overall

return overall, self.last_reasoning


async def grade_with_checklist(
goal: str,
context: str,
*,
judge_model: str | None = None,
temperature: float = 1.0,
) -> tuple[float, str]:
"""Convenience wrapper: build a grader and evaluate in one call."""
grader = ChecklistGrader(goal, judge_model=judge_model, temperature=temperature)
return await grader.aevaluate(context=context)
151 changes: 151 additions & 0 deletions astraEnv/judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Minimal LLM-as-a-judge utility.

Two functions. Both stateless.

- `judge(system, user, ...)` posts a (system, user) pair to Fireworks and
returns the raw assistant content string.
- `extract_json(text)` parses JSON out of an LLM response, tolerating
common code-fence wrapping.

Callers write their own rubric prompts and parse what they expect.
See claude-doc/minimal-llm-judge-plan.md for the design rationale.

Usage:
from astraEnv.judge import judge, extract_json

response = await judge(
system='You grade outputs. Return JSON {"score", "reason"}.',
user=f"Goal: {goal}\\n\\nOutput: {output}",
)
parsed = extract_json(response)
score = float(parsed["score"])

Requires the env var `FIREWORKS_API_KEY`.
"""

from __future__ import annotations

import asyncio
import json
import os
import re
from typing import Any

import httpx

_API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
_DEFAULT_MODEL = "accounts/fireworks/models/gpt-oss-120b"
_RETRY_STATUSES = {429, 500, 502, 503, 504}
_MAX_ATTEMPTS = 3


class JudgeError(RuntimeError):
"""Raised when the judge call cannot return a usable response."""


async def judge(
system: str,
user: str,
*,
model: str = _DEFAULT_MODEL,
temperature: float = 0.0,
max_tokens: int = 2048,
timeout_s: float = 60.0,
) -> str:
"""Send (system, user) to Fireworks; return the raw assistant content.

Retries up to 3 times with exponential backoff on transient failures
(429, 5xx, network errors). Raises JudgeError on persistent failure.

Default `max_tokens` is set generously (2048) because reasoning models
like gpt-oss-120b consume tokens for internal chain-of-thought before
emitting the final answer; too-tight budgets truncate before content.

For reasoning models that put their chain-of-thought into a separate
`reasoning_content` field, this function returns `content` if non-empty,
otherwise falls back to `reasoning_content`. extract_json() handles
both shapes.
"""
api_key = os.environ.get("FIREWORKS_API_KEY")
if not api_key:
raise JudgeError("FIREWORKS_API_KEY environment variable is not set")

payload = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
}
headers = {"Authorization": f"Bearer {api_key}"}

last_err: Exception | None = None
async with httpx.AsyncClient(timeout=timeout_s) as client:
for attempt in range(_MAX_ATTEMPTS):
try:
resp = await client.post(_API_URL, json=payload, headers=headers)
except httpx.RequestError as exc:
last_err = exc
await asyncio.sleep(2**attempt)
continue

if resp.status_code == 200:
try:
message = resp.json()["choices"][0]["message"]
except (KeyError, IndexError, ValueError) as exc:
raise JudgeError(
f"Unexpected response shape: {resp.text[:500]}"
) from exc
# Prefer the canonical `content` field. Reasoning models
# (e.g. gpt-oss-120b) may emit only `reasoning_content`
# when truncated; fall back to that so extract_json can
# still find a JSON snippet inside the chain-of-thought.
content = message.get("content") or message.get("reasoning_content")
if not content:
raise JudgeError(
f"Empty assistant content: {resp.text[:500]}"
)
return content

if resp.status_code in _RETRY_STATUSES:
last_err = JudgeError(
f"Fireworks returned {resp.status_code}: {resp.text[:200]}"
)
await asyncio.sleep(2**attempt)
continue

raise JudgeError(
f"Fireworks returned {resp.status_code}: {resp.text[:500]}"
)

raise JudgeError(
f"judge() failed after {_MAX_ATTEMPTS} attempts: {last_err}"
) from last_err


def extract_json(text: str) -> dict[str, Any]:
"""Parse JSON out of an LLM response, tolerating common fence wrapping.

Strategy (first success wins):
1. json.loads on the trimmed text
2. strip ```json ... ``` fences and retry
3. strip plain ``` ... ``` fences and retry
4. re-raise the original JSONDecodeError
"""
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass

fenced = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
if fenced:
return json.loads(fenced.group(1).strip())

fenced = re.search(r"```\s*(.*?)\s*```", text, re.DOTALL)
if fenced:
return json.loads(fenced.group(1).strip())

return json.loads(text)
Loading
Loading