Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 1 addition & 49 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ members = ["crates/*"]
umem_core = {path = "crates/umem_core"}
umem_memory_machine = {path = "crates/umem_memory_machine"}
umem_refine = {path = "crates/umem_refine"}
umem_rerank = {path = "crates/umem_rerank"}
umem_vector_store = {path = "crates/umem_vector_store"}
umem_config = {path = "crates/umem_config"}
umem_embeddings = {path = "crates/umem_embeddings"}
Expand Down Expand Up @@ -56,7 +55,6 @@ thiserror = { workspace = true }
typed-builder = { workspace = true }
umem_config = { workspace = true }
umem_controller = { workspace = true }
umem_rerank = { workspace = true }
umem_ai = { workspace = true }
umem_embeddings = { workspace = true }
umem_vector_store = { workspace = true }
Expand Down
50 changes: 50 additions & 0 deletions crates/umem_ai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,69 @@ impl LanguageModel {
}
}

#[derive(Error, Debug)]
pub enum RerankingModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

#[derive(Debug, Clone)]
pub struct RerankingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

static RERANKING_MODEL: OnceCell<Arc<RerankingModel>> = OnceCell::const_new();

impl RerankingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}

pub async fn get_model() -> Result<Arc<RerankingModel>, LanguageModelError> {
RERANKING_MODEL
.get_or_try_init(|| async {
match CONFIG.reranking_model.provider.clone() {
umem_config::Provider::OpenAI(open_ai) => {
let openai_provider = OpenAIProvider::builder()
.api_key(open_ai.api_key)
.base_url(open_ai.base_url)
.default_headers(open_ai.default_headers.unwrap_or_default())
.project(open_ai.project)
.organization(open_ai.organization)
.build();

let provider = Arc::new(AIProvider::from(openai_provider));

Ok(Arc::new(RerankingModel {
provider,
model_name: CONFIG.reranking_model.model.clone(),
}))
}
umem_config::Provider::AmazonBedrock(config) => {
let provider = AmazonBedrockProviderBuilder::default()
.region(config.region)
.access_key_id(config.key_id)
.secret_access_key(config.access_key)
.build()
.await
.map_err(|e| AIProviderError::ProviderBuilderError(e.into()))?;

let provider = Arc::new(AIProvider::from(provider));

Ok(Arc::new(RerankingModel {
provider,
model_name: CONFIG.reranking_model.model.clone(),
}))
}
}
})
.await
.cloned()
}
Comment thread
Sang-it marked this conversation as resolved.
}

#[derive(Debug)]
Expand Down
25 changes: 5 additions & 20 deletions crates/umem_config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub struct Mcp {
}

#[derive(Debug, Deserialize, Clone)]
pub enum Embedder {
pub enum EmbeddingModel {
#[serde(rename = "cloudflare")]
Cloudflare(Cloudflare),
}
Expand Down Expand Up @@ -92,32 +92,17 @@ pub enum VectorStore {
}

#[derive(Debug, Deserialize, Clone)]
pub struct Pinecone {
pub api_key: String,
pub model: String,
}

#[derive(Debug, Deserialize, Clone)]
pub struct Cohere {
pub api_key: String,
pub struct RerankingModel {
pub provider: Provider,
pub model: String,
}

#[derive(Debug, Deserialize, Clone)]
pub enum Reranker {
#[serde(rename = "pinecone")]
Pinecone(Pinecone),

#[serde(rename = "cohere")]
Cohere(Cohere),
}

#[derive(Debug, Deserialize, Clone)]
pub struct AppConfig {
pub vector_store: VectorStore,
pub embedder: Embedder,
pub embedding_model: EmbeddingModel,
pub language_model: LanguageModel,
pub reranker: Reranker,
pub reranking_model: RerankingModel,
pub mcp: Mcp,
pub grpc: Grpc,
}
Expand Down
1 change: 0 additions & 1 deletion crates/umem_controller/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ umem_vector_store = { workspace = true}
umem_ai = { workspace = true}
umem_embeddings = { workspace = true}
umem_refine = { workspace = true}
umem_rerank = { workspace = true}
umem_annotations = { workspace = true}
umem_core = { workspace = true}
anyhow = { workspace = true }
Expand Down
5 changes: 2 additions & 3 deletions crates/umem_controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ pub use delete_memory::*;
pub use get_memory::*;
pub use list_memory::*;
pub use search_memory::*;
use umem_ai::LanguageModel;
use umem_ai::{LanguageModel, RerankingModel};
use umem_embeddings::EmbedderBase;
use umem_rerank::RerankerBase;
use umem_vector_store::VectorStoreBase;
pub use update_memory::*;

Expand Down Expand Up @@ -45,6 +44,6 @@ pub enum MemoryControllerError {
pub struct MemoryController {
pub vector_store: Arc<dyn VectorStoreBase + Send + Sync>,
pub embedder: Arc<dyn EmbedderBase + Send + Sync>,
pub reranker: Arc<dyn RerankerBase + Send + Sync>,
pub reranking_model: Arc<RerankingModel>,
pub language_model: Arc<LanguageModel>,
}
56 changes: 40 additions & 16 deletions crates/umem_controller/src/search_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use thiserror::Error;
use tokio::{sync::AcquireError, task::JoinError};
use tracing::info;
use typed_builder::TypedBuilder;
use umem_ai::{
rerank, RerankRequest, RerankRequestBuilderError, RerankingModelError, ResponseGeneratorError,
};
use umem_core::{Memory, MemoryContext, MemoryContextError, Query};
use umem_embeddings::{EmbedderBase, EmbedderError};
use umem_refine::{RefineError, Segmenter};
use umem_rerank::RerankError;
use umem_vector_store::VectorStoreError;

#[derive(Debug, Error)]
Expand All @@ -31,7 +33,13 @@ pub enum SearchMemoryError {
JoinError(#[from] JoinError),

#[error("rerank action failed with: {0}")]
RerankError(#[from] RerankError),
RerankingModelError(#[from] RerankingModelError),

#[error("rerank builder action failed with: {0}")]
RerankingModelBuilderError(#[from] RerankRequestBuilderError),

#[error("rerank response action failed with: {0}")]
ResponseGeneratorError(#[from] ResponseGeneratorError),
}

#[derive(TypedBuilder, Default)]
Expand Down Expand Up @@ -86,24 +94,29 @@ impl MemoryController {
query: String,
_options: Option<SearchMemoryOptions>,
) -> Result<Vec<Memory>, SearchMemoryError> {
let vector_store = Arc::clone(&self.vector_store);
let embedder = Arc::clone(&self.embedder);
let reranker = Arc::clone(&self.reranker);

let vector = embedder.generate_embedding(query.as_str()).await?;
let vector = self.embedder.generate_embedding(query.as_str()).await?;
let vector_query = Query::builder()
.vector(vector)
.context(context)
.limit(20)
.build();

let mut memories = vector_store.search(vector_query).await?;
let mut memories = self.vector_store.search(vector_query).await?;

let documents: Vec<String> = memories.iter().map(|m| m.get_summary().clone()).collect();
let request = RerankRequest::builder()
.model(Arc::clone(&self.reranking_model))
.documents(documents)
.query(query)
.top_k(6)
.build()?;

let data = reranker.rank(query, &memories, None).await?.data;
let rerank_response = rerank(request).await?;

let memories: Vec<Memory> = data
let memories: Vec<Memory> = rerank_response
.rankings
.iter()
.map(|row| std::mem::take(&mut memories[row.index]))
.map(|row| std::mem::take(&mut memories[row.original_index]))
.collect();

Ok(memories)
Expand Down Expand Up @@ -132,7 +145,6 @@ impl MemoryController {

let vector_store = Arc::clone(&self.vector_store);
let embedder = Arc::clone(&self.embedder);
let reranker = Arc::clone(&self.reranker);

let semaphore = Arc::new(Semaphore::new(8)); // limit concurrency (tune this!)
let mut tasks: FuturesUnordered<JoinHandle<Result<Vec<Memory>, SearchMemoryError>>> =
Expand Down Expand Up @@ -174,15 +186,27 @@ impl MemoryController {
let duration = start.elapsed();
info!("Searching time : {:?}", duration);

let documents: Vec<String> = all_memories
.iter()
.map(|m| m.get_summary().clone())
.collect();
let request = RerankRequest::builder()
.model(Arc::clone(&self.reranking_model))
.documents(documents)
.query(query)
.top_k(6)
.build()?;

let start = Instant::now();
let data = reranker.rank(query, &all_memories, None).await?.data;
let rerank_response = rerank(request).await?;
let duration = start.elapsed();
info!("Ranking time : {:?}", duration);
info!("Reranking time : {:?}", duration);

// TODO: hybrid search with "row.score" + other metrics
let memories: Vec<Memory> = data
let memories: Vec<Memory> = rerank_response
.rankings
.iter()
.map(|row| std::mem::take(&mut all_memories[row.index]))
.map(|row| std::mem::take(&mut all_memories[row.original_index]))
.collect();

Ok(memories)
Expand Down
4 changes: 2 additions & 2 deletions crates/umem_embeddings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ impl Embedder {
pub async fn get_embedder() -> Result<Arc<dyn EmbedderBase + Send + Sync>> {
EMBEDDER
.get_or_try_init(|| async {
match CONFIG.embedder.clone() {
umem_config::Embedder::Cloudflare(config) => {
match CONFIG.embedding_model.clone() {
umem_config::EmbeddingModel::Cloudflare(config) => {
let cloudflare = Cloudflare::new(config);
Ok(Arc::new(cloudflare) as Arc<dyn EmbedderBase + Send + Sync>)
}
Expand Down
1 change: 0 additions & 1 deletion crates/umem_memory_machine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ edition = "2021"
umem_embeddings = { workspace = true }
umem_controller = { workspace = true }
umem_vector_store = { workspace = true }
umem_rerank = { workspace = true }
umem_config = { workspace = true }
umem_grpc_server = { workspace = true }
umem_ai = { workspace = true }
Expand Down
Loading