Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
820 changes: 764 additions & 56 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion crates/umem_ai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "umem_ai"
version = "0.1.0"
edition = "2021"
edition = "2024"

[dependencies]
anyhow.workspace = true
Expand All @@ -20,3 +20,8 @@ tokio.workspace = true
tracing.workspace = true
typed-builder.workspace = true
umem_config = {workspace = true}
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-bedrockruntime = "1.120.0"
aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]}
serde-saphyr = "0.0.14"
aws-sdk-bedrockagentruntime = "1.119.0"
77 changes: 71 additions & 6 deletions crates/umem_ai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
#![allow(dead_code)]
mod providers;
mod response_generators;
pub(crate) mod utils;

use std::sync::Arc;
mod utils;

use anyhow::Result;
use async_trait::async_trait;
use lazy_static::lazy_static;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Serialize};

pub use providers::*;
pub use response_generators::*;
use schemars::JsonSchema;
use serde::{Serialize, de::DeserializeOwned};
use std::sync::Arc;
use umem_config::CONFIG;

pub type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
Expand All @@ -40,13 +38,34 @@ impl LanguageModel {
}
}

#[derive(Debug, Clone)]
pub struct RerankingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl RerankingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}

pub fn get_model() -> Arc<LanguageModel> {
Arc::clone(&LANGUAGE_MODEL)
}
}

#[derive(Debug)]
pub enum AIProvider {
OpenAI(OpenAIProvider),
AzureOpenAI(AzureOpenAIProvider),
GoogleVertexAI(GoogleVertexAIProvider),
Anthropic(AnthropicProvider),
XAI(XAIProvider),
AmazonBedrock(AmazonBedrockProvider),
Cohere(CohereProvider),
}

lazy_static! {
Expand Down Expand Up @@ -77,6 +96,7 @@ impl AIProvider {
) -> Result<GenerateTextResponse, ResponseGeneratorError> {
match self {
AIProvider::OpenAI(provider) => provider.generate_text(request),
AIProvider::AmazonBedrock(provider) => provider.generate_text(request),
_ => unimplemented!(),
}
.await
Expand All @@ -90,10 +110,37 @@ impl AIProvider {
) -> Result<GenerateObjectResponse<T>, ResponseGeneratorError> {
match self {
AIProvider::OpenAI(provider) => provider.generate_object(request),
AIProvider::AmazonBedrock(provider) => provider.generate_object(request),
_ => unimplemented!(),
}
.await
}

pub(crate) async fn do_reranking(
&self,
request: RerankRequest,
) -> Result<RerankResponse, ResponseGeneratorError> {
match self {
AIProvider::Cohere(provider) => provider.rerank(request),
AIProvider::AmazonBedrock(provider) => provider.rerank(request),
_ => unimplemented!(),
}
.await
}

pub(crate) async fn do_structured_reranking<T>(
&self,
request: StructuredRerankRequest<T>,
) -> Result<StructuredRerankResponse<T>, ResponseGeneratorError>
where
T: Serialize + Clone + Send + Sync,
{
match self {
AIProvider::Cohere(provider) => provider.rerank_structured(request).await,
AIProvider::AmazonBedrock(provider) => provider.rerank_structured(request).await,
_ => unimplemented!(),
}
}
}

#[async_trait]
Expand All @@ -112,6 +159,24 @@ pub trait GeneratesObject {
) -> Result<GenerateObjectResponse<T>, ResponseGeneratorError>;
}

#[async_trait]
pub trait Reranks {
async fn rerank(
&self,
request: RerankRequest,
) -> Result<RerankResponse, ResponseGeneratorError>;
}

#[async_trait]
pub trait ReranksStructuredData {
async fn rerank_structured<T>(
&self,
request: StructuredRerankRequest<T>,
) -> Result<StructuredRerankResponse<T>, ResponseGeneratorError>
where
T: Serialize + Clone + Send + Sync;
}

impl From<OpenAIProvider> for AIProvider {
fn from(config: OpenAIProvider) -> Self {
AIProvider::OpenAI(config)
Expand Down
Loading