Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions agentex-ui/components/primary-content/prompt-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
useSafeSearchParams,
} from '@/hooks/use-safe-search-params';
import { useSendMessage } from '@/hooks/use-task-messages';
import { useTask } from '@/hooks/use-tasks';

type PromptInputProps = {
prompt: string;
Expand Down Expand Up @@ -52,10 +53,16 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) {

const createTaskMutation = useCreateTask({ agentexClient });
const sendMessageMutation = useSendMessage({ agentexClient });
const { data: task } = useTask({ agentexClient, taskId: taskID ?? '' });

const textInputRef = useRef<HTMLInputElement>(null);
const codeMirrorViewRef = useRef<EditorView | null>(null);

const isTaskTerminal = useMemo(() => {
if (!taskID || !task) return false;
return task.status != null && task.status !== 'RUNNING';
}, [taskID, task]);

const handleSetJson = useCallback(
(value: boolean) => {
if (value && !prompt.trim()) {
Expand Down Expand Up @@ -86,8 +93,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) {
}, [taskID, isClient, isSendingJSON]);

const isDisabled = useMemo(
() => !agentName || !isClient,
[agentName, isClient]
() => !agentName || !isClient || isTaskTerminal,
[agentName, isClient, isTaskTerminal]
);

const handleSendPrompt = useCallback(async () => {
Expand Down Expand Up @@ -171,6 +178,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) {
prompt={prompt}
setPrompt={setPrompt}
isDisabled={isDisabled}
isTaskTerminal={isTaskTerminal}
taskStatus={task?.status}
handleSendPrompt={handleSendPrompt}
inputRef={textInputRef}
/>
Expand Down Expand Up @@ -205,12 +214,16 @@ const TextInput = ({
prompt,
setPrompt,
isDisabled,
isTaskTerminal,
taskStatus,
handleSendPrompt,
inputRef,
}: {
prompt: string;
setPrompt: (prompt: string) => void;
isDisabled: boolean;
isTaskTerminal: boolean;
taskStatus: string | null | undefined;
handleSendPrompt: () => void;
inputRef: React.RefObject<HTMLInputElement | null>;
}) => {
Expand All @@ -230,7 +243,11 @@ const TextInput = ({
}}
disabled={isDisabled}
placeholder={
isDisabled ? 'Select an agent to start' : 'Enter your prompt'
isTaskTerminal
? `Task ${taskStatus?.toLowerCase() ?? 'ended'}`
: isDisabled
? 'Select an agent to start'
: 'Enter your prompt'
}
className="mr-2 flex-1 outline-none focus:ring-0 focus:outline-none"
style={{
Expand Down
38 changes: 23 additions & 15 deletions agentex-ui/components/task-messages/task-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type TaskMessagesProps = {
};
type MessagePair = {
id: string;
userMessage: TaskMessage;
userMessage: TaskMessage | null;
agentMessages: TaskMessage[];
};

Expand Down Expand Up @@ -58,36 +58,41 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) {
const pairs: MessagePair[] = [];
let currentUserMessage: TaskMessage | null = null;
let currentAgentMessages: TaskMessage[] = [];
let pairStarted = false;

for (const message of messages) {
const isUserMessage = message.content.author === 'user';

if (isUserMessage) {
if (currentUserMessage) {
if (pairStarted) {
pairs.push({
id: currentUserMessage.id || `pair-${pairs.length}`,
id:
currentUserMessage?.id ||
currentAgentMessages[0]?.id ||
`pair-${pairs.length}`,
userMessage: currentUserMessage,
agentMessages: currentAgentMessages,
});
}
currentUserMessage = message;
currentAgentMessages = [];
pairStarted = true;
} else {
if (currentUserMessage) {
currentAgentMessages.push(message);
} else {
pairs.push({
id: message.id || `pair-${pairs.length}`,
userMessage: message,
agentMessages: [],
});
if (!pairStarted) {
currentUserMessage = null;
currentAgentMessages = [];
pairStarted = true;
}
currentAgentMessages.push(message);
}
}

if (currentUserMessage) {
if (pairStarted) {
pairs.push({
id: currentUserMessage.id || `pair-${pairs.length}`,
id:
currentUserMessage?.id ||
currentAgentMessages[0]?.id ||
`pair-${pairs.length}`,
userMessage: currentUserMessage,
agentMessages: currentAgentMessages,
});
Expand All @@ -101,10 +106,13 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) {

const lastPair = messagePairs[messagePairs.length - 1]!;
const hasNoAgentMessages = lastPair.agentMessages.length === 0;
const hasUserMessage = lastPair.userMessage !== null;
const rpcStatus = queryData?.rpcStatus;

return (
hasNoAgentMessages && (rpcStatus === 'pending' || rpcStatus === 'success')
hasUserMessage &&
hasNoAgentMessages &&
(rpcStatus === 'pending' || rpcStatus === 'success')
);
}, [messagePairs, queryData?.rpcStatus]);

Expand Down Expand Up @@ -191,7 +199,7 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) {
containerHeight={containerHeight}
>
<AnimatePresence>
{renderMessage(pair.userMessage)}
{pair.userMessage && renderMessage(pair.userMessage)}
{pair.agentMessages.map(agentMessage => (
<Fragment key={agentMessage.id}>
{renderMessage(agentMessage)}
Expand Down
86 changes: 86 additions & 0 deletions agentex/src/api/routes/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Task,
TaskRelationships,
TaskResponse,
TaskStatusReasonRequest,
UpdateTaskRequest,
)
from src.domain.services.authorization_service import DAuthorizationService
Expand Down Expand Up @@ -169,6 +170,91 @@ async def update_task_by_name(
return Task.model_validate(updated_task_entity)


@router.post(
"/{task_id}/complete",
response_model=Task,
summary="Complete Task",
description="Mark a running task as completed.",
)
async def complete_task(
task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update),
task_use_case: DTaskUseCase,
request: TaskStatusReasonRequest | None = None,
) -> Task:
updated = await task_use_case.complete_task(
id=task_id, reason=request.reason if request else None
)
return Task.model_validate(updated)


@router.post(
"/{task_id}/fail",
response_model=Task,
summary="Fail Task",
description="Mark a running task as failed.",
)
async def fail_task(
task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update),
task_use_case: DTaskUseCase,
request: TaskStatusReasonRequest | None = None,
) -> Task:
updated = await task_use_case.fail_task(
id=task_id, reason=request.reason if request else None
)
return Task.model_validate(updated)


@router.post(
"/{task_id}/cancel",
response_model=Task,
summary="Cancel Task",
description="Mark a running task as canceled.",
)
async def cancel_task(
task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update),
task_use_case: DTaskUseCase,
request: TaskStatusReasonRequest | None = None,
) -> Task:
updated = await task_use_case.cancel_task(
id=task_id, reason=request.reason if request else None
)
return Task.model_validate(updated)


@router.post(
"/{task_id}/terminate",
response_model=Task,
summary="Terminate Task",
description="Mark a running task as terminated.",
)
async def terminate_task(
task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update),
task_use_case: DTaskUseCase,
request: TaskStatusReasonRequest | None = None,
) -> Task:
updated = await task_use_case.terminate_task(
id=task_id, reason=request.reason if request else None
)
return Task.model_validate(updated)


@router.post(
"/{task_id}/timeout",
response_model=Task,
summary="Timeout Task",
description="Mark a running task as timed out.",
)
async def timeout_task(
task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update),
task_use_case: DTaskUseCase,
request: TaskStatusReasonRequest | None = None,
) -> Task:
updated = await task_use_case.timeout_task(
id=task_id, reason=request.reason if request else None
)
return Task.model_validate(updated)


@router.get(
"/{task_id}/stream",
summary="Stream Task Events by ID",
Expand Down
7 changes: 7 additions & 0 deletions agentex/src/api/schemas/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ class UpdateTaskRequest(BaseModel):
None,
title="If provided, replaces task_metadata with this value",
)


class TaskStatusReasonRequest(BaseModel):
reason: str | None = Field(
None,
title="Optional reason for the status change",
)
34 changes: 33 additions & 1 deletion agentex/src/domain/repositories/task_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Annotated, Literal

from fastapi import Depends
from sqlalchemy import select
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from src.adapters.crud_store.adapter_postgres import (
ColumnPrimitiveValue,
Expand Down Expand Up @@ -139,5 +139,37 @@ async def update(self, task: TaskEntity) -> TaskEntity:
# Return with agents populated
return TaskEntity.model_validate(modified_orm)

async def transition_status(
self,
task_id: str,
expected_status: TaskStatus,
new_status: TaskStatus,
status_reason: str,
task_metadata: dict | None = None,
) -> TaskEntity | None:
"""Atomically transition task status. Returns None if the expected status didn't match (i.e. lost the race)."""

async with (
self.start_async_db_session(True) as session,
async_sql_exception_handler(),
):
values: dict = {"status": new_status, "status_reason": status_reason}
if task_metadata is not None:
values["task_metadata"] = task_metadata

stmt = (
update(TaskORM)
.where(TaskORM.id == task_id, TaskORM.status == expected_status)
.values(**values)
)
result = await session.execute(stmt)
await session.commit()

if result.rowcount == 0:
return None

refreshed = await session.get(TaskORM, task_id)
return TaskEntity.model_validate(refreshed)


DTaskRepository = Annotated[TaskRepository, Depends(TaskRepository)]
38 changes: 38 additions & 0 deletions agentex/src/domain/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,44 @@ async def get_task(
id=id, name=name, relationships=relationships
)

async def transition_task_status(
self,
task_id: str,
expected_status: TaskStatus,
new_status: TaskStatus,
status_reason: str,
task_metadata: dict | None = None,
) -> TaskEntity | None:
"""
Atomically transition task status. Returns None if the expected status didn't match.
Publishes a task_updated event on success.
"""
updated_task = await self.task_repository.transition_status(
task_id=task_id,
expected_status=expected_status,
new_status=new_status,
status_reason=status_reason,
task_metadata=task_metadata,
)
if updated_task is None:
return None

try:
topic = get_task_event_stream_topic(task_id=task_id)
await self.stream_repository.send_data(
topic,
TaskStreamTaskUpdatedEventEntity(
type="task_updated", task=updated_task
).model_dump(mode="json"),
)
logger.info(f"task_updated event published to topic: {topic}")
except Exception as e:
logger.error(
f"Error sending task_updated event to stream: {e}", exc_info=True
)

return updated_task

async def update_task(self, task: TaskEntity) -> TaskEntity:
"""
Update a task in the repository.
Expand Down
Loading
Loading