diff --git a/agentex-ui/components/primary-content/prompt-input.tsx b/agentex-ui/components/primary-content/prompt-input.tsx index 1ae7475f..b85deb90 100644 --- a/agentex-ui/components/primary-content/prompt-input.tsx +++ b/agentex-ui/components/primary-content/prompt-input.tsx @@ -20,6 +20,7 @@ import { useSafeSearchParams, } from '@/hooks/use-safe-search-params'; import { useSendMessage } from '@/hooks/use-task-messages'; +import { useTask } from '@/hooks/use-tasks'; type PromptInputProps = { prompt: string; @@ -52,10 +53,16 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) { const createTaskMutation = useCreateTask({ agentexClient }); const sendMessageMutation = useSendMessage({ agentexClient }); + const { data: task } = useTask({ agentexClient, taskId: taskID ?? '' }); const textInputRef = useRef(null); const codeMirrorViewRef = useRef(null); + const isTaskTerminal = useMemo(() => { + if (!taskID || !task) return false; + return task.status != null && task.status !== 'RUNNING'; + }, [taskID, task]); + const handleSetJson = useCallback( (value: boolean) => { if (value && !prompt.trim()) { @@ -86,8 +93,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) { }, [taskID, isClient, isSendingJSON]); const isDisabled = useMemo( - () => !agentName || !isClient, - [agentName, isClient] + () => !agentName || !isClient || isTaskTerminal, + [agentName, isClient, isTaskTerminal] ); const handleSendPrompt = useCallback(async () => { @@ -171,6 +178,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) { prompt={prompt} setPrompt={setPrompt} isDisabled={isDisabled} + isTaskTerminal={isTaskTerminal} + taskStatus={task?.status} handleSendPrompt={handleSendPrompt} inputRef={textInputRef} /> @@ -205,12 +214,16 @@ const TextInput = ({ prompt, setPrompt, isDisabled, + isTaskTerminal, + taskStatus, handleSendPrompt, inputRef, }: { prompt: string; setPrompt: (prompt: string) => void; isDisabled: boolean; + isTaskTerminal: boolean; + taskStatus: string | null | undefined; handleSendPrompt: () => void; inputRef: React.RefObject; }) => { @@ -230,7 +243,11 @@ const TextInput = ({ }} disabled={isDisabled} placeholder={ - isDisabled ? 'Select an agent to start' : 'Enter your prompt' + isTaskTerminal + ? `Task ${taskStatus?.toLowerCase() ?? 'ended'}` + : isDisabled + ? 'Select an agent to start' + : 'Enter your prompt' } className="mr-2 flex-1 outline-none focus:ring-0 focus:outline-none" style={{ diff --git a/agentex-ui/components/task-messages/task-messages.tsx b/agentex-ui/components/task-messages/task-messages.tsx index 67330839..5e658a7c 100644 --- a/agentex-ui/components/task-messages/task-messages.tsx +++ b/agentex-ui/components/task-messages/task-messages.tsx @@ -23,7 +23,7 @@ type TaskMessagesProps = { }; type MessagePair = { id: string; - userMessage: TaskMessage; + userMessage: TaskMessage | null; agentMessages: TaskMessage[]; }; @@ -58,36 +58,41 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) { const pairs: MessagePair[] = []; let currentUserMessage: TaskMessage | null = null; let currentAgentMessages: TaskMessage[] = []; + let pairStarted = false; for (const message of messages) { const isUserMessage = message.content.author === 'user'; if (isUserMessage) { - if (currentUserMessage) { + if (pairStarted) { pairs.push({ - id: currentUserMessage.id || `pair-${pairs.length}`, + id: + currentUserMessage?.id || + currentAgentMessages[0]?.id || + `pair-${pairs.length}`, userMessage: currentUserMessage, agentMessages: currentAgentMessages, }); } currentUserMessage = message; currentAgentMessages = []; + pairStarted = true; } else { - if (currentUserMessage) { - currentAgentMessages.push(message); - } else { - pairs.push({ - id: message.id || `pair-${pairs.length}`, - userMessage: message, - agentMessages: [], - }); + if (!pairStarted) { + currentUserMessage = null; + currentAgentMessages = []; + pairStarted = true; } + currentAgentMessages.push(message); } } - if (currentUserMessage) { + if (pairStarted) { pairs.push({ - id: currentUserMessage.id || `pair-${pairs.length}`, + id: + currentUserMessage?.id || + currentAgentMessages[0]?.id || + `pair-${pairs.length}`, userMessage: currentUserMessage, agentMessages: currentAgentMessages, }); @@ -101,10 +106,13 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) { const lastPair = messagePairs[messagePairs.length - 1]!; const hasNoAgentMessages = lastPair.agentMessages.length === 0; + const hasUserMessage = lastPair.userMessage !== null; const rpcStatus = queryData?.rpcStatus; return ( - hasNoAgentMessages && (rpcStatus === 'pending' || rpcStatus === 'success') + hasUserMessage && + hasNoAgentMessages && + (rpcStatus === 'pending' || rpcStatus === 'success') ); }, [messagePairs, queryData?.rpcStatus]); @@ -191,7 +199,7 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) { containerHeight={containerHeight} > - {renderMessage(pair.userMessage)} + {pair.userMessage && renderMessage(pair.userMessage)} {pair.agentMessages.map(agentMessage => ( {renderMessage(agentMessage)} diff --git a/agentex/src/api/routes/tasks.py b/agentex/src/api/routes/tasks.py index 0cccdc5e..9eef026b 100644 --- a/agentex/src/api/routes/tasks.py +++ b/agentex/src/api/routes/tasks.py @@ -13,6 +13,7 @@ Task, TaskRelationships, TaskResponse, + TaskStatusReasonRequest, UpdateTaskRequest, ) from src.domain.services.authorization_service import DAuthorizationService @@ -169,6 +170,91 @@ async def update_task_by_name( return Task.model_validate(updated_task_entity) +@router.post( + "/{task_id}/complete", + response_model=Task, + summary="Complete Task", + description="Mark a running task as completed.", +) +async def complete_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.complete_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/fail", + response_model=Task, + summary="Fail Task", + description="Mark a running task as failed.", +) +async def fail_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.fail_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/cancel", + response_model=Task, + summary="Cancel Task", + description="Mark a running task as canceled.", +) +async def cancel_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.cancel_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/terminate", + response_model=Task, + summary="Terminate Task", + description="Mark a running task as terminated.", +) +async def terminate_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.terminate_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/timeout", + response_model=Task, + summary="Timeout Task", + description="Mark a running task as timed out.", +) +async def timeout_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.timeout_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + @router.get( "/{task_id}/stream", summary="Stream Task Events by ID", diff --git a/agentex/src/api/schemas/tasks.py b/agentex/src/api/schemas/tasks.py index df5ad42a..4055cadc 100644 --- a/agentex/src/api/schemas/tasks.py +++ b/agentex/src/api/schemas/tasks.py @@ -73,3 +73,10 @@ class UpdateTaskRequest(BaseModel): None, title="If provided, replaces task_metadata with this value", ) + + +class TaskStatusReasonRequest(BaseModel): + reason: str | None = Field( + None, + title="Optional reason for the status change", + ) diff --git a/agentex/src/domain/repositories/task_repository.py b/agentex/src/domain/repositories/task_repository.py index 53898eb6..fe3e2eae 100644 --- a/agentex/src/domain/repositories/task_repository.py +++ b/agentex/src/domain/repositories/task_repository.py @@ -2,7 +2,7 @@ from typing import Annotated, Literal from fastapi import Depends -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.orm import selectinload from src.adapters.crud_store.adapter_postgres import ( ColumnPrimitiveValue, @@ -139,5 +139,37 @@ async def update(self, task: TaskEntity) -> TaskEntity: # Return with agents populated return TaskEntity.model_validate(modified_orm) + async def transition_status( + self, + task_id: str, + expected_status: TaskStatus, + new_status: TaskStatus, + status_reason: str, + task_metadata: dict | None = None, + ) -> TaskEntity | None: + """Atomically transition task status. Returns None if the expected status didn't match (i.e. lost the race).""" + + async with ( + self.start_async_db_session(True) as session, + async_sql_exception_handler(), + ): + values: dict = {"status": new_status, "status_reason": status_reason} + if task_metadata is not None: + values["task_metadata"] = task_metadata + + stmt = ( + update(TaskORM) + .where(TaskORM.id == task_id, TaskORM.status == expected_status) + .values(**values) + ) + result = await session.execute(stmt) + await session.commit() + + if result.rowcount == 0: + return None + + refreshed = await session.get(TaskORM, task_id) + return TaskEntity.model_validate(refreshed) + DTaskRepository = Annotated[TaskRepository, Depends(TaskRepository)] diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index c31b3561..83923ff9 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -144,6 +144,44 @@ async def get_task( id=id, name=name, relationships=relationships ) + async def transition_task_status( + self, + task_id: str, + expected_status: TaskStatus, + new_status: TaskStatus, + status_reason: str, + task_metadata: dict | None = None, + ) -> TaskEntity | None: + """ + Atomically transition task status. Returns None if the expected status didn't match. + Publishes a task_updated event on success. + """ + updated_task = await self.task_repository.transition_status( + task_id=task_id, + expected_status=expected_status, + new_status=new_status, + status_reason=status_reason, + task_metadata=task_metadata, + ) + if updated_task is None: + return None + + try: + topic = get_task_event_stream_topic(task_id=task_id) + await self.stream_repository.send_data( + topic, + TaskStreamTaskUpdatedEventEntity( + type="task_updated", task=updated_task + ).model_dump(mode="json"), + ) + logger.info(f"task_updated event published to topic: {topic}") + except Exception as e: + logger.error( + f"Error sending task_updated event to stream: {e}", exc_info=True + ) + + return updated_task + async def update_task(self, task: TaskEntity) -> TaskEntity: """ Update a task in the repository. diff --git a/agentex/src/domain/use_cases/tasks_use_case.py b/agentex/src/domain/use_cases/tasks_use_case.py index 954f7cc3..f358cf1b 100644 --- a/agentex/src/domain/use_cases/tasks_use_case.py +++ b/agentex/src/domain/use_cases/tasks_use_case.py @@ -114,5 +114,79 @@ async def update_mutable_fields_on_task( updated_task_entity = await self.task_service.update_task(task=task_entity) return updated_task_entity + async def _transition_to_terminal( + self, + target_status: TaskStatus, + id: str | None = None, + name: str | None = None, + reason: str | None = None, + ) -> TaskEntity: + """Atomically transition a running task to a terminal status.""" + if not id and not name: + raise ClientError("Either id or name must be provided") + + task_entity = await self.task_service.get_task(id=id, name=name) + if task_entity.status == TaskStatus.DELETED: + raise ItemDoesNotExist(f"Task {id or name} not found") + if task_entity.status != TaskStatus.RUNNING: + raise ClientError( + f"Task {task_entity.id} is not running (current status: {task_entity.status}). " + f"Only running tasks can have their status updated." + ) + + status_reason = reason or f"Task {target_status.value.lower()}" + updated = await self.task_service.transition_task_status( + task_id=task_entity.id, + expected_status=TaskStatus.RUNNING, + new_status=target_status, + status_reason=status_reason, + ) + if updated is None: + raise ClientError( + f"Task {task_entity.id} status was concurrently modified. " + f"Please retry the request." + ) + return updated + + async def complete_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as completed.""" + return await self._transition_to_terminal( + TaskStatus.COMPLETED, id=id, name=name, reason=reason + ) + + async def fail_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as failed.""" + return await self._transition_to_terminal( + TaskStatus.FAILED, id=id, name=name, reason=reason + ) + + async def cancel_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as canceled.""" + return await self._transition_to_terminal( + TaskStatus.CANCELED, id=id, name=name, reason=reason + ) + + async def terminate_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as terminated.""" + return await self._transition_to_terminal( + TaskStatus.TERMINATED, id=id, name=name, reason=reason + ) + + async def timeout_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as timed out.""" + return await self._transition_to_terminal( + TaskStatus.TIMED_OUT, id=id, name=name, reason=reason + ) + DTaskUseCase = Annotated[TasksUseCase, Depends(TasksUseCase)] diff --git a/agentex/tests/integration/api/tasks/test_tasks_api.py b/agentex/tests/integration/api/tasks/test_tasks_api.py index 3fd4874c..934e079f 100644 --- a/agentex/tests/integration/api/tasks/test_tasks_api.py +++ b/agentex/tests/integration/api/tasks/test_tasks_api.py @@ -1381,3 +1381,101 @@ async def test_list_tasks_filters_work_with_views( assert "agents" in task_data assert len(task_data["agents"]) == 1 assert task_data["agents"][0]["name"] == "target-filter-agent" + + async def test_complete_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/complete transitions RUNNING to COMPLETED""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/complete", + json={"reason": "Agent finished"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "COMPLETED" + assert task_data["status_reason"] == "Agent finished" + + async def test_fail_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/fail transitions RUNNING to FAILED""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/fail", + json={"reason": "Something went wrong"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "FAILED" + assert task_data["status_reason"] == "Something went wrong" + + async def test_cancel_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/cancel transitions RUNNING to CANCELED""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/cancel", + json={"reason": "User requested cancellation"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "CANCELED" + assert task_data["status_reason"] == "User requested cancellation" + + async def test_terminate_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/terminate transitions RUNNING to TERMINATED""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/terminate", + json={"reason": "Workflow killed"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "TERMINATED" + assert task_data["status_reason"] == "Workflow killed" + + async def test_timeout_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/timeout transitions RUNNING to TIMED_OUT""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/timeout", + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "TIMED_OUT" + assert task_data["status_reason"] == "Task timed_out" + + async def test_complete_task_with_default_reason(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/complete without a reason uses default""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/complete", + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "COMPLETED" + assert task_data["status_reason"] == "Task completed" + + async def test_cannot_transition_non_running_task(self, isolated_client, test_task): + """Test that a completed task cannot be transitioned again""" + # Given - Complete the task first + response = await isolated_client.post( + f"/tasks/{test_task.id}/complete", + ) + assert response.status_code == 200 + + # When - Try to terminate the already-completed task + response = await isolated_client.post( + f"/tasks/{test_task.id}/terminate", + ) + + # Then - Should fail + assert response.status_code == 400 diff --git a/agentex/tests/unit/use_cases/test_tasks_use_case.py b/agentex/tests/unit/use_cases/test_tasks_use_case.py new file mode 100644 index 00000000..3de88eea --- /dev/null +++ b/agentex/tests/unit/use_cases/test_tasks_use_case.py @@ -0,0 +1,457 @@ +""" +Unit tests for TasksUseCase - status transition logic via explicit status +methods (complete_task, fail_task, etc.) and metadata updates. +""" + +from uuid import uuid4 + +import pytest +from src.adapters.crud_store.exceptions import DuplicateItemError, ItemDoesNotExist +from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus +from src.domain.entities.tasks import TaskStatus +from src.domain.exceptions import ClientError +from src.domain.repositories.agent_repository import AgentRepository +from src.domain.repositories.task_repository import TaskRepository +from src.domain.use_cases.tasks_use_case import TasksUseCase + + +async def create_or_get_agent(agent_repository, agent): + """Helper to create agent or get existing one if name already exists""" + try: + return await agent_repository.create(agent) + except DuplicateItemError: + existing_agent = await agent_repository.get(name=agent.name) + agent.id = existing_agent.id + return existing_agent + + +@pytest.fixture +def agent_repository(postgres_session_maker): + """Real AgentRepository using test PostgreSQL database""" + return AgentRepository(postgres_session_maker, postgres_session_maker) + + +@pytest.fixture +def task_repository(postgres_session_maker): + """Real TaskRepository using test PostgreSQL database""" + return TaskRepository(postgres_session_maker, postgres_session_maker) + + +@pytest.fixture +def tasks_use_case(task_service): + """TasksUseCase with real task_service""" + return TasksUseCase(task_service=task_service) + + +@pytest.fixture +def sample_agent(): + """Sample agent entity for testing""" + return AgentEntity( + id=str(uuid4()), + name="test-agent-use-case", + description="A test agent for use case testing", + status=AgentStatus.READY, + acp_type=ACPType.ASYNC, + acp_url="http://test-acp.example.com", + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTasksUseCaseStatusTransitions: + """Test suite for task status transitions via explicit status methods""" + + # --- Happy-path transitions (RUNNING -> terminal) --- + + async def test_complete_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to COMPLETED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="complete-test" + ) + assert task.status == TaskStatus.RUNNING + + # When + updated = await tasks_use_case.complete_task( + id=task.id, reason="Agent finished" + ) + + # Then + assert updated.status == TaskStatus.COMPLETED + assert updated.status_reason == "Agent finished" + + async def test_fail_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to FAILED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task(agent=sample_agent, task_name="fail-test") + + # When + updated = await tasks_use_case.fail_task( + id=task.id, reason="Something went wrong" + ) + + # Then + assert updated.status == TaskStatus.FAILED + assert updated.status_reason == "Something went wrong" + + async def test_cancel_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to CANCELED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="cancel-test" + ) + + # When + updated = await tasks_use_case.cancel_task( + id=task.id, reason="User requested cancellation" + ) + + # Then + assert updated.status == TaskStatus.CANCELED + assert updated.status_reason == "User requested cancellation" + + async def test_terminate_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to TERMINATED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="terminate-test" + ) + + # When + updated = await tasks_use_case.terminate_task( + id=task.id, reason="Workflow killed" + ) + + # Then + assert updated.status == TaskStatus.TERMINATED + assert updated.status_reason == "Workflow killed" + + async def test_timeout_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to TIMED_OUT""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="timeout-test" + ) + + # When + updated = await tasks_use_case.timeout_task(id=task.id) + + # Then + assert updated.status == TaskStatus.TIMED_OUT + assert updated.status_reason == "Task timed_out" + + # --- Default reason for each transition --- + + @pytest.mark.parametrize( + "method,expected_reason", + [ + ("complete_task", "Task completed"), + ("fail_task", "Task failed"), + ("cancel_task", "Task canceled"), + ("terminate_task", "Task terminated"), + ("timeout_task", "Task timed_out"), + ], + ) + async def test_default_status_reason( + self, + tasks_use_case, + task_service, + agent_repository, + sample_agent, + method, + expected_reason, + ): + """Test that each transition method sets a default reason when none provided""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name=f"default-reason-{method}" + ) + + # When + updated = await getattr(tasks_use_case, method)(id=task.id) + + # Then + assert updated.status_reason == expected_reason + + # --- Transition by name --- + + async def test_complete_task_by_name( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a task can be transitioned using name instead of id""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="complete-by-name-test" + ) + + # When + updated = await tasks_use_case.complete_task( + name=task.name, reason="Done by name" + ) + + # Then + assert updated.status == TaskStatus.COMPLETED + assert updated.status_reason == "Done by name" + + # --- Blocked transitions from each terminal state --- + + async def test_cannot_transition_completed_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a COMPLETED task cannot be transitioned again""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="double-complete-test" + ) + await tasks_use_case.complete_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.terminate_task(id=task.id) + + async def test_cannot_transition_failed_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a FAILED task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="fail-block-test" + ) + await tasks_use_case.fail_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.complete_task(id=task.id) + + async def test_cannot_transition_canceled_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a CANCELED task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="cancel-block-test" + ) + await tasks_use_case.cancel_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.complete_task(id=task.id) + + async def test_cannot_transition_terminated_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a TERMINATED task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="terminate-block-test" + ) + await tasks_use_case.terminate_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.complete_task(id=task.id) + + async def test_cannot_transition_timed_out_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a TIMED_OUT task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="timeout-block-test" + ) + await tasks_use_case.timeout_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.complete_task(id=task.id) + + async def test_cannot_transition_deleted_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a DELETED task raises not found""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="delete-block-test" + ) + await tasks_use_case.delete_task(id=task.id) + + # When / Then + with pytest.raises(ItemDoesNotExist): + await tasks_use_case.complete_task(id=task.id) + + # --- Validation --- + + async def test_transition_requires_id_or_name(self, tasks_use_case): + """Test that transitioning without id or name raises ClientError""" + with pytest.raises(ClientError, match="Either id or name must be provided"): + await tasks_use_case.complete_task() + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTasksUseCaseMetadataUpdate: + """Test suite for update_mutable_fields_on_task""" + + async def test_update_metadata( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that task_metadata is replaced with the provided value""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-update-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"key": "value"} + ) + + # Then + assert updated.task_metadata == {"key": "value"} + assert updated.status == TaskStatus.RUNNING + + async def test_update_metadata_does_not_change_status( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that updating metadata leaves status unchanged""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-status-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"new": "data"} + ) + + # Then + assert updated.status == TaskStatus.RUNNING + + async def test_update_metadata_replaces_existing( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that metadata is fully replaced, not merged""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-replace-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"original": "data", "keep": "this"} + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"replaced": "entirely"} + ) + + # Then + assert updated.task_metadata == {"replaced": "entirely"} + assert "original" not in updated.task_metadata + + async def test_update_metadata_with_empty_dict( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that metadata can be set to an empty dict""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-empty-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"some": "data"} + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={} + ) + + # Then + assert updated.task_metadata == {} + + async def test_update_metadata_noop_when_none( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that passing task_metadata=None is a no-op""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-noop-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata=None + ) + + # Then + assert updated.id == task.id + assert updated.task_metadata == task.task_metadata + + async def test_update_metadata_by_name( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that metadata can be updated using task name""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-by-name-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + name=task.name, task_metadata={"via": "name"} + ) + + # Then + assert updated.task_metadata == {"via": "name"} + + async def test_update_metadata_on_deleted_task_raises( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that updating metadata on a deleted task raises not found""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-deleted-test" + ) + await tasks_use_case.delete_task(id=task.id) + + # When / Then + with pytest.raises(ItemDoesNotExist): + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"should": "fail"} + ) + + async def test_update_metadata_requires_id_or_name(self, tasks_use_case): + """Test that updating metadata without id or name raises ClientError""" + with pytest.raises(ClientError, match="Either id or name must be provided"): + await tasks_use_case.update_mutable_fields_on_task( + task_metadata={"key": "value"} + )