From e00cb1dcedec2d82fc5d4f2a2657078de8e28e5f Mon Sep 17 00:00:00 2001 From: Sangit Manandhar Date: Wed, 21 Jan 2026 00:38:22 -0500 Subject: [PATCH 1/2] chore: switch to rerank in umem_ai --- Cargo.lock | 50 +----- Cargo.toml | 2 - crates/umem_ai/src/lib.rs | 50 ++++++ crates/umem_config/src/lib.rs | 15 +- crates/umem_controller/Cargo.toml | 1 - crates/umem_controller/src/lib.rs | 5 +- crates/umem_controller/src/search_memory.rs | 56 +++++-- crates/umem_embeddings/src/lib.rs | 4 +- crates/umem_memory_machine/Cargo.toml | 1 - crates/umem_memory_machine/src/lib.rs | 15 +- crates/umem_rerank/Cargo.toml | 17 --- crates/umem_rerank/src/cohere.rs | 143 ------------------ crates/umem_rerank/src/lib.rs | 88 ----------- crates/umem_rerank/src/pinecone.rs | 137 ----------------- .../umem_rerank/umem_rerank_derive/Cargo.toml | 12 -- .../umem_rerank/umem_rerank_derive/src/lib.rs | 85 ----------- src/bin/seed_db.rs | 103 ++++++------- 17 files changed, 161 insertions(+), 623 deletions(-) delete mode 100644 crates/umem_rerank/Cargo.toml delete mode 100644 crates/umem_rerank/src/cohere.rs delete mode 100644 crates/umem_rerank/src/lib.rs delete mode 100644 crates/umem_rerank/src/pinecone.rs delete mode 100644 crates/umem_rerank/umem_rerank_derive/Cargo.toml delete mode 100644 crates/umem_rerank/umem_rerank_derive/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index ab4b2d3..0c11365 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3923,25 +3923,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-saphyr" -version = "0.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef7e021087b4b525edcbb93045ca3f1dafeff7800b3c7960a90a75f197e257da" -dependencies = [ - "ahash", - "annotate-snippets", - "base64", - "encoding_rs_io", - "nohash-hasher", - "num-traits", - "ryu", - "saphyr-parser", - "serde", - "serde_json", - "smallvec 2.0.0-alpha.12", -] - [[package]] name = "serde-saphyr" version = "0.0.14" @@ -5152,7 +5133,6 @@ dependencies = [ "umem_grpc_server", "umem_mcp", "umem_memory_machine", - "umem_rerank", "umem_vector_store", ] @@ -5174,7 +5154,7 @@ dependencies = [ "rustc-hash 2.1.1", "schemars", "serde", - "serde-saphyr 0.0.14", + "serde-saphyr", "serde_json", "thiserror 2.0.17", "tokio", @@ -5229,7 +5209,6 @@ dependencies = [ "umem_core", "umem_embeddings", "umem_refine", - "umem_rerank", "umem_vector_store", "uuid", ] @@ -5339,7 +5318,6 @@ dependencies = [ "umem_embeddings", "umem_grpc_server", "umem_mcp", - "umem_rerank", "umem_vector_store", ] @@ -5365,32 +5343,6 @@ dependencies = [ "umem_ai", ] -[[package]] -name = "umem_rerank" -version = "0.1.0" -dependencies = [ - "async-trait", - "reqwest", - "serde", - "serde-saphyr 0.0.12", - "serde_json", - "thiserror 2.0.17", - "tokio", - "typed-builder", - "umem_config", - "umem_core", - "umem_rerank_derive", -] - -[[package]] -name = "umem_rerank_derive" -version = "0.1.0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "umem_summarizer" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 1d915b1..bb8b7d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ members = ["crates/*"] umem_core = {path = "crates/umem_core"} umem_memory_machine = {path = "crates/umem_memory_machine"} umem_refine = {path = "crates/umem_refine"} -umem_rerank = {path = "crates/umem_rerank"} umem_vector_store = {path = "crates/umem_vector_store"} umem_config = {path = "crates/umem_config"} umem_embeddings = {path = "crates/umem_embeddings"} @@ -56,7 +55,6 @@ thiserror = { workspace = true } typed-builder = { workspace = true } umem_config = { workspace = true } umem_controller = { workspace = true } -umem_rerank = { workspace = true } umem_ai = { workspace = true } umem_embeddings = { workspace = true } umem_vector_store = { workspace = true } diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index 1393ead..00668d6 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -86,12 +86,20 @@ impl LanguageModel { } } +#[derive(Error, Debug)] +pub enum RerankingModelError { + #[error("ai provider failed with : {0}")] + AIProviderError(#[from] AIProviderError), +} + #[derive(Debug, Clone)] pub struct RerankingModel { pub provider: Arc, pub model_name: String, } +static RERANKING_MODEL: OnceCell> = OnceCell::const_new(); + impl RerankingModel { fn new(provider: Arc, model_name: String) -> Self { Self { @@ -99,6 +107,48 @@ impl RerankingModel { model_name, } } + + pub async fn get_model() -> Result, LanguageModelError> { + RERANKING_MODEL + .get_or_try_init(|| async { + match CONFIG.reranking_model.provider.clone() { + umem_config::Provider::OpenAI(open_ai) => { + let openai_provider = OpenAIProvider::builder() + .api_key(open_ai.api_key) + .base_url(open_ai.base_url) + .default_headers(open_ai.default_headers.unwrap_or_default()) + .project(open_ai.project) + .organization(open_ai.organization) + .build(); + + let provider = Arc::new(AIProvider::from(openai_provider)); + + Ok(Arc::new(RerankingModel { + provider, + model_name: CONFIG.reranking_model.model.clone(), + })) + } + umem_config::Provider::AmazonBedrock(config) => { + let provider = AmazonBedrockProviderBuilder::default() + .region(config.region) + .access_key_id(config.key_id) + .secret_access_key(config.access_key) + .build() + .await + .map_err(|e| AIProviderError::ProviderBuilderError(e.into()))?; + + let provider = Arc::new(AIProvider::from(provider)); + + Ok(Arc::new(RerankingModel { + provider, + model_name: CONFIG.reranking_model.model.clone(), + })) + } + } + }) + .await + .cloned() + } } #[derive(Debug)] diff --git a/crates/umem_config/src/lib.rs b/crates/umem_config/src/lib.rs index c123eb4..e5058dc 100644 --- a/crates/umem_config/src/lib.rs +++ b/crates/umem_config/src/lib.rs @@ -55,7 +55,7 @@ pub struct Mcp { } #[derive(Debug, Deserialize, Clone)] -pub enum Embedder { +pub enum EmbeddingModel { #[serde(rename = "cloudflare")] Cloudflare(Cloudflare), } @@ -104,20 +104,17 @@ pub struct Cohere { } #[derive(Debug, Deserialize, Clone)] -pub enum Reranker { - #[serde(rename = "pinecone")] - Pinecone(Pinecone), - - #[serde(rename = "cohere")] - Cohere(Cohere), +pub struct RerankingModel { + pub provider: Provider, + pub model: String, } #[derive(Debug, Deserialize, Clone)] pub struct AppConfig { pub vector_store: VectorStore, - pub embedder: Embedder, + pub embedding_model: EmbeddingModel, pub language_model: LanguageModel, - pub reranker: Reranker, + pub reranking_model: RerankingModel, pub mcp: Mcp, pub grpc: Grpc, } diff --git a/crates/umem_controller/Cargo.toml b/crates/umem_controller/Cargo.toml index f2dbf13..cea6a41 100644 --- a/crates/umem_controller/Cargo.toml +++ b/crates/umem_controller/Cargo.toml @@ -9,7 +9,6 @@ umem_vector_store = { workspace = true} umem_ai = { workspace = true} umem_embeddings = { workspace = true} umem_refine = { workspace = true} -umem_rerank = { workspace = true} umem_annotations = { workspace = true} umem_core = { workspace = true} anyhow = { workspace = true } diff --git a/crates/umem_controller/src/lib.rs b/crates/umem_controller/src/lib.rs index 53f6423..5e0ce54 100644 --- a/crates/umem_controller/src/lib.rs +++ b/crates/umem_controller/src/lib.rs @@ -14,9 +14,8 @@ pub use delete_memory::*; pub use get_memory::*; pub use list_memory::*; pub use search_memory::*; -use umem_ai::LanguageModel; +use umem_ai::{LanguageModel, RerankingModel}; use umem_embeddings::EmbedderBase; -use umem_rerank::RerankerBase; use umem_vector_store::VectorStoreBase; pub use update_memory::*; @@ -45,6 +44,6 @@ pub enum MemoryControllerError { pub struct MemoryController { pub vector_store: Arc, pub embedder: Arc, - pub reranker: Arc, + pub reranking_model: Arc, pub language_model: Arc, } diff --git a/crates/umem_controller/src/search_memory.rs b/crates/umem_controller/src/search_memory.rs index ba38901..2b998b7 100644 --- a/crates/umem_controller/src/search_memory.rs +++ b/crates/umem_controller/src/search_memory.rs @@ -4,10 +4,12 @@ use thiserror::Error; use tokio::{sync::AcquireError, task::JoinError}; use tracing::info; use typed_builder::TypedBuilder; +use umem_ai::{ + rerank, RerankRequest, RerankRequestBuilderError, RerankingModelError, ResponseGeneratorError, +}; use umem_core::{Memory, MemoryContext, MemoryContextError, Query}; use umem_embeddings::{EmbedderBase, EmbedderError}; use umem_refine::{RefineError, Segmenter}; -use umem_rerank::RerankError; use umem_vector_store::VectorStoreError; #[derive(Debug, Error)] @@ -31,7 +33,13 @@ pub enum SearchMemoryError { JoinError(#[from] JoinError), #[error("rerank action failed with: {0}")] - RerankError(#[from] RerankError), + RerankingModelError(#[from] RerankingModelError), + + #[error("rerank builder action failed with: {0}")] + RerankingModelBuilderError(#[from] RerankRequestBuilderError), + + #[error("rerank response action failed with: {0}")] + ResponseGeneratorError(#[from] ResponseGeneratorError), } #[derive(TypedBuilder, Default)] @@ -86,24 +94,29 @@ impl MemoryController { query: String, _options: Option, ) -> Result, SearchMemoryError> { - let vector_store = Arc::clone(&self.vector_store); - let embedder = Arc::clone(&self.embedder); - let reranker = Arc::clone(&self.reranker); - - let vector = embedder.generate_embedding(query.as_str()).await?; + let vector = self.embedder.generate_embedding(query.as_str()).await?; let vector_query = Query::builder() .vector(vector) .context(context) .limit(20) .build(); - let mut memories = vector_store.search(vector_query).await?; + let mut memories = self.vector_store.search(vector_query).await?; + + let documents: Vec = memories.iter().map(|m| m.get_summary().clone()).collect(); + let request = RerankRequest::builder() + .model(Arc::clone(&self.reranking_model)) + .documents(documents) + .query(query) + .top_k(6) + .build()?; - let data = reranker.rank(query, &memories, None).await?.data; + let rerank_response = rerank(request).await?; - let memories: Vec = data + let memories: Vec = rerank_response + .rankings .iter() - .map(|row| std::mem::take(&mut memories[row.index])) + .map(|row| std::mem::take(&mut memories[row.original_index])) .collect(); Ok(memories) @@ -132,7 +145,6 @@ impl MemoryController { let vector_store = Arc::clone(&self.vector_store); let embedder = Arc::clone(&self.embedder); - let reranker = Arc::clone(&self.reranker); let semaphore = Arc::new(Semaphore::new(8)); // limit concurrency (tune this!) let mut tasks: FuturesUnordered, SearchMemoryError>>> = @@ -174,15 +186,27 @@ impl MemoryController { let duration = start.elapsed(); info!("Searching time : {:?}", duration); + let documents: Vec = all_memories + .iter() + .map(|m| m.get_summary().clone()) + .collect(); + let request = RerankRequest::builder() + .model(Arc::clone(&self.reranking_model)) + .documents(documents) + .query(query) + .top_k(6) + .build()?; + let start = Instant::now(); - let data = reranker.rank(query, &all_memories, None).await?.data; + let rerank_response = rerank(request).await?; let duration = start.elapsed(); - info!("Ranking time : {:?}", duration); + info!("Reranking time : {:?}", duration); // TODO: hybrid search with "row.score" + other metrics - let memories: Vec = data + let memories: Vec = rerank_response + .rankings .iter() - .map(|row| std::mem::take(&mut all_memories[row.index])) + .map(|row| std::mem::take(&mut all_memories[row.original_index])) .collect(); Ok(memories) diff --git a/crates/umem_embeddings/src/lib.rs b/crates/umem_embeddings/src/lib.rs index 4c9ed74..6493b8d 100644 --- a/crates/umem_embeddings/src/lib.rs +++ b/crates/umem_embeddings/src/lib.rs @@ -33,8 +33,8 @@ impl Embedder { pub async fn get_embedder() -> Result> { EMBEDDER .get_or_try_init(|| async { - match CONFIG.embedder.clone() { - umem_config::Embedder::Cloudflare(config) => { + match CONFIG.embedding_model.clone() { + umem_config::EmbeddingModel::Cloudflare(config) => { let cloudflare = Cloudflare::new(config); Ok(Arc::new(cloudflare) as Arc) } diff --git a/crates/umem_memory_machine/Cargo.toml b/crates/umem_memory_machine/Cargo.toml index f2aced1..ba212f2 100644 --- a/crates/umem_memory_machine/Cargo.toml +++ b/crates/umem_memory_machine/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" umem_embeddings = { workspace = true } umem_controller = { workspace = true } umem_vector_store = { workspace = true } -umem_rerank = { workspace = true } umem_config = { workspace = true } umem_grpc_server = { workspace = true } umem_ai = { workspace = true } diff --git a/crates/umem_memory_machine/src/lib.rs b/crates/umem_memory_machine/src/lib.rs index 4358c8a..f8ba21f 100644 --- a/crates/umem_memory_machine/src/lib.rs +++ b/crates/umem_memory_machine/src/lib.rs @@ -2,13 +2,12 @@ use std::sync::Arc; use thiserror::Error; use typed_builder::TypedBuilder; -use umem_ai::{LanguageModel, LanguageModelError}; +use umem_ai::{LanguageModel, LanguageModelError, RerankingModel, RerankingModelError}; use umem_config::CONFIG; use umem_controller::MemoryController; use umem_embeddings::{Embedder, EmbedderBase, EmbedderError}; use umem_grpc_server::MemoryServiceGrpc; use umem_mcp::MemoryServiceMcp; -use umem_rerank::{RerankError, Reranker, RerankerBase}; use umem_vector_store::{VectorStore, VectorStoreBase, VectorStoreError}; #[derive(Debug, Error)] @@ -20,7 +19,7 @@ pub enum MemoryMachineError { VectorStoreError(#[from] VectorStoreError), #[error("memory machine rerank failed : {0}")] - RerankError(#[from] RerankError), + RerankingModelError(#[from] RerankingModelError), #[error("memory machine llm failed : {0}")] LanguageModelError(#[from] LanguageModelError), @@ -35,7 +34,7 @@ pub struct MemoryMachine { pub struct MemoryMachineOptions { vector_store: Option>, embedder: Option>, - reranker: Option>, + reranking_model: Option>, language_model: Option>, } @@ -45,7 +44,7 @@ impl MemoryMachine { memory_controller: MemoryController { embedder: Embedder::get_embedder().await?, vector_store: VectorStore::get_store().await?, - reranker: Reranker::get_reranker().await?, + reranking_model: RerankingModel::get_model().await?, language_model: LanguageModel::get_model().await?, }, }) @@ -56,7 +55,9 @@ impl MemoryMachine { let vector_store = options .vector_store .unwrap_or(VectorStore::get_store().await?); - let reranker = options.reranker.unwrap_or(Reranker::get_reranker().await?); + let reranking_model = options + .reranking_model + .unwrap_or(RerankingModel::get_model().await?); let language_model = options .language_model .unwrap_or(LanguageModel::get_model().await?); @@ -65,7 +66,7 @@ impl MemoryMachine { memory_controller: MemoryController { embedder, vector_store, - reranker, + reranking_model, language_model, }, }) diff --git a/crates/umem_rerank/Cargo.toml b/crates/umem_rerank/Cargo.toml deleted file mode 100644 index f642a84..0000000 --- a/crates/umem_rerank/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "umem_rerank" -version = "0.1.0" -edition = "2021" - -[dependencies] -reqwest = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -async-trait = { workspace = true } -thiserror = { workspace = true } -typed-builder = { workspace = true } -umem_core = { workspace = true } -umem_config = { workspace = true } -tokio = { workspace = true } -umem_rerank_derive = { path = "./umem_rerank_derive" } -serde-saphyr = { workspace = true } diff --git a/crates/umem_rerank/src/cohere.rs b/crates/umem_rerank/src/cohere.rs deleted file mode 100644 index b45282b..0000000 --- a/crates/umem_rerank/src/cohere.rs +++ /dev/null @@ -1,143 +0,0 @@ -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use thiserror::Error; -use umem_core::{Memory, MemoryContent, MemoryKind}; - -use super::{ - DataRow, RerankError, RerankOptions, RerankOptionsError, RerankResponse, RerankerBase, -}; - -#[derive(Debug, Deserialize)] -pub struct CohereResponse { - pub results: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct CohereDataRow { - pub index: usize, - pub relevance_score: f32, -} - -impl From for DataRow { - fn from(value: CohereDataRow) -> Self { - Self { - index: value.index, - score: value.relevance_score, - } - } -} - -#[derive(Error, Debug)] -pub enum CohereError { - #[error("pinecone api call failed with: {0}")] - ReqwestError(#[from] reqwest::Error), - - #[error("rerank option action failed with: {0}")] - RerankOptionsError(#[from] RerankOptionsError), -} - -pub struct Cohere { - api_key: String, - model: String, -} - -impl Cohere { - pub fn new(config: umem_config::Cohere) -> Self { - Self { - api_key: config.api_key, - model: config.model, - } - } - - fn validate( - &self, - options: Option, - doc_len: usize, - ) -> Result<(), RerankOptionsError> { - if let Some(options) = options { - if let Some(top_n) = options.top_n { - if doc_len < top_n { - return Err(RerankOptionsError::InvalidTopN(top_n, doc_len)); - } - } - } - - Ok(()) - } -} - -#[derive(Debug, Deserialize, Serialize)] -struct CohereMemory { - kind: MemoryKind, - content: MemoryContent, -} - -impl From<&Memory> for CohereMemory { - fn from(value: &Memory) -> Self { - Self { - kind: *value.kind(), - content: value.content().clone(), - } - } -} - -impl Cohere { - async fn rank_impl( - &self, - query: String, - documents: &[Memory], - options: Option, - ) -> Result { - let client = Client::new(); - - self.validate(options, documents.len())?; - - let documents: Vec = documents.iter().map(|m| m.into()).collect(); - let documents: Vec = documents - .iter() - .map(|m| serde_saphyr::to_string(m).unwrap()) - .collect(); - - let response = client - .post("https://api.cohere.com/v2/rerank") - .header("Content-Type", "application/json") - .header("Accept", "application/json") - .header("Authorization", &format!("bearer {}", &self.api_key)) - .json(&json!({ - "model": &self.model, - "query": query, - "documents": documents, - "top_n": 2, - })); - - let response = response.send().await?; - - match response.error_for_status() { - Ok(res) => { - let response: CohereResponse = res.json().await?; - Ok(response) - } - Err(err) => Err(err.into()), - } - } -} -impl From for RerankResponse { - fn from(value: CohereResponse) -> Self { - let data = value.results.into_iter().map(|d| d.into()).collect(); - Self { data } - } -} - -#[async_trait] -impl RerankerBase for Cohere { - async fn rank( - &self, - query: String, - documents: &[Memory], - options: Option, - ) -> Result { - Ok(self.rank_impl(query, documents, options).await?.into()) - } -} diff --git a/crates/umem_rerank/src/lib.rs b/crates/umem_rerank/src/lib.rs deleted file mode 100644 index 24ea0d6..0000000 --- a/crates/umem_rerank/src/lib.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use serde::Deserialize; -use thiserror::Error; -use tokio::sync::OnceCell; -use typed_builder::TypedBuilder; -use umem_config::CONFIG; -use umem_core::Memory; - -mod cohere; -mod pinecone; - -pub use cohere::*; -pub use pinecone::*; - -static RERANKER: OnceCell> = OnceCell::const_new(); - -#[derive(Copy, Clone)] -pub struct Reranker; - -impl Reranker { - pub async fn get_reranker() -> Result, RerankError> { - RERANKER - .get_or_try_init(|| async { - match CONFIG.reranker.clone() { - umem_config::Reranker::Pinecone(config) => { - let pinecone = Pinecone::new(config); - Ok(Arc::new(pinecone) as Arc) - } - umem_config::Reranker::Cohere(config) => { - let cohere = Cohere::new(config); - Ok(Arc::new(cohere) as Arc) - } - } - }) - .await - .cloned() - } -} - -#[derive(TypedBuilder, Debug, Default, Deserialize)] -pub struct RerankOptions { - #[builder(default = None)] - pub top_n: Option, - - #[builder(default = None)] - pub structured_input: Option, -} - -#[derive(Error, Debug)] -pub enum RerankOptionsError { - #[error("top_n : {0} cannot be more than doc_len: {1}")] - InvalidTopN(usize, usize), - - #[error("field '{0}' cannot be used for '{1}'")] - InvalidField(String, String), -} - -#[derive(Error, Debug)] -pub enum RerankError { - #[error("pinecone failed with : {0}")] - PineconeError(#[from] PineconeError), - - #[error("cohere failed with : {0}")] - CohereError(#[from] CohereError), -} - -#[derive(Debug, Deserialize)] -pub struct RerankResponse { - pub data: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct DataRow { - pub index: usize, - pub score: f32, -} - -#[async_trait] -pub trait RerankerBase { - async fn rank( - &self, - query: String, - documents: &[Memory], - options: Option, - ) -> Result; -} diff --git a/crates/umem_rerank/src/pinecone.rs b/crates/umem_rerank/src/pinecone.rs deleted file mode 100644 index 23be61b..0000000 --- a/crates/umem_rerank/src/pinecone.rs +++ /dev/null @@ -1,137 +0,0 @@ -use super::{ - DataRow, RerankError, RerankOptions, RerankOptionsError, RerankResponse, RerankerBase, -}; -use async_trait::async_trait; -use reqwest::Client; -use serde::Deserialize; -use serde_json::json; -use thiserror::Error; -use umem_core::Memory; - -#[derive(Debug, Deserialize)] -pub struct PineconeResponse { - pub data: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct PineconeDataRow { - pub index: usize, - pub score: f32, -} - -impl From for DataRow { - fn from(value: PineconeDataRow) -> Self { - Self { - index: value.index, - score: value.score, - } - } -} - -#[derive(Error, Debug)] -pub enum PineconeError { - #[error("pinecone api call failed with: {0}")] - ReqwestError(#[from] reqwest::Error), - - #[error("rerank option action failed with: {0}")] - RerankOptionsError(#[from] RerankOptionsError), -} - -pub struct Pinecone { - api_key: String, - model: String, -} - -impl Pinecone { - pub fn new(config: umem_config::Pinecone) -> Self { - Self { - api_key: config.api_key, - model: config.model, - } - } - - fn validate( - &self, - options: Option, - doc_len: usize, - ) -> Result<(), RerankOptionsError> { - if let Some(options) = options { - if let Some(top_n) = options.top_n { - if doc_len < top_n { - return Err(RerankOptionsError::InvalidTopN(top_n, doc_len)); - } - } - if options.structured_input.is_some() { - return Err(RerankOptionsError::InvalidField( - "structured_input".to_string(), - "Pinecone".to_string(), - )); - } - } - - Ok(()) - } -} - -impl Pinecone { - async fn rank_impl( - &self, - query: String, - documents: &[Memory], - options: Option, - ) -> Result { - let client = Client::new(); - - self.validate(options, documents.len())?; - - let documents: Vec = documents - .iter() - .map(|m| json!({"text": m.content().summary()})) - .collect(); - - let response = client - .post("https://api.pinecone.io/rerank") - .header("Content-Type", "application/json") - .header("Accept", "application/json") - .header("X-Pinecone-Api-Version", "2025-10") - .header("Api-Key", &self.api_key) - .json(&json!({ - "model": &self.model, - "query": query, - "documents": documents, - "top_n": 2, - "return_documents": false, - "parameters": { - "truncate": "END" - } - })); - - let response = response.send().await?; - - match response.error_for_status() { - Ok(res) => { - let response: PineconeResponse = res.json().await?; - Ok(response) - } - Err(err) => Err(err.into()), - } - } -} -impl From for RerankResponse { - fn from(value: PineconeResponse) -> Self { - let data = value.data.into_iter().map(|d| d.into()).collect(); - Self { data } - } -} - -#[async_trait] -impl RerankerBase for Pinecone { - async fn rank( - &self, - query: String, - documents: &[Memory], - options: Option, - ) -> Result { - Ok(self.rank_impl(query, documents, options).await?.into()) - } -} diff --git a/crates/umem_rerank/umem_rerank_derive/Cargo.toml b/crates/umem_rerank/umem_rerank_derive/Cargo.toml deleted file mode 100644 index edc7cc8..0000000 --- a/crates/umem_rerank/umem_rerank_derive/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "umem_rerank_derive" -version = "0.1.0" -edition = "2021" - -[lib] -proc-macro = true - -[dependencies] -syn = { version = "2.0", features = ["full"] } -quote = "1.0" -proc-macro2 = "1.0" diff --git a/crates/umem_rerank/umem_rerank_derive/src/lib.rs b/crates/umem_rerank/umem_rerank_derive/src/lib.rs deleted file mode 100644 index d814631..0000000 --- a/crates/umem_rerank/umem_rerank_derive/src/lib.rs +++ /dev/null @@ -1,85 +0,0 @@ -use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields, Type}; - -#[proc_macro_derive(Rerankable, attributes(rerank_field))] -pub fn derive_rerankable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let name = &input.ident; - - let fields = match &input.data { - Data::Struct(data) => match &data.fields { - Fields::Named(fields) => &fields.named, - _ => { - return syn::Error::new_spanned( - &input, - "Rerankable can only be derived for structs with named fields", - ) - .to_compile_error() - .into(); - } - }, - _ => { - return syn::Error::new_spanned(&input, "Rerankable can only be derived for structs") - .to_compile_error() - .into(); - } - }; - - let marked_fields: Vec<_> = fields - .iter() - .filter(|f| { - f.attrs - .iter() - .any(|attr| attr.path().is_ident("rerank_field")) - }) - .collect(); - - if marked_fields.is_empty() { - return syn::Error::new_spanned( - &input, - "exactly one field must be marked with #[rerank_field]", - ) - .to_compile_error() - .into(); - } - - if marked_fields.len() > 1 { - return syn::Error::new_spanned( - &marked_fields[1].ident, - "only one field can be marked with #[rerank_field]", - ) - .to_compile_error() - .into(); - } - - let field = marked_fields[0]; - let field_name = field.ident.as_ref().unwrap().to_string(); - - let is_string = match &field.ty { - Type::Path(type_path) => type_path - .path - .segments - .last() - .map(|seg| seg.ident == "String") - .unwrap_or(false), - _ => false, - }; - - if !is_string { - return syn::Error::new_spanned( - &field.ty, - "#[rerank_field] can only be applied to a field of type String", - ) - .to_compile_error() - .into(); - } - - let expanded = quote! { - impl umem_rerank::Rerankable for #name { - const RANK_FIELD: &'static str = #field_name; - } - }; - - TokenStream::from(expanded) -} diff --git a/src/bin/seed_db.rs b/src/bin/seed_db.rs index ea11095..5607725 100644 --- a/src/bin/seed_db.rs +++ b/src/bin/seed_db.rs @@ -1,66 +1,67 @@ -use std::{ - fs::File, - io::{BufRead, BufReader}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, -}; +fn main() {} +// use std::{ +// fs::File, +// io::{BufRead, BufReader}, +// sync::{ +// atomic::{AtomicUsize, Ordering}, +// Arc, +// }, +// }; -use anyhow::Result; -use dotenv::dotenv; -use tracing::info; -use umem::tracing_conf; -use umem_controller::{CreateMemoryRequest, MemoryController}; +// use anyhow::Result; +// use dotenv::dotenv; +// use tracing::info; +// use umem::tracing_conf; +// use umem_controller::{CreateMemoryRequest, MemoryController}; -const CONCURRENCY_LIMIT: usize = 10; +// const CONCURRENCY_LIMIT: usize = 10; -#[tokio::main] -async fn main() -> Result<()> { - // dotenv().ok(); - // let _guard = tracing_conf::init_tracing()?; +// #[tokio::main] +// async fn main() -> Result<()> { +// // dotenv().ok(); +// // let _guard = tracing_conf::init_tracing()?; - // let file = File::open("data_chatgpt/conversations_general.jsonl")?; - // let reader = BufReader::new(file); - // let lines: Vec = reader.lines().collect::>>()?; +// // let file = File::open("data_chatgpt/conversations_general.jsonl")?; +// // let reader = BufReader::new(file); +// // let lines: Vec = reader.lines().collect::>>()?; - // let total = lines.len(); - // info!("loaded {total} chat records, processing with {CONCURRENCY_LIMIT} concurrent tasks"); +// // let total = lines.len(); +// // info!("loaded {total} chat records, processing with {CONCURRENCY_LIMIT} concurrent tasks"); - // let semaphore = Arc::new(tokio::sync::Semaphore::new(CONCURRENCY_LIMIT)); - // let completed = Arc::new(AtomicUsize::new(0)); +// // let semaphore = Arc::new(tokio::sync::Semaphore::new(CONCURRENCY_LIMIT)); +// // let completed = Arc::new(AtomicUsize::new(0)); - // let mut tasks = Vec::new(); +// // let mut tasks = Vec::new(); - // for raw_chats in lines { - // let sem = semaphore.clone(); - // let completed = completed.clone(); +// // for raw_chats in lines { +// // let sem = semaphore.clone(); +// // let completed = completed.clone(); - // let task = tokio::spawn(async move { - // let _permit = sem.acquire().await?; +// // let task = tokio::spawn(async move { +// // let _permit = sem.acquire().await?; - // let _ = MemoryController::create( - // CreateMemoryRequest::builder() - // .user_id(Some("harry".to_string())) - // .raw_content(raw_chats) - // .build(), - // None, - // ) - // .await; +// // let _ = MemoryController::create( +// // CreateMemoryRequest::builder() +// // .user_id(Some("harry".to_string())) +// // .raw_content(raw_chats) +// // .build(), +// // None, +// // ) +// // .await; - // let count = completed.fetch_add(1, Ordering::SeqCst) + 1; - // info!("Progress: {count}/{total} completed"); +// // let count = completed.fetch_add(1, Ordering::SeqCst) + 1; +// // info!("Progress: {count}/{total} completed"); - // Ok::<(), anyhow::Error>(()) - // }); +// // Ok::<(), anyhow::Error>(()) +// // }); - // tasks.push(task); - // } +// // tasks.push(task); +// // } - // for task in tasks { - // task.await??; - // } +// // for task in tasks { +// // task.await??; +// // } - // info!("all {total} chat records processed successfully"); - Ok(()) -} +// // info!("all {total} chat records processed successfully"); +// Ok(()) +// } From b8f11defbac8c9b5561c219856d1dd656418bc67 Mon Sep 17 00:00:00 2001 From: Sangit Manandhar Date: Wed, 21 Jan 2026 00:46:59 -0500 Subject: [PATCH 2/2] chore: rm unused structs --- crates/umem_config/src/lib.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/crates/umem_config/src/lib.rs b/crates/umem_config/src/lib.rs index e5058dc..6ef9e4e 100644 --- a/crates/umem_config/src/lib.rs +++ b/crates/umem_config/src/lib.rs @@ -91,18 +91,6 @@ pub enum VectorStore { PgVector(PgVector), } -#[derive(Debug, Deserialize, Clone)] -pub struct Pinecone { - pub api_key: String, - pub model: String, -} - -#[derive(Debug, Deserialize, Clone)] -pub struct Cohere { - pub api_key: String, - pub model: String, -} - #[derive(Debug, Deserialize, Clone)] pub struct RerankingModel { pub provider: Provider,