Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions rust/crates/api/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
None
}




#[must_use]
pub fn strip_provider_prefix(canonical_model: &str) -> String {
if let Some(pos) = canonical_model.find('/') {
Expand All @@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String {
}
}



#[must_use]
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
let resolved_model = resolve_model_alias(model);
Expand Down
143 changes: 42 additions & 101 deletions rust/crates/api/src/providers/openai_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ use crate::types::{
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};

use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix};

use super::{preflight_message_request, resolve_model_alias, Provider, ProviderFuture};

pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
Expand Down Expand Up @@ -213,80 +212,22 @@ impl OpenAiCompatClient {
}

pub async fn send_message(
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
// 1. Keep track of what Claw originally asked for
let original_model = request.model.clone();
let canonical = resolve_model_alias(&request.model);

// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
let downstream_model = strip_provider_prefix(&canonical);

let mut request = MessageRequest {
stream: false,
..request.clone()
};
request.model = downstream_model; // Use the clean name for the API payload

preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;

// Some backends return {"error":{"message":"...","type":"...","code":...}}
// instead of a valid completion object. Check for this before attempting
// full deserialization so the user sees the actual error, not a cryptic.
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(err_obj) = raw.get("error") {
let msg = err_obj
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("provider returned an error")
.to_string();
let code = err_obj
.get("code")
.and_then(serde_json::Value::as_u64)
.map(|c| c as u16);
return Err(ApiError::Api {
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
error_type: err_obj
.get("type")
.and_then(|t| t.as_str())
.map(str::to_owned),
message: Some(msg),
request_id,
body,
retryable: false,
suggested_action: suggested_action_for_status(
reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
),
retry_after: None,
});
}
}

// Pass original_model to the deserializer error context so debugging logs are accurate
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?;

let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() {
normalized.request_id = request_id;
}
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
let original_model = request.model.clone();
let canonical = resolve_model_alias(&request.model);

// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
normalized.model = original_model;
let mut request = MessageRequest {
stream: false,
..request.clone()
};
request.model = canonical;

Ok(normalized)
}
// Some backends return {"error":{"message":"...","type":"...","code":...}}
// instead of a valid completion object. Check for this before attempting
// full deserialization so the user sees the actual error, not a cryptic
// "missing field 'id'" parse failure.
preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(err_obj) = raw.get("error") {
let msg = err_obj
Expand Down Expand Up @@ -318,41 +259,41 @@ impl OpenAiCompatClient {
}
}
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error)
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?;
let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() {
normalized.request_id = request_id;
}
normalized.model = original_model;
Ok(normalized)
}

pub async fn stream_message(
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
// 1. Keep track of the original model name
let original_model = request.model.clone();
let canonical = resolve_model_alias(&request.model);

// 2. Clean it up for DeepSeek
let downstream_model = strip_provider_prefix(&canonical);

let mut streaming_request = request.clone().with_streaming();
streaming_request.model = downstream_model;

preflight_message_request(&streaming_request)?;
let response = self.send_with_retry(&streaming_request).await?;

Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()),
pending: VecDeque::new(),
done: false,
state: StreamState::new(original_model), // 3. Use the original name here
})
}
pub async fn stream_message(
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
let original_model = request.model.clone();
let canonical = resolve_model_alias(&request.model);

let mut streaming_request = request.clone().with_streaming();
streaming_request.model = canonical;

preflight_message_request(&streaming_request)?;
let response = self.send_with_retry(&streaming_request).await?;

Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: OpenAiSseParser::with_context(
self.config.provider_name,
original_model.clone(),
),
pending: VecDeque::new(),
done: false,
state: StreamState::new(original_model),
})
}

async fn send_with_retry(
&self,
Expand Down
5 changes: 4 additions & 1 deletion rust/crates/api/tests/openai_compat_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,12 +548,13 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
.with_base_url("http://origin.invalid/v1");
let response = client
.send_message(&MessageRequest {
model: "gpt-4o".to_string(),
model: "openai/gpt-4.1-mini".to_string(),
..sample_request(false)
})
.await
.expect("proxy should return the OpenAI-compatible response");

assert_eq!(response.model, "openai/gpt-4.1-mini");
assert_eq!(response.total_tokens(), 7);
let captured = state.lock().await;
let request = captured.first().expect("proxy should capture request");
Expand All @@ -562,6 +563,8 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
request.headers.get("authorization").map(String::as_str),
Some("Bearer openai-test-key")
);
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
assert_eq!(body["model"], json!("openai/gpt-4.1-mini"));
}

#[allow(clippy::await_holding_lock)]
Expand Down
25 changes: 25 additions & 0 deletions rust/crates/runtime/src/session_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,28 @@ mod tests {

static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);

struct EnvVarGuard {
key: &'static str,
previous: Option<std::ffi::OsString>,
}

impl EnvVarGuard {
fn set(key: &'static str, value: &Path) -> Self {
let previous = std::env::var_os(key);
std::env::set_var(key, value);
Self { key, previous }
}
}

impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}

fn temp_dir() -> PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
Expand Down Expand Up @@ -1290,8 +1312,11 @@ mod tests {
#[test]
fn latest_session_returns_all_empty_error_when_sessions_exist_but_have_no_messages() {
// given — create sessions with 0 messages (empty)
let _env_guard = crate::test_env_lock();
let base = temp_dir();
fs::create_dir_all(&base).expect("base dir should exist");
let isolated_config_home = base.join("config-home");
let _claw_config_home = EnvVarGuard::set("CLAW_CONFIG_HOME", &isolated_config_home);
let store = SessionStore::from_cwd(&base).expect("store should build");

let empty_handle = store.create_handle("empty-session");
Expand Down
15 changes: 6 additions & 9 deletions rust/crates/runtime/src/worker_boot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1644,16 +1644,13 @@ mod tests {

let tmp = tempfile::tempdir().expect("tempdir");
let worktree = tmp.path().join("worktree");
let git_dir = tmp.path().join("external-gitdir");
fs::create_dir_all(&worktree).expect("worktree dir");
fs::create_dir_all(git_dir.join("objects")).expect("objects dir");
fs::create_dir_all(git_dir.join("refs/heads")).expect("refs dir");
fs::write(git_dir.join("HEAD"), "ref: refs/heads/main\n").expect("HEAD");
fs::write(
worktree.join(".git"),
format!("gitdir: {}\n", git_dir.display()),
)
.expect(".git file");
Command::new("git")
.arg("init")
.current_dir(&worktree)
.output()
.expect("git init should run");
let git_dir = worktree.join(".git");

let original_permissions = fs::metadata(&git_dir)
.expect("gitdir metadata")
Expand Down
60 changes: 56 additions & 4 deletions rust/crates/rusty-claude-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13737,8 +13737,15 @@ fn push_output_block(
};
*pending_tool = Some((id, name, initial_input));
}
OutputContentBlock::Thinking { thinking, .. } => {
OutputContentBlock::Thinking {
thinking,
signature,
} => {
render_thinking_block_summary(out, Some(thinking.chars().count()), false)?;
events.push(AssistantEvent::Thinking {
thinking,
signature,
});
*block_has_thinking_summary = true;
}
OutputContentBlock::RedactedThinking { .. } => {
Expand Down Expand Up @@ -19073,6 +19080,13 @@ UU conflicted.rs",

assert!(matches!(
&events[0],
AssistantEvent::Thinking {
thinking,
signature
} if thinking == "step 1" && signature.as_deref() == Some("sig_123")
));
assert!(matches!(
&events[1],
AssistantEvent::TextDelta(text) if text == "Final answer"
));
let rendered = String::from_utf8(out).expect("utf8");
Expand Down Expand Up @@ -19649,6 +19663,41 @@ mod dump_manifests_tests {

#[cfg(test)]
mod alias_resolution_tests {
fn ollama_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.expect("ollama env lock poisoned")
}

struct EnvVarGuard {
key: &'static str,
previous: Option<String>,
}

impl EnvVarGuard {
fn unset(key: &'static str) -> Self {
let previous = std::env::var(key).ok();
std::env::remove_var(key);
Self { key, previous }
}

fn set(key: &'static str, value: &str) -> Self {
let previous = std::env::var(key).ok();
std::env::set_var(key, value);
Self { key, previous }
}
}

impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}

use super::{resolve_model_alias_with_config, validate_model_syntax};

#[test]
Expand All @@ -19670,6 +19719,8 @@ mod alias_resolution_tests {

#[test]
fn test_alias_resolution_syntax_validation() {
let _guard = ollama_env_lock();
let _env = EnvVarGuard::unset("OLLAMA_HOST");
// Resolved aliases should pass syntax validation
let resolved = resolve_model_alias_with_config("opus");
assert!(validate_model_syntax(&resolved).is_ok());
Expand All @@ -19680,6 +19731,8 @@ mod alias_resolution_tests {

#[test]
fn test_unknown_alias_fails_validation() {
let _guard = ollama_env_lock();
let _env = EnvVarGuard::unset("OLLAMA_HOST");
// Unknown aliases resolve to themselves
let resolved = resolve_model_alias_with_config("unknown-alias");
assert_eq!(resolved, "unknown-alias");
Expand All @@ -19699,14 +19752,13 @@ mod alias_resolution_tests {
}
#[test]
fn test_ollama_host_bypasses_model_validation() {
// Safety: test sets and clears env var within the test.
std::env::set_var("OLLAMA_HOST", "http://127.0.0.1:11434");
let _guard = ollama_env_lock();
let _env = EnvVarGuard::set("OLLAMA_HOST", "http://127.0.0.1:11434");
// Ollama model names with colons pass
assert!(validate_model_syntax("qwen3:8b").is_ok());
assert!(validate_model_syntax("gemma4:e2b").is_ok());
assert!(validate_model_syntax("qwen3.6:27b-nvfp4").is_ok());
// Empty model still rejected
assert!(validate_model_syntax("").is_err());
std::env::remove_var("OLLAMA_HOST");
}
}
Loading