Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions frontend/src/hooks/useAgentChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ interface UseAgentChatOptions {
onSessionDead?: (sessionId: string) => void;
}

function textFromUIMessage(message: UIMessage): string {
return message.parts
.filter((p): p is Extract<typeof p, { type: 'text' }> => p.type === 'text')
.map(p => p.text)
.join('');
}

export function useAgentChat({ sessionId, isActive, isProcessing = false, onReady, onError, onSessionDead }: UseAgentChatOptions) {
const callbacksRef = useRef({ onReady, onError, onSessionDead });
callbacksRef.current = { onReady, onError, onSessionDead };
Expand Down Expand Up @@ -348,6 +355,97 @@ export function useAgentChat({ sessionId, isActive, isProcessing = false, onRead
useUsageStore.getState().applyUsageEvent(sessionId, eventType, data);
},
onInterrupted: () => { /* no-op — handled by stop() caller */ },
onRecoverMessages: async ({
submittedText,
currentMessageCount,
currentUserMessageCount,
sessionInfo,
}) => {
try {
let msgsRes: Response;
let info = sessionInfo;

if (sessionInfo) {
msgsRes = await apiFetch(`/api/session/${sessionId}/messages`);
} else {
const [fetchedMsgsRes, infoRes] = await Promise.all([
apiFetch(`/api/session/${sessionId}/messages`),
apiFetch(`/api/session/${sessionId}`),
]);
msgsRes = fetchedMsgsRes;

if (infoRes.status === 404 && msgsRes.status === 404) {
callbacksRef.current.onSessionDead?.(sessionId);
return false;
}
if (infoRes.ok) {
info = await infoRes.json();
}
}

if (sessionInfo && msgsRes.status === 404) {
callbacksRef.current.onSessionDead?.(sessionId);
return false;
}
if (!msgsRes.ok) return false;

const data = await msgsRes.json();
if (!Array.isArray(data) || data.length === 0) return false;
saveBackendMessages(sessionId, data);

let pendingIds: Set<string> | undefined;
let backendIsProcessing = false;
if (info) {
backendIsProcessing = !!info.is_processing;
if (info.pending_approval && Array.isArray(info.pending_approval)) {
pendingIds = new Set(
info.pending_approval.map((t: { tool_call_id: string }) => t.tool_call_id)
);
if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
}
if (info.auto_approval) {
updateSessionYolo(sessionId, info.auto_approval);
}
}

const uiMsgs = llmMessagesToUIMessages(
data,
pendingIds,
chatActionsRef.current.messages,
);
const backendAdvanced = uiMsgs.length > currentMessageCount;
let submittedTurnAccepted = false;
if (submittedText) {
const userMessages = uiMsgs.filter((m) => m.role === 'user');
const lastUser = userMessages[userMessages.length - 1];
submittedTurnAccepted = (
userMessages.length >= currentUserMessageCount &&
!!lastUser &&
textFromUIMessage(lastUser).trim() === submittedText.trim()
);
}

const setMsgs = chatActionsRef.current.setMessages;
if (setMsgs && uiMsgs.length >= currentMessageCount) {
setMsgs(uiMsgs);
saveMessages(sessionId, uiMsgs);
}

if (backendIsProcessing) {
setProcessingState(true, { activityStatus: { type: 'thinking' } });
return false;
}
if (pendingIds && pendingIds.size > 0) {
setProcessingState(false, { activityStatus: { type: 'waiting-approval' } });
} else {
setProcessingState(false);
}

return backendAdvanced || submittedTurnAccepted;
} catch {
return false;
}
},
}),
// eslint-disable-next-line react-hooks/exhaustive-deps
[sessionId, setProcessingState],
Expand Down
167 changes: 144 additions & 23 deletions frontend/src/lib/sse-chat-transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ export interface SideChannelCallbacks {
onToolRunning: (toolName: string, description?: string) => void;
onUsageEvent: (eventType: 'llm_call' | 'hf_job_complete', data: Record<string, unknown>) => void;
onInterrupted: () => void;
onRecoverMessages: (context: MessageRecoveryContext) => Promise<boolean>;
}

export interface MessageRecoveryContext {
submittedText?: string;
currentMessageCount: number;
currentUserMessageCount: number;
sessionInfo?: RecoverySessionInfo;
}

export interface RecoverySessionInfo {
is_processing?: boolean;
pending_approval?: Array<{ tool_call_id: string }> | null;
auto_approval?: {
enabled: boolean;
cost_cap_usd?: number | null;
estimated_spend_usd?: number;
remaining_usd?: number | null;
} | null;
}

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -70,6 +89,34 @@ async function readErrorResponse(response: Response): Promise<string> {
}
}

function isAbortError(error: unknown, signal?: AbortSignal): boolean {
if (signal?.aborted) return true;
if (!(error instanceof Error)) return false;
return error.name === 'AbortError';
}

function isRecoverableFetchError(error: unknown): boolean {
if (!(error instanceof Error)) return false;

const name = error.name.toLowerCase();
const message = error.message.toLowerCase();
const networkFailureMessages = [
'load failed',
'failed to fetch',
'fetch failed',
'networkerror',
'network error',
'network request failed',
'network connection was lost',
'internet connection appears to be offline',
];

return (
name === 'networkerror' ||
(name === 'typeerror' && networkFailureMessages.some((pattern) => message.includes(pattern)))
);
}

/** Parse an SSE text stream into AgentEvent objects. */
function createSSEParserStream(sessionId: string): TransformStream<string, AgentEvent> {
let buffer = '';
Expand Down Expand Up @@ -130,6 +177,15 @@ function createSSEParserStream(sessionId: string): TransformStream<string, Agent
});
}

function createRecoveredFinishedStream(): ReadableStream<UIMessageChunk> {
return new ReadableStream<UIMessageChunk>({
start(controller) {
controller.enqueue({ type: 'finish', finishReason: 'stop' });
controller.close();
},
});
}

/** Transform AgentEvent objects into UIMessageChunk objects for the Vercel AI SDK. */
function createEventToChunkStream(sideChannel: SideChannelCallbacks): TransformStream<AgentEvent, UIMessageChunk> {
let textPartId: string | null = null;
Expand Down Expand Up @@ -369,6 +425,69 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
// Nothing to clean up — no persistent connections
}

private async connectToEventStream(): Promise<ReadableStream<UIMessageChunk> | null> {
const lastSeq = localStorage.getItem(lastEventKey(this.sessionId));
const qs = lastSeq ? `?after=${encodeURIComponent(lastSeq)}` : '';
const response = await apiFetch(`/api/events/${this.sessionId}${qs}`, {
headers: { 'Accept': 'text/event-stream' },
});
if (!response.ok || !response.body) return null;

this.sideChannel.onProcessing();

return response.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(createSSEParserStream(this.sessionId))
.pipeThrough(createEventToChunkStream(this.sideChannel));
}

private async recoverFailedSend(
context: MessageRecoveryContext,
): Promise<ReadableStream<UIMessageChunk>> {
let infoRes: Response;
try {
infoRes = await apiFetch(`/api/session/${this.sessionId}`);
} catch {
throw new Error(
'Connection to the Space was interrupted before the message was accepted. Please retry.',
);
}

if (infoRes.status === 404) {
this.sideChannel.onSessionDead(this.sessionId);
throw new Error('Session not found or inactive');
}
if (!infoRes.ok) {
throw new Error(
'Connection to the Space was interrupted before the message was accepted. Please retry.',
);
}

const info = await infoRes.json() as RecoverySessionInfo;
if (info.is_processing) {
try {
const stream = await this.connectToEventStream();
if (stream) return stream;
} catch {
// Fall through to message hydration; the turn may have completed
// between the status probe and the event-stream reconnect.
}
}

const recovered = await this.sideChannel.onRecoverMessages({
...context,
sessionInfo: info.is_processing ? undefined : info,
});
if (recovered) {
this.sideChannel.onProcessingDone();
return createRecoveredFinishedStream();
}

throw new Error(
'Connection to the Space was interrupted before the message was accepted. Please retry.',
);
}

// -- ChatTransport interface ---------------------------------------------

async sendMessages(
Expand All @@ -391,6 +510,7 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
) || [];

let body: Record<string, unknown>;
let submittedText: string | undefined;
if (approvedParts.length > 0) {
// Approval continuation — extract approval decisions
const approvals = approvedParts.map((p) => {
Expand All @@ -415,19 +535,33 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
.map(p => p.text)
.join('')
: '';
submittedText = text;
body = { text };
}

// POST to SSE endpoint
const response = await apiFetch(`/api/chat/${sessionId}`, {
method: 'POST',
body: JSON.stringify(body),
signal: options.abortSignal,
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
},
});
let response: Response;
try {
response = await apiFetch(`/api/chat/${sessionId}`, {
method: 'POST',
body: JSON.stringify(body),
signal: options.abortSignal,
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
},
});
} catch (error) {
if (isAbortError(error, options.abortSignal) || !isRecoverableFetchError(error)) {
throw error;
}
logger.warn('Chat POST failed; attempting session recovery:', error);
return this.recoverFailedSend({
submittedText,
currentMessageCount: options.messages.length,
currentUserMessageCount: options.messages.filter(m => m.role === 'user').length,
});
}

if (response.status === 404) {
// Backend lost this session (e.g. Space restart). Signal the UI so
Expand Down Expand Up @@ -460,20 +594,7 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
const info = await infoRes.json();
if (!info.is_processing) return null;

// Session is mid-turn — subscribe to its event broadcast.
const lastSeq = localStorage.getItem(lastEventKey(this.sessionId));
const qs = lastSeq ? `?after=${encodeURIComponent(lastSeq)}` : '';
const response = await apiFetch(`/api/events/${this.sessionId}${qs}`, {
headers: { 'Accept': 'text/event-stream' },
});
if (!response.ok || !response.body) return null;

this.sideChannel.onProcessing();

return response.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(createSSEParserStream(this.sessionId))
.pipeThrough(createEventToChunkStream(this.sideChannel));
return this.connectToEventStream();
} catch {
return null;
}
Expand Down
Loading