diff --git a/Cargo.lock b/Cargo.lock index e83b480c9f..d4faa4d3ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18538,7 +18538,6 @@ dependencies = [ "quickcheck", "quickcheck_macros", "ractor", - "ractor-supervisor", "rodio", "sentry", "serde", diff --git a/apps/desktop/src/components/main/body/sessions/outer-header/listen.tsx b/apps/desktop/src/components/main/body/sessions/outer-header/listen.tsx index c002177087..fda0197689 100644 --- a/apps/desktop/src/components/main/body/sessions/outer-header/listen.tsx +++ b/apps/desktop/src/components/main/body/sessions/outer-header/listen.tsx @@ -1,7 +1,8 @@ import { useHover } from "@uidotdev/usehooks"; -import { MicOff } from "lucide-react"; -import { useCallback, useEffect, useRef } from "react"; +import { AlertTriangle, MicOff } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef } from "react"; +import type { DegradedError } from "@hypr/plugin-listener"; import { Tooltip, TooltipContent, @@ -188,24 +189,58 @@ function StartButton({ sessionId }: { sessionId: string }) { ); } +function formatDegradedError(error: DegradedError): string { + switch (error.type) { + case "authentication_failed": + return `Authentication failed for ${error.provider}`; + case "upstream_unavailable": + return error.message; + case "connection_timeout": + return "Connection timed out"; + case "stream_error": + return error.message; + case "channel_overflow": + return "Audio channel overflow"; + default: + return "Transcription unavailable"; + } +} + function InMeetingIndicator({ sessionId }: { sessionId: string }) { const [ref, hovered] = useHover(); + const openNew = useTabs((state) => state.openNew); - const { mode, stop, amplitude, muted } = useListener((state) => ({ - mode: state.getSessionMode(sessionId), - stop: state.stop, - amplitude: state.live.amplitude, - muted: state.live.muted, - })); + const { mode, stop, amplitude, muted, degradedError } = useListener( + (state) => ({ + mode: state.getSessionMode(sessionId), + stop: state.stop, + amplitude: state.live.amplitude, + muted: state.live.muted, + degradedError: state.live.degradedError, + }), + ); const active = mode === "active" || mode === "finalizing"; const finalizing = mode === "finalizing"; + const isDegraded = !!degradedError; + + const degradedMessage = useMemo( + () => + degradedError + ? `Transcription degraded: ${formatDegradedError(degradedError)}` + : null, + [degradedError], + ); + + const handleConfigureAction = useCallback(() => { + openNew({ type: "ai", state: { tab: "transcription" } }); + }, [openNew]); if (!active) { return null; } - return ( + const button = ( ); + + if (!isDegraded) { + return button; + } + + return ( + + + {button} + + + + + + ); } diff --git a/apps/desktop/src/components/main/body/sessions/shared.tsx b/apps/desktop/src/components/main/body/sessions/shared.tsx index 268389ff22..67b053f34c 100644 --- a/apps/desktop/src/components/main/body/sessions/shared.tsx +++ b/apps/desktop/src/components/main/body/sessions/shared.tsx @@ -58,7 +58,7 @@ export function RecordingIcon() { export function useListenButtonState(sessionId: string) { const sessionMode = useListener((state) => state.getSessionMode(sessionId)); - const lastError = useListener((state) => state.live.lastError); + const degradedError = useListener((state) => state.live.degradedError); const active = sessionMode === "active" || sessionMode === "finalizing"; const batching = sessionMode === "running_batch"; @@ -81,8 +81,18 @@ export function useListenButtonState(sessionId: string) { isOfflineWithCloudModel; let warningMessage = ""; - if (lastError) { - warningMessage = `Session failed: ${lastError}`; + if (degradedError) { + const errorMessage = + degradedError.type === "authentication_failed" + ? `Authentication failed for ${degradedError.provider}` + : degradedError.type === "upstream_unavailable" + ? degradedError.message + : degradedError.type === "connection_timeout" + ? "Connection timed out" + : degradedError.type === "stream_error" + ? degradedError.message + : "Transcription unavailable"; + warningMessage = `Transcription degraded: ${errorMessage}`; } else if (isLocalServerLoading) { warningMessage = "Local STT server is starting up..."; } else if (isOfflineWithCloudModel) { diff --git a/apps/desktop/src/store/zustand/listener/general.ts b/apps/desktop/src/store/zustand/listener/general.ts index 318f158a6c..a7449c1871 100644 --- a/apps/desktop/src/store/zustand/listener/general.ts +++ b/apps/desktop/src/store/zustand/listener/general.ts @@ -7,10 +7,10 @@ import { commands as detectCommands } from "@hypr/plugin-detect"; import { commands as hooksCommands } from "@hypr/plugin-hooks"; import { commands as iconCommands } from "@hypr/plugin-icon"; import { + type DegradedError, commands as listenerCommands, events as listenerEvents, type SessionDataEvent, - type SessionErrorEvent, type SessionLifecycleEvent, type SessionParams, type SessionProgressEvent, @@ -49,7 +49,7 @@ export type GeneralState = { intervalId?: NodeJS.Timeout; sessionId: string | null; muted: boolean; - lastError: string | null; + degradedError: DegradedError | null; device: string | null; }; }; @@ -77,7 +77,7 @@ const initialState: GeneralState = { seconds: 0, sessionId: null, muted: false, - lastError: null, + degradedError: null, device: null, }, }; @@ -85,7 +85,6 @@ const initialState: GeneralState = { type EventListeners = { lifecycle: (payload: SessionLifecycleEvent) => void; progress: (payload: SessionProgressEvent) => void; - error: (payload: SessionErrorEvent) => void; data: (payload: SessionDataEvent) => void; }; @@ -101,9 +100,6 @@ const listenToAllSessionEvents = ( listenerEvents.sessionProgressEvent.listen(({ payload }) => handlers.progress(payload), ), - listenerEvents.sessionErrorEvent.listen(({ payload }) => - handlers.error(payload), - ), listenerEvents.sessionDataEvent.listen(({ payload }) => handlers.data(payload), ), @@ -184,6 +180,7 @@ export const createGeneralSlice = < draft.live.seconds = 0; draft.live.intervalId = intervalId; draft.live.sessionId = targetSessionId; + draft.live.degradedError = payload.error ?? null; }), ); } else if (payload.type === "finalizing") { @@ -212,7 +209,7 @@ export const createGeneralSlice = < draft.live.loadingPhase = "idle"; draft.live.sessionId = null; draft.live.eventUnlisteners = undefined; - draft.live.lastError = payload.error ?? null; + draft.live.degradedError = null; draft.live.device = null; }), ); @@ -230,7 +227,7 @@ export const createGeneralSlice = < set((state) => mutate(state, (draft) => { draft.live.loadingPhase = "audio_initializing"; - draft.live.lastError = null; + draft.live.degradedError = null; }), ); } else if (payload.type === "audio_ready") { @@ -255,29 +252,6 @@ export const createGeneralSlice = < } }; - const handleErrorEvent = (payload: SessionErrorEvent) => { - if (payload.session_id !== targetSessionId) { - return; - } - - if (payload.type === "audio_error") { - set((state) => - mutate(state, (draft) => { - draft.live.lastError = payload.error; - if (payload.is_fatal) { - draft.live.loading = false; - } - }), - ); - } else if (payload.type === "connection_error") { - set((state) => - mutate(state, (draft) => { - draft.live.lastError = payload.error; - }), - ); - } - }; - const handleDataEvent = (payload: SessionDataEvent) => { if (payload.session_id !== targetSessionId) { return; @@ -308,7 +282,6 @@ export const createGeneralSlice = < const unlisteners = yield* listenToAllSessionEvents({ lifecycle: handleLifecycleEvent, progress: handleProgressEvent, - error: handleErrorEvent, data: handleDataEvent, }); @@ -384,7 +357,7 @@ export const createGeneralSlice = < draft.live.seconds = 0; draft.live.sessionId = null; draft.live.muted = initialState.live.muted; - draft.live.lastError = null; + draft.live.degradedError = null; draft.live.device = null; }), ); diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index afb6dc0abd..b9a9cc2651 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -65,7 +65,6 @@ hound = { workspace = true } vorbis_rs = { workspace = true } ractor = { workspace = true } -ractor-supervisor = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index 78bdab0f91..f4a78386eb 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -93,12 +93,10 @@ async listDocumentedLanguageCodesLive() : Promise> { export const events = __makeEvents__<{ sessionDataEvent: SessionDataEvent, -sessionErrorEvent: SessionErrorEvent, sessionLifecycleEvent: SessionLifecycleEvent, sessionProgressEvent: SessionProgressEvent }>({ sessionDataEvent: "plugin:listener:session-data-event", -sessionErrorEvent: "plugin:listener:session-error-event", sessionLifecycleEvent: "plugin:listener:session-lifecycle-event", sessionProgressEvent: "plugin:listener:session-progress-event" }) @@ -109,9 +107,10 @@ sessionProgressEvent: "plugin:listener:session-progress-event" /** user-defined types **/ +export type CriticalError = { message: string } +export type DegradedError = { type: "authentication_failed"; provider: string } | { type: "upstream_unavailable"; message: string } | { type: "connection_timeout" } | { type: "stream_error"; message: string } | { type: "channel_overflow" } export type SessionDataEvent = { type: "audio_amplitude"; session_id: string; mic: number; speaker: number } | { type: "mic_muted"; session_id: string; value: boolean } | { type: "stream_response"; session_id: string; response: StreamResponse } -export type SessionErrorEvent = { type: "audio_error"; session_id: string; error: string; device: string | null; is_fatal: boolean } | { type: "connection_error"; session_id: string; error: string } -export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string } | { type: "finalizing"; session_id: string } +export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error?: CriticalError | null } | { type: "active"; session_id: string; error?: DegradedError | null } | { type: "finalizing"; session_id: string } export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } export type SessionProgressEvent = { type: "audio_initializing"; session_id: string } | { type: "audio_ready"; session_id: string; device: string | null } | { type: "connecting"; session_id: string } | { type: "connected"; session_id: string; adapter: string } export type StreamAlternatives = { transcript: string; words: StreamWord[]; confidence: number; languages?: string[] } diff --git a/plugins/listener/src/actors/listener/adapters.rs b/plugins/listener/src/actors/listener/adapters.rs index e548f0f27a..8cefa8266c 100644 --- a/plugins/listener/src/actors/listener/adapters.rs +++ b/plugins/listener/src/actors/listener/adapters.rs @@ -2,7 +2,6 @@ use std::time::{Duration, UNIX_EPOCH}; use bytes::Bytes; use ractor::{ActorProcessingErr, ActorRef}; -use tauri_specta::Event; use owhisper_client::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, DeepgramAdapter, ElevenLabsAdapter, @@ -11,12 +10,26 @@ use owhisper_client::{ use owhisper_interface::stream::Extra; use owhisper_interface::{ControlMessage, MixedMessage}; -use super::stream::process_stream; -use super::{ - ChannelSender, DEVICE_FINGERPRINT_HEADER, LISTEN_CONNECT_TIMEOUT, ListenerArgs, ListenerMsg, - actor_error, -}; -use crate::SessionErrorEvent; +use super::stream::{ChannelSender, LISTEN_CONNECT_TIMEOUT, process_stream}; +use super::{ListenerArgs, ListenerMsg}; +use crate::DegradedError; + +const DEVICE_FINGERPRINT_HEADER: &str = "x-device-fingerprint"; + +#[derive(Debug)] +struct ListenerInitError(String); + +impl std::fmt::Display for ListenerInitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for ListenerInitError {} + +fn actor_error(msg: impl Into) -> ActorProcessingErr { + Box::new(ListenerInitError(msg.into())) +} pub(super) async fn spawn_rx_task( args: ListenerArgs, @@ -88,12 +101,12 @@ pub(super) async fn spawn_rx_task( Ok((result.0, result.1, result.2, adapter_kind.to_string())) } -fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams { +pub(super) fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams { let redemption_time_ms = if args.onboarding { "60" } else { "400" }; owhisper_interface::ListenParams { model: Some(args.model.clone()), languages: args.languages.clone(), - sample_rate: super::super::SAMPLE_RATE, + sample_rate: crate::actors::SAMPLE_RATE, keywords: args.keywords.clone(), custom_query: Some(std::collections::HashMap::from([( "redemption_time_ms".to_string(), @@ -103,7 +116,7 @@ fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams } } -fn build_extra(args: &ListenerArgs) -> (f64, Extra) { +pub(super) fn build_extra(args: &ListenerArgs) -> (f64, Extra) { let session_offset_secs = args.session_started_at.elapsed().as_secs_f64(); let started_unix_millis = args .session_started_at_unix @@ -156,21 +169,19 @@ async fn spawn_rx_task_single_with_adapter( timeout_secs = LISTEN_CONNECT_TIMEOUT.as_secs_f32(), "listen_ws_connect_timeout(single)" ); - let _ = (SessionErrorEvent::ConnectionError { - session_id: args.session_id.clone(), - error: "listen_ws_connect_timeout".to_string(), - }) - .emit(&args.app); - return Err(actor_error("listen_ws_connect_timeout")); + return Err(actor_error( + serde_json::to_string(&DegradedError::ConnectionTimeout) + .unwrap_or_else(|_| "connection_timeout".to_string()), + )); } Ok(Err(e)) => { tracing::error!(session_id = %args.session_id, error = ?e, "listen_ws_connect_failed(single)"); - let _ = (SessionErrorEvent::ConnectionError { - session_id: args.session_id.clone(), - error: format!("listen_ws_connect_failed: {:?}", e), - }) - .emit(&args.app); - return Err(actor_error(format!("listen_ws_connect_failed: {:?}", e))); + return Err(actor_error( + serde_json::to_string(&DegradedError::UpstreamUnavailable { + message: format!("{:?}", e), + }) + .unwrap_or_else(|_| format!("{:?}", e)), + )); } Ok(Ok(res)) => res, }; @@ -228,21 +239,19 @@ async fn spawn_rx_task_dual_with_adapter( timeout_secs = LISTEN_CONNECT_TIMEOUT.as_secs_f32(), "listen_ws_connect_timeout(dual)" ); - let _ = (SessionErrorEvent::ConnectionError { - session_id: args.session_id.clone(), - error: "listen_ws_connect_timeout".to_string(), - }) - .emit(&args.app); - return Err(actor_error("listen_ws_connect_timeout")); + return Err(actor_error( + serde_json::to_string(&DegradedError::ConnectionTimeout) + .unwrap_or_else(|_| "connection_timeout".to_string()), + )); } Ok(Err(e)) => { tracing::error!(session_id = %args.session_id, error = ?e, "listen_ws_connect_failed(dual)"); - let _ = (SessionErrorEvent::ConnectionError { - session_id: args.session_id.clone(), - error: format!("listen_ws_connect_failed: {:?}", e), - }) - .emit(&args.app); - return Err(actor_error(format!("listen_ws_connect_failed: {:?}", e))); + return Err(actor_error( + serde_json::to_string(&DegradedError::UpstreamUnavailable { + message: format!("{:?}", e), + }) + .unwrap_or_else(|_| format!("{:?}", e)), + )); } Ok(Ok(res)) => res, }; diff --git a/plugins/listener/src/actors/listener/mod.rs b/plugins/listener/src/actors/listener/mod.rs index e47a390e78..d97e391378 100644 --- a/plugins/listener/src/actors/listener/mod.rs +++ b/plugins/listener/src/actors/listener/mod.rs @@ -1,25 +1,24 @@ mod adapters; mod stream; -use std::time::{Duration, Instant, SystemTime}; +use std::time::{Instant, SystemTime}; use bytes::Bytes; -use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef}; use tauri_specta::Event; use tokio::time::error::Elapsed; use tracing::Instrument; +use owhisper_interface::MixedMessage; use owhisper_interface::stream::StreamResponse; -use owhisper_interface::{ControlMessage, MixedMessage}; -use super::root::session_span; -use crate::{SessionDataEvent, SessionErrorEvent, SessionProgressEvent}; +use super::session::session_span; +use crate::{DegradedError, SessionDataEvent, SessionProgressEvent}; use adapters::spawn_rx_task; +use stream::ChannelSender; -pub(super) const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(15 * 60); -pub(super) const LISTEN_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); -pub(super) const DEVICE_FINGERPRINT_HEADER: &str = "x-device-fingerprint"; +const AUDIO_SEND_FAILURE_THRESHOLD: u32 = 10; pub enum ListenerMsg { AudioSingle(Bytes), @@ -50,11 +49,7 @@ pub struct ListenerState { tx: ChannelSender, rx_task: tokio::task::JoinHandle<()>, shutdown_tx: Option>, -} - -pub(super) enum ChannelSender { - Single(tokio::sync::mpsc::Sender>), - Dual(tokio::sync::mpsc::Sender>), + consecutive_send_failures: u32, } pub struct ListenerActor; @@ -65,21 +60,6 @@ impl ListenerActor { } } -#[derive(Debug)] -pub(super) struct ListenerInitError(pub(super) String); - -impl std::fmt::Display for ListenerInitError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl std::error::Error for ListenerInitError {} - -pub(super) fn actor_error(msg: impl Into) -> ActorProcessingErr { - Box::new(ListenerInitError(msg.into())) -} - #[ractor::async_trait] impl Actor for ListenerActor { type Msg = ListenerMsg; @@ -120,6 +100,7 @@ impl Actor for ListenerActor { tx, rx_task, shutdown_tx: Some(shutdown_tx), + consecutive_send_failures: 0, }; Ok(state) @@ -152,13 +133,53 @@ impl Actor for ListenerActor { match message { ListenerMsg::AudioSingle(audio) => { if let ChannelSender::Single(tx) = &state.tx { - let _ = tx.try_send(MixedMessage::Audio(audio)); + match tx.try_send(MixedMessage::Audio(audio)) { + Ok(()) => { + state.consecutive_send_failures = 0; + } + Err(e) => { + state.consecutive_send_failures += 1; + if state.consecutive_send_failures >= AUDIO_SEND_FAILURE_THRESHOLD { + tracing::error!( + consecutive_failures = state.consecutive_send_failures, + error = ?e, + "audio_send_failures_exceeded_threshold_entering_degraded_mode" + ); + stop_with_degraded_error(myself, DegradedError::ChannelOverflow); + return Ok(()); + } + tracing::warn!( + consecutive_failures = state.consecutive_send_failures, + "audio_send_failed" + ); + } + } } } ListenerMsg::AudioDual(mic, spk) => { if let ChannelSender::Dual(tx) = &state.tx { - let _ = tx.try_send(MixedMessage::Audio((mic, spk))); + match tx.try_send(MixedMessage::Audio((mic, spk))) { + Ok(()) => { + state.consecutive_send_failures = 0; + } + Err(e) => { + state.consecutive_send_failures += 1; + if state.consecutive_send_failures >= AUDIO_SEND_FAILURE_THRESHOLD { + tracing::error!( + consecutive_failures = state.consecutive_send_failures, + error = ?e, + "audio_send_failures_exceeded_threshold_entering_degraded_mode" + ); + stop_with_degraded_error(myself, DegradedError::ChannelOverflow); + return Ok(()); + } + tracing::warn!( + consecutive_failures = state.consecutive_send_failures, + "audio_send_failed" + ); + } + } } } @@ -175,19 +196,12 @@ impl Actor for ListenerActor { %provider, "stream_provider_error" ); - let _ = (SessionErrorEvent::ConnectionError { - session_id: state.args.session_id.clone(), - error: format!( - "[{}] {} (code: {})", - provider, - error_message, - error_code - .map(|c| c.to_string()) - .unwrap_or_else(|| "none".to_string()) - ), - }) - .emit(&state.args.app); - myself.stop(Some(format!("{}: {}", provider, error_message))); + stop_with_degraded_error( + myself, + DegradedError::AuthenticationFailed { + provider: provider.clone(), + }, + ); return Ok(()); } @@ -212,40 +226,30 @@ impl Actor for ListenerActor { } ListenerMsg::StreamError(error) => { - tracing::info!("listen_stream_error: {}", error); - myself.stop(None); + tracing::warn!("listen_stream_error: {}", error); + stop_with_degraded_error(myself, DegradedError::StreamError { message: error }); } ListenerMsg::StreamEnded => { - tracing::info!("listen_stream_ended"); - myself.stop(None); + tracing::warn!("listen_stream_ended_unexpectedly"); + stop_with_degraded_error( + myself, + DegradedError::UpstreamUnavailable { + message: "stream ended unexpectedly".to_string(), + }, + ); } - ListenerMsg::StreamTimeout(elapsed) => { - tracing::info!("listen_stream_timeout: {}", elapsed); - myself.stop(None); + ListenerMsg::StreamTimeout(_elapsed) => { + tracing::warn!("listen_stream_timeout"); + stop_with_degraded_error(myself, DegradedError::ConnectionTimeout); } } Ok(()) } +} - async fn handle_supervisor_evt( - &self, - myself: ActorRef, - message: SupervisionEvent, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - let span = session_span(&state.args.session_id); - let _guard = span.enter(); - tracing::info!("supervisor_event: {:?}", message); - - match message { - SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} - SupervisionEvent::ActorTerminated(_, _, _) => {} - SupervisionEvent::ActorFailed(_cell, _) => { - myself.stop(None); - } - } - Ok(()) - } +fn stop_with_degraded_error(myself: ActorRef, error: DegradedError) { + let reason = serde_json::to_string(&error).ok(); + myself.stop(reason); } diff --git a/plugins/listener/src/actors/listener/stream.rs b/plugins/listener/src/actors/listener/stream.rs index 3778403683..c17fc20923 100644 --- a/plugins/listener/src/actors/listener/stream.rs +++ b/plugins/listener/src/actors/listener/stream.rs @@ -1,11 +1,21 @@ use std::time::Duration; +use bytes::Bytes; use futures_util::StreamExt; -use owhisper_client::FinalizeHandle; -use owhisper_interface::stream::{Extra, StreamResponse}; use ractor::ActorRef; -use super::{LISTEN_STREAM_TIMEOUT, ListenerMsg}; +use owhisper_interface::stream::{Extra, StreamResponse}; +use owhisper_interface::{ControlMessage, MixedMessage}; + +use super::ListenerMsg; + +pub(super) const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(5 * 60); +pub(super) const LISTEN_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); + +pub(super) enum ChannelSender { + Single(tokio::sync::mpsc::Sender>), + Dual(tokio::sync::mpsc::Sender>), +} pub(super) async fn process_stream( mut listen_stream: std::pin::Pin<&mut S>, @@ -17,7 +27,7 @@ pub(super) async fn process_stream( ) where S: futures_util::Stream>, E: std::fmt::Debug, - H: FinalizeHandle, + H: owhisper_client::FinalizeHandle, { loop { tokio::select! { diff --git a/plugins/listener/src/actors/mod.rs b/plugins/listener/src/actors/mod.rs index 7bf02216ec..aa2ccb4daf 100644 --- a/plugins/listener/src/actors/mod.rs +++ b/plugins/listener/src/actors/mod.rs @@ -1,13 +1,11 @@ mod listener; mod recorder; -mod root; -mod session; +pub(crate) mod session; mod source; pub use listener::*; pub use recorder::*; -pub use root::*; -pub use session::*; +pub use session::{SessionActor, SessionArgs, SessionMsg, SessionParams}; pub use source::*; #[cfg(target_os = "macos")] diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs deleted file mode 100644 index e105d6abaa..0000000000 --- a/plugins/listener/src/actors/root.rs +++ /dev/null @@ -1,298 +0,0 @@ -use std::collections::BTreeMap; -use std::time::{Instant, SystemTime}; - -use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, RpcReplyPort, SupervisionEvent}; -use tauri_plugin_settings::SettingsPluginExt; -use tauri_specta::Event; -use tracing::Instrument; - -use crate::SessionLifecycleEvent; -use crate::actors::{SessionContext, SessionParams, spawn_session_supervisor}; - -/// Creates a tracing span with session context that child events will inherit -pub(crate) fn session_span(session_id: &str) -> tracing::Span { - tracing::info_span!("session", session_id = %session_id) -} - -pub enum RootMsg { - StartSession(SessionParams, RpcReplyPort), - StopSession(RpcReplyPort<()>), - GetState(RpcReplyPort), -} - -pub struct RootArgs { - pub app: tauri::AppHandle, -} - -pub struct RootState { - app: tauri::AppHandle, - session_id: Option, - supervisor: Option, - finalizing: bool, -} - -pub struct RootActor; - -impl RootActor { - pub fn name() -> ractor::ActorName { - "listener_root_actor".into() - } -} - -#[ractor::async_trait] -impl Actor for RootActor { - type Msg = RootMsg; - type State = RootState; - type Arguments = RootArgs; - - async fn pre_start( - &self, - _myself: ActorRef, - args: Self::Arguments, - ) -> Result { - Ok(RootState { - app: args.app, - session_id: None, - supervisor: None, - finalizing: false, - }) - } - - async fn handle( - &self, - myself: ActorRef, - message: Self::Msg, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match message { - RootMsg::StartSession(params, reply) => { - let success = start_session_impl(myself.get_cell(), params, state).await; - let _ = reply.send(success); - } - RootMsg::StopSession(reply) => { - stop_session_impl(state).await; - let _ = reply.send(()); - } - RootMsg::GetState(reply) => { - let fsm_state = if state.finalizing { - crate::fsm::State::Finalizing - } else if state.supervisor.is_some() { - crate::fsm::State::Active - } else { - crate::fsm::State::Inactive - }; - let _ = reply.send(fsm_state); - } - } - Ok(()) - } - - async fn handle_supervisor_evt( - &self, - _myself: ActorRef, - message: SupervisionEvent, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match message { - SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} - SupervisionEvent::ActorTerminated(cell, _, reason) => { - if let Some(supervisor) = &state.supervisor - && cell.get_id() == supervisor.get_id() - { - let session_id = state.session_id.take().unwrap_or_default(); - let span = session_span(&session_id); - let _guard = span.enter(); - tracing::info!(?reason, "session_supervisor_terminated"); - state.supervisor = None; - state.finalizing = false; - emit_session_ended(&state.app, &session_id, None); - } - } - SupervisionEvent::ActorFailed(cell, error) => { - if let Some(supervisor) = &state.supervisor - && cell.get_id() == supervisor.get_id() - { - let session_id = state.session_id.take().unwrap_or_default(); - let span = session_span(&session_id); - let _guard = span.enter(); - tracing::warn!(?error, "session_supervisor_failed"); - state.supervisor = None; - state.finalizing = false; - emit_session_ended(&state.app, &session_id, Some(format!("{:?}", error))); - } - } - } - Ok(()) - } -} - -async fn start_session_impl( - root_cell: ActorCell, - params: SessionParams, - state: &mut RootState, -) -> bool { - let session_id = params.session_id.clone(); - let span = session_span(&session_id); - - async { - if state.supervisor.is_some() { - tracing::warn!("session_already_running"); - return false; - } - - configure_sentry_session_context(¶ms); - - let app_dir = match state.app.settings().settings_base() { - Ok(base) => base.join("sessions"), - Err(e) => { - tracing::error!(error = ?e, "failed_to_resolve_sessions_base_dir"); - clear_sentry_session_context(); - return false; - } - }; - - { - use tauri_plugin_tray::TrayPluginExt; - let _ = state.app.tray().set_start_disabled(true); - } - - let ctx = SessionContext { - app: state.app.clone(), - params: params.clone(), - app_dir, - started_at_instant: Instant::now(), - started_at_system: SystemTime::now(), - }; - - match spawn_session_supervisor(ctx).await { - Ok((supervisor_cell, _handle)) => { - supervisor_cell.link(root_cell); - - state.session_id = Some(params.session_id.clone()); - state.supervisor = Some(supervisor_cell); - - if let Err(error) = (SessionLifecycleEvent::Active { - session_id: params.session_id, - }) - .emit(&state.app) - { - tracing::error!(?error, "failed_to_emit_active"); - } - - tracing::info!("session_started"); - true - } - Err(e) => { - tracing::error!(error = ?e, "failed_to_start_session"); - clear_sentry_session_context(); - - use tauri_plugin_tray::TrayPluginExt; - let _ = state.app.tray().set_start_disabled(false); - false - } - } - } - .instrument(span) - .await -} - -fn configure_sentry_session_context(params: &SessionParams) { - sentry::configure_scope(|scope| { - scope.set_tag("session_id", ¶ms.session_id); - scope.set_tag( - "session_type", - if params.onboarding { - "onboarding" - } else { - "production" - }, - ); - - let mut session_context = BTreeMap::new(); - session_context.insert("session_id".to_string(), params.session_id.clone().into()); - session_context.insert("model".to_string(), params.model.clone().into()); - session_context.insert("record_enabled".to_string(), params.record_enabled.into()); - session_context.insert("onboarding".to_string(), params.onboarding.into()); - session_context.insert( - "languages".to_string(), - format!("{:?}", params.languages).into(), - ); - scope.set_context("session", sentry::protocol::Context::Other(session_context)); - }); -} - -async fn stop_session_impl(state: &mut RootState) { - if let Some(supervisor) = &state.supervisor { - state.finalizing = true; - - if let Some(session_id) = &state.session_id { - let span = session_span(session_id); - let _guard = span.enter(); - tracing::info!("session_finalizing"); - - if let Err(error) = (SessionLifecycleEvent::Finalizing { - session_id: session_id.clone(), - }) - .emit(&state.app) - { - tracing::error!(?error, "failed_to_emit_finalizing"); - } - } - - // TO make sure post_stop is called. - stop_actor_by_name_and_wait(crate::actors::RecorderActor::name(), "session_stop").await; - - supervisor.stop(None); - } -} - -async fn stop_actor_by_name_and_wait(actor_name: ractor::ActorName, reason: &str) { - if let Some(cell) = ractor::registry::where_is(actor_name.clone()) { - cell.stop(Some(reason.to_string())); - wait_for_actor_shutdown(actor_name).await; - } -} - -async fn wait_for_actor_shutdown(actor_name: ractor::ActorName) { - for _ in 0..50 { - if ractor::registry::where_is(actor_name.clone()).is_none() { - break; - } - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } -} - -fn emit_session_ended(app: &tauri::AppHandle, session_id: &str, failure_reason: Option) { - let span = session_span(session_id); - let _guard = span.enter(); - - { - use tauri_plugin_tray::TrayPluginExt; - let _ = app.tray().set_start_disabled(false); - } - - if let Err(error) = (SessionLifecycleEvent::Inactive { - session_id: session_id.to_string(), - error: failure_reason.clone(), - }) - .emit(app) - { - tracing::error!(?error, "failed_to_emit_inactive"); - } - - if let Some(reason) = failure_reason { - tracing::info!(failure_reason = %reason, "session_stopped"); - } else { - tracing::info!("session_stopped"); - } - - clear_sentry_session_context(); -} - -fn clear_sentry_session_context() { - sentry::configure_scope(|scope| { - scope.remove_tag("session_id"); - scope.remove_tag("session_type"); - scope.remove_context("session"); - }); -} diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs deleted file mode 100644 index dac82518d6..0000000000 --- a/plugins/listener/src/actors/session.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::path::PathBuf; -use std::time::{Instant, SystemTime}; - -use ractor::concurrency::Duration; -use ractor::{Actor, ActorCell, ActorProcessingErr}; -use ractor_supervisor::SupervisorStrategy; -use ractor_supervisor::core::{ChildBackoffFn, ChildSpec, Restart, SpawnFn}; -use ractor_supervisor::supervisor::{Supervisor, SupervisorArguments, SupervisorOptions}; - -use crate::actors::{ - ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, -}; - -pub const SESSION_SUPERVISOR_PREFIX: &str = "session_supervisor_"; - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] -pub struct SessionParams { - pub session_id: String, - pub languages: Vec, - pub onboarding: bool, - pub record_enabled: bool, - pub model: String, - pub base_url: String, - pub api_key: String, - pub keywords: Vec, -} - -#[derive(Clone)] -pub struct SessionContext { - pub app: tauri::AppHandle, - pub params: SessionParams, - pub app_dir: PathBuf, - pub started_at_instant: Instant, - pub started_at_system: SystemTime, -} - -pub fn session_supervisor_name(session_id: &str) -> String { - format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) -} - -fn make_supervisor_options() -> SupervisorOptions { - SupervisorOptions { - strategy: SupervisorStrategy::RestForOne, - max_restarts: 3, - max_window: Duration::from_secs(15), - reset_after: Some(Duration::from_secs(30)), - } -} - -fn make_listener_backoff() -> ChildBackoffFn { - ChildBackoffFn::new(|_id, count, _, _| { - if count == 0 { - None - } else { - Some(Duration::from_millis(500)) - } - }) -} - -pub async fn spawn_session_supervisor( - ctx: SessionContext, -) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { - let supervisor_name = session_supervisor_name(&ctx.params.session_id); - - let mut child_specs = Vec::new(); - - let ctx_source = ctx.clone(); - child_specs.push(ChildSpec { - id: SourceActor::name().to_string(), - restart: Restart::Permanent, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_source.clone(); - async move { - let (actor_ref, _) = Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - mic_device: None, - onboarding: ctx.params.onboarding, - app: ctx.app.clone(), - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: None, - reset_after: Some(Duration::from_secs(30)), - }); - - let ctx_listener = ctx.clone(); - child_specs.push(ChildSpec { - id: ListenerActor::name().to_string(), - restart: Restart::Permanent, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_listener.clone(); - async move { - let mode = ChannelMode::determine(ctx.params.onboarding); - - let (actor_ref, _) = Actor::spawn_linked( - Some(ListenerActor::name()), - ListenerActor, - ListenerArgs { - app: ctx.app.clone(), - languages: ctx.params.languages.clone(), - onboarding: ctx.params.onboarding, - model: ctx.params.model.clone(), - base_url: ctx.params.base_url.clone(), - api_key: ctx.params.api_key.clone(), - keywords: ctx.params.keywords.clone(), - mode, - session_started_at: ctx.started_at_instant, - session_started_at_unix: ctx.started_at_system, - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: Some(make_listener_backoff()), - reset_after: Some(Duration::from_secs(30)), - }); - - if ctx.params.record_enabled { - let ctx_recorder = ctx.clone(); - child_specs.push(ChildSpec { - id: RecorderActor::name().to_string(), - restart: Restart::Transient, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_recorder.clone(); - async move { - let (actor_ref, _) = Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: ctx.app_dir.clone(), - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: None, - reset_after: None, - }); - } - - let args = SupervisorArguments { - child_specs, - options: make_supervisor_options(), - }; - - let (supervisor_ref, handle) = Supervisor::spawn(supervisor_name, args).await?; - - Ok((supervisor_ref.get_cell(), handle)) -} diff --git a/plugins/listener/src/actors/session/lifecycle.rs b/plugins/listener/src/actors/session/lifecycle.rs new file mode 100644 index 0000000000..110705a334 --- /dev/null +++ b/plugins/listener/src/actors/session/lifecycle.rs @@ -0,0 +1,331 @@ +use std::collections::BTreeMap; +use std::time::{Duration, Instant, SystemTime}; + +use ractor::{Actor, ActorProcessingErr, ActorRef}; +use tauri_plugin_settings::SettingsPluginExt; +use tauri_specta::Event; +use tracing::Instrument; + +use super::supervision::RestartState; +use super::{ActiveSession, SessionActorState, SessionMsg, SessionParams}; +use crate::actors::{ + ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, +}; +use crate::{CriticalError, DegradedError, SessionLifecycleEvent}; + +use super::session_span; + +pub(super) async fn start_session_impl( + myself: ActorRef, + params: SessionParams, + state: &mut SessionActorState, +) -> bool { + let session_id = params.session_id.clone(); + let span = session_span(&session_id); + + async { + if state.active.is_some() { + tracing::warn!("session_already_running"); + return false; + } + + configure_sentry_session_context(¶ms); + + let app_dir = match state.app.settings().settings_base() { + Ok(base) => base.join("sessions"), + Err(e) => { + tracing::error!(error = ?e, "failed_to_resolve_sessions_base_dir"); + clear_sentry_session_context(); + return false; + } + }; + + { + use tauri_plugin_tray::TrayPluginExt; + let _ = state.app.tray().set_start_disabled(true); + } + + let mut active = ActiveSession { + session_id: params.session_id.clone(), + app_dir, + params: params.clone(), + started_at_instant: Instant::now(), + started_at_system: SystemTime::now(), + source: None, + recorder: None, + listener: None, + listener_degraded: None, + source_restart: RestartState::new(), + recorder_restart: RestartState::new(), + }; + + if let Err(e) = spawn_source(&myself, &mut active, &state.app).await { + tracing::error!(error = ?e, "failed_to_spawn_source"); + cleanup_failed_start(state, ¶ms).await; + return false; + } + + if params.record_enabled { + if let Err(e) = spawn_recorder(&myself, &mut active).await { + tracing::error!(error = ?e, "failed_to_spawn_recorder"); + if let Some(source) = &active.source { + source.stop(Some("startup_failure".to_string())); + } + cleanup_failed_start(state, ¶ms).await; + return false; + } + } + + if let Err(e) = spawn_listener(&myself, &mut active, &state.app).await { + tracing::warn!(error = ?e, "failed_to_spawn_listener_continuing_degraded"); + active.listener_degraded = Some(DegradedError::UpstreamUnavailable { + message: format!("{:?}", e), + }); + } + + state.active = Some(active); + + let degraded_error = state + .active + .as_ref() + .and_then(|s| s.listener_degraded.clone()); + + if let Err(error) = (SessionLifecycleEvent::Active { + session_id: params.session_id, + error: degraded_error, + }) + .emit(&state.app) + { + tracing::error!(?error, "failed_to_emit_active"); + } + + tracing::info!("session_started"); + true + } + .instrument(span) + .await +} + +pub(super) async fn stop_session_impl(state: &mut SessionActorState) { + let Some(active) = &state.active else { + return; + }; + + state.finalizing = true; + + let span = session_span(&active.session_id); + let _guard = span.enter(); + tracing::info!("session_finalizing"); + + if let Err(error) = (SessionLifecycleEvent::Finalizing { + session_id: active.session_id.clone(), + }) + .emit(&state.app) + { + tracing::error!(?error, "failed_to_emit_finalizing"); + } + + stop_actor_by_name_and_wait(RecorderActor::name(), "session_stop").await; + + if let Some(source) = &active.source { + source.stop(Some("session_stop".to_string())); + } + if let Some(listener) = &active.listener { + listener.stop(Some("session_stop".to_string())); + } + + wait_for_all_actors_shutdown(active).await; + + let session_id = active.session_id.clone(); + state.active = None; + state.finalizing = false; + + emit_session_ended(&state.app, &session_id, None); +} + +async fn cleanup_failed_start(state: &mut SessionActorState, params: &SessionParams) { + clear_sentry_session_context(); + + use tauri_plugin_tray::TrayPluginExt; + let _ = state.app.tray().set_start_disabled(false); + + let _ = (SessionLifecycleEvent::Inactive { + session_id: params.session_id.clone(), + error: Some(CriticalError { + message: "Failed to start session".to_string(), + }), + }) + .emit(&state.app); +} + +pub(super) async fn spawn_source( + myself: &ActorRef, + active: &mut ActiveSession, + app: &tauri::AppHandle, +) -> Result<(), ActorProcessingErr> { + let (actor_ref, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: active.params.onboarding, + app: app.clone(), + session_id: active.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + active.source = Some(actor_ref.get_cell()); + Ok(()) +} + +pub(super) async fn spawn_recorder( + myself: &ActorRef, + active: &mut ActiveSession, +) -> Result<(), ActorProcessingErr> { + let (actor_ref, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: active.app_dir.clone(), + session_id: active.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + active.recorder = Some(actor_ref.get_cell()); + Ok(()) +} + +pub(super) async fn spawn_listener( + myself: &ActorRef, + active: &mut ActiveSession, + app: &tauri::AppHandle, +) -> Result<(), ActorProcessingErr> { + let mode = ChannelMode::determine(active.params.onboarding); + + let (actor_ref, _) = Actor::spawn_linked( + Some(ListenerActor::name()), + ListenerActor, + ListenerArgs { + app: app.clone(), + languages: active.params.languages.clone(), + onboarding: active.params.onboarding, + model: active.params.model.clone(), + base_url: active.params.base_url.clone(), + api_key: active.params.api_key.clone(), + keywords: active.params.keywords.clone(), + mode, + session_started_at: active.started_at_instant, + session_started_at_unix: active.started_at_system, + session_id: active.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + active.listener = Some(actor_ref.get_cell()); + Ok(()) +} + +fn configure_sentry_session_context(params: &SessionParams) { + sentry::configure_scope(|scope| { + scope.set_tag("session_id", ¶ms.session_id); + scope.set_tag( + "session_type", + if params.onboarding { + "onboarding" + } else { + "production" + }, + ); + + let mut session_context = BTreeMap::new(); + session_context.insert("session_id".to_string(), params.session_id.clone().into()); + session_context.insert("model".to_string(), params.model.clone().into()); + session_context.insert("record_enabled".to_string(), params.record_enabled.into()); + session_context.insert("onboarding".to_string(), params.onboarding.into()); + session_context.insert( + "languages".to_string(), + format!("{:?}", params.languages).into(), + ); + scope.set_context("session", sentry::protocol::Context::Other(session_context)); + }); +} + +pub(super) fn clear_sentry_session_context() { + sentry::configure_scope(|scope| { + scope.remove_tag("session_id"); + scope.remove_tag("session_type"); + scope.remove_context("session"); + }); +} + +async fn wait_for_all_actors_shutdown(active: &ActiveSession) { + wait_for_actor_shutdown(SourceActor::name()).await; + wait_for_actor_shutdown(RecorderActor::name()).await; + wait_for_actor_shutdown(ListenerActor::name()).await; + + let _ = active; +} + +async fn stop_actor_by_name_and_wait(actor_name: ractor::ActorName, reason: &str) { + if let Some(cell) = ractor::registry::where_is(actor_name.clone()) { + cell.stop(Some(reason.to_string())); + wait_for_actor_shutdown(actor_name).await; + } +} + +async fn wait_for_actor_shutdown(actor_name: ractor::ActorName) { + for _ in 0..50 { + if ractor::registry::where_is(actor_name.clone()).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} + +pub(super) fn emit_degraded(app: &tauri::AppHandle, session_id: &str, error: DegradedError) { + if let Err(e) = (SessionLifecycleEvent::Active { + session_id: session_id.to_string(), + error: Some(error), + }) + .emit(app) + { + tracing::error!(?e, "failed_to_emit_degraded"); + } +} + +pub(super) fn emit_session_ended( + app: &tauri::AppHandle, + session_id: &str, + failure_reason: Option, +) { + let span = session_span(session_id); + let _guard = span.enter(); + + { + use tauri_plugin_tray::TrayPluginExt; + let _ = app.tray().set_start_disabled(false); + } + + let error = failure_reason.as_ref().map(|msg| CriticalError { + message: msg.clone(), + }); + + if let Err(e) = (SessionLifecycleEvent::Inactive { + session_id: session_id.to_string(), + error, + }) + .emit(app) + { + tracing::error!(?e, "failed_to_emit_inactive"); + } + + if let Some(reason) = failure_reason { + tracing::info!(failure_reason = %reason, "session_stopped"); + } else { + tracing::info!("session_stopped"); + } + + clear_sentry_session_context(); +} diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs new file mode 100644 index 0000000000..c68399989f --- /dev/null +++ b/plugins/listener/src/actors/session/mod.rs @@ -0,0 +1,127 @@ +mod lifecycle; +mod supervision; + +use std::path::PathBuf; +use std::time::SystemTime; + +use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, RpcReplyPort, SupervisionEvent}; + +use crate::DegradedError; + +use lifecycle::{start_session_impl, stop_session_impl}; +use supervision::{RestartState, handle_supervisor_evt}; + +pub(crate) fn session_span(session_id: &str) -> tracing::Span { + tracing::info_span!("session", session_id = %session_id) +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct SessionParams { + pub session_id: String, + pub languages: Vec, + pub onboarding: bool, + pub record_enabled: bool, + pub model: String, + pub base_url: String, + pub api_key: String, + pub keywords: Vec, +} + +pub enum SessionMsg { + Start(SessionParams, RpcReplyPort), + Stop(RpcReplyPort<()>), + GetState(RpcReplyPort), +} + +pub struct SessionArgs { + pub app: tauri::AppHandle, +} + +pub(super) struct ActiveSession { + pub(super) session_id: String, + pub(super) app_dir: PathBuf, + pub(super) params: SessionParams, + pub(super) started_at_instant: std::time::Instant, + pub(super) started_at_system: SystemTime, + + pub(super) source: Option, + pub(super) recorder: Option, + pub(super) listener: Option, + + pub(super) listener_degraded: Option, + + pub(super) source_restart: RestartState, + pub(super) recorder_restart: RestartState, +} + +pub struct SessionActorState { + pub(super) app: tauri::AppHandle, + pub(super) active: Option, + pub(super) finalizing: bool, +} + +pub struct SessionActor; + +impl SessionActor { + pub fn name() -> ractor::ActorName { + "session_actor".into() + } +} + +#[ractor::async_trait] +impl Actor for SessionActor { + type Msg = SessionMsg; + type State = SessionActorState; + type Arguments = SessionArgs; + + async fn pre_start( + &self, + _myself: ActorRef, + args: Self::Arguments, + ) -> Result { + Ok(SessionActorState { + app: args.app, + active: None, + finalizing: false, + }) + } + + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + SessionMsg::Start(params, reply) => { + let success = start_session_impl(myself.clone(), params, state).await; + let _ = reply.send(success); + } + SessionMsg::Stop(reply) => { + stop_session_impl(state).await; + let _ = reply.send(()); + } + SessionMsg::GetState(reply) => { + let fsm_state = if state.finalizing { + crate::fsm::State::Finalizing + } else if state.active.is_some() { + crate::fsm::State::Active + } else { + crate::fsm::State::Inactive + }; + let _ = reply.send(fsm_state); + } + } + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + message: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + handle_supervisor_evt(myself, message, state).await; + Ok(()) + } +} diff --git a/plugins/listener/src/actors/session/supervision.rs b/plugins/listener/src/actors/session/supervision.rs new file mode 100644 index 0000000000..12945678a8 --- /dev/null +++ b/plugins/listener/src/actors/session/supervision.rs @@ -0,0 +1,252 @@ +use std::time::{Duration, Instant}; + +use ractor::{ActorRef, SupervisionEvent}; + +use super::lifecycle::{emit_degraded, emit_session_ended, spawn_recorder, spawn_source}; +use super::{SessionActorState, SessionMsg, session_span}; +use crate::DegradedError; + +const MAX_RESTARTS: u32 = 3; +const RESTART_WINDOW: Duration = Duration::from_secs(15); + +pub(crate) struct RestartState { + count: u32, + window_start: Instant, +} + +impl RestartState { + pub fn new() -> Self { + Self { + count: 0, + window_start: Instant::now(), + } + } + + pub fn should_restart(&mut self) -> bool { + let now = Instant::now(); + if now.duration_since(self.window_start) > RESTART_WINDOW { + self.count = 0; + self.window_start = now; + } + if self.count < MAX_RESTARTS { + self.count += 1; + true + } else { + false + } + } + + pub fn count(&self) -> u32 { + self.count + } +} + +enum RestartableChild { + Source, + Recorder, +} + +impl RestartableChild { + fn name(&self) -> &'static str { + match self { + Self::Source => "source", + Self::Recorder => "recorder", + } + } + + fn label(&self) -> &'static str { + match self { + Self::Source => "Source", + Self::Recorder => "Recorder", + } + } +} + +async fn try_restart_child( + child: RestartableChild, + myself: &ActorRef, + state: &mut SessionActorState, +) { + if state.finalizing { + return; + } + + let app = state.app.clone(); + let name = child.name(); + let label = child.label(); + + let can_restart = { + let Some(active) = state.active.as_mut() else { + return; + }; + match child { + RestartableChild::Source => active.source_restart.should_restart(), + RestartableChild::Recorder => active.recorder_restart.should_restart(), + } + }; + + if !can_restart { + tracing::error!("{name}_restart_limit_exceeded_meltdown"); + trigger_meltdown(state, &format!("{label} restart limit exceeded")).await; + return; + } + + let count = { + let active = state.active.as_ref().unwrap(); + match child { + RestartableChild::Source => active.source_restart.count(), + RestartableChild::Recorder => active.recorder_restart.count(), + } + }; + tracing::info!(restart_count = count, "restarting_{name}"); + + let spawn_result = { + let active = state.active.as_mut().unwrap(); + match child { + RestartableChild::Source => spawn_source(myself, active, &app).await, + RestartableChild::Recorder => spawn_recorder(myself, active).await, + } + }; + + if let Err(e) = spawn_result { + tracing::error!(error = ?e, "{name}_restart_failed_meltdown"); + trigger_meltdown(state, &format!("{label} restart failed")).await; + } +} + +pub(super) async fn handle_supervisor_evt( + myself: ActorRef, + message: SupervisionEvent, + state: &mut SessionActorState, +) { + if state.active.is_none() { + return; + } + + let session_id = state.active.as_ref().unwrap().session_id.clone(); + let span = session_span(&session_id); + let _guard = span.enter(); + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, reason) => { + let actor_id = cell.get_id(); + + let is_listener = state + .active + .as_ref() + .unwrap() + .listener + .as_ref() + .is_some_and(|c| c.get_id() == actor_id); + let is_source = state + .active + .as_ref() + .unwrap() + .source + .as_ref() + .is_some_and(|c| c.get_id() == actor_id); + let is_recorder = state + .active + .as_ref() + .unwrap() + .recorder + .as_ref() + .is_some_and(|c| c.get_id() == actor_id); + + if is_listener { + let active = state.active.as_mut().unwrap(); + active.listener = None; + + let degraded_error = reason + .as_ref() + .and_then(|r| serde_json::from_str::(r).ok()); + + if let Some(error) = degraded_error { + tracing::info!(?error, "listener_terminated_entering_degraded_mode"); + active.listener_degraded = Some(error.clone()); + emit_degraded(&state.app, &session_id, error); + } else { + tracing::info!(?reason, "listener_terminated"); + } + } else if is_source { + state.active.as_mut().unwrap().source = None; + tracing::info!(?reason, "source_terminated"); + try_restart_child(RestartableChild::Source, &myself, state).await; + } else if is_recorder { + state.active.as_mut().unwrap().recorder = None; + tracing::info!(?reason, "recorder_terminated"); + try_restart_child(RestartableChild::Recorder, &myself, state).await; + } + } + + SupervisionEvent::ActorFailed(cell, error) => { + let actor_id = cell.get_id(); + + let is_listener = state + .active + .as_ref() + .unwrap() + .listener + .as_ref() + .is_some_and(|c| c.get_id() == actor_id); + let is_source = state + .active + .as_ref() + .unwrap() + .source + .as_ref() + .is_some_and(|c| c.get_id() == actor_id); + let is_recorder = state + .active + .as_ref() + .unwrap() + .recorder + .as_ref() + .is_some_and(|c| c.get_id() == actor_id); + + if is_listener { + let active = state.active.as_mut().unwrap(); + active.listener = None; + tracing::warn!(?error, "listener_failed_entering_degraded_mode"); + let err = DegradedError::UpstreamUnavailable { + message: format!("{}", error), + }; + active.listener_degraded = Some(err.clone()); + emit_degraded(&state.app, &session_id, err); + } else if is_source { + state.active.as_mut().unwrap().source = None; + tracing::warn!(?error, "source_failed"); + try_restart_child(RestartableChild::Source, &myself, state).await; + } else if is_recorder { + state.active.as_mut().unwrap().recorder = None; + tracing::warn!(?error, "recorder_failed"); + try_restart_child(RestartableChild::Recorder, &myself, state).await; + } + } + } +} + +async fn trigger_meltdown(state: &mut SessionActorState, reason: &str) { + let Some(active) = state.active.take() else { + return; + }; + + let span = session_span(&active.session_id); + let _guard = span.enter(); + + state.finalizing = false; + + if let Some(source) = &active.source { + source.stop(Some("meltdown".to_string())); + } + if let Some(recorder) = &active.recorder { + recorder.stop(Some("meltdown".to_string())); + } + if let Some(listener) = &active.listener { + listener.stop(Some("meltdown".to_string())); + } + + emit_session_ended(&state.app, &active.session_id, Some(reason.to_string())); +} diff --git a/plugins/listener/src/actors/source/mod.rs b/plugins/listener/src/actors/source/mod.rs index 065c8b6c19..e281064e35 100644 --- a/plugins/listener/src/actors/source/mod.rs +++ b/plugins/listener/src/actors/source/mod.rs @@ -12,8 +12,8 @@ use tokio_util::sync::CancellationToken; use tracing::Instrument; use crate::{ - SessionErrorEvent, SessionProgressEvent, - actors::root::session_span, + ActorStopReason, SessionProgressEvent, SourceStopReason, + actors::session::session_span, actors::{AudioChunk, ChannelMode}, }; use hypr_audio::AudioInput; @@ -78,7 +78,8 @@ impl DeviceChangeWatcher { match event_rx.recv() { Ok(DeviceSwitch::DefaultInputChanged) => { tracing::info!("default_input_changed_restarting_source"); - actor.stop(Some("device_change".to_string())); + let reason = ActorStopReason::Source(SourceStopReason::DeviceChanged); + actor.stop(serde_json::to_string(&reason).ok()); } Ok(_) => {} Err(_) => break, @@ -180,14 +181,9 @@ impl Actor for SourceActor { } SourceMsg::StreamFailed(reason) => { tracing::error!(%reason, "source_stream_failed_stopping"); - let _ = (SessionErrorEvent::AudioError { - session_id: st.session_id.clone(), - error: reason.clone(), - device: st.mic_device.clone(), - is_fatal: true, - }) - .emit(&st.app); - myself.stop(Some(reason)); + let stop_reason = + ActorStopReason::Source(SourceStopReason::StreamFailed { message: reason }); + myself.stop(serde_json::to_string(&stop_reason).ok()); } } diff --git a/plugins/listener/src/actors/source/pipeline.rs b/plugins/listener/src/actors/source/pipeline.rs index 58ea62515a..be3c11bfb2 100644 --- a/plugins/listener/src/actors/source/pipeline.rs +++ b/plugins/listener/src/actors/source/pipeline.rs @@ -116,16 +116,19 @@ impl Pipeline { let Some(cell) = registry::where_is(ListenerActor::name()) else { self.audio_buffer.push(processed_mic, processed_spk, mode); - tracing::debug!( - actor = ListenerActor::name(), - buffered = self.audio_buffer.len(), - "listener_unavailable_buffering" - ); + if !self.audio_buffer.overflow_while_never_seen { + tracing::debug!( + actor = ListenerActor::name(), + buffered = self.audio_buffer.len(), + "listener_unavailable_buffering" + ); + } return; }; let actor: ActorRef = cell.into(); + self.audio_buffer.mark_listener_seen(); self.flush_buffer_to_listener(&actor, mode); self.send_to_listener(&actor, &processed_mic, &processed_spk, mode); @@ -181,6 +184,8 @@ impl Pipeline { struct AudioBuffer { buffer: VecDeque, max_size: usize, + listener_ever_seen: bool, + overflow_while_never_seen: bool, } impl AudioBuffer { @@ -188,12 +193,29 @@ impl AudioBuffer { Self { buffer: VecDeque::new(), max_size, + listener_ever_seen: false, + overflow_while_never_seen: false, } } + fn mark_listener_seen(&mut self) { + self.listener_ever_seen = true; + self.overflow_while_never_seen = false; + } + fn push(&mut self, mic: Arc<[f32]>, spk: Arc<[f32]>, mode: ChannelMode) { + if self.overflow_while_never_seen { + return; + } + if self.buffer.len() >= self.max_size { self.buffer.pop_front(); + if !self.listener_ever_seen { + tracing::warn!("audio_buffer_overflow_listener_never_connected_disabling_buffer"); + self.overflow_while_never_seen = true; + self.buffer.clear(); + return; + } tracing::warn!("audio_buffer_overflow"); } self.buffer.push_back((mic, spk, mode)); diff --git a/plugins/listener/src/error.rs b/plugins/listener/src/error.rs index f7f1bc2f37..f39adeba00 100644 --- a/plugins/listener/src/error.rs +++ b/plugins/listener/src/error.rs @@ -1,4 +1,4 @@ -use serde::{Serialize, ser::Serializer}; +use serde::{Deserialize, Serialize, ser::Serializer}; pub type Result = std::result::Result; @@ -30,3 +30,42 @@ impl Serialize for Error { serializer.serialize_str(self.to_string().as_ref()) } } + +#[derive(Debug, Clone, Serialize, Deserialize, specta::Type)] +#[serde(tag = "type")] +pub enum DegradedError { + #[serde(rename = "authentication_failed")] + AuthenticationFailed { provider: String }, + #[serde(rename = "upstream_unavailable")] + UpstreamUnavailable { message: String }, + #[serde(rename = "connection_timeout")] + ConnectionTimeout, + #[serde(rename = "stream_error")] + StreamError { message: String }, + #[serde(rename = "channel_overflow")] + ChannelOverflow, +} + +#[derive(Debug, Clone, Serialize, Deserialize, specta::Type)] +pub struct CriticalError { + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "actor", content = "reason")] +pub enum ActorStopReason { + Source(SourceStopReason), + Recorder(RecorderStopReason), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SourceStopReason { + StreamFailed { message: String }, + DeviceChanged, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RecorderStopReason { + IoError { message: String }, + EncodingError { message: String }, +} diff --git a/plugins/listener/src/events.rs b/plugins/listener/src/events.rs index 876a379dc0..d994488ec7 100644 --- a/plugins/listener/src/events.rs +++ b/plugins/listener/src/events.rs @@ -1,5 +1,7 @@ use owhisper_interface::stream::StreamResponse; +use crate::error::{CriticalError, DegradedError}; + #[macro_export] macro_rules! common_event_derives { ($item:item) => { @@ -16,10 +18,15 @@ common_event_derives! { #[serde(rename = "inactive")] Inactive { session_id: String, - error: Option, + #[serde(default)] + error: Option, }, #[serde(rename = "active")] - Active { session_id: String }, + Active { + session_id: String, + #[serde(default)] + error: Option, + }, #[serde(rename = "finalizing")] Finalizing { session_id: String }, } @@ -42,24 +49,6 @@ common_event_derives! { } } -common_event_derives! { - #[serde(tag = "type")] - pub enum SessionErrorEvent { - #[serde(rename = "audio_error")] - AudioError { - session_id: String, - error: String, - device: Option, - is_fatal: bool, - }, - #[serde(rename = "connection_error")] - ConnectionError { - session_id: String, - error: String, - }, - } -} - common_event_derives! { #[serde(tag = "type")] pub enum SessionDataEvent { diff --git a/plugins/listener/src/ext.rs b/plugins/listener/src/ext.rs index 1bf3b1a03b..7444027bf5 100644 --- a/plugins/listener/src/ext.rs +++ b/plugins/listener/src/ext.rs @@ -1,6 +1,6 @@ use ractor::{ActorRef, call_t, registry}; -use crate::actors::{RootActor, RootMsg, SessionParams, SourceActor, SourceMsg}; +use crate::actors::{SessionActor, SessionMsg, SessionParams, SourceActor, SourceMsg}; pub struct Listener<'a, R: tauri::Runtime, M: tauri::Manager> { #[allow(unused)] @@ -29,9 +29,9 @@ impl<'a, R: tauri::Runtime, M: tauri::Manager> Listener<'a, R, M> { #[tracing::instrument(skip_all)] pub async fn get_state(&self) -> crate::fsm::State { - if let Some(cell) = registry::where_is(RootActor::name()) { - let actor: ActorRef = cell.into(); - match call_t!(actor, RootMsg::GetState, 100) { + if let Some(cell) = registry::where_is(SessionActor::name()) { + let actor: ActorRef = cell.into(); + match call_t!(actor, SessionMsg::GetState, 100) { Ok(fsm_state) => fsm_state, Err(_) => crate::fsm::State::Inactive, } @@ -60,17 +60,17 @@ impl<'a, R: tauri::Runtime, M: tauri::Manager> Listener<'a, R, M> { #[tracing::instrument(skip_all)] pub async fn start_session(&self, params: SessionParams) { - if let Some(cell) = registry::where_is(RootActor::name()) { - let actor: ActorRef = cell.into(); - let _ = ractor::call!(actor, RootMsg::StartSession, params); + if let Some(cell) = registry::where_is(SessionActor::name()) { + let actor: ActorRef = cell.into(); + let _ = ractor::call!(actor, SessionMsg::Start, params); } } #[tracing::instrument(skip_all)] pub async fn stop_session(&self) { - if let Some(cell) = registry::where_is(RootActor::name()) { - let actor: ActorRef = cell.into(); - let _ = ractor::call!(actor, RootMsg::StopSession); + if let Some(cell) = registry::where_is(SessionActor::name()) { + let actor: ActorRef = cell.into(); + let _ = ractor::call!(actor, SessionMsg::Stop); } } } diff --git a/plugins/listener/src/lib.rs b/plugins/listener/src/lib.rs index 5a9bd67fe6..5f2cec3213 100644 --- a/plugins/listener/src/lib.rs +++ b/plugins/listener/src/lib.rs @@ -12,7 +12,7 @@ pub use error::*; pub use events::*; pub use ext::*; -use actors::{RootActor, RootArgs, SourceActor}; +use actors::{SessionActor, SessionArgs, SourceActor}; const PLUGIN_NAME: &str = "listener"; @@ -34,7 +34,6 @@ fn make_specta_builder() -> tauri_specta::Builder { .events(tauri_specta::collect_events![ SessionLifecycleEvent, SessionProgressEvent, - SessionErrorEvent, SessionDataEvent ]) .error_handling(tauri_specta::ErrorHandlingMode::Result) @@ -52,17 +51,17 @@ pub fn init() -> tauri::plugin::TauriPlugin { tauri::async_runtime::spawn(async move { match Actor::spawn( - Some(RootActor::name()), - RootActor, - RootArgs { app: app_handle }, + Some(SessionActor::name()), + SessionActor, + SessionArgs { app: app_handle }, ) .await { Ok(_) => { - tracing::info!("root_actor_spawned"); + tracing::info!("session_actor_spawned"); } Err(e) => { - tracing::error!(?e, "failed_to_spawn_root_actor"); + tracing::error!(?e, "failed_to_spawn_session_actor"); } } }); @@ -78,7 +77,7 @@ pub fn init() -> tauri::plugin::TauriPlugin { }) .on_drop(|_app| { hypr_intercept::unregister_quit_handler(PLUGIN_NAME); - if let Some(cell) = ractor::registry::where_is(RootActor::name()) { + if let Some(cell) = ractor::registry::where_is(SessionActor::name()) { cell.stop(None); } })