diff --git a/cmd/afs/afs_live_workspace.go b/cmd/afs/afs_live_workspace.go index 6be4e72..88ddde6 100644 --- a/cmd/afs/afs_live_workspace.go +++ b/cmd/afs/afs_live_workspace.go @@ -120,9 +120,13 @@ func saveLiveWorkspaceCheckpoint(ctx context.Context, store *afsStore, workspace if err != nil { return false, err } + var metadata controlplane.SaveCheckpointFromLiveOptions + if len(options) > 0 { + metadata = options[0] + } if dirty, known, err := store.workspaceRootDirtyState(ctx, workspace); err != nil { return false, err - } else if known && !dirty { + } else if known && !dirty && !metadata.AllowUnchanged { if printResult { fmt.Println("No changes to save") } diff --git a/cmd/afs/afs_mcp.go b/cmd/afs/afs_mcp.go index b502d2f..457886a 100644 --- a/cmd/afs/afs_mcp.go +++ b/cmd/afs/afs_mcp.go @@ -826,8 +826,9 @@ func (s *afsMCPServer) toolCheckpointCreate(ctx context.Context, args map[string return nil, err } saved, err := saveAFSWorkspaceOrLiveRoot(ctx, s.cfg, s.store, workspace, checkpointID, false, controlplane.SaveCheckpointFromLiveOptions{ - Kind: controlplane.CheckpointKindManual, - Source: controlplane.CheckpointSourceMCP, + Kind: controlplane.CheckpointKindManual, + Source: controlplane.CheckpointSourceMCP, + AllowUnchanged: true, }) if err != nil { return nil, err diff --git a/cmd/afs/afs_mcp_test.go b/cmd/afs/afs_mcp_test.go index 0e76969..609f007 100644 --- a/cmd/afs/afs_mcp_test.go +++ b/cmd/afs/afs_mcp_test.go @@ -219,6 +219,43 @@ func TestAFSMCPCheckpointCreatePersistsPendingWrite(t *testing.T) { } } +func TestAFSMCPCheckpointCreateAllowsUnchangedWorkspace(t *testing.T) { + t.Helper() + + server, closeFn := setupAFSMCPTestServer(t) + defer closeFn() + server.profile = controlplane.MCPProfileWorkspaceRWCheckpoint + + checkpointResult := server.callTool(context.Background(), "checkpoint_create", map[string]any{ + "checkpoint": "unchanged-head", + }) + if checkpointResult.IsError { + t.Fatalf("checkpoint_create on unchanged workspace returned error result: %+v", checkpointResult) + } + + var checkpointPayload map[string]any + if err := decodeStructuredContent(checkpointResult.StructuredContent, &checkpointPayload); err != nil { + t.Fatalf("decodeStructuredContent(checkpoint unchanged) returned error: %v", err) + } + if created, _ := checkpointPayload["created"].(bool); !created { + t.Fatalf("checkpoint_create created = %#v, want true", checkpointPayload["created"]) + } + if checkpoint, _ := checkpointPayload["checkpoint"].(string); checkpoint != "unchanged-head" { + t.Fatalf("checkpoint_create checkpoint = %#v, want %q", checkpointPayload["checkpoint"], "unchanged-head") + } + + if _, err := server.store.getSavepointMeta(context.Background(), "repo", "unchanged-head"); err != nil { + t.Fatalf("getSavepointMeta(unchanged-head) returned error: %v", err) + } + + restoreResult := server.callTool(context.Background(), "checkpoint_restore", map[string]any{ + "checkpoint": "unchanged-head", + }) + if restoreResult.IsError { + t.Fatalf("checkpoint_restore after unchanged create returned error result: %+v", restoreResult) + } +} + func TestAFSMCPFileWriteDoesNotRematerializeLocalWorkspaceCache(t *testing.T) { t.Helper() diff --git a/cmd/afs/config_state.go b/cmd/afs/config_state.go index 5a5bd86..af74c9e 100644 --- a/cmd/afs/config_state.go +++ b/cmd/afs/config_state.go @@ -2,6 +2,7 @@ package main import ( "crypto/rand" + "crypto/sha256" "encoding/hex" "encoding/json" "errors" @@ -19,6 +20,10 @@ func configPath() string { if cfgPathOverride != "" { return cfgPathOverride } + return defaultConfigPath() +} + +func defaultConfigPath() string { exe, err := executablePath() if err != nil { return "afs.config.json" @@ -477,24 +482,50 @@ func defaultWorkRoot() string { return filepath.Join(stateDir(), "workspaces") } -func statePath() string { +func defaultStatePath() string { return filepath.Join(stateDir(), "state.json") } +func statePathForConfig(configFile string) string { + cleanConfig := cleanConfigPath(configFile) + if cleanConfig == "" || cleanConfig == cleanConfigPath(defaultConfigPath()) { + return defaultStatePath() + } + sum := sha256.Sum256([]byte(cleanConfig)) + return filepath.Join(stateDir(), "configs", hex.EncodeToString(sum[:8])+".json") +} + +func statePath() string { + return statePathForConfig(configPath()) +} + func saveState(st state) error { - if err := os.MkdirAll(stateDir(), 0o700); err != nil { + target := statePath() + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { return err } b, err := json.MarshalIndent(st, "", " ") if err != nil { return err } - return os.WriteFile(statePath(), b, 0o600) + return os.WriteFile(target, b, 0o600) } func loadState() (state, error) { + if st, err := loadStateFromPath(statePath()); err == nil { + return st, nil + } else if !errors.Is(err, os.ErrNotExist) { + return state{}, err + } + if sameConfigPath(statePath(), defaultStatePath()) { + return state{}, os.ErrNotExist + } + return loadStateFromPath(defaultStatePath()) +} + +func loadStateFromPath(path string) (state, error) { var st state - b, err := os.ReadFile(statePath()) + b, err := os.ReadFile(path) if err != nil { return st, err } diff --git a/cmd/afs/sync_integration_test.go b/cmd/afs/sync_integration_test.go index 8919c59..a218737 100644 --- a/cmd/afs/sync_integration_test.go +++ b/cmd/afs/sync_integration_test.go @@ -1194,6 +1194,77 @@ func TestCmdFileCreateExclusiveRoundTrip(t *testing.T) { } } +func TestCmdFileCreateExclusiveUsesConfigScopedState(t *testing.T) { + t.Helper() + + env := newSyncTestEnv(t) + env.startDaemon(t) + defer env.stopDaemon() + + oldCfgPathOverride := cfgPathOverride + cfgPathOverride = filepath.Join(t.TempDir(), "afs.config.json") + t.Cleanup(func() { + cfgPathOverride = oldCfgPathOverride + }) + + cfg := defaultConfig() + cfg.ProductMode = productModeLocal + cfg.Mode = modeSync + cfg.RedisAddr = env.mr.Addr() + cfg.RedisDB = 0 + cfg.LocalPath = env.localRoot + cfg.CurrentWorkspace = env.workspace + if err := saveConfig(cfg); err != nil { + t.Fatalf("saveConfig() returned error: %v", err) + } + + st := state{ + ProductMode: productModeLocal, + RedisAddr: env.mr.Addr(), + RedisDB: 0, + CurrentWorkspace: env.workspace, + LocalPath: env.localRoot, + Mode: modeSync, + SyncPID: os.Getpid(), + } + if err := saveState(st); err != nil { + t.Fatalf("saveState() returned error: %v", err) + } + if sameConfigPath(statePath(), defaultStatePath()) { + t.Fatalf("statePath() = %q, want config-scoped path distinct from legacy %q", statePath(), defaultStatePath()) + } + + legacyState := state{ + ProductMode: productModeLocal, + RedisAddr: "127.0.0.1:1", + RedisDB: 99, + CurrentWorkspace: "legacy-workspace", + LocalPath: t.TempDir(), + Mode: modeSync, + SyncPID: os.Getpid(), + } + rawLegacyState, err := json.MarshalIndent(legacyState, "", " ") + if err != nil { + t.Fatalf("json.MarshalIndent(legacyState) returned error: %v", err) + } + if err := os.MkdirAll(filepath.Dir(defaultStatePath()), 0o700); err != nil { + t.Fatalf("MkdirAll(defaultStatePath dir) returned error: %v", err) + } + if err := os.WriteFile(defaultStatePath(), rawLegacyState, 0o600); err != nil { + t.Fatalf("WriteFile(defaultStatePath) returned error: %v", err) + } + + if err := cmdFS([]string{"fs", "create-exclusive", "--content", "agent-c\n", "/tasks/003.claim"}); err != nil { + t.Fatalf("cmdFS(create-exclusive with config-scoped state) returned error: %v", err) + } + assertEventually(t, 3*time.Second, "remote 003.claim", func() bool { + return env.remoteExists(t, "tasks/003.claim") + }) + if got := env.readRemoteFile(t, "tasks/003.claim"); got != "agent-c\n" { + t.Fatalf("remote content = %q, want %q", got, "agent-c\n") + } +} + // Scenario 1 (burst variant): a batch of files written before startup all // land remotely, and the steady-state has no spurious echo loops. func TestSyncStartupUploadBurst(t *testing.T) { diff --git a/internal/controlplane/mcp_hosted.go b/internal/controlplane/mcp_hosted.go index 7a4c18c..377f2ce 100644 --- a/internal/controlplane/mcp_hosted.go +++ b/internal/controlplane/mcp_hosted.go @@ -549,8 +549,9 @@ func (p *hostedMCPProvider) callWorkspaceTool(ctx context.Context, name string, if err = validateHostedMCPName("checkpoint", checkpointID); err == nil { var saved bool saved, err = p.manager.SaveCheckpointFromLiveWithOptions(ctx, p.databaseID, p.workspace, checkpointID, SaveCheckpointFromLiveOptions{ - Kind: CheckpointKindManual, - Source: CheckpointSourceMCP, + Kind: CheckpointKindManual, + Source: CheckpointSourceMCP, + AllowUnchanged: true, }) value = map[string]any{ "workspace": p.workspace, diff --git a/internal/controlplane/mcp_hosted_test.go b/internal/controlplane/mcp_hosted_test.go index d13bc4c..81868dc 100644 --- a/internal/controlplane/mcp_hosted_test.go +++ b/internal/controlplane/mcp_hosted_test.go @@ -79,6 +79,43 @@ func TestHostedMCPFileCreateExclusiveFailsWhenFileExists(t *testing.T) { } } +func TestHostedMCPCheckpointCreateAllowsUnchangedWorkspace(t *testing.T) { + t.Helper() + + manager, databaseID := newTestManager(t) + provider := &hostedMCPProvider{ + manager: manager, + databaseID: databaseID, + workspace: "repo", + profile: MCPProfileWorkspaceRWCheckpoint, + } + + checkpointResult := provider.CallTool(context.Background(), "checkpoint_create", map[string]any{ + "checkpoint": "unchanged-head", + }) + if checkpointResult.IsError { + t.Fatalf("checkpoint_create on unchanged workspace returned error result: %+v", checkpointResult) + } + + var checkpointPayload map[string]any + if err := decodeHostedStructuredContent(checkpointResult.StructuredContent, &checkpointPayload); err != nil { + t.Fatalf("decodeHostedStructuredContent(checkpoint unchanged) returned error: %v", err) + } + if created, _ := checkpointPayload["created"].(bool); !created { + t.Fatalf("checkpoint_create created = %#v, want true", checkpointPayload["created"]) + } + if checkpoint, _ := checkpointPayload["checkpoint"].(string); checkpoint != "unchanged-head" { + t.Fatalf("checkpoint_create checkpoint = %#v, want %q", checkpointPayload["checkpoint"], "unchanged-head") + } + + restoreResult := provider.CallTool(context.Background(), "checkpoint_restore", map[string]any{ + "checkpoint": "unchanged-head", + }) + if restoreResult.IsError { + t.Fatalf("checkpoint_restore after unchanged create returned error result: %+v", restoreResult) + } +} + func TestHostedMCPWorkspaceVersioningPolicyToolsRoundTrip(t *testing.T) { t.Helper() diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py index e7c1fef..3772767 100644 --- a/sdk/python/tests/test_client.py +++ b/sdk/python/tests/test_client.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch -from redis_afs.client import AFSError, FSClient, MCPHttpClient, MountedFS, _MountedWorkspace, _normalize_mcp_endpoint +from redis_afs.client import AFSError, CheckpointClient, FSClient, MCPHttpClient, MountedFS, _MountedWorkspace, _normalize_mcp_endpoint class FakeMCP: @@ -32,6 +32,8 @@ def call_tool(self, name, arguments=None): return {"entries": entries} if name == "checkpoint_create": return {"workspace": "workspace", "checkpoint": arguments.get("checkpoint") or "auto", "created": True} + if name == "checkpoint_restore": + return {"workspace": "workspace", "checkpoint": arguments["checkpoint"], "restored": True} raise AssertionError(f"unexpected tool {name}") @@ -90,6 +92,8 @@ def call_tool(self, name, arguments=None): return {"entries": entries} if name == "checkpoint_create": return {"workspace": "workspace", "checkpoint": arguments.get("checkpoint") or "auto", "created": True} + if name == "checkpoint_restore": + return {"workspace": "workspace", "checkpoint": arguments["checkpoint"], "restored": True} raise AssertionError(f"unexpected tool {name}") @@ -144,6 +148,17 @@ def test_fs_mount_issues_workspace_token_and_reads_and_writes_files(self): class EndpointTest(unittest.TestCase): + def test_checkpoint_create_and_restore_round_trip(self): + checkpoint = CheckpointClient(FakeMCP()) + + created = checkpoint.create(workspace="repo", checkpoint="unchanged-head") + restored = checkpoint.restore(workspace="repo", checkpoint="unchanged-head") + + self.assertTrue(created["created"]) + self.assertEqual(created["checkpoint"], "unchanged-head") + self.assertTrue(restored["restored"]) + self.assertEqual(restored["checkpoint"], "unchanged-head") + def test_normalizes_mcp_endpoint(self): self.assertEqual(_normalize_mcp_endpoint("https://afs.cloud"), "https://afs.cloud/mcp") self.assertEqual(_normalize_mcp_endpoint("https://afs.cloud/mcp"), "https://afs.cloud/mcp") diff --git a/sdk/typescript/test/sdk.test.mjs b/sdk/typescript/test/sdk.test.mjs index a827877..f7e4b72 100644 --- a/sdk/typescript/test/sdk.test.mjs +++ b/sdk/typescript/test/sdk.test.mjs @@ -56,3 +56,51 @@ test("single-workspace mounts allow workspace-relative paths", async () => { assert.equal(await fs.readFile("/foobar/src/README.md"), "hello"); assert.deepEqual(fs.workspaceNames, ["foobar"]); }); + +test("checkpoint.create and checkpoint.restore round-trip through MCP", async () => { + const calls = []; + const afs = new AFS({ + apiKey: "test", + baseUrl: "https://afs.cloud", + fetch: async (_url, init) => { + const body = JSON.parse(String(init.body)); + calls.push(body); + let structuredContent; + if (body.params.name === "checkpoint_create") { + structuredContent = { + workspace: body.params.arguments.workspace, + checkpoint: body.params.arguments.checkpoint, + created: true, + }; + } else if (body.params.name === "checkpoint_restore") { + structuredContent = { + workspace: body.params.arguments.workspace, + checkpoint: body.params.arguments.checkpoint, + restored: true, + }; + } else { + throw new Error(`unexpected tool ${body.params.name}`); + } + return new Response( + JSON.stringify({ + jsonrpc: "2.0", + id: body.id, + result: { structuredContent }, + }), + { status: 200, headers: { "content-type": "application/json" } }, + ); + }, + }); + + const created = await afs.checkpoint.create({ workspace: "repo", checkpoint: "unchanged-head" }); + const restored = await afs.checkpoint.restore({ workspace: "repo", checkpoint: "unchanged-head" }); + + assert.equal(created.created, true); + assert.equal(created.checkpoint, "unchanged-head"); + assert.equal(restored.restored, true); + assert.equal(restored.checkpoint, "unchanged-head"); + assert.deepEqual( + calls.map((call) => call.params.name), + ["checkpoint_create", "checkpoint_restore"], + ); +});