diff --git a/apps/code/src/renderer/features/settings/stores/settingsStore.test.ts b/apps/code/src/renderer/features/settings/stores/settingsStore.test.ts new file mode 100644 index 000000000..6e546a8d9 --- /dev/null +++ b/apps/code/src/renderer/features/settings/stores/settingsStore.test.ts @@ -0,0 +1,68 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const { getItem, setItem, removeItem } = vi.hoisted(() => ({ + getItem: vi.fn(), + setItem: vi.fn(), + removeItem: vi.fn(), +})); + +vi.mock("@renderer/trpc/client", () => ({ + trpcClient: { + secureStore: { + getItem: { query: getItem }, + setItem: { query: setItem }, + removeItem: { query: removeItem }, + }, + }, +})); + +import { useSettingsStore } from "./settingsStore"; + +describe("feature settingsStore cloud selections", () => { + beforeEach(() => { + getItem.mockReset(); + setItem.mockReset(); + removeItem.mockReset(); + getItem.mockResolvedValue(null); + setItem.mockResolvedValue(undefined); + removeItem.mockResolvedValue(undefined); + + useSettingsStore.setState({ + lastUsedCloudRepository: null, + }); + }); + + it("persists the last used cloud repository", async () => { + useSettingsStore.getState().setLastUsedCloudRepository("posthog/posthog"); + + await vi.waitFor(() => { + expect(setItem).toHaveBeenCalled(); + }); + + const lastCall = setItem.mock.calls[setItem.mock.calls.length - 1]; + const persisted = JSON.parse(lastCall[0].value); + + expect(persisted.state.lastUsedCloudRepository).toBe("posthog/posthog"); + }); + + it("rehydrates the last used cloud repository", async () => { + getItem.mockResolvedValue( + JSON.stringify({ + state: { + lastUsedCloudRepository: "posthog/posthog", + }, + version: 0, + }), + ); + + useSettingsStore.setState({ + lastUsedCloudRepository: null, + }); + + await useSettingsStore.persist.rehydrate(); + + expect(useSettingsStore.getState().lastUsedCloudRepository).toBe( + "posthog/posthog", + ); + }); +}); diff --git a/apps/code/src/renderer/features/settings/stores/settingsStore.ts b/apps/code/src/renderer/features/settings/stores/settingsStore.ts index 12162f0cf..ab31fed54 100644 --- a/apps/code/src/renderer/features/settings/stores/settingsStore.ts +++ b/apps/code/src/renderer/features/settings/stores/settingsStore.ts @@ -25,6 +25,7 @@ interface SettingsStore { lastUsedWorkspaceMode: WorkspaceMode; lastUsedAdapter: AgentAdapter; lastUsedModel: string | null; + lastUsedCloudRepository: string | null; lastUsedEnvironments: Record; desktopNotifications: boolean; dockBadgeNotifications: boolean; @@ -57,6 +58,7 @@ interface SettingsStore { setLastUsedWorkspaceMode: (mode: WorkspaceMode) => void; setLastUsedAdapter: (adapter: AgentAdapter) => void; setLastUsedModel: (model: string) => void; + setLastUsedCloudRepository: (repo: string | null) => void; setLastUsedEnvironment: ( repoPath: string, environmentId: string | null, @@ -88,6 +90,7 @@ export const useSettingsStore = create()( lastUsedWorkspaceMode: "local", lastUsedAdapter: "claude", lastUsedModel: null, + lastUsedCloudRepository: null, lastUsedEnvironments: {}, desktopNotifications: true, dockBadgeNotifications: true, @@ -143,6 +146,8 @@ export const useSettingsStore = create()( setLastUsedWorkspaceMode: (mode) => set({ lastUsedWorkspaceMode: mode }), setLastUsedAdapter: (adapter) => set({ lastUsedAdapter: adapter }), setLastUsedModel: (model) => set({ lastUsedModel: model }), + setLastUsedCloudRepository: (repo) => + set({ lastUsedCloudRepository: repo }), setLastUsedEnvironment: (repoPath, environmentId) => set((state) => { const next = { ...state.lastUsedEnvironments }; @@ -190,6 +195,7 @@ export const useSettingsStore = create()( lastUsedWorkspaceMode: state.lastUsedWorkspaceMode, lastUsedAdapter: state.lastUsedAdapter, lastUsedModel: state.lastUsedModel, + lastUsedCloudRepository: state.lastUsedCloudRepository, lastUsedEnvironments: state.lastUsedEnvironments, desktopNotifications: state.desktopNotifications, dockBadgeNotifications: state.dockBadgeNotifications, diff --git a/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx b/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx index 0c055e51f..4cf514e0c 100644 --- a/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx +++ b/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx @@ -29,7 +29,7 @@ import { useAuthStore } from "@renderer/features/auth/stores/authStore"; import { useTRPC } from "@renderer/trpc/client"; import { useNavigationStore } from "@stores/navigationStore"; import { useQuery } from "@tanstack/react-query"; -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useHotkeys } from "react-hotkeys-hook"; import { usePreviewConfig } from "../hooks/usePreviewConfig"; import { useTaskCreation } from "../hooks/useTaskCreation"; @@ -59,6 +59,8 @@ export function TaskInput({ setLastUsedWorkspaceMode, lastUsedAdapter, setLastUsedAdapter, + lastUsedCloudRepository, + setLastUsedCloudRepository, allowBypassPermissions, setLastUsedEnvironment, getLastUsedEnvironment, @@ -103,13 +105,18 @@ export function TaskInput({ const { githubIntegration, repositories, isLoadingRepos } = useRepositoryIntegration(); const [selectedRepository, setSelectedRepository] = useState( - null, + () => lastUsedCloudRepository?.toLowerCase() ?? null, ); + const selectedCloudRepository = useMemo(() => { + if (!selectedRepository) return null; + const lower = selectedRepository.toLowerCase(); + return repositories.includes(lower) ? lower : null; + }, [selectedRepository, repositories]); const { currentBranch, branchLoading, defaultBranch } = useGitQueries(selectedDirectory); const { data: cloudBranchData, isPending: cloudBranchesLoading } = - useGithubBranches(githubIntegration?.id, selectedRepository); + useGithubBranches(githubIntegration?.id, selectedCloudRepository); const cloudBranches = cloudBranchData?.branches; const cloudDefaultBranch = cloudBranchData?.defaultBranch ?? null; @@ -149,6 +156,15 @@ export function TaskInput({ } }, [selectedDirectory, newBranchName, gitActions]); + const handleRepositorySelect = useCallback( + (repo: string) => { + const normalizedRepo = repo.toLowerCase(); + setSelectedRepository(normalizedRepo); + setLastUsedCloudRepository(normalizedRepo); + }, + [setLastUsedCloudRepository], + ); + const { modeOption, modelOption, @@ -159,6 +175,37 @@ export function TaskInput({ const { folders } = useFolders(); + useEffect(() => { + if (selectedRepository || !lastUsedCloudRepository) { + return; + } + + setSelectedRepository(lastUsedCloudRepository.toLowerCase()); + }, [lastUsedCloudRepository, selectedRepository]); + + useEffect(() => { + if ( + isLoadingRepos || + !githubIntegration || + !selectedRepository || + selectedCloudRepository + ) { + return; + } + + setSelectedRepository(null); + if (lastUsedCloudRepository === selectedRepository) { + setLastUsedCloudRepository(null); + } + }, [ + githubIntegration, + isLoadingRepos, + lastUsedCloudRepository, + selectedCloudRepository, + selectedRepository, + setLastUsedCloudRepository, + ]); + useEffect(() => { if (view.folderId) { const folder = folders.find((f) => f.id === view.folderId); @@ -169,7 +216,7 @@ export function TaskInput({ }, [view.folderId, folders]); const effectiveRepoPath = - workspaceMode === "cloud" ? selectedRepository : selectedDirectory; + workspaceMode === "cloud" ? selectedCloudRepository : selectedDirectory; const setSelectedEnvironment = useCallback( (envId: string | null) => { @@ -183,6 +230,7 @@ export function TaskInput({ useEffect(() => { setSelectedBranch(null); + if (effectiveRepoPath) { setSelectedEnvironmentRaw(getLastUsedEnvironment(effectiveRepoPath)); } else { @@ -212,7 +260,7 @@ export function TaskInput({ const { isCreatingTask, canSubmit, handleSubmit } = useTaskCreation({ editorRef, selectedDirectory, - selectedRepository, + selectedRepository: selectedCloudRepository, githubIntegrationId: githubIntegration?.id, workspaceMode: effectiveWorkspaceMode, branch: branchForTaskCreation, @@ -373,7 +421,7 @@ export function TaskInput({ {workspaceMode === "cloud" ? ( vi.fn()); +const mockWorkspaceDelete = vi.hoisted(() => vi.fn()); +const mockGetTaskDirectory = vi.hoisted(() => vi.fn()); + +vi.mock("@renderer/trpc", () => ({ + trpcClient: { + workspace: { + create: { mutate: mockWorkspaceCreate }, + delete: { mutate: mockWorkspaceDelete }, + }, + }, +})); + +vi.mock("@hooks/useRepositoryDirectory", () => ({ + getTaskDirectory: mockGetTaskDirectory, +})); + +vi.mock("@features/provisioning/stores/provisioningStore", () => ({ + useProvisioningStore: { + getState: () => ({ + setActive: vi.fn(), + clear: vi.fn(), + }), + }, +})); + +vi.mock("@features/panels/store/panelLayoutStore", () => ({ + usePanelLayoutStore: { + getState: () => ({ + addActionTab: vi.fn(), + }), + }, +})); + +vi.mock("@features/sessions/service/service", () => ({ + getSessionService: () => ({ + updateSessionTaskTitle: vi.fn(), + }), +})); + +vi.mock("@renderer/utils/generateTitle", () => ({ + generateTitle: vi.fn(async () => null), +})); + +vi.mock("@utils/queryClient", () => ({ + queryClient: { + setQueriesData: vi.fn(), + }, +})); + +vi.mock("@utils/logger", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + debug: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }), + }, +})); + +import { TaskCreationSaga } from "./task-creation"; + +const createTask = (overrides: Partial = {}): Task => ({ + id: "task-123", + task_number: 1, + slug: "task-123", + title: "Test task", + description: "Ship the fix", + origin_product: "user_created", + repository: "posthog/posthog", + created_at: "2026-04-03T00:00:00Z", + updated_at: "2026-04-03T00:00:00Z", + ...overrides, +}); + +const createRun = (overrides: Partial = {}): TaskRun => ({ + id: "run-123", + task: "task-123", + team: 1, + branch: "release/remembered-branch", + environment: "cloud", + status: "started", + log_url: "https://example.com/logs/run-123", + error_message: null, + output: null, + state: {}, + created_at: "2026-04-03T00:00:00Z", + updated_at: "2026-04-03T00:00:00Z", + completed_at: null, + ...overrides, +}); + +describe("TaskCreationSaga", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockWorkspaceCreate.mockResolvedValue(undefined); + mockWorkspaceDelete.mockResolvedValue(undefined); + mockGetTaskDirectory.mockResolvedValue(null); + }); + + it("waits for the cloud run response before surfacing the task", async () => { + const createdTask = createTask(); + const startedTask = createTask({ latest_run: createRun() }); + const createTaskMock = vi.fn().mockResolvedValue(createdTask); + const runTaskInCloudMock = vi.fn().mockResolvedValue(startedTask); + const onTaskReady = vi.fn(); + + const saga = new TaskCreationSaga({ + posthogClient: { + createTask: createTaskMock, + deleteTask: vi.fn(), + getTask: vi.fn(), + runTaskInCloud: runTaskInCloudMock, + updateTask: vi.fn(), + } as never, + onTaskReady, + }); + + const result = await saga.run({ + content: "Ship the fix", + repository: "posthog/posthog", + workspaceMode: "cloud", + branch: "release/remembered-branch", + }); + + expect(result.success).toBe(true); + if (!result.success) { + throw new Error("Expected task creation to succeed"); + } + + expect(runTaskInCloudMock).toHaveBeenCalledWith( + "task-123", + "release/remembered-branch", + undefined, + undefined, + ); + expect(onTaskReady).toHaveBeenCalledTimes(1); + expect(onTaskReady.mock.calls[0][0].task.latest_run?.branch).toBe( + "release/remembered-branch", + ); + expect(result.data.task.latest_run?.branch).toBe( + "release/remembered-branch", + ); + expect(runTaskInCloudMock.mock.invocationCallOrder[0]).toBeLessThan( + onTaskReady.mock.invocationCallOrder[0], + ); + }); +}); diff --git a/apps/code/src/renderer/sagas/task/task-creation.ts b/apps/code/src/renderer/sagas/task/task-creation.ts index bc6371f78..79293799f 100644 --- a/apps/code/src/renderer/sagas/task/task-creation.ts +++ b/apps/code/src/renderer/sagas/task/task-creation.ts @@ -107,7 +107,7 @@ export class TaskCreationSaga extends Saga< ? this.resolveFolder(input.repoPath) : undefined; - const task = taskId + let task = taskId ? await this.readOnlyStep("fetch_task", () => this.deps.posthogClient.getTask(taskId), ) @@ -230,7 +230,9 @@ export class TaskCreationSaga extends Saga< }; } - if (!hasProvisioning && this.deps.onTaskReady) { + const shouldStartCloudRun = workspaceMode === "cloud" && !task.latest_run; + + if (!hasProvisioning && !shouldStartCloudRun && this.deps.onTaskReady) { this.deps.onTaskReady({ task, workspace }); } @@ -253,8 +255,8 @@ export class TaskCreationSaga extends Saga< } // Step 5: Start cloud run (only for new cloud tasks) - if (workspaceMode === "cloud" && !task.latest_run) { - await this.step({ + if (shouldStartCloudRun) { + task = await this.step({ name: "cloud_run", execute: () => this.deps.posthogClient.runTaskInCloud( @@ -267,6 +269,10 @@ export class TaskCreationSaga extends Saga< log.info("Rolling back: cloud run (no-op)", { taskId: task.id }); }, }); + + if (!hasProvisioning && this.deps.onTaskReady) { + this.deps.onTaskReady({ task, workspace }); + } } // Step 7: Connect to session