Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 40 additions & 60 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ class AgentState:
class ToolStart:
name: str
inputs: dict
tool_id: str = ""

@dataclass
class ToolEnd:
name: str
result: str
permitted: bool = True
duration: float = 0.0
tool_id: str = ""
inputs: dict = field(default_factory=dict)

@dataclass
class TurnDone:
Expand Down Expand Up @@ -201,12 +205,21 @@ def run(
if not assistant_turn.tool_calls:
break # No tools → conversation turn complete

# ── Execute tools (parallel when safe) ────────────────────────────
tool_calls = assistant_turn.tool_calls
# ── Uniquify ids to prevent GC collisions ─────────────────────────
from id_uniquify import uniquify_tool_call_ids
uniquify_tool_call_ids(assistant_turn.tool_calls, state)

# Check permissions first (must be sequential — may prompt user)
# Deduplicate tool calls by ID (model may echo duplicates)
_seen_ids: set[str] = set()
tool_calls = [tc for tc in assistant_turn.tool_calls
if tc["id"] not in _seen_ids and not _seen_ids.add(tc["id"])]
state.messages[-1]["tool_calls"] = tool_calls

# ── Check permissions (sequential — may prompt user) ──────────────
permissions: dict[str, bool] = {}
denied_results: dict[str, str] = {}
for tc in tool_calls:
yield ToolStart(tc["name"], tc["input"], tool_id=tc["id"])
permitted = _check_permission(tc, config)
if not permitted:
if config.get("permission_mode") == "plan":
Expand All @@ -216,80 +229,47 @@ def run(
yield req
permitted = req.granted
permissions[tc["id"]] = permitted

# Determine which tools can run in parallel
from tool_registry import get_tool as _get_tool
parallel_batch = []
sequential_batch = []
for tc in tool_calls:
if not permissions[tc["id"]]:
sequential_batch.append(tc)
continue
tdef = _get_tool(tc["name"])
if tdef and tdef.concurrent_safe and len(tool_calls) > 1:
parallel_batch.append(tc)
else:
sequential_batch.append(tc)

def _exec_one(tc):
"""Execute a single tool call, return (tc, result, permitted)."""
tid = tc["id"]
permitted = permissions[tid]
if not permitted:
if config.get("permission_mode") == "plan":
plan_file = runtime.get_ctx(config).plan_file or ""
result = (
denied_results[tc["id"]] = (
f"[Plan mode] Write operations are blocked except to the plan file: {plan_file}\n"
"Finish your analysis and write the plan to the plan file. "
"The user will run /plan done to exit plan mode and begin implementation."
)
else:
result = "Denied: user rejected this operation"
else:
result = execute_tool(
tc["name"], tc["input"],
permission_mode="accept-all",
config=config,
)
return tc, result, permitted

results_ordered = []

# Run parallel batch concurrently
if parallel_batch:
from concurrent.futures import ThreadPoolExecutor
for tc in parallel_batch:
yield ToolStart(tc["name"], tc["input"])
with ThreadPoolExecutor(max_workers=min(len(parallel_batch), 8)) as pool:
futures = {pool.submit(_exec_one, tc): tc for tc in parallel_batch}
for future in futures:
tc, result, permitted = future.result()
_log.debug("tool_end", session_id=session_id,
tool=tc["name"], permitted=permitted,
result_len=len(result))
results_ordered.append((tc, result, permitted))

# Run sequential batch one by one
for tc in sequential_batch:
yield ToolStart(tc["name"], tc["input"])
_log.debug("tool_start", session_id=session_id,
tool=tc["name"], input_keys=list(tc["input"].keys()))
tc, result, permitted = _exec_one(tc)
_log.debug("tool_end", session_id=session_id,
tool=tc["name"], permitted=permitted,
result_len=len(result))
results_ordered.append((tc, result, permitted))
denied_results[tc["id"]] = "Denied: user rejected this operation"

# ── Execute tools via DAG (parallel when safe) ────────────────────
from dag import _build_dag_levels, _execute_level

permitted_tcs = [tc for tc in tool_calls if permissions[tc["id"]]]
results: dict[str, str] = dict(denied_results)
durations: dict[str, float] = {tc["id"]: 0.0 for tc in tool_calls}

levels, deps = _build_dag_levels(permitted_tcs)
for level in levels:
_execute_level(level, results, durations, config)

# Yield results and append to state in original order
for tc, result, permitted in results_ordered:
yield ToolEnd(tc["name"], result, permitted)
for tc in tool_calls:
if tc["id"] not in results:
continue
result = results[tc["id"]]
yield ToolEnd(tc["name"], result, permissions[tc["id"]],
durations[tc["id"]], tool_id=tc["id"], inputs=tc["input"])
state.messages.append({
"role": "tool",
"tool_call_id": tc["id"],
"name": tc["name"],
"content": result,
})

# ContextGC is terminal — if it's the only thing called,
# don't loop back to the LLM or it re-calls GC forever.
if tool_calls and all(tc["name"] == "ContextGC" for tc in tool_calls):
break


# ── Helpers ───────────────────────────────────────────────────────────────

Expand Down
157 changes: 157 additions & 0 deletions dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Build a dependency DAG from tool calls and execute them level-by-level with parallelism.

Respects concurrent_safe flags: safe tools run in parallel within a level,
unsafe tools run sequentially after the parallel batch completes.
"""
from __future__ import annotations

import os
import time
from concurrent.futures import ThreadPoolExecutor

from tools import execute_tool


def _build_alias_map(tool_calls: list[dict]) -> dict[str, str]:
alias_to_id: dict[str, str] = {}
for tc in tool_calls:
alias = tc["input"].pop("tool_call_alias", None)
if alias:
alias_to_id[alias] = tc["id"]
return alias_to_id


def _build_dag_levels(tool_calls: list[dict]) -> tuple[list[list[dict]], dict[str, set[str]]]:
"""Return (levels, deps). Levels order tcs by dependency; deps maps tc_id -> its prerequisites."""
if not tool_calls:
return [], {}

by_id: dict[str, dict] = {tc["id"]: tc for tc in tool_calls}
tc_ids_in_turn = set(by_id.keys())

alias_to_id = _build_alias_map(tool_calls)

deps: dict[str, set[str]] = {}
for tc in tool_calls:
raw = tc["input"].pop("depends_on", None) or []
resolved = [alias_to_id.get(d, d) for d in raw]
deps[tc["id"]] = {d for d in resolved if d in tc_ids_in_turn}

_add_implicit_write_deps(tool_calls, deps)

remaining = set(by_id.keys())
levels: list[list[dict]] = []
while remaining:
ready = {nid for nid in remaining
if not (deps.get(nid, set()) & remaining)}
if not ready:
levels.append([by_id[nid] for nid in remaining])
break
levels.append([by_id[nid] for nid in ready])
remaining -= ready
return levels, deps


def _compute_downstream(deps: dict[str, set[str]], seed_ids: set[str]) -> set[str]:
"""Return seed_ids plus every tc_id transitively depending on any seed."""
dependents: dict[str, set[str]] = {}
for tc_id, prereqs in deps.items():
for prereq in prereqs:
dependents.setdefault(prereq, set()).add(tc_id)

downstream = set(seed_ids)
frontier = list(seed_ids)
while frontier:
node = frontier.pop()
for child in dependents.get(node, ()):
if child not in downstream:
downstream.add(child)
frontier.append(child)
return downstream


def _add_implicit_write_deps(
tool_calls: list[dict],
deps: dict[str, set[str]],
) -> None:
last_write: dict[str, str] = {}
for tc in tool_calls:
if tc["name"] not in ("Write", "Edit", "NotebookEdit"):
continue
fp = os.path.normpath(tc["input"].get("file_path", tc["input"].get("notebook_path", "")))
if not fp:
continue
prev = last_write.get(fp)
if prev is not None:
deps.setdefault(tc["id"], set()).add(prev)
last_write[fp] = tc["id"]


def _execute_level(
level: list[dict],
results: dict[str, str],
durations: dict[str, float],
config: dict,
) -> None:
from tool_registry import is_concurrent_safe

if len(level) == 1:
tc = level[0]
t0 = time.monotonic()
results[tc["id"]] = execute_tool(
tc["name"], tc["input"],
permission_mode="accept-all", config=config)
durations[tc["id"]] = time.monotonic() - t0
return

parallel_tcs = [tc for tc in level if is_concurrent_safe(tc["name"])]
sequential_tcs = [tc for tc in level if not is_concurrent_safe(tc["name"])]

if parallel_tcs:
_run_parallel(parallel_tcs, results, durations, config)

for tc in sequential_tcs:
t0 = time.monotonic()
results[tc["id"]] = execute_tool(
tc["name"], tc["input"],
permission_mode="accept-all", config=config)
durations[tc["id"]] = time.monotonic() - t0


def _run_parallel(
tcs: list[dict],
results: dict[str, str],
durations: dict[str, float],
config: dict,
) -> None:
if len(tcs) == 1:
tc = tcs[0]
t0 = time.monotonic()
results[tc["id"]] = execute_tool(
tc["name"], tc["input"],
permission_mode="accept-all", config=config)
durations[tc["id"]] = time.monotonic() - t0
return

pool = ThreadPoolExecutor(max_workers=len(tcs))
start_times: dict[str, float] = {}
for tc in tcs:
start_times[tc["id"]] = time.monotonic()
futures = {
pool.submit(execute_tool, tc["name"], tc["input"],
"accept-all", None, config): tc
for tc in tcs
}
try:
remaining_futs = set(futures)
while remaining_futs:
newly_done = {f for f in remaining_futs if f.done()}
for fut in newly_done:
tc = futures[fut]
results[tc["id"]] = fut.result()
durations[tc["id"]] = time.monotonic() - start_times[tc["id"]]
remaining_futs -= newly_done
if remaining_futs:
time.sleep(0.1)
finally:
pool.shutdown(wait=False, cancel_futures=True)
76 changes: 76 additions & 0 deletions id_uniquify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Prevent ContextGC auto-stubbing of fresh tool_results.

The model picks short ids (e.g. `r1`, `w2`) per the XML tool protocol and
reuses them freely across turns. Once an id lands in `gc_state.trashed_ids`,
any future tool_result with the same id is stubbed by `apply_gc`, even though
it is a completely different tool call. The model then sees `[Read result --
trashed by model]` instead of the file it just asked for.

The safe fix is to uniquify the id on ingest. If the incoming id clashes with
any id already present in `state.messages` or in `gc_state.trashed_ids`, we
rewrite it to `t{turn}_{original}` (with a numeric suffix if that still
clashes). Same-turn `depends_on` references are rewritten in lockstep so the
DAG resolves correctly.
"""
from __future__ import annotations


def _collect_used_ids(state) -> set[str]:
used: set[str] = set()
gc_state = getattr(state, "gc_state", None)
if gc_state is not None:
used.update(gc_state.trashed_ids)
used.update(gc_state.snippets.keys())
for msg in state.messages:
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
tid = tc.get("id")
if tid:
used.add(tid)
elif role == "tool":
tid = msg.get("tool_call_id")
if tid:
used.add(tid)
return used


def _pick_fresh_id(original: str, turn: int, used: set[str]) -> str:
candidate = f"t{turn}_{original}"
if candidate not in used:
return candidate
suffix = 2
while f"{candidate}_{suffix}" in used:
suffix += 1
return f"{candidate}_{suffix}"


def uniquify_tool_call_ids(tool_calls: list, state) -> dict[str, str]:
"""Rewrite colliding tool_call ids in-place and rewrite same-turn depends_on refs.

Only ids that already exist in state.messages or gc_state.trashed_ids are
remapped. Ids the model has never used before pass through unchanged,
preserving behavior for simple sessions and existing tests."""
if not tool_calls:
return {}
used = _collect_used_ids(state)
remap: dict[str, str] = {}
for tc in tool_calls:
original = tc.get("id")
if not original or original not in used:
if original:
used.add(original)
continue
fresh = _pick_fresh_id(original, state.turn_count, used)
remap[original] = fresh
tc["id"] = fresh
used.add(fresh)

if not remap:
return {}
for tc in tool_calls:
params = tc.get("input") or {}
deps = params.get("depends_on")
if deps:
params["depends_on"] = [remap.get(d, d) for d in deps]
return remap
Loading
Loading