Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ umem_memory_machine = {path = "crates/umem_memory_machine"}
umem_refine = {path = "crates/umem_refine"}
umem_vector_store = {path = "crates/umem_vector_store"}
umem_config = {path = "crates/umem_config"}
umem_embeddings = {path = "crates/umem_embeddings"}
umem_annotations = {path = "crates/umem_annotations"}
umem_grpc_server = {path = "crates/umem_grpc_server"}
umem_web_scrapper = {path = "crates/umem_web_scrapper"}
Expand Down Expand Up @@ -56,7 +55,6 @@ typed-builder = { workspace = true }
umem_config = { workspace = true }
umem_controller = { workspace = true }
umem_ai = { workspace = true }
umem_embeddings = { workspace = true }
umem_vector_store = { workspace = true }
anyhow = { workspace = true }
tokio = { workspace = true }
Expand Down
61 changes: 59 additions & 2 deletions crates/umem_ai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ use thiserror::Error;
use tokio::sync::OnceCell;
use umem_config::CONFIG;

use crate::response_generators::embed::{EmbeddingRequest, EmbeddingResponse};

pub type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;

lazy_static! {
Expand Down Expand Up @@ -153,12 +151,71 @@ impl RerankingModel {
}
}

#[derive(Error, Debug)]
pub enum EmbeddingModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

#[derive(Debug, Clone)]
pub struct EmbeddingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

static EMBEDDING_MODEL: OnceCell<Arc<EmbeddingModel>> = OnceCell::const_new();

impl EmbeddingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}

pub async fn get_model() -> Result<Arc<EmbeddingModel>, LanguageModelError> {
EMBEDDING_MODEL
.get_or_try_init(|| async {
match CONFIG.embedding_model.provider.clone() {
umem_config::Provider::OpenAI(open_ai) => {
let openai_provider = OpenAIProvider::builder()
.api_key(open_ai.api_key)
.base_url(open_ai.base_url)
.default_headers(open_ai.default_headers.unwrap_or_default())
.project(open_ai.project)
.organization(open_ai.organization)
.build();

let provider = Arc::new(AIProvider::from(openai_provider));

Ok(Arc::new(EmbeddingModel {
provider,
model_name: CONFIG.embedding_model.model.clone(),
}))
}
umem_config::Provider::AmazonBedrock(config) => {
let provider = AmazonBedrockProviderBuilder::default()
.region(config.region)
.access_key_id(config.key_id)
.secret_access_key(config.access_key)
.build()
.await
.map_err(|e| AIProviderError::ProviderBuilderError(e.into()))?;

let provider = Arc::new(AIProvider::from(provider));

Ok(Arc::new(EmbeddingModel {
provider,
model_name: CONFIG.embedding_model.model.clone(),
}))
}
}
})
.await
.cloned()
}
Comment thread
vidurkhanal marked this conversation as resolved.
}

#[derive(Debug)]
pub enum AIProvider {
OpenAI(OpenAIProvider),
Expand Down
12 changes: 7 additions & 5 deletions crates/umem_ai/src/providers/amazon_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
Embeds, GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText,
OpenAIProvider, Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData,
SerializationMode, StructuredRanking, StructuredRerankRequest, StructuredRerankResponse,
embed::{EmbeddingRequest, EmbeddingResponse, embed as embedFn},
embed::{EmbeddingRequest, EmbeddingResponse},
messages::{FilePart, UserModelMessage},
response_generators::{
self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError,
Expand Down Expand Up @@ -30,7 +30,7 @@ use futures::future::join_all;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Map;
use std::{error::Error, sync::Arc};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Semaphore;

Expand Down Expand Up @@ -770,8 +770,8 @@ mod tests {
use super::*;
use crate::{
AIProvider, EmbeddingModel, GenerateObjectRequestBuilder, GenerateTextRequestBuilder,
LanguageModel, RerankingModel, SerializationFormat, embed, generate_object, generate_text,
rerank, structured_rerank,
LanguageModel, RerankingModel, SerializationFormat, generate_object, generate_text, rerank,
structured_rerank,
};
use serde::Deserialize;
use std::sync::Arc;
Expand Down Expand Up @@ -998,6 +998,8 @@ mod tests {

#[tokio::test]
async fn embedding_test() {
use crate::embed::embed;

let provider = Arc::new(AIProvider::from(
AmazonBedrockProviderBuilder::default()
.region("REGION")
Expand All @@ -1022,7 +1024,7 @@ mod tests {
])
.build();

let embedding_response = embedFn(request).await.unwrap();
let embedding_response = embed(request).await.unwrap();
dbg!(&embedding_response);
}
}
4 changes: 1 addition & 3 deletions crates/umem_ai/src/response_generators/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use backon::{ExponentialBuilder, Retryable};
use reqwest::header::HeaderMap;
use std::{sync::Arc, time::Duration};

pub async fn embed(
mut request: EmbeddingRequest,
) -> Result<EmbeddingResponse, ResponseGeneratorError> {
pub async fn embed(request: EmbeddingRequest) -> Result<EmbeddingResponse, ResponseGeneratorError> {
let per_request_timeout = request.timeout;
let max_retries = request.max_retries;
let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0);
Expand Down
1 change: 1 addition & 0 deletions crates/umem_ai/src/response_generators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod messages;
pub mod rerank;
pub mod structured_rerank;

pub use embed::*;
pub use generate_object::*;
pub use generate_text::*;
pub use messages::*;
Expand Down
19 changes: 6 additions & 13 deletions crates/umem_config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@ use lazy_static::lazy_static;
use serde::Deserialize;
use std::net::SocketAddr;

#[derive(Debug, Deserialize, Clone)]
pub struct Cloudflare {
pub account_id: String,
pub api_token: String,
pub model: String,
}

#[derive(Debug, Deserialize, Clone)]
pub struct OpenAI {
pub api_key: String,
Expand All @@ -35,6 +28,12 @@ pub enum Provider {
AmazonBedrock(AmazonBedrock),
}

#[derive(Debug, Deserialize, Clone)]
pub struct EmbeddingModel {
pub provider: Provider,
pub model: String,
}

#[derive(Debug, Deserialize, Clone)]
pub struct LanguageModel {
pub provider: Provider,
Expand All @@ -54,12 +53,6 @@ pub struct Mcp {
pub work_os: WorkOs,
}

#[derive(Debug, Deserialize, Clone)]
pub enum EmbeddingModel {
#[serde(rename = "cloudflare")]
Cloudflare(Cloudflare),
}

#[derive(Debug, Deserialize, Clone)]
pub struct WorkOs {
pub client_id: String,
Expand Down
1 change: 0 additions & 1 deletion crates/umem_controller/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ edition = "2021"
serde_json = { workspace = true }
umem_vector_store = { workspace = true}
umem_ai = { workspace = true}
umem_embeddings = { workspace = true}
umem_refine = { workspace = true}
umem_annotations = { workspace = true}
umem_core = { workspace = true}
Expand Down
37 changes: 27 additions & 10 deletions crates/umem_controller/src/create_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ use chrono::Utc;
use std::sync::Arc;
use thiserror::Error;
use typed_builder::TypedBuilder;
use umem_ai::{AIProviderError, LanguageModel};
use umem_ai::{
embed::{embed, EmbeddingRequest},
AIProviderError, EmbeddingModel, LanguageModel, ResponseGeneratorError,
};
use umem_annotations::{Annotation, AnnotationError, LLMAnnotated};
use umem_core::{
LifecycleState, Memory, MemoryContentError, MemoryContext, MemoryContextError, MemoryError,
TemporalMetadata,
};
use umem_embeddings::{EmbedderBase, EmbedderError};
use umem_vector_store::VectorStoreError;
use uuid::Uuid;

Expand All @@ -21,11 +23,11 @@ pub enum CreateMemoryError {
#[error("vector store action failed with: {0}")]
VectorStoreError(#[from] VectorStoreError),

#[error("embedder action failed with: {0}")]
EmbedderError(#[from] EmbedderError),

#[error("ai provider action failed with: {0}")]
AIProviderError(#[from] AIProviderError),

#[error("response generator action failed with: {0}")]
ResponseGeneratorError(#[from] ResponseGeneratorError),
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -120,9 +122,9 @@ impl CreateMemoryRequest {
#[derive(TypedBuilder, Default)]
pub struct CreateMemoryOptions {
#[builder(default = None)]
pub embedder: Option<Arc<dyn EmbedderBase + Send + Sync>>,
pub embedding_model: Option<Arc<EmbeddingModel>>,
#[builder(default = None)]
pub model: Option<Arc<LanguageModel>>,
pub language_model: Option<Arc<LanguageModel>>,
}

impl MemoryController {
Expand All @@ -140,12 +142,27 @@ impl MemoryController {
_options: Option<CreateMemoryOptions>,
) -> Result<Memory, CreateMemoryError> {
let vector_store = Arc::clone(&self.vector_store);
let embedder = Arc::clone(&self.embedder);
let embedding_model = Arc::clone(&self.embedding_model);
let language_model = Arc::clone(&self.language_model);

let memory = request.build(language_model).await?;
let vector = embedder.generate_embedding(memory.get_summary()).await?;
vector_store.insert(&[&vector], &[&memory]).await?;

let request = EmbeddingRequest::builder()
.model(embedding_model)
.input(vec![memory.get_summary().to_owned()])
.build();

let embedding_response = embed(request).await?;

//NOTE: change this later, just didin't want to fight with the drilled types
let slices: Vec<&[f32]> = embedding_response
.embeddings
.iter()
.map(|inner| inner.as_slice())
.collect();
let slice_of_slices: &[&[f32]] = &slices;
Comment thread
vidurkhanal marked this conversation as resolved.

vector_store.insert(slice_of_slices, &[&memory]).await?;
Ok(memory)
}
}
5 changes: 2 additions & 3 deletions crates/umem_controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ pub use delete_memory::*;
pub use get_memory::*;
pub use list_memory::*;
pub use search_memory::*;
use umem_ai::{LanguageModel, RerankingModel};
use umem_embeddings::EmbedderBase;
use umem_ai::{EmbeddingModel, LanguageModel, RerankingModel};
use umem_vector_store::VectorStoreBase;
pub use update_memory::*;

Expand Down Expand Up @@ -43,7 +42,7 @@ pub enum MemoryControllerError {
#[derive(Clone)]
pub struct MemoryController {
pub vector_store: Arc<dyn VectorStoreBase + Send + Sync>,
pub embedder: Arc<dyn EmbedderBase + Send + Sync>,
pub embedding_model: Arc<EmbeddingModel>,
pub reranking_model: Arc<RerankingModel>,
pub language_model: Arc<LanguageModel>,
}
Loading