diff --git a/docs/gateway-context.md b/docs/gateway-context.md new file mode 100644 index 000000000..62aab57b5 --- /dev/null +++ b/docs/gateway-context.md @@ -0,0 +1,44 @@ +# Gateway Context Providers + +OpenAB gateway adapters can optionally enrich an admitted user turn with recent chat context that the bot would otherwise miss because of mention gating or platform-specific admission rules. + +The gateway-level `ContextProvider` abstraction keeps this behavior shared across platforms: + +- `observe()` records a message that was seen by an adapter but not dispatched to the agent. +- `fetch_context()` returns recent context for an admitted turn. +- `inject_context()` prepends that context with a clear boundary before the current message. + +## Provider Types + +| Provider | Intended platforms | Data source | +|---|---|---| +| `BufferedContextProvider` | LINE, Telegram, WeChat/WeCom, Feishu fallback | webhook observe -> local bounded buffer | +| `ApiFetchContextProvider` | Discord, Slack, Teams, Google Chat where available | platform history API | +| Hybrid provider | Google Chat and other mixed-permission platforms | API fetch when possible, buffer fallback otherwise | + +The first implementation wires LINE group/room text to `BufferedContextProvider`. Future adapters can reuse the same trait without changing the prompt injection format. + +## Defaults + +Context capture is disabled by default. + +| Setting | Default | +|---|---| +| `enabled` | `false` | +| `ttl` | `24h` | +| `max_messages` | `50` | +| `max_chars` | `8000` | + +Gateway-wide environment variables use the `GATEWAY_CONTEXT_*` prefix. Platform-specific settings can override them; for example LINE uses `LINE_GROUP_CONTEXT_*`. + +## Scope And Storage + +Buffered context is: + +- scoped by platform, channel, optional thread, and bot id +- in-memory only +- drained after injection +- bounded by TTL, message count, and total characters +- not long-term memory, retrieval storage, or GBrain state + +This is intentionally short-term conversational continuity. Platforms with reliable history APIs can later implement API-backed or hybrid providers instead of relying only on local buffers. diff --git a/docs/line.md b/docs/line.md index c5890f2a6..796724c45 100644 --- a/docs/line.md +++ b/docs/line.md @@ -83,7 +83,7 @@ In the LINE Developers Console → **Messaging API** tab → scan the QR code wi - **1:1 chat** — send a message to the bot, get an AI agent response - **Inbound voice messages in 1:1 chat** — LINE-hosted audio messages are downloaded through the LINE Content API and forwarded to OpenAB as `audio` attachments, so the existing STT flow can transcribe them. This requires `[stt] enabled = true` in OpenAB core. See [STT (Speech-to-Text)](stt.md). -- **Group chat** — add the bot to a group; it responds only when @-mentioned (see @mention gating below) +- **Group chat** — add the bot to a group; it responds when directly @-mentioned. Deployments can opt in to folding recent unmentioned text into the next direct-mention turn through the gateway ContextProvider buffer (see @mention gating below) - **Inbound images** — user-sent LINE images are downloaded through the LINE Content API and forwarded to OpenAB as image attachments - **Webhook signature validation** — HMAC-SHA256 via `LINE_CHANNEL_SECRET` @@ -96,11 +96,12 @@ In the LINE Developers Console → **Messaging API** tab → scan the QR code wi - **Threads** — LINE has no thread/topic concept. All messages in a chat share one agent session. - **Reactions** — LINE Bot API does not support message reactions. -- **@mention gating** — Supported (zero-config). In group/room chats the gateway only forwards messages where the bot is explicitly @-mentioned (LINE's native `mentionees[].isSelf` signal). 1:1 DMs are always forwarded. No env var is needed. - - *Limitation — non-text messages*: LINE only attaches mention data to text messages. Images, videos, stickers, files, and location messages in groups are silently dropped because they cannot carry an @-mention. - - *Limitation — group voice messages*: LINE voice/audio messages in groups and rooms are also dropped today because audio messages do not carry mention metadata. This PR only enables inbound voice STT for 1:1 chats. +- **@mention gating** — Supported (zero-config). In group/room chats the gateway only dispatches a visible bot reply when the bot is explicitly @-mentioned (LINE's native `mentionees[].isSelf` signal). 1:1 DMs are always forwarded. No env var is needed. + - *Optional short-term text buffering*: when `LINE_GROUP_CONTEXT_ENABLED=true`, unmentioned **text** messages in groups/rooms are observed by the gateway ContextProvider and buffered locally for up to 24 hours. The next directly @-mentioned text turn for the same chat drains that buffer and prepends it as short-term context. The buffer is capped per chat by message count and total characters. This improves conversational continuity without making the bot reply to every group message. + - *Limitation — non-text messages*: LINE only attaches mention data to text messages. Images, videos, stickers, files, and location messages in groups are still dropped when not directly @-mentioned because they do not enter the short-term text buffer. + - *Limitation — group voice messages*: LINE voice/audio messages in groups and rooms are also dropped today because audio messages do not carry mention metadata. LINE inbound voice STT is currently for 1:1 chats. - *Limitation — `@All`*: A group-wide `@All` mention does **not** trigger the bot; only a direct `@BotName` mention does. - - *Breaking change*: This gating is always active. Deployments that previously relied on the bot responding to all group messages will need to @-mention the bot after upgrading. + - *Behavior note*: the short-term context buffer is local, bounded, temporary, scoped by platform/chat/thread/bot, and drained after injection. It is not a long-term chat archive or GBrain memory store. - **Markdown rendering** — LINE uses its own text formatting. Agent replies are sent as plain text. - **External-content images** — LINE image messages backed by `contentProvider.type = "external"` are not downloaded yet. - **External-content audio** — LINE audio messages backed by `contentProvider.type = "external"` are not downloaded yet. @@ -111,6 +112,13 @@ In the LINE Developers Console → **Messaging API** tab → scan the QR code wi |---|---|---| | `LINE_CHANNEL_SECRET` | Yes | Channel secret for webhook signature validation | | `LINE_CHANNEL_ACCESS_TOKEN` | Yes | Channel access token for Reply/Push Message API and LINE-hosted image/audio downloads | +| `LINE_GROUP_CONTEXT_ENABLED` | No | Opt in to buffering unmentioned group/room text for the next direct mention. Falls back to `GATEWAY_CONTEXT_ENABLED`. Default: `false` | +| `LINE_GROUP_CONTEXT_TTL_HOURS` | No | Hours to keep unmentioned group/room text eligible for the next direct mention. Falls back to `GATEWAY_CONTEXT_TTL_HOURS`. Default: `24` | +| `LINE_GROUP_CONTEXT_MAX_MESSAGES` | No | Maximum buffered unmentioned text messages per group/room. Falls back to `GATEWAY_CONTEXT_MAX_MESSAGES`. Default: `50` | +| `LINE_GROUP_CONTEXT_MAX_CHARS` | No | Maximum total buffered text characters per group/room. Falls back to `GATEWAY_CONTEXT_MAX_CHARS`. Default: `8000` | +| `LINE_CONTEXT_BOT_ID` | No | Stable bot identity used for ContextProvider isolation when multiple bots share a LINE group. Falls back to `LINE_BOT_ID`, then `line-default-bot` | + +See [Gateway Context Providers](./gateway-context.md) for the shared context buffering model and future API-fetch/hybrid provider direction. ## Troubleshooting diff --git a/gateway/Cargo.lock b/gateway/Cargo.lock index c1567f997..f84917a97 100644 --- a/gateway/Cargo.lock +++ b/gateway/Cargo.lock @@ -53,6 +53,17 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -1116,6 +1127,7 @@ version = "0.5.4" dependencies = [ "aes", "anyhow", + "async-trait", "axum", "base64", "cbc", @@ -1124,6 +1136,7 @@ dependencies = [ "hmac", "image", "jsonwebtoken", + "parking_lot", "prost", "quick-xml", "reqwest", diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index e40160f8a..a9b1ce817 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -14,6 +14,7 @@ reqwest = { version = "0.12", default-features = false, features = ["rustls-tls" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } anyhow = "1" +async-trait = "0.1" uuid = { version = "1", features = ["v4"] } chrono = { version = "0.4", features = ["serde"] } hmac = "0.12" diff --git a/gateway/src/adapters/line.rs b/gateway/src/adapters/line.rs index c0981f6c5..c4432c871 100644 --- a/gateway/src/adapters/line.rs +++ b/gateway/src/adapters/line.rs @@ -1,3 +1,4 @@ +use crate::context::{inject_context, ContextFetchRequest, ContextObserveRequest, ContextScope}; use crate::media::{ audio_extension, format_bytes, resize_and_compress, AUDIO_MAX_DOWNLOAD, IMAGE_MAX_DOWNLOAD, }; @@ -154,12 +155,21 @@ async fn process_line_webhook_events( // - Guardrail: a shared semaphore bounds how many LINE payloads can enter the // post-ack path concurrently. When saturated, new webhooks wait for capacity // before spawning background work so bursts do not create unbounded backlog. + let line_context_provider = state.context_providers.get("line").cloned(); + let line_context_bot_id = state + .context_bot_ids + .get("line") + .cloned() + .unwrap_or_else(|| "line-default-bot".into()); + for event in webhook_body.events { let Some(gateway_event) = build_gateway_event_from_line_event( &event, &state.client, state.line_access_token.as_deref(), LINE_DATA_API_BASE, + line_context_provider.clone(), + &line_context_bot_id, ) .await else { @@ -206,6 +216,8 @@ async fn build_gateway_event_from_line_event( client: &reqwest::Client, line_access_token: Option<&str>, data_api_base: &str, + line_context_provider: Option>, + line_context_bot_id: &str, ) -> Option { if event.event_type != "message" { return None; @@ -315,7 +327,7 @@ async fn build_gateway_event_from_line_event( } } - let event_text = text; + let mut event_text = text.to_string(); if event_text.trim().is_empty() && attachments.is_empty() { return None; @@ -369,7 +381,46 @@ async fn build_gateway_event_from_line_event( // LINE sets isSelf=true on the mentionee that is the bot itself — no env var needed. // 1:1 DMs always pass through. let is_group = channel_type == "group" || channel_type == "room"; - if is_group && !mentionees.iter().any(|m| m.is_self) { + let bot_mentioned = mentionees.iter().any(|m| m.is_self); + if is_group + && msg.message_type == "text" + && line_context_provider + .as_ref() + .is_some_and(|p| p.is_enabled()) + { + let scope = ContextScope::new("line", &channel_id, None, line_context_bot_id); + if !bot_mentioned { + let observed = line_context_provider + .as_ref() + .expect("checked provider above") + .observe(ContextObserveRequest { + scope, + sender_id: user_id.to_string(), + sender_label: user_id.to_string(), + text: event_text.clone(), + }) + .await; + info!( + channel = %channel_id, + observed, + "line group text buffered (bot not mentioned)" + ); + return None; + } + if let Some(context) = line_context_provider + .as_ref() + .expect("checked provider above") + .fetch_context(ContextFetchRequest { scope, limit: None }) + .await + { + info!( + channel = %channel_id, + buffered_messages = context.len(), + "line group context injected into direct mention" + ); + event_text = inject_context(&context, &event_text); + } + } else if is_group && !bot_mentioned { info!( channel = %channel_id, "line group message dropped (@mention gating: bot not mentioned)" @@ -390,7 +441,7 @@ async fn build_gateway_event_from_line_event( display_name: user_id.into(), is_bot: false, }, - event_text, + &event_text, &msg.id, mention_ids, ); @@ -739,6 +790,55 @@ mod tests { use wiremock::matchers::{header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; + const TEST_LINE_CONTEXT_BOT_ID: &str = "line-default-bot"; + + fn test_line_context_config() -> crate::context::ContextConfig { + crate::context::ContextConfig::default() + } + + fn enabled_line_context_config() -> crate::context::ContextConfig { + crate::context::ContextConfig { + enabled: true, + ..crate::context::ContextConfig::default() + } + } + + fn disabled_line_context_provider() -> Arc { + Arc::new(crate::context::BufferedContextProvider::new( + test_line_context_config(), + )) + } + + fn enabled_line_context_provider() -> Arc { + Arc::new(crate::context::BufferedContextProvider::new( + enabled_line_context_config(), + )) + } + + fn as_context_provider( + provider: &Arc, + ) -> Arc { + provider.clone() + } + + fn line_context_scope(channel_id: &str) -> crate::context::ContextScope { + crate::context::ContextScope::new("line", channel_id, None, TEST_LINE_CONTEXT_BOT_ID) + } + + fn context_provider_registry( + provider: Arc, + ) -> crate::ContextProviderRegistry { + let mut providers = HashMap::new(); + providers.insert("line".into(), provider); + Arc::new(providers) + } + + fn context_bot_id_registry() -> crate::ContextBotIdRegistry { + let mut bot_ids = HashMap::new(); + bot_ids.insert("line".into(), TEST_LINE_CONTEXT_BOT_ID.into()); + Arc::new(bot_ids) + } + #[tokio::test] async fn download_line_image_resizes_and_returns_attachment() { let server = MockServer::start().await; @@ -814,6 +914,8 @@ mod tests { &reqwest::Client::new(), Some("line_token"), &server.uri(), + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await .expect("image event should produce a gateway event"); @@ -864,6 +966,8 @@ mod tests { &reqwest::Client::new(), Some("line_token"), &server.uri(), + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await .expect("audio event should produce a gateway event"); @@ -919,6 +1023,8 @@ mod tests { &reqwest::Client::new(), Some("line_token"), &server.uri(), + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await .expect("audio event should produce a gateway event"); @@ -1032,6 +1138,8 @@ mod tests { &reqwest::Client::new(), None, LINE_DATA_API_BASE, + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await; @@ -1068,6 +1176,8 @@ mod tests { &reqwest::Client::new(), None, LINE_DATA_API_BASE, + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await; @@ -1102,6 +1212,8 @@ mod tests { &reqwest::Client::new(), None, // no access token LINE_DATA_API_BASE, + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await; @@ -1131,6 +1243,8 @@ mod tests { &reqwest::Client::new(), None, LINE_DATA_API_BASE, + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await; @@ -1161,6 +1275,10 @@ mod tests { event_tx, reply_token_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), line_webhook_semaphore: Arc::new(Semaphore::new(crate::LINE_WEBHOOK_CONCURRENCY_MAX)), + context_providers: context_provider_registry(as_context_provider( + &disabled_line_context_provider(), + )), + context_bot_ids: context_bot_id_registry(), client: reqwest::Client::new(), }); @@ -1228,6 +1346,8 @@ mod tests { &reqwest::Client::new(), None, LINE_DATA_API_BASE, + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await; assert!(result.is_some()); @@ -1236,20 +1356,44 @@ mod tests { } #[tokio::test] - async fn group_message_dropped_when_bot_not_mentioned() { + async fn group_message_drops_without_buffer_when_context_disabled() { + let provider = disabled_line_context_provider(); let event = make_group_text_event("hey everyone", false); let result = build_gateway_event_from_line_event( &event, &reqwest::Client::new(), None, LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, ) .await; assert!(result.is_none()); + assert_eq!(provider.buffered_len(&line_context_scope("C001")), 0); } #[tokio::test] - async fn group_message_dropped_when_no_mention_at_all() { + async fn group_message_buffers_when_bot_not_mentioned_and_context_enabled() { + let provider = enabled_line_context_provider(); + let event = make_group_text_event("hey everyone", false); + let result = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await; + assert!(result.is_none()); + assert_eq!( + provider.buffered_texts(&line_context_scope("C001")), + vec!["hey everyone"] + ); + } + + #[tokio::test] + async fn group_message_buffers_when_no_mention_at_all() { let event: LineEvent = serde_json::from_value(serde_json::json!({ "type": "message", "source": {"type": "group", "groupId": "C001", "userId": "U_sender"}, @@ -1261,6 +1405,8 @@ mod tests { &reqwest::Client::new(), None, LINE_DATA_API_BASE, + Some(as_context_provider(&enabled_line_context_provider())), + TEST_LINE_CONTEXT_BOT_ID, ) .await; assert!(result.is_none()); @@ -1279,8 +1425,234 @@ mod tests { &reqwest::Client::new(), None, LINE_DATA_API_BASE, + None, + TEST_LINE_CONTEXT_BOT_ID, ) .await; assert!(result.is_some()); } + + #[tokio::test] + async fn group_message_buffers_then_injects_context_on_later_mention() { + let provider = enabled_line_context_provider(); + + let first = make_group_text_event("今天下午兩點開會", false); + let first_result = build_gateway_event_from_line_event( + &first, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await; + assert!( + first_result.is_none(), + "unmentioned message should only be buffered" + ); + + let second = make_group_text_event("@Bot 幫我總結一下", true); + let second_result = build_gateway_event_from_line_event( + &second, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .expect("mentioned message should produce an event"); + + assert!(second_result + .content + .text + .contains("[Recent conversation context before this trigger]")); + assert!(second_result + .content + .text + .contains("U_sender: 今天下午兩點開會")); + assert!(second_result + .content + .text + .contains("[Current message - respond to this]")); + assert!(second_result.content.text.contains("@Bot 幫我總結一下")); + assert_eq!(provider.buffered_len(&line_context_scope("C001")), 0); + } + + #[tokio::test] + async fn direct_mention_without_buffer_keeps_original_text() { + let provider = enabled_line_context_provider(); + + let event = make_group_text_event("@Bot 現在狀況如何", true); + let result = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .expect("mentioned message should produce an event"); + + assert_eq!(result.content.text, "@Bot 現在狀況如何"); + } + + #[tokio::test] + async fn multiple_buffered_messages_preserve_order_on_injection() { + let provider = enabled_line_context_provider(); + + let first: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "group", "groupId": "C001", "userId": "U_alice"}, + "message": {"id": "msg001", "type": "text", "text": "第一句前文"} + })) + .unwrap(); + let second: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "group", "groupId": "C001", "userId": "U_bob"}, + "message": {"id": "msg002", "type": "text", "text": "第二句前文"} + })) + .unwrap(); + + assert!(build_gateway_event_from_line_event( + &first, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .is_none()); + assert!(build_gateway_event_from_line_event( + &second, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .is_none()); + + let mention = make_group_text_event("@Bot 幫我整理一下", true); + let result = build_gateway_event_from_line_event( + &mention, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .expect("mentioned message should produce an event"); + + let first_idx = result + .content + .text + .find("U_alice: 第一句前文") + .expect("first buffered line present"); + let second_idx = result + .content + .text + .find("U_bob: 第二句前文") + .expect("second buffered line present"); + let current_idx = result + .content + .text + .find("[Current message - respond to this]") + .expect("current message header present"); + + assert!( + first_idx < second_idx, + "buffered lines should preserve arrival order" + ); + assert!( + second_idx < current_idx, + "buffered context should appear before current message" + ); + } + + #[tokio::test] + async fn buffer_is_chat_local_and_not_reused_after_drain() { + let provider = enabled_line_context_provider(); + + let buffered: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "group", "groupId": "C001", "userId": "U_sender"}, + "message": {"id": "msg001", "type": "text", "text": "只屬於 C001 的前文"} + })) + .unwrap(); + + assert!(build_gateway_event_from_line_event( + &buffered, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .is_none()); + + let other_chat: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "group", "groupId": "C002", "userId": "U_sender"}, + "message": { + "id": "msg010", + "type": "text", + "text": "@Bot 另一個群組的 mention", + "mention": {"mentionees": [{"userId": "Ubot123", "type": "user", "isSelf": true}]} + } + })) + .unwrap(); + + let other_result = build_gateway_event_from_line_event( + &other_chat, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .expect("other chat mention should produce an event"); + assert!( + !other_result.content.text.contains("只屬於 C001 的前文"), + "buffered context must not leak across chats" + ); + + let same_chat = make_group_text_event("@Bot C001 的 mention", true); + let same_result = build_gateway_event_from_line_event( + &same_chat, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .expect("same chat mention should produce an event"); + assert!(same_result + .content + .text + .contains("U_sender: 只屬於 C001 的前文")); + + let second_same_chat = make_group_text_event("@Bot 再問一次", true); + let drained_result = build_gateway_event_from_line_event( + &second_same_chat, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + Some(as_context_provider(&provider)), + TEST_LINE_CONTEXT_BOT_ID, + ) + .await + .expect("second same-chat mention should produce an event"); + assert!( + !drained_result.content.text.contains("只屬於 C001 的前文"), + "buffer should be one-shot and drain after injection" + ); + } } diff --git a/gateway/src/adapters/teams.rs b/gateway/src/adapters/teams.rs index 09ac09df8..d70d6c6e2 100644 --- a/gateway/src/adapters/teams.rs +++ b/gateway/src/adapters/teams.rs @@ -666,6 +666,8 @@ mod tests { event_tx, reply_token_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), line_webhook_semaphore: Arc::new(tokio::sync::Semaphore::new(crate::LINE_WEBHOOK_CONCURRENCY_MAX)), + context_providers: Arc::new(std::collections::HashMap::new()), + context_bot_ids: Arc::new(std::collections::HashMap::new()), client: reqwest::Client::new(), }) } diff --git a/gateway/src/context/api_fetch.rs b/gateway/src/context/api_fetch.rs new file mode 100644 index 000000000..fce4ead80 --- /dev/null +++ b/gateway/src/context/api_fetch.rs @@ -0,0 +1,20 @@ +use super::{ContextFetchRequest, ContextMessage, ContextObserveRequest, ContextProvider}; + +#[derive(Clone, Debug, Default)] +#[allow(dead_code)] +pub struct ApiFetchContextProvider; + +#[async_trait::async_trait] +impl ContextProvider for ApiFetchContextProvider { + fn is_enabled(&self) -> bool { + false + } + + async fn observe(&self, _request: ContextObserveRequest) -> bool { + false + } + + async fn fetch_context(&self, _request: ContextFetchRequest) -> Option> { + None + } +} diff --git a/gateway/src/context/buffered.rs b/gateway/src/context/buffered.rs new file mode 100644 index 000000000..19b802e5a --- /dev/null +++ b/gateway/src/context/buffered.rs @@ -0,0 +1,243 @@ +use super::{ + ContextConfig, ContextFetchRequest, ContextMessage, ContextObserveRequest, ContextProvider, + ContextScope, +}; +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::time::Instant; + +#[derive(Clone, Debug)] +struct BufferedContextMessage { + message: ContextMessage, + observed_at: Instant, +} + +#[derive(Clone, Debug)] +pub struct BufferedContextProvider { + config: ContextConfig, + buffers: Arc>>>, +} + +impl BufferedContextProvider { + pub fn new(config: ContextConfig) -> Self { + Self { + config, + buffers: Arc::new(std::sync::Mutex::new(HashMap::new())), + } + } + + #[cfg(test)] + pub fn buffered_texts(&self, scope: &ContextScope) -> Vec { + let guard = self.buffers.lock().unwrap_or_else(|e| e.into_inner()); + guard + .get(scope) + .map(|entry| { + entry + .iter() + .map(|message| message.message.text.clone()) + .collect() + }) + .unwrap_or_default() + } + + #[cfg(test)] + pub fn buffered_len(&self, scope: &ContextScope) -> usize { + let guard = self.buffers.lock().unwrap_or_else(|e| e.into_inner()); + guard.get(scope).map(VecDeque::len).unwrap_or_default() + } + + fn prune_expired(&self, entry: &mut VecDeque) { + let now = Instant::now(); + entry.retain(|m| now.duration_since(m.observed_at).as_secs() < self.config.ttl_secs); + } + + fn enforce_limits(&self, entry: &mut VecDeque, max_messages: usize) { + while entry.len() > max_messages { + entry.pop_front(); + } + while entry.len() > 1 && context_char_count(entry) > self.config.max_chars { + entry.pop_front(); + } + } +} + +#[async_trait::async_trait] +impl ContextProvider for BufferedContextProvider { + fn is_enabled(&self) -> bool { + self.config.enabled + } + + async fn observe(&self, request: ContextObserveRequest) -> bool { + let trimmed = request.text.trim(); + if !self.config.enabled + || self.config.max_messages == 0 + || self.config.max_chars == 0 + || trimmed.is_empty() + { + return false; + } + + let mut guard = self.buffers.lock().unwrap_or_else(|e| e.into_inner()); + let entry = guard.entry(request.scope).or_default(); + self.prune_expired(entry); + + let bounded_text: String = trimmed.chars().take(self.config.max_chars).collect(); + entry.push_back(BufferedContextMessage { + message: ContextMessage { + sender_id: request.sender_id, + sender_label: request.sender_label, + text: bounded_text, + }, + observed_at: Instant::now(), + }); + self.enforce_limits(entry, self.config.max_messages); + true + } + + async fn fetch_context(&self, request: ContextFetchRequest) -> Option> { + if !self.config.enabled || self.config.max_messages == 0 || self.config.max_chars == 0 { + return None; + } + + let mut guard = self.buffers.lock().unwrap_or_else(|e| e.into_inner()); + let mut entry = guard.remove(&request.scope)?; + self.prune_expired(&mut entry); + + let max_messages = request + .limit + .unwrap_or(self.config.max_messages) + .min(self.config.max_messages); + self.enforce_limits(&mut entry, max_messages); + + if entry.is_empty() { + None + } else { + Some(entry.into_iter().map(|message| message.message).collect()) + } + } +} + +fn context_char_count(entry: &VecDeque) -> usize { + entry + .iter() + .map(|m| { + m.message.sender_label.chars().count() + + m.message.sender_id.chars().count() + + m.message.text.chars().count() + + 2 + }) + .sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn enabled_provider(max_messages: usize, max_chars: usize) -> BufferedContextProvider { + BufferedContextProvider::new(ContextConfig { + enabled: true, + ttl_secs: 24 * 60 * 60, + max_messages, + max_chars, + }) + } + + fn scope(channel: &str) -> ContextScope { + ContextScope::new("line", channel, None, "line-default-bot") + } + + async fn observe(provider: &BufferedContextProvider, scope: ContextScope, text: &str) { + provider + .observe(ContextObserveRequest { + scope, + sender_id: "U_sender".into(), + sender_label: "U_sender".into(), + text: text.into(), + }) + .await; + } + + #[tokio::test] + async fn observe_fetches_and_drains_context() { + let provider = enabled_provider(50, 8_000); + let scope = scope("C001"); + + observe(&provider, scope.clone(), "hello").await; + let context = provider + .fetch_context(ContextFetchRequest { + scope: scope.clone(), + limit: None, + }) + .await + .expect("context should be returned"); + + assert_eq!(context.len(), 1); + assert_eq!(context[0].text, "hello"); + assert!(provider + .fetch_context(ContextFetchRequest { scope, limit: None }) + .await + .is_none()); + } + + #[tokio::test] + async fn observe_is_noop_when_disabled() { + let provider = BufferedContextProvider::new(ContextConfig::default()); + let scope = scope("C001"); + + observe(&provider, scope.clone(), "hello").await; + + assert_eq!(provider.buffered_len(&scope), 0); + } + + #[tokio::test] + async fn max_messages_keeps_latest_context() { + let provider = enabled_provider(2, 8_000); + let scope = scope("C001"); + + observe(&provider, scope.clone(), "first").await; + observe(&provider, scope.clone(), "second").await; + observe(&provider, scope.clone(), "third").await; + + assert_eq!(provider.buffered_texts(&scope), vec!["second", "third"]); + } + + #[tokio::test] + async fn max_chars_trims_old_context() { + let provider = enabled_provider(10, 20); + let scope = scope("C001"); + + observe(&provider, scope.clone(), "first long message").await; + observe(&provider, scope.clone(), "second long message").await; + + let texts = provider.buffered_texts(&scope); + assert!(texts.len() <= 2); + assert_eq!( + texts.last().map(String::as_str), + Some("second long message") + ); + } + + #[tokio::test] + async fn scope_isolation_prevents_cross_chat_leakage() { + let provider = enabled_provider(50, 8_000); + let first = scope("C001"); + let second = scope("C002"); + + observe(&provider, first.clone(), "first chat").await; + + assert!(provider + .fetch_context(ContextFetchRequest { + scope: second, + limit: None, + }) + .await + .is_none()); + assert!(provider + .fetch_context(ContextFetchRequest { + scope: first, + limit: None, + }) + .await + .is_some()); + } +} diff --git a/gateway/src/context/config.rs b/gateway/src/context/config.rs new file mode 100644 index 000000000..a9f27bcc0 --- /dev/null +++ b/gateway/src/context/config.rs @@ -0,0 +1,98 @@ +pub const DEFAULT_CONTEXT_TTL_HOURS: u64 = 24; +pub const DEFAULT_CONTEXT_MAX_MESSAGES: usize = 50; +pub const DEFAULT_CONTEXT_MAX_CHARS: usize = 8_000; + +#[derive(Clone, Debug)] +pub struct ContextConfig { + pub enabled: bool, + pub ttl_secs: u64, + pub max_messages: usize, + pub max_chars: usize, +} + +impl Default for ContextConfig { + fn default() -> Self { + Self { + enabled: false, + ttl_secs: DEFAULT_CONTEXT_TTL_HOURS * 60 * 60, + max_messages: DEFAULT_CONTEXT_MAX_MESSAGES, + max_chars: DEFAULT_CONTEXT_MAX_CHARS, + } + } +} + +impl ContextConfig { + pub fn from_env_with_prefixes(prefixes: &[&str]) -> Self { + let defaults = Self::default(); + let ttl_hours = + read_positive_env_u64(prefixes, "CONTEXT_TTL_HOURS", DEFAULT_CONTEXT_TTL_HOURS); + + Self { + enabled: read_bool_env(prefixes, "CONTEXT_ENABLED", defaults.enabled), + ttl_secs: ttl_hours.saturating_mul(60 * 60), + max_messages: read_positive_env_usize( + prefixes, + "CONTEXT_MAX_MESSAGES", + defaults.max_messages, + ), + max_chars: read_positive_env_usize(prefixes, "CONTEXT_MAX_CHARS", defaults.max_chars), + } + } +} + +fn read_bool_env(prefixes: &[&str], suffix: &str, default: bool) -> bool { + env_names(prefixes, suffix) + .into_iter() + .find_map(|name| { + std::env::var(name) + .ok() + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + }) + .unwrap_or(default) +} + +fn read_positive_env_u64(prefixes: &[&str], suffix: &str, default: u64) -> u64 { + env_names(prefixes, suffix) + .into_iter() + .find_map(|name| { + std::env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|v| *v > 0) + }) + .unwrap_or(default) +} + +fn read_positive_env_usize(prefixes: &[&str], suffix: &str, default: usize) -> usize { + env_names(prefixes, suffix) + .into_iter() + .find_map(|name| { + std::env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|v| *v > 0) + }) + .unwrap_or(default) +} + +fn env_names(prefixes: &[&str], suffix: &str) -> Vec { + prefixes + .iter() + .map(|prefix| format!("{prefix}_{suffix}")) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_context_config_is_disabled_and_bounded() { + let config = ContextConfig::default(); + + assert!(!config.enabled); + assert_eq!(config.ttl_secs, 24 * 60 * 60); + assert_eq!(config.max_messages, 50); + assert_eq!(config.max_chars, 8_000); + } +} diff --git a/gateway/src/context/mod.rs b/gateway/src/context/mod.rs new file mode 100644 index 000000000..d728745a4 --- /dev/null +++ b/gateway/src/context/mod.rs @@ -0,0 +1,116 @@ +pub mod api_fetch; +pub mod buffered; +pub mod config; + +pub use buffered::BufferedContextProvider; +pub use config::ContextConfig; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct ContextScope { + pub platform: String, + pub channel_id: String, + pub thread_id: Option, + pub bot_id: String, +} + +impl ContextScope { + pub fn new( + platform: impl Into, + channel_id: impl Into, + thread_id: Option, + bot_id: impl Into, + ) -> Self { + Self { + platform: platform.into(), + channel_id: channel_id.into(), + thread_id, + bot_id: bot_id.into(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ContextMessage { + pub sender_id: String, + pub sender_label: String, + pub text: String, +} + +#[derive(Clone, Debug)] +pub struct ContextObserveRequest { + pub scope: ContextScope, + pub sender_id: String, + pub sender_label: String, + pub text: String, +} + +#[derive(Clone, Debug)] +pub struct ContextFetchRequest { + pub scope: ContextScope, + pub limit: Option, +} + +#[async_trait::async_trait] +pub trait ContextProvider: Send + Sync { + fn is_enabled(&self) -> bool; + + async fn observe(&self, request: ContextObserveRequest) -> bool; + + async fn fetch_context(&self, request: ContextFetchRequest) -> Option>; +} + +pub fn inject_context(context: &[ContextMessage], current_text: &str) -> String { + if context.is_empty() { + return current_text.to_string(); + } + + let mut lines = Vec::with_capacity(context.len() + 3); + lines.push("[Recent conversation context before this trigger]".to_string()); + for message in context { + let label = if message.sender_label.trim().is_empty() { + message.sender_id.as_str() + } else { + message.sender_label.as_str() + }; + lines.push(format!("{}: {}", label, message.text)); + } + lines.push(String::new()); + lines.push("[Current message - respond to this]".to_string()); + lines.push(current_text.to_string()); + lines.join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn inject_context_wraps_history_and_current_message() { + let text = inject_context( + &[ + ContextMessage { + sender_id: "U1".into(), + sender_label: "Alice".into(), + text: "first".into(), + }, + ContextMessage { + sender_id: "U2".into(), + sender_label: "Bob".into(), + text: "second".into(), + }, + ], + "@Bot summarize", + ); + + assert!(text.contains("[Recent conversation context before this trigger]")); + assert!(text.contains("Alice: first")); + assert!(text.contains("Bob: second")); + assert!(text.contains("[Current message - respond to this]")); + assert!(text.contains("@Bot summarize")); + } + + #[test] + fn inject_context_keeps_current_message_when_history_empty() { + assert_eq!(inject_context(&[], "hello"), "hello"); + } +} diff --git a/gateway/src/main.rs b/gateway/src/main.rs index c5dfec24f..8aba99f61 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,4 +1,5 @@ mod adapters; +mod context; mod media; mod schema; pub mod store; @@ -38,6 +39,9 @@ pub const REPLY_TOKEN_CACHE_MAX: usize = 10_000; /// fast 200 OK response path. pub const LINE_WEBHOOK_CONCURRENCY_MAX: usize = 8; +pub type ContextProviderRegistry = Arc>>; +pub type ContextBotIdRegistry = Arc>; + // --- App state (shared across all adapters) --- pub struct AppState { @@ -72,6 +76,10 @@ pub struct AppState { /// Limits concurrent post-ack LINE webhook processing so image bursts do not /// turn into unbounded download/decode work. pub line_webhook_semaphore: Arc, + /// Gateway-level context providers keyed by platform name. + pub context_providers: ContextProviderRegistry, + /// Stable bot identity per platform for context isolation in shared chats. + pub context_bot_ids: ContextBotIdRegistry, /// Shared HTTP client for media downloads and API calls pub client: reqwest::Client, } @@ -236,6 +244,26 @@ async fn main() -> Result<()> { let (event_tx, _) = broadcast::channel::(256); let reply_token_cache: ReplyTokenCache = Arc::new(std::sync::Mutex::new(HashMap::new())); + let line_context_config = + context::ContextConfig::from_env_with_prefixes(&["LINE_GROUP", "GATEWAY"]); + let line_context_bot_id = std::env::var("LINE_CONTEXT_BOT_ID") + .or_else(|_| std::env::var("LINE_BOT_ID")) + .unwrap_or_else(|_| "line-default-bot".into()); + if line_context_config.enabled { + info!( + ttl_secs = line_context_config.ttl_secs, + max_messages = line_context_config.max_messages, + max_chars = line_context_config.max_chars, + "line buffered context provider enabled" + ); + } + let mut context_providers = HashMap::>::new(); + context_providers.insert( + "line".into(), + Arc::new(context::BufferedContextProvider::new(line_context_config)), + ); + let mut context_bot_ids = HashMap::new(); + context_bot_ids.insert("line".into(), line_context_bot_id); let mut app = Router::new() .route("/ws", get(ws_handler)) @@ -410,6 +438,8 @@ async fn main() -> Result<()> { event_tx, reply_token_cache, line_webhook_semaphore: Arc::new(Semaphore::new(LINE_WEBHOOK_CONCURRENCY_MAX)), + context_providers: Arc::new(context_providers), + context_bot_ids: Arc::new(context_bot_ids), client, });