diff --git a/Cargo.lock b/Cargo.lock index 332266d..016faf4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3153,6 +3153,16 @@ dependencies = [ "objc2-foundation 0.2.2", ] +[[package]] +name = "objc2-local-authentication" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48e0b8b339e0d9d2ed4416b7f93f9d4daadff7d4dd797f89867cde11aeac607" +dependencies = [ + "objc2 0.6.4", + "objc2-foundation 0.3.2", +] + [[package]] name = "objc2-metal" version = "0.2.2" @@ -5069,12 +5079,18 @@ name = "voxit-core" version = "0.1.0" dependencies = [ "base64", + "core-foundation 0.10.1", + "core-foundation-sys", "directories", "futures-util", + "hound", "http", "keyring", + "objc2-foundation 0.3.2", + "objc2-local-authentication", "rand 0.10.0", "reqwest", + "security-framework-sys", "serde", "serde_json", "sha2", diff --git a/apps/voxit/src/main.rs b/apps/voxit/src/main.rs index e94ef7f..314b7c6 100644 --- a/apps/voxit/src/main.rs +++ b/apps/voxit/src/main.rs @@ -29,7 +29,6 @@ use eframe::{ }; #[cfg(target_os = "macos")] use enigo::{Direction, Enigo, Key, Keyboard, Settings}; #[cfg(target_os = "macos")] use global_hotkey::GlobalHotKeyManager; -use realtime::{RealtimeEvent, RealtimeSession}; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::EnvFilter; #[cfg(target_os = "macos")] @@ -46,8 +45,8 @@ use voxit_audio::{InputDevice, Recorder}; use voxit_core::{ auth::{self, AuthRecord, AuthStatus}, config::Config, - openai::{self, RewriteState}, - realtime, + inference::{self, RewriteState}, + realtime::{self, RealtimeEvent, RealtimeSession}, transcript::TranscriptAssembler, }; use voxit_macos::{MicrophonePermissionState, PermissionSettingsPane, TargetApp}; @@ -119,8 +118,8 @@ struct VoxitApp { auth_event_tx: Sender, realtime_event_rx: Receiver, realtime_event_tx: Sender, - inference_tx: Sender, - inference_rx: Receiver, + inference_tx: Sender, + inference_rx: Receiver, is_recording: bool, is_window_visible: bool, state: String, @@ -166,8 +165,8 @@ impl VoxitApp { auth_event_tx: Sender, realtime_event_rx: Receiver, realtime_event_tx: Sender, - inference_tx: Sender, - inference_rx: Receiver, + inference_tx: Sender, + inference_rx: Receiver, hotkey_mode_u8: Arc, ) -> Self { let start_hidden = config.ui.start_hidden; @@ -608,7 +607,7 @@ impl VoxitApp { fn handle_inference_events(&mut self) { while let Ok(event) = self.inference_rx.try_recv() { match event { - openai::InferenceEvent::Pass2Completed { total_ms, raw_transcript } => { + inference::InferenceEvent::Pass2Completed { total_ms, raw_transcript } => { self.is_finalizing = false; self.transcription_result = raw_transcript; self.state = "FinalizingPass2".to_string(); @@ -627,7 +626,7 @@ impl VoxitApp { self.status = format!("Pass2 completed in {total_ms} ms. {paste_status}"); } }, - openai::InferenceEvent::RewriteCompleted { total_ms, result } => { + inference::InferenceEvent::RewriteCompleted { total_ms, result } => { self.is_rewriting = false; if self.ignore_rewrite_result { @@ -669,7 +668,7 @@ impl VoxitApp { }, } }, - openai::InferenceEvent::Failed(err) => { + inference::InferenceEvent::Failed(err) => { self.is_finalizing = false; self.is_rewriting = false; self.status = format!("OpenAI failed: {err}"); @@ -738,49 +737,25 @@ impl VoxitApp { ); } - match auth::access_token() { - Ok((api_key, account_id)) => { - let realtime_config = realtime::RealtimeSessionConfig { - model: self.config.openai.realtime_model.clone(), - sample_rate_hz: self.config.audio.realtime_target_rate_hz, - noise_reduction: self.config.openai.realtime.noise_reduction.clone(), - }; - - match realtime::start_realtime_session( - api_key, - account_id, - realtime_config, - chunk_rx, - self.realtime_event_tx.clone(), - ) { - Ok(session) => { - self.realtime_session = Some(session); - }, - Err(err) => { - let realtime_status = format!( - "Realtime unavailable ({err}). Recording continues with Pass2 finalize." - ); - - self.status = if used_fallback { - format!("{fallback_prefix} {realtime_status}") - } else { - realtime_status - }; - }, - } - }, - Err(err) => { - let auth_status = format!( - "Realtime unavailable because auth is missing ({err}). Recording continues with Pass2 finalize." - ); - - self.status = if used_fallback { - format!("{fallback_prefix} {auth_status}") - } else { - auth_status - }; + let realtime_config = realtime::RealtimeSessionConfig { + model: self.config.openai.realtime_model.clone(), + sample_rate_hz: self.config.audio.realtime_target_rate_hz, + noise_reduction: self.config.openai.realtime.noise_reduction.clone(), + }; + let realtime_status = match inference::start_realtime_session( + realtime_config, + chunk_rx, + self.realtime_event_tx.clone(), + ) { + Ok(session) => { + self.realtime_session = Some(session); + + None }, - } + Err(err) => Some(format!( + "Realtime unavailable ({err}). Recording continues with Pass2 finalize." + )), + }; self.is_recording = true; self.is_finalizing = false; @@ -794,6 +769,10 @@ impl VoxitApp { started_status = format!("{fallback_prefix} {started_status}"); } + if let Some(realtime_status) = realtime_status { + started_status = format!("{started_status} {realtime_status}"); + } + self.status = started_status; }, Err(err) => { @@ -867,12 +846,12 @@ impl VoxitApp { self.is_finalizing = true; self.status = "Finalizing Pass2 transcript...".to_string(); thread::spawn(move || { - let outcome = openai::transcribe_only(&wav, &model) - .map(|(raw_transcript, total_ms)| openai::InferenceEvent::Pass2Completed { + let outcome = inference::transcribe_only(&wav, &model) + .map(|(raw_transcript, total_ms)| inference::InferenceEvent::Pass2Completed { total_ms, raw_transcript, }) - .unwrap_or_else(openai::InferenceEvent::Failed); + .unwrap_or_else(inference::InferenceEvent::Failed); let _ = tx.send(outcome); }); } @@ -896,12 +875,12 @@ impl VoxitApp { let raw = self.transcription_result.clone(); thread::spawn(move || { - let outcome = openai::rewrite_only(&raw, &model) - .map(|(result, total_ms)| openai::InferenceEvent::RewriteCompleted { + let outcome = inference::rewrite_only(&raw, &model) + .map(|(result, total_ms)| inference::InferenceEvent::RewriteCompleted { total_ms, result, }) - .unwrap_or_else(openai::InferenceEvent::Failed); + .unwrap_or_else(inference::InferenceEvent::Failed); let _ = tx.send(outcome); }); } @@ -1232,8 +1211,8 @@ impl VoxitApp { auth_event_tx: Sender, realtime_event_rx: Receiver, realtime_event_tx: Sender, - inference_tx: Sender, - inference_rx: Receiver, + inference_tx: Sender, + inference_rx: Receiver, hotkey_mode_u8: Arc, _hotkey_manager: Option, _tray_icon: TrayIcon, @@ -1469,7 +1448,7 @@ fn run_ui() -> Result<()> { let (command_tx, command_rx) = mpsc::channel::(); let (auth_event_tx, auth_event_rx) = mpsc::channel::(); let (realtime_event_tx, realtime_event_rx) = mpsc::channel::(); - let (inference_tx, inference_rx) = mpsc::channel::(); + let (inference_tx, inference_rx) = mpsc::channel::(); let initial_hotkey = if app_config.hotkey.mode == "hold" { HotkeyMode::Hold } else { HotkeyMode::Toggle }; let hotkey_mode = Arc::new(AtomicU8::new(initial_hotkey.as_u8())); @@ -1541,7 +1520,7 @@ fn run_ui() -> Result<()> { let (_command_tx, command_rx) = mpsc::channel::(); let (auth_event_tx, auth_event_rx) = mpsc::channel::(); let (realtime_event_tx, realtime_event_rx) = mpsc::channel::(); - let (inference_tx, inference_rx) = mpsc::channel::(); + let (inference_tx, inference_rx) = mpsc::channel::(); let initial_hotkey = if app_config.hotkey.mode == "hold" { HotkeyMode::Hold } else { HotkeyMode::Toggle }; let hotkey_mode = Arc::new(AtomicU8::new(initial_hotkey.as_u8())); diff --git a/packages/voxit-core/Cargo.toml b/packages/voxit-core/Cargo.toml index cb4e0af..3d820a5 100644 --- a/packages/voxit-core/Cargo.toml +++ b/packages/voxit-core/Cargo.toml @@ -17,6 +17,7 @@ voxit-realtime = ["dep:futures-util", "dep:http", "dep:tokio", "dep:tokio-tungst base64 = { version = "0.22" } directories = { version = "6.0" } futures-util = { version = "0.3", optional = true } +hound = { version = "3.5" } http = { version = "1.3", optional = true } keyring = { version = "3.6", features = ["apple-native"] } rand = { version = "0.10" } @@ -31,3 +32,10 @@ tracing = { version = "0.1" } url = { version = "2.5" } voxit-audio = { path = "../voxit-audio" } webbrowser = { version = "1.1" } + +[target.'cfg(target_os = "macos")'.dependencies] +core-foundation = { version = "0.10" } +core-foundation-sys = { version = "0.8" } +objc2-foundation = { version = "0.3.2", default-features = false, features = ["NSObjCRuntime", "NSObject", "NSString", "alloc", "std"] } +objc2-local-authentication = { version = "0.3.2", default-features = false, features = ["LAContext", "alloc", "std"] } +security-framework-sys = { version = "2.16" } diff --git a/packages/voxit-core/src/audio_payload.rs b/packages/voxit-core/src/audio_payload.rs new file mode 100644 index 0000000..9b12a83 --- /dev/null +++ b/packages/voxit-core/src/audio_payload.rs @@ -0,0 +1,244 @@ +//! Audio payload preparation shared by speech providers. + +#[cfg(target_os = "macos")] use std::io::Cursor; + +#[cfg(target_os = "macos")] use hound::{SampleFormat, WavReader, WavSpec, WavWriter}; + +#[cfg(target_os = "macos")] +const MODEL_AUDIO_SAMPLE_RATE: u32 = 24_000; +#[cfg(target_os = "macos")] +const MODEL_AUDIO_CHANNELS: u16 = 1; + +/// Metadata extracted from a WAV payload. +#[cfg(target_os = "macos")] +#[derive(Clone, Copy, Debug)] +pub(crate) struct WavMetadata { + pub(crate) sample_rate_hz: u32, + pub(crate) channels: u16, + pub(crate) bits_per_sample: u16, + pub(crate) duration_ms: u64, +} + +/// Provider-ready transcription audio payload. +#[cfg(target_os = "macos")] +#[derive(Debug)] +pub(crate) struct PreparedTranscriptionAudio { + pub(crate) wav: Vec, + pub(crate) input: WavMetadata, + pub(crate) request: WavMetadata, +} + +/// Normalize captured WAV bytes into the ChatGPT transcription input shape. +#[cfg(target_os = "macos")] +pub(crate) fn prepare_chatgpt_transcription_wav( + wav: &[u8], +) -> Result { + let (input_spec, input_samples) = decode_wav_pcm16(wav)?; + let input = wav_metadata(&input_spec, input_samples.len()); + let request_wav = encode_wav_normalized( + &input_samples, + input_spec.sample_rate, + input_spec.channels, + MODEL_AUDIO_SAMPLE_RATE, + MODEL_AUDIO_CHANNELS, + )?; + let request = inspect_wav_metadata(&request_wav)?; + + Ok(PreparedTranscriptionAudio { wav: request_wav, input, request }) +} + +#[cfg(target_os = "macos")] +fn inspect_wav_metadata(wav: &[u8]) -> Result { + let reader = + WavReader::new(Cursor::new(wav)).map_err(|err| format!("failed to parse wav: {err}"))?; + let spec = reader.spec(); + let samples = reader.duration() as usize; + + Ok(wav_metadata(&spec, samples)) +} + +#[cfg(target_os = "macos")] +fn wav_metadata(spec: &WavSpec, sample_count: usize) -> WavMetadata { + let channels = spec.channels.max(1); + let frames = sample_count / usize::from(channels); + let duration_ms = if spec.sample_rate == 0 { + 0 + } else { + ((frames as u128) * 1_000 / (spec.sample_rate as u128)) as u64 + }; + + WavMetadata { + sample_rate_hz: spec.sample_rate, + channels: spec.channels, + bits_per_sample: spec.bits_per_sample, + duration_ms, + } +} + +#[cfg(target_os = "macos")] +fn decode_wav_pcm16(wav: &[u8]) -> Result<(WavSpec, Vec), String> { + let mut reader = + WavReader::new(Cursor::new(wav)).map_err(|err| format!("failed to parse wav: {err}"))?; + let spec = reader.spec(); + + if spec.channels == 0 { + return Err("wav has invalid channel count (0)".to_string()); + } + if spec.sample_rate == 0 { + return Err("wav has invalid sample rate (0)".to_string()); + } + + let samples = match spec.sample_format { + SampleFormat::Int if spec.bits_per_sample == 0 || spec.bits_per_sample > 32 => + return Err(format!("unsupported integer wav bit depth: {}", spec.bits_per_sample)), + SampleFormat::Int if spec.bits_per_sample <= 16 => reader + .samples::() + .map(|sample| sample.map_err(|err| err.to_string())) + .collect::, _>>()?, + SampleFormat::Int => { + let shift = spec.bits_per_sample.saturating_sub(16).min(16); + + reader + .samples::() + .map(|sample| { + sample.map_err(|err| err.to_string()).map(|value| { + (value >> shift).clamp(i16::MIN as i32, i16::MAX as i32) as i16 + }) + }) + .collect::, _>>()? + }, + SampleFormat::Float => reader + .samples::() + .map(|sample| { + sample.map_err(|err| err.to_string()).map(|value| { + (value.clamp(-1.0, 1.0) * (i16::MAX as f32)) + .round() + .clamp(i16::MIN as f32, i16::MAX as f32) as i16 + }) + }) + .collect::, _>>()?, + }; + + Ok((spec, samples)) +} + +#[cfg(target_os = "macos")] +fn encode_wav_normalized( + input: &[i16], + input_sample_rate: u32, + input_channels: u16, + output_sample_rate: u32, + output_channels: u16, +) -> Result, String> { + if input_sample_rate == 0 + || input_channels == 0 + || output_sample_rate == 0 + || output_channels == 0 + { + return Err("invalid audio format for normalization".to_string()); + } + + let converted = if input_sample_rate == output_sample_rate && input_channels == output_channels + { + input.to_vec() + } else { + convert_pcm16(input, input_sample_rate, input_channels, output_sample_rate, output_channels) + }; + let mut peak_abs = 0_i32; + + for sample in &converted { + peak_abs = peak_abs.max((*sample as i32).abs()); + } + + let target = (i16::MAX as f32) * 0.9; + let gain = if peak_abs > 0 { target / (peak_abs as f32) } else { 1.0 }; + let spec = WavSpec { + channels: output_channels, + sample_rate: output_sample_rate, + bits_per_sample: 16, + sample_format: SampleFormat::Int, + }; + let mut cursor = Cursor::new(Vec::new()); + let mut writer = WavWriter::new(&mut cursor, spec) + .map_err(|err| format!("failed to create wav writer: {err}"))?; + + for sample in converted { + let scaled = + ((sample as f32) * gain).round().clamp(i16::MIN as f32, i16::MAX as f32) as i16; + + writer.write_sample(scaled).map_err(|err| format!("failed writing wav sample: {err}"))?; + } + + writer.finalize().map_err(|err| format!("failed to finalize wav: {err}"))?; + + Ok(cursor.into_inner()) +} + +#[cfg(target_os = "macos")] +fn convert_pcm16( + input: &[i16], + input_sample_rate: u32, + input_channels: u16, + output_sample_rate: u32, + output_channels: u16, +) -> Vec { + if input.is_empty() || input_channels == 0 || output_channels == 0 { + return Vec::new(); + } + + let in_channels = input_channels as usize; + let out_channels = output_channels as usize; + let in_frames = input.len() / in_channels; + + if in_frames == 0 { + return Vec::new(); + } + + let out_frames = if input_sample_rate == output_sample_rate { + in_frames + } else { + (((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1) + as usize + }; + let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels)); + + for out_idx in 0..out_frames { + let src_frame_idx = if output_sample_rate == input_sample_rate { + out_idx + } else { + ((out_idx as u64) * (input_sample_rate as u64) / (output_sample_rate as u64)) as usize + } + .min(in_frames - 1); + let src_start = src_frame_idx.saturating_mul(in_channels); + let src = &input[src_start..src_start + in_channels]; + + match (in_channels, out_channels) { + (1, 1) => out.push(src[0]), + (1, m) => { + let s = src[0]; + + for _ in 0..m { + out.push(s); + } + }, + (n, 1) => { + let sum: i32 = src.iter().map(|s| *s as i32).sum(); + + out.push((sum / (n as i32)) as i16); + }, + (n, m) if n == m => out.extend_from_slice(src), + (n, m) if n > m => out.extend_from_slice(&src[..m]), + (n, m) => { + out.extend_from_slice(src); + + let last = *src.last().unwrap_or(&0); + + for _ in n..m { + out.push(last); + } + }, + } + } + + out +} diff --git a/packages/voxit-core/src/auth.rs b/packages/voxit-core/src/auth.rs index fa8ddac..42a49bc 100644 --- a/packages/voxit-core/src/auth.rs +++ b/packages/voxit-core/src/auth.rs @@ -1,5 +1,280 @@ //! ChatGPT authentication using OAuth and device-code flow with secure token storage. +#[cfg(target_os = "macos")] +mod secitem_keychain { + use crate::auth::{AuthResult, ensure_keychain_user_interaction_allowed}; + use core_foundation::{ + base::{CFType, TCFType}, + boolean::CFBoolean, + data::CFData, + dictionary::CFDictionary, + string::CFString, + }; + use core_foundation_sys::{ + base::{CFTypeRef, kCFAllocatorDefault}, + data::CFDataRef, + dictionary::{CFDictionaryCreate, kCFTypeDictionaryKeyCallBacks}, + }; + use objc2_foundation::NSString; + use objc2_local_authentication::LAContext; + use security_framework_sys::{ + base::{SecCopyErrorMessageString, errSecDuplicateItem, errSecItemNotFound, errSecSuccess}, + item::{ + kSecAttrAccount, kSecAttrService, kSecClass, kSecClassGenericPassword, kSecReturnData, + kSecUseAuthenticationContext, kSecValueData, + }, + keychain_item::{SecItemAdd, SecItemCopyMatching, SecItemDelete, SecItemUpdate}, + }; + use std::{ops::Deref, ptr, time::Instant}; + + const OPERATION_PROMPT: &str = "Voxit needs Keychain access to continue sign in."; + + pub(super) struct SecItemQuery { + dict: CFDictionary, + _keepalive: Vec, + } + + pub(super) fn set_generic_password( + service: &str, + account: &str, + password: &[u8], + ) -> AuthResult<()> { + ensure_keychain_user_interaction_allowed(); + + tracing::info!(operation = "SecItemAdd", "starting secitem keychain operation"); + + let auth_context = make_authentication_context(); + let add_query = base_query(service, account, Some(&auth_context), Some(password), false)?; + let add_dict = add_query.dict; + let add_start = Instant::now(); + let add_status = unsafe { SecItemAdd(add_dict.as_concrete_TypeRef(), ptr::null_mut()) }; + + log_secitem_result("SecItemAdd", add_status, add_start); + + if add_status == errSecDuplicateItem { + tracing::info!(operation = "SecItemUpdate", "starting secitem keychain operation"); + + let update_query = base_query(service, account, Some(&auth_context), None, false)?; + let update_start = Instant::now(); + let query_dict = update_query.dict; + let update_dict = CFDictionary::from_CFType_pairs(&[( + unsafe { CFString::wrap_under_get_rule(kSecValueData) }, + CFData::from_buffer(password).into_CFType(), + )]); + let update_status = unsafe { + SecItemUpdate(query_dict.as_concrete_TypeRef(), update_dict.as_concrete_TypeRef()) + }; + + log_secitem_result("SecItemUpdate", update_status, update_start); + + if update_status != errSecSuccess { + return Err(format!( + "secitem keychain update failed: {}", + status_message(update_status) + )); + } + + return Ok(()); + } + if add_status != errSecSuccess { + return Err(format!("secitem keychain write failed: {}", status_message(add_status))); + } + + Ok(()) + } + + pub(super) fn get_generic_password( + service: &str, + account: &str, + ) -> AuthResult>> { + ensure_keychain_user_interaction_allowed(); + + tracing::info!(operation = "SecItemCopyMatching", "starting secitem keychain operation"); + + let auth_context = make_authentication_context(); + let query = base_query(service, account, Some(&auth_context), None, true)?; + let query_dict = query.dict; + let mut result: CFTypeRef = ptr::null_mut(); + let copy_start = Instant::now(); + let status = unsafe { SecItemCopyMatching(query_dict.as_concrete_TypeRef(), &mut result) }; + + log_secitem_result("SecItemCopyMatching", status, copy_start); + + if status == errSecItemNotFound { + return Ok(None); + } + if status != errSecSuccess { + return Err(format!("secitem keychain read failed: {}", status_message(status))); + } + if result.is_null() { + return Err("secitem keychain read returned empty data".to_string()); + } + + let data = unsafe { CFData::wrap_under_create_rule(result as CFDataRef) }; + + Ok(Some(data.bytes().to_vec())) + } + + pub(super) fn delete_generic_password(service: &str, account: &str) -> AuthResult<()> { + ensure_keychain_user_interaction_allowed(); + + tracing::info!(operation = "SecItemDelete", "starting secitem keychain operation"); + + let auth_context = make_authentication_context(); + let query = base_query(service, account, Some(&auth_context), None, false)?; + let query = query.dict; + let delete_start = Instant::now(); + let status = unsafe { SecItemDelete(query.as_concrete_TypeRef()) }; + + log_secitem_result("SecItemDelete", status, delete_start); + + if status == errSecItemNotFound || status == errSecSuccess { + return Ok(()); + } + + Err(format!("secitem keychain delete failed: {}", status_message(status))) + } + + fn make_authentication_context() -> impl Deref { + let context = unsafe { LAContext::new() }; + let localized_reason = NSString::from_str(OPERATION_PROMPT); + + unsafe { context.setLocalizedReason(&localized_reason) }; + + context + } + + fn base_query( + service: &str, + account: &str, + authentication_context: Option<&LAContext>, + secret: Option<&[u8]>, + request_return_data: bool, + ) -> AuthResult { + let mut keepalive = Vec::with_capacity(6 + usize::from(secret.is_some())); + let mut pairs = Vec::with_capacity(6 + usize::from(secret.is_some())); + + macro_rules! add_pair { + ($key:expr, $value:expr) => {{ + let key: CFType = $key; + let value: CFType = $value; + let key_ref = key.as_concrete_TypeRef() as CFTypeRef; + let value_ref = value.as_concrete_TypeRef() as CFTypeRef; + + keepalive.push(key); + keepalive.push(value); + pairs.push((key_ref, value_ref)); + }}; + } + + add_pair!(unsafe { CFString::wrap_under_get_rule(kSecClass).into_CFType() }, unsafe { + CFString::wrap_under_get_rule(kSecClassGenericPassword).into_CFType() + }); + add_pair!( + unsafe { CFString::wrap_under_get_rule(kSecAttrService).into_CFType() }, + CFString::from(service).into_CFType() + ); + add_pair!( + unsafe { CFString::wrap_under_get_rule(kSecAttrAccount).into_CFType() }, + CFString::from(account).into_CFType() + ); + + if request_return_data { + add_pair!( + unsafe { CFString::wrap_under_get_rule(kSecReturnData).into_CFType() }, + CFBoolean::from(true).into_CFType() + ); + } + + if let Some(secret) = secret { + add_pair!( + unsafe { CFString::wrap_under_get_rule(kSecValueData).into_CFType() }, + CFData::from_buffer(secret).into_CFType() + ); + } + if let Some(context) = authentication_context { + let context_key = unsafe { + CFString::wrap_under_get_rule(kSecUseAuthenticationContext).into_CFType() + }; + let context_key_ref = context_key.as_concrete_TypeRef() as CFTypeRef; + let context_ptr = (context as *const LAContext) as CFTypeRef; + + keepalive.push(context_key); + pairs.push((context_key_ref, context_ptr)); + } + + let query_dict = build_query_dictionary(&pairs)?; + + Ok(SecItemQuery { dict: query_dict, _keepalive: keepalive }) + } + + fn build_query_dictionary( + pairs: &[(CFTypeRef, CFTypeRef)], + ) -> AuthResult> { + let mut keys: Vec = Vec::with_capacity(pairs.len()); + let mut values: Vec = Vec::with_capacity(pairs.len()); + + for (key, value) in pairs { + keys.push(*key); + values.push(*value); + } + + let dictionary_ref = unsafe { + CFDictionaryCreate( + kCFAllocatorDefault, + keys.as_ptr(), + values.as_ptr(), + keys.len() as isize, + &kCFTypeDictionaryKeyCallBacks, + ptr::null(), + ) + }; + + if dictionary_ref.is_null() { + return Err("secitem query creation failed: null dictionary pointer".to_string()); + } + + let query = unsafe { CFDictionary::wrap_under_create_rule(dictionary_ref) }; + + Ok(query) + } + + fn log_secitem_result(operation: &str, status: i32, start_time: Instant) { + let status_message = status_message(status); + let elapsed_ms = start_time.elapsed().as_millis(); + + if status == errSecSuccess { + tracing::info!( + operation, + elapsed_ms, + os_status = status, + status_message = %status_message, + "secitem operation completed" + ); + + return; + } + + tracing::warn!( + operation, + elapsed_ms, + os_status = status, + status_message = %status_message, + "secitem operation failed" + ); + } + + fn status_message(status: i32) -> String { + let message_ref = unsafe { SecCopyErrorMessageString(status, ptr::null_mut()) }; + + if message_ref.is_null() { + return format!("OSStatus {status}"); + } + + unsafe { CFString::wrap_under_create_rule(message_ref).to_string() } + } +} + use std::{ collections::HashMap, env, @@ -8,7 +283,7 @@ use std::{ os::unix::fs::{OpenOptionsExt as _, PermissionsExt as _}, path::{Path, PathBuf}, string::{String, ToString}, - sync::{Condvar, Mutex, OnceLock, RwLock}, + sync::{Condvar, Mutex, OnceLock, RwLock, mpsc, mpsc::RecvTimeoutError}, thread, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; @@ -29,20 +304,36 @@ type AuthResult = std::result::Result; const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; const DEFAULT_ISSUER: &str = "https://auth.openai.com"; const DEFAULT_PORT: u16 = 1_455; +const FALLBACK_PORT: u16 = 1_457; const REDIRECT_URI_PATH: &str = "/auth/callback"; const CODEX_OAUTH_ORIGINATOR: &str = "codex_cli_rs"; +const CODEX_OAUTH_SCOPE: &str = + "openid profile email offline_access api.connectors.read api.connectors.invoke"; +const REFRESH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; +const TOKEN_REFRESH_SKEW_SECS: u64 = 60; const KEYRING_SERVICE: &str = "Voxit Auth"; const KEYRING_KEY_PREFIX: &str = "cli|"; const AUTH_FILE_NAME: &str = "auth.json"; const AUTH_FILE_FALLBACK_ENV: &str = "VOXIT_AUTH_FILE_FALLBACK"; const KEYRING_VERIFY_ENABLED_ENV: &str = "VOXIT_VERIFY_KEYRING"; +const KEYCHAIN_BACKEND_ENV: &str = "VOXIT_KEYCHAIN_BACKEND"; const KEYRING_VERIFY_ATTEMPTS: usize = 5; const KEYRING_VERIFY_DELAY_MS: u64 = 120; +const KEYCHAIN_OPERATION_TIMEOUT_SECS: u64 = 12; #[cfg(test)] const TEST_FORCE_KEYRING_ERROR_ENV: &str = "VOXIT_TEST_FORCE_KEYRING_ERROR"; static SESSION_TOKEN_CACHE: OnceLock>> = OnceLock::new(); static STORED_AUTH_CACHE: OnceLock<(Mutex, Condvar)> = OnceLock::new(); +static AUTH_STATUS_LOGGED: OnceLock<()> = OnceLock::new(); +static KEYCHAIN_BACKEND_LOGGED: OnceLock<()> = OnceLock::new(); + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum KeychainBackend { + Keyring, + #[cfg(target_os = "macos")] + SecItem, +} /// Authentication result returned to the UI after sign-in. #[derive(Clone, Debug)] @@ -60,6 +351,12 @@ pub struct AuthStatus { pub account_id: Option, } +#[derive(Clone, Debug)] +pub(crate) struct ChatGptAuthContext { + pub bearer_token: String, + pub account_id: Option, +} + #[derive(Clone, Debug, Default)] struct StoredAuthCacheState { loading: bool, @@ -104,8 +401,20 @@ struct DeviceLoginCode { code_verifier: String, } +#[derive(Debug, Deserialize)] +struct RefreshTokenResponse { + id_token: Option, + access_token: Option, + refresh_token: Option, + expires_in: Option, +} + /// Return stored authentication status without leaking token payload. pub fn status() -> AuthStatus { + if AUTH_STATUS_LOGGED.set(()).is_ok() { + tracing::info!("auth status() invoked"); + } + if let Some(tokens) = load_cached_tokens() { return AuthStatus { signed_in: true, account_id: tokens.account_id }; } @@ -171,40 +480,50 @@ where Ok(AuthRecord { account_id: tokens.account_id }) } -/// Returns `(access_token, account_id)` for API calls. -/// Falls back to `OPENAI_API_KEY` only when no stored OAuth token exists. +/// Returns `(access_token, account_id)` for ChatGPT-backed API calls. pub fn access_token() -> AuthResult<(String, Option)> { + let ctx = chatgpt_auth_context()?; + + Ok((ctx.bearer_token, ctx.account_id)) +} + +pub(crate) fn chatgpt_auth_context() -> AuthResult { if let Some(tokens) = load_cached_tokens() { - return Ok((tokens.access_token, tokens.account_id)); + return Ok(ChatGptAuthContext { + bearer_token: tokens.access_token, + account_id: tokens.account_id, + }); } if let Some(tokens) = load_stored_auth_tokens()? { cache_session_tokens(&tokens); - return Ok((tokens.access_token, tokens.account_id)); + return Ok(ChatGptAuthContext { + bearer_token: tokens.access_token, + account_id: tokens.account_id, + }); } - env::var("OPENAI_API_KEY") - .map(|value| (value, None)) - .map_err(|_| "not signed in and OPENAI_API_KEY is not set".to_string()) + Err("not signed in with ChatGPT".to_string()) } fn sign_in_with_chatgpt_browser() -> AuthResult { let pkce = generate_pkce(); let state = generate_state(); - let redirect_uri = browser_redirect_uri(); + let server = bind_callback_server()?; + let redirect_uri = browser_redirect_uri(server_port(&server)?); let authorize_url = build_authorize_url(&redirect_uri, &pkce.code_challenge, &state, DEFAULT_ISSUER); webbrowser::open(&authorize_url) .map_err(|_| "failed to open browser for ChatGPT login".to_string())?; - wait_for_callback(&state, &pkce, &redirect_uri, DEFAULT_ISSUER) + wait_for_callback(server, &state, &pkce, &redirect_uri, DEFAULT_ISSUER) } -fn browser_redirect_uri() -> String { +fn browser_redirect_uri(port: u16) -> String { // Codex OSS uses http://localhost:/auth/callback for browser OAuth redirect URI. // Aligning here avoids auth.openai.com rejecting 127.0.0.1 redirect URIs for this client id. - format!("http://localhost:{DEFAULT_PORT}{REDIRECT_URI_PATH}") + format!("http://localhost:{port}{REDIRECT_URI_PATH}") } fn valid_tokens_or_none(auth: Option) -> AuthResult> { @@ -217,8 +536,8 @@ fn valid_tokens_or_none(auth: Option) -> AuthResult return Ok(None), }; - if is_token_expired(tokens.created_at_unix, tokens.expires_in_seconds) { - return Ok(None); + if is_token_data_expired(&tokens) { + return refresh_stored_tokens(&tokens).map(Some); } Ok(Some(tokens)) @@ -232,7 +551,7 @@ fn load_cached_tokens() -> Option { }; let tokens = cached?; - if is_token_expired(tokens.created_at_unix, tokens.expires_in_seconds) { + if is_token_data_expired(&tokens) { clear_cached_session_tokens(); return None; @@ -460,6 +779,63 @@ fn env_flag_enabled(name: &str) -> bool { } } +fn keychain_backend() -> KeychainBackend { + #[cfg(target_os = "macos")] + { + let configured = + env::var(KEYCHAIN_BACKEND_ENV).unwrap_or_default().trim().to_ascii_lowercase(); + + if configured == "keyring" { + let backend = KeychainBackend::Keyring; + + if KEYCHAIN_BACKEND_LOGGED.set(()).is_ok() { + tracing::info!(?backend, env = %KEYCHAIN_BACKEND_ENV, "keychain backend selected"); + } + + return backend; + } + + let backend = KeychainBackend::SecItem; + + if KEYCHAIN_BACKEND_LOGGED.set(()).is_ok() { + tracing::info!(?backend, env = %KEYCHAIN_BACKEND_ENV, "keychain backend selected"); + } + + backend + } + #[cfg(not(target_os = "macos"))] + { + let backend = KeychainBackend::Keyring; + + if KEYCHAIN_BACKEND_LOGGED.set(()).is_ok() { + tracing::info!(?backend, env = %KEYCHAIN_BACKEND_ENV, "keychain backend selected"); + } + + backend + } +} + +fn run_with_timeout(operation: &str, timeout: Duration, operation_fn: F) -> AuthResult +where + T: Send + 'static, + F: FnOnce() -> AuthResult + Send + 'static, +{ + let (tx, rx) = mpsc::sync_channel(1); + let operation_name = operation.to_string(); + + thread::spawn(move || { + let _ = tx.send(operation_fn()); + }); + + match rx.recv_timeout(timeout) { + Ok(result) => result, + Err(RecvTimeoutError::Timeout) => + Err(format!("{operation_name} timed out after {}s", timeout.as_secs())), + Err(RecvTimeoutError::Disconnected) => + Err(format!("{operation_name} failed before completion")), + } +} + fn stored_auth_cache() -> &'static (Mutex, Condvar) { STORED_AUTH_CACHE.get_or_init(|| (Mutex::new(StoredAuthCacheState::default()), Condvar::new())) } @@ -472,7 +848,7 @@ fn load_stored_auth_tokens() -> AuthResult> { if let Some(cached) = state.result.clone() { match cached { Some(tokens) => { - if is_token_expired(tokens.created_at_unix, tokens.expires_in_seconds) { + if is_token_data_expired(&tokens) { clear_cached_session_tokens(); state.result = None; @@ -561,27 +937,208 @@ fn clear_cached_session_tokens() { *cache = None; } +#[cfg(target_os = "macos")] +fn ensure_keychain_user_interaction_allowed() { + // If keychain user interaction has been disabled in-process (or the OS decides it is), + // keychain reads can fail without presenting the expected password/permission prompt. + // Re-enable interaction before prompt-critical operations. + #[allow(non_snake_case)] + unsafe extern "C" { + fn SecKeychainSetUserInteractionAllowed(state: u8) -> i32; + } + + unsafe { + let _ = SecKeychainSetUserInteractionAllowed(1_u8); + } +} + +#[cfg(not(target_os = "macos"))] +fn ensure_keychain_user_interaction_allowed() {} + fn save_to_keyring(base: &Path, payload: &str) -> io::Result<()> { let key = auth_key(base).map_err(Error::other)?; + let payload = payload.to_string(); #[cfg(test)] if env_flag_enabled(TEST_FORCE_KEYRING_ERROR_ENV) { return Err(Error::other("forced test keyring error")); } - let entry = Entry::new(KEYRING_SERVICE, &key).map_err(Error::other)?; + let backend = keychain_backend(); + let op = "keychain write"; + let start = Instant::now(); + + tracing::info!(op = op, ?backend, "starting keychain operation"); + + let result = match backend { + #[cfg(target_os = "macos")] + KeychainBackend::SecItem => save_keychain_payload(&key, &payload).map_err(Error::other), + KeychainBackend::Keyring => + run_with_timeout(op, Duration::from_secs(KEYCHAIN_OPERATION_TIMEOUT_SECS), move || { + save_keychain_payload(&key, &payload) + }) + .map_err(Error::other), + }; + + match &result { + Ok(_) => tracing::info!( + op = op, + ?backend, + elapsed_ms = start.elapsed().as_millis(), + "completed keychain write operation" + ), + Err(err) => tracing::warn!( + op = op, + ?backend, + os_status = -1_i32, + status_message = %err.to_string(), + "failed keychain write operation" + ), + } - entry.set_password(payload).map_err(Error::other) + result } fn load_from_keyring(base: &Path) -> AuthResult> { let key = auth_key(base)?; - let entry = match Entry::new(KEYRING_SERVICE, &key) { + let backend = keychain_backend(); + let op = "keychain read"; + let start = Instant::now(); + + tracing::info!(op = op, ?backend, "starting keychain operation"); + + let value = match backend { + #[cfg(target_os = "macos")] + KeychainBackend::SecItem => load_keychain_payload(&key)?, + KeychainBackend::Keyring => + run_with_timeout(op, Duration::from_secs(KEYCHAIN_OPERATION_TIMEOUT_SECS), move || { + load_keychain_payload(&key) + })?, + }; + let value = match value { + Some(value) => value, + None => { + tracing::info!( + op = op, + ?backend, + elapsed_ms = start.elapsed().as_millis(), + "completed keychain read operation (not found)" + ); + + return Ok(None); + }, + }; + let parsed = serde_json::from_str(&value).map_err(|err| { + tracing::warn!( + op = op, + ?backend, + os_status = -1_i32, + error = %format!("decode keyring auth json failed: {err}"), + "failed keychain read operation" + ); + + format!("decode keyring auth json failed: {err}") + })?; + + tracing::info!( + op = op, + ?backend, + elapsed_ms = start.elapsed().as_millis(), + "completed keychain read operation" + ); + + Ok(Some(parsed)) +} + +fn clear_keyring_entry(base: &Path) -> AuthResult<()> { + let key = auth_key(base)?; + let backend = keychain_backend(); + let op = "keychain delete"; + let start = Instant::now(); + + tracing::info!(op = op, ?backend, "starting keychain operation"); + + let result = match backend { + #[cfg(target_os = "macos")] + KeychainBackend::SecItem => clear_keychain_payload(&key), + KeychainBackend::Keyring => + run_with_timeout(op, Duration::from_secs(KEYCHAIN_OPERATION_TIMEOUT_SECS), move || { + clear_keychain_payload(&key) + }), + }; + + match &result { + Ok(_) => tracing::info!( + op = op, + ?backend, + elapsed_ms = start.elapsed().as_millis(), + "completed keychain delete operation" + ), + Err(err) => tracing::warn!( + op = op, + ?backend, + os_status = -1_i32, + status_message = %err.to_string(), + "failed keychain delete operation" + ), + } + + result +} + +fn save_keychain_payload(key: &str, payload: &str) -> AuthResult<()> { + match keychain_backend() { + #[cfg(target_os = "macos")] + KeychainBackend::SecItem => + secitem_keychain::set_generic_password(KEYRING_SERVICE, key, payload.as_bytes()), + KeychainBackend::Keyring => save_keychain_payload_via_keyring(key, payload), + } +} + +fn load_keychain_payload(key: &str) -> AuthResult> { + match keychain_backend() { + #[cfg(target_os = "macos")] + KeychainBackend::SecItem => { + let bytes = secitem_keychain::get_generic_password(KEYRING_SERVICE, key)?; + + match bytes { + Some(bytes) => String::from_utf8(bytes) + .map(Some) + .map_err(|err| format!("decode keychain payload utf8 failed: {err}")), + None => Ok(None), + } + }, + KeychainBackend::Keyring => load_keychain_payload_via_keyring(key), + } +} + +fn clear_keychain_payload(key: &str) -> AuthResult<()> { + match keychain_backend() { + #[cfg(target_os = "macos")] + KeychainBackend::SecItem => secitem_keychain::delete_generic_password(KEYRING_SERVICE, key), + KeychainBackend::Keyring => clear_keychain_payload_via_keyring(key), + } +} + +fn save_keychain_payload_via_keyring(key: &str, payload: &str) -> AuthResult<()> { + ensure_keychain_user_interaction_allowed(); + + let entry = + Entry::new(KEYRING_SERVICE, key).map_err(|err| format!("keyring init failed: {err}"))?; + + entry.set_password(payload).map_err(|err| format!("keyring write failed: {err}")) +} + +fn load_keychain_payload_via_keyring(key: &str) -> AuthResult> { + ensure_keychain_user_interaction_allowed(); + + let entry = match Entry::new(KEYRING_SERVICE, key) { Ok(entry) => entry, Err(_) => return Ok(None), }; - let value = match entry.get_password() { - Ok(value) => value, + + match entry.get_password() { + Ok(value) => Ok(Some(value)), Err(err) => { let err_text = err.to_string(); @@ -589,19 +1146,15 @@ fn load_from_keyring(base: &Path) -> AuthResult> { return Ok(None); } - return Err(format!("keyring read failed: {err_text}")); + Err(format!("keyring read failed: {err_text}")) }, - }; - let parsed = serde_json::from_str(&value) - .map_err(|err| format!("decode keyring auth json failed: {err}"))?; - - Ok(Some(parsed)) + } } -fn clear_keyring_entry(base: &Path) -> AuthResult<()> { - let key = auth_key(base)?; +fn clear_keychain_payload_via_keyring(key: &str) -> AuthResult<()> { + ensure_keychain_user_interaction_allowed(); - match Entry::new(KEYRING_SERVICE, &key) { + match Entry::new(KEYRING_SERVICE, key) { Ok(entry) => { if let Err(err) = entry.delete_credential() { let message = err.to_string(); @@ -676,14 +1229,45 @@ fn clear_auth_file(base: &Path) -> AuthResult<()> { } } +fn bind_callback_server() -> AuthResult { + let primary = format!("127.0.0.1:{DEFAULT_PORT}"); + + match Server::http(&primary) { + Ok(server) => Ok(server), + Err(primary_err) => { + let fallback = format!("127.0.0.1:{FALLBACK_PORT}"); + + tracing::warn!( + error = %primary_err, + primary_port = DEFAULT_PORT, + fallback_port = FALLBACK_PORT, + "auth callback primary port unavailable; trying fallback" + ); + + Server::http(&fallback).map_err(|fallback_err| { + format!( + "failed to bind local callback server on {primary} or {fallback}: {fallback_err}" + ) + }) + }, + } +} + +fn server_port(server: &Server) -> AuthResult { + server + .server_addr() + .to_ip() + .map(|addr| addr.port()) + .ok_or_else(|| "failed to resolve local callback server port".to_string()) +} + fn wait_for_callback( + server: Server, expected_state: &str, pkce: &PkceCodes, redirect_uri: &str, issuer: &str, ) -> AuthResult { - let server = Server::http(format!("localhost:{DEFAULT_PORT}")) - .map_err(|err| format!("failed to bind local callback server: {err}"))?; let start = Instant::now(); let timeout = Duration::from_secs(180); @@ -933,7 +1517,7 @@ fn build_authorize_url( url.query_pairs_mut().append_pair("response_type", "code"); url.query_pairs_mut().append_pair("client_id", CLIENT_ID); url.query_pairs_mut().append_pair("redirect_uri", redirect_uri); - url.query_pairs_mut().append_pair("scope", "openid profile email offline_access"); + url.query_pairs_mut().append_pair("scope", CODEX_OAUTH_SCOPE); url.query_pairs_mut().append_pair("code_challenge", code_challenge); url.query_pairs_mut().append_pair("code_challenge_method", "S256"); url.query_pairs_mut().append_pair("id_token_add_organizations", "true"); @@ -1018,16 +1602,71 @@ fn parse_error_text(raw: &str) -> String { raw.to_string() } +fn refresh_stored_tokens(tokens: &TokenData) -> AuthResult { + let refresh_token = tokens.refresh_token.as_ref().ok_or_else(|| { + "stored ChatGPT OAuth token expired and has no refresh token; sign in again".to_string() + })?; + let payload = serde_json::json!({ + "client_id": CLIENT_ID, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }); + let response = post_json(REFRESH_TOKEN_URL, &payload.to_string()) + .map_err(|err| format!("ChatGPT OAuth refresh failed: {err}"))?; + let parsed: RefreshTokenResponse = serde_json::from_str(&response) + .map_err(|err| format!("invalid ChatGPT OAuth refresh response: {err}"))?; + let access_token = parsed + .access_token + .ok_or_else(|| "ChatGPT OAuth refresh response missing access_token".to_string())?; + let id_token = parsed.id_token.unwrap_or_else(|| tokens.id_token.clone()); + let account_id = extract_claims(&id_token).and_then(|claims| { + claims.get("chatgpt_account_id").and_then(|value| value.as_str()).map(str::to_string) + }); + let refreshed = TokenData { + id_token, + access_token, + refresh_token: parsed.refresh_token.or_else(|| tokens.refresh_token.clone()), + account_id: account_id.or_else(|| tokens.account_id.clone()), + created_at_unix: now_unix(), + expires_in_seconds: parsed.expires_in, + }; + + store_tokens(&refreshed)?; + + Ok(refreshed) +} + fn extract_claims(id_token: &str) -> Option> { - let mut parts = id_token.split('.'); + let value = decode_jwt_payload(id_token)?; + let claims = value.get("https://api.openai.com/auth")?.as_object()?; + + Some(claims.clone().into_iter().collect()) +} + +fn decode_jwt_payload(jwt: &str) -> Option { + let mut parts = jwt.split('.'); let _header = parts.next()?; let payload = parts.next()?; let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload.as_bytes()).ok()?; - let value: Value = serde_json::from_slice(&payload).ok()?; - let claims = value.get("https://api.openai.com/auth")?.as_object()?; - Some(claims.clone().into_iter().collect()) + serde_json::from_slice(&payload).ok() +} + +fn token_expires_at_unix(tokens: &TokenData) -> Option { + if let Some(expires_in) = tokens.expires_in_seconds { + return Some(tokens.created_at_unix.saturating_add(expires_in)); + } + + decode_jwt_payload(&tokens.access_token)?.get("exp")?.as_u64() +} + +fn is_token_data_expired(tokens: &TokenData) -> bool { + if let Some(expires_at) = token_expires_at_unix(tokens) { + return now_unix().saturating_add(TOKEN_REFRESH_SKEW_SECS) >= expires_at; + } + + is_token_expired(tokens.created_at_unix, tokens.expires_in_seconds) } fn is_token_expired(created_at_unix: u64, expires_in: Option) -> bool { @@ -1105,12 +1744,18 @@ fn html_escape(raw: &str) -> String { #[cfg(test)] mod tests { - use std::{env, fs, sync::Mutex}; + use std::{ + env, fs, + sync::Mutex, + thread, + time::{Duration, Instant}, + }; use crate::auth::{ - self, AUTH_FILE_FALLBACK_ENV, CLIENT_ID, CODEX_OAUTH_ORIGINATOR, DEFAULT_ISSUER, - DEFAULT_PORT, HashMap, KEYRING_VERIFY_ENABLED_ENV, REDIRECT_URI_PATH, StoredAuth, - TEST_FORCE_KEYRING_ERROR_ENV, TokenData, Url, + self, AUTH_FILE_FALLBACK_ENV, CLIENT_ID, CODEX_OAUTH_ORIGINATOR, CODEX_OAUTH_SCOPE, + DEFAULT_ISSUER, DEFAULT_PORT, HashMap, KEYCHAIN_BACKEND_ENV, KEYRING_VERIFY_ENABLED_ENV, + KeychainBackend, REDIRECT_URI_PATH, StoredAuth, TEST_FORCE_KEYRING_ERROR_ENV, TokenData, + Url, }; static TEST_MUTEX: Mutex<()> = Mutex::new(()); @@ -1138,7 +1783,7 @@ mod tests { #[test] fn browser_redirect_uri_matches_codex() { assert_eq!( - auth::browser_redirect_uri(), + auth::browser_redirect_uri(DEFAULT_PORT), format!("http://localhost:{DEFAULT_PORT}{REDIRECT_URI_PATH}") ); } @@ -1146,7 +1791,7 @@ mod tests { #[test] fn authorize_url_includes_expected_codex_params() { let url = auth::build_authorize_url( - &auth::browser_redirect_uri(), + &auth::browser_redirect_uri(DEFAULT_PORT), "challenge123", "state123", DEFAULT_ISSUER, @@ -1159,12 +1804,9 @@ mod tests { assert_eq!(params.get("client_id").map(String::as_str), Some(CLIENT_ID)); assert_eq!( params.get("redirect_uri").map(String::as_str), - Some(auth::browser_redirect_uri().as_str()) - ); - assert_eq!( - params.get("scope").map(String::as_str), - Some("openid profile email offline_access") + Some(auth::browser_redirect_uri(DEFAULT_PORT).as_str()) ); + assert_eq!(params.get("scope").map(String::as_str), Some(CODEX_OAUTH_SCOPE)); assert_eq!(params.get("code_challenge").map(String::as_str), Some("challenge123")); assert_eq!(params.get("code_challenge_method").map(String::as_str), Some("S256")); assert_eq!(params.get("id_token_add_organizations").map(String::as_str), Some("true")); @@ -1251,6 +1893,50 @@ mod tests { restore_env(KEYRING_VERIFY_ENABLED_ENV, previous); } + #[test] + fn keychain_backend_escape_hatch_selects_keyring() { + let _guard = TEST_MUTEX.lock().unwrap(); + let previous = set_env(KEYCHAIN_BACKEND_ENV, Some("keyring")); + + assert_eq!(auth::keychain_backend(), KeychainBackend::Keyring); + + restore_env(KEYCHAIN_BACKEND_ENV, previous); + } + + #[cfg(target_os = "macos")] + #[test] + fn keychain_backend_defaults_to_secitem_on_macos() { + let _guard = TEST_MUTEX.lock().unwrap(); + let previous = set_env(KEYCHAIN_BACKEND_ENV, None); + + assert_eq!(auth::keychain_backend(), KeychainBackend::SecItem); + + restore_env(KEYCHAIN_BACKEND_ENV, previous); + } + + #[test] + fn timeout_helper_returns_success_before_deadline() { + let value = + auth::run_with_timeout("test-op", Duration::from_millis(80), || Ok::(7)) + .expect("operation should complete"); + + assert_eq!(value, 7); + } + + #[test] + fn timeout_helper_stops_waiting_after_deadline() { + let start = Instant::now(); + let err = auth::run_with_timeout("test-timeout", Duration::from_millis(20), || { + thread::sleep(Duration::from_millis(120)); + + Ok::<(), String>(()) + }) + .expect_err("operation should time out"); + + assert!(err.contains("timed out")); + assert!(start.elapsed() < Duration::from_millis(90)); + } + #[test] fn fallback_to_auth_json_preserves_file_when_keyring_fails() { let _guard = TEST_MUTEX.lock().unwrap(); diff --git a/packages/voxit-core/src/inference.rs b/packages/voxit-core/src/inference.rs new file mode 100644 index 0000000..dff7be6 --- /dev/null +++ b/packages/voxit-core/src/inference.rs @@ -0,0 +1,325 @@ +//! Provider-routed transcription and rewrite pipeline. + +use std::sync::mpsc::{Receiver, Sender}; +#[cfg(target_os = "macos")] use std::{collections::BTreeMap, time::Instant}; + +#[cfg(target_os = "macos")] +use crate::providers::{self, InferenceProvider, RewriteRequest, TranscriptionRequest}; +use crate::{ + providers::chatgpt::ChatGptProvider, + realtime::{RealtimeError, RealtimeEvent, RealtimeSession, RealtimeSessionConfig}, +}; +use voxit_audio::AudioChunk; + +/// Rewrite outcome status. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum RewriteState { + /// Rewrite was intentionally skipped before request. + Skipped, + /// Rewrite succeeded and output passed all safety checks. + Applied, + /// Rewrite was returned but rejected due to protected token mismatch. + Rejected, +} + +/// Guarded rewrite result payload. +#[derive(Clone, Debug)] +pub struct RewriteResult { + /// Optional rewritten transcript when state is `Applied`. + pub rewritten_transcript: Option, + /// Rewrite decision. + pub state: RewriteState, + /// Optional reason for skipped or rejected rewrite. + pub reason: Option, +} + +/// Background event sent to the UI thread. +#[derive(Debug)] +pub enum InferenceEvent { + /// Pass2 transcription completed with raw transcript text. + Pass2Completed { + /// Pass2 duration in milliseconds. + total_ms: u64, + /// Raw transcript text. + raw_transcript: String, + }, + /// Pass3 rewrite completed (or rejected by guard). + RewriteCompleted { + /// Pass3 duration in milliseconds. + total_ms: u64, + /// Rewrite result. + result: RewriteResult, + }, + /// Pipeline failed with an error. + Failed(String), +} + +/// Start the configured realtime transcription provider. +#[cfg(target_os = "macos")] +pub fn start_realtime_session( + config: RealtimeSessionConfig, + chunk_rx: Receiver, + event_tx: Sender, +) -> Result { + let provider = default_provider().map_err(|err| RealtimeError::RuntimeError { + reason: format!("ChatGPT OAuth provider unavailable: {err}"), + })?; + + provider.start_realtime_session(config, chunk_rx, event_tx) +} + +/// Realtime inference is unavailable on non-macOS placeholder builds. +#[cfg(not(target_os = "macos"))] +pub fn start_realtime_session( + config: RealtimeSessionConfig, + chunk_rx: Receiver, + event_tx: Sender, +) -> Result { + let _ = config; + let _ = chunk_rx; + let _ = event_tx; + + Err(RealtimeError::DependencyUnavailable { + reason: "inference pipeline is only enabled on macOS builds".to_string(), + }) +} + +/// Transcribes WAV bytes using the configured Pass2 provider. +#[cfg(target_os = "macos")] +pub fn transcribe_only(wav: &[u8], model: &str) -> Result<(String, u64), String> { + let started = Instant::now(); + let provider = default_provider()?; + let raw = provider.transcribe(TranscriptionRequest { wav, model })?; + + Ok((raw, started.elapsed().as_millis() as u64)) +} + +/// Inference pipeline is unavailable on non-macOS placeholder builds. +#[cfg(not(target_os = "macos"))] +pub fn transcribe_only(_wav: &[u8], _model: &str) -> Result<(String, u64), String> { + Err("inference pipeline is only enabled on macOS builds.".to_string()) +} + +/// Rewrites transcript text with protected-token guard checks. +#[cfg(target_os = "macos")] +pub fn rewrite_only(text: &str, model: &str) -> Result<(RewriteResult, u64), String> { + if text.trim().is_empty() { + return Ok(( + RewriteResult { + rewritten_transcript: None, + state: RewriteState::Skipped, + reason: Some("empty transcript; rewrite skipped".to_string()), + }, + 0, + )); + } + + let started = Instant::now(); + let result = rewrite_with_guard(text, model)?; + + Ok((result, started.elapsed().as_millis() as u64)) +} + +/// Inference pipeline is unavailable on non-macOS placeholder builds. +#[cfg(not(target_os = "macos"))] +pub fn rewrite_only(_text: &str, _model: &str) -> Result<(RewriteResult, u64), String> { + Err("inference pipeline is only enabled on macOS builds.".to_string()) +} + +#[cfg(target_os = "macos")] +fn default_provider() -> Result { + providers::chatgpt_oauth_provider() +} + +#[cfg(target_os = "macos")] +fn rewrite_with_guard(text: &str, model: &str) -> Result { + let provider = default_provider()?; + let rewritten = provider.rewrite(RewriteRequest { text, model })?; + let baseline = protected_token_multiset(text); + let candidate = protected_token_multiset(&rewritten); + + if baseline != candidate { + return Ok(RewriteResult { + rewritten_transcript: None, + state: RewriteState::Rejected, + reason: Some( + "rewrite changed protected tokens (numbers, dates, or currency). Using ASR transcript for safety.".to_string(), + ), + }); + } + + Ok(RewriteResult { + rewritten_transcript: Some(rewritten), + state: RewriteState::Applied, + reason: None, + }) +} + +#[cfg(target_os = "macos")] +fn protected_token_multiset(text: &str) -> BTreeMap { + let mut items = BTreeMap::new(); + + for token in text.split_whitespace() { + let token = trim_token(token); + + if token.is_empty() { + continue; + } + + if let Some(normalized) = normalize_currency_token(token) { + *items.entry(normalized).or_default() += 1; + + continue; + } + if let Some(normalized) = normalize_date_token(token) { + *items.entry(normalized).or_default() += 1; + + continue; + } + if let Some(normalized) = normalize_numeric_token(token) { + *items.entry(normalized).or_default() += 1; + } + } + + items +} + +#[cfg(target_os = "macos")] +fn trim_token(raw: &str) -> &str { + raw.trim_matches(|ch: char| { + matches!( + ch, + '.' | ',' | ';' | ':' | '!' | '?' | '"' | '\'' | '(' | ')' | '[' | ']' | '{' | '}' + ) + }) +} + +#[cfg(target_os = "macos")] +fn normalize_currency_token(token: &str) -> Option { + if let Some(without_symbol) = token.strip_prefix('$') { + let value = normalize_numeric_token(without_symbol)?; + + return Some(format!("${value}")); + } + if let Some(without_symbol) = token.strip_prefix('€') { + let value = normalize_numeric_token(without_symbol)?; + + return Some(format!("€{value}")); + } + if let Some(without_symbol) = token.strip_prefix('£') { + let value = normalize_numeric_token(without_symbol)?; + + return Some(format!("£{value}")); + } + if let Some(without_symbol) = token.strip_prefix('¥') { + let value = normalize_numeric_token(without_symbol)?; + + return Some(format!("¥{value}")); + } + + None +} + +#[cfg(target_os = "macos")] +fn normalize_date_token(token: &str) -> Option { + let parts: Vec<&str> = token.split(['/', '-']).collect(); + + if parts.len() != 3 { + return None; + } + + let norm: Vec<_> = parts.iter().map(|part| part.trim()).collect(); + + if !norm.iter().all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit())) { + return None; + } + + let year = if norm[0].len() == 4 { norm[0] } else { norm[2] }; + let month = norm[1]; + let day = norm[2]; + + Some(format!("date|{year}-{month}-{day}")) +} + +#[cfg(target_os = "macos")] +fn normalize_numeric_token(token: &str) -> Option { + if token.is_empty() { + return None; + } + + let trimmed = token.trim_matches(|ch: char| ch == '$' || ch == '£' || ch == '€' || ch == '¥'); + + if trimmed.is_empty() { + return None; + } + + let mut digits_seen = false; + let mut dot_seen = false; + let mut normalized = String::new(); + + for ch in trimmed.chars() { + if ch.is_ascii_digit() { + digits_seen = true; + + normalized.push(ch); + + continue; + } + if ch == '.' { + if dot_seen { + return None; + } + + dot_seen = true; + + normalized.push(ch); + + continue; + } + if ch != ',' { + return None; + } + } + + if digits_seen { Some(normalized) } else { None } +} + +#[cfg(test)] +#[cfg(target_os = "macos")] +mod tests { + use crate::inference::{self}; + + #[test] + fn normalize_numeric_token_extracts_stable_forms() { + assert_eq!(inference::normalize_numeric_token("12,345.60"), Some("12345.60".to_string())); + assert_eq!(inference::normalize_numeric_token("abc"), None); + } + + #[test] + fn normalize_currency_token_parses_common_markers() { + assert_eq!(inference::normalize_currency_token("$12.50"), Some("$12.50".to_string())); + assert_eq!(inference::normalize_currency_token("€1,200"), Some("€1200".to_string())); + assert_eq!(inference::normalize_currency_token("100"), None); + } + + #[test] + fn normalize_date_token_parses_common_patterns() { + assert_eq!( + inference::normalize_date_token("2026-02-28"), + Some("date|2026-02-28".to_string()) + ); + assert_eq!(inference::normalize_date_token("02/28/26"), Some("date|26-28-26".to_string())); + assert_eq!(inference::normalize_date_token("abc"), None); + } + + #[test] + fn rewrite_guard_flags_numeric_changes() { + let raw = + inference::protected_token_multiset("call me at 120 and send 25 dollars on 2026-02-28"); + let rewritten = inference::protected_token_multiset( + "call me at one twenty and send 26 dollars on 2026-02-28", + ); + + assert_ne!(raw, rewritten); + } +} diff --git a/packages/voxit-core/src/lib.rs b/packages/voxit-core/src/lib.rs index 8591c06..19cbfc7 100644 --- a/packages/voxit-core/src/lib.rs +++ b/packages/voxit-core/src/lib.rs @@ -2,20 +2,26 @@ pub mod auth; pub mod config; +pub mod inference; pub mod openai; pub mod realtime; pub mod transcript; +mod audio_payload; +mod providers; + pub use self::{ auth::{ AuthRecord, AuthStatus, access_token, sign_in_with_chatgpt, sign_in_with_device_code, sign_in_with_device_code_with_progress, sign_out, status, }, config::Config, - openai::{InferenceEvent, RewriteResult, RewriteState, rewrite_only, transcribe_only}, + inference::{ + InferenceEvent, RewriteResult, RewriteState, rewrite_only, start_realtime_session, + transcribe_only, + }, realtime::{ REALTIME_ENDPOINT, RealtimeError, RealtimeEvent, RealtimeSession, RealtimeSessionConfig, - start_realtime_session, }, transcript::{TranscriptAssembler, TranscriptEvent, TranscriptSnapshot}, }; diff --git a/packages/voxit-core/src/openai.rs b/packages/voxit-core/src/openai.rs index 15f2228..03f7173 100644 --- a/packages/voxit-core/src/openai.rs +++ b/packages/voxit-core/src/openai.rs @@ -1,401 +1,5 @@ -//! OpenAI transcription and optional rewrite client. +//! Backward-compatible inference re-exports. -#[cfg(target_os = "macos")] use std::collections::BTreeMap; -#[cfg(target_os = "macos")] use std::time::{Duration, Instant}; - -use reqwest::blocking::Response; -#[cfg(target_os = "macos")] use reqwest::blocking::{ - Client, - multipart::{Form, Part}, +pub use crate::inference::{ + InferenceEvent, RewriteResult, RewriteState, rewrite_only, transcribe_only, }; -#[cfg(target_os = "macos")] use serde_json::Value; - -use crate::auth; - -/// Rewrite outcome status. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum RewriteState { - /// Rewrite was intentionally skipped before request. - Skipped, - /// Rewrite succeeded and output passed all safety checks. - Applied, - /// Rewrite was returned but rejected due to protected token mismatch. - Rejected, -} - -/// Guarded rewrite result payload. -#[derive(Clone, Debug)] -pub struct RewriteResult { - /// Optional rewritten transcript when state is `Applied`. - pub rewritten_transcript: Option, - /// Rewrite decision. - pub state: RewriteState, - /// Optional reason for skipped or rejected rewrite. - pub reason: Option, -} - -/// Background event sent to the UI thread. -#[derive(Debug)] -pub enum InferenceEvent { - /// Pass2 transcription completed with raw transcript text. - Pass2Completed { - /// Pass2 duration in milliseconds. - total_ms: u64, - /// Raw transcript text. - raw_transcript: String, - }, - /// Pass3 rewrite completed (or rejected by guard). - RewriteCompleted { - /// Pass3 duration in milliseconds. - total_ms: u64, - /// Rewrite result. - result: RewriteResult, - }, - /// Pipeline failed with an error. - Failed(String), -} - -/// Transcribes WAV bytes using the configured Pass2 model. -#[cfg(target_os = "macos")] -pub fn transcribe_only(wav: &[u8], model: &str) -> Result<(String, u64), String> { - let started = Instant::now(); - let raw = transcribe(wav, model)?; - - Ok((raw, started.elapsed().as_millis() as u64)) -} - -/// OpenAI pipeline is unavailable on non-macOS placeholder builds. -#[cfg(not(target_os = "macos"))] -pub fn transcribe_only(_wav: &[u8], _model: &str) -> Result<(String, u64), String> { - Err("OpenAI pipeline is only enabled on macOS builds.".to_string()) -} - -/// Rewrites transcript text with protected-token guard checks. -#[cfg(target_os = "macos")] -pub fn rewrite_only(text: &str, model: &str) -> Result<(RewriteResult, u64), String> { - if text.trim().is_empty() { - return Ok(( - RewriteResult { - rewritten_transcript: None, - state: RewriteState::Skipped, - reason: Some("empty transcript; rewrite skipped".to_string()), - }, - 0, - )); - } - - let started = Instant::now(); - let result = rewrite_with_guard(text, model)?; - - Ok((result, started.elapsed().as_millis() as u64)) -} - -/// Rewrite pipeline is unavailable on non-macOS placeholder builds. -#[cfg(not(target_os = "macos"))] -pub fn rewrite_only(_text: &str, _model: &str) -> Result<(RewriteResult, u64), String> { - Err("OpenAI pipeline is only enabled on macOS builds.".to_string()) -} - -#[cfg(target_os = "macos")] -fn transcribe(wav: &[u8], model: &str) -> Result { - let body = - post_multipart("https://api.openai.com/v1/audio/transcriptions", wav, "audio.wav", model)?; - - extract_json_value(&body, &["/text", "/output_text"]) - .or_else(|| extract_json_output_array_value(&body)) - .ok_or_else(|| "transcription response has no usable text".to_string()) -} - -#[cfg(target_os = "macos")] -fn rewrite(text: &str, model: &str) -> Result { - let prompt = "Rewrite the transcript for punctuation and readability. Keep the meaning, numbers, and names intact."; - let body = serde_json::json!({ - "model": model, - "input": format!("Transcript: {text}"), - "instructions": prompt, - "temperature": 0.2, - }); - let body = post_json("https://api.openai.com/v1/responses", body)?; - - extract_json_value(&body, &["/output_text", "/output/0/content/0/text"]) - .or_else(|| extract_json_output_array_value(&body)) - .or_else(|| extract_json_value(&body, &["/text", "/choices/0/message/content"])) - .ok_or_else(|| "rewrite response has no usable text".to_string()) -} - -#[cfg(target_os = "macos")] -fn rewrite_with_guard(text: &str, model: &str) -> Result { - let rewritten = rewrite(text, model)?; - let baseline = protected_token_multiset(text); - let candidate = protected_token_multiset(&rewritten); - - if baseline != candidate { - return Ok(RewriteResult { - rewritten_transcript: None, - state: RewriteState::Rejected, - reason: Some( - "rewrite changed protected tokens (numbers, dates, or currency). Using ASR transcript for safety.".to_string(), - ), - }); - } - - Ok(RewriteResult { - rewritten_transcript: Some(rewritten), - state: RewriteState::Applied, - reason: None, - }) -} - -#[cfg(target_os = "macos")] -fn post_multipart( - url: &str, - file_bytes: &[u8], - file_name: &str, - model: &str, -) -> Result { - let (api_key, account_id) = auth_token()?; - let client = Client::builder() - .timeout(Duration::from_secs(120)) - .build() - .map_err(|err| format!("failed to build OpenAI HTTP client: {err}"))?; - let file_part = Part::bytes(file_bytes.to_vec()) - .file_name(file_name.to_string()) - .mime_str("audio/wav") - .map_err(|err| format!("invalid file mime: {err}"))?; - let form = Form::new().text("model", model.to_string()).part("file", file_part); - let mut request = client.post(url).bearer_auth(api_key).multipart(form); - - if let Some(account_id) = account_id { - request = request.header("ChatGPT-Account-ID", account_id); - } - - let response = request.send().map_err(|err| format!("transcription request failed: {err}"))?; - - check_status(response, "transcription") -} - -#[cfg(target_os = "macos")] -fn post_json(url: &str, body: Value) -> Result { - let (api_key, account_id) = auth_token()?; - let client = Client::builder() - .timeout(Duration::from_secs(120)) - .build() - .map_err(|err| format!("failed to build OpenAI HTTP client: {err}"))?; - let mut request = client.post(url).bearer_auth(api_key).json(&body); - - if let Some(account_id) = account_id { - request = request.header("ChatGPT-Account-ID", account_id); - } - - let response = request.send().map_err(|err| format!("rewrite request failed: {err}"))?; - - check_status(response, "rewrite") -} - -#[cfg(target_os = "macos")] -fn check_status(response: Response, step: &str) -> Result { - if !response.status().is_success() { - let status = response.status(); - let body = response.text().unwrap_or_else(|_| "".to_string()); - - return Err(format!("{step} failed with status {status}: {body}")); - } - - response.text().map_err(|err| format!("failed to read {step} response body: {err}")) -} - -#[cfg(target_os = "macos")] -fn extract_json_value(body: &str, pointers: &[&str]) -> Option { - let value = serde_json::from_str::(body).ok()?; - - pointers - .iter() - .find_map(|pointer| value.pointer(pointer).and_then(Value::as_str).map(str::to_string)) -} - -#[cfg(target_os = "macos")] -fn extract_json_output_array_value(body: &str) -> Option { - let value = serde_json::from_str::(body).ok()?; - let outputs = value.get("output")?.as_array()?; - - outputs.iter().find_map(|entry| { - entry.get("content").and_then(Value::as_array)?.iter().find_map(|chunk| { - chunk - .get("text") - .or_else(|| chunk.get("transcript")) - .and_then(Value::as_str) - .map(str::to_string) - }) - }) -} - -#[cfg(target_os = "macos")] -fn protected_token_multiset(text: &str) -> BTreeMap { - let mut items = BTreeMap::new(); - - for token in text.split_whitespace() { - let token = trim_token(token); - - if token.is_empty() { - continue; - } - - if let Some(normalized) = normalize_currency_token(token) { - *items.entry(normalized).or_default() += 1; - - continue; - } - if let Some(normalized) = normalize_date_token(token) { - *items.entry(normalized).or_default() += 1; - - continue; - } - if let Some(normalized) = normalize_numeric_token(token) { - *items.entry(normalized).or_default() += 1; - } - } - - items -} - -#[cfg(target_os = "macos")] -fn trim_token(raw: &str) -> &str { - raw.trim_matches(|ch: char| { - matches!( - ch, - '.' | ',' | ';' | ':' | '!' | '?' | '"' | '\'' | '(' | ')' | '[' | ']' | '{' | '}' - ) - }) -} - -#[cfg(target_os = "macos")] -fn normalize_currency_token(token: &str) -> Option { - if let Some(without_symbol) = token.strip_prefix('$') { - let value = normalize_numeric_token(without_symbol)?; - - return Some(format!("${value}")); - } - if let Some(without_symbol) = token.strip_prefix('€') { - let value = normalize_numeric_token(without_symbol)?; - - return Some(format!("€{value}")); - } - if let Some(without_symbol) = token.strip_prefix('£') { - let value = normalize_numeric_token(without_symbol)?; - - return Some(format!("£{value}")); - } - if let Some(without_symbol) = token.strip_prefix('¥') { - let value = normalize_numeric_token(without_symbol)?; - - return Some(format!("¥{value}")); - } - - None -} - -#[cfg(target_os = "macos")] -fn normalize_date_token(token: &str) -> Option { - let parts: Vec<&str> = token.split(['/', '-']).collect(); - - if parts.len() != 3 { - return None; - } - - let norm: Vec<_> = parts.iter().map(|part| part.trim()).collect(); - - if !norm.iter().all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit())) { - return None; - } - - let year = if norm[0].len() == 4 { norm[0] } else { norm[2] }; - let month = norm[1]; - let day = norm[2]; - - Some(format!("date|{year}-{month}-{day}")) -} - -#[cfg(target_os = "macos")] -fn normalize_numeric_token(token: &str) -> Option { - if token.is_empty() { - return None; - } - - let trimmed = token.trim_matches(|ch: char| ch == '$' || ch == '£' || ch == '€' || ch == '¥'); - - if trimmed.is_empty() { - return None; - } - - let mut digits_seen = false; - let mut dot_seen = false; - let mut normalized = String::new(); - - for ch in trimmed.chars() { - if ch.is_ascii_digit() { - digits_seen = true; - - normalized.push(ch); - - continue; - } - if ch == '.' { - if dot_seen { - return None; - } - - dot_seen = true; - - normalized.push(ch); - - continue; - } - if ch != ',' { - return None; - } - } - - if digits_seen { Some(normalized) } else { None } -} - -#[cfg(target_os = "macos")] -fn auth_token() -> Result<(String, Option), String> { - auth::access_token().map_err(|err| format!("auth token not available: {err}")) -} - -#[cfg(test)] -#[cfg(target_os = "macos")] -mod tests { - use crate::openai::{self}; - - #[test] - fn normalize_numeric_token_extracts_stable_forms() { - assert_eq!(openai::normalize_numeric_token("12,345.60"), Some("12345.60".to_string())); - assert_eq!(openai::normalize_numeric_token("abc"), None); - } - - #[test] - fn normalize_currency_token_parses_common_markers() { - assert_eq!(openai::normalize_currency_token("$12.50"), Some("$12.50".to_string())); - assert_eq!(openai::normalize_currency_token("€1,200"), Some("€1200".to_string())); - assert_eq!(openai::normalize_currency_token("100"), None); - } - - #[test] - fn normalize_date_token_parses_common_patterns() { - assert_eq!(openai::normalize_date_token("2026-02-28"), Some("date|2026-02-28".to_string())); - assert_eq!(openai::normalize_date_token("02/28/26"), Some("date|26-28-26".to_string())); - assert_eq!(openai::normalize_date_token("abc"), None); - } - - #[test] - fn rewrite_guard_flags_numeric_changes() { - let raw = - openai::protected_token_multiset("call me at 120 and send 25 dollars on 2026-02-28"); - let rewritten = openai::protected_token_multiset( - "call me at one twenty and send 26 dollars on 2026-02-28", - ); - - assert_ne!(raw, rewritten); - } -} diff --git a/packages/voxit-core/src/providers.rs b/packages/voxit-core/src/providers.rs new file mode 100644 index 0000000..9c3ff84 --- /dev/null +++ b/packages/voxit-core/src/providers.rs @@ -0,0 +1,46 @@ +//! Speech and rewrite provider interfaces. + +#[cfg(target_os = "macos")] pub(crate) mod chatgpt; + +#[cfg(target_os = "macos")] use std::sync::mpsc::{Receiver, Sender}; + +#[cfg(target_os = "macos")] +use crate::realtime::{RealtimeError, RealtimeEvent, RealtimeSession, RealtimeSessionConfig}; +#[cfg(target_os = "macos")] use voxit_audio::AudioChunk; + +/// Provider boundary for voice inference backends. +#[cfg(target_os = "macos")] +pub(crate) trait InferenceProvider { + fn provider_id(&self) -> &'static str; + + fn start_realtime_session( + &self, + config: RealtimeSessionConfig, + chunk_rx: Receiver, + event_tx: Sender, + ) -> Result; + + fn transcribe(&self, request: TranscriptionRequest<'_>) -> Result; + + fn rewrite(&self, request: RewriteRequest<'_>) -> Result; +} + +/// Pass2 transcription request. +#[cfg(target_os = "macos")] +pub(crate) struct TranscriptionRequest<'a> { + pub(crate) wav: &'a [u8], + pub(crate) model: &'a str, +} + +/// Pass3 rewrite request. +#[cfg(target_os = "macos")] +pub(crate) struct RewriteRequest<'a> { + pub(crate) text: &'a str, + pub(crate) model: &'a str, +} + +/// Resolve the only provider enabled in the first provider-abstraction version. +#[cfg(target_os = "macos")] +pub(crate) fn chatgpt_oauth_provider() -> Result { + chatgpt::ChatGptProvider::from_stored_oauth() +} diff --git a/packages/voxit-core/src/providers/chatgpt.rs b/packages/voxit-core/src/providers/chatgpt.rs new file mode 100644 index 0000000..7d0048d --- /dev/null +++ b/packages/voxit-core/src/providers/chatgpt.rs @@ -0,0 +1,208 @@ +//! ChatGPT OAuth-backed inference provider. + +use std::{ + sync::mpsc::{Receiver, Sender}, + time::Duration, +}; + +use reqwest::blocking::{ + Client, RequestBuilder, Response, + multipart::{Form, Part}, +}; +use serde_json::Value; + +use crate::{ + audio_payload::{self, PreparedTranscriptionAudio}, + auth::{self, ChatGptAuthContext}, + providers::{InferenceProvider, RewriteRequest, TranscriptionRequest}, + realtime::{self, RealtimeError, RealtimeEvent, RealtimeSession, RealtimeSessionConfig}, +}; +use voxit_audio::AudioChunk; + +const CHATGPT_TRANSCRIBE_ENDPOINT: &str = "https://chatgpt.com/backend-api/transcribe"; +const OPENAI_RESPONSES_ENDPOINT: &str = "https://api.openai.com/v1/responses"; +const MIN_TRANSCRIBE_DURATION_MS: u64 = 1_000; +const VOXIT_USER_AGENT: &str = concat!("voxit/", env!("CARGO_PKG_VERSION")); + +/// ChatGPT OAuth-backed provider for v1 provider abstraction. +#[derive(Clone)] +pub(crate) struct ChatGptProvider { + auth: ChatGptAuthContext, + client: Client, +} +impl ChatGptProvider { + /// Build the provider from stored ChatGPT OAuth credentials. + pub(crate) fn from_stored_oauth() -> Result { + let auth = auth::chatgpt_auth_context()?; + let client = Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .map_err(|err| format!("failed to build ChatGPT HTTP client: {err}"))?; + + Ok(Self { auth, client }) + } + + fn transcribe_chatgpt(&self, request: TranscriptionRequest<'_>) -> Result { + let prepared = audio_payload::prepare_chatgpt_transcription_wav(request.wav)?; + + self.log_prepared_audio(request.model, &prepared); + self.reject_too_short_audio(&prepared)?; + + let body = self.post_transcribe(&prepared.wav, "audio.wav")?; + + extract_json_value(&body, &["/text", "/output_text"]) + .or_else(|| extract_json_output_array_value(&body)) + .ok_or_else(|| "transcription response has no usable text".to_string()) + } + + fn rewrite_chatgpt(&self, request: RewriteRequest<'_>) -> Result { + let prompt = "Rewrite the transcript for punctuation and readability. Keep the meaning, numbers, and names intact."; + let body = serde_json::json!({ + "model": request.model, + "input": format!("Transcript: {}", request.text), + "instructions": prompt, + "temperature": 0.2, + }); + let body = self.post_json(OPENAI_RESPONSES_ENDPOINT, body)?; + + extract_json_value(&body, &["/output_text", "/output/0/content/0/text"]) + .or_else(|| extract_json_output_array_value(&body)) + .or_else(|| extract_json_value(&body, &["/text", "/choices/0/message/content"])) + .ok_or_else(|| "rewrite response has no usable text".to_string()) + } + + fn log_prepared_audio(&self, model: &str, prepared: &PreparedTranscriptionAudio) { + tracing::info!( + provider = self.provider_id(), + model, + input_sample_rate = prepared.input.sample_rate_hz, + input_channels = prepared.input.channels, + input_bits_per_sample = prepared.input.bits_per_sample, + input_duration_ms = prepared.input.duration_ms, + request_sample_rate = prepared.request.sample_rate_hz, + request_channels = prepared.request.channels, + request_bits_per_sample = prepared.request.bits_per_sample, + request_duration_ms = prepared.request.duration_ms, + request_wav_bytes = prepared.wav.len(), + "prepared transcription audio payload" + ); + } + + fn reject_too_short_audio(&self, prepared: &PreparedTranscriptionAudio) -> Result<(), String> { + if prepared.input.duration_ms >= MIN_TRANSCRIBE_DURATION_MS + || prepared.request.duration_ms >= MIN_TRANSCRIBE_DURATION_MS + { + return Ok(()); + } + + tracing::warn!( + provider = self.provider_id(), + input_duration_ms = prepared.input.duration_ms, + request_duration_ms = prepared.request.duration_ms, + min_required_ms = MIN_TRANSCRIBE_DURATION_MS, + request_wav_bytes = prepared.wav.len(), + "skipping transcription: audio clip too short" + ); + + Err("audio is too short for transcription; please record at least 1 second".to_string()) + } + + fn post_transcribe(&self, file_bytes: &[u8], file_name: &str) -> Result { + let file_part = Part::bytes(file_bytes.to_vec()) + .file_name(file_name.to_string()) + .mime_str("audio/wav") + .map_err(|err| format!("invalid file mime: {err}"))?; + let form = Form::new().part("file", file_part); + let response = self + .with_auth(self.client.post(CHATGPT_TRANSCRIBE_ENDPOINT)) + .multipart(form) + .send() + .map_err(|err| format!("transcription request failed: {err}"))?; + + check_status(response, "transcription") + } + + fn post_json(&self, url: &str, body: Value) -> Result { + let response = self + .with_auth(self.client.post(url)) + .json(&body) + .send() + .map_err(|err| format!("rewrite request failed: {err}"))?; + + check_status(response, "rewrite") + } + + fn with_auth(&self, request: RequestBuilder) -> RequestBuilder { + let request = + request.bearer_auth(&self.auth.bearer_token).header("User-Agent", VOXIT_USER_AGENT); + + if let Some(account_id) = self.auth.account_id.as_ref() { + request.header("ChatGPT-Account-Id", account_id) + } else { + request + } + } +} + +impl InferenceProvider for ChatGptProvider { + fn provider_id(&self) -> &'static str { + "chatgpt-oauth" + } + + fn start_realtime_session( + &self, + config: RealtimeSessionConfig, + chunk_rx: Receiver, + event_tx: Sender, + ) -> Result { + realtime::start_realtime_session( + self.auth.bearer_token.clone(), + self.auth.account_id.clone(), + config, + chunk_rx, + event_tx, + ) + } + + fn transcribe(&self, request: TranscriptionRequest<'_>) -> Result { + self.transcribe_chatgpt(request) + } + + fn rewrite(&self, request: RewriteRequest<'_>) -> Result { + self.rewrite_chatgpt(request) + } +} + +fn check_status(response: Response, step: &str) -> Result { + if !response.status().is_success() { + let status = response.status(); + let body = response.text().unwrap_or_else(|_| "".to_string()); + + return Err(format!("{step} failed with status {status}: {body}")); + } + + response.text().map_err(|err| format!("failed to read {step} response body: {err}")) +} + +fn extract_json_value(body: &str, pointers: &[&str]) -> Option { + let value = serde_json::from_str::(body).ok()?; + + pointers + .iter() + .find_map(|pointer| value.pointer(pointer).and_then(Value::as_str).map(str::to_string)) +} + +fn extract_json_output_array_value(body: &str) -> Option { + let value = serde_json::from_str::(body).ok()?; + let outputs = value.get("output")?.as_array()?; + + outputs.iter().find_map(|entry| { + entry.get("content").and_then(Value::as_array)?.iter().find_map(|chunk| { + chunk + .get("text") + .or_else(|| chunk.get("transcript")) + .and_then(Value::as_str) + .map(str::to_string) + }) + }) +} diff --git a/scripts/bundle-and-open-macos.sh b/scripts/bundle-and-open-macos.sh index 514bb37..44ef9dc 100755 --- a/scripts/bundle-and-open-macos.sh +++ b/scripts/bundle-and-open-macos.sh @@ -8,6 +8,16 @@ cargo bundle --release -p voxit app_path="target/release/bundle/osx/Voxit.app" +if [ "${VOXIT_ALLOW_MULTI_INSTANCE:-0}" != "1" ]; then + existing_pids="$(pgrep -x voxit || true)" + if [ -n "$existing_pids" ]; then + echo "Voxit is already running (pids: $(echo "$existing_pids" | tr '\n' ' '))." + echo "Quit it first to avoid launching multiple menu bar instances (this script uses: open -n)." + echo "To override, rerun with: VOXIT_ALLOW_MULTI_INSTANCE=1" + exit 2 + fi +fi + # If the bundle has Gatekeeper attributes, Launch Services may block launch via `open` even for local builds. # Only strip attributes when they are actually present. if xattr -p com.apple.quarantine "$app_path" >/dev/null 2>&1; then @@ -28,7 +38,9 @@ fi # Without a bundle signature, macOS Launch Services may refuse to launch it with: # "code has no resources but signature indicates they must be present". # Re-signing the bundle ad-hoc keeps local dev runs launchable via `open`. -codesign --force --deep --sign - "$app_path" +codesign_identity="${VOXIT_CODESIGN_IDENTITY:--}" +echo "Signing Voxit.app with identity: $codesign_identity" +codesign --force --deep --sign "$codesign_identity" "$app_path" # Some macOS builds attach `com.apple.provenance` during signing; strip it again to avoid # triggering remote Gatekeeper assessment that kills the process shortly after launch. @@ -40,7 +52,6 @@ if xattr -p com.apple.provenance "$app_path" >/dev/null 2>&1; then fi fi -pre_pids="$(pgrep -x voxit || true)" open -n "$app_path" new_pid=""