Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
518 changes: 400 additions & 118 deletions Cargo.lock

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ members = ["crates/*"]

[workspace.dependencies]
umem_core = {path = "crates/umem_core"}
umem_memory_machine = {path = "crates/umem_memory_machine"}
umem_refine = {path = "crates/umem_refine"}
umem_rerank = {path = "crates/umem_rerank"}
umem_vector_store = {path = "crates/umem_vector_store"}
umem_config = {path = "crates/umem_config"}
umem_embeddings = {path = "crates/umem_embeddings"}
Expand Down Expand Up @@ -40,17 +43,30 @@ sqlx = { version = "0.8.6" , features = [ "postgres", "uuid", "runtime-tokio",
thiserror = "2.0.17"
typed-builder = "0.23.2"
schemars = "1.1.0"
futures = "0.3.31"
rayon = "1.11.0"
yaml-rust2 = "0.11.0"
serde-saphyr = "0.0.12"

[dependencies]
umem_grpc_server = { workspace = true }
umem_memory_machine = {path = "crates/umem_memory_machine"}
umem_mcp = { workspace = true }
thiserror = { workspace = true }
typed-builder = { workspace = true }
umem_config = { workspace = true }
umem_controller = { workspace = true }
umem_rerank = { workspace = true }
umem_ai = { workspace = true }
umem_embeddings = { workspace = true }
umem_vector_store = { workspace = true }
anyhow = { workspace = true }
tokio = { workspace = true }
tracing-subscriber = { workspace = true }
tracing = { workspace = true }
tracing-appender = { workspace = true }
dirs = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rayon = { workspace = true }
dotenv = { version = "0.15.0" }
4 changes: 4 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Errors

-- Eradicate ANYHOW
-- Fix the current error hierarchy
77 changes: 50 additions & 27 deletions crates/umem_ai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub use response_generators::*;
use schemars::JsonSchema;
use serde::{Serialize, de::DeserializeOwned};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::OnceCell;
use umem_config::CONFIG;

pub type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
Expand All @@ -20,11 +22,19 @@ lazy_static! {
static ref reqwest_client: reqwest::Client = reqwest::Client::new();
}

#[derive(Error, Debug)]
pub enum LanguageModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

pub struct LanguageModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

static LANGUAGE_MODEL: OnceCell<Arc<LanguageModel>> = OnceCell::const_new();

impl LanguageModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
Expand All @@ -33,8 +43,46 @@ impl LanguageModel {
}
}

pub fn get_model() -> Arc<LanguageModel> {
Arc::clone(&LANGUAGE_MODEL)
pub async fn get_model() -> Result<Arc<LanguageModel>, LanguageModelError> {
LANGUAGE_MODEL
.get_or_try_init(|| async {
match CONFIG.language_model.provider.clone() {
umem_config::Provider::OpenAI(open_ai) => {
let openai_provider = OpenAIProvider::builder()
.api_key(open_ai.api_key)
.base_url(open_ai.base_url)
.default_headers(open_ai.default_headers.unwrap_or_default())
.project(open_ai.project)
.organization(open_ai.organization)
.build();

let provider = Arc::new(AIProvider::from(openai_provider));

Ok(Arc::new(LanguageModel {
provider,
model_name: CONFIG.language_model.model.clone(),
}))
}
umem_config::Provider::AmazonBedrock(config) => {
let provider = AmazonBedrockProviderBuilder::default()
.region(config.region)
.access_key_id(config.key_id)
.secret_access_key(config.access_key)
.build()
.await
.map_err(|e| AIProviderError::ProviderBuilderError(e.into()))?;

let provider = Arc::new(AIProvider::from(provider));

Ok(Arc::new(LanguageModel {
provider,
model_name: CONFIG.language_model.model.clone(),
}))
}
}
})
.await
.cloned()
}
}

Expand All @@ -51,10 +99,6 @@ impl RerankingModel {
model_name,
}
}

pub fn get_model() -> Arc<LanguageModel> {
Arc::clone(&LANGUAGE_MODEL)
}
}

#[derive(Debug)]
Expand All @@ -68,27 +112,6 @@ pub enum AIProvider {
Cohere(CohereProvider),
}

lazy_static! {
static ref LANGUAGE_MODEL: Arc<LanguageModel> = match CONFIG.language_model.provider.clone() {
umem_config::Provider::OpenAI(open_ai) => {
let openai_provider = OpenAIProvider::builder()
.api_key(open_ai.api_key)
.base_url(open_ai.base_url)
.default_headers(open_ai.default_headers.unwrap_or_default())
.project(open_ai.project)
.organization(open_ai.organization)
.build();

let provider = Arc::new(AIProvider::from(openai_provider));

Arc::new(LanguageModel {
provider,
model_name: CONFIG.language_model.model_name.clone(),
})
}
};
}

impl AIProvider {
pub(crate) async fn do_generate_text(
&self,
Expand Down
6 changes: 3 additions & 3 deletions crates/umem_ai/src/providers/amazon_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,9 @@ mod tests {
async fn test_bedrock_generate_object() {
let provider = Arc::new(AIProvider::from(
AmazonBedrockProviderBuilder::default()
.region("REGION")
.access_key_id("ACESS_KEY_ID")
.secret_access_key("SECRET_ACCESS_KEY")
.region("us-west-2")
.access_key_id("test_access_key_id")
.secret_access_key("test_secret_access_key")
.build()
.await
.unwrap(),
Expand Down
4 changes: 3 additions & 1 deletion crates/umem_ai/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ pub struct AnthropicProvider {
#[builder(default = "https://api.anthropic.com/v1".into())]
pub base_url: String,

#[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| utils::build_header_map(value.as_slice()).unwrap_or_default()))]
#[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>|
utils::build_header_map(value.as_slice()).unwrap_or_default()
))]
pub headers: HeaderMap,
}

Expand Down
16 changes: 15 additions & 1 deletion crates/umem_ai/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,21 @@ mod google_vertex;
mod openai;
mod xai;

pub use amazon_bedrock::AmazonBedrockProvider;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum AIProviderError {
#[error("provider build failed with : {0}")]
ProviderBuilderError(#[from] ProviderBuilderError),
}

#[derive(Error, Debug)]
pub enum ProviderBuilderError {
#[error("amazon bedrock provider build failed with : {0}")]
AmazonBedrockProviderBuilderError(#[from] AmazonBedrockProviderBuilderError),
}

pub use amazon_bedrock::*;
pub use anthropic::AnthropicProvider;
pub use azure_openai::AzureOpenAIProvider;
pub use cohere::CohereProvider;
Expand Down
4 changes: 3 additions & 1 deletion crates/umem_ai/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ pub struct OpenAIProvider {
#[builder(default = "https://api.openai.com/v1".into(), setter(transform = |value: impl Into<String>| value.into()))]
pub base_url: String,

#[builder(default, setter(transform = |value: Vec<(String, String)>| utils::build_header_map(value.as_slice()).unwrap_or_default()))]
#[builder(default, setter(transform = |value: Vec<(String, String)>|
utils::build_header_map(value.as_slice()).unwrap_or_default()
))]
pub default_headers: HeaderMap,

#[builder(default)]
Expand Down
6 changes: 6 additions & 0 deletions crates/umem_allocator/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "umem_allocator"
version = "0.1.0"
edition = "2021"

[dependencies]
Empty file.
1 change: 1 addition & 0 deletions crates/umem_annotations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ umem_core = { workspace = true}
thiserror = { workspace = true }
typed-builder = { workspace = true }
schemars = { workspace = true }
tracing = { workspace = true }
serde = { workspace = true }
61 changes: 23 additions & 38 deletions crates/umem_annotations/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use umem_core::{MemoryContent, MemoryKind, MemorySignals, Provenance};
use umem_core::{
MemoryContent,
MemoryKind,
// MemorySignals, Provenance
};

use umem_ai::{
GenerateObjectRequestBuilder, GenerateObjectRequestBuilderError,
Expand All @@ -23,22 +27,28 @@ pub enum AnnotationError {
}

const ANNOTATION_PROMPT: &str = r#"
You are a memory annotation system. Your task is to analyze raw content and extract structured memory metadata that can be stored and retrieved efficiently.
Given the input content, produce a JSON object with the following structure:
## content
### summary
Extract the key points from the content as a concise, information-dense summary. Requirements:
You are a memory annotation system. Your task is to analyze a chat session between a user and an AI agent, then extract structured memory metadata that can be stored and retrieved efficiently.

## Input
A conversation transcript containing user messages and agent responses. Focus on extracting what the user learned, decided, asked about, or expressed preferences for—not the back-and-forth dialogue itself.

## Output

### content.summary
Extract the key points from the conversation as a concise, information-dense summary. Requirements:
- Preserve specific details: names, dates, numbers, URLs, technical terms, and concrete values
- Focus on actionable information, facts, preferences, and decisions
- Omit filler words, pleasantries, and redundant context
- Focus on actionable information, facts, preferences, and decisions made by the user
- Capture the resolution or answer, not the process of arriving at it
- Omit filler words, pleasantries, and redundant back-and-forth
- Use clear, direct language
- 1-4 sentences depending on content complexity
### tags

### content.tags
Extract 3-7 lowercase keywords that categorize and index this memory:
- Use singular forms (e.g., "project" not "projects")
- Include domain-specific terms, proper nouns (lowercased), and action verbs where relevant
- Prioritize terms useful for future retrieval
## kind

### kind
Classify the memory into exactly one type:
- **Semantic**: General knowledge, facts, concepts, definitions, explanations
- **Episodic**: Specific events, experiences, occurrences with temporal or spatial context
Expand All @@ -47,39 +57,14 @@ Classify the memory into exactly one type:
- **Relational**: Information about people, organizations, entities, and their relationships
- **Working**: Temporary context relevant only to an ongoing task or session
- **Prospective**: Future intentions, goals, plans, reminders, scheduled commitments
## signals
### certainty (0.0 - 1.0)
How confident you are in the accuracy and completeness of the extracted information:
- **0.9-1.0**: Explicitly stated facts, direct quotes, unambiguous instructions
- **0.6-0.8**: Reasonable inferences, implied information with strong context
- **0.3-0.5**: Uncertain interpretations, partial information, ambiguous statements
- **0.1-0.2**: Speculative, conflicting information, or low-quality source
### salience (0.0 - 1.0)
How important, actionable, or frequently needed this memory is likely to be:
- **0.9-1.0**: Critical preferences, key decisions, important instructions, identity-defining information
- **0.6-0.8**: Useful facts, notable events, relevant context for future tasks
- **0.3-0.5**: Background information, minor details, situational context
- **0.1-0.2**: Trivial details, transient information, low future utility
Note: certainty and salience cannot both be 0.0.
## provenance
### origin
Infer the source of the content:
- **User**: Content originates from a human user (input, message, document they provided)
- **Agent**: Content generated by an AI system, automated process, or tool output
### method
How this memory was derived:
- **Direct**: The content is being stored as-is or with minimal transformation
- **Extracted**: Information was extracted/summarized from larger content (include model name and system prompt)
- **Summarized**: Content was condensed from a longer source (include model name)
---
"#;

#[derive(Clone, schemars::JsonSchema, Serialize, Deserialize)]
pub struct LLMAnnotated {
pub content: MemoryContent,
pub kind: MemoryKind,
pub signals: MemorySignals,
pub provenance: Provenance,
// pub signals: MemorySignals,
// pub provenance: Provenance,
}

impl Annotation {
Expand Down
Loading