Skip to content

Commit 7b359c9

Browse files
authored
Call models endpoint in models manager (#7616)
- Introduce `with_remote_overrides` and update `refresh_available_models` - Put `auth_manager` instead of `auth_mode` on `models_manager` - Remove `ShellType` and `ReasoningLevel` to use already existing structs
1 parent 6736d18 commit 7b359c9

File tree

13 files changed

+329
-88
lines changed

13 files changed

+329
-88
lines changed

codex-rs/codex-api/tests/models_integration.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ use codex_api::provider::RetryConfig;
55
use codex_api::provider::WireApi;
66
use codex_client::ReqwestTransport;
77
use codex_protocol::openai_models::ClientVersion;
8+
use codex_protocol::openai_models::ConfigShellToolType;
89
use codex_protocol::openai_models::ModelInfo;
910
use codex_protocol::openai_models::ModelVisibility;
1011
use codex_protocol::openai_models::ModelsResponse;
11-
use codex_protocol::openai_models::ReasoningLevel;
12-
use codex_protocol::openai_models::ShellType;
12+
use codex_protocol::openai_models::ReasoningEffort;
1313
use http::HeaderMap;
1414
use http::Method;
1515
use wiremock::Mock;
@@ -55,13 +55,13 @@ async fn models_client_hits_models_endpoint() {
5555
slug: "gpt-test".to_string(),
5656
display_name: "gpt-test".to_string(),
5757
description: Some("desc".to_string()),
58-
default_reasoning_level: ReasoningLevel::Medium,
58+
default_reasoning_level: ReasoningEffort::Medium,
5959
supported_reasoning_levels: vec![
60-
ReasoningLevel::Low,
61-
ReasoningLevel::Medium,
62-
ReasoningLevel::High,
60+
ReasoningEffort::Low,
61+
ReasoningEffort::Medium,
62+
ReasoningEffort::High,
6363
],
64-
shell_type: ShellType::ShellCommand,
64+
shell_type: ConfigShellToolType::ShellCommand,
6565
visibility: ModelVisibility::List,
6666
minimal_client_version: ClientVersion(0, 1, 0),
6767
supported_in_api: true,

codex-rs/core/src/codex.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2475,6 +2475,7 @@ pub(crate) use tests::make_session_and_context_with_rx;
24752475
#[cfg(test)]
24762476
mod tests {
24772477
use super::*;
2478+
use crate::CodexAuth;
24782479
use crate::config::ConfigOverrides;
24792480
use crate::config::ConfigToml;
24802481
use crate::exec::ExecToolCallOutput;
@@ -2765,12 +2766,9 @@ mod tests {
27652766
.expect("load default test config");
27662767
let config = Arc::new(config);
27672768
let conversation_id = ConversationId::default();
2768-
let auth_manager = AuthManager::shared(
2769-
config.cwd.clone(),
2770-
false,
2771-
config.cli_auth_credentials_store_mode,
2772-
);
2773-
let models_manager = Arc::new(ModelsManager::new(auth_manager.get_auth_mode()));
2769+
let auth_manager =
2770+
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
2771+
let models_manager = Arc::new(ModelsManager::new(auth_manager.clone()));
27742772
let otel_event_manager =
27752773
otel_event_manager(conversation_id, config.as_ref(), &models_manager);
27762774

@@ -2801,7 +2799,7 @@ mod tests {
28012799
rollout: Mutex::new(None),
28022800
user_shell: default_user_shell(),
28032801
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
2804-
auth_manager: Arc::clone(&auth_manager),
2802+
auth_manager: auth_manager.clone(),
28052803
otel_event_manager: otel_event_manager.clone(),
28062804
models_manager: models_manager.clone(),
28072805
tool_approvals: Mutex::new(ApprovalStore::default()),
@@ -2847,12 +2845,9 @@ mod tests {
28472845
.expect("load default test config");
28482846
let config = Arc::new(config);
28492847
let conversation_id = ConversationId::default();
2850-
let auth_manager = AuthManager::shared(
2851-
config.cwd.clone(),
2852-
false,
2853-
config.cli_auth_credentials_store_mode,
2854-
);
2855-
let models_manager = Arc::new(ModelsManager::new(auth_manager.get_auth_mode()));
2848+
let auth_manager =
2849+
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
2850+
let models_manager = Arc::new(ModelsManager::new(auth_manager.clone()));
28562851
let otel_event_manager =
28572852
otel_event_manager(conversation_id, config.as_ref(), &models_manager);
28582853

codex-rs/core/src/conversation_manager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl ConversationManager {
4747
conversations: Arc::new(RwLock::new(HashMap::new())),
4848
auth_manager: auth_manager.clone(),
4949
session_source,
50-
models_manager: Arc::new(ModelsManager::new(auth_manager.get_auth_mode())),
50+
models_manager: Arc::new(ModelsManager::new(auth_manager)),
5151
}
5252
}
5353

codex-rs/core/src/openai_models/model_family.rs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use codex_protocol::config_types::Verbosity;
2+
use codex_protocol::openai_models::ModelInfo;
23
use codex_protocol::openai_models::ReasoningEffort;
34

45
use crate::config::Config;
56
use crate::config::types::ReasoningSummaryFormat;
67
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
7-
use crate::tools::spec::ConfigShellToolType;
88
use crate::truncate::TruncationPolicy;
9+
use codex_protocol::openai_models::ConfigShellToolType;
910

1011
/// The `instructions` field in the payload sent to a model should always start
1112
/// with this content.
@@ -83,6 +84,15 @@ impl ModelFamily {
8384
}
8485
self
8586
}
87+
pub fn with_remote_overrides(mut self, remote_models: Vec<ModelInfo>) -> Self {
88+
for model in remote_models {
89+
if model.slug == self.slug {
90+
self.default_reasoning_effort = Some(model.default_reasoning_level);
91+
self.shell_type = model.shell_type;
92+
}
93+
}
94+
self
95+
}
8696
}
8797

8898
macro_rules! model_family {
@@ -275,3 +285,71 @@ fn derive_default_model_family(model: &str) -> ModelFamily {
275285
truncation_policy: TruncationPolicy::Bytes(10_000),
276286
}
277287
}
288+
289+
#[cfg(test)]
290+
mod tests {
291+
use super::*;
292+
use codex_protocol::openai_models::ClientVersion;
293+
use codex_protocol::openai_models::ModelVisibility;
294+
295+
fn remote(slug: &str, effort: ReasoningEffort, shell: ConfigShellToolType) -> ModelInfo {
296+
ModelInfo {
297+
slug: slug.to_string(),
298+
display_name: slug.to_string(),
299+
description: Some(format!("{slug} desc")),
300+
default_reasoning_level: effort,
301+
supported_reasoning_levels: vec![effort],
302+
shell_type: shell,
303+
visibility: ModelVisibility::List,
304+
minimal_client_version: ClientVersion(0, 1, 0),
305+
supported_in_api: true,
306+
priority: 1,
307+
}
308+
}
309+
310+
#[test]
311+
fn remote_overrides_apply_when_slug_matches() {
312+
let family = model_family!("gpt-4o-mini", "gpt-4o-mini");
313+
assert_ne!(family.default_reasoning_effort, Some(ReasoningEffort::High));
314+
315+
let updated = family.with_remote_overrides(vec![
316+
remote(
317+
"gpt-4o-mini",
318+
ReasoningEffort::High,
319+
ConfigShellToolType::ShellCommand,
320+
),
321+
remote(
322+
"other-model",
323+
ReasoningEffort::Low,
324+
ConfigShellToolType::UnifiedExec,
325+
),
326+
]);
327+
328+
assert_eq!(
329+
updated.default_reasoning_effort,
330+
Some(ReasoningEffort::High)
331+
);
332+
assert_eq!(updated.shell_type, ConfigShellToolType::ShellCommand);
333+
}
334+
335+
#[test]
336+
fn remote_overrides_skip_non_matching_models() {
337+
let family = model_family!(
338+
"codex-mini-latest",
339+
"codex-mini-latest",
340+
shell_type: ConfigShellToolType::Local
341+
);
342+
343+
let updated = family.clone().with_remote_overrides(vec![remote(
344+
"other",
345+
ReasoningEffort::High,
346+
ConfigShellToolType::ShellCommand,
347+
)]);
348+
349+
assert_eq!(
350+
updated.default_reasoning_effort,
351+
family.default_reasoning_effort
352+
);
353+
assert_eq!(updated.shell_type, family.shell_type);
354+
}
355+
}
Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,172 @@
1-
use codex_app_server_protocol::AuthMode;
1+
use codex_api::ModelsClient;
2+
use codex_api::ReqwestTransport;
3+
use codex_protocol::openai_models::ModelInfo;
24
use codex_protocol::openai_models::ModelPreset;
5+
use http::HeaderMap;
6+
use std::sync::Arc;
37
use tokio::sync::RwLock;
48

9+
use crate::api_bridge::auth_provider_from_auth;
10+
use crate::api_bridge::map_api_error;
11+
use crate::auth::AuthManager;
512
use crate::config::Config;
13+
use crate::default_client::build_reqwest_client;
14+
use crate::error::Result as CoreResult;
15+
use crate::model_provider_info::ModelProviderInfo;
616
use crate::openai_models::model_family::ModelFamily;
717
use crate::openai_models::model_family::find_family_for_model;
818
use crate::openai_models::model_presets::builtin_model_presets;
919

1020
#[derive(Debug)]
1121
pub struct ModelsManager {
22+
// todo(aibrahim) merge available_models and model family creation into one struct
1223
pub available_models: RwLock<Vec<ModelPreset>>,
24+
pub remote_models: RwLock<Vec<ModelInfo>>,
1325
pub etag: String,
14-
pub auth_mode: Option<AuthMode>,
26+
pub auth_manager: Arc<AuthManager>,
1527
}
1628

1729
impl ModelsManager {
18-
pub fn new(auth_mode: Option<AuthMode>) -> Self {
30+
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
1931
Self {
20-
available_models: RwLock::new(builtin_model_presets(auth_mode)),
32+
available_models: RwLock::new(builtin_model_presets(auth_manager.get_auth_mode())),
33+
remote_models: RwLock::new(Vec::new()),
2134
etag: String::new(),
22-
auth_mode,
35+
auth_manager,
2336
}
2437
}
2538

26-
pub async fn refresh_available_models(&self) {
27-
let models = builtin_model_presets(self.auth_mode);
28-
*self.available_models.write().await = models;
39+
// do not use this function yet. It's work in progress.
40+
pub async fn refresh_available_models(
41+
&self,
42+
provider: &ModelProviderInfo,
43+
) -> CoreResult<Vec<ModelInfo>> {
44+
let auth = self.auth_manager.auth();
45+
let api_provider = provider.to_api_provider(auth.as_ref().map(|auth| auth.mode))?;
46+
let api_auth = auth_provider_from_auth(auth.clone(), provider).await?;
47+
let transport = ReqwestTransport::new(build_reqwest_client());
48+
let client = ModelsClient::new(transport, api_provider, api_auth);
49+
50+
let response = client
51+
.list_models(env!("CARGO_PKG_VERSION"), HeaderMap::new())
52+
.await
53+
.map_err(map_api_error)?;
54+
55+
let models = response.models;
56+
*self.remote_models.write().await = models.clone();
57+
{
58+
let mut available_models_guard = self.available_models.write().await;
59+
*available_models_guard = self.build_available_models().await;
60+
}
61+
Ok(models)
2962
}
3063

3164
pub fn construct_model_family(&self, model: &str, config: &Config) -> ModelFamily {
3265
find_family_for_model(model).with_config_overrides(config)
3366
}
67+
68+
async fn build_available_models(&self) -> Vec<ModelPreset> {
69+
let mut available_models = self.remote_models.read().await.clone();
70+
available_models.sort_by(|a, b| b.priority.cmp(&a.priority));
71+
let mut model_presets: Vec<ModelPreset> =
72+
available_models.into_iter().map(Into::into).collect();
73+
if let Some(default) = model_presets.first_mut() {
74+
default.is_default = true;
75+
}
76+
model_presets
77+
}
78+
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use super::*;
83+
use crate::CodexAuth;
84+
use crate::model_provider_info::WireApi;
85+
use codex_protocol::openai_models::ModelsResponse;
86+
use serde_json::json;
87+
use wiremock::Mock;
88+
use wiremock::MockServer;
89+
use wiremock::ResponseTemplate;
90+
use wiremock::matchers::method;
91+
use wiremock::matchers::path;
92+
93+
fn remote_model(slug: &str, display: &str, priority: i32) -> ModelInfo {
94+
serde_json::from_value(json!({
95+
"slug": slug,
96+
"display_name": display,
97+
"description": format!("{display} desc"),
98+
"default_reasoning_level": "medium",
99+
"supported_reasoning_levels": ["low", "medium"],
100+
"shell_type": "shell_command",
101+
"visibility": "list",
102+
"minimal_client_version": [0, 1, 0],
103+
"supported_in_api": true,
104+
"priority": priority
105+
}))
106+
.expect("valid model")
107+
}
108+
109+
fn provider_for(base_url: String) -> ModelProviderInfo {
110+
ModelProviderInfo {
111+
name: "mock".into(),
112+
base_url: Some(base_url),
113+
env_key: None,
114+
env_key_instructions: None,
115+
experimental_bearer_token: None,
116+
wire_api: WireApi::Responses,
117+
query_params: None,
118+
http_headers: None,
119+
env_http_headers: None,
120+
request_max_retries: Some(0),
121+
stream_max_retries: Some(0),
122+
stream_idle_timeout_ms: Some(5_000),
123+
requires_openai_auth: false,
124+
}
125+
}
126+
127+
#[tokio::test]
128+
async fn refresh_available_models_sorts_and_marks_default() {
129+
let server = MockServer::start().await;
130+
let remote_models = vec![
131+
remote_model("priority-low", "Low", 1),
132+
remote_model("priority-high", "High", 10),
133+
];
134+
let response = ModelsResponse {
135+
models: remote_models.clone(),
136+
};
137+
Mock::given(method("GET"))
138+
.and(path("/models"))
139+
.respond_with(
140+
ResponseTemplate::new(200)
141+
.insert_header("content-type", "application/json")
142+
.set_body_json(&response),
143+
)
144+
.expect(1)
145+
.mount(&server)
146+
.await;
147+
148+
let auth_manager =
149+
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
150+
let manager = ModelsManager::new(auth_manager);
151+
let provider = provider_for(server.uri());
152+
153+
let returned = manager
154+
.refresh_available_models(&provider)
155+
.await
156+
.expect("refresh succeeds");
157+
158+
assert_eq!(returned, remote_models);
159+
let cached_remote = manager.remote_models.read().await.clone();
160+
assert_eq!(cached_remote, remote_models);
161+
162+
let available = manager.available_models.read().await.clone();
163+
assert_eq!(available.len(), 2);
164+
assert_eq!(available[0].model, "priority-high");
165+
assert!(
166+
available[0].is_default,
167+
"highest priority should be default"
168+
);
169+
assert_eq!(available[1].model, "priority-low");
170+
assert!(!available[1].is_default);
171+
}
34172
}

0 commit comments

Comments
 (0)