From 994198113f4fae0482dc3bb5e26863374bfb00cd Mon Sep 17 00:00:00 2001 From: locnguyen1986 Date: Thu, 18 Dec 2025 23:24:04 +0700 Subject: [PATCH 1/2] model context size checking --- .../handlers/chathandler/chat_handler.go | 113 ++- .../handlers/chathandler/message_trimmer.go | 734 ++++++++++++++++-- .../httpserver/responses/chat/chat.go | 4 +- .../routes/v1/chat/completion_route.go | 27 +- .../internal/utils/platformerrors/errors.go | 5 + .../internal/domain/llm/message_trimmer.go | 46 +- 6 files changed, 821 insertions(+), 108 deletions(-) diff --git a/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go b/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go index 50aeefae..bf1f7734 100644 --- a/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go +++ b/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go @@ -18,6 +18,7 @@ import ( "jan-server/services/llm-api/internal/domain/prompt" "jan-server/services/llm-api/internal/domain/usersettings" "jan-server/services/llm-api/internal/infrastructure/inference" + "jan-server/services/llm-api/internal/infrastructure/logger" "jan-server/services/llm-api/internal/infrastructure/mediaresolver" memclient "jan-server/services/llm-api/internal/infrastructure/memory" "jan-server/services/llm-api/internal/infrastructure/metrics" @@ -41,6 +42,7 @@ type ChatCompletionResult struct { Response *openai.ChatCompletionResponse ConversationID string ConversationTitle *string + Trimmed bool // True if messages were trimmed to fit context } // ChatHandler handles chat completion requests @@ -289,20 +291,112 @@ func (h *ChatHandler) CreateChatCompletion( return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to create chat client") } - // Trim messages to fit within model's context length + // Build token budget for context management contextLength := DefaultContextLength if modelCatalog != nil && modelCatalog.ContextLength != nil && *modelCatalog.ContextLength > 0 { contextLength = *modelCatalog.ContextLength } - trimResult := TrimMessagesToFitContext(request.Messages, contextLength) - if trimResult.TrimmedCount > 0 { - observability.AddSpanEvent(ctx, "messages_trimmed", - attribute.Int("trimmed_count", trimResult.TrimmedCount), - attribute.Int("estimated_tokens", trimResult.EstimatedTokens), - attribute.Int("context_length", contextLength), - ) - request.Messages = trimResult.Messages + + // Validate user input size BEFORE any processing + // This returns an error if the current user input exceeds MaxUserContentTokens + if err := ValidateUserInputSize(request.Messages); err != nil { + observability.RecordError(ctx, err) + return nil, platformerrors.NewError(ctx, platformerrors.LayerHandler, platformerrors.ErrorTypeValidation, err.Error(), nil, "a1b2c3d4-e5f6-7890-abcd-ef1234567890") + } + + // Get max_tokens from request (0 if not set) + maxCompletionTokens := 0 + if request.MaxTokens > 0 { + maxCompletionTokens = request.MaxTokens + } + + // Track whether any trimming occurred + wasTrimmed := false + + // Build and validate token budget + budget := BuildTokenBudget(contextLength, request.Tools, maxCompletionTokens) + if err := budget.Validate(); err != nil { + budgetLog := logger.GetLogger() + budgetLog.Warn(). + Err(err). + Int("context_length", budget.ContextLength). + Int("tools_tokens", budget.ToolsTokens). + Int("max_completion_tokens", budget.MaxCompletionTokens). + Msg("token budget validation failed, using fallback trimming") + // Fall back to legacy trimming if budget validation fails + trimResult := TrimMessagesToFitContext(request.Messages, contextLength) + if trimResult.TrimmedCount > 0 { + wasTrimmed = true + observability.AddSpanEvent(ctx, "messages_trimmed", + attribute.Int("trimmed_count", trimResult.TrimmedCount), + attribute.Int("estimated_tokens", trimResult.EstimatedTokens), + attribute.Int("context_length", contextLength), + ) + request.Messages = trimResult.Messages + } + } else { + // Log budget for debugging + budgetLog := logger.GetLogger() + budgetLog.Info(). + Int("context_length", budget.ContextLength). + Int("tools_tokens", budget.ToolsTokens). + Int("response_reserve", budget.ResponseReserve). + Int("available_for_messages", budget.AvailableForMessages). + Msg("token budget computed") + + // First, truncate oversized user content in HISTORICAL messages (not current input) + userTruncatedMessages, userTruncEvents := TruncateLargeUserContent(request.Messages) + if len(userTruncEvents) > 0 { + wasTrimmed = true + observability.AddSpanEvent(ctx, "user_content_truncated", + attribute.Int("truncation_count", len(userTruncEvents)), + ) + request.Messages = userTruncatedMessages + } + + // Second, truncate oversized tool content (with JSON-aware parsing) + truncatedMessages, truncEvents := TruncateLargeToolContent(request.Messages) + if len(truncEvents) > 0 { + wasTrimmed = true + observability.AddSpanEvent(ctx, "tool_content_truncated", + attribute.Int("truncation_count", len(truncEvents)), + ) + request.Messages = truncatedMessages + } + + // Then trim messages using the validated budget (oldest items first) + trimResult := TrimMessagesToFitBudget(request.Messages, budget) + if trimResult.TrimmedCount > 0 { + wasTrimmed = true + observability.AddSpanEvent(ctx, "messages_trimmed", + attribute.Int("trimmed_count", trimResult.TrimmedCount), + attribute.Int("estimated_tokens", trimResult.EstimatedTokens), + attribute.Int("context_length", contextLength), + attribute.Int("tools_tokens", budget.ToolsTokens), + ) + request.Messages = trimResult.Messages + } + } + + // Log final content size AFTER all trimming for accurate debugging + finalContentLength := 0 + for _, msg := range request.Messages { + finalContentLength += len(msg.Content) + for _, part := range msg.MultiContent { + finalContentLength += len(part.Text) + } } + trimLog := logger.GetLogger() + trimLog.Info(). + Str("route", "/v1/chat/completions"). + Str("model", request.Model). + Str("conversation_id", conversationID). + Int("messages_after_trim", len(request.Messages)). + Int("content_length_after_trim", finalContentLength). + Int("context_length", contextLength). + Bool("stream", request.Stream). + Bool("trimmed", wasTrimmed). + Msg("chat completion ready for LLM (after trimming)") var response *openai.ChatCompletionResponse @@ -414,6 +508,7 @@ func (h *ChatHandler) CreateChatCompletion( Response: response, ConversationID: conversationID, ConversationTitle: conversationTitle, + Trimmed: wasTrimmed, }, nil } diff --git a/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/message_trimmer.go b/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/message_trimmer.go index 73bae97f..62306321 100644 --- a/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/message_trimmer.go +++ b/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/message_trimmer.go @@ -2,6 +2,9 @@ package chathandler import ( "encoding/json" + "fmt" + "strconv" + "strings" "unicode/utf8" "jan-server/services/llm-api/internal/infrastructure/logger" @@ -11,19 +14,272 @@ import ( const ( // DefaultContextLength is used when model context length is unknown. - DefaultContextLength = 128000 // 128k tokens as fallback + DefaultContextLength = 220000 // 220k tokens as fallback // TokenEstimateRatio estimates ~4 characters per token (conservative estimate). TokenEstimateRatio = 4 + // TokenEstimateRatioCJK estimates ~1.5 characters per token for CJK content. + TokenEstimateRatioCJK = 1.5 + // MinMessagesToKeep ensures we always keep system prompt + at least one user message. MinMessagesToKeep = 2 - // SafetyMarginRatio reserves space for response and overhead (20% margin). - SafetyMarginRatio = 0.80 + // MinMessagesTokenFloor is the hard minimum tokens required for messages. + MinMessagesTokenFloor = 1000 + + // SafetyMarginRatio reserves space for response and overhead (15% margin for response). + SafetyMarginRatio = 0.75 + + // FixedOverheadTokens is fixed overhead for API request structure. + FixedOverheadTokens = 100 + + // MaxToolSchemaBytes caps tool parameter schema size to prevent runaway serialization. + MaxToolSchemaBytes = 16384 // 16KB + + // MaxToolResultTokens is max tokens per tool result before truncation. + MaxToolResultTokens = 20000 + + // MaxToolArgumentTokens is max tokens for tool call arguments. + MaxToolArgumentTokens = 2000 + + // MaxUserContentTokens is max tokens per user message text content before truncation. + MaxUserContentTokens = 24000 + + // MaxMultiContentTextTokens is max tokens per text part in multi-content arrays. + MaxMultiContentTextTokens = 6000 + + // Image token estimates (conservative for safety) + ImageTokensLowRes = 85 // Low resolution image + ImageTokensHighRes = 850 // High resolution image (average) ) +// =============================== +// TokenBudget - Central budget management +// =============================== + +// TokenBudget represents the complete token budget for a request. +// This struct flows through the trimmer so callers don't recompute. +type TokenBudget struct { + ContextLength int // Total context window size + ToolsTokens int // Tokens consumed by tool definitions + MaxCompletionTokens int // User-requested max_tokens (0 = use default margin) + FixedOverhead int // Fixed overhead (API structure, formatting) + + // Computed fields (set by Validate()) + AvailableForMessages int // Tokens available for message content + ResponseReserve int // Tokens reserved for response +} + +// Validate checks the budget and computes available space. +// Returns error if budget is invalid (e.g., max_tokens exceeds context). +func (b *TokenBudget) Validate() error { + if b.ContextLength <= 0 { + return fmt.Errorf("invalid context length: %d (must be positive)", b.ContextLength) + } + + // Calculate response reserve + if b.MaxCompletionTokens > 0 { + b.ResponseReserve = b.MaxCompletionTokens + } else { + b.ResponseReserve = int(float64(b.ContextLength) * (1 - SafetyMarginRatio)) + } + + // Calculate available space for messages + b.AvailableForMessages = b.ContextLength - b.ToolsTokens - b.ResponseReserve - b.FixedOverhead + + // Hard floor check: if available space is too small, return error + if b.AvailableForMessages < MinMessagesTokenFloor { + return fmt.Errorf( + "token budget exhausted: context=%d, tools=%d, response_reserve=%d, overhead=%d → only %d tokens available (minimum required: %d). Reduce max_tokens, use fewer tools, or choose a model with larger context", + b.ContextLength, b.ToolsTokens, b.ResponseReserve, b.FixedOverhead, + b.AvailableForMessages, MinMessagesTokenFloor, + ) + } + + return nil +} + +// =============================== +// Tool Token Estimation +// =============================== + +// EstimateToolsTokens estimates tokens for tool definitions. +// Logs warnings for marshal errors and caps schema size. +func EstimateToolsTokens(tools []openai.Tool) int { + if len(tools) == 0 { + return 0 + } + + log := logger.GetLogger() + total := 50 // Base overhead for tools array structure + + for _, tool := range tools { + total += 20 // Overhead per tool + if tool.Function != nil { + total += estimateTokenCount(tool.Function.Name) + total += estimateTokenCount(tool.Function.Description) + + // Parameters schema can be large - cap and handle errors + if tool.Function.Parameters != nil { + paramsJSON, err := json.Marshal(tool.Function.Parameters) + if err != nil { + log.Warn(). + Str("tool", tool.Function.Name). + Err(err). + Msg("failed to marshal tool parameters, using fallback estimate") + total += 200 // Conservative fallback + continue + } + + // Cap schema size to prevent extremely large schemas + if len(paramsJSON) > MaxToolSchemaBytes { + log.Warn(). + Str("tool", tool.Function.Name). + Int("schema_bytes", len(paramsJSON)). + Int("cap_bytes", MaxToolSchemaBytes). + Msg("tool schema exceeds size cap, truncating estimate") + paramsJSON = paramsJSON[:MaxToolSchemaBytes] + } + + total += estimateTokenCount(string(paramsJSON)) + } + } + } + + log.Debug(). + Int("tool_count", len(tools)). + Int("estimated_tokens", total). + Msg("estimated tools tokens") + + return total +} + +// =============================== +// Image Token Estimation +// =============================== + +// estimateImageTokens estimates tokens for an image based on detail level. +// Missing or empty detail is normalized to "high" for conservative estimation. +func estimateImageTokens(imageURL *openai.ChatMessageImageURL) int { + if imageURL == nil { + return 0 + } + + // Normalize: treat empty/missing detail as "high" for safety + detail := imageURL.Detail + if detail == "" { + detail = openai.ImageURLDetailHigh + } + + switch detail { + case openai.ImageURLDetailLow: + return ImageTokensLowRes + case openai.ImageURLDetailHigh: + return ImageTokensHighRes + case openai.ImageURLDetailAuto: + return ImageTokensHighRes + default: + return ImageTokensHighRes + } +} + +// estimateMultiContentTokens handles different content part types. +func estimateMultiContentTokens(parts []openai.ChatMessagePart) int { + total := 0 + for _, part := range parts { + switch part.Type { + case openai.ChatMessagePartTypeText: + total += estimateTokenCount(part.Text) + case openai.ChatMessagePartTypeImageURL: + total += estimateImageTokens(part.ImageURL) + } + } + return total +} + +// countImagesInToolResult detects images embedded in tool result content. +// Uses lightweight JSON sniffing to avoid false positives/negatives. +func countImagesInToolResult(content string) int { + if len(content) == 0 { + return 0 + } + + // Quick pre-check: if content doesn't look like JSON array, skip parsing + trimmed := strings.TrimSpace(content) + if len(trimmed) == 0 || trimmed[0] != '[' { + return countDataURLImages(content) + } + + // Try to parse as JSON array + var items []map[string]any + if err := json.Unmarshal([]byte(content), &items); err != nil { + return countDataURLImages(content) + } + + imageCount := 0 + for _, item := range items { + if isImageType(item) { + imageCount++ + continue + } + if hasImageDataURL(item) { + imageCount++ + } + } + + return imageCount +} + +// isImageType checks if a map represents an image content type. +func isImageType(item map[string]any) bool { + for _, key := range []string{"type", "kind", "contentType"} { + if val, ok := item[key].(string); ok { + if val == "image" || strings.HasPrefix(val, "image/") { + return true + } + } + } + return false +} + +// hasImageDataURL checks if a map contains an image data URL. +func hasImageDataURL(item map[string]any) bool { + for _, key := range []string{"data", "url", "src", "imageUrl"} { + if val, ok := item[key].(string); ok { + if strings.HasPrefix(val, "data:image/") { + return true + } + } + } + return false +} + +// countDataURLImages counts data URL images in non-JSON content (fallback). +func countDataURLImages(content string) int { + count := 0 + count += strings.Count(content, "data:image/png;base64") + count += strings.Count(content, "data:image/jpeg;base64") + count += strings.Count(content, "data:image/webp;base64") + count += strings.Count(content, "data:image/gif;base64") + return count +} + +// =============================== +// CJK Character Detection +// =============================== + +// isCJK checks if a rune is a CJK character. +func isCJK(r rune) bool { + return (r >= 0x4E00 && r <= 0x9FFF) || // CJK Unified Ideographs + (r >= 0x3400 && r <= 0x4DBF) || // CJK Unified Ideographs Extension A + (r >= 0x3040 && r <= 0x309F) || // Hiragana + (r >= 0x30A0 && r <= 0x30FF) || // Katakana + (r >= 0xAC00 && r <= 0xD7AF) // Hangul Syllables +} + // estimateTokenCount provides a rough estimate of token count for content. +// Handles CJK characters with adjusted ratio for better accuracy. func estimateTokenCount(content interface{}) int { var text string switch v := content.(type) { @@ -32,13 +288,39 @@ func estimateTokenCount(content interface{}) int { case nil: return 0 default: - bytes, _ := json.Marshal(v) + bytes, err := json.Marshal(v) + if err != nil { + return 50 // Fallback estimate for marshal errors + } text = string(bytes) } - return utf8.RuneCountInString(text) / TokenEstimateRatio + + if len(text) == 0 { + return 0 + } + + runeCount := utf8.RuneCountInString(text) + + // Count CJK characters for adjusted estimation + cjkCount := 0 + for _, r := range text { + if isCJK(r) { + cjkCount++ + } + } + + // If more than 30% CJK, use CJK ratio for that portion + if runeCount > 0 && float64(cjkCount)/float64(runeCount) > 0.3 { + cjkTokens := float64(cjkCount) / TokenEstimateRatioCJK + otherTokens := float64(runeCount-cjkCount) / float64(TokenEstimateRatio) + return int(cjkTokens + otherTokens) + } + + return runeCount / TokenEstimateRatio } // estimateMessagesTokenCount estimates total tokens across all messages. +// Includes proper handling for images in MultiContent and tool results. func estimateMessagesTokenCount(messages []openai.ChatCompletionMessage) int { total := 0 for _, msg := range messages { @@ -46,11 +328,15 @@ func estimateMessagesTokenCount(messages []openai.ChatCompletionMessage) int { total += 10 total += estimateTokenCount(msg.Content) - // Handle multipart content + // Handle multipart content with image support if len(msg.MultiContent) > 0 { - for _, part := range msg.MultiContent { - total += estimateTokenCount(part.Text) - } + total += estimateMultiContentTokens(msg.MultiContent) + } + + // Count images in tool results (browser screenshots, etc.) + if msg.Role == "tool" && msg.Content != "" { + imageCount := countImagesInToolResult(msg.Content) + total += imageCount * ImageTokensHighRes } // Add tokens for tool calls @@ -65,6 +351,320 @@ func estimateMessagesTokenCount(messages []openai.ChatCompletionMessage) int { return total } +// =============================== +// Tool Content Truncation +// =============================== + +// TruncationEvent represents a truncation for logging/metrics. +type TruncationEvent struct { + MessageIndex int + ToolName string + ToolCallID string + OriginalTokens int + TruncatedTokens int + TruncationType string // "tool_result" or "tool_argument" +} + +// truncateTextPreservingJSON truncates text content while trying to preserve JSON structure. +// If content is JSON-stringified MultiContent, it parses and truncates the nested text fields. +func truncateTextPreservingJSON(content string, maxTokens int) (string, bool) { + maxChars := maxTokens * TokenEstimateRatio + trimmed := strings.TrimSpace(content) + + // Check if it looks like a JSON array (MultiContent format) + if len(trimmed) > 0 && trimmed[0] == '[' { + var parts []map[string]interface{} + if err := json.Unmarshal([]byte(content), &parts); err == nil { + // Successfully parsed as JSON array - truncate nested text fields + modified := false + for i := range parts { + if textVal, ok := parts[i]["text"]; ok { + if textStr, ok := textVal.(string); ok { + textTokens := estimateTokenCount(textStr) + if textTokens > maxTokens { + textRunes := []rune(textStr) + if len(textRunes) > maxChars { + parts[i]["text"] = string(textRunes[:maxChars]) + "\n\n[Content truncated]" + modified = true + } + } + } + } + } + if modified { + if newContent, err := json.Marshal(parts); err == nil { + return string(newContent), true + } + } + return content, false + } + } + + // Check if it looks like a JSON object with nested content + if len(trimmed) > 0 && trimmed[0] == '{' { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(content), &obj); err == nil { + modified := false + // Truncate common large text fields + for _, key := range []string{"text", "content", "markdown", "raw_text", "body"} { + if textVal, ok := obj[key]; ok { + if textStr, ok := textVal.(string); ok { + textTokens := estimateTokenCount(textStr) + if textTokens > maxTokens { + textRunes := []rune(textStr) + if len(textRunes) > maxChars { + obj[key] = string(textRunes[:maxChars]) + "\n\n[Content truncated]" + modified = true + } + } + } + } + } + if modified { + if newContent, err := json.Marshal(obj); err == nil { + return string(newContent), true + } + } + return content, false + } + } + + // Plain text - simple truncation + runes := []rune(content) + if len(runes) > maxChars { + return string(runes[:maxChars]) + "\n\n[Content truncated - exceeded " + strconv.Itoa(maxTokens) + " token limit]", true + } + return content, false +} + +// TruncateLargeToolContent reduces oversized tool results AND arguments. +// Now with MultiContent-aware JSON parsing to truncate nested text fields properly. +// Returns the modified messages and a list of truncation events for logging. +func TruncateLargeToolContent(messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, []TruncationEvent) { + log := logger.GetLogger() + result := make([]openai.ChatCompletionMessage, len(messages)) + copy(result, messages) + + var events []TruncationEvent + + for i := range result { + // Truncate tool results with MultiContent-aware parsing + if result[i].Role == "tool" && result[i].Content != "" { + originalTokens := estimateTokenCount(result[i].Content) + if originalTokens > MaxToolResultTokens { + truncatedContent, didTruncate := truncateTextPreservingJSON(result[i].Content, MaxToolResultTokens) + if didTruncate { + result[i].Content = truncatedContent + + event := TruncationEvent{ + MessageIndex: i, + ToolCallID: result[i].ToolCallID, + OriginalTokens: originalTokens, + TruncatedTokens: MaxToolResultTokens, + TruncationType: "tool_result", + } + events = append(events, event) + + log.Warn(). + Str("tool_call_id", result[i].ToolCallID). + Int("original_tokens", originalTokens). + Int("truncated_to", MaxToolResultTokens). + Msg("truncated large tool result (JSON-aware)") + } + } + } + + // Truncate tool call arguments (in assistant messages) + if result[i].ToolCalls != nil { + for j := range result[i].ToolCalls { + tc := &result[i].ToolCalls[j] + originalTokens := estimateTokenCount(tc.Function.Arguments) + if originalTokens > MaxToolArgumentTokens { + maxChars := MaxToolArgumentTokens * TokenEstimateRatio + runes := []rune(tc.Function.Arguments) + if len(runes) > maxChars { + tc.Function.Arguments = string(runes[:maxChars]) + "...[truncated]" + + event := TruncationEvent{ + MessageIndex: i, + ToolName: tc.Function.Name, + ToolCallID: tc.ID, + OriginalTokens: originalTokens, + TruncatedTokens: MaxToolArgumentTokens, + TruncationType: "tool_argument", + } + events = append(events, event) + + log.Warn(). + Str("tool_name", tc.Function.Name). + Str("tool_call_id", tc.ID). + Int("original_tokens", originalTokens). + Int("truncated_to", MaxToolArgumentTokens). + Msg("truncated large tool arguments") + } + } + } + } + } + + if len(events) > 0 { + log.Info(). + Int("total_truncations", len(events)). + Msg("tool content truncation summary") + } + + return result, events +} + +// UserInputValidationError represents an error when user input exceeds token limits. +type UserInputValidationError struct { + EstimatedTokens int + MaxTokens int + Message string +} + +func (e *UserInputValidationError) Error() string { + return e.Message +} + +// ValidateUserInputSize checks if the last user message (current input) exceeds MaxUserContentTokens. +// Returns an error if the user input is too large, preventing the request from proceeding. +// This only validates the LAST user message (current input), not historical messages. +func ValidateUserInputSize(messages []openai.ChatCompletionMessage) error { + if len(messages) == 0 { + return nil + } + + // Find the last user message (current user input) + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role != "user" { + continue + } + + // Check plain string content + if messages[i].Content != "" && len(messages[i].MultiContent) == 0 { + tokens := estimateTokenCount(messages[i].Content) + if tokens > MaxUserContentTokens { + return &UserInputValidationError{ + EstimatedTokens: tokens, + MaxTokens: MaxUserContentTokens, + Message: fmt.Sprintf( + "User input too large: estimated %d tokens exceeds maximum allowed %d tokens. Please reduce your message size.", + tokens, MaxUserContentTokens, + ), + } + } + } + + // Check MultiContent array + if len(messages[i].MultiContent) > 0 { + totalTextTokens := 0 + for _, part := range messages[i].MultiContent { + if part.Type == openai.ChatMessagePartTypeText && part.Text != "" { + totalTextTokens += estimateTokenCount(part.Text) + } + } + if totalTextTokens > MaxUserContentTokens { + return &UserInputValidationError{ + EstimatedTokens: totalTextTokens, + MaxTokens: MaxUserContentTokens, + Message: fmt.Sprintf( + "User input too large: estimated %d tokens exceeds maximum allowed %d tokens. Please reduce your message size.", + totalTextTokens, MaxUserContentTokens, + ), + } + } + } + + // Only check the last user message + break + } + + return nil +} + +// TruncateLargeUserContent reduces oversized user message content. +// Handles both plain string content and MultiContent arrays. +// Returns the modified messages and a list of truncation events for logging. +func TruncateLargeUserContent(messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, []TruncationEvent) { + log := logger.GetLogger() + result := make([]openai.ChatCompletionMessage, len(messages)) + copy(result, messages) + + var events []TruncationEvent + + for i := range result { + if result[i].Role != "user" { + continue + } + + // Handle plain string content + if result[i].Content != "" && len(result[i].MultiContent) == 0 { + originalTokens := estimateTokenCount(result[i].Content) + if originalTokens > MaxUserContentTokens { + truncatedContent, didTruncate := truncateTextPreservingJSON(result[i].Content, MaxUserContentTokens) + if didTruncate { + result[i].Content = truncatedContent + + event := TruncationEvent{ + MessageIndex: i, + OriginalTokens: originalTokens, + TruncatedTokens: MaxUserContentTokens, + TruncationType: "user_content", + } + events = append(events, event) + + log.Warn(). + Int("message_index", i). + Int("original_tokens", originalTokens). + Int("truncated_to", MaxUserContentTokens). + Msg("truncated large user content") + } + } + } + + // Handle MultiContent array + if len(result[i].MultiContent) > 0 { + for j := range result[i].MultiContent { + part := &result[i].MultiContent[j] + if part.Type == openai.ChatMessagePartTypeText && part.Text != "" { + originalTokens := estimateTokenCount(part.Text) + if originalTokens > MaxMultiContentTextTokens { + maxChars := MaxMultiContentTextTokens * TokenEstimateRatio + runes := []rune(part.Text) + if len(runes) > maxChars { + part.Text = string(runes[:maxChars]) + "\n\n[Content truncated - exceeded " + strconv.Itoa(MaxMultiContentTextTokens) + " token limit]" + + event := TruncationEvent{ + MessageIndex: i, + OriginalTokens: originalTokens, + TruncatedTokens: MaxMultiContentTextTokens, + TruncationType: "user_multicontent_text", + } + events = append(events, event) + + log.Warn(). + Int("message_index", i). + Int("part_index", j). + Int("original_tokens", originalTokens). + Int("truncated_to", MaxMultiContentTextTokens). + Msg("truncated large user multi-content text part") + } + } + } + } + } + } + + if len(events) > 0 { + log.Info(). + Int("total_user_truncations", len(events)). + Msg("user content truncation summary") + } + + return result, events +} + // TrimMessagesResult contains the result of trimming messages. type TrimMessagesResult struct { Messages []openai.ChatCompletionMessage @@ -72,6 +672,12 @@ type TrimMessagesResult struct { EstimatedTokens int } +// TrimMessagesToFitBudget trims messages using the provided TokenBudget. +// The budget must be validated before calling this function. +func TrimMessagesToFitBudget(messages []openai.ChatCompletionMessage, budget *TokenBudget) TrimMessagesResult { + return trimMessagesInternal(messages, budget.AvailableForMessages) +} + // TrimMessagesToFitContext removes oldest tool results and assistant messages // to fit within the context length limit. // Priority order for removal (oldest first): @@ -79,6 +685,8 @@ type TrimMessagesResult struct { // 2. Assistant messages with tool calls // 3. Regular assistant messages // Never removes: system prompts, user messages +// +// Deprecated: Use TrimMessagesToFitBudget with a validated TokenBudget instead. func TrimMessagesToFitContext(messages []openai.ChatCompletionMessage, contextLength int) TrimMessagesResult { if contextLength <= 0 { contextLength = DefaultContextLength @@ -86,7 +694,12 @@ func TrimMessagesToFitContext(messages []openai.ChatCompletionMessage, contextLe // Apply safety margin maxTokens := int(float64(contextLength) * SafetyMarginRatio) + return trimMessagesInternal(messages, maxTokens) +} +// trimMessagesInternal is the core trimming logic used by both public functions. +// Removes oldest conversation items first (any role except system) to fit within token budget. +func trimMessagesInternal(messages []openai.ChatCompletionMessage, maxTokens int) TrimMessagesResult { currentTokens := estimateMessagesTokenCount(messages) if currentTokens <= maxTokens { return TrimMessagesResult{ @@ -101,7 +714,6 @@ func TrimMessagesToFitContext(messages []openai.ChatCompletionMessage, contextLe Int("initial_messages", len(messages)). Int("initial_tokens", currentTokens). Int("max_tokens", maxTokens). - Int("context_length", contextLength). Msg("starting message trimming") // Create a working copy @@ -110,79 +722,42 @@ func TrimMessagesToFitContext(messages []openai.ChatCompletionMessage, contextLe trimmedCount := 0 // Build a token count cache for efficient removal - // This avoids O(n²) complexity from recalculating all tokens on each removal messageTokens := make([]int, len(result)) for i := range result { - tokens := 10 // Overhead for role and structure - tokens += estimateTokenCount(result[i].Content) - - if len(result[i].MultiContent) > 0 { - for _, part := range result[i].MultiContent { - tokens += estimateTokenCount(part.Text) - } - } - - if result[i].ToolCalls != nil { - for _, tc := range result[i].ToolCalls { - tokens += 20 - tokens += estimateTokenCount(tc.Function.Name) - tokens += estimateTokenCount(tc.Function.Arguments) - } - } - messageTokens[i] = tokens + messageTokens[i] = estimateSingleMessageTokens(&result[i]) } - // Find indices of messages that can be removed (in order of priority) - // We iterate from oldest to newest (excluding system prompt at index 0) + // Remove oldest items first (any role except system at index 0) + // This approach removes conversation items chronologically from oldest to newest for currentTokens > maxTokens && len(result) > MinMessagesToKeep { + // Find the oldest removable message (skip index 0 which is system prompt) removedIdx := -1 - - // Phase 1: Remove oldest tool result message for i := 1; i < len(result); i++ { - if result[i].Role == "tool" { - removedIdx = i - break - } - } - - // Phase 2: Remove oldest assistant message with tool calls (and its following tool results) - if removedIdx == -1 { - for i := 1; i < len(result); i++ { - if result[i].Role == "assistant" && len(result[i].ToolCalls) > 0 { - removedIdx = i - break - } - } - } - - // Phase 3: Remove oldest regular assistant message - if removedIdx == -1 { - for i := 1; i < len(result); i++ { - if result[i].Role == "assistant" { - removedIdx = i - break - } + // Skip system messages anywhere in the conversation + if result[i].Role == "system" { + continue } + // Remove the oldest non-system message + removedIdx = i + break } - // If no removable message found, stop if removedIdx == -1 { + // No removable messages found break } - // Decrement token count by the removed message's tokens removedTokens := messageTokens[removedIdx] currentTokens -= removedTokens - + log.Debug(). Str("role", result[removedIdx].Role). Int("index", removedIdx). Int("message_tokens", removedTokens). Int("remaining_tokens", currentTokens). Int("remaining_messages", len(result)-1). - Msg("trimmed message") - - // Remove the message and its token count from caches + Msg("trimmed oldest message") + result = append(result[:removedIdx], result[removedIdx+1:]...) messageTokens = append(messageTokens[:removedIdx], messageTokens[removedIdx+1:]...) trimmedCount++ @@ -192,7 +767,6 @@ func TrimMessagesToFitContext(messages []openai.ChatCompletionMessage, contextLe Int("trimmed_count", trimmedCount). Int("final_messages", len(result)). Int("final_tokens", currentTokens). - Int("tokens_freed", estimateMessagesTokenCount(messages)-currentTokens). Msg("message trimming completed") return TrimMessagesResult{ @@ -201,3 +775,39 @@ func TrimMessagesToFitContext(messages []openai.ChatCompletionMessage, contextLe EstimatedTokens: currentTokens, } } + +// estimateSingleMessageTokens calculates tokens for a single message. +func estimateSingleMessageTokens(msg *openai.ChatCompletionMessage) int { + tokens := 10 // Overhead for role and structure + tokens += estimateTokenCount(msg.Content) + + if len(msg.MultiContent) > 0 { + tokens += estimateMultiContentTokens(msg.MultiContent) + } + + // Count images in tool results + if msg.Role == "tool" && msg.Content != "" { + imageCount := countImagesInToolResult(msg.Content) + tokens += imageCount * ImageTokensHighRes + } + + if msg.ToolCalls != nil { + for _, tc := range msg.ToolCalls { + tokens += 20 + tokens += estimateTokenCount(tc.Function.Name) + tokens += estimateTokenCount(tc.Function.Arguments) + } + } + + return tokens +} + +// BuildTokenBudget creates a TokenBudget from request parameters. +func BuildTokenBudget(contextLength int, tools []openai.Tool, maxCompletionTokens int) *TokenBudget { + return &TokenBudget{ + ContextLength: contextLength, + ToolsTokens: EstimateToolsTokens(tools), + MaxCompletionTokens: maxCompletionTokens, + FixedOverhead: FixedOverheadTokens, + } +} diff --git a/services/llm-api/internal/interfaces/httpserver/responses/chat/chat.go b/services/llm-api/internal/interfaces/httpserver/responses/chat/chat.go index 9cf605c5..59382054 100644 --- a/services/llm-api/internal/interfaces/httpserver/responses/chat/chat.go +++ b/services/llm-api/internal/interfaces/httpserver/responses/chat/chat.go @@ -8,6 +8,7 @@ import ( type ChatCompletionResponse struct { openai.ChatCompletionResponse Conversation *ConversationContext `json:"conversation,omitempty"` + Trimmed bool `json:"trimmed,omitempty"` // True if messages were trimmed to fit context } // ConversationContext represents the conversation associated with this response @@ -17,9 +18,10 @@ type ConversationContext struct { } // NewChatCompletionResponse creates a response with optional conversation context -func NewChatCompletionResponse(openaiResp *openai.ChatCompletionResponse, conversationID string, conversationTitle *string) *ChatCompletionResponse { +func NewChatCompletionResponse(openaiResp *openai.ChatCompletionResponse, conversationID string, conversationTitle *string, trimmed bool) *ChatCompletionResponse { resp := &ChatCompletionResponse{ ChatCompletionResponse: *openaiResp, + Trimmed: trimmed, } if conversationID != "" { diff --git a/services/llm-api/internal/interfaces/httpserver/routes/v1/chat/completion_route.go b/services/llm-api/internal/interfaces/httpserver/routes/v1/chat/completion_route.go index 8f475d42..054b17c2 100644 --- a/services/llm-api/internal/interfaces/httpserver/routes/v1/chat/completion_route.go +++ b/services/llm-api/internal/interfaces/httpserver/routes/v1/chat/completion_route.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" "jan-server/services/llm-api/internal/interfaces/httpserver/handlers/authhandler" "jan-server/services/llm-api/internal/interfaces/httpserver/handlers/chathandler" @@ -81,16 +82,38 @@ func (chatCompletionRoute *ChatCompletionRoute) PostCompletion(reqCtx *gin.Conte } var request chatrequests.ChatCompletionRequest + contentLength := reqCtx.Request.ContentLength if err := reqCtx.ShouldBindJSON(&request); err != nil { responses.HandleError(reqCtx, err, "Invalid request body") return } + conversationID := "" + if request.Conversation != nil { + conversationID = request.Conversation.GetID() + } + + log.Info(). + Str("route", "/v1/chat/completions"). + Str("model", request.Model). + Str("conversation_id", conversationID). + Int("messages", len(request.Messages)). + Int64("content_length", contentLength). + Bool("stream", request.Stream). + Msg("chat completion request received") + // Delegate to chat handler result, err := chatCompletionRoute.chatHandler.CreateChatCompletion(reqCtx.Request.Context(), reqCtx, user.ID, request) if err != nil { + // Check if it's a validation error (user input too large) + if platformerrors.IsValidationError(err) { + responses.HandleError(reqCtx, err, err.Error()) + return + } + + // For other errors, return fallback response fallback := chatCompletionRoute.chatHandler.BuildFallbackResponse(request.Model) - chatResponse := chatresponses.NewChatCompletionResponse(fallback, "", nil) + chatResponse := chatresponses.NewChatCompletionResponse(fallback, "", nil, false) reqCtx.JSON(http.StatusOK, chatResponse) return } @@ -98,7 +121,7 @@ func (chatCompletionRoute *ChatCompletionRoute) PostCompletion(reqCtx *gin.Conte // For non-streaming requests, return the response with conversation context if !request.Stream { // Wrap the OpenAI response with conversation context (including title) - chatResponse := chatresponses.NewChatCompletionResponse(result.Response, result.ConversationID, result.ConversationTitle) + chatResponse := chatresponses.NewChatCompletionResponse(result.Response, result.ConversationID, result.ConversationTitle, result.Trimmed) reqCtx.JSON(http.StatusOK, chatResponse) } diff --git a/services/llm-api/internal/utils/platformerrors/errors.go b/services/llm-api/internal/utils/platformerrors/errors.go index 37fb2afc..25b4e498 100644 --- a/services/llm-api/internal/utils/platformerrors/errors.go +++ b/services/llm-api/internal/utils/platformerrors/errors.go @@ -206,6 +206,11 @@ func IsErrorType(err error, errorType ErrorType) bool { return false } +// IsValidationError checks if an error is a validation error +func IsValidationError(err error) bool { + return IsErrorType(err, ErrorTypeValidation) +} + // LogError logs a platform error with proper structure func LogError(logger zerolog.Logger, err *PlatformError) { if err == nil { diff --git a/services/response-api/internal/domain/llm/message_trimmer.go b/services/response-api/internal/domain/llm/message_trimmer.go index 39f22dc0..cbe27275 100644 --- a/services/response-api/internal/domain/llm/message_trimmer.go +++ b/services/response-api/internal/domain/llm/message_trimmer.go @@ -60,13 +60,9 @@ type TrimMessagesResult struct { EstimatedTokens int } -// TrimMessagesToFitContext removes oldest tool results and assistant messages -// to fit within the context length limit. -// Priority order for removal (oldest first): -// 1. Tool result messages (role="tool") -// 2. Assistant messages with tool calls -// 3. Regular assistant messages -// Never removes: system prompts, user messages +// TrimMessagesToFitContext removes oldest conversation items to fit within the context length limit. +// Removes oldest non-system messages first, regardless of role (user, assistant, tool). +// Never removes: system prompts at index 0 func TrimMessagesToFitContext(messages []ChatMessage, contextLength int) TrimMessagesResult { if contextLength <= 0 { contextLength = DefaultContextLength @@ -89,37 +85,19 @@ func TrimMessagesToFitContext(messages []ChatMessage, contextLength int) TrimMes copy(result, messages) trimmedCount := 0 - // Find indices of messages that can be removed (in order of priority) - // We iterate from oldest to newest (excluding system prompt at index 0) + // Remove oldest items first (any role except system at index 0) + // This approach removes conversation items chronologically from oldest to newest for currentTokens > maxTokens && len(result) > MinMessagesToKeep { + // Find the oldest removable message (skip index 0 which is system prompt) removedIdx := -1 - - // Phase 1: Remove oldest tool result message for i := 1; i < len(result); i++ { - if result[i].Role == "tool" { - removedIdx = i - break - } - } - - // Phase 2: Remove oldest assistant message with tool calls (and its following tool results) - if removedIdx == -1 { - for i := 1; i < len(result); i++ { - if result[i].Role == "assistant" && len(result[i].ToolCalls) > 0 { - removedIdx = i - break - } - } - } - - // Phase 3: Remove oldest regular assistant message - if removedIdx == -1 { - for i := 1; i < len(result); i++ { - if result[i].Role == "assistant" { - removedIdx = i - break - } + // Skip system messages anywhere in the conversation + if result[i].Role == "system" { + continue } + // Remove the oldest non-system message + removedIdx = i + break } // If no removable message found, stop From ba9e68bfcb836868d1be4a1d6b3efb7335e19055 Mon Sep 17 00:00:00 2001 From: locnguyen1986 Date: Fri, 19 Dec 2025 00:23:27 +0700 Subject: [PATCH 2/2] add branching --- services/llm-api/cmd/server/wire_gen.go | 7 +- .../domain/conversation/conversation.go | 21 +- .../conversation/conversation_service.go | 24 +- .../internal/domain/conversation/item.go | 1 + .../conversation/message_action_service.go | 311 ++++++++ services/llm-api/internal/domain/provider.go | 1 + .../conversation_repository.go | 383 ++++++++- .../handlers/chathandler/chat_handler.go | 66 +- .../conversationhandler/branch_handler.go | 323 ++++++++ .../conversation_handler.go | 55 +- .../requests/conversation/conversation.go | 1 + .../httpserver/routes/routes_provider.go | 2 + .../routes/v1/conversation/branch_route.go | 307 +++++++ .../v1/conversation/conversation_route.go | 24 +- .../httpserver/routes/v1/v1_route.go | 4 + .../conversations-postman-scripts.json | 748 ++++++++++++++++++ 16 files changed, 2190 insertions(+), 88 deletions(-) create mode 100644 services/llm-api/internal/domain/conversation/message_action_service.go create mode 100644 services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/branch_handler.go create mode 100644 services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/branch_route.go diff --git a/services/llm-api/cmd/server/wire_gen.go b/services/llm-api/cmd/server/wire_gen.go index 950da3b1..e5ac8fb4 100644 --- a/services/llm-api/cmd/server/wire_gen.go +++ b/services/llm-api/cmd/server/wire_gen.go @@ -90,10 +90,11 @@ func CreateApplication() (*Application, error) { providerHandler := modelhandler.NewProviderHandler(providerService, providerModelService, inferenceProvider) conversationRepository := conversationrepo.NewConversationGormRepository(database) conversationService := conversation.NewConversationService(conversationRepository) + messageActionService := conversation.NewMessageActionService(conversationRepository) projectRepository := projectrepo.NewProjectGormRepository(db) projectService := project.NewProjectService(projectRepository) shareRepository := sharerepo.NewShareGormRepository(database) - conversationHandler := conversationhandler.NewConversationHandler(conversationService, projectService, shareRepository) + conversationHandler := conversationhandler.NewConversationHandler(conversationService, messageActionService, projectService, shareRepository) client := infrastructure.ProvideKeycloakClient(config, zerologLogger) resolver := infrastructure.ProvideMediaResolver(config, zerologLogger, client) processorConfig := domain.ProvidePromptProcessorConfig(config, zerologLogger) @@ -108,6 +109,8 @@ func CreateApplication() (*Application, error) { chatCompletionRoute := chat.NewChatCompletionRoute(chatHandler, authHandler) chatRoute := chat.NewChatRoute(chatCompletionRoute) conversationRoute := conversation2.NewConversationRoute(conversationHandler, authHandler) + branchHandler := conversationhandler.NewBranchHandler(conversationService, messageActionService, conversationRepository) + branchRoute := conversation2.NewBranchRoute(conversationHandler, branchHandler, authHandler) projectHandler := projecthandler.NewProjectHandler(projectService) projectRoute := projects.NewProjectRoute(projectHandler, authHandler) providerModelHandler := modelhandler.NewProviderModelHandler(providerModelService, providerService, modelCatalogService) @@ -126,7 +129,7 @@ func CreateApplication() (*Application, error) { shareHandler := sharehandler.NewShareHandler(shareService, conversationHandler, config) shareRoute := share2.NewShareRoute(shareHandler, authHandler, conversationHandler) publicShareRoute := public.NewPublicShareRoute(shareHandler) - v1Route := v1.NewV1Route(modelRoute, chatRoute, conversationRoute, projectRoute, adminRoute, usersRoute, promptTemplateHandler, shareRoute, publicShareRoute) + v1Route := v1.NewV1Route(modelRoute, chatRoute, conversationRoute, branchRoute, projectRoute, adminRoute, usersRoute, promptTemplateHandler, shareRoute, publicShareRoute) guestHandler := guestauth.NewGuestHandler(client, zerologLogger) upgradeHandler := guestauth.NewUpgradeHandler(client, zerologLogger) tokenHandler := authhandler.NewTokenHandler(client, zerologLogger) diff --git a/services/llm-api/internal/domain/conversation/conversation.go b/services/llm-api/internal/domain/conversation/conversation.go index 6a47ed76..df3bee4b 100644 --- a/services/llm-api/internal/domain/conversation/conversation.go +++ b/services/llm-api/internal/domain/conversation/conversation.go @@ -119,6 +119,11 @@ type ConversationRepository interface { // TODO: Implement forking functionality for conversation editing ForkBranch(ctx context.Context, conversationID uint, sourceBranch, newBranch string, fromItemID string, description *string) error + // SwapBranchToMain swaps a branch with MAIN - the given branch becomes MAIN + // and the old MAIN is renamed to a backup branch name. This is used for edit/regenerate + // operations where the new content should become the primary conversation. + SwapBranchToMain(ctx context.Context, conversationID uint, branchToPromote string) (oldMainBackupName string, err error) + // Item rating operations - TODO: Implement item rating/feedback system RateItem(ctx context.Context, conversationID uint, itemID string, rating ItemRating, comment *string) error GetItemRating(ctx context.Context, conversationID uint, itemID string) (*ItemRating, error) @@ -299,8 +304,22 @@ func (c *Conversation) CreateBranch(newBranchName, sourceBranch, fromItemID stri return nil } +// CreateBranchMetadata creates metadata for a new branch +func (c *Conversation) CreateBranchMetadata(name string, parentBranch *string, forkFromItemID *string, description *string) BranchMetadata { + now := time.Now() + return BranchMetadata{ + Name: name, + Description: description, + ParentBranch: parentBranch, + ForkedAt: &now, + ForkedFromItemID: forkFromItemID, + ItemCount: 0, + CreatedAt: now, + UpdatedAt: now, + } +} + // GenerateEditBranchName generates a unique branch name for conversation edits -// TODO: Currently unused - will be needed when implementing conversation branching UI func GenerateEditBranchName(conversationID uint) string { return fmt.Sprintf("EDIT_%d_%d", conversationID, time.Now().Unix()) } diff --git a/services/llm-api/internal/domain/conversation/conversation_service.go b/services/llm-api/internal/domain/conversation/conversation_service.go index 4783d046..c6d95d2d 100644 --- a/services/llm-api/internal/domain/conversation/conversation_service.go +++ b/services/llm-api/internal/domain/conversation/conversation_service.go @@ -200,9 +200,17 @@ func (s *ConversationService) AddItemsToConversation(ctx context.Context, conv * return []Item{}, nil } - // Validate branch exists (for now, only MAIN is supported) + // Default to MAIN branch if not specified + if branchName == "" { + branchName = BranchMain + } + + // Validate branch exists for non-MAIN branches if branchName != BranchMain { - return nil, platformerrors.NewError(ctx, platformerrors.LayerDomain, platformerrors.ErrorTypeNotFound, fmt.Sprintf("branch not found: %s", branchName), nil, "e5f6a7b8-c9d0-4e1f-2a3b-4c5d6e7f8a9b") + branch, err := s.repo.GetBranch(ctx, conv.ID, branchName) + if err != nil || branch == nil { + return nil, platformerrors.NewError(ctx, platformerrors.LayerDomain, platformerrors.ErrorTypeNotFound, fmt.Sprintf("branch not found: %s", branchName), nil, "e5f6a7b8-c9d0-4e1f-2a3b-4c5d6e7f8a9b") + } } // Get current item count to determine starting sequence number @@ -229,15 +237,9 @@ func (s *ConversationService) AddItemsToConversation(ctx context.Context, conv * itemPtrs[i] = &items[i] } - // Add items to repository - if branchName == BranchMain || branchName == "" { - if err := s.repo.BulkAddItems(ctx, conv.ID, itemPtrs); err != nil { - return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to add items") - } - } else { - if err := s.repo.BulkAddItemsToBranch(ctx, conv.ID, branchName, itemPtrs); err != nil { - return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to add items to branch") - } + // Add items to repository - use branch-aware method for all branches + if err := s.repo.BulkAddItemsToBranch(ctx, conv.ID, branchName, itemPtrs); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to add items to branch") } // Update conversation's updated_at timestamp diff --git a/services/llm-api/internal/domain/conversation/item.go b/services/llm-api/internal/domain/conversation/item.go index 397a5d31..b40e434c 100644 --- a/services/llm-api/internal/domain/conversation/item.go +++ b/services/llm-api/internal/domain/conversation/item.go @@ -684,6 +684,7 @@ type ItemFilter struct { ConversationID *uint Role *ItemRole ResponseID *uint + Branch *string // Filter by branch name } type ItemRepository interface { diff --git a/services/llm-api/internal/domain/conversation/message_action_service.go b/services/llm-api/internal/domain/conversation/message_action_service.go new file mode 100644 index 00000000..ed0d355a --- /dev/null +++ b/services/llm-api/internal/domain/conversation/message_action_service.go @@ -0,0 +1,311 @@ +package conversation + +import ( + "context" + "time" + + "jan-server/services/llm-api/internal/utils/idgen" + "jan-server/services/llm-api/internal/utils/platformerrors" +) + +// MessageActionService handles message edit, regenerate, and delete operations +type MessageActionService struct { + convRepo ConversationRepository +} + +// NewMessageActionService creates a new message action service +func NewMessageActionService(convRepo ConversationRepository) *MessageActionService { + return &MessageActionService{ + convRepo: convRepo, + } +} + +// EditResult contains the result of an edit message operation +type EditResult struct { + NewBranch string `json:"new_branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN + UserItem *Item `json:"user_item"` + ConversationID string `json:"conversation_id"` +} + +// RegenerateResult contains the result of a regenerate operation +type RegenerateResult struct { + NewBranch string `json:"new_branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN + ConvID string `json:"conversation_id"` + UserItemID string `json:"user_item_id"` // The user message to regenerate from +} + +// EditMessage creates a new branch from the edited message point +// It creates a fork of the conversation at the specified item with new content +func (s *MessageActionService) EditMessage(ctx context.Context, conv *Conversation, itemPublicID string, newContent string) (*EditResult, error) { + // Get the original item to verify it exists and is a user message + originalItem, err := s.convRepo.GetItemByPublicID(ctx, conv.ID, itemPublicID) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "item not found") + } + + // Verify it's a user message + if originalItem.Role == nil || *originalItem.Role != ItemRoleUser { + return nil, platformerrors.NewError(ctx, platformerrors.LayerDomain, platformerrors.ErrorTypeValidation, "can only edit user messages", nil, "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d") + } + + // Generate new branch name + newBranchName := GenerateEditBranchName(conv.ID) + + // Fork the branch at the item before this one (parent item) + // We need to find the previous item in the sequence + branchItems, err := s.convRepo.GetBranchItems(ctx, conv.ID, conv.ActiveBranch, nil) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to get branch items") + } + + // Find the item and determine fork point (one item before) + forkFromItemID := "" + for i, item := range branchItems { + if item.PublicID == itemPublicID { + if i > 0 { + forkFromItemID = branchItems[i-1].PublicID + } + break + } + } + + // Create the new branch + now := time.Now() + description := "Edited message branch" + metadata := &BranchMetadata{ + Name: newBranchName, + Description: &description, + ParentBranch: &conv.ActiveBranch, + ForkedAt: &now, + ForkedFromItemID: &forkFromItemID, + ItemCount: 0, + CreatedAt: now, + UpdatedAt: now, + } + + // Fork branch copies items up to fork point + if forkFromItemID != "" { + if err := s.convRepo.ForkBranch(ctx, conv.ID, conv.ActiveBranch, newBranchName, forkFromItemID, &description); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to fork branch") + } + } else { + // If no fork point (editing first message), just create empty branch + if err := s.convRepo.CreateBranch(ctx, conv.ID, newBranchName, metadata); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to create branch") + } + } + + // Create new user item with edited content + newPublicID, err := idgen.GenerateSecureID("msg", 16) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to generate item ID") + } + + // Get item count in new branch for sequence number + itemCount, err := s.convRepo.CountItems(ctx, conv.ID, newBranchName) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to count items") + } + + newItem := &Item{ + PublicID: newPublicID, + Object: "conversation.item", + Type: ItemTypeMessage, + Role: originalItem.Role, + Content: []Content{{Type: "input_text", TextString: &newContent}}, + ConversationID: conv.ID, + Branch: newBranchName, + SequenceNumber: itemCount + 1, + CreatedAt: now, + } + + // Add the new item to the branch + if err := s.convRepo.AddItemToBranch(ctx, conv.ID, newBranchName, newItem); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to add edited item") + } + + // Swap the new branch to become MAIN (old MAIN becomes a backup) + oldMainBackup, err := s.convRepo.SwapBranchToMain(ctx, conv.ID, newBranchName) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to swap branch to MAIN") + } + + // Update the item's branch to MAIN after swap + newItem.Branch = "MAIN" + + return &EditResult{ + NewBranch: "MAIN", + OldMainBackup: oldMainBackup, + UserItem: newItem, + ConversationID: conv.PublicID, + }, nil +} + +// RegenerateResponse creates a new branch and prepares for regenerating the assistant response +// Returns the user message that should be used to regenerate a response +func (s *MessageActionService) RegenerateResponse(ctx context.Context, conv *Conversation, assistantItemPublicID string) (*RegenerateResult, error) { + // Get the assistant item + assistantItem, err := s.convRepo.GetItemByPublicID(ctx, conv.ID, assistantItemPublicID) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "item not found") + } + + // Verify it's an assistant message + if assistantItem.Role == nil || *assistantItem.Role != ItemRoleAssistant { + return nil, platformerrors.NewError(ctx, platformerrors.LayerDomain, platformerrors.ErrorTypeValidation, "can only regenerate assistant messages", nil, "b2c3d4e5-f6a7-4b8c-9d0e-1f2a3b4c5d6e") + } + + // Use the item's branch to find sibling items (not active branch, as item may be in a different branch) + itemBranch := assistantItem.Branch + if itemBranch == "" { + itemBranch = "MAIN" + } + + // Get all branch items to find the user message before this assistant message + branchItems, err := s.convRepo.GetBranchItems(ctx, conv.ID, itemBranch, nil) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to get branch items") + } + + // Find the corresponding user message (the item right before the assistant message) + var userItem *Item + for i, item := range branchItems { + if item.PublicID == assistantItemPublicID { + // Look for the user message before this assistant message + for j := i - 1; j >= 0; j-- { + if branchItems[j].Role != nil && *branchItems[j].Role == ItemRoleUser { + userItem = branchItems[j] + break + } + } + break + } + } + + if userItem == nil { + return nil, platformerrors.NewError(ctx, platformerrors.LayerDomain, platformerrors.ErrorTypeNotFound, "user message not found for regeneration", nil, "c3d4e5f6-a7b8-4c9d-0e1f-2a3b4c5d6e7f") + } + + // Generate new branch name + newBranchName := GenerateRegenBranchName(conv.ID) + + // Fork at the user message (so we keep history up to and including user message) + // Use the item's branch as source, not active branch + description := "Regenerated response branch" + if err := s.convRepo.ForkBranch(ctx, conv.ID, itemBranch, newBranchName, userItem.PublicID, &description); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to fork branch for regeneration") + } + + // Swap the new branch to become MAIN (old MAIN becomes a backup) + oldMainBackup, err := s.convRepo.SwapBranchToMain(ctx, conv.ID, newBranchName) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to swap branch to MAIN") + } + + // Get the new user item ID from MAIN (it was copied during fork) + mainItems, err := s.convRepo.GetBranchItems(ctx, conv.ID, "MAIN", nil) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to get MAIN items") + } + + // Find the last user message in MAIN (should be the one we forked to) + var newUserItemID string + for i := len(mainItems) - 1; i >= 0; i-- { + if mainItems[i].Role != nil && *mainItems[i].Role == ItemRoleUser { + newUserItemID = mainItems[i].PublicID + break + } + } + + return &RegenerateResult{ + NewBranch: "MAIN", + OldMainBackup: oldMainBackup, + ConvID: conv.PublicID, + UserItemID: newUserItemID, + }, nil +} + +// DeleteResult contains the result of a delete message operation +type DeleteResult struct { + NewBranch string `json:"new_branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN +} + +// DeleteMessage deletes a message by creating a new branch without it +// The new branch becomes MAIN and the old MAIN is preserved as a backup +func (s *MessageActionService) DeleteMessage(ctx context.Context, conv *Conversation, itemPublicID string) (*DeleteResult, error) { + // Get the item to verify it exists + item, err := s.convRepo.GetItemByPublicID(ctx, conv.ID, itemPublicID) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "item not found") + } + + // Get the item's branch + itemBranch := item.Branch + if itemBranch == "" { + itemBranch = "MAIN" + } + + // Get all branch items to find the item before the one to delete + branchItems, err := s.convRepo.GetBranchItems(ctx, conv.ID, itemBranch, nil) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to get branch items") + } + + // Find the item to delete and determine fork point (one item before) + forkFromItemID := "" + for i, branchItem := range branchItems { + if branchItem.PublicID == itemPublicID { + if i > 0 { + forkFromItemID = branchItems[i-1].PublicID + } + break + } + } + + // Generate new branch name + newBranchName := generateBranchNameWithPrefix(conv.ID, "DELETE") + + // Fork the branch at the item before the deleted one + description := "Deleted message branch" + if forkFromItemID != "" { + if err := s.convRepo.ForkBranch(ctx, conv.ID, itemBranch, newBranchName, forkFromItemID, &description); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to fork branch for delete") + } + } else { + // If deleting the first message, create an empty branch + now := time.Now() + metadata := &BranchMetadata{ + Name: newBranchName, + Description: &description, + CreatedAt: now, + UpdatedAt: now, + } + if err := s.convRepo.CreateBranch(ctx, conv.ID, newBranchName, metadata); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to create branch for delete") + } + } + + // Swap the new branch to become MAIN (old MAIN becomes a backup) + oldMainBackup, err := s.convRepo.SwapBranchToMain(ctx, conv.ID, newBranchName) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerDomain, err, "failed to swap branch to MAIN") + } + + return &DeleteResult{ + NewBranch: "MAIN", + OldMainBackup: oldMainBackup, + }, nil +} + +// GenerateRegenBranchName generates a unique branch name for regenerated responses +func GenerateRegenBranchName(conversationID uint) string { + return generateBranchNameWithPrefix(conversationID, "REGEN") +} + +// generateBranchNameWithPrefix generates a unique branch name with a prefix +func generateBranchNameWithPrefix(conversationID uint, prefix string) string { + return prefix + "_" + time.Now().Format("20060102150405") +} diff --git a/services/llm-api/internal/domain/provider.go b/services/llm-api/internal/domain/provider.go index 1e7fac4b..f6f5ce0c 100644 --- a/services/llm-api/internal/domain/provider.go +++ b/services/llm-api/internal/domain/provider.go @@ -20,6 +20,7 @@ import ( var ServiceProvider = wire.NewSet( // Conversation domain conversation.NewConversationService, + conversation.NewMessageActionService, // Project domain project.NewProjectService, diff --git a/services/llm-api/internal/infrastructure/database/repository/conversationrepo/conversation_repository.go b/services/llm-api/internal/infrastructure/database/repository/conversationrepo/conversation_repository.go index 1784385b..98769a98 100644 --- a/services/llm-api/internal/infrastructure/database/repository/conversationrepo/conversation_repository.go +++ b/services/llm-api/internal/infrastructure/database/repository/conversationrepo/conversation_repository.go @@ -2,6 +2,7 @@ package conversationrepo import ( "context" + "time" "jan-server/services/llm-api/internal/domain/conversation" "jan-server/services/llm-api/internal/domain/query" @@ -9,6 +10,7 @@ import ( "jan-server/services/llm-api/internal/infrastructure/database/gormgen" "jan-server/services/llm-api/internal/infrastructure/database/transaction" "jan-server/services/llm-api/internal/utils/functional" + "jan-server/services/llm-api/internal/utils/idgen" "jan-server/services/llm-api/internal/utils/platformerrors" ) @@ -304,13 +306,17 @@ func (repo *ConversationGormRepository) DeleteItem(ctx context.Context, conversa func (repo *ConversationGormRepository) CountItems(ctx context.Context, conversationID uint, branchName string) (int, error) { q := repo.db.GetQuery(ctx) sql := q.ConversationItem.WithContext(ctx) - sql = repo.applyItemFilter(q, sql, conversation.ItemFilter{ + + // Apply filter with branch name for proper per-branch counting + filter := conversation.ItemFilter{ ConversationID: &conversationID, - }) + } + if branchName != "" { + filter.Branch = &branchName + } + sql = repo.applyItemFilter(q, sql, filter) - // For now, we count all items since branch filtering isn't fully implemented in gormgen count, err := sql.Count() - if err != nil { return 0, platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to count items") } @@ -327,63 +333,126 @@ func (repo *ConversationGormRepository) CreateBranch(ctx context.Context, conver return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "conversation not found") } - // TODO: Implement branch storage in database - // For now, return not implemented error - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "b4c5d6e7-f8a9-4b0c-1d2e-3f4a5b6c7d8e") + // Create branch in database + branch := dbschema.NewSchemaConversationBranch(conversationID, *metadata) + q := repo.db.GetQuery(ctx) + if err := q.ConversationBranch.WithContext(ctx).Create(branch); err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to create branch") + } + return nil } // GetBranch implements conversation.ConversationRepository. func (repo *ConversationGormRepository) GetBranch(ctx context.Context, conversationID uint, branchName string) (*conversation.BranchMetadata, error) { - return nil, platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "c5d6e7f8-a9b0-4c1d-2e3f-4a5b6c7d8e9f") + q := repo.db.GetQuery(ctx) + branch, err := q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Where(q.ConversationBranch.Name.Eq(branchName)). + First() + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "branch not found") + } + result := branch.EtoD() + return &result, nil } // ListBranches implements conversation.ConversationRepository. func (repo *ConversationGormRepository) ListBranches(ctx context.Context, conversationID uint) ([]*conversation.BranchMetadata, error) { - return nil, platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "d6e7f8a9-b0c1-4d2e-3f4a-5b6c7d8e9f0a") + q := repo.db.GetQuery(ctx) + branches, err := q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Order(q.ConversationBranch.CreatedAt.Asc()). + Find() + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to list branches") + } + + result := make([]*conversation.BranchMetadata, len(branches)) + for i, branch := range branches { + meta := branch.EtoD() + result[i] = &meta + } + return result, nil } // DeleteBranch implements conversation.ConversationRepository. func (repo *ConversationGormRepository) DeleteBranch(ctx context.Context, conversationID uint, branchName string) error { - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "e7f8a9b0-c1d2-4e3f-4a5b-6c7d8e9f0a1b") + // Don't allow deleting MAIN branch + if branchName == "MAIN" { + return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeValidation, "cannot delete MAIN branch", nil, "e7f8a9b0-c1d2-4e3f-4a5b-6c7d8e9f0a1b") + } + + q := repo.db.GetQuery(ctx) + + // Delete all items in this branch first + _, err := q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.Branch.Eq(branchName)). + Delete() + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to delete branch items") + } + + // Delete the branch metadata + _, err = q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Where(q.ConversationBranch.Name.Eq(branchName)). + Delete() + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to delete branch") + } + return nil } // SetActiveBranch implements conversation.ConversationRepository. func (repo *ConversationGormRepository) SetActiveBranch(ctx context.Context, conversationID uint, branchName string) error { - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "f8a9b0c1-d2e3-4f4a-5b6c-7d8e9f0a1b2c") + q := repo.db.GetQuery(ctx) + _, err := q.Conversation.WithContext(ctx). + Where(q.Conversation.ID.Eq(conversationID)). + Update(q.Conversation.ActiveBranch, branchName) + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to set active branch") + } + return nil } // Branch item operations // AddItemToBranch implements conversation.ConversationRepository. func (repo *ConversationGormRepository) AddItemToBranch(ctx context.Context, conversationID uint, branchName string, item *conversation.Item) error { - // For now, branch operations are not implemented - // Default to MAIN branch behavior - if branchName == "MAIN" || branchName == "" { - return repo.AddItem(ctx, conversationID, item) + // Set branch on item + item.Branch = branchName + if branchName == "" { + item.Branch = "MAIN" } - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "a9b0c1d2-e3f4-4a5b-6c7d-8e9f0a1b2c3d") + return repo.AddItem(ctx, conversationID, item) } // GetBranchItems implements conversation.ConversationRepository. func (repo *ConversationGormRepository) GetBranchItems(ctx context.Context, conversationID uint, branchName string, pagination *query.Pagination) ([]*conversation.Item, error) { - // For now, return items for MAIN branch with pagination support - if branchName == "MAIN" || branchName == "" { - q := repo.db.GetQuery(ctx) - sql := q.ConversationItem.WithContext(ctx) - sql = repo.applyItemFilter(q, sql, conversation.ItemFilter{ - ConversationID: &conversationID, - }) - sql = repo.applyItemPagination(q, sql, pagination) + // Default to MAIN branch if empty + if branchName == "" { + branchName = "MAIN" + } - rows, err := sql.Find() - if err != nil { - return nil, platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to get branch items") - } + q := repo.db.GetQuery(ctx) + sql := q.ConversationItem.WithContext(ctx) + + // Apply filter with branch name + filter := conversation.ItemFilter{ + ConversationID: &conversationID, + Branch: &branchName, + } + sql = repo.applyItemFilter(q, sql, filter) + sql = repo.applyItemPagination(q, sql, pagination) - return functional.Map(rows, func(item *dbschema.ConversationItem) *conversation.Item { - return item.EtoD() - }), nil + rows, err := sql.Find() + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to get branch items") } - return nil, platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "b0c1d2e3-f4a5-4b6c-7d8e-9f0a1b2c3d4e") + + return functional.Map(rows, func(item *dbschema.ConversationItem) *conversation.Item { + return item.EtoD() + }), nil } // applyItemPagination applies pagination to item queries @@ -418,35 +487,259 @@ func (repo *ConversationGormRepository) applyItemPagination(q *gormgen.Query, sq // BulkAddItemsToBranch implements conversation.ConversationRepository. func (repo *ConversationGormRepository) BulkAddItemsToBranch(ctx context.Context, conversationID uint, branchName string, items []*conversation.Item) error { - // For now, branch operations are not implemented - // Default to MAIN branch behavior - if branchName == "MAIN" || branchName == "" { - return repo.BulkAddItems(ctx, conversationID, items) + if len(items) == 0 { + return nil + } + + // Default to MAIN if empty + if branchName == "" { + branchName = "MAIN" } - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "c1d2e3f4-a5b6-4c7d-8e9f-0a1b2c3d4e5f") + + // Set branch on all items + for _, item := range items { + item.Branch = branchName + } + + // Use existing BulkAddItems - it already handles the conversion + return repo.BulkAddItems(ctx, conversationID, items) } // ForkBranch implements conversation.ConversationRepository. func (repo *ConversationGormRepository) ForkBranch(ctx context.Context, conversationID uint, sourceBranch, newBranch string, fromItemID string, description *string) error { - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "branch operations not yet implemented in database layer", nil, "d2e3f4a5-b6c7-4d8e-9f0a-1b2c3d4e5f6a") + // Get source branch items up to the fork point + sourceItems, err := repo.GetBranchItems(ctx, conversationID, sourceBranch, nil) + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to get source branch items") + } + + // Find fork point + forkIndex := -1 + for i, item := range sourceItems { + if item.PublicID == fromItemID { + forkIndex = i + break + } + } + + if forkIndex == -1 && fromItemID != "" { + return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotFound, "fork item not found", nil, "d2e3f4a5-b6c7-4d8e-9f0a-1b2c3d4e5f6a") + } + + // Create branch metadata + now := time.Now() + metadata := &conversation.BranchMetadata{ + Name: newBranch, + Description: description, + ParentBranch: &sourceBranch, + ForkedAt: &now, + ForkedFromItemID: &fromItemID, + ItemCount: 0, + CreatedAt: now, + UpdatedAt: now, + } + + if err := repo.CreateBranch(ctx, conversationID, newBranch, metadata); err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to create branch") + } + + // Copy items up to fork point to new branch + if forkIndex >= 0 { + itemsToCopy := make([]*conversation.Item, forkIndex+1) + for i := 0; i <= forkIndex; i++ { + itemCopy := *sourceItems[i] + itemCopy.ID = 0 // Reset ID for new insert + // Generate new PublicID for the copied item (PublicID has unique constraint) + newPublicID, err := idgen.GenerateSecureID("msg", 16) + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to generate item ID") + } + itemCopy.PublicID = newPublicID + itemCopy.Branch = newBranch + itemCopy.SequenceNumber = i + 1 + itemsToCopy[i] = &itemCopy + } + + if err := repo.BulkAddItemsToBranch(ctx, conversationID, newBranch, itemsToCopy); err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to copy items to new branch") + } + + // Update branch item count + q := repo.db.GetQuery(ctx) + _, err = q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Where(q.ConversationBranch.Name.Eq(newBranch)). + Update(q.ConversationBranch.ItemCount, len(itemsToCopy)) + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to update branch item count") + } + } + + return nil +} + +// SwapBranchToMain implements conversation.ConversationRepository. +// It promotes the given branch to become MAIN by: +// 1. Creating a backup for the old MAIN items (if they exist) +// 2. Renaming the given branch to MAIN +// 3. Setting MAIN as the active branch +func (repo *ConversationGormRepository) SwapBranchToMain(ctx context.Context, conversationID uint, branchToPromote string) (string, error) { + if branchToPromote == "MAIN" { + // Already MAIN, nothing to do + return "", nil + } + + q := repo.db.GetQuery(ctx) + + // Generate backup name for old MAIN + oldMainBackupName := "MAIN_" + time.Now().Format("20060102150405") + + // Check if MAIN branch record exists in the database + mainBranch, err := q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Where(q.ConversationBranch.Name.Eq("MAIN")). + First() + + if err == nil && mainBranch != nil { + // MAIN branch record exists - rename it to backup + _, err = q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Where(q.ConversationBranch.Name.Eq("MAIN")). + Update(q.ConversationBranch.Name, oldMainBackupName) + if err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to rename MAIN branch to backup") + } + } else { + // MAIN branch record doesn't exist - create backup branch for existing MAIN items + // Count existing MAIN items + count, err := q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.Branch.Eq("MAIN")). + Count() + if err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to count MAIN items") + } + + if count > 0 { + // Create a branch record for the backup + now := time.Now() + description := "Backup of original MAIN branch" + backupBranch := &dbschema.ConversationBranch{ + ConversationID: conversationID, + Name: oldMainBackupName, + Description: &description, + ItemCount: int(count), + } + backupBranch.CreatedAt = now + backupBranch.UpdatedAt = now + + if err := q.ConversationBranch.WithContext(ctx).Create(backupBranch); err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to create backup branch") + } + } else { + // No MAIN items exist, no backup needed + oldMainBackupName = "" + } + } + + // Update all items in old MAIN to use backup name (if backup was created) + if oldMainBackupName != "" { + _, err = q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.Branch.Eq("MAIN")). + Update(q.ConversationItem.Branch, oldMainBackupName) + if err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to update MAIN items to backup branch") + } + } + + // Rename the promoted branch metadata to MAIN + _, err = q.ConversationBranch.WithContext(ctx). + Where(q.ConversationBranch.ConversationID.Eq(conversationID)). + Where(q.ConversationBranch.Name.Eq(branchToPromote)). + Update(q.ConversationBranch.Name, "MAIN") + if err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to rename branch to MAIN") + } + + // Update all items in promoted branch to use MAIN + _, err = q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.Branch.Eq(branchToPromote)). + Update(q.ConversationItem.Branch, "MAIN") + if err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to update promoted items to MAIN") + } + + // Set MAIN as active branch + _, err = q.Conversation.WithContext(ctx). + Where(q.Conversation.ID.Eq(conversationID)). + Update(q.Conversation.ActiveBranch, "MAIN") + if err != nil { + return "", platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to set active branch to MAIN") + } + + return oldMainBackupName, nil } // Item rating operations // RateItem implements conversation.ConversationRepository. func (repo *ConversationGormRepository) RateItem(ctx context.Context, conversationID uint, itemID string, rating conversation.ItemRating, comment *string) error { - // TODO: Implement rating storage in database - // For now, return not implemented error - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "rating operations not yet implemented in database layer", nil, "e3f4a5b6-c7d8-4e9f-0a1b-2c3d4e5f6a7b") + q := repo.db.GetQuery(ctx) + ratingStr := string(rating) + now := time.Now() + + updates := map[string]interface{}{ + "rating": ratingStr, + "rated_at": now, + } + if comment != nil { + updates["rating_comment"] = *comment + } + + _, err := q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.PublicID.Eq(itemID)). + Updates(updates) + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to rate item") + } + return nil } // GetItemRating implements conversation.ConversationRepository. func (repo *ConversationGormRepository) GetItemRating(ctx context.Context, conversationID uint, itemID string) (*conversation.ItemRating, error) { - return nil, platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "rating operations not yet implemented in database layer", nil, "f4a5b6c7-d8e9-4f0a-1b2c-3d4e5f6a7b8c") + q := repo.db.GetQuery(ctx) + item, err := q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.PublicID.Eq(itemID)). + First() + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "item not found") + } + if item.Rating == nil { + return nil, nil + } + rating := conversation.ItemRating(*item.Rating) + return &rating, nil } // RemoveItemRating implements conversation.ConversationRepository. func (repo *ConversationGormRepository) RemoveItemRating(ctx context.Context, conversationID uint, itemID string) error { - return platformerrors.NewError(ctx, platformerrors.LayerRepository, platformerrors.ErrorTypeNotImplemented, "rating operations not yet implemented in database layer", nil, "a5b6c7d8-e9f0-4a1b-2c3d-4e5f6a7b8c9d") + q := repo.db.GetQuery(ctx) + updates := map[string]interface{}{ + "rating": nil, + "rated_at": nil, + "rating_comment": nil, + } + _, err := q.ConversationItem.WithContext(ctx). + Where(q.ConversationItem.ConversationID.Eq(conversationID)). + Where(q.ConversationItem.PublicID.Eq(itemID)). + Updates(updates) + if err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerRepository, err, "failed to remove item rating") + } + return nil } // applyFilter applies filter conditions to the query @@ -485,6 +778,10 @@ func (repo *ConversationGormRepository) applyItemFilter(q *gormgen.Query, sql go if filter.ResponseID != nil { sql = sql.Where(q.ConversationItem.ResponseID.Eq(*filter.ResponseID)) } + // Filter by branch name + if filter.Branch != nil && *filter.Branch != "" { + sql = sql.Where(q.ConversationItem.Branch.Eq(*filter.Branch)) + } return sql } diff --git a/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go b/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go index 50aeefae..15f47336 100644 --- a/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go +++ b/services/llm-api/internal/interfaces/httpserver/handlers/chathandler/chat_handler.go @@ -1013,10 +1013,39 @@ func (h *ChatHandler) addCompletionToConversation( return nil } + // Use conversation's active branch instead of hardcoded MAIN + branchName := conv.ActiveBranch + if branchName == "" { + branchName = conversation.BranchMain + } + items := make([]conversation.Item, 0, 2) - if item := h.buildInputConversationItem(newMessages, storeReasoning, askItemID); item != nil { - items = append(items, *item) + // Build the user input item + userItem := h.buildInputConversationItem(newMessages, storeReasoning, askItemID) + + // Check if we should skip adding the user message (avoid duplicates after regenerate) + // This happens when regenerate creates a branch with the user message, then frontend + // triggers a new completion which would add the same user message again + if userItem != nil { + skipUserItem := false + + // Get the last item in the branch to check for duplicates + existingItems, err := h.conversationService.GetConversationItems(ctx, conv, branchName, nil) + if err == nil && len(existingItems) > 0 { + lastItem := existingItems[len(existingItems)-1] + // If the last item is a user message, check if it has the same content + if lastItem.Role != nil && *lastItem.Role == conversation.ItemRoleUser { + // Compare content - if it's the same, skip adding + if h.isSameMessageContent(userItem, &lastItem) { + skipUserItem = true + } + } + } + + if !skipUserItem { + items = append(items, *userItem) + } } if item := h.buildAssistantConversationItem(response, storeReasoning, completionItemID); item != nil { @@ -1036,13 +1065,44 @@ func (h *ChatHandler) addCompletionToConversation( return nil } - if _, err := h.conversationService.AddItemsToConversation(ctx, conv, conversation.BranchMain, items); err != nil { + if _, err := h.conversationService.AddItemsToConversation(ctx, conv, branchName, items); err != nil { return platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to add items to conversation") } return nil } +// isSameMessageContent checks if two items have the same text content +// Used to detect duplicate user messages after regenerate +func (h *ChatHandler) isSameMessageContent(newItem *conversation.Item, existingItem *conversation.Item) bool { + if newItem == nil || existingItem == nil { + return false + } + + // Extract text content from both items + newText := extractTextFromContent(newItem.Content) + existingText := extractTextFromContent(existingItem.Content) + + // Compare normalized text (trim whitespace) + return strings.TrimSpace(newText) == strings.TrimSpace(existingText) +} + +// extractTextFromContent extracts the text content from a slice of Content +func extractTextFromContent(contents []conversation.Content) string { + for _, c := range contents { + if c.TextString != nil && *c.TextString != "" { + return *c.TextString + } + if c.Text != nil && c.Text.Text != "" { + return c.Text.Text + } + if c.OutputText != nil && c.OutputText.Text != "" { + return c.OutputText.Text + } + } + return "" +} + func (h *ChatHandler) buildInputConversationItem( messages []openai.ChatCompletionMessage, storeReasoning bool, diff --git a/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/branch_handler.go b/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/branch_handler.go new file mode 100644 index 00000000..ad26bfb6 --- /dev/null +++ b/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/branch_handler.go @@ -0,0 +1,323 @@ +package conversationhandler + +import ( + "context" + + "jan-server/services/llm-api/internal/domain/conversation" + "jan-server/services/llm-api/internal/utils/platformerrors" +) + +// BranchHandler handles branch-related HTTP requests +type BranchHandler struct { + conversationService *conversation.ConversationService + messageActionService *conversation.MessageActionService + repo conversation.ConversationRepository +} + +// NewBranchHandler creates a new branch handler +func NewBranchHandler( + conversationService *conversation.ConversationService, + messageActionService *conversation.MessageActionService, + repo conversation.ConversationRepository, +) *BranchHandler { + return &BranchHandler{ + conversationService: conversationService, + messageActionService: messageActionService, + repo: repo, + } +} + +// =============================================== +// Request/Response Types +// =============================================== + +// CreateBranchRequest represents the request to create a branch +type CreateBranchRequest struct { + Name string `json:"name" binding:"required"` + ParentBranch *string `json:"parent_branch,omitempty"` + ForkFromItemID *string `json:"fork_from_item_id,omitempty"` + Description *string `json:"description,omitempty"` +} + +// EditMessageRequest represents the request to edit a message +type EditMessageRequest struct { + Content string `json:"content" binding:"required"` + Regenerate *bool `json:"regenerate,omitempty"` // Auto-trigger new response (default: true) +} + +// RegenerateMessageRequest represents the request to regenerate a message +type RegenerateMessageRequest struct { + Model *string `json:"model,omitempty"` // Override model + Temperature *float32 `json:"temperature,omitempty"` // Override temperature + MaxTokens *int `json:"max_tokens,omitempty"` // Override max tokens +} + +// BranchResponse represents a branch in API responses +type BranchResponse struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + ParentBranch *string `json:"parent_branch,omitempty"` + ForkedAt *int64 `json:"forked_at,omitempty"` + ForkedFromItemID *string `json:"forked_from_item_id,omitempty"` + ItemCount int `json:"item_count"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + IsActive bool `json:"is_active"` +} + +// ListBranchesResponse represents the response for listing branches +type ListBranchesResponse struct { + Object string `json:"object"` // "list" + Data []BranchResponse `json:"data"` + ActiveBranch string `json:"active_branch"` +} + +// EditMessageResponse represents the response for editing a message +type EditMessageResponse struct { + Branch string `json:"branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN + BranchCreated bool `json:"branch_created"` + NewBranch *BranchResponse `json:"new_branch,omitempty"` + UserItem *conversation.Item `json:"user_item"` +} + +// RegenerateMessageResponse represents the response for regenerating a message +type RegenerateMessageResponse struct { + Branch string `json:"branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN + BranchCreated bool `json:"branch_created"` + NewBranch *BranchResponse `json:"new_branch,omitempty"` + UserItemID string `json:"user_item_id"` +} + +// DeleteMessageResponse represents the response for deleting a message +type DeleteMessageResponse struct { + Branch string `json:"branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN + BranchCreated bool `json:"branch_created"` + Deleted bool `json:"deleted"` +} + +// ActivateBranchResponse represents the response for activating a branch +type ActivateBranchResponse struct { + ActiveBranch string `json:"active_branch"` + Message string `json:"message"` +} + +// =============================================== +// Handler Methods +// =============================================== + +// ListBranches lists all branches for a conversation +func (h *BranchHandler) ListBranches(ctx context.Context, conv *conversation.Conversation) (*ListBranchesResponse, error) { + branches, err := h.repo.ListBranches(ctx, conv.ID) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to list branches") + } + + data := make([]BranchResponse, len(branches)) + for i, branch := range branches { + data[i] = toBranchResponse(branch, conv.ActiveBranch) + } + + // If no branches exist, return MAIN as default + if len(data) == 0 { + data = []BranchResponse{{ + Name: "MAIN", + ItemCount: 0, + IsActive: true, + }} + } + + return &ListBranchesResponse{ + Object: "list", + Data: data, + ActiveBranch: conv.ActiveBranch, + }, nil +} + +// CreateBranch creates a new branch +func (h *BranchHandler) CreateBranch(ctx context.Context, conv *conversation.Conversation, req CreateBranchRequest) (*BranchResponse, error) { + // Validate branch name + if req.Name == "" { + return nil, platformerrors.NewError(ctx, platformerrors.LayerHandler, platformerrors.ErrorTypeValidation, "branch name is required", nil, "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d") + } + + if req.Name == "MAIN" { + return nil, platformerrors.NewError(ctx, platformerrors.LayerHandler, platformerrors.ErrorTypeValidation, "cannot create branch named MAIN", nil, "b2c3d4e5-f6a7-4b8c-9d0e-1f2a3b4c5d6e") + } + + // Set default parent branch + parentBranch := conv.ActiveBranch + if req.ParentBranch != nil && *req.ParentBranch != "" { + parentBranch = *req.ParentBranch + } + + // Fork the branch if fork point is specified + if req.ForkFromItemID != nil && *req.ForkFromItemID != "" { + if err := h.repo.ForkBranch(ctx, conv.ID, parentBranch, req.Name, *req.ForkFromItemID, req.Description); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to fork branch") + } + } else { + // Create empty branch + metadata := conv.CreateBranchMetadata(req.Name, &parentBranch, nil, req.Description) + if err := h.repo.CreateBranch(ctx, conv.ID, req.Name, &metadata); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to create branch") + } + } + + // Get the created branch + branch, err := h.repo.GetBranch(ctx, conv.ID, req.Name) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to get created branch") + } + + response := toBranchResponse(branch, conv.ActiveBranch) + return &response, nil +} + +// GetBranch gets a branch by name +func (h *BranchHandler) GetBranch(ctx context.Context, conv *conversation.Conversation, branchName string) (*BranchResponse, error) { + branch, err := h.repo.GetBranch(ctx, conv.ID, branchName) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "branch not found") + } + + response := toBranchResponse(branch, conv.ActiveBranch) + return &response, nil +} + +// DeleteBranch deletes a branch +func (h *BranchHandler) DeleteBranch(ctx context.Context, conv *conversation.Conversation, branchName string) error { + // Normalize "main" to "MAIN" for case-insensitive matching + if branchName == "main" { + branchName = "MAIN" + } + + if branchName == "MAIN" { + return platformerrors.NewError(ctx, platformerrors.LayerHandler, platformerrors.ErrorTypeValidation, "cannot delete MAIN branch", nil, "c3d4e5f6-a7b8-4c9d-0e1f-2a3b4c5d6e7f") + } + + if branchName == conv.ActiveBranch { + return platformerrors.NewError(ctx, platformerrors.LayerHandler, platformerrors.ErrorTypeValidation, "cannot delete active branch", nil, "d4e5f6a7-b8c9-4d0e-1f2a-3b4c5d6e7f8a") + } + + if err := h.repo.DeleteBranch(ctx, conv.ID, branchName); err != nil { + return platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to delete branch") + } + + return nil +} + +// ActivateBranch sets a branch as active +func (h *BranchHandler) ActivateBranch(ctx context.Context, conv *conversation.Conversation, branchName string) (*ActivateBranchResponse, error) { + // Normalize "main" to "MAIN" for case-insensitive matching + if branchName == "main" { + branchName = "MAIN" + } + + // Verify branch exists (for non-MAIN branches) + if branchName != "MAIN" { + _, err := h.repo.GetBranch(ctx, conv.ID, branchName) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "branch not found") + } + } + + if err := h.repo.SetActiveBranch(ctx, conv.ID, branchName); err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to activate branch") + } + + return &ActivateBranchResponse{ + ActiveBranch: branchName, + Message: "Branch activated successfully", + }, nil +} + +// EditMessage edits a message and creates a new branch that becomes MAIN +func (h *BranchHandler) EditMessage(ctx context.Context, conv *conversation.Conversation, itemID string, req EditMessageRequest) (*EditMessageResponse, error) { + result, err := h.messageActionService.EditMessage(ctx, conv, itemID, req.Content) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to edit message") + } + + response := &EditMessageResponse{ + Branch: result.NewBranch, // Always "MAIN" + OldMainBackup: result.OldMainBackup, + BranchCreated: true, // Edit always creates a new branch (which becomes MAIN) + UserItem: result.UserItem, + } + + // Fetch the MAIN branch details (the new branch was swapped to MAIN) + if branch, err := h.repo.GetBranch(ctx, conv.ID, "MAIN"); err == nil { + branchResp := toBranchResponse(branch, "MAIN") + response.NewBranch = &branchResp + } + + return response, nil +} + +// RegenerateMessage regenerates an assistant response, creating a new MAIN branch +func (h *BranchHandler) RegenerateMessage(ctx context.Context, conv *conversation.Conversation, itemID string, req RegenerateMessageRequest) (*RegenerateMessageResponse, error) { + result, err := h.messageActionService.RegenerateResponse(ctx, conv, itemID) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to regenerate message") + } + + response := &RegenerateMessageResponse{ + Branch: result.NewBranch, // Always "MAIN" + OldMainBackup: result.OldMainBackup, + BranchCreated: true, // Regenerate always creates a new branch (which becomes MAIN) + UserItemID: result.UserItemID, + } + + // Fetch the MAIN branch details (the new branch was swapped to MAIN) + if branch, err := h.repo.GetBranch(ctx, conv.ID, "MAIN"); err == nil { + branchResp := toBranchResponse(branch, "MAIN") + response.NewBranch = &branchResp + } + + return response, nil +} + +// DeleteMessage deletes a message by creating a new MAIN branch without it +func (h *BranchHandler) DeleteMessage(ctx context.Context, conv *conversation.Conversation, itemID string) (*DeleteMessageResponse, error) { + result, err := h.messageActionService.DeleteMessage(ctx, conv, itemID) + if err != nil { + return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to delete message") + } + + return &DeleteMessageResponse{ + Branch: result.NewBranch, // Always "MAIN" + OldMainBackup: result.OldMainBackup, + BranchCreated: true, + Deleted: true, + }, nil +} +// =============================================== +// Helper Functions +// =============================================== + +func toBranchResponse(branch *conversation.BranchMetadata, activeBranch string) BranchResponse { + response := BranchResponse{ + Name: branch.Name, + Description: branch.Description, + ItemCount: branch.ItemCount, + CreatedAt: branch.CreatedAt.Unix(), + UpdatedAt: branch.UpdatedAt.Unix(), + IsActive: branch.Name == activeBranch, + } + + if branch.ParentBranch != nil { + response.ParentBranch = branch.ParentBranch + } + if branch.ForkedAt != nil { + ts := branch.ForkedAt.Unix() + response.ForkedAt = &ts + } + if branch.ForkedFromItemID != nil { + response.ForkedFromItemID = branch.ForkedFromItemID + } + + return response +} diff --git a/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/conversation_handler.go b/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/conversation_handler.go index c8974755..29781127 100644 --- a/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/conversation_handler.go +++ b/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler/conversation_handler.go @@ -31,23 +31,26 @@ const ( // ConversationHandler handles conversation-related HTTP requests type ConversationHandler struct { - conversationService *conversation.ConversationService - projectService *project.ProjectService - itemValidator *conversation.ItemValidator - shareRepo share.ShareRepository + conversationService *conversation.ConversationService + messageActionService *conversation.MessageActionService + projectService *project.ProjectService + itemValidator *conversation.ItemValidator + shareRepo share.ShareRepository } // NewConversationHandler creates a new conversation handler func NewConversationHandler( conversationService *conversation.ConversationService, + messageActionService *conversation.MessageActionService, projectService *project.ProjectService, shareRepo share.ShareRepository, ) *ConversationHandler { return &ConversationHandler{ - conversationService: conversationService, - projectService: projectService, - itemValidator: conversation.NewItemValidator(conversation.DefaultItemValidationConfig()), - shareRepo: shareRepo, + conversationService: conversationService, + messageActionService: messageActionService, + projectService: projectService, + itemValidator: conversation.NewItemValidator(conversation.DefaultItemValidationConfig()), + shareRepo: shareRepo, } } @@ -276,6 +279,7 @@ func (h *ConversationHandler) ListItems( ctx context.Context, userID uint, conversationID string, + branchName *string, pagination *query.Pagination, ) ([]conversation.Item, error) { // Verify conversation ownership @@ -284,8 +288,14 @@ func (h *ConversationHandler) ListItems( return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to get conversation") } - // Get items from repository for the active branch - items, err := h.conversationService.GetConversationItems(ctx, conv, conv.ActiveBranch, pagination) + // Use specified branch or fall back to active branch + branch := conv.ActiveBranch + if branchName != nil && *branchName != "" { + branch = *branchName + } + + // Get items from repository for the specified branch + items, err := h.conversationService.GetConversationItems(ctx, conv, branch, pagination) if err != nil { return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to list items") } @@ -371,26 +381,39 @@ func (h *ConversationHandler) GetItem( return item, nil } -// DeleteItem deletes an item from a conversation +// DeleteItemResponse represents the response for deleting a message +type DeleteItemResponse struct { + Branch string `json:"branch"` // Always "MAIN" after swap + OldMainBackup string `json:"old_main_backup"` // Backup name for old MAIN + BranchCreated bool `json:"branch_created"` + Deleted bool `json:"deleted"` +} + +// DeleteItem deletes an item from a conversation by creating a new MAIN branch without it func (h *ConversationHandler) DeleteItem( ctx context.Context, userID uint, conversationID string, itemID string, -) (*conversationresponses.ConversationResponse, error) { +) (*DeleteItemResponse, error) { // Verify conversation ownership conv, err := h.conversationService.GetConversationByPublicIDAndUserID(ctx, conversationID, userID) if err != nil { return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to get conversation") } - // Delete item - if err := h.conversationService.DeleteConversationItem(ctx, conv, itemID); err != nil { + // Delete item using branch swap approach + result, err := h.messageActionService.DeleteMessage(ctx, conv, itemID) + if err != nil { return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to delete item") } - // Return the conversation (per OpenAI spec) - return conversationresponses.NewConversationResponse(conv), nil + return &DeleteItemResponse{ + Branch: result.NewBranch, + OldMainBackup: result.OldMainBackup, + BranchCreated: true, + Deleted: true, + }, nil } // UpdateItemByCallID updates an existing mcp_call item with tool execution results diff --git a/services/llm-api/internal/interfaces/httpserver/requests/conversation/conversation.go b/services/llm-api/internal/interfaces/httpserver/requests/conversation/conversation.go index 4b74795e..0549a10a 100644 --- a/services/llm-api/internal/interfaces/httpserver/requests/conversation/conversation.go +++ b/services/llm-api/internal/interfaces/httpserver/requests/conversation/conversation.go @@ -39,6 +39,7 @@ type ListItemsQueryParams struct { Include []string `form:"include"` Limit *int `form:"limit"` Order *string `form:"order"` + Branch *string `form:"branch"` // Filter by branch name (defaults to active branch) } // GetItemQueryParams represents query parameters for getting a single item diff --git a/services/llm-api/internal/interfaces/httpserver/routes/routes_provider.go b/services/llm-api/internal/interfaces/httpserver/routes/routes_provider.go index 1ea7ae41..43cc1f3a 100644 --- a/services/llm-api/internal/interfaces/httpserver/routes/routes_provider.go +++ b/services/llm-api/internal/interfaces/httpserver/routes/routes_provider.go @@ -40,6 +40,7 @@ var RouteProvider = wire.NewSet( handlers.ProvideMemoryHandler, chathandler.NewChatHandler, conversationhandler.NewConversationHandler, + conversationhandler.NewBranchHandler, guestauth.NewGuestHandler, guestauth.NewUpgradeHandler, modelhandler.NewProviderHandler, @@ -66,6 +67,7 @@ var RouteProvider = wire.NewSet( chat.NewChatRoute, chat.NewChatCompletionRoute, conversation.NewConversationRoute, + conversation.NewBranchRoute, projects.NewProjectRoute, model.NewModelRoute, modelProvider.NewModelProviderRoute, diff --git a/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/branch_route.go b/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/branch_route.go new file mode 100644 index 00000000..042e3513 --- /dev/null +++ b/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/branch_route.go @@ -0,0 +1,307 @@ +package conversation + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "jan-server/services/llm-api/internal/interfaces/httpserver/handlers/authhandler" + "jan-server/services/llm-api/internal/interfaces/httpserver/handlers/conversationhandler" + "jan-server/services/llm-api/internal/interfaces/httpserver/responses" + "jan-server/services/llm-api/internal/utils/platformerrors" +) + +type BranchRoute struct { + handler *conversationhandler.ConversationHandler + branchHandler *conversationhandler.BranchHandler + authHandler *authhandler.AuthHandler +} + +func NewBranchRoute( + handler *conversationhandler.ConversationHandler, + branchHandler *conversationhandler.BranchHandler, + authHandler *authhandler.AuthHandler, +) *BranchRoute { + return &BranchRoute{ + handler: handler, + branchHandler: branchHandler, + authHandler: authHandler, + } +} + +func (route *BranchRoute) RegisterRouter(router gin.IRouter) { + conversations := router.Group("/conversations") + + // Branch CRUD endpoints + conversations.GET("/:conv_public_id/branches", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.listBranches)...) + conversations.POST("/:conv_public_id/branches", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.createBranch)...) + conversations.GET("/:conv_public_id/branches/:branch_name", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.getBranch)...) + conversations.DELETE("/:conv_public_id/branches/:branch_name", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.deleteBranch)...) + conversations.POST("/:conv_public_id/branches/:branch_name/activate", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.activateBranch)...) + + // Message action endpoints + conversations.POST("/:conv_public_id/items/:item_id/edit", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.editMessage)...) + conversations.POST("/:conv_public_id/items/:item_id/regenerate", route.authHandler.WithAppUserAuthChain(route.handler.ConversationMiddleware(), route.regenerateMessage)...) +} + +// listBranches godoc +// @Summary List branches +// @Description List all branches for a conversation +// @Tags Conversation Branches +// @Security BearerAuth +// @Produce json +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Success 200 {object} conversationhandler.ListBranchesResponse "Successfully retrieved branches" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Conversation not found" +// @Router /v1/conversations/{conv_public_id}/branches [get] +func (route *BranchRoute) listBranches(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d") + return + } + + response, err := route.branchHandler.ListBranches(ctx, conv) + if err != nil { + responses.HandleError(reqCtx, err, "Failed to list branches") + return + } + + reqCtx.JSON(http.StatusOK, response) +} + +// createBranch godoc +// @Summary Create a branch +// @Description Create a new branch in a conversation, optionally forking from an existing item +// @Tags Conversation Branches +// @Security BearerAuth +// @Accept json +// @Produce json +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Param request body conversationhandler.CreateBranchRequest true "Create branch request" +// @Success 201 {object} conversationhandler.BranchResponse "Successfully created branch" +// @Failure 400 {object} responses.ErrorResponse "Invalid request" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Conversation not found" +// @Router /v1/conversations/{conv_public_id}/branches [post] +func (route *BranchRoute) createBranch(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "b2c3d4e5-f6a7-4b8c-9d0e-1f2a3b4c5d6e") + return + } + + var req conversationhandler.CreateBranchRequest + if err := reqCtx.ShouldBindJSON(&req); err != nil { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "invalid request body", "c3d4e5f6-a7b8-4c9d-0e1f-2a3b4c5d6e7f") + return + } + + response, err := route.branchHandler.CreateBranch(ctx, conv, req) + if err != nil { + responses.HandleError(reqCtx, err, "Failed to create branch") + return + } + + reqCtx.JSON(http.StatusCreated, response) +} + +// getBranch godoc +// @Summary Get branch details +// @Description Get details of a specific branch +// @Tags Conversation Branches +// @Security BearerAuth +// @Produce json +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Param branch_name path string true "Branch name" +// @Success 200 {object} conversationhandler.BranchResponse "Successfully retrieved branch" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Branch not found" +// @Router /v1/conversations/{conv_public_id}/branches/{branch_name} [get] +func (route *BranchRoute) getBranch(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "d4e5f6a7-b8c9-4d0e-1f2a-3b4c5d6e7f8a") + return + } + + branchName := reqCtx.Param("branch_name") + if branchName == "" { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "branch name is required", "e5f6a7b8-c9d0-4e1f-2a3b-4c5d6e7f8a9b") + return + } + + response, err := route.branchHandler.GetBranch(ctx, conv, branchName) + if err != nil { + responses.HandleError(reqCtx, err, "Failed to get branch") + return + } + + reqCtx.JSON(http.StatusOK, response) +} + +// deleteBranch godoc +// @Summary Delete a branch +// @Description Delete a branch from a conversation (cannot delete MAIN or active branch) +// @Tags Conversation Branches +// @Security BearerAuth +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Param branch_name path string true "Branch name" +// @Success 204 "Branch deleted successfully" +// @Failure 400 {object} responses.ErrorResponse "Cannot delete MAIN or active branch" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Branch not found" +// @Router /v1/conversations/{conv_public_id}/branches/{branch_name} [delete] +func (route *BranchRoute) deleteBranch(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "f6a7b8c9-d0e1-4f2a-3b4c-5d6e7f8a9b0c") + return + } + + branchName := reqCtx.Param("branch_name") + if branchName == "" { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "branch name is required", "a7b8c9d0-e1f2-4a3b-4c5d-6e7f8a9b0c1d") + return + } + + if err := route.branchHandler.DeleteBranch(ctx, conv, branchName); err != nil { + responses.HandleError(reqCtx, err, "Failed to delete branch") + return + } + + reqCtx.Status(http.StatusNoContent) +} + +// activateBranch godoc +// @Summary Activate a branch +// @Description Set a branch as the active branch for a conversation +// @Tags Conversation Branches +// @Security BearerAuth +// @Produce json +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Param branch_name path string true "Branch name" +// @Success 200 {object} conversationhandler.ActivateBranchResponse "Branch activated successfully" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Branch not found" +// @Router /v1/conversations/{conv_public_id}/branches/{branch_name}/activate [post] +func (route *BranchRoute) activateBranch(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "b8c9d0e1-f2a3-4b4c-5d6e-7f8a9b0c1d2e") + return + } + + branchName := reqCtx.Param("branch_name") + if branchName == "" { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "branch name is required", "c9d0e1f2-a3b4-4c5d-6e7f-8a9b0c1d2e3f") + return + } + + response, err := route.branchHandler.ActivateBranch(ctx, conv, branchName) + if err != nil { + responses.HandleError(reqCtx, err, "Failed to activate branch") + return + } + + reqCtx.JSON(http.StatusOK, response) +} + +// editMessage godoc +// @Summary Edit a message +// @Description Edit a user message and create a new branch with the edited content +// @Tags Message Actions +// @Security BearerAuth +// @Accept json +// @Produce json +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Param item_id path string true "Message item ID (format: msg_xxxxx)" +// @Param request body conversationhandler.EditMessageRequest true "Edit message request" +// @Success 200 {object} conversationhandler.EditMessageResponse "Message edited successfully" +// @Failure 400 {object} responses.ErrorResponse "Invalid request or not a user message" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Message not found" +// @Router /v1/conversations/{conv_public_id}/items/{item_id}/edit [post] +func (route *BranchRoute) editMessage(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "d0e1f2a3-b4c5-4d6e-7f8a-9b0c1d2e3f4a") + return + } + + itemID := reqCtx.Param("item_id") + if itemID == "" { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "item ID is required", "e1f2a3b4-c5d6-4e7f-8a9b-0c1d2e3f4a5b") + return + } + + var req conversationhandler.EditMessageRequest + if err := reqCtx.ShouldBindJSON(&req); err != nil { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "invalid request body", "f2a3b4c5-d6e7-4f8a-9b0c-1d2e3f4a5b6c") + return + } + + response, err := route.branchHandler.EditMessage(ctx, conv, itemID, req) + if err != nil { + responses.HandleError(reqCtx, err, "Failed to edit message") + return + } + + reqCtx.JSON(http.StatusOK, response) +} + +// regenerateMessage godoc +// @Summary Regenerate a response +// @Description Regenerate an assistant response by creating a new branch +// @Tags Message Actions +// @Security BearerAuth +// @Accept json +// @Produce json +// @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" +// @Param item_id path string true "Assistant message item ID (format: msg_xxxxx)" +// @Param request body conversationhandler.RegenerateMessageRequest false "Regenerate options" +// @Success 200 {object} conversationhandler.RegenerateMessageResponse "Regeneration initiated" +// @Failure 400 {object} responses.ErrorResponse "Invalid request or not an assistant message" +// @Failure 401 {object} responses.ErrorResponse "Unauthorized" +// @Failure 404 {object} responses.ErrorResponse "Message not found" +// @Router /v1/conversations/{conv_public_id}/items/{item_id}/regenerate [post] +func (route *BranchRoute) regenerateMessage(reqCtx *gin.Context) { + ctx := reqCtx.Request.Context() + + conv, ok := conversationhandler.GetConversationFromContext(reqCtx) + if !ok { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeInternal, "conversation not found in context", "a3b4c5d6-e7f8-4a9b-0c1d-2e3f4a5b6c7d") + return + } + + itemID := reqCtx.Param("item_id") + if itemID == "" { + responses.HandleNewError(reqCtx, platformerrors.ErrorTypeValidation, "item ID is required", "b4c5d6e7-f8a9-4b0c-1d2e-3f4a5b6c7d8e") + return + } + + var req conversationhandler.RegenerateMessageRequest + // Body is optional for regenerate + _ = reqCtx.ShouldBindJSON(&req) + + response, err := route.branchHandler.RegenerateMessage(ctx, conv, itemID, req) + if err != nil { + responses.HandleError(reqCtx, err, "Failed to regenerate message") + return + } + + reqCtx.JSON(http.StatusOK, response) +} diff --git a/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/conversation_route.go b/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/conversation_route.go index 5fd33208..678db7b2 100644 --- a/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/conversation_route.go +++ b/services/llm-api/internal/interfaces/httpserver/routes/v1/conversation/conversation_route.go @@ -404,8 +404,8 @@ func (route *ConversationRoute) listItems(reqCtx *gin.Context) { fetchLimit := requestedLimit + 1 pagination.Limit = &fetchLimit - // Get items from handler - items, err := route.handler.ListItems(ctx, user.ID, conv.PublicID, pagination) + // Get items from handler with optional branch filter + items, err := route.handler.ListItems(ctx, user.ID, conv.PublicID, params.Branch, pagination) if err != nil { responses.HandleError(reqCtx, err, "Failed to list items") return @@ -556,28 +556,28 @@ func (route *ConversationRoute) getItem(reqCtx *gin.Context) { // deleteItem godoc // @Summary Delete a conversation item -// @Description Delete an item from a conversation. The item will be removed from the conversation. +// @Description Delete an item from a conversation by creating a new MAIN branch without it. +// @Description The old MAIN branch is preserved as a backup. // @Description // @Description **Features:** -// @Description - Remove specific item from conversation +// @Description - Creates a new branch without the deleted item +// @Description - New branch becomes MAIN, old MAIN becomes backup // @Description - Automatic ownership verification -// @Description - Returns updated conversation object after deletion -// @Description - Items are permanently removed (not soft delete) +// @Description - Preserves conversation history in backup branch // @Description // @Description **Important:** -// @Description - Deleting an item may affect conversation flow -// @Description - Item IDs are not reused after deletion -// @Description - Other items in conversation remain unchanged -// @Description - Consider creating a new branch instead of deleting items +// @Description - The old MAIN branch is renamed to MAIN_YYYYMMDDHHMMSS +// @Description - You can switch back to the backup branch if needed +// @Description - This is a non-destructive delete operation // @Description // @Description **Response:** -// @Description Returns the conversation object (not the deleted item) +// @Description Returns branch information including the backup branch name // @Tags Conversations API // @Security BearerAuth // @Produce json // @Param conv_public_id path string true "Conversation ID (format: conv_xxxxx)" // @Param item_id path string true "Item ID to delete (format: msg_xxxxx)" -// @Success 200 {object} conversationresponses.ConversationResponse "Successfully deleted item, returns conversation" +// @Success 200 {object} conversationhandler.DeleteItemResponse "Successfully deleted item, returns branch info" // @Failure 400 {object} responses.ErrorResponse "Invalid conversation ID or item ID format" // @Failure 401 {object} responses.ErrorResponse "Unauthorized - missing or invalid authentication" // @Failure 404 {object} responses.ErrorResponse "Conversation or item not found, or access denied" diff --git a/services/llm-api/internal/interfaces/httpserver/routes/v1/v1_route.go b/services/llm-api/internal/interfaces/httpserver/routes/v1/v1_route.go index 58e82e3f..a4fd699e 100644 --- a/services/llm-api/internal/interfaces/httpserver/routes/v1/v1_route.go +++ b/services/llm-api/internal/interfaces/httpserver/routes/v1/v1_route.go @@ -21,6 +21,7 @@ type V1Route struct { model *model.ModelRoute chat *chat.ChatRoute conversation *conversation.ConversationRoute + branch *conversation.BranchRoute project *projects.ProjectRoute adminRoute *admin.AdminRoute users *users.UsersRoute @@ -33,6 +34,7 @@ func NewV1Route( model *model.ModelRoute, chat *chat.ChatRoute, conversation *conversation.ConversationRoute, + branch *conversation.BranchRoute, project *projects.ProjectRoute, adminRoute *admin.AdminRoute, users *users.UsersRoute, @@ -44,6 +46,7 @@ func NewV1Route( model, chat, conversation, + branch, project, adminRoute, users, @@ -63,6 +66,7 @@ func (v1Route *V1Route) RegisterRouter(router gin.IRouter) { v1Route.model.RegisterRouter(v1Router) v1Route.chat.RegisterRouter(v1Router) v1Route.conversation.RegisterRouter(v1Router) + v1Route.branch.RegisterRouter(v1Router) v1Route.project.RegisterRoutes(v1Router) v1Route.users.RegisterRouter(v1Router) diff --git a/tests/automation/conversations-postman-scripts.json b/tests/automation/conversations-postman-scripts.json index 50d2a1c0..c914277e 100644 --- a/tests/automation/conversations-postman-scripts.json +++ b/tests/automation/conversations-postman-scripts.json @@ -2111,6 +2111,724 @@ } ] }, + { + "name": "Conversation Branching", + "description": "Tests for conversation branching: create branches, list branches, switch active branch, get branch items, edit message (creates branch), and delete branch", + "item": [ + { + "name": "Branch-1: Create Conversation for Branch Tests", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + }, + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"title\": \"Branch Test Conversation\"\n}" + }, + "url": "{{kong_url}}/v1/conversations" + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-1] Branch test conversation created', function () {", + " pm.expect(pm.response.code).to.be.oneOf([200, 201]);", + " const response = pm.response.json();", + " pm.collectionVariables.set('branch_test_conv_id', response.id);", + " console.log('Created branch test conversation:', response.id);", + "});" + ] + } + } + ] + }, + { + "name": "Branch-2: Add Messages to Create History", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + }, + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"model\": \"{{model_id}}\",\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"What is 2 + 2?\"\n }\n ],\n \"conversation\": {\n \"id\": \"{{branch_test_conv_id}}\"\n },\n \"stream\": false,\n \"max_tokens\": 100\n}" + }, + "url": "{{kong_url}}/v1/chat/completions" + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-2] First message exchange created', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response.choices).to.be.an('array').that.is.not.empty;", + " console.log('✓ First user/assistant exchange created');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-3: Add Second Message Exchange", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + }, + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"model\": \"{{model_id}}\",\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"Now what is 3 + 3?\"\n }\n ],\n \"conversation\": {\n \"id\": \"{{branch_test_conv_id}}\"\n },\n \"stream\": false,\n \"max_tokens\": 100\n}" + }, + "url": "{{kong_url}}/v1/chat/completions" + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-3] Second message exchange created', function () {", + " pm.response.to.have.status(200);", + " console.log('✓ Second user/assistant exchange created');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-4: Get Conversation Items (MAIN branch)", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/items", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "items"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-4] Items retrieved from MAIN branch', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response).to.have.property('data');", + " pm.expect(response.data).to.be.an('array');", + " pm.expect(response.data.length).to.be.at.least(4);", + " console.log('Found ' + response.data.length + ' items in MAIN branch');", + "});", + "", + "// Use data[3] which should be first user message (desc order: 0=assistant2, 1=user2, 2=assistant1, 3=user1)", + "pm.collectionVariables.set('edit_target_item_id', pm.response.json().data[3].id);", + "pm.collectionVariables.set('regen_target_item_id', pm.response.json().data[2].id);" + ] + } + } + ] + }, + { + "name": "Branch-5: List Branches (should only have MAIN)", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-5] Branches list retrieved', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response).to.have.property('data');", + " pm.expect(response.data).to.be.an('array');", + "});", + "", + "pm.test('[Branch-5] MAIN branch exists', function () {", + " const response = pm.response.json();", + " const mainBranch = response.data.find(b => b.name === 'main' || b.name === 'MAIN');", + " pm.expect(mainBranch).to.exist;", + " pm.expect(mainBranch.is_active).to.be.true;", + " console.log('✓ MAIN branch is active');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-6: Create New Branch", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + }, + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"name\": \"edit-branch-1\",\n \"parent_branch\": \"MAIN\",\n \"fork_from_item_id\": \"{{edit_target_item_id}}\"\n}" + }, + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-6] New branch created', function () {", + " pm.expect(pm.response.code).to.be.oneOf([200, 201]);", + " const response = pm.response.json();", + " pm.expect(response).to.have.property('name', 'edit-branch-1');", + " pm.collectionVariables.set('new_branch_name', response.name);", + " console.log('✓ Created branch:', response.name);", + "});" + ] + } + } + ] + }, + { + "name": "Branch-7: List Branches (should have 2)", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-7] Two branches now exist', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response.data.length).to.equal(2);", + " console.log('✓ Conversation now has 2 branches');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-8: Activate New Branch", + "request": { + "method": "POST", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches/{{new_branch_name}}/activate", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches", "{{new_branch_name}}", "activate"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-8] Branch activated successfully', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response.active_branch).to.equal(pm.collectionVariables.get('new_branch_name'));", + " console.log('✓ Active branch switched to:', response.active_branch);", + "});" + ] + } + } + ] + }, + { + "name": "Branch-9: Get Items with Branch Parameter", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/items?branch={{new_branch_name}}", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "items"], + "query": [ + { + "key": "branch", + "value": "{{new_branch_name}}" + } + ] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-9] Items retrieved from new branch', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response).to.have.property('data');", + " pm.expect(response.data).to.be.an('array');", + " console.log('Found ' + response.data.length + ' items in branch: ' + pm.collectionVariables.get('new_branch_name'));", + "});" + ] + } + } + ] + }, + { + "name": "Branch-10: Add Message to New Branch", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + }, + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"model\": \"{{model_id}}\",\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"Actually, what is 5 + 5?\"\n }\n ],\n \"conversation\": {\n \"id\": \"{{branch_test_conv_id}}\",\n \"branch\": \"{{new_branch_name}}\"\n },\n \"stream\": false,\n \"max_tokens\": 100\n}" + }, + "url": "{{kong_url}}/v1/chat/completions" + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-10] Message added to new branch', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response.choices).to.be.an('array').that.is.not.empty;", + " console.log('✓ Message added to branch:', pm.collectionVariables.get('new_branch_name'));", + "});" + ] + } + } + ] + }, + { + "name": "Branch-11: Verify Branch Has More Items Than MAIN", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/items?branch={{new_branch_name}}", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "items"], + "query": [ + { + "key": "branch", + "value": "{{new_branch_name}}" + } + ] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-11] Branch diverged from MAIN', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " // Branch should have items from fork point + new items", + " pm.expect(response.data.length).to.be.at.least(2);", + " console.log('Branch has ' + response.data.length + ' items after divergence');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-12: Edit Message (via API)", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + }, + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"content\": \"What is 10 + 10?\"\n}" + }, + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/items/{{edit_target_item_id}}/edit", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "items", "{{edit_target_item_id}}", "edit"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-12] Message edited (creates new branch swapped to MAIN)', function () {", + " pm.expect(pm.response.code).to.be.oneOf([200, 201]);", + " const response = pm.response.json();", + " ", + " // Verify response structure - branch should always be MAIN after swap", + " pm.expect(response.branch).to.equal('MAIN');", + " pm.expect(response.branch_created).to.be.true;", + " pm.expect(response.user_item).to.be.an('object');", + " pm.expect(response.old_main_backup).to.be.a('string');", + " ", + " // Verify new_branch details point to MAIN", + " if (response.new_branch) {", + " pm.expect(response.new_branch.name).to.equal('MAIN');", + " pm.expect(response.new_branch.item_count).to.be.a('number');", + " console.log('✓ Branch promoted to MAIN, old MAIN backup:', response.old_main_backup);", + " }", + " ", + " pm.collectionVariables.set('edit_created_branch', response.branch);", + " pm.collectionVariables.set('edit_old_main_backup', response.old_main_backup);", + " console.log('✓ Edit swapped to MAIN, backup branch:', response.old_main_backup);", + "});" + ] + } + } + ] + }, + { + "name": "Branch-13: Regenerate Assistant Message", + "request": { + "method": "POST", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"model\": \"{{model_id}}\",\n \"max_tokens\": 100\n}" + }, + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/items/{{regen_target_item_id}}/regenerate", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "items", "{{regen_target_item_id}}", "regenerate"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-13] Message regenerated (swapped to MAIN)', function () {", + " pm.expect(pm.response.code).to.be.oneOf([200, 201]);", + " const response = pm.response.json();", + " ", + " // Verify response structure - branch should always be MAIN after swap", + " pm.expect(response.branch).to.equal('MAIN');", + " pm.expect(response.branch_created).to.be.true;", + " pm.expect(response.user_item_id).to.be.a('string');", + " pm.expect(response.old_main_backup).to.be.a('string');", + " ", + " // Verify new_branch details point to MAIN", + " if (response.new_branch) {", + " pm.expect(response.new_branch.name).to.equal('MAIN');", + " pm.expect(response.new_branch.item_count).to.be.a('number');", + " console.log('✓ Branch promoted to MAIN, old MAIN backup:', response.old_main_backup);", + " }", + " ", + " pm.collectionVariables.set('regen_created_branch', response.branch);", + " pm.collectionVariables.set('regen_old_main_backup', response.old_main_backup);", + " console.log('✓ Regenerate swapped to MAIN, backup branch:', response.old_main_backup);", + "});" + ] + } + } + ] + }, + { + "name": "Branch-14: Get Branch Details", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches/{{new_branch_name}}", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches", "{{new_branch_name}}"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-14] Branch details retrieved', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response).to.have.property('name');", + " pm.expect(response).to.have.property('created_at');", + " console.log('✓ Branch details:', JSON.stringify(response));", + "});" + ] + } + } + ] + }, + { + "name": "Branch-15: Switch Back to MAIN", + "request": { + "method": "POST", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches/main/activate", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches", "main", "activate"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-15] Switched back to MAIN branch', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " pm.expect(response.active_branch.toLowerCase()).to.equal('main');", + " console.log('✓ Switched back to MAIN branch');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-16: Delete Non-Active Branch", + "request": { + "method": "DELETE", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches/{{new_branch_name}}", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches", "{{new_branch_name}}"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-16] Branch deleted successfully', function () {", + " pm.expect(pm.response.code).to.be.oneOf([200, 204]);", + " console.log('✓ Branch deleted:', pm.collectionVariables.get('new_branch_name'));", + "});" + ] + } + } + ] + }, + { + "name": "Branch-17: Verify Branch Deleted", + "request": { + "method": "GET", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-17] Only MAIN branch remains', function () {", + " pm.response.to.have.status(200);", + " const response = pm.response.json();", + " const deletedBranch = response.data.find(b => b.name === pm.collectionVariables.get('new_branch_name'));", + " pm.expect(deletedBranch).to.be.undefined;", + " console.log('✓ Deleted branch no longer exists');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-18: Attempt Delete MAIN Branch (should fail)", + "request": { + "method": "DELETE", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": { + "raw": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}/branches/main", + "host": ["{{kong_url}}"], + "path": ["v1", "conversations", "{{branch_test_conv_id}}", "branches", "main"] + } + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-18] MAIN branch deletion rejected', function () {", + " pm.expect(pm.response.code).to.be.oneOf([400, 403, 422]);", + " console.log('✓ MAIN branch correctly protected from deletion');", + "});" + ] + } + } + ] + }, + { + "name": "Branch-19: Cleanup - Delete Branch Test Conversation", + "request": { + "method": "DELETE", + "header": [ + { + "key": "Authorization", + "value": "Bearer {{access_token}}" + } + ], + "url": "{{kong_url}}/v1/conversations/{{branch_test_conv_id}}" + }, + "event": [ + { + "listen": "test", + "script": { + "type": "text/javascript", + "exec": [ + "pm.test('[Branch-19] Branch test conversation cleaned up', function () {", + " pm.expect(pm.response.code).to.be.oneOf([200, 204, 404]);", + " console.log('✓ Branch test cleanup complete');", + "});" + ] + } + } + ] + } + ] + }, { "name": "Cleanup", "item": [ @@ -2343,6 +3061,36 @@ "value": "", "type": "string", "description": "Share slug for cascade delete test" + }, + { + "key": "branch_test_conv_id", + "value": "", + "type": "string", + "description": "Conversation ID for branch tests" + }, + { + "key": "edit_target_item_id", + "value": "", + "type": "string", + "description": "User item ID to use for edit tests" + }, + { + "key": "regen_target_item_id", + "value": "", + "type": "string", + "description": "Assistant item ID to use for regenerate tests" + }, + { + "key": "new_branch_name", + "value": "", + "type": "string", + "description": "Name of the newly created branch" + }, + { + "key": "edit_created_branch", + "value": "", + "type": "string", + "description": "Branch created by edit operation" } ] }