From d3a553c08a5871972200ef594d9f08b1257f8d8e Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Wed, 21 Jan 2026 21:49:25 -0800 Subject: [PATCH] refactor(models): move model structs and errors into dedicated module - Extracted `LanguageModel`, `RerankingModel`, and `EmbeddingModel` structs and their error types from `lib.rs` into new files under `models/`. - Updated imports and usage throughout the codebase to reference the new `models` module. - Moved related trait definitions (`GeneratesText`, `GeneratesObject`, `Reranks`, `ReranksStructuredData`, `Embeds`) to their respective response generator modules. - Centralized `AIProvider` enum and its implementations in `providers/mod.rs` for better separation of concerns. - No functional changes; this is an internal code organization and modularization improvement. --- crates/umem_ai/src/lib.rs | 226 +----------------- .../umem_ai/src/model_impl/embedding_model.rs | 4 +- .../umem_ai/src/model_impl/language_model.rs | 4 +- .../umem_ai/src/model_impl/reranking_model.rs | 4 +- crates/umem_ai/src/models/embedding.rs | 24 ++ crates/umem_ai/src/models/language.rs | 23 ++ crates/umem_ai/src/models/mod.rs | 7 + crates/umem_ai/src/models/reranking.rs | 24 ++ .../umem_ai/src/providers/amazon_bedrock.rs | 7 +- crates/umem_ai/src/providers/mod.rs | 135 ++++++++++- crates/umem_ai/src/providers/openai.rs | 3 +- .../umem_ai/src/response_generators/embed.rs | 12 +- .../response_generators/generate_object.rs | 13 +- .../src/response_generators/generate_text.rs | 12 +- .../umem_ai/src/response_generators/rerank.rs | 12 +- .../response_generators/structured_rerank.rs | 15 +- 16 files changed, 275 insertions(+), 250 deletions(-) create mode 100644 crates/umem_ai/src/models/embedding.rs create mode 100644 crates/umem_ai/src/models/language.rs create mode 100644 crates/umem_ai/src/models/mod.rs create mode 100644 crates/umem_ai/src/models/reranking.rs diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index 2c8cb28..f013a97 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -1,239 +1,19 @@ // TODO: remove this allow once the module is fully implemented #![allow(dead_code)] mod model_impl; +pub mod models; mod providers; mod response_generators; mod utils; - -use anyhow::Result; -use async_trait::async_trait; use lazy_static::lazy_static; + pub use model_impl::*; +pub use models::*; pub use providers::*; pub use response_generators::*; -use schemars::JsonSchema; -use serde::{Serialize, de::DeserializeOwned}; -use std::sync::Arc; -use thiserror::Error; pub type HashMap = rustc_hash::FxHashMap; lazy_static! { static ref reqwest_client: reqwest::Client = reqwest::Client::new(); } - -#[derive(Error, Debug)] -pub enum LanguageModelError { - #[error("ai provider failed with : {0}")] - AIProviderError(#[from] AIProviderError), -} - -pub struct LanguageModel { - pub provider: Arc, - pub model_name: String, -} - -impl LanguageModel { - fn new(provider: Arc, model_name: String) -> Self { - Self { - provider, - model_name, - } - } -} - -#[derive(Error, Debug)] -pub enum RerankingModelError { - #[error("ai provider failed with : {0}")] - AIProviderError(#[from] AIProviderError), -} - -#[derive(Debug, Clone)] -pub struct RerankingModel { - pub provider: Arc, - pub model_name: String, -} - -impl RerankingModel { - fn new(provider: Arc, model_name: String) -> Self { - Self { - provider, - model_name, - } - } -} - -#[derive(Error, Debug)] -pub enum EmbeddingModelError { - #[error("ai provider failed with : {0}")] - AIProviderError(#[from] AIProviderError), -} - -#[derive(Debug, Clone)] -pub struct EmbeddingModel { - pub provider: Arc, - pub model_name: String, -} - -impl EmbeddingModel { - fn new(provider: Arc, model_name: String) -> Self { - Self { - provider, - model_name, - } - } -} - -#[derive(Debug)] -pub enum AIProvider { - OpenAI(OpenAIProvider), - AzureOpenAI(AzureOpenAIProvider), - GoogleVertexAI(GoogleVertexAIProvider), - Anthropic(AnthropicProvider), - XAI(XAIProvider), - AmazonBedrock(AmazonBedrockProvider), - Cohere(CohereProvider), -} - -impl AIProvider { - pub(crate) async fn do_generate_text( - &self, - request: GenerateTextRequest, - ) -> Result { - match self { - AIProvider::OpenAI(provider) => provider.generate_text(request), - AIProvider::AmazonBedrock(provider) => provider.generate_text(request), - _ => unimplemented!(), - } - .await - } - - pub(crate) async fn do_generate_object< - T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, - >( - &self, - request: GenerateObjectRequest, - ) -> Result, ResponseGeneratorError> { - match self { - AIProvider::OpenAI(provider) => provider.generate_object(request), - AIProvider::AmazonBedrock(provider) => provider.generate_object(request), - _ => unimplemented!(), - } - .await - } - - pub(crate) async fn do_reranking( - &self, - request: RerankRequest, - ) -> Result { - match self { - AIProvider::Cohere(provider) => provider.rerank(request), - AIProvider::AmazonBedrock(provider) => provider.rerank(request), - _ => unimplemented!(), - } - .await - } - - pub(crate) async fn do_structured_reranking( - &self, - request: StructuredRerankRequest, - ) -> Result, ResponseGeneratorError> - where - T: Serialize + Clone + Send + Sync, - { - match self { - AIProvider::Cohere(provider) => provider.rerank_structured(request).await, - AIProvider::AmazonBedrock(provider) => provider.rerank_structured(request).await, - _ => unimplemented!(), - } - } - - pub(crate) async fn do_embed( - &self, - request: EmbeddingRequest, - ) -> Result { - match self { - AIProvider::AmazonBedrock(provider) => provider.embed(request), - _ => unimplemented!(), - } - .await - } -} - -#[async_trait] -pub trait GeneratesText { - async fn generate_text( - &self, - request: GenerateTextRequest, - ) -> Result; -} - -#[async_trait] -pub trait GeneratesObject { - async fn generate_object( - &self, - request: GenerateObjectRequest, - ) -> Result, ResponseGeneratorError>; -} - -#[async_trait] -pub trait Reranks { - async fn rerank( - &self, - request: RerankRequest, - ) -> Result; -} - -#[async_trait] -pub trait ReranksStructuredData { - async fn rerank_structured( - &self, - request: StructuredRerankRequest, - ) -> Result, ResponseGeneratorError> - where - T: Serialize + Clone + Send + Sync; -} - -#[async_trait] -pub trait Embeds { - async fn embed( - &self, - request: EmbeddingRequest, - ) -> Result; -} - -impl From for AIProvider { - fn from(config: OpenAIProvider) -> Self { - AIProvider::OpenAI(config) - } -} - -impl From for AIProvider { - fn from(config: AzureOpenAIProvider) -> Self { - AIProvider::AzureOpenAI(config) - } -} - -impl From for AIProvider { - fn from(config: AnthropicProvider) -> Self { - AIProvider::Anthropic(config) - } -} - -impl From for AIProvider { - fn from(config: XAIProvider) -> Self { - AIProvider::XAI(config) - } -} - -impl From for AIProvider { - fn from(config: AmazonBedrockProvider) -> Self { - AIProvider::AmazonBedrock(config) - } -} - -impl From for AIProvider { - fn from(config: GoogleVertexAIProvider) -> Self { - AIProvider::GoogleVertexAI(config) - } -} diff --git a/crates/umem_ai/src/model_impl/embedding_model.rs b/crates/umem_ai/src/model_impl/embedding_model.rs index a6a7aef..c15bf6b 100644 --- a/crates/umem_ai/src/model_impl/embedding_model.rs +++ b/crates/umem_ai/src/model_impl/embedding_model.rs @@ -1,6 +1,6 @@ use crate::{ - AIProvider, AIProviderError, AmazonBedrockProviderBuilder, EmbeddingModel, EmbeddingModelError, - OpenAIProvider, + AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider, + models::{EmbeddingModel, EmbeddingModelError}, }; use std::sync::Arc; use tokio::sync::OnceCell; diff --git a/crates/umem_ai/src/model_impl/language_model.rs b/crates/umem_ai/src/model_impl/language_model.rs index 48a4e0f..a1de8bd 100644 --- a/crates/umem_ai/src/model_impl/language_model.rs +++ b/crates/umem_ai/src/model_impl/language_model.rs @@ -1,6 +1,6 @@ use crate::{ - AIProvider, AIProviderError, AmazonBedrockProviderBuilder, LanguageModel, LanguageModelError, - OpenAIProvider, + AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider, + models::{LanguageModel, LanguageModelError}, }; use std::sync::Arc; use tokio::sync::OnceCell; diff --git a/crates/umem_ai/src/model_impl/reranking_model.rs b/crates/umem_ai/src/model_impl/reranking_model.rs index 81d4caf..d8544e0 100644 --- a/crates/umem_ai/src/model_impl/reranking_model.rs +++ b/crates/umem_ai/src/model_impl/reranking_model.rs @@ -1,6 +1,6 @@ use crate::{ - AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider, RerankingModel, - RerankingModelError, + AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider, + models::{RerankingModel, RerankingModelError}, }; use std::sync::Arc; use tokio::sync::OnceCell; diff --git a/crates/umem_ai/src/models/embedding.rs b/crates/umem_ai/src/models/embedding.rs new file mode 100644 index 0000000..212296b --- /dev/null +++ b/crates/umem_ai/src/models/embedding.rs @@ -0,0 +1,24 @@ +use crate::{AIProvider, AIProviderError}; +use std::sync::Arc; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum EmbeddingModelError { + #[error("ai provider failed with : {0}")] + AIProviderError(#[from] AIProviderError), +} + +#[derive(Debug, Clone)] +pub struct EmbeddingModel { + pub provider: Arc, + pub model_name: String, +} + +impl EmbeddingModel { + fn new(provider: Arc, model_name: String) -> Self { + Self { + provider, + model_name, + } + } +} diff --git a/crates/umem_ai/src/models/language.rs b/crates/umem_ai/src/models/language.rs new file mode 100644 index 0000000..ce17d3a --- /dev/null +++ b/crates/umem_ai/src/models/language.rs @@ -0,0 +1,23 @@ +use crate::{AIProvider, AIProviderError}; +use std::sync::Arc; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum LanguageModelError { + #[error("ai provider failed with : {0}")] + AIProviderError(#[from] AIProviderError), +} + +pub struct LanguageModel { + pub provider: Arc, + pub model_name: String, +} + +impl LanguageModel { + fn new(provider: Arc, model_name: String) -> Self { + Self { + provider, + model_name, + } + } +} diff --git a/crates/umem_ai/src/models/mod.rs b/crates/umem_ai/src/models/mod.rs new file mode 100644 index 0000000..a660cef --- /dev/null +++ b/crates/umem_ai/src/models/mod.rs @@ -0,0 +1,7 @@ +pub mod embedding; +pub mod language; +pub mod reranking; + +pub use embedding::*; +pub use language::*; +pub use reranking::*; diff --git a/crates/umem_ai/src/models/reranking.rs b/crates/umem_ai/src/models/reranking.rs new file mode 100644 index 0000000..ed1bbcc --- /dev/null +++ b/crates/umem_ai/src/models/reranking.rs @@ -0,0 +1,24 @@ +use crate::{AIProvider, AIProviderError}; +use std::sync::Arc; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum RerankingModelError { + #[error("ai provider failed with : {0}")] + AIProviderError(#[from] AIProviderError), +} + +#[derive(Debug, Clone)] +pub struct RerankingModel { + pub provider: Arc, + pub model_name: String, +} + +impl RerankingModel { + fn new(provider: Arc, model_name: String) -> Self { + Self { + provider, + model_name, + } + } +} diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index dbdf598..4c46b6e 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -769,9 +769,10 @@ impl AmazonBedrockProviderBuilder { mod tests { use super::*; use crate::{ - AIProvider, EmbeddingModel, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, - LanguageModel, RerankingModel, SerializationFormat, generate_object, generate_text, rerank, - structured_rerank, + AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, SerializationFormat, + generate_object, generate_text, + models::{EmbeddingModel, LanguageModel, RerankingModel}, + rerank, structured_rerank, }; use serde::Deserialize; use std::sync::Arc; diff --git a/crates/umem_ai/src/providers/mod.rs b/crates/umem_ai/src/providers/mod.rs index 1342298..cf8f320 100644 --- a/crates/umem_ai/src/providers/mod.rs +++ b/crates/umem_ai/src/providers/mod.rs @@ -5,8 +5,23 @@ mod cohere; mod google_vertex; mod openai; mod xai; - +use crate::{ + Embeds, GenerateObjectRequest, GenerateObjectResponse, GenerateTextRequest, + GenerateTextResponse, GeneratesObject, GeneratesText, RerankRequest, RerankResponse, Reranks, + ReranksStructuredData, ResponseGeneratorError, StructuredRerankRequest, + StructuredRerankResponse, + embed::{EmbeddingRequest, EmbeddingResponse}, +}; +pub use amazon_bedrock::*; +pub use anthropic::AnthropicProvider; +pub use azure_openai::AzureOpenAIProvider; +pub use cohere::CohereProvider; +pub use google_vertex::GoogleVertexAIProvider; +pub use openai::OpenAIProvider; +use schemars::JsonSchema; +use serde::{Serialize, de::DeserializeOwned}; use thiserror::Error; +pub use xai::XAIProvider; #[derive(Error, Debug)] pub enum AIProviderError { @@ -20,10 +35,114 @@ pub enum ProviderBuilderError { AmazonBedrockProviderBuilderError(#[from] AmazonBedrockProviderBuilderError), } -pub use amazon_bedrock::*; -pub use anthropic::AnthropicProvider; -pub use azure_openai::AzureOpenAIProvider; -pub use cohere::CohereProvider; -pub use google_vertex::GoogleVertexAIProvider; -pub use openai::OpenAIProvider; -pub use xai::XAIProvider; +#[derive(Debug)] +pub enum AIProvider { + OpenAI(OpenAIProvider), + AzureOpenAI(AzureOpenAIProvider), + GoogleVertexAI(GoogleVertexAIProvider), + Anthropic(AnthropicProvider), + XAI(XAIProvider), + AmazonBedrock(AmazonBedrockProvider), + Cohere(CohereProvider), +} + +impl AIProvider { + pub(crate) async fn do_generate_text( + &self, + request: GenerateTextRequest, + ) -> Result { + match self { + AIProvider::OpenAI(provider) => provider.generate_text(request), + AIProvider::AmazonBedrock(provider) => provider.generate_text(request), + _ => unimplemented!(), + } + .await + } + + pub(crate) async fn do_generate_object< + T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, + >( + &self, + request: GenerateObjectRequest, + ) -> Result, ResponseGeneratorError> { + match self { + AIProvider::OpenAI(provider) => provider.generate_object(request), + AIProvider::AmazonBedrock(provider) => provider.generate_object(request), + _ => unimplemented!(), + } + .await + } + + pub(crate) async fn do_reranking( + &self, + request: RerankRequest, + ) -> Result { + match self { + AIProvider::Cohere(provider) => provider.rerank(request), + AIProvider::AmazonBedrock(provider) => provider.rerank(request), + _ => unimplemented!(), + } + .await + } + + pub(crate) async fn do_structured_reranking( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + match self { + AIProvider::Cohere(provider) => provider.rerank_structured(request).await, + AIProvider::AmazonBedrock(provider) => provider.rerank_structured(request).await, + _ => unimplemented!(), + } + } + + pub(crate) async fn do_embed( + &self, + request: EmbeddingRequest, + ) -> Result { + match self { + AIProvider::AmazonBedrock(provider) => provider.embed(request), + _ => unimplemented!(), + } + .await + } +} + +impl From for AIProvider { + fn from(config: OpenAIProvider) -> Self { + AIProvider::OpenAI(config) + } +} + +impl From for AIProvider { + fn from(config: AzureOpenAIProvider) -> Self { + AIProvider::AzureOpenAI(config) + } +} + +impl From for AIProvider { + fn from(config: AnthropicProvider) -> Self { + AIProvider::Anthropic(config) + } +} + +impl From for AIProvider { + fn from(config: XAIProvider) -> Self { + AIProvider::XAI(config) + } +} + +impl From for AIProvider { + fn from(config: AmazonBedrockProvider) -> Self { + AIProvider::AmazonBedrock(config) + } +} + +impl From for AIProvider { + fn from(config: GoogleVertexAIProvider) -> Self { + AIProvider::GoogleVertexAI(config) + } +} diff --git a/crates/umem_ai/src/providers/openai.rs b/crates/umem_ai/src/providers/openai.rs index dcd7d9a..5e68e87 100644 --- a/crates/umem_ai/src/providers/openai.rs +++ b/crates/umem_ai/src/providers/openai.rs @@ -359,7 +359,8 @@ mod tests { use super::*; use crate::{ - AIProvider, LanguageModel, + AIProvider, + models::LanguageModel, response_generators::{ GenerateTextRequestBuilder, generate_object::{GenerateObjectRequestBuilder, generate_object}, diff --git a/crates/umem_ai/src/response_generators/embed.rs b/crates/umem_ai/src/response_generators/embed.rs index ef944a3..80cd99e 100644 --- a/crates/umem_ai/src/response_generators/embed.rs +++ b/crates/umem_ai/src/response_generators/embed.rs @@ -1,11 +1,21 @@ use crate::{ - EmbeddingModel, ResponseGeneratorError, + ResponseGeneratorError, + models::EmbeddingModel, utils::{self, is_retryable_error}, }; +use async_trait::async_trait; use backon::{ExponentialBuilder, Retryable}; use reqwest::header::HeaderMap; use std::{sync::Arc, time::Duration}; +#[async_trait] +pub trait Embeds { + async fn embed( + &self, + request: EmbeddingRequest, + ) -> Result; +} + pub async fn embed(request: EmbeddingRequest) -> Result { let per_request_timeout = request.timeout; let max_retries = request.max_retries; diff --git a/crates/umem_ai/src/response_generators/generate_object.rs b/crates/umem_ai/src/response_generators/generate_object.rs index 33b6e5b..f33f469 100644 --- a/crates/umem_ai/src/response_generators/generate_object.rs +++ b/crates/umem_ai/src/response_generators/generate_object.rs @@ -1,5 +1,7 @@ -use crate::{LanguageModel, ResponseGeneratorError, utils}; +use crate::models::LanguageModel; +use crate::{ResponseGeneratorError, utils}; use crate::{response_generators::messages::Message, utils::is_retryable_error}; +use async_trait::async_trait; use backon::{ExponentialBuilder, Retryable}; use reqwest::header::HeaderMap; use schemars::{JsonSchema, Schema, schema_for}; @@ -8,6 +10,14 @@ use std::time::Duration; use std::{marker::PhantomData, sync::Arc}; use thiserror::Error; +#[async_trait] +pub trait GeneratesObject { + async fn generate_object( + &self, + request: GenerateObjectRequest, + ) -> Result, ResponseGeneratorError>; +} + pub async fn generate_object( request: GenerateObjectRequest, ) -> Result, ResponseGeneratorError> @@ -21,6 +31,7 @@ where let generation = || { let model = Arc::clone(&request.model); let provider = Arc::clone(&model.provider); + let request = request.clone(); async move { tokio::time::timeout(per_request_timeout, provider.do_generate_object(request)) diff --git a/crates/umem_ai/src/response_generators/generate_text.rs b/crates/umem_ai/src/response_generators/generate_text.rs index 64446fa..46d4d8b 100644 --- a/crates/umem_ai/src/response_generators/generate_text.rs +++ b/crates/umem_ai/src/response_generators/generate_text.rs @@ -1,8 +1,9 @@ -use crate::LanguageModel; use crate::ResponseGeneratorError; +use crate::models::LanguageModel; use crate::response_generators::messages::Message; use crate::utils; use crate::utils::is_retryable_error; +use async_trait::async_trait; use backon::ExponentialBuilder; use backon::Retryable; use reqwest::header::HeaderMap; @@ -10,7 +11,14 @@ use std::sync::Arc; use std::time::Duration; use thiserror::Error; -// TODO: Wrap me with observers for logging, metrics, tracing, etc. +#[async_trait] +pub trait GeneratesText { + async fn generate_text( + &self, + request: GenerateTextRequest, + ) -> Result; +} + pub async fn generate_text( request: GenerateTextRequest, ) -> Result { diff --git a/crates/umem_ai/src/response_generators/rerank.rs b/crates/umem_ai/src/response_generators/rerank.rs index bcfde7d..75479ff 100644 --- a/crates/umem_ai/src/response_generators/rerank.rs +++ b/crates/umem_ai/src/response_generators/rerank.rs @@ -1,9 +1,17 @@ +use crate::{ResponseGeneratorError, models::RerankingModel, utils::is_retryable_error}; +use async_trait::async_trait; use backon::{ExponentialBuilder, Retryable}; use serde_json::{Map, Value}; - -use crate::{RerankingModel, ResponseGeneratorError, utils::is_retryable_error}; use std::{sync::Arc, time::Duration}; +#[async_trait] +pub trait Reranks { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result; +} + pub async fn rerank(request: RerankRequest) -> Result { let per_request_timeout = request.timeout; let max_retries = request.max_retries; diff --git a/crates/umem_ai/src/response_generators/structured_rerank.rs b/crates/umem_ai/src/response_generators/structured_rerank.rs index 64b7ff9..c9c4efc 100644 --- a/crates/umem_ai/src/response_generators/structured_rerank.rs +++ b/crates/umem_ai/src/response_generators/structured_rerank.rs @@ -1,10 +1,19 @@ -use std::{sync::Arc, time::Duration}; - +use crate::{ResponseGeneratorError, models::RerankingModel, utils::is_retryable_error}; +use async_trait::async_trait; use backon::{ExponentialBuilder, Retryable}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::{Map, Value}; +use std::{sync::Arc, time::Duration}; -use crate::{RerankingModel, ResponseGeneratorError, utils::is_retryable_error}; +#[async_trait] +pub trait ReranksStructuredData { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync; +} #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub enum SerializationFormat {