diff --git a/pkg/runtime/event.go b/pkg/runtime/event.go index d4b141d1e..1b6f94fde 100644 --- a/pkg/runtime/event.go +++ b/pkg/runtime/event.go @@ -333,17 +333,19 @@ func SessionTitle(sessionID, title string) Event { type SessionSummaryEvent struct { AgentContext - Type string `json:"type"` - SessionID string `json:"session_id"` - Summary string `json:"summary"` + Type string `json:"type"` + SessionID string `json:"session_id"` + Summary string `json:"summary"` + FirstKeptEntry int `json:"first_kept_entry,omitempty"` } -func SessionSummary(sessionID, summary, agentName string) Event { +func SessionSummary(sessionID, summary, agentName string, firstKeptEntry int) Event { return &SessionSummaryEvent{ - Type: "session_summary", - SessionID: sessionID, - Summary: summary, - AgentContext: newAgentContext(agentName), + Type: "session_summary", + SessionID: sessionID, + Summary: summary, + FirstKeptEntry: firstKeptEntry, + AgentContext: newAgentContext(agentName), } } diff --git a/pkg/runtime/persistent_runtime.go b/pkg/runtime/persistent_runtime.go index cf93e8427..c6b704da7 100644 --- a/pkg/runtime/persistent_runtime.go +++ b/pkg/runtime/persistent_runtime.go @@ -131,7 +131,7 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi } case *SessionSummaryEvent: - if err := r.sessionStore.AddSummary(ctx, e.SessionID, e.Summary); err != nil { + if err := r.sessionStore.AddSummary(ctx, e.SessionID, e.Summary, e.FirstKeptEntry); err != nil { slog.Warn("Failed to persist summary", "session_id", e.SessionID, "error", err) } diff --git a/pkg/runtime/remote_runtime.go b/pkg/runtime/remote_runtime.go index 9f8a0db87..5f03297cd 100644 --- a/pkg/runtime/remote_runtime.go +++ b/pkg/runtime/remote_runtime.go @@ -228,7 +228,7 @@ func (r *RemoteRuntime) Resume(ctx context.Context, req ResumeRequest) { // Summarize generates a summary for the session func (r *RemoteRuntime) Summarize(_ context.Context, sess *session.Session, _ string, events chan Event) { slog.Debug("Summarize not yet implemented for remote runtime", "session_id", r.sessionID) - events <- SessionSummary(sess.ID, "Summary generation not yet implemented for remote runtime", r.currentAgent) + events <- SessionSummary(sess.ID, "Summary generation not yet implemented for remote runtime", r.currentAgent, 0) } func (r *RemoteRuntime) convertSessionMessages(sess *session.Session) []api.Message { diff --git a/pkg/runtime/session_compaction.go b/pkg/runtime/session_compaction.go index 714038e4c..afa015aed 100644 --- a/pkg/runtime/session_compaction.go +++ b/pkg/runtime/session_compaction.go @@ -17,6 +17,11 @@ import ( const maxSummaryTokens = 16_000 +// maxKeepTokens is the maximum number of tokens to preserve from the end of +// the conversation during compaction. These recent messages are kept verbatim +// so the LLM can continue naturally after compaction. +const maxKeepTokens = 20_000 + // doCompact runs compaction on a session and applies the result (events, // persistence, token count updates). The agent is used to extract the // conversation from the session and to obtain the model for summarization. @@ -41,8 +46,8 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a * compactionAgent := agent.New("root", compaction.SystemPrompt, agent.WithModel(summaryModel)) - // Compute the messages to compact. - messages := extractMessagesToCompact(sess, compactionAgent, int64(m.Limit.Context), additionalPrompt) + // Compute the messages to compact, keeping recent messages aside. + messages, firstKeptEntry := extractMessagesToCompact(sess, compactionAgent, int64(m.Limit.Context), additionalPrompt) // Run the compaction. compactionSession := session.New( @@ -72,16 +77,21 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a * sess.InputTokens = compactionSession.OutputTokens sess.OutputTokens = 0 sess.Messages = append(sess.Messages, session.Item{ - Summary: summary, - Cost: compactionSession.TotalCost(), + Summary: summary, + FirstKeptEntry: firstKeptEntry, + Cost: compactionSession.TotalCost(), }) _ = r.sessionStore.UpdateSession(ctx, sess) slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary)) - events <- SessionSummary(sess.ID, summary, a.Name()) + events <- SessionSummary(sess.ID, summary, a.Name(), firstKeptEntry) } -func extractMessagesToCompact(sess *session.Session, compactionAgent *agent.Agent, contextLimit int64, additionalPrompt string) []chat.Message { +// extractMessagesToCompact returns the messages to send to the compaction model +// and the index (into sess.Messages) of the first message that was kept aside. +// Recent messages (up to maxKeepTokens) are excluded from compaction so they +// can be preserved verbatim in the session after summarization. +func extractMessagesToCompact(sess *session.Session, compactionAgent *agent.Agent, contextLimit int64, additionalPrompt string) ([]chat.Message, int) { // Add all the existing messages. var messages []chat.Message for _, msg := range sess.GetMessages(compactionAgent) { @@ -95,6 +105,17 @@ func extractMessagesToCompact(sess *session.Session, compactionAgent *agent.Agen messages = append(messages, msg) } + // Split: keep the last N tokens of messages aside so the LLM retains + // recent context after compaction. + splitIdx := splitIndexForKeep(messages, maxKeepTokens) + messagesToCompact := messages[:splitIdx] + // Compute firstKeptEntry: index into sess.Messages of the first kept message. + // The kept messages start at splitIdx in the non-system filtered list. We + // need to map this back to the original sess.Messages index. + firstKeptEntry := mapToSessionIndex(sess, splitIdx) + + messages = messagesToCompact + // Prepare the first (system) message. systemPromptMessage := chat.Message{ Role: chat.MessageRoleSystem, @@ -131,7 +152,49 @@ func extractMessagesToCompact(sess *session.Session, compactionAgent *agent.Agen // Append the last (user) message. messages = append(messages, userPromptMessage) - return messages + return messages, firstKeptEntry +} + +// splitIndexForKeep returns the index that splits messages into [0:idx] (to +// compact) and [idx:] (to keep). It walks backwards accumulating tokens up to +// maxTokens, snapping to user/assistant boundaries. +func splitIndexForKeep(messages []chat.Message, maxTokens int64) int { + if len(messages) == 0 { + return 0 + } + + var tokens int64 + // Walk from the end; find the earliest index whose suffix fits in maxTokens. + lastValidBoundary := len(messages) + for i := len(messages) - 1; i >= 0; i-- { + tokens += compaction.EstimateMessageTokens(&messages[i]) + if tokens > maxTokens { + return lastValidBoundary + } + role := messages[i].Role + if role == chat.MessageRoleUser || role == chat.MessageRoleAssistant { + lastValidBoundary = i + } + } + // All messages fit within maxTokens — don't keep any aside (compact everything). + return len(messages) +} + +// mapToSessionIndex maps an index in the non-system-filtered message list back +// to the corresponding index in sess.Messages. It counts only message items +// that are not system messages. +func mapToSessionIndex(sess *session.Session, filteredIdx int) int { + count := 0 + for i, item := range sess.Messages { + if item.IsMessage() && item.Message.Message.Role != chat.MessageRoleSystem { + if count == filteredIdx { + return i + } + count++ + } + } + // filteredIdx is past the end — no messages to keep. + return len(sess.Messages) } func firstMessageToKeep(messages []chat.Message, contextLimit int64) int { diff --git a/pkg/runtime/session_compaction_test.go b/pkg/runtime/session_compaction_test.go index 57d3c73f3..2387907ab 100644 --- a/pkg/runtime/session_compaction_test.go +++ b/pkg/runtime/session_compaction_test.go @@ -1,9 +1,11 @@ package runtime import ( + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" @@ -96,7 +98,7 @@ func TestExtractMessagesToCompact(t *testing.T) { sess := session.New(session.WithMessages(tt.messages)) a := agent.New("test", "test prompt") - result := extractMessagesToCompact(sess, a, tt.contextLimit, tt.additionalPrompt) + result, _ := extractMessagesToCompact(sess, a, tt.contextLimit, tt.additionalPrompt) assert.GreaterOrEqual(t, len(result), tt.wantConversationMsgCount+2) assert.Equal(t, chat.MessageRoleSystem, result[0].Role) @@ -121,3 +123,169 @@ func TestExtractMessagesToCompact(t *testing.T) { }) } } + +func TestSplitIndexForKeep(t *testing.T) { + msg := func(role chat.MessageRole, content string) chat.Message { + return chat.Message{Role: role, Content: content} + } + + tests := []struct { + name string + messages []chat.Message + maxTokens int64 + wantSplit int // expected split index + }{ + { + name: "empty messages", + messages: nil, + maxTokens: 1000, + wantSplit: 0, + }, + { + name: "all messages fit in keep budget - compact everything", + messages: []chat.Message{ + msg(chat.MessageRoleUser, "short"), + msg(chat.MessageRoleAssistant, "short"), + }, + maxTokens: 100_000, + wantSplit: 2, // all fit → compact everything + }, + { + name: "recent messages kept, older ones compacted", + messages: []chat.Message{ + msg(chat.MessageRoleUser, strings.Repeat("a", 40000)), // ~10005 tokens + msg(chat.MessageRoleAssistant, strings.Repeat("b", 40000)), // ~10005 tokens + msg(chat.MessageRoleUser, strings.Repeat("c", 40000)), // ~10005 tokens + msg(chat.MessageRoleAssistant, strings.Repeat("d", 40000)), // ~10005 tokens + msg(chat.MessageRoleUser, strings.Repeat("e", 40000)), // ~10005 tokens + msg(chat.MessageRoleAssistant, strings.Repeat("f", 40000)), // ~10005 tokens + }, + maxTokens: 20_100, // enough for exactly 2 messages + wantSplit: 4, // last 2 messages are kept + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := splitIndexForKeep(tt.messages, tt.maxTokens) + assert.Equal(t, tt.wantSplit, got) + }) + } +} + +func TestExtractMessagesToCompact_KeepsRecentMessages(t *testing.T) { + // Create a session with many messages, some large enough that the last + // ~20k tokens are kept aside. + var items []session.Item + for range 10 { + items = append(items, session.NewMessageItem(&session.Message{ + Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: strings.Repeat("x", 20000), // ~5k tokens each + }, + }), session.NewMessageItem(&session.Message{ + Message: chat.Message{ + Role: chat.MessageRoleAssistant, + Content: strings.Repeat("y", 20000), // ~5k tokens each + }, + })) + } + + sess := session.New(session.WithMessages(items)) + a := agent.New("test", "test prompt") + + result, firstKeptEntry := extractMessagesToCompact(sess, a, 200_000, "") + + // The kept messages should not appear in the compaction result + // (only system + compacted messages + user prompt). + // Total: 20 messages × ~5k tokens = ~100k tokens. + // Keep budget: 20k tokens → ~4 messages kept. + // So compacted messages should be 20 - 4 = 16. + compactedMsgCount := len(result) - 2 // minus system and user prompt + assert.Less(t, compactedMsgCount, 20, "some messages should have been kept aside") + assert.Positive(t, compactedMsgCount, "some messages should be compacted") + + // firstKeptEntry should point into sess.Messages + assert.Positive(t, firstKeptEntry, "firstKeptEntry should be > 0") + assert.Less(t, firstKeptEntry, len(sess.Messages), "firstKeptEntry should be within bounds") +} + +func TestSessionGetMessages_WithFirstKeptEntry(t *testing.T) { + // Build a session with some messages, then add a summary with FirstKeptEntry. + items := []session.Item{ + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleUser, Content: "m1"}, + }), + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "m2"}, + }), + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleUser, Content: "m3"}, + }), + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "m4"}, + }), + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleUser, Content: "m5"}, + }), + } + + // Add summary that says "first kept entry is index 3" (m4). + // So we expect: [system...] + [summary] + [m4, m5] + items = append(items, session.Item{ + Summary: "This is a summary of m1-m3", + FirstKeptEntry: 3, // index of m4 in the Messages slice + }) + + sess := session.New(session.WithMessages(items)) + a := agent.New("test", "test instruction") + + messages := sess.GetMessages(a) + + // Extract just the non-system messages + var conversationMessages []chat.Message + for _, msg := range messages { + if msg.Role != chat.MessageRoleSystem { + conversationMessages = append(conversationMessages, msg) + } + } + + // Should have: summary (as user message), m4, m5 + require.Len(t, conversationMessages, 3, "expected summary + 2 kept messages") + assert.Contains(t, conversationMessages[0].Content, "Session Summary:") + assert.Equal(t, "m4", conversationMessages[1].Content) + assert.Equal(t, "m5", conversationMessages[2].Content) +} + +func TestSessionGetMessages_SummaryWithoutFirstKeptEntry(t *testing.T) { + // Backward compatibility: summary without FirstKeptEntry should work as before. + items := []session.Item{ + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleUser, Content: "m1"}, + }), + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "m2"}, + }), + {Summary: "This is a summary"}, + session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: chat.MessageRoleUser, Content: "m3"}, + }), + } + + sess := session.New(session.WithMessages(items)) + a := agent.New("test", "test instruction") + + messages := sess.GetMessages(a) + + var conversationMessages []chat.Message + for _, msg := range messages { + if msg.Role != chat.MessageRoleSystem { + conversationMessages = append(conversationMessages, msg) + } + } + + // Should have: summary + m3 (messages after the summary) + require.Len(t, conversationMessages, 2) + assert.Contains(t, conversationMessages[0].Content, "Session Summary:") + assert.Equal(t, "m3", conversationMessages[1].Content) +} diff --git a/pkg/session/migrations.go b/pkg/session/migrations.go index 77a93bbf2..1a32aa010 100644 --- a/pkg/session/migrations.go +++ b/pkg/session/migrations.go @@ -350,6 +350,12 @@ func getAllMigrations() []Migration { Description: "Drop the legacy messages JSON column now that all data lives in session_items", UpSQL: `ALTER TABLE sessions DROP COLUMN messages`, }, + { + ID: 21, + Name: "021_add_first_kept_entry_column", + Description: "Add first_kept_entry column to session_items for compaction-preserved messages", + UpSQL: `ALTER TABLE session_items ADD COLUMN first_kept_entry INTEGER DEFAULT 0`, + }, } } diff --git a/pkg/session/session.go b/pkg/session/session.go index 89b0324e4..0c5b3213f 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -38,6 +38,13 @@ type Item struct { // Summary is a summary of the session up until this point Summary string `json:"summary,omitempty"` + // FirstKeptEntry is the index (into the session's Messages slice) of the + // first message that was kept verbatim during compaction. Messages from + // this index onward (up to the summary item itself) are appended after + // the summary when reconstructing the conversation. A value of -1 (or 0 + // with no summary) means no messages were kept. + FirstKeptEntry int `json:"first_kept_entry,omitempty"` + // Cost tracks the cost of operations associated with this item that // don't produce a regular message (e.g., compaction/summarization). Cost float64 `json:"cost,omitempty"` @@ -732,7 +739,11 @@ func buildContextSpecificSystemMessages(a *agent.Agent, s *Session) []chat.Messa // buildSessionSummaryMessages builds system messages containing the session summary // if one exists. Session summaries are context-specific per session and thus should not have a checkpoint (they will be cached alongside the first user message anyway) // -// lastSummaryIndex is the index of the last summary item in s.Messages, or -1 if none exists. +// startIndex is the index in items from which conversation messages should be +// emitted. When a summary with FirstKeptEntry is present, this points to the +// first kept message so that recent context is preserved after compaction. +// Otherwise it is lastSummaryIndex+1 (i.e. right after the summary item), or +// 0 when there is no summary. func buildSessionSummaryMessages(items []Item) ([]chat.Message, int) { var messages []chat.Message // Find the last summary index to determine where conversation messages start @@ -753,7 +764,18 @@ func buildSessionSummaryMessages(items []Item) ([]chat.Message, int) { }) } - return messages, lastSummaryIndex + // Determine where conversation messages should start. + // If the summary has a FirstKeptEntry, we start from there so that + // messages kept during compaction are included after the summary. + startIndex := lastSummaryIndex + 1 + if lastSummaryIndex >= 0 { + kept := items[lastSummaryIndex].FirstKeptEntry + if kept > 0 && kept < lastSummaryIndex { + startIndex = kept + } + } + + return messages, startIndex } func (s *Session) GetMessages(a *agent.Agent) []chat.Message { @@ -781,15 +803,13 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message { s.mu.RUnlock() // Build session summary messages (vary per session) - summaryMessages, lastSummaryIndex := buildSessionSummaryMessages(items) + summaryMessages, startIndex := buildSessionSummaryMessages(items) var messages []chat.Message messages = append(messages, invariantMessages...) messages = append(messages, contextMessages...) messages = append(messages, summaryMessages...) - startIndex := lastSummaryIndex + 1 - // Begin adding conversation messages for i := startIndex; i < len(items); i++ { item := items[i] diff --git a/pkg/session/store.go b/pkg/session/store.go index 0982588be..775a5e689 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -98,8 +98,9 @@ type Store interface { // The sub-session is stored as a separate session row with parent_id set. AddSubSession(ctx context.Context, parentSessionID string, subSession *Session) error - // AddSummary adds a summary item to a session at the next position - AddSummary(ctx context.Context, sessionID, summary string) error + // AddSummary adds a summary item to a session at the next position. + // firstKeptEntry is the index of the first message kept verbatim during compaction. + AddSummary(ctx context.Context, sessionID, summary string, firstKeptEntry int) error // === Granular metadata updates === @@ -303,7 +304,7 @@ func (s *InMemorySessionStore) AddSubSession(_ context.Context, parentSessionID } // AddSummary adds a summary item to a session at the next position. -func (s *InMemorySessionStore) AddSummary(_ context.Context, sessionID, summary string) error { +func (s *InMemorySessionStore) AddSummary(_ context.Context, sessionID, summary string, firstKeptEntry int) error { if sessionID == "" { return ErrEmptyID } @@ -312,7 +313,7 @@ func (s *InMemorySessionStore) AddSummary(_ context.Context, sessionID, summary return ErrNotFound } session.mu.Lock() - session.Messages = append(session.Messages, Item{Summary: summary}) + session.Messages = append(session.Messages, Item{Summary: summary, FirstKeptEntry: firstKeptEntry}) session.mu.Unlock() return nil } @@ -658,13 +659,14 @@ func (s *SQLiteSessionStore) GetSession(ctx context.Context, id string) (*Sessio // sessionItemRow holds the raw data from a session_items row type sessionItemRow struct { - position int - itemType string - agentName sql.NullString - messageJSON sql.NullString - implicit bool - subsessionID sql.NullString - summaryText sql.NullString + position int + itemType string + agentName sql.NullString + messageJSON sql.NullString + implicit bool + subsessionID sql.NullString + summaryText sql.NullString + firstKeptEntry int } // loadSessionItems loads all items for a session from the session_items table. @@ -675,7 +677,7 @@ func (s *SQLiteSessionStore) loadSessionItems(ctx context.Context, sessionID str // loadSessionItemsWith loads items using the provided querier (db or tx). func (s *SQLiteSessionStore) loadSessionItemsWith(ctx context.Context, q querier, sessionID string) ([]Item, error) { rows, err := q.QueryContext(ctx, - `SELECT position, item_type, agent_name, message_json, implicit, subsession_id, summary_text + `SELECT position, item_type, agent_name, message_json, implicit, subsession_id, summary_text, COALESCE(first_kept_entry, 0) FROM session_items WHERE session_id = ? ORDER BY position`, sessionID) if err != nil { return nil, err @@ -687,7 +689,7 @@ func (s *SQLiteSessionStore) loadSessionItemsWith(ctx context.Context, q querier var rawRows []sessionItemRow for rows.Next() { var row sessionItemRow - if err := rows.Scan(&row.position, &row.itemType, &row.agentName, &row.messageJSON, &row.implicit, &row.subsessionID, &row.summaryText); err != nil { + if err := rows.Scan(&row.position, &row.itemType, &row.agentName, &row.messageJSON, &row.implicit, &row.subsessionID, &row.summaryText, &row.firstKeptEntry); err != nil { return nil, err } rawRows = append(rawRows, row) @@ -737,7 +739,7 @@ func (s *SQLiteSessionStore) loadSessionItemsWith(ctx context.Context, q querier items = append(items, Item{SubSession: subSession}) case "summary": - items = append(items, Item{Summary: row.summaryText.String}) + items = append(items, Item{Summary: row.summaryText.String, FirstKeptEntry: row.firstKeptEntry}) } } @@ -1166,9 +1168,9 @@ func (s *SQLiteSessionStore) addItemTx(ctx context.Context, tx *sql.Tx, sessionI case item.Summary != "": _, err := tx.ExecContext(ctx, - `INSERT INTO session_items (session_id, position, item_type, summary_text) - VALUES (?, ?, 'summary', ?)`, - sessionID, position, item.Summary) + `INSERT INTO session_items (session_id, position, item_type, summary_text, first_kept_entry) + VALUES (?, ?, 'summary', ?, ?)`, + sessionID, position, item.Summary, item.FirstKeptEntry) return err default: @@ -1177,15 +1179,15 @@ func (s *SQLiteSessionStore) addItemTx(ctx context.Context, tx *sql.Tx, sessionI } // AddSummary adds a summary item to a session at the next position. -func (s *SQLiteSessionStore) AddSummary(ctx context.Context, sessionID, summary string) error { +func (s *SQLiteSessionStore) AddSummary(ctx context.Context, sessionID, summary string, firstKeptEntry int) error { if sessionID == "" { return ErrEmptyID } _, err := s.db.ExecContext(ctx, - `INSERT INTO session_items (session_id, position, item_type, summary_text) - VALUES (?, (SELECT COALESCE(MAX(position), -1) + 1 FROM session_items WHERE session_id = ?), 'summary', ?)`, - sessionID, sessionID, summary) + `INSERT INTO session_items (session_id, position, item_type, summary_text, first_kept_entry) + VALUES (?, (SELECT COALESCE(MAX(position), -1) + 1 FROM session_items WHERE session_id = ?), 'summary', ?, ?)`, + sessionID, sessionID, summary, firstKeptEntry) if err != nil { return err }