Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions rust/crates/api/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,20 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
None
}




#[must_use]
pub fn strip_provider_prefix(canonical_model: &str) -> String {
if let Some(pos) = canonical_model.find('/') {
canonical_model[pos + 1..].to_string()
} else {
canonical_model.to_string()
}
}



#[must_use]
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
let resolved_model = resolve_model_alias(model);
Expand Down
127 changes: 98 additions & 29 deletions rust/crates/api/src/providers/openai_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::types::{
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};

use super::{preflight_message_request, Provider, ProviderFuture};
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix};


pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
Expand Down Expand Up @@ -212,17 +213,76 @@ impl OpenAiCompatClient {
}

pub async fn send_message(
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
let request = MessageRequest {
stream: false,
..request.clone()
};
preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
// 1. Keep track of what Claw originally asked for
let original_model = request.model.clone();
let canonical = resolve_model_alias(&request.model);

// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
let downstream_model = strip_provider_prefix(&canonical);

let mut request = MessageRequest {
stream: false,
..request.clone()
};
request.model = downstream_model; // Use the clean name for the API payload

preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;

// Some backends return {"error":{"message":"...","type":"...","code":...}}
// instead of a valid completion object. Check for this before attempting
// full deserialization so the user sees the actual error, not a cryptic.
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(err_obj) = raw.get("error") {
let msg = err_obj
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("provider returned an error")
.to_string();
let code = err_obj
.get("code")
.and_then(serde_json::Value::as_u64)
.map(|c| c as u16);
return Err(ApiError::Api {
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
error_type: err_obj
.get("type")
.and_then(|t| t.as_str())
.map(str::to_owned),
message: Some(msg),
request_id,
body,
retryable: false,
suggested_action: suggested_action_for_status(
reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
),
retry_after: None,
});
}
}

// Pass original_model to the deserializer error context so debugging logs are accurate
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?;

let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() {
normalized.request_id = request_id;
}

// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
normalized.model = original_model;

Ok(normalized)
}
// Some backends return {"error":{"message":"...","type":"...","code":...}}
// instead of a valid completion object. Check for this before attempting
// full deserialization so the user sees the actual error, not a cryptic
Expand Down Expand Up @@ -267,23 +327,32 @@ impl OpenAiCompatClient {
Ok(normalized)
}

pub async fn stream_message(
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
preflight_message_request(request)?;
let response = self
.send_with_retry(&request.clone().with_streaming())
.await?;
Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()),
pending: VecDeque::new(),
done: false,
state: StreamState::new(request.model.clone()),
})
}
pub async fn stream_message(
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
// 1. Keep track of the original model name
let original_model = request.model.clone();
let canonical = resolve_model_alias(&request.model);

// 2. Clean it up for DeepSeek
let downstream_model = strip_provider_prefix(&canonical);

let mut streaming_request = request.clone().with_streaming();
streaming_request.model = downstream_model;

preflight_message_request(&streaming_request)?;
let response = self.send_with_retry(&streaming_request).await?;

Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()),
pending: VecDeque::new(),
done: false,
state: StreamState::new(original_model), // 3. Use the original name here
})
}

async fn send_with_retry(
&self,
Expand Down