diff --git a/Cargo.lock b/Cargo.lock index e063c2f..a5d8cb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5129,7 +5129,6 @@ dependencies = [ "umem_ai", "umem_config", "umem_controller", - "umem_embeddings", "umem_grpc_server", "umem_mcp", "umem_memory_machine", @@ -5208,7 +5207,6 @@ dependencies = [ "umem_ai", "umem_annotations", "umem_core", - "umem_embeddings", "umem_refine", "umem_vector_store", "uuid", @@ -5260,19 +5258,6 @@ dependencies = [ "yaml-rust2 0.11.0", ] -[[package]] -name = "umem_embeddings" -version = "0.1.0" -dependencies = [ - "async-trait", - "lazy_static", - "reqwest", - "serde", - "thiserror 2.0.17", - "tokio", - "umem_config", -] - [[package]] name = "umem_grpc_server" version = "0.1.0" @@ -5316,7 +5301,6 @@ dependencies = [ "umem_ai", "umem_config", "umem_controller", - "umem_embeddings", "umem_grpc_server", "umem_mcp", "umem_vector_store", diff --git a/Cargo.toml b/Cargo.toml index bb8b7d4..22dd04e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ umem_memory_machine = {path = "crates/umem_memory_machine"} umem_refine = {path = "crates/umem_refine"} umem_vector_store = {path = "crates/umem_vector_store"} umem_config = {path = "crates/umem_config"} -umem_embeddings = {path = "crates/umem_embeddings"} umem_annotations = {path = "crates/umem_annotations"} umem_grpc_server = {path = "crates/umem_grpc_server"} umem_web_scrapper = {path = "crates/umem_web_scrapper"} @@ -56,7 +55,6 @@ typed-builder = { workspace = true } umem_config = { workspace = true } umem_controller = { workspace = true } umem_ai = { workspace = true } -umem_embeddings = { workspace = true } umem_vector_store = { workspace = true } anyhow = { workspace = true } tokio = { workspace = true } diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index f573c65..d82d464 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -16,8 +16,6 @@ use thiserror::Error; use tokio::sync::OnceCell; use umem_config::CONFIG; -use crate::response_generators::embed::{EmbeddingRequest, EmbeddingResponse}; - pub type HashMap = rustc_hash::FxHashMap; lazy_static! { @@ -153,12 +151,71 @@ impl RerankingModel { } } +#[derive(Error, Debug)] +pub enum EmbeddingModelError { + #[error("ai provider failed with : {0}")] + AIProviderError(#[from] AIProviderError), +} + #[derive(Debug, Clone)] pub struct EmbeddingModel { pub provider: Arc, pub model_name: String, } +static EMBEDDING_MODEL: OnceCell> = OnceCell::const_new(); + +impl EmbeddingModel { + fn new(provider: Arc, model_name: String) -> Self { + Self { + provider, + model_name, + } + } + + pub async fn get_model() -> Result, LanguageModelError> { + EMBEDDING_MODEL + .get_or_try_init(|| async { + match CONFIG.embedding_model.provider.clone() { + umem_config::Provider::OpenAI(open_ai) => { + let openai_provider = OpenAIProvider::builder() + .api_key(open_ai.api_key) + .base_url(open_ai.base_url) + .default_headers(open_ai.default_headers.unwrap_or_default()) + .project(open_ai.project) + .organization(open_ai.organization) + .build(); + + let provider = Arc::new(AIProvider::from(openai_provider)); + + Ok(Arc::new(EmbeddingModel { + provider, + model_name: CONFIG.embedding_model.model.clone(), + })) + } + umem_config::Provider::AmazonBedrock(config) => { + let provider = AmazonBedrockProviderBuilder::default() + .region(config.region) + .access_key_id(config.key_id) + .secret_access_key(config.access_key) + .build() + .await + .map_err(|e| AIProviderError::ProviderBuilderError(e.into()))?; + + let provider = Arc::new(AIProvider::from(provider)); + + Ok(Arc::new(EmbeddingModel { + provider, + model_name: CONFIG.embedding_model.model.clone(), + })) + } + } + }) + .await + .cloned() + } +} + #[derive(Debug)] pub enum AIProvider { OpenAI(OpenAIProvider), diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index ea8d46a..dbdf598 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -2,7 +2,7 @@ use crate::{ Embeds, GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, OpenAIProvider, Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, SerializationMode, StructuredRanking, StructuredRerankRequest, StructuredRerankResponse, - embed::{EmbeddingRequest, EmbeddingResponse, embed as embedFn}, + embed::{EmbeddingRequest, EmbeddingResponse}, messages::{FilePart, UserModelMessage}, response_generators::{ self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, @@ -30,7 +30,7 @@ use futures::future::join_all; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Map; -use std::{error::Error, sync::Arc}; +use std::sync::Arc; use thiserror::Error; use tokio::sync::Semaphore; @@ -770,8 +770,8 @@ mod tests { use super::*; use crate::{ AIProvider, EmbeddingModel, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, - LanguageModel, RerankingModel, SerializationFormat, embed, generate_object, generate_text, - rerank, structured_rerank, + LanguageModel, RerankingModel, SerializationFormat, generate_object, generate_text, rerank, + structured_rerank, }; use serde::Deserialize; use std::sync::Arc; @@ -998,6 +998,8 @@ mod tests { #[tokio::test] async fn embedding_test() { + use crate::embed::embed; + let provider = Arc::new(AIProvider::from( AmazonBedrockProviderBuilder::default() .region("REGION") @@ -1022,7 +1024,7 @@ mod tests { ]) .build(); - let embedding_response = embedFn(request).await.unwrap(); + let embedding_response = embed(request).await.unwrap(); dbg!(&embedding_response); } } diff --git a/crates/umem_ai/src/response_generators/embed.rs b/crates/umem_ai/src/response_generators/embed.rs index 872e7e8..ef944a3 100644 --- a/crates/umem_ai/src/response_generators/embed.rs +++ b/crates/umem_ai/src/response_generators/embed.rs @@ -6,9 +6,7 @@ use backon::{ExponentialBuilder, Retryable}; use reqwest::header::HeaderMap; use std::{sync::Arc, time::Duration}; -pub async fn embed( - mut request: EmbeddingRequest, -) -> Result { +pub async fn embed(request: EmbeddingRequest) -> Result { let per_request_timeout = request.timeout; let max_retries = request.max_retries; let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0); diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index a7c4b9a..f0919ca 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -5,6 +5,7 @@ pub mod messages; pub mod rerank; pub mod structured_rerank; +pub use embed::*; pub use generate_object::*; pub use generate_text::*; pub use messages::*; diff --git a/crates/umem_config/src/lib.rs b/crates/umem_config/src/lib.rs index 6ef9e4e..fd688bb 100644 --- a/crates/umem_config/src/lib.rs +++ b/crates/umem_config/src/lib.rs @@ -3,13 +3,6 @@ use lazy_static::lazy_static; use serde::Deserialize; use std::net::SocketAddr; -#[derive(Debug, Deserialize, Clone)] -pub struct Cloudflare { - pub account_id: String, - pub api_token: String, - pub model: String, -} - #[derive(Debug, Deserialize, Clone)] pub struct OpenAI { pub api_key: String, @@ -35,6 +28,12 @@ pub enum Provider { AmazonBedrock(AmazonBedrock), } +#[derive(Debug, Deserialize, Clone)] +pub struct EmbeddingModel { + pub provider: Provider, + pub model: String, +} + #[derive(Debug, Deserialize, Clone)] pub struct LanguageModel { pub provider: Provider, @@ -54,12 +53,6 @@ pub struct Mcp { pub work_os: WorkOs, } -#[derive(Debug, Deserialize, Clone)] -pub enum EmbeddingModel { - #[serde(rename = "cloudflare")] - Cloudflare(Cloudflare), -} - #[derive(Debug, Deserialize, Clone)] pub struct WorkOs { pub client_id: String, diff --git a/crates/umem_controller/Cargo.toml b/crates/umem_controller/Cargo.toml index cea6a41..7852a1f 100644 --- a/crates/umem_controller/Cargo.toml +++ b/crates/umem_controller/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" serde_json = { workspace = true } umem_vector_store = { workspace = true} umem_ai = { workspace = true} -umem_embeddings = { workspace = true} umem_refine = { workspace = true} umem_annotations = { workspace = true} umem_core = { workspace = true} diff --git a/crates/umem_controller/src/create_memory.rs b/crates/umem_controller/src/create_memory.rs index 83b57ce..11aa4e3 100644 --- a/crates/umem_controller/src/create_memory.rs +++ b/crates/umem_controller/src/create_memory.rs @@ -3,13 +3,15 @@ use chrono::Utc; use std::sync::Arc; use thiserror::Error; use typed_builder::TypedBuilder; -use umem_ai::{AIProviderError, LanguageModel}; +use umem_ai::{ + embed::{embed, EmbeddingRequest}, + AIProviderError, EmbeddingModel, LanguageModel, ResponseGeneratorError, +}; use umem_annotations::{Annotation, AnnotationError, LLMAnnotated}; use umem_core::{ LifecycleState, Memory, MemoryContentError, MemoryContext, MemoryContextError, MemoryError, TemporalMetadata, }; -use umem_embeddings::{EmbedderBase, EmbedderError}; use umem_vector_store::VectorStoreError; use uuid::Uuid; @@ -21,11 +23,11 @@ pub enum CreateMemoryError { #[error("vector store action failed with: {0}")] VectorStoreError(#[from] VectorStoreError), - #[error("embedder action failed with: {0}")] - EmbedderError(#[from] EmbedderError), - #[error("ai provider action failed with: {0}")] AIProviderError(#[from] AIProviderError), + + #[error("response generator action failed with: {0}")] + ResponseGeneratorError(#[from] ResponseGeneratorError), } #[derive(Debug, Error)] @@ -120,9 +122,9 @@ impl CreateMemoryRequest { #[derive(TypedBuilder, Default)] pub struct CreateMemoryOptions { #[builder(default = None)] - pub embedder: Option>, + pub embedding_model: Option>, #[builder(default = None)] - pub model: Option>, + pub language_model: Option>, } impl MemoryController { @@ -140,12 +142,27 @@ impl MemoryController { _options: Option, ) -> Result { let vector_store = Arc::clone(&self.vector_store); - let embedder = Arc::clone(&self.embedder); + let embedding_model = Arc::clone(&self.embedding_model); let language_model = Arc::clone(&self.language_model); let memory = request.build(language_model).await?; - let vector = embedder.generate_embedding(memory.get_summary()).await?; - vector_store.insert(&[&vector], &[&memory]).await?; + + let request = EmbeddingRequest::builder() + .model(embedding_model) + .input(vec![memory.get_summary().to_owned()]) + .build(); + + let embedding_response = embed(request).await?; + + //NOTE: change this later, just didin't want to fight with the drilled types + let slices: Vec<&[f32]> = embedding_response + .embeddings + .iter() + .map(|inner| inner.as_slice()) + .collect(); + let slice_of_slices: &[&[f32]] = &slices; + + vector_store.insert(slice_of_slices, &[&memory]).await?; Ok(memory) } } diff --git a/crates/umem_controller/src/lib.rs b/crates/umem_controller/src/lib.rs index 5e0ce54..6c07e2d 100644 --- a/crates/umem_controller/src/lib.rs +++ b/crates/umem_controller/src/lib.rs @@ -14,8 +14,7 @@ pub use delete_memory::*; pub use get_memory::*; pub use list_memory::*; pub use search_memory::*; -use umem_ai::{LanguageModel, RerankingModel}; -use umem_embeddings::EmbedderBase; +use umem_ai::{EmbeddingModel, LanguageModel, RerankingModel}; use umem_vector_store::VectorStoreBase; pub use update_memory::*; @@ -43,7 +42,7 @@ pub enum MemoryControllerError { #[derive(Clone)] pub struct MemoryController { pub vector_store: Arc, - pub embedder: Arc, + pub embedding_model: Arc, pub reranking_model: Arc, pub language_model: Arc, } diff --git a/crates/umem_controller/src/search_memory.rs b/crates/umem_controller/src/search_memory.rs index 2b998b7..e81c727 100644 --- a/crates/umem_controller/src/search_memory.rs +++ b/crates/umem_controller/src/search_memory.rs @@ -5,10 +5,11 @@ use tokio::{sync::AcquireError, task::JoinError}; use tracing::info; use typed_builder::TypedBuilder; use umem_ai::{ - rerank, RerankRequest, RerankRequestBuilderError, RerankingModelError, ResponseGeneratorError, + embed::{embed, EmbeddingRequest}, + rerank, EmbeddingModel, EmbeddingModelError, RerankRequest, RerankRequestBuilderError, + RerankingModelError, ResponseGeneratorError, }; use umem_core::{Memory, MemoryContext, MemoryContextError, Query}; -use umem_embeddings::{EmbedderBase, EmbedderError}; use umem_refine::{RefineError, Segmenter}; use umem_vector_store::VectorStoreError; @@ -20,8 +21,8 @@ pub enum SearchMemoryError { #[error("memory context action failed with: {0}")] MemoryContextError(#[from] MemoryContextError), - #[error("embedder action failed with: {0}")] - EmbedderError(#[from] EmbedderError), + #[error("embedding action failed with: {0}")] + EmbeddingModelError(#[from] EmbeddingModelError), #[error("umem refine action failed with: {0}")] RefineError(#[from] RefineError), @@ -45,7 +46,7 @@ pub enum SearchMemoryError { #[derive(TypedBuilder, Default)] pub struct SearchMemoryOptions { #[builder(default = None)] - pub embedder: Option>, + pub embedding_model: Option>, } impl MemoryController { @@ -64,17 +65,20 @@ impl MemoryController { query: String, _options: Option, ) -> Result, SearchMemoryError> { - let vector_store = Arc::clone(&self.vector_store); - let embedder = Arc::clone(&self.embedder); + let request = EmbeddingRequest::builder() + .model(self.embedding_model.clone()) + .input(vec![query]) + .build(); + + let embedding_response = embed(request).await?; - let vector = embedder.generate_embedding(query.as_str()).await?; let query = Query::builder() - .vector(vector) + .vector(embedding_response.embeddings[0].clone()) .context(MemoryContext::for_user(user_id)?) .limit(1000) .build(); - Ok(vector_store.search(query).await?) + Ok(self.vector_store.search(query).await?) } pub async fn search_with_context( @@ -94,9 +98,15 @@ impl MemoryController { query: String, _options: Option, ) -> Result, SearchMemoryError> { - let vector = self.embedder.generate_embedding(query.as_str()).await?; + let request = EmbeddingRequest::builder() + .model(self.embedding_model.clone()) + .input(vec![query.clone()]) + .build(); + + let embedding_response = embed(request).await?; + let vector_query = Query::builder() - .vector(vector) + .vector(embedding_response.embeddings[0].clone()) .context(context) .limit(20) .build(); @@ -144,7 +154,6 @@ impl MemoryController { use tokio::task::JoinHandle; let vector_store = Arc::clone(&self.vector_store); - let embedder = Arc::clone(&self.embedder); let semaphore = Arc::new(Semaphore::new(8)); // limit concurrency (tune this!) let mut tasks: FuturesUnordered, SearchMemoryError>>> = @@ -152,14 +161,21 @@ impl MemoryController { let mut sub_queries = Segmenter::process(&query)?; sub_queries.push(query.clone()); - let sub_query_slices: Vec<&str> = sub_queries.iter().map(|s| s.as_str()).collect(); + // let sub_query_slices: Vec<&str> = sub_queries.iter().map(|s| s.as_str()).collect(); + // let vectors = embedder.generate_embeddings(&sub_query_slices).await?; let start = Instant::now(); - let vectors = embedder.generate_embeddings(&sub_query_slices).await?; + + let request = EmbeddingRequest::builder() + .model(self.embedding_model.clone()) + .input(sub_queries) + .build(); + + let embedding_response = embed(request).await?; let duration = start.elapsed(); info!("Embedder time : {:?}", duration); - for vector in vectors { + for vector in embedding_response.embeddings { let permit = Arc::clone(&semaphore).acquire_owned().await?; let vector_store = Arc::clone(&vector_store); let context = context.clone(); diff --git a/crates/umem_controller/src/update_memory.rs b/crates/umem_controller/src/update_memory.rs index 7222a89..525cd56 100644 --- a/crates/umem_controller/src/update_memory.rs +++ b/crates/umem_controller/src/update_memory.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use super::{MemoryController, MemoryControllerError}; use thiserror::Error; use typed_builder::TypedBuilder; +use umem_ai::EmbeddingModelError; use umem_core::Memory; -use umem_embeddings::EmbedderError; use umem_vector_store::VectorStoreError; #[derive(Debug, Error)] @@ -25,7 +25,7 @@ pub enum UpdateMemoryError { VectorStoreError(#[from] VectorStoreError), #[error("embedder action failed with: {0}")] - EmbedderError(#[from] EmbedderError), + EmbeddingModelError(#[from] EmbeddingModelError), } #[derive(TypedBuilder)] diff --git a/crates/umem_embeddings/Cargo.toml b/crates/umem_embeddings/Cargo.toml deleted file mode 100644 index 571ba4b..0000000 --- a/crates/umem_embeddings/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "umem_embeddings" -version = "0.1.0" -edition = "2021" - -[dependencies] -reqwest = { workspace = true } -serde = { workspace = true, features = ["derive"] } -tokio = { workspace = true } -lazy_static = { workspace = true } -umem_config = { workspace = true } -async-trait = { workspace = true } -thiserror = { workspace = true } diff --git a/crates/umem_embeddings/src/cloudflare.rs b/crates/umem_embeddings/src/cloudflare.rs deleted file mode 100644 index 81aa7a0..0000000 --- a/crates/umem_embeddings/src/cloudflare.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::{client, EmbedderBase}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - -use thiserror::Error; - -#[derive(Debug, Error)] -pub enum CloudflareError { - #[error("text to embed cannot be empty.")] - EmptyText, - - #[error("embedding failed: {0:?}")] - EmbeddingFailed(Vec), -} - -#[derive(Serialize)] -struct EmbeddingRequest<'em> { - text: &'em [&'em str], -} - -#[derive(Deserialize)] -struct EmbeddingResponse { - result: EmbeddingResult, - errors: Vec, - success: bool, -} - -#[derive(Deserialize)] -struct EmbeddingResult { - data: Vec>, -} - -pub struct Cloudflare { - pub account_id: String, - pub api_token: String, - pub model: String, -} - -impl Cloudflare { - pub fn new(config: umem_config::Cloudflare) -> Self { - Self { - account_id: config.account_id, - api_token: config.api_token, - model: config.model, - } - } -} - -#[async_trait] -impl EmbedderBase for Cloudflare { - async fn generate_embedding(&self, text: &str) -> crate::Result> { - if text.trim().is_empty() { - Err(CloudflareError::EmptyText)?; - } - let url = format!( - "https://api.cloudflare.com/client/v4/accounts/{}/ai/run/{}", - self.account_id, self.model - ); - let request_body = EmbeddingRequest { text: &[text] }; - let response = client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_token)) - .json(&request_body) - .send() - .await?; - let mut embedding_response: EmbeddingResponse = response.json().await?; - if !embedding_response.success { - Err(CloudflareError::EmbeddingFailed(embedding_response.errors))?; - } - if embedding_response.result.data.is_empty() { - Err(CloudflareError::EmbeddingFailed(vec!["No embedding data - returned" - .to_string()]))?; - } - - Ok(std::mem::take(&mut embedding_response.result.data[0])) - } - - async fn generate_embeddings(&self, texts: &[&str]) -> crate::Result>> { - if texts.is_empty() || texts.iter().any(|text| text.trim().is_empty()) { - Err(CloudflareError::EmptyText)?; - } - - let url = format!( - "https://api.cloudflare.com/client/v4/accounts/{}/ai/run/{}", - self.account_id, self.model - ); - let request_body = EmbeddingRequest { text: texts }; - let response = client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_token)) - .json(&request_body) - .send() - .await?; - let embedding_response: EmbeddingResponse = response.json().await?; - if !embedding_response.success { - Err(CloudflareError::EmbeddingFailed(embedding_response.errors))?; - } - - if embedding_response.result.data.is_empty() { - Err(CloudflareError::EmbeddingFailed(vec!["No embedding data - returned" - .to_string()]))?; - } - - Ok(embedding_response.result.data) - } -} diff --git a/crates/umem_embeddings/src/lib.rs b/crates/umem_embeddings/src/lib.rs deleted file mode 100644 index 6493b8d..0000000 --- a/crates/umem_embeddings/src/lib.rs +++ /dev/null @@ -1,52 +0,0 @@ -use async_trait::async_trait; -use lazy_static::lazy_static; -use reqwest::Client; -mod cloudflare; -use cloudflare::{Cloudflare, CloudflareError}; -use umem_config::CONFIG; - -use std::sync::Arc; -use thiserror::Error; -use tokio::sync::OnceCell; - -#[derive(Debug, Error)] -pub enum EmbedderError { - #[error("cloudflare embedder failed : {0}")] - CloudflareError(#[from] CloudflareError), - - #[error("reqwest failed : {0}")] - ReqwestError(#[from] reqwest::Error), -} - -lazy_static! { - static ref client: Client = Client::new(); -} - -static EMBEDDER: OnceCell> = OnceCell::const_new(); - -#[derive(Copy, Clone)] -pub struct Embedder; - -type Result = std::result::Result; - -impl Embedder { - pub async fn get_embedder() -> Result> { - EMBEDDER - .get_or_try_init(|| async { - match CONFIG.embedding_model.clone() { - umem_config::EmbeddingModel::Cloudflare(config) => { - let cloudflare = Cloudflare::new(config); - Ok(Arc::new(cloudflare) as Arc) - } - } - }) - .await - .cloned() - } -} - -#[async_trait] -pub trait EmbedderBase { - async fn generate_embedding(&self, text: &str) -> Result>; - async fn generate_embeddings(&self, text: &[&str]) -> Result>>; -} diff --git a/crates/umem_memory_machine/Cargo.toml b/crates/umem_memory_machine/Cargo.toml index ba212f2..f5148bc 100644 --- a/crates/umem_memory_machine/Cargo.toml +++ b/crates/umem_memory_machine/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -umem_embeddings = { workspace = true } umem_controller = { workspace = true } umem_vector_store = { workspace = true } umem_config = { workspace = true } diff --git a/crates/umem_memory_machine/src/lib.rs b/crates/umem_memory_machine/src/lib.rs index f8ba21f..16ec305 100644 --- a/crates/umem_memory_machine/src/lib.rs +++ b/crates/umem_memory_machine/src/lib.rs @@ -2,19 +2,18 @@ use std::sync::Arc; use thiserror::Error; use typed_builder::TypedBuilder; -use umem_ai::{LanguageModel, LanguageModelError, RerankingModel, RerankingModelError}; +use umem_ai::{ + EmbeddingModel, EmbeddingModelError, LanguageModel, LanguageModelError, RerankingModel, + RerankingModelError, +}; use umem_config::CONFIG; use umem_controller::MemoryController; -use umem_embeddings::{Embedder, EmbedderBase, EmbedderError}; use umem_grpc_server::MemoryServiceGrpc; use umem_mcp::MemoryServiceMcp; use umem_vector_store::{VectorStore, VectorStoreBase, VectorStoreError}; #[derive(Debug, Error)] pub enum MemoryMachineError { - #[error("memory machine embedder failed : {0}")] - EmbedderError(#[from] EmbedderError), - #[error("memory machine vector_store failed : {0}")] VectorStoreError(#[from] VectorStoreError), @@ -23,6 +22,9 @@ pub enum MemoryMachineError { #[error("memory machine llm failed : {0}")] LanguageModelError(#[from] LanguageModelError), + + #[error("memory machine embedding failed : {0}")] + EmbeddingModelError(#[from] EmbeddingModelError), } #[derive(TypedBuilder)] @@ -33,7 +35,7 @@ pub struct MemoryMachine { #[derive(TypedBuilder)] pub struct MemoryMachineOptions { vector_store: Option>, - embedder: Option>, + embedder: Option>, reranking_model: Option>, language_model: Option>, } @@ -42,7 +44,7 @@ impl MemoryMachine { pub async fn new() -> Result { Ok(Self { memory_controller: MemoryController { - embedder: Embedder::get_embedder().await?, + embedding_model: EmbeddingModel::get_model().await?, vector_store: VectorStore::get_store().await?, reranking_model: RerankingModel::get_model().await?, language_model: LanguageModel::get_model().await?, @@ -51,7 +53,9 @@ impl MemoryMachine { } pub async fn new_with(options: MemoryMachineOptions) -> Result { - let embedder = options.embedder.unwrap_or(Embedder::get_embedder().await?); + let embedding_model = options + .embedder + .unwrap_or(EmbeddingModel::get_model().await?); let vector_store = options .vector_store .unwrap_or(VectorStore::get_store().await?); @@ -64,7 +68,7 @@ impl MemoryMachine { Ok(Self { memory_controller: MemoryController { - embedder, + embedding_model, vector_store, reranking_model, language_model,