diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index c0aec8b9ed..8fb6735eb9 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -7,7 +7,7 @@ //! 4. Run event loop routing responses to predictions //! 5. On worker crash: fail all predictions, shut down -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex as StdMutex; @@ -353,15 +353,18 @@ pub struct OrchestratorReady { pub setup_logs: String, } +type RegisterPredictionMessage = ( + SlotId, + Arc>, + tokio::sync::oneshot::Sender, + tokio::sync::oneshot::Sender<()>, +); + pub struct OrchestratorHandle { child: Child, ctrl_writer: Arc>>>, - register_tx: mpsc::Sender<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + register_tx: mpsc::Sender, healthcheck_tx: mpsc::Sender>, cancel_tx: mpsc::Sender, slot_ids: Vec, @@ -375,10 +378,12 @@ impl Orchestrator for OrchestratorHandle { prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ) { + let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); let _ = self .register_tx - .send((slot_id, prediction, idle_sender)) + .send((slot_id, prediction, idle_sender, ack_tx)) .await; + let _ = ack_rx.await; } async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError> { @@ -698,11 +703,7 @@ async fn run_event_loop( SlotId, FramedRead>, )>, - mut register_rx: mpsc::Receiver<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + mut register_rx: mpsc::Receiver, mut healthcheck_rx: mpsc::Receiver>, mut cancel_rx: mpsc::Receiver, pool: Arc, @@ -718,6 +719,7 @@ async fn run_event_loop( let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; let mut pending_uploads: HashMap>> = HashMap::new(); + let mut pending_cancellations: HashSet = HashSet::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); @@ -923,17 +925,19 @@ async fn run_event_loop( } } None => { - tracing::debug!(%prediction_id, "Cancel requested for unknown prediction (may have already completed)"); + tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation"); + pending_cancellations.insert(prediction_id); } } } - Some((slot_id, prediction, idle_sender)) = register_rx.recv() => { + Some((slot_id, prediction, idle_sender, registered_tx)) = register_rx.recv() => { let prediction_id = match try_lock_prediction(&prediction) { Some(p) => p.id().to_string(), None => { // Mutex poisoned during registration - prediction already failed tracing::error!(%slot_id, "Prediction mutex poisoned during registration"); + let _ = registered_tx.send(()); continue; } }; @@ -949,6 +953,24 @@ async fn run_event_loop( ); tracing::debug!(%slot_id, %prediction_id, "Registered prediction"); predictions.insert(slot_id, prediction); + let pending_cancel = pending_cancellations.remove(&prediction_id); + let _ = registered_tx.send(()); + if pending_cancel { + tracing::info!( + target: "coglet::prediction", + %prediction_id, + %slot_id, + "Applying pending cancellation" + ); + let mut writer = ctrl_writer.lock().await; + if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await { + tracing::error!( + %slot_id, + error = %e, + "Failed to send pending cancel request to worker" + ); + } + } } Some((slot_id, result)) = slot_msg_rx.recv() => { @@ -966,7 +988,7 @@ async fn run_event_loop( Ok(SlotResponse::LogLine { source, data }) => { let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_log(&data); + p.append_log_source(source, &data); (Some(p.id().to_string()), false) } else { (None, true) @@ -1015,10 +1037,10 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::OutputChunk { output, index: _ }) => { + Ok(SlotResponse::OutputChunk { output, index }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_output(output); + p.append_output_chunk(output, index); false } else { true diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 81ab018b40..10c87fe83e 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -7,9 +7,11 @@ use std::time::Instant; use tokio::sync::Notify; pub use tokio_util::sync::CancellationToken; -use crate::bridge::protocol::MetricMode; +use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; +const MAX_STREAM_HISTORY_EVENTS: usize = 1024; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { Starting, @@ -64,6 +66,71 @@ impl PredictionOutput { } } +#[derive(Debug, Clone)] +pub enum PredictionStreamEvent { + Start { + id: String, + status: String, + }, + Output { + chunk: serde_json::Value, + index: u64, + }, + Log { + source: LogSource, + data: String, + }, + Metric { + name: String, + value: serde_json::Value, + mode: MetricMode, + }, + Completed { + payload: serde_json::Value, + }, +} + +pub struct PredictionStreamReplay { + pub replay: Vec, + pub skipped: u64, + pub receiver: tokio::sync::broadcast::Receiver, +} + +impl PredictionStreamEvent { + pub fn event_name(&self) -> &'static str { + match self { + Self::Start { .. } => "start", + Self::Output { .. } => "output", + Self::Log { .. } => "log", + Self::Metric { .. } => "metric", + Self::Completed { .. } => "completed", + } + } + + pub fn json_data(&self) -> serde_json::Value { + match self { + Self::Start { id, status } => serde_json::json!({ + "id": id, + "status": status, + }), + Self::Output { chunk, index } => serde_json::json!({ + "chunk": chunk, + "index": index, + }), + Self::Log { source, data } => serde_json::json!({ + "source": source, + "data": data, + }), + Self::Metric { name, value, mode } => serde_json::json!({ + "name": name, + "value": value, + "mode": mode, + }), + Self::Completed { payload } => payload.clone(), + } + } +} + /// Prediction lifecycle state. pub struct Prediction { id: String, @@ -76,12 +143,17 @@ pub struct Prediction { error: Option, webhook: Option, completion: Arc, + stream_tx: tokio::sync::broadcast::Sender, + stream_history: Vec, + stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } impl Prediction { pub fn new(id: String, webhook: Option) -> Self { + let (stream_tx, _) = tokio::sync::broadcast::channel(1024); + Self { id, cancel_token: CancellationToken::new(), @@ -93,6 +165,9 @@ impl Prediction { error: None, webhook, completion: Arc::new(Notify::new()), + stream_tx, + stream_history: Vec::new(), + stream_history_skipped: 0, metrics: HashMap::new(), } } @@ -105,6 +180,31 @@ impl Prediction { self.cancel_token.clone() } + pub fn subscribe_stream(&self) -> tokio::sync::broadcast::Receiver { + self.stream_tx.subscribe() + } + + pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { + PredictionStreamReplay { + replay: self.stream_history.clone(), + skipped: self.stream_history_skipped, + receiver: self.stream_tx.subscribe(), + } + } + + pub fn stream_receiver_count(&self) -> usize { + self.stream_tx.receiver_count() + } + + fn emit_stream_event(&mut self, event: PredictionStreamEvent) { + if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { + self.stream_history.remove(0); + self.stream_history_skipped += 1; + } + self.stream_history.push(event.clone()); + let _ = self.stream_tx.send(event); + } + pub fn is_canceled(&self) -> bool { self.cancel_token.is_cancelled() } @@ -119,6 +219,10 @@ impl Prediction { pub fn set_processing(&mut self) { self.status = PredictionStatus::Processing; + self.emit_stream_event(PredictionStreamEvent::Start { + id: self.id.clone(), + status: self.status.as_str().to_string(), + }); self.fire_webhook(WebhookEventType::Start); } @@ -128,6 +232,9 @@ impl Prediction { } self.status = PredictionStatus::Succeeded; self.output = Some(output); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); // notify_one stores a permit so a future .notified().await will // consume it immediately. notify_waiters only wakes currently- @@ -144,6 +251,9 @@ impl Prediction { } self.status = PredictionStatus::Failed; self.error = Some(error); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -153,6 +263,9 @@ impl Prediction { return; } self.status = PredictionStatus::Canceled; + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -162,7 +275,15 @@ impl Prediction { } pub fn append_log(&mut self, data: &str) { + self.append_log_source(LogSource::Stdout, data); + } + + pub fn append_log_source(&mut self, source: LogSource, data: &str) { self.logs.push_str(data); + self.emit_stream_event(PredictionStreamEvent::Log { + source, + data: data.to_string(), + }); self.fire_webhook(WebhookEventType::Logs); } @@ -184,6 +305,12 @@ impl Prediction { return; } + self.emit_stream_event(PredictionStreamEvent::Metric { + name: name.clone(), + value: value.clone(), + mode, + }); + // Dot-path resolution: "a.b.c" → nested objects let parts: Vec<&str> = name.split('.').collect(); if parts.len() > 1 { @@ -297,7 +424,16 @@ impl Prediction { } pub fn append_output(&mut self, output: serde_json::Value) { - self.outputs.push(output); + let index = self.outputs.len() as u64; + self.append_output_chunk(output, index); + } + + pub fn append_output_chunk(&mut self, output: serde_json::Value, index: u64) { + self.outputs.push(output.clone()); + self.emit_stream_event(PredictionStreamEvent::Output { + chunk: output, + index, + }); self.fire_webhook(WebhookEventType::Output); } @@ -484,6 +620,132 @@ mod tests { assert_eq!(pred.outputs().len(), 2); } + #[tokio::test] + async fn prediction_stream_emits_start_output_log_and_completed() { + let mut prediction = Prediction::new("pred_stream".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.append_log("loading\n"); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let start = rx.recv().await.unwrap(); + assert_eq!(start.event_name(), "start"); + assert_eq!( + start.json_data(), + serde_json::json!({"id":"pred_stream","status":"processing"}) + ); + + let output = rx.recv().await.unwrap(); + assert_eq!(output.event_name(), "output"); + assert_eq!( + output.json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + + let log = rx.recv().await.unwrap(); + assert_eq!(log.event_name(), "log"); + assert_eq!( + log.json_data(), + serde_json::json!({"source":"stdout","data":"loading\n"}) + ); + + let completed = rx.recv().await.unwrap(); + assert_eq!(completed.event_name(), "completed"); + assert_eq!(completed.json_data()["id"], "pred_stream"); + assert_eq!(completed.json_data()["status"], "succeeded"); + assert_eq!( + completed.json_data()["output"], + serde_json::json!(["hello"]) + ); + } + + #[tokio::test] + async fn prediction_stream_emits_metric_event() { + let mut prediction = Prediction::new("pred_metric".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_metric( + "tokens".to_string(), + serde_json::json!(1), + MetricMode::Increment, + ); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "metric"); + assert_eq!( + event.json_data(), + serde_json::json!({ + "name":"tokens", + "value":1, + "mode":"increment" + }) + ); + } + + #[tokio::test] + async fn prediction_stream_preserves_log_source() { + let mut prediction = Prediction::new("pred_log_source".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.append_log_source(crate::bridge::protocol::LogSource::Stderr, "warning\n"); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "log"); + assert_eq!( + event.json_data(), + serde_json::json!({"source":"stderr","data":"warning\n"}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_includes_already_emitted_events() { + let mut prediction = Prediction::new("pred_replay".to_string(), None); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let replay = prediction.subscribe_stream_replay(); + let events: Vec<&str> = replay + .replay + .iter() + .map(|event| event.event_name()) + .collect(); + + assert_eq!(events, vec!["start", "output", "completed"]); + assert_eq!(replay.skipped, 0); + assert_eq!( + replay.replay[1].json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + assert_eq!(replay.replay[2].json_data()["status"], "succeeded"); + } + + #[tokio::test] + async fn prediction_stream_replay_is_bounded_to_recent_events() { + let mut prediction = Prediction::new("pred_replay_bounded".to_string(), None); + + prediction.set_processing(); + for index in 0..1100 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!(replay.skipped, 77); + assert_eq!( + replay.replay[0].json_data(), + serde_json::json!({"chunk":76,"index":76}) + ); + assert_eq!( + replay.replay[MAX_STREAM_HISTORY_EVENTS - 1].json_data(), + serde_json::json!({"chunk":1099,"index":1099}) + ); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 300ccab9e7..30a1c85163 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -79,6 +79,7 @@ struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, + cancel_on_stream_drop: bool, } /// Handle to a submitted prediction for cancellation on disconnect. @@ -106,6 +107,51 @@ impl PredictionHandle { } } +pub struct PredictionStreamSubscription { + id: String, + replay: Vec, + skipped: u64, + receiver: tokio::sync::broadcast::Receiver, + guard: PredictionStreamGuard, +} + +impl PredictionStreamSubscription { + pub fn prediction_id(&self) -> &str { + &self.id + } + + pub fn into_parts( + self, + ) -> ( + Vec, + u64, + tokio::sync::broadcast::Receiver, + PredictionStreamGuard, + ) { + (self.replay, self.skipped, self.receiver, self.guard) + } +} + +pub struct PredictionStreamGuard { + id: String, + service: Arc, + cancel_on_stream_drop: bool, +} + +impl Drop for PredictionStreamGuard { + fn drop(&mut self) { + if !self.cancel_on_stream_drop { + return; + } + + if self.service.stream_receiver_count(&self.id) == 0 + && !self.service.prediction_is_terminal(&self.id) + { + self.service.cancel(&self.id); + } + } +} + /// Guard for sync predictions - cancels on drop unless disarmed. /// /// When the HTTP connection drops (client disconnect), axum drops the @@ -415,6 +461,7 @@ impl PredictionService { id: String, input: serde_json::Value, webhook: Option, + cancel_on_stream_drop: bool, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { @@ -442,6 +489,7 @@ impl PredictionService { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, + cancel_on_stream_drop, }, ); @@ -469,6 +517,46 @@ impl PredictionService { Some(response) } + pub fn subscribe_prediction_stream( + self: &Arc, + id: &str, + ) -> Option { + let entry = self.predictions.get(id)?; + let stream = entry.prediction.lock().ok()?.subscribe_stream_replay(); + let cancel_on_stream_drop = entry.cancel_on_stream_drop; + Some(PredictionStreamSubscription { + id: id.to_string(), + replay: stream.replay, + skipped: stream.skipped, + receiver: stream.receiver, + guard: PredictionStreamGuard { + id: id.to_string(), + service: Arc::clone(self), + cancel_on_stream_drop, + }, + }) + } + + fn stream_receiver_count(&self, id: &str) -> usize { + self.predictions + .get(id) + .and_then(|entry| { + entry + .prediction + .lock() + .ok() + .map(|p| p.stream_receiver_count()) + }) + .unwrap_or(0) + } + + fn prediction_is_terminal(&self, id: &str) -> bool { + self.predictions + .get(id) + .and_then(|entry| entry.prediction.lock().ok().map(|p| p.is_terminal())) + .unwrap_or(true) + } + /// Run a prediction to completion via orchestrator. pub async fn predict( &self, @@ -541,6 +629,22 @@ impl PredictionService { ))); } + let was_cancelled_before_send = try_lock_prediction(&prediction_arc) + .map(|p| p.is_canceled()) + .unwrap_or(false); + if was_cancelled_before_send + && let Err(e) = state + .orchestrator + .cancel_by_prediction_id(&prediction_id) + .await + { + tracing::error!( + prediction_id = %prediction_id, + error = %e, + "Failed to forward pending cancellation after registration" + ); + } + // Wait for prediction to complete // Check if already terminal first to avoid race with fast completions let (already_terminal, completion) = { @@ -760,6 +864,103 @@ mod tests { } } + struct CountingCancelOrchestrator { + cancel_count: AtomicUsize, + } + + impl CountingCancelOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CountingCancelOrchestrator { + async fn register_prediction( + &self, + _slot_id: SlotId, + _prediction: Arc>, + _idle_sender: tokio::sync::oneshot::Sender, + ) { + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + + struct CancelRecordingOrchestrator { + cancel_count: AtomicUsize, + prediction: std::sync::Mutex>>>, + } + + impl CancelRecordingOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + prediction: std::sync::Mutex::new(None), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CancelRecordingOrchestrator { + async fn register_prediction( + &self, + slot_id: SlotId, + prediction: Arc>, + idle_sender: tokio::sync::oneshot::Sender, + ) { + *self.prediction.lock().unwrap() = Some(prediction); + let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate()); + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + if let Some(prediction) = self.prediction.lock().unwrap().as_ref() { + prediction.lock().unwrap().set_canceled(); + } + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; @@ -863,7 +1064,7 @@ mod tests { let svc = PredictionService::new_no_pool(); let result = svc - .submit_prediction("test".to_string(), serde_json::json!({}), None) + .submit_prediction("test".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::NotReady))); } @@ -908,7 +1109,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -916,6 +1117,137 @@ mod tests { assert!(svc.prediction_exists("test-1")); } + #[tokio::test] + async fn subscribe_prediction_stream_returns_receiver_for_existing_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::new()); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction("stream-test".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("stream-test").unwrap(); + assert_eq!(subscription.prediction_id(), "stream-test"); + } + + #[tokio::test] + async fn dropping_only_sync_stream_subscription_cancels_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction("sync-stream".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_async_json_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "async-json-stream".to_string(), + serde_json::json!({}), + None, + false, + ) + .await + .unwrap(); + + let subscription = svc + .subscribe_prediction_stream("async-json-stream") + .unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + + #[tokio::test] + async fn dropping_live_sse_stream_subscription_cancels_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "live-sse-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("live-sse-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_completed_sync_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "completed-sync-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + { + let entry = svc.predictions.get("completed-sync-stream").unwrap(); + let mut prediction = entry.prediction.lock().unwrap(); + prediction.set_succeeded(crate::PredictionOutput::Single(serde_json::json!("done"))); + } + + let subscription = svc + .subscribe_prediction_stream("completed-sync-stream") + .unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + #[tokio::test] async fn submit_returns_at_capacity_when_no_slots() { let svc = PredictionService::new_no_pool(); @@ -927,13 +1259,13 @@ mod tests { // First prediction takes the only slot let (_handle1, _slot1) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); // Second should fail with AtCapacity let result = svc - .submit_prediction("test-2".to_string(), serde_json::json!({}), None) + .submit_prediction("test-2".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::AtCapacity))); } @@ -953,6 +1285,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -970,6 +1303,42 @@ mod tests { assert_eq!(orch_ref.register_count(), 1); } + #[tokio::test] + async fn predict_forwards_cancel_token_set_before_registration() { + let svc = PredictionService::new_no_pool(); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CancelRecordingOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (handle, slot) = svc + .submit_prediction( + "pre-register-cancel".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + handle.cancel_token().cancel(); + + let result = tokio::time::timeout( + Duration::from_millis(100), + svc.predict( + slot, + serde_json::json!({}), + std::collections::HashMap::new(), + ), + ) + .await + .expect("prediction should observe cancellation after registration"); + + assert!(matches!(result, Err(PredictionError::Cancelled))); + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + #[tokio::test] async fn health_shows_busy_when_all_slots_used() { let svc = PredictionService::new_no_pool(); @@ -986,7 +1355,7 @@ mod tests { // After acquiring slot let (_handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); let health = svc.health().await; @@ -1009,6 +1378,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1048,6 +1418,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1087,6 +1458,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1113,7 +1485,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-cancel".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-cancel".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1139,7 +1516,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-guard".to_string(), serde_json::json!({}), None) + .submit_prediction("test-guard".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -1163,7 +1540,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-disarm".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-disarm".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1187,7 +1569,12 @@ mod tests { svc.set_health(Health::Ready).await; let (_handle, _slot) = svc - .submit_prediction("test-remove".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-remove".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index aa9875340e..869f0307a6 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -1,12 +1,17 @@ //! HTTP route handlers. +use std::convert::Infallible; use std::sync::Arc; +use std::time::Duration; use axum::{ Router, extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, - response::{IntoResponse, Json}, + response::{ + IntoResponse, Json, Response, + sse::{Event, KeepAlive, Sse}, + }, routing::{get, post, put}, }; use serde::{Deserialize, Serialize}; @@ -15,7 +20,9 @@ use serde::{Deserialize, Serialize}; use crate::health::Health; use crate::health::{HealthResponse, SetupResult}; use crate::predictor::PredictionError; -use crate::service::{CreatePredictionError, HealthSnapshot, PredictionService}; +use crate::service::{ + CreatePredictionError, HealthSnapshot, PredictionService, PredictionStreamSubscription, +}; use crate::version::VersionInfo; use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender}; @@ -209,6 +216,43 @@ fn should_respond_async(headers: &HeaderMap) -> bool { .unwrap_or(false) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PredictionResponseMode { + SyncJson, + AsyncJson, + AsyncSse, +} + +fn wants_sse(headers: &HeaderMap) -> bool { + headers + .get(axum::http::header::ACCEPT) + .and_then(|value| value.to_str().ok()) + .map(|accept| { + accept + .split(',') + .any(|part| part.trim().split(';').next() == Some("text/event-stream")) + }) + .unwrap_or(false) +} + +fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if wants_sse(headers) { + PredictionResponseMode::AsyncSse + } else if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + +fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + fn extract_trace_context(headers: &HeaderMap) -> TraceContext { TraceContext { traceparent: headers @@ -226,7 +270,7 @@ async fn create_prediction( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -235,7 +279,7 @@ async fn create_prediction( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = prediction_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -244,7 +288,7 @@ async fn create_prediction( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -256,7 +300,7 @@ async fn create_prediction_idempotent( Path(prediction_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -277,15 +321,20 @@ async fn create_prediction_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } + let response_mode = prediction_response_mode(&headers); + // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { - return (StatusCode::ACCEPTED, Json(response)); + if response_mode == PredictionResponseMode::AsyncSse { + return stream_prediction_response(service, &prediction_id); + } + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -294,7 +343,7 @@ async fn create_prediction_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -333,10 +382,10 @@ async fn create_prediction_with_id( context: std::collections::HashMap, webhook: Option, webhook_events_filter: Vec, - respond_async: bool, + response_mode: PredictionResponseMode, trace_context: TraceContext, is_training: bool, -) -> (StatusCode, Json) { +) -> Response { // Strip unknown fields and validate in one pass. Unknown inputs are // silently dropped to match Replicate's historical API behavior. let (stripped, validation_result) = if is_training { @@ -365,7 +414,8 @@ async fn create_prediction_with_id( return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": detail })), - ); + ) + .into_response(); } let webhook_sender = build_webhook_sender( @@ -376,7 +426,12 @@ async fn create_prediction_with_id( // Submit prediction: creates Prediction, acquires slot, registers in service let (handle, unregistered_slot) = match service - .submit_prediction(prediction_id.clone(), input.clone(), webhook_sender) + .submit_prediction( + prediction_id.clone(), + input.clone(), + webhook_sender, + response_mode != PredictionResponseMode::AsyncJson, + ) .await { Ok(r) => r, @@ -388,7 +443,8 @@ async fn create_prediction_with_id( "error": msg, "status": "failed" })), - ); + ) + .into_response(); } Err(CreatePredictionError::AtCapacity) => { return ( @@ -397,14 +453,21 @@ async fn create_prediction_with_id( "error": "At capacity - all prediction slots busy", "status": "failed" })), - ); + ) + .into_response(); } }; let prediction = unregistered_slot.prediction(); // Async mode: spawn background task, return immediately - if respond_async { + if response_mode != PredictionResponseMode::SyncJson { + let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse { + service.subscribe_prediction_stream(&prediction_id) + } else { + None + }; + let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); @@ -417,13 +480,25 @@ async fn create_prediction_with_id( service_clone.remove_prediction(&id_for_cleanup); }); + if response_mode == PredictionResponseMode::AsyncSse { + let Some(subscription) = sse_subscription else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(); + }; + return stream_prediction_subscription_response(subscription); + } + return ( StatusCode::ACCEPTED, Json(serde_json::json!({ "id": prediction_id, "status": "starting" })), - ); + ) + .into_response(); } // Sync mode: spawn prediction into a background task so the slot lifetime @@ -489,6 +564,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::InvalidInput(msg)) => { let metrics = build_metrics(&user_metrics); @@ -502,6 +578,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::NotReady) => { let msg = PredictionError::NotReady.to_string(); @@ -514,6 +591,7 @@ async fn create_prediction_with_id( "status": "failed" })), ) + .into_response() } Err(PredictionError::Failed(msg)) => { let metrics = build_metrics(&user_metrics); @@ -528,6 +606,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::Cancelled) => { let metrics = build_metrics(&user_metrics); @@ -540,6 +619,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } } } @@ -557,6 +637,104 @@ async fn cancel_prediction( } } +fn stream_event_to_sse(event: crate::prediction::PredictionStreamEvent) -> Event { + Event::default() + .event(event.event_name()) + .json_data(event.json_data()) + .expect("prediction stream events serialize to JSON") +} + +fn prediction_sse_stream( + subscription: PredictionStreamSubscription, +) -> impl futures::Stream> { + let (replay, replay_skipped, receiver, guard) = subscription.into_parts(); + + struct StreamState { + replay: std::collections::VecDeque, + replay_skipped: u64, + receiver: tokio::sync::broadcast::Receiver, + _guard: crate::service::PredictionStreamGuard, + done: bool, + } + + futures::stream::unfold( + StreamState { + replay: replay.into(), + replay_skipped, + receiver, + _guard: guard, + done: false, + }, + |mut state| async move { + if state.done { + return None; + } + + if state.replay_skipped > 0 { + let skipped = state.replay_skipped; + state.replay_skipped = 0; + state.done = true; + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream replay truncated; events were dropped", + "skipped": skipped, + })) + .expect("SSE replay truncation error serializes to JSON"); + return Some((Ok(event), state)); + } + + if let Some(event) = state.replay.pop_front() { + state.done = event.event_name() == "completed"; + return Some((Ok(stream_event_to_sse(event)), state)); + } + + match state.receiver.recv().await { + Ok(event) => { + state.done = event.event_name() == "completed"; + Some((Ok(stream_event_to_sse(event)), state)) + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "SSE prediction stream receiver lagged"); + state.done = true; + // In the future, this could become backpressure or cursor-based replay. + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream lagged; events were dropped", + "skipped": skipped, + })) + .expect("SSE lag error serializes to JSON"); + Some((Ok(event), state)) + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => None, + } + }, + ) +} + +fn stream_prediction_response(service: Arc, prediction_id: &str) -> Response { + let Some(subscription) = service.subscribe_prediction_stream(prediction_id) else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(); + }; + + stream_prediction_subscription_response(subscription) +} + +fn stream_prediction_subscription_response(subscription: PredictionStreamSubscription) -> Response { + Sse::new(prediction_sse_stream(subscription)) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive"), + ) + .into_response() +} + async fn shutdown(State(service): State>) -> impl IntoResponse { tracing::info!("Shutdown requested via HTTP"); service.trigger_shutdown(); @@ -582,7 +760,7 @@ async fn create_training( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -591,7 +769,7 @@ async fn create_training( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -600,7 +778,7 @@ async fn create_training( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -612,7 +790,7 @@ async fn create_training_idempotent( Path(training_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -633,15 +811,16 @@ async fn create_training_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } // Idempotent: return existing state if already submitted if let Some(response) = service.get_prediction_response(&training_id) { - return (StatusCode::ACCEPTED, Json(response)); + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -650,7 +829,7 @@ async fn create_training_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -955,6 +1134,217 @@ mod tests { assert_eq!(json["status"], "starting"); } + #[tokio::test] + async fn prediction_post_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/predictions") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + let sse = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn lagged_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "lagged-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("lagged-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("lagged SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!(sse.contains("SSE stream lagged"), "SSE body: {sse}"); + assert!(sse.contains("skipped"), "SSE body: {sse}"); + } + + #[tokio::test] + async fn truncated_replay_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "truncated-replay".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let subscription = service + .subscribe_prediction_stream("truncated-replay") + .unwrap(); + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("truncated replay SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!( + sse.contains("SSE stream replay truncated"), + "SSE body: {sse}" + ); + assert!(sse.contains("skipped"), "SSE body: {sse}"); + } + + #[tokio::test] + async fn prediction_put_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::put("/predictions/sse-put") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + + #[tokio::test] + async fn prediction_put_existing_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "existing-sse-put".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::put("/predictions/existing-sse-put") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + + #[tokio::test] + async fn stream_prediction_route_is_removed() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "removed-stream-route".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/removed-stream-route/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; diff --git a/examples/streaming-text/.dockerignore b/examples/streaming-text/.dockerignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.dockerignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/.gitignore b/examples/streaming-text/.gitignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.gitignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md new file mode 100644 index 0000000000..645664289b --- /dev/null +++ b/examples/streaming-text/README.md @@ -0,0 +1,51 @@ +# examples/streaming-text + +Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. + +This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks with Server-Sent Events. + +## Run a normal prediction + +From this directory: + +```sh +cog predict -i prompt="Write a short haiku about databases" +``` + +This returns the final accumulated output after the prediction completes. + +## Stream output over HTTP + +Start the server: + +```sh +cog serve +``` + +Create a prediction and request an SSE response: + +```sh +curl -N -X PUT http://localhost:5000/predictions/streaming-demo \ + -H 'Content-Type: application/json' \ + -H 'Accept: text/event-stream' \ + -d '{"input":{"prompt":"Write a short haiku about databases","max_new_tokens":96}}' +``` + +The response includes `output` events as chunks are generated, followed by a `completed` event: + +```text +event: output +data: {"chunk":"Silent","index":0} + +event: output +data: {"chunk":" rows","index":1} + +event: completed +data: {"id":"streaming-demo","status":"succeeded",...} +``` + +## How it works + +`predict.py` returns `Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. + +The normal prediction response still contains the accumulated output for compatibility. Requesting `Accept: text/event-stream` is useful when clients want to display tokens as they arrive. diff --git a/examples/streaming-text/cog.yaml b/examples/streaming-text/cog.yaml new file mode 100644 index 0000000000..68866b7aeb --- /dev/null +++ b/examples/streaming-text/cog.yaml @@ -0,0 +1,7 @@ +# Streaming text generation example using a small open-weight language model. + +build: + python_version: "3.12" + python_requirements: requirements.txt + +predict: "predict.py:Predictor" diff --git a/examples/streaming-text/predict.py b/examples/streaming-text/predict.py new file mode 100644 index 0000000000..3ed8772e54 --- /dev/null +++ b/examples/streaming-text/predict.py @@ -0,0 +1,64 @@ +from threading import Thread +from typing import Iterator + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from cog import BasePredictor, Input + +MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" + + +class Predictor(BasePredictor): + def setup(self) -> None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if self.device == "cuda" else torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=dtype, + ).to(self.device) + self.model.eval() + + def predict( + self, + prompt: str = Input(description="Prompt to complete"), + max_new_tokens: int = Input( + description="Maximum number of tokens to generate", + default=128, + ge=1, + le=512, + ), + ) -> Iterator[str]: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = self.tokenizer([text], return_tensors="pt").to(self.device) + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + generation_kwargs = { + **inputs, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": True, + "temperature": 0.7, + "top_p": 0.9, + "pad_token_id": self.tokenizer.eos_token_id, + } + + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + + for chunk in streamer: + if chunk: + yield chunk + + thread.join() diff --git a/examples/streaming-text/requirements.txt b/examples/streaming-text/requirements.txt new file mode 100644 index 0000000000..916933a3de --- /dev/null +++ b/examples/streaming-text/requirements.txt @@ -0,0 +1,3 @@ +torch==2.7.1 +transformers==4.51.3 +accelerate==1.6.0 diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar new file mode 100644 index 0000000000..32c757008f --- /dev/null +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -0,0 +1,32 @@ +# Test that async generator output is available when predictions are created with SSE accept. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-1","index":0}' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-2","index":1}' +stdout 'event: completed' +stdout '"status":"succeeded"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[str]: + time.sleep(0.25) + yield "chunk-1" + time.sleep(0.25) + yield "chunk-2"