Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions packages/voxit-core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1746,8 +1746,7 @@ fn html_escape(raw: &str) -> String {
mod tests {
use std::{
env, fs,
sync::Mutex,
thread,
sync::{Mutex, mpsc},
time::{Duration, Instant},
};

Expand Down Expand Up @@ -1925,16 +1924,22 @@ mod tests {

#[test]
fn timeout_helper_stops_waiting_after_deadline() {
let (release_tx, release_rx) = mpsc::channel();
let start = Instant::now();
let err = auth::run_with_timeout("test-timeout", Duration::from_millis(20), || {
thread::sleep(Duration::from_millis(120));
let err = auth::run_with_timeout("test-timeout", Duration::from_millis(20), move || {
let _ = release_rx.recv_timeout(Duration::from_secs(5));

Ok::<(), String>(())
})
.expect_err("operation should time out");
let elapsed = start.elapsed();
let _ = release_tx.send(());

assert!(err.contains("timed out"));
assert!(start.elapsed() < Duration::from_millis(90));
assert!(
elapsed < Duration::from_secs(2),
"timeout helper returned after {elapsed:?}, which suggests it waited for the operation"
);
}

#[test]
Expand Down
225 changes: 225 additions & 0 deletions packages/voxit-core/src/providers/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,228 @@ fn extract_json_output_array_value(body: &str) -> Option<String> {
})
})
}

#[cfg(test)]
mod tests {
use std::{env, fs, path::PathBuf, time::Duration};

use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use reqwest::blocking::Client;
use serde::Deserialize;
use serde_json::Value;

use crate::{
auth::ChatGptAuthContext,
providers::{
InferenceProvider, TranscriptionRequest,
chatgpt::{ChatGptProvider, VOXIT_USER_AGENT},
},
};

const CODEX_AUTH_PATH_ENV: &str = "VOXIT_CODEX_AUTH_JSON";
const LIVE_ASR_ENV: &str = "VOXIT_RUN_CHATGPT_ASR_LIVE";
const OSR_SAMPLE_URL: &str =
"https://www.voiptroubleshooter.com/open_speech/american/OSR_us_000_0010_8k.wav";
const TEST_TRANSCRIPTION_MODEL: &str = "gpt-4o-mini-transcribe";
const EXPECTED_PHRASES: &[&str] = &[
"the birch canoe slid on the smooth planks",
"glue the sheet to the dark blue background",
];

#[derive(Debug, Deserialize)]
struct CodexAuthFile {
tokens: Option<CodexAuthTokens>,
}

#[derive(Debug, Deserialize)]
struct CodexAuthTokens {
access_token: Option<String>,
account_id: Option<String>,
id_token: Option<String>,
}

#[test]
fn codex_auth_parser_uses_access_token_and_claim_account_id() -> Result<(), String> {
let id_token = test_jwt(serde_json::json!({
"https://api.openai.com/auth": {
"chatgpt_account_id": "account-from-claim"
}
}))?;
let raw = serde_json::json!({
"auth_mode": "chatgpt",
"tokens": {
"access_token": "access-token",
"id_token": id_token
}
})
.to_string();
let auth = parse_codex_auth_context(&raw)?;

assert_eq!(auth.bearer_token, "access-token");
assert_eq!(auth.account_id.as_deref(), Some("account-from-claim"));

Ok(())
}

#[test]
fn transcript_matcher_accepts_normalized_osr_phrase() -> Result<(), String> {
assert_expected_transcript("The birch canoe slid on the smooth planks.")?;

Ok(())
}

#[test]
#[ignore = "requires network access plus ChatGPT OAuth credentials in ~/.codex/auth.json"]
fn live_chatgpt_oauth_asr_transcribes_open_speech_sample() -> Result<(), String> {
require_live_asr_opt_in()?;

let auth = load_codex_auth_context()?;
let provider = live_test_provider(auth)?;
let wav = download_public_wav(OSR_SAMPLE_URL)?;
let transcript = provider
.transcribe(TranscriptionRequest { wav: &wav, model: TEST_TRANSCRIPTION_MODEL })?;

eprintln!("live ChatGPT OAuth ASR transcript: {transcript}");

assert_expected_transcript(&transcript)
}

fn require_live_asr_opt_in() -> Result<(), String> {
env::var(LIVE_ASR_ENV)
.map(|_| ())
.map_err(|_| format!("set {LIVE_ASR_ENV}=1 to run the live ChatGPT OAuth ASR test"))
}

fn live_test_provider(auth: ChatGptAuthContext) -> Result<ChatGptProvider, String> {
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|err| format!("failed to build live ASR HTTP client: {err}"))?;

Ok(ChatGptProvider { auth, client })
}

fn load_codex_auth_context() -> Result<ChatGptAuthContext, String> {
let path = codex_auth_path()?;
let raw = fs::read_to_string(&path)
.map_err(|err| format!("failed to read {}: {err}", path.display()))?;

parse_codex_auth_context(&raw)
}

fn codex_auth_path() -> Result<PathBuf, String> {
if let Some(path) = env::var_os(CODEX_AUTH_PATH_ENV) {
return Ok(PathBuf::from(path));
}

let home = env::var_os("HOME")
.ok_or_else(|| "HOME is not set; cannot find Codex auth".to_string())?;

Ok(PathBuf::from(home).join(".codex").join("auth.json"))
}

fn parse_codex_auth_context(raw: &str) -> Result<ChatGptAuthContext, String> {
let auth: CodexAuthFile =
serde_json::from_str(raw).map_err(|err| format!("invalid Codex auth JSON: {err}"))?;
let tokens =
auth.tokens.ok_or_else(|| "Codex auth JSON has no tokens object".to_string())?;
let bearer_token = non_empty(tokens.access_token)
.ok_or_else(|| "Codex auth JSON tokens.access_token is missing".to_string())?;
let account_id = non_empty(tokens.account_id)
.or_else(|| tokens.id_token.as_deref().and_then(id_account_id));

Ok(ChatGptAuthContext { bearer_token, account_id })
}

fn non_empty(value: Option<String>) -> Option<String> {
value.and_then(|value| if value.trim().is_empty() { None } else { Some(value) })
}

fn id_account_id(id_token: &str) -> Option<String> {
decode_jwt_payload(id_token)?
.get("https://api.openai.com/auth")?
.get("chatgpt_account_id")?
.as_str()
.map(str::to_string)
}

fn decode_jwt_payload(jwt: &str) -> Option<Value> {
let mut parts = jwt.split('.');
let _header = parts.next()?;
let payload = parts.next()?;
let bytes = URL_SAFE_NO_PAD.decode(payload.as_bytes()).ok()?;

serde_json::from_slice(&bytes).ok()
}

fn download_public_wav(url: &str) -> Result<Vec<u8>, String> {
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|err| format!("failed to build audio download client: {err}"))?;
let response = client
.get(url)
.header("User-Agent", VOXIT_USER_AGENT)
.send()
.map_err(|err| format!("failed to download test wav: {err}"))?;

if !response.status().is_success() {
return Err(format!("test wav download failed with status {}", response.status()));
}

let wav = response
.bytes()
.map_err(|err| format!("failed to read test wav bytes: {err}"))?
.to_vec();

if !wav.starts_with(b"RIFF") {
return Err("downloaded test fixture is not a RIFF WAV file".to_string());
}

Ok(wav)
}

fn assert_expected_transcript(transcript: &str) -> Result<(), String> {
let transcript = normalize(transcript);
let matched = EXPECTED_PHRASES
.iter()
.map(|phrase| normalize(phrase))
.any(|phrase| transcript.contains(&phrase));

if matched {
return Ok(());
}

Err(format!(
"expected ASR transcript to contain one Open Speech Repository phrase; got: {transcript}"
))
}

fn normalize(input: &str) -> String {
let mut normalized = String::with_capacity(input.len());
let mut previous_was_space = true;

for character in input.chars() {
if character.is_ascii_alphanumeric() {
normalized.push(character.to_ascii_lowercase());

previous_was_space = false;
} else if !previous_was_space {
normalized.push(' ');

previous_was_space = true;
}
}

normalized.trim().to_string()
}

fn test_jwt(payload: Value) -> Result<String, String> {
let header = serde_json::to_vec(&serde_json::json!({"alg": "none", "typ": "JWT"}))
.map_err(|err| format!("failed to serialize test JWT header: {err}"))?;
let payload = serde_json::to_vec(&payload)
.map_err(|err| format!("failed to serialize test JWT payload: {err}"))?;

Ok(format!("{}.{}.", URL_SAFE_NO_PAD.encode(header), URL_SAFE_NO_PAD.encode(payload)))
}
}
Loading