diff --git a/checkpoint/store.py b/checkpoint/store.py index ec770fc..93f3e24 100644 --- a/checkpoint/store.py +++ b/checkpoint/store.py @@ -12,6 +12,7 @@ import json import os import shutil +import sys import time from datetime import datetime, timedelta from pathlib import Path @@ -97,7 +98,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None: except OSError: return None if size > _MAX_FILE_SIZE: - print(f"[checkpoint] skipping large file ({size} bytes): {file_path}") + print(f"[checkpoint] skipping large file ({size} bytes): {file_path}", file=sys.stderr) return None # Copy file to backups/ @@ -107,7 +108,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None: try: shutil.copy2(str(p), str(backup_path)) except Exception as e: - print(f"[checkpoint] backup failed for {file_path}: {e}") + print(f"[checkpoint] backup failed for {file_path}: {e}", file=sys.stderr) return None return backup_name diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f935fc6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,14 @@ +"""Shared pytest fixtures for all tests.""" + +from __future__ import annotations + +import pytest + + +# --------------- quota stub (avoids ImportError on CI for calc_cost) -------- + +@pytest.fixture(autouse=True) +def _no_quota(monkeypatch): + """Disable quota.record_usage so tests never hit the real billing path.""" + import quota + monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..7d37a79 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,27 @@ +"""Reusable test helpers (importable from any test module).""" + +from __future__ import annotations + +from agent import AssistantTurn + + +def scripted_stream(captured_schemas: list, turns: list[dict]): + """Return a fake ``stream()`` callable that yields pre-defined turns. + + *captured_schemas* receives the ``tool_schemas`` kwarg from each call, + letting tests assert on schema injection. *turns* is a list of dicts, + each with optional ``text`` and ``tool_calls`` keys. + """ + cursor = iter(turns) + + def fake_stream(**kwargs): + captured_schemas.append(kwargs.get("tool_schemas") or []) + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, + out_tokens=1, + ) + + return fake_stream diff --git a/tests/test_checkpoint_e2e.py b/tests/test_checkpoint_e2e.py new file mode 100644 index 0000000..09ff441 --- /dev/null +++ b/tests/test_checkpoint_e2e.py @@ -0,0 +1,112 @@ +"""End-to-end: drive a real agent.run() conversation where the LLM calls Write, +and verify the checkpoint hook intercepts the call and files a backup to disk. + +Only the LLM provider is mocked (via monkeypatching agent.stream). The Write +tool, checkpoint hooks and checkpoint store all run for real against tmp_path. +""" +from __future__ import annotations + +import pytest + +import tools as _tools_init # noqa: F401 - force built-in tool registration +from agent import AgentState, run +from providers import AssistantTurn +from checkpoint import hooks as checkpoint_hooks +from checkpoint import store as checkpoint_store + + +def _scripted_stream(turns): + cursor = iter(turns) + + def fake_stream(**_kwargs): + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, out_tokens=1, + ) + + return fake_stream + + +@pytest.fixture +def sandboxed_checkpoints(tmp_path, monkeypatch): + """Run checkpoint store against tmp_path and install hooks on built-in tools.""" + monkeypatch.setattr( + checkpoint_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints" + ) + checkpoint_store.reset_file_versions() + checkpoint_hooks.set_session("e2e-session") + checkpoint_hooks.reset_tracked() + checkpoint_hooks.install_hooks() + yield tmp_path + checkpoint_hooks.reset_tracked() + + +def test_llm_write_triggers_checkpoint_backup(monkeypatch, sandboxed_checkpoints): + """When the LLM calls Write, the checkpoint hook must back the pre-edit file up. + + Pre-populate a small file, then let the LLM overwrite it via the Write + tool. The hook should copy the old content into checkpoints/.../backups/ + before the Write executes, so the backup holds the original bytes. + """ + target = sandboxed_checkpoints / "hello.py" + target.write_text("print('before')\n", encoding="utf-8") + + turns = [ + {"tool_calls": [{ + "id": "w1", + "name": "Write", + "input": {"file_path": str(target), "content": "print('after')\n"}, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + config = {"model": "test", "permission_mode": "accept-all", + "_session_id": "e2e-session"} + list(run("overwrite the file", state, config, "system prompt")) + + # After the turn: Write applied the new content + assert target.read_text(encoding="utf-8") == "print('after')\n" + + # And the checkpoint hook filed a backup with the pre-edit content + backups_dir = sandboxed_checkpoints / ".checkpoints" / "e2e-session" / "backups" + backups = list(backups_dir.iterdir()) + assert backups, "checkpoint hook did not create a backup file" + assert any(b.read_text(encoding="utf-8") == "print('before')\n" for b in backups) + + +def test_oversized_write_logs_to_stderr_not_stdout( + monkeypatch, sandboxed_checkpoints, capfd +): + """Over the _MAX_FILE_SIZE threshold the hook skips + logs — to stderr only. + + This is the actual user-visible contract of PR #47: checkpoint skips must + not pollute stdout (which carries the conversation transcript), they must + land on stderr where operators look. + """ + monkeypatch.setattr(checkpoint_store, "_MAX_FILE_SIZE", 20) + big = sandboxed_checkpoints / "big.py" + big.write_text("x" * 100, encoding="utf-8") + + turns = [ + {"tool_calls": [{ + "id": "w1", + "name": "Write", + "input": {"file_path": str(big), "content": "y" * 100}, + }]}, + {"text": "ok"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + list(run("rewrite", state, {"model": "test", "permission_mode": "accept-all", + "_session_id": "e2e-session", + "disabled_tools": ["Agent"]}, + "sys")) + + out, errtxt = capfd.readouterr() + assert "[checkpoint] skipping large file" in errtxt + assert "[checkpoint] skipping large file" not in out diff --git a/tests/test_checkpoint_store.py b/tests/test_checkpoint_store.py new file mode 100644 index 0000000..13160a2 --- /dev/null +++ b/tests/test_checkpoint_store.py @@ -0,0 +1,61 @@ +"""Integration tests for checkpoint store: stderr capture + large file skip.""" +from __future__ import annotations + +import pytest + +import checkpoint.store as store + + +@pytest.fixture(autouse=True) +def isolate_store(tmp_path, monkeypatch): + """Redirect checkpoint root to tmp_path and reset global state.""" + monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / "checkpoints") + store.reset_file_versions() + + +def test_large_file_skipped_and_logged_to_stderr(tmp_path, monkeypatch, capsys): + monkeypatch.setattr(store, "_MAX_FILE_SIZE", 50) + big_file = tmp_path / "big.txt" + big_file.write_bytes(b"x" * 100) + + result = store.track_file_edit("test-session", str(big_file)) + + assert result is None + captured = capsys.readouterr() + assert "[checkpoint] skipping large file" in captured.err + assert "100 bytes" in captured.err + assert captured.out == "" + + +def test_normal_file_backed_up(tmp_path, capsys): + small_file = tmp_path / "small.txt" + content = b"hello world" + small_file.write_bytes(content) + + result = store.track_file_edit("test-session", str(small_file)) + + assert result is not None + backup_dir = tmp_path / "checkpoints" / "test-session" / "backups" + backup_path = backup_dir / result + assert backup_path.exists() + assert backup_path.read_bytes() == content + captured = capsys.readouterr() + assert captured.err == "" + + +def test_backup_failure_logged_to_stderr(tmp_path, monkeypatch, capsys): + normal_file = tmp_path / "normal.txt" + normal_file.write_bytes(b"some data") + + def failing_copy(*args, **kwargs): + raise PermissionError("access denied") + + monkeypatch.setattr(store.shutil, "copy2", failing_copy) + + result = store.track_file_edit("test-session", str(normal_file)) + + assert result is None + captured = capsys.readouterr() + assert "[checkpoint] backup failed" in captured.err + assert "access denied" in captured.err + assert captured.out == ""