diff --git a/agent.py b/agent.py index bf37c75..1a23916 100644 --- a/agent.py +++ b/agent.py @@ -40,12 +40,16 @@ class AgentState: class ToolStart: name: str inputs: dict + tool_id: str = "" @dataclass class ToolEnd: name: str result: str permitted: bool = True + duration: float = 0.0 + tool_id: str = "" + inputs: dict = field(default_factory=dict) @dataclass class TurnDone: @@ -201,12 +205,21 @@ def run( if not assistant_turn.tool_calls: break # No tools → conversation turn complete - # ── Execute tools (parallel when safe) ──────────────────────────── - tool_calls = assistant_turn.tool_calls + # ── Uniquify ids to prevent GC collisions ───────────────────────── + from id_uniquify import uniquify_tool_call_ids + uniquify_tool_call_ids(assistant_turn.tool_calls, state) - # Check permissions first (must be sequential — may prompt user) + # Deduplicate tool calls by ID (model may echo duplicates) + _seen_ids: set[str] = set() + tool_calls = [tc for tc in assistant_turn.tool_calls + if tc["id"] not in _seen_ids and not _seen_ids.add(tc["id"])] + state.messages[-1]["tool_calls"] = tool_calls + + # ── Check permissions (sequential — may prompt user) ────────────── permissions: dict[str, bool] = {} + denied_results: dict[str, str] = {} for tc in tool_calls: + yield ToolStart(tc["name"], tc["input"], tool_id=tc["id"]) permitted = _check_permission(tc, config) if not permitted: if config.get("permission_mode") == "plan": @@ -216,73 +229,35 @@ def run( yield req permitted = req.granted permissions[tc["id"]] = permitted - - # Determine which tools can run in parallel - from tool_registry import get_tool as _get_tool - parallel_batch = [] - sequential_batch = [] - for tc in tool_calls: - if not permissions[tc["id"]]: - sequential_batch.append(tc) - continue - tdef = _get_tool(tc["name"]) - if tdef and tdef.concurrent_safe and len(tool_calls) > 1: - parallel_batch.append(tc) - else: - sequential_batch.append(tc) - - def _exec_one(tc): - """Execute a single tool call, return (tc, result, permitted).""" - tid = tc["id"] - permitted = permissions[tid] if not permitted: if config.get("permission_mode") == "plan": plan_file = runtime.get_ctx(config).plan_file or "" - result = ( + denied_results[tc["id"]] = ( f"[Plan mode] Write operations are blocked except to the plan file: {plan_file}\n" "Finish your analysis and write the plan to the plan file. " "The user will run /plan done to exit plan mode and begin implementation." ) else: - result = "Denied: user rejected this operation" - else: - result = execute_tool( - tc["name"], tc["input"], - permission_mode="accept-all", - config=config, - ) - return tc, result, permitted - - results_ordered = [] - - # Run parallel batch concurrently - if parallel_batch: - from concurrent.futures import ThreadPoolExecutor - for tc in parallel_batch: - yield ToolStart(tc["name"], tc["input"]) - with ThreadPoolExecutor(max_workers=min(len(parallel_batch), 8)) as pool: - futures = {pool.submit(_exec_one, tc): tc for tc in parallel_batch} - for future in futures: - tc, result, permitted = future.result() - _log.debug("tool_end", session_id=session_id, - tool=tc["name"], permitted=permitted, - result_len=len(result)) - results_ordered.append((tc, result, permitted)) - - # Run sequential batch one by one - for tc in sequential_batch: - yield ToolStart(tc["name"], tc["input"]) - _log.debug("tool_start", session_id=session_id, - tool=tc["name"], input_keys=list(tc["input"].keys())) - tc, result, permitted = _exec_one(tc) - _log.debug("tool_end", session_id=session_id, - tool=tc["name"], permitted=permitted, - result_len=len(result)) - results_ordered.append((tc, result, permitted)) + denied_results[tc["id"]] = "Denied: user rejected this operation" + + # ── Execute tools via DAG (parallel when safe) ──────────────────── + from dag import _build_dag_levels, _execute_level + + permitted_tcs = [tc for tc in tool_calls if permissions[tc["id"]]] + results: dict[str, str] = dict(denied_results) + durations: dict[str, float] = {tc["id"]: 0.0 for tc in tool_calls} + + levels, deps = _build_dag_levels(permitted_tcs) + for level in levels: + _execute_level(level, results, durations, config) # Yield results and append to state in original order - for tc, result, permitted in results_ordered: - yield ToolEnd(tc["name"], result, permitted) + for tc in tool_calls: + if tc["id"] not in results: + continue + result = results[tc["id"]] + yield ToolEnd(tc["name"], result, permissions[tc["id"]], + durations[tc["id"]], tool_id=tc["id"], inputs=tc["input"]) state.messages.append({ "role": "tool", "tool_call_id": tc["id"], @@ -290,6 +265,11 @@ def _exec_one(tc): "content": result, }) + # ContextGC is terminal — if it's the only thing called, + # don't loop back to the LLM or it re-calls GC forever. + if tool_calls and all(tc["name"] == "ContextGC" for tc in tool_calls): + break + # ── Helpers ─────────────────────────────────────────────────────────────── diff --git a/dag.py b/dag.py new file mode 100644 index 0000000..a4b1fcb --- /dev/null +++ b/dag.py @@ -0,0 +1,157 @@ +"""Build a dependency DAG from tool calls and execute them level-by-level with parallelism. + +Respects concurrent_safe flags: safe tools run in parallel within a level, +unsafe tools run sequentially after the parallel batch completes. +""" +from __future__ import annotations + +import os +import time +from concurrent.futures import ThreadPoolExecutor + +from tools import execute_tool + + +def _build_alias_map(tool_calls: list[dict]) -> dict[str, str]: + alias_to_id: dict[str, str] = {} + for tc in tool_calls: + alias = tc["input"].pop("tool_call_alias", None) + if alias: + alias_to_id[alias] = tc["id"] + return alias_to_id + + +def _build_dag_levels(tool_calls: list[dict]) -> tuple[list[list[dict]], dict[str, set[str]]]: + """Return (levels, deps). Levels order tcs by dependency; deps maps tc_id -> its prerequisites.""" + if not tool_calls: + return [], {} + + by_id: dict[str, dict] = {tc["id"]: tc for tc in tool_calls} + tc_ids_in_turn = set(by_id.keys()) + + alias_to_id = _build_alias_map(tool_calls) + + deps: dict[str, set[str]] = {} + for tc in tool_calls: + raw = tc["input"].pop("depends_on", None) or [] + resolved = [alias_to_id.get(d, d) for d in raw] + deps[tc["id"]] = {d for d in resolved if d in tc_ids_in_turn} + + _add_implicit_write_deps(tool_calls, deps) + + remaining = set(by_id.keys()) + levels: list[list[dict]] = [] + while remaining: + ready = {nid for nid in remaining + if not (deps.get(nid, set()) & remaining)} + if not ready: + levels.append([by_id[nid] for nid in remaining]) + break + levels.append([by_id[nid] for nid in ready]) + remaining -= ready + return levels, deps + + +def _compute_downstream(deps: dict[str, set[str]], seed_ids: set[str]) -> set[str]: + """Return seed_ids plus every tc_id transitively depending on any seed.""" + dependents: dict[str, set[str]] = {} + for tc_id, prereqs in deps.items(): + for prereq in prereqs: + dependents.setdefault(prereq, set()).add(tc_id) + + downstream = set(seed_ids) + frontier = list(seed_ids) + while frontier: + node = frontier.pop() + for child in dependents.get(node, ()): + if child not in downstream: + downstream.add(child) + frontier.append(child) + return downstream + + +def _add_implicit_write_deps( + tool_calls: list[dict], + deps: dict[str, set[str]], +) -> None: + last_write: dict[str, str] = {} + for tc in tool_calls: + if tc["name"] not in ("Write", "Edit", "NotebookEdit"): + continue + fp = os.path.normpath(tc["input"].get("file_path", tc["input"].get("notebook_path", ""))) + if not fp: + continue + prev = last_write.get(fp) + if prev is not None: + deps.setdefault(tc["id"], set()).add(prev) + last_write[fp] = tc["id"] + + +def _execute_level( + level: list[dict], + results: dict[str, str], + durations: dict[str, float], + config: dict, +) -> None: + from tool_registry import is_concurrent_safe + + if len(level) == 1: + tc = level[0] + t0 = time.monotonic() + results[tc["id"]] = execute_tool( + tc["name"], tc["input"], + permission_mode="accept-all", config=config) + durations[tc["id"]] = time.monotonic() - t0 + return + + parallel_tcs = [tc for tc in level if is_concurrent_safe(tc["name"])] + sequential_tcs = [tc for tc in level if not is_concurrent_safe(tc["name"])] + + if parallel_tcs: + _run_parallel(parallel_tcs, results, durations, config) + + for tc in sequential_tcs: + t0 = time.monotonic() + results[tc["id"]] = execute_tool( + tc["name"], tc["input"], + permission_mode="accept-all", config=config) + durations[tc["id"]] = time.monotonic() - t0 + + +def _run_parallel( + tcs: list[dict], + results: dict[str, str], + durations: dict[str, float], + config: dict, +) -> None: + if len(tcs) == 1: + tc = tcs[0] + t0 = time.monotonic() + results[tc["id"]] = execute_tool( + tc["name"], tc["input"], + permission_mode="accept-all", config=config) + durations[tc["id"]] = time.monotonic() - t0 + return + + pool = ThreadPoolExecutor(max_workers=len(tcs)) + start_times: dict[str, float] = {} + for tc in tcs: + start_times[tc["id"]] = time.monotonic() + futures = { + pool.submit(execute_tool, tc["name"], tc["input"], + "accept-all", None, config): tc + for tc in tcs + } + try: + remaining_futs = set(futures) + while remaining_futs: + newly_done = {f for f in remaining_futs if f.done()} + for fut in newly_done: + tc = futures[fut] + results[tc["id"]] = fut.result() + durations[tc["id"]] = time.monotonic() - start_times[tc["id"]] + remaining_futs -= newly_done + if remaining_futs: + time.sleep(0.1) + finally: + pool.shutdown(wait=False, cancel_futures=True) diff --git a/id_uniquify.py b/id_uniquify.py new file mode 100644 index 0000000..ab5b413 --- /dev/null +++ b/id_uniquify.py @@ -0,0 +1,76 @@ +"""Prevent ContextGC auto-stubbing of fresh tool_results. + +The model picks short ids (e.g. `r1`, `w2`) per the XML tool protocol and +reuses them freely across turns. Once an id lands in `gc_state.trashed_ids`, +any future tool_result with the same id is stubbed by `apply_gc`, even though +it is a completely different tool call. The model then sees `[Read result -- +trashed by model]` instead of the file it just asked for. + +The safe fix is to uniquify the id on ingest. If the incoming id clashes with +any id already present in `state.messages` or in `gc_state.trashed_ids`, we +rewrite it to `t{turn}_{original}` (with a numeric suffix if that still +clashes). Same-turn `depends_on` references are rewritten in lockstep so the +DAG resolves correctly. +""" +from __future__ import annotations + + +def _collect_used_ids(state) -> set[str]: + used: set[str] = set() + gc_state = getattr(state, "gc_state", None) + if gc_state is not None: + used.update(gc_state.trashed_ids) + used.update(gc_state.snippets.keys()) + for msg in state.messages: + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + tid = tc.get("id") + if tid: + used.add(tid) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid: + used.add(tid) + return used + + +def _pick_fresh_id(original: str, turn: int, used: set[str]) -> str: + candidate = f"t{turn}_{original}" + if candidate not in used: + return candidate + suffix = 2 + while f"{candidate}_{suffix}" in used: + suffix += 1 + return f"{candidate}_{suffix}" + + +def uniquify_tool_call_ids(tool_calls: list, state) -> dict[str, str]: + """Rewrite colliding tool_call ids in-place and rewrite same-turn depends_on refs. + + Only ids that already exist in state.messages or gc_state.trashed_ids are + remapped. Ids the model has never used before pass through unchanged, + preserving behavior for simple sessions and existing tests.""" + if not tool_calls: + return {} + used = _collect_used_ids(state) + remap: dict[str, str] = {} + for tc in tool_calls: + original = tc.get("id") + if not original or original not in used: + if original: + used.add(original) + continue + fresh = _pick_fresh_id(original, state.turn_count, used) + remap[original] = fresh + tc["id"] = fresh + used.add(fresh) + + if not remap: + return {} + for tc in tool_calls: + params = tc.get("input") or {} + deps = params.get("depends_on") + if deps: + params["depends_on"] = [remap.get(d, d) for d in deps] + return remap diff --git a/tests/test_dag_concurrency.py b/tests/test_dag_concurrency.py new file mode 100644 index 0000000..b246aa3 --- /dev/null +++ b/tests/test_dag_concurrency.py @@ -0,0 +1,188 @@ +"""Tests that DAG level execution enforces concurrent_safe flags for parallel vs sequential tools.""" +from __future__ import annotations + +import threading +import time +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +import pytest + +from dag import _execute_level, _build_dag_levels + + +# --------------------------------------------------------------------------- +# Helpers — track execution timing per tool call to prove parallelism/sequencing +# --------------------------------------------------------------------------- + +_execution_log: list[dict] = [] +_log_lock = threading.Lock() + + +def _make_fake_tool(name: str, duration: float = 0.05): + """Return a function that logs start/end times and sleeps for `duration`.""" + def _fake(tool_name, params, permission_mode=None, _extra=None, config=None): + start = time.monotonic() + tid = threading.current_thread().ident + time.sleep(duration) + end = time.monotonic() + with _log_lock: + _execution_log.append({ + "id": params.get("_tc_id", tool_name), + "name": tool_name, + "thread": tid, + "start": start, + "end": end, + }) + return f"ok:{tool_name}" + return _fake + + +def _clear_log(): + with _log_lock: + _execution_log.clear() + + +def _overlaps(a: dict, b: dict) -> bool: + """True if two execution intervals overlap in time.""" + return a["start"] < b["end"] and b["start"] < a["end"] + + +# --------------------------------------------------------------------------- +# Fixtures — monkeypatch execute_tool and is_concurrent_safe +# --------------------------------------------------------------------------- + +_CONCURRENT_SAFE_TOOLS = {"Read", "Grep", "Glob", "GetDiagnostics", "ContextGC"} + + +@pytest.fixture(autouse=True) +def _patch_tools(monkeypatch): + """Replace execute_tool with a fake that logs timing, and patch is_concurrent_safe.""" + fake = _make_fake_tool("generic", duration=0.05) + monkeypatch.setattr("dag.execute_tool", fake) + + def _fake_is_concurrent_safe(name: str) -> bool: + return name in _CONCURRENT_SAFE_TOOLS + + monkeypatch.setattr("tool_registry.is_concurrent_safe", _fake_is_concurrent_safe) + _clear_log() + yield + _clear_log() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_single_tool_runs_directly(): + """Single tool in a level runs without ThreadPoolExecutor.""" + level = [{"id": "t1", "name": "Read", "input": {"_tc_id": "t1"}}] + results, durations = {}, {} + _execute_level(level, results, durations, config={}) + + assert results["t1"] == "ok:Read" + assert durations["t1"] > 0 + assert len(_execution_log) == 1 + + +def test_all_concurrent_safe_run_in_parallel(): + """Multiple concurrent_safe tools run in parallel (overlapping times).""" + level = [ + {"id": f"t{i}", "name": "Read", "input": {"_tc_id": f"t{i}"}} + for i in range(4) + ] + results, durations = {}, {} + _execute_level(level, results, durations, config={}) + + assert len(results) == 4 + assert all(v.startswith("ok:") for v in results.values()) + # With 4 parallel tools each sleeping 50ms, total should be ~50-100ms not ~200ms + logs = sorted(_execution_log, key=lambda x: x["start"]) + assert any(_overlaps(logs[i], logs[j]) + for i in range(len(logs)) for j in range(i + 1, len(logs))), \ + "Expected overlapping execution for concurrent_safe tools" + + +def test_all_sequential_run_one_at_a_time(): + """Multiple non-concurrent_safe tools run sequentially (no overlap).""" + level = [ + {"id": f"t{i}", "name": "Bash", "input": {"_tc_id": f"t{i}"}} + for i in range(3) + ] + results, durations = {}, {} + _execute_level(level, results, durations, config={}) + + assert len(results) == 3 + logs = sorted(_execution_log, key=lambda x: x["start"]) + for i in range(len(logs) - 1): + assert not _overlaps(logs[i], logs[i + 1]), \ + f"Sequential tools {logs[i]['id']} and {logs[i+1]['id']} should not overlap" + + +def test_mixed_level_parallel_then_sequential(): + """Mixed level: concurrent_safe tools run in parallel first, then sequential tools one at a time.""" + level = [ + {"id": "r1", "name": "Read", "input": {"_tc_id": "r1"}}, + {"id": "r2", "name": "Grep", "input": {"_tc_id": "r2"}}, + {"id": "r3", "name": "Glob", "input": {"_tc_id": "r3"}}, + {"id": "s1", "name": "Bash", "input": {"_tc_id": "s1"}}, + {"id": "s2", "name": "Skill", "input": {"_tc_id": "s2"}}, + ] + results, durations = {}, {} + _execute_level(level, results, durations, config={}) + + assert len(results) == 5 + + logs_by_id = {e["id"]: e for e in _execution_log} + parallel_logs = [logs_by_id[k] for k in ("r1", "r2", "r3")] + seq_logs = sorted([logs_by_id[k] for k in ("s1", "s2")], key=lambda x: x["start"]) + + # Parallel tools should all finish before sequential tools start + parallel_end = max(e["end"] for e in parallel_logs) + seq_start = min(e["start"] for e in seq_logs) + assert parallel_end <= seq_start + 0.01, \ + "Sequential tools should start after parallel tools finish" + + # Sequential tools should not overlap each other + assert not _overlaps(seq_logs[0], seq_logs[1]), \ + "Sequential tools should not overlap" + + +def test_only_sequential_tools_no_parallel_batch(): + """When all tools are non-concurrent_safe, no parallel batch runs.""" + level = [ + {"id": "w1", "name": "Write", "input": {"_tc_id": "w1"}}, + {"id": "w2", "name": "Edit", "input": {"_tc_id": "w2"}}, + ] + results, durations = {}, {} + _execute_level(level, results, durations, config={}) + + assert len(results) == 2 + logs = sorted(_execution_log, key=lambda x: x["start"]) + assert not _overlaps(logs[0], logs[1]) + + +def test_build_dag_levels_all_independent(): + """All independent tools land in a single level.""" + tcs = [ + {"id": "t1", "name": "Read", "input": {}}, + {"id": "t2", "name": "Grep", "input": {}}, + {"id": "t3", "name": "Skill", "input": {}}, + ] + levels, deps = _build_dag_levels(tcs) + assert len(levels) == 1 + assert len(levels[0]) == 3 + + +def test_build_dag_levels_with_deps(): + """Tools with depends_on are ordered into separate levels.""" + tcs = [ + {"id": "t1", "name": "Write", "input": {"tool_call_alias": "w1"}}, + {"id": "t2", "name": "Bash", "input": {"depends_on": ["w1"]}}, + ] + levels, deps = _build_dag_levels(tcs) + assert len(levels) == 2 + assert levels[0][0]["id"] == "t1" + assert levels[1][0]["id"] == "t2" diff --git a/tests/test_id_uniquify.py b/tests/test_id_uniquify.py new file mode 100644 index 0000000..00ee644 --- /dev/null +++ b/tests/test_id_uniquify.py @@ -0,0 +1,141 @@ +"""Tests for id_uniquify — prevent GC auto-stubbing of fresh tool_results.""" +from __future__ import annotations + +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +import pytest +from id_uniquify import uniquify_tool_call_ids, _collect_used_ids, _pick_fresh_id + + +# ── Helpers ────────────────────────────────────────────────────────────── + +class _FakeGCState: + def __init__(self, trashed_ids=None, snippets=None): + self.trashed_ids = trashed_ids or set() + self.snippets = snippets or {} + + +class _FakeState: + def __init__(self, messages=None, gc_state=None, turn_count=1): + self.messages = messages or [] + self.gc_state = gc_state or _FakeGCState() + self.turn_count = turn_count + + +# ── collect_used_ids ───────────────────────────────────────────────────── + +class TestCollectUsedIds: + def test_empty_state(self): + state = _FakeState() + assert _collect_used_ids(state) == set() + + def test_collects_from_messages(self): + state = _FakeState(messages=[ + {"role": "assistant", "tool_calls": [{"id": "r1"}]}, + {"role": "tool", "tool_call_id": "r1"}, + ]) + assert "r1" in _collect_used_ids(state) + + def test_collects_from_trashed_ids(self): + gc = _FakeGCState(trashed_ids={"old1", "old2"}) + state = _FakeState(gc_state=gc) + used = _collect_used_ids(state) + assert "old1" in used + assert "old2" in used + + def test_collects_from_snippets(self): + gc = _FakeGCState(snippets={"s1": {"keep_after": "foo"}}) + state = _FakeState(gc_state=gc) + assert "s1" in _collect_used_ids(state) + + def test_no_gc_state(self): + """Gracefully handles state without gc_state.""" + state = _FakeState() + state.gc_state = None + assert _collect_used_ids(state) == set() + + +# ── pick_fresh_id ──────────────────────────────────────────────────────── + +class TestPickFreshId: + def test_basic(self): + assert _pick_fresh_id("r1", 2, set()) == "t2_r1" + + def test_conflict_adds_suffix(self): + used = {"t2_r1"} + assert _pick_fresh_id("r1", 2, used) == "t2_r1_2" + + def test_multiple_conflicts(self): + used = {"t3_x", "t3_x_2", "t3_x_3"} + assert _pick_fresh_id("x", 3, used) == "t3_x_4" + + +# ── uniquify_tool_call_ids ─────────────────────────────────────────────── + +class TestUniquifyToolCallIds: + def test_no_collision_no_remap(self): + tcs = [{"id": "r1", "name": "Read", "input": {}}] + state = _FakeState() + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {} + assert tcs[0]["id"] == "r1" + + def test_collision_with_trashed_id(self): + gc = _FakeGCState(trashed_ids={"r1"}) + state = _FakeState(gc_state=gc, turn_count=2) + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {"r1": "t2_r1"} + assert tcs[0]["id"] == "t2_r1" + + def test_collision_with_existing_message(self): + state = _FakeState( + messages=[ + {"role": "assistant", "tool_calls": [{"id": "r1"}]}, + {"role": "tool", "tool_call_id": "r1"}, + ], + turn_count=3, + ) + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {"r1": "t3_r1"} + assert tcs[0]["id"] == "t3_r1" + + def test_depends_on_rewritten(self): + gc = _FakeGCState(trashed_ids={"w1"}) + state = _FakeState(gc_state=gc, turn_count=2) + tcs = [ + {"id": "w1", "name": "Write", "input": {}}, + {"id": "b1", "name": "Bash", "input": {"depends_on": ["w1"]}}, + ] + remap = uniquify_tool_call_ids(tcs, state) + assert tcs[0]["id"] == "t2_w1" + assert tcs[1]["input"]["depends_on"] == ["t2_w1"] + + def test_multiple_collisions(self): + gc = _FakeGCState(trashed_ids={"r1", "r2"}) + state = _FakeState(gc_state=gc, turn_count=4) + tcs = [ + {"id": "r1", "name": "Read", "input": {}}, + {"id": "r2", "name": "Read", "input": {}}, + {"id": "r3", "name": "Read", "input": {}}, + ] + remap = uniquify_tool_call_ids(tcs, state) + assert "r1" in remap + assert "r2" in remap + assert "r3" not in remap + assert tcs[2]["id"] == "r3" + + def test_empty_tool_calls(self): + state = _FakeState() + assert uniquify_tool_call_ids([], state) == {} + + def test_no_gc_state_graceful(self): + state = _FakeState() + state.gc_state = None + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {} diff --git a/tool_registry.py b/tool_registry.py index f0a66c2..8f39660 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -64,6 +64,12 @@ def get_tool(name: str) -> Optional[ToolDef]: return _registry.get(name) +def is_concurrent_safe(name: str) -> bool: + """Return True if the named tool is safe to run in parallel.""" + tool = _registry.get(name) + return tool.concurrent_safe if tool else False + + def get_all_tools() -> List[ToolDef]: """Return all registered tools (insertion order).""" return list(_registry.values())