Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 3 additions & 223 deletions crates/umem_ai/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,239 +1,19 @@
// TODO: remove this allow once the module is fully implemented
#![allow(dead_code)]
mod model_impl;
pub mod models;
mod providers;
mod response_generators;
mod utils;

use anyhow::Result;
use async_trait::async_trait;
use lazy_static::lazy_static;

pub use model_impl::*;
pub use models::*;
pub use providers::*;
pub use response_generators::*;
use schemars::JsonSchema;
use serde::{Serialize, de::DeserializeOwned};
use std::sync::Arc;
use thiserror::Error;

pub type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;

lazy_static! {
static ref reqwest_client: reqwest::Client = reqwest::Client::new();
}

#[derive(Error, Debug)]
pub enum LanguageModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

pub struct LanguageModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl LanguageModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}
}

#[derive(Error, Debug)]
pub enum RerankingModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

#[derive(Debug, Clone)]
pub struct RerankingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl RerankingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}
}

#[derive(Error, Debug)]
pub enum EmbeddingModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

#[derive(Debug, Clone)]
pub struct EmbeddingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl EmbeddingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}
}

#[derive(Debug)]
pub enum AIProvider {
OpenAI(OpenAIProvider),
AzureOpenAI(AzureOpenAIProvider),
GoogleVertexAI(GoogleVertexAIProvider),
Anthropic(AnthropicProvider),
XAI(XAIProvider),
AmazonBedrock(AmazonBedrockProvider),
Cohere(CohereProvider),
}

impl AIProvider {
pub(crate) async fn do_generate_text(
&self,
request: GenerateTextRequest,
) -> Result<GenerateTextResponse, ResponseGeneratorError> {
match self {
AIProvider::OpenAI(provider) => provider.generate_text(request),
AIProvider::AmazonBedrock(provider) => provider.generate_text(request),
_ => unimplemented!(),
}
.await
}

pub(crate) async fn do_generate_object<
T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned,
>(
&self,
request: GenerateObjectRequest<T>,
) -> Result<GenerateObjectResponse<T>, ResponseGeneratorError> {
match self {
AIProvider::OpenAI(provider) => provider.generate_object(request),
AIProvider::AmazonBedrock(provider) => provider.generate_object(request),
_ => unimplemented!(),
}
.await
}

pub(crate) async fn do_reranking(
&self,
request: RerankRequest,
) -> Result<RerankResponse, ResponseGeneratorError> {
match self {
AIProvider::Cohere(provider) => provider.rerank(request),
AIProvider::AmazonBedrock(provider) => provider.rerank(request),
_ => unimplemented!(),
}
.await
}

pub(crate) async fn do_structured_reranking<T>(
&self,
request: StructuredRerankRequest<T>,
) -> Result<StructuredRerankResponse<T>, ResponseGeneratorError>
where
T: Serialize + Clone + Send + Sync,
{
match self {
AIProvider::Cohere(provider) => provider.rerank_structured(request).await,
AIProvider::AmazonBedrock(provider) => provider.rerank_structured(request).await,
_ => unimplemented!(),
}
}

pub(crate) async fn do_embed(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, ResponseGeneratorError> {
match self {
AIProvider::AmazonBedrock(provider) => provider.embed(request),
_ => unimplemented!(),
}
.await
}
}

#[async_trait]
pub trait GeneratesText {
async fn generate_text(
&self,
request: GenerateTextRequest,
) -> Result<GenerateTextResponse, ResponseGeneratorError>;
}

#[async_trait]
pub trait GeneratesObject {
async fn generate_object<T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned>(
&self,
request: GenerateObjectRequest<T>,
) -> Result<GenerateObjectResponse<T>, ResponseGeneratorError>;
}

#[async_trait]
pub trait Reranks {
async fn rerank(
&self,
request: RerankRequest,
) -> Result<RerankResponse, ResponseGeneratorError>;
}

#[async_trait]
pub trait ReranksStructuredData {
async fn rerank_structured<T>(
&self,
request: StructuredRerankRequest<T>,
) -> Result<StructuredRerankResponse<T>, ResponseGeneratorError>
where
T: Serialize + Clone + Send + Sync;
}

#[async_trait]
pub trait Embeds {
async fn embed(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, ResponseGeneratorError>;
}

impl From<OpenAIProvider> for AIProvider {
fn from(config: OpenAIProvider) -> Self {
AIProvider::OpenAI(config)
}
}

impl From<AzureOpenAIProvider> for AIProvider {
fn from(config: AzureOpenAIProvider) -> Self {
AIProvider::AzureOpenAI(config)
}
}

impl From<AnthropicProvider> for AIProvider {
fn from(config: AnthropicProvider) -> Self {
AIProvider::Anthropic(config)
}
}

impl From<XAIProvider> for AIProvider {
fn from(config: XAIProvider) -> Self {
AIProvider::XAI(config)
}
}

impl From<AmazonBedrockProvider> for AIProvider {
fn from(config: AmazonBedrockProvider) -> Self {
AIProvider::AmazonBedrock(config)
}
}

impl From<GoogleVertexAIProvider> for AIProvider {
fn from(config: GoogleVertexAIProvider) -> Self {
AIProvider::GoogleVertexAI(config)
}
}
4 changes: 2 additions & 2 deletions crates/umem_ai/src/model_impl/embedding_model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
AIProvider, AIProviderError, AmazonBedrockProviderBuilder, EmbeddingModel, EmbeddingModelError,
OpenAIProvider,
AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider,
models::{EmbeddingModel, EmbeddingModelError},
};
use std::sync::Arc;
use tokio::sync::OnceCell;
Expand Down
4 changes: 2 additions & 2 deletions crates/umem_ai/src/model_impl/language_model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
AIProvider, AIProviderError, AmazonBedrockProviderBuilder, LanguageModel, LanguageModelError,
OpenAIProvider,
AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider,
models::{LanguageModel, LanguageModelError},
};
use std::sync::Arc;
use tokio::sync::OnceCell;
Expand Down
4 changes: 2 additions & 2 deletions crates/umem_ai/src/model_impl/reranking_model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider, RerankingModel,
RerankingModelError,
AIProvider, AIProviderError, AmazonBedrockProviderBuilder, OpenAIProvider,
models::{RerankingModel, RerankingModelError},
};
use std::sync::Arc;
use tokio::sync::OnceCell;
Expand Down
24 changes: 24 additions & 0 deletions crates/umem_ai/src/models/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use crate::{AIProvider, AIProviderError};
use std::sync::Arc;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum EmbeddingModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

#[derive(Debug, Clone)]
pub struct EmbeddingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl EmbeddingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}
}
23 changes: 23 additions & 0 deletions crates/umem_ai/src/models/language.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use crate::{AIProvider, AIProviderError};
use std::sync::Arc;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum LanguageModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

pub struct LanguageModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl LanguageModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}
}
7 changes: 7 additions & 0 deletions crates/umem_ai/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub mod embedding;
pub mod language;
pub mod reranking;

pub use embedding::*;
pub use language::*;
pub use reranking::*;
24 changes: 24 additions & 0 deletions crates/umem_ai/src/models/reranking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use crate::{AIProvider, AIProviderError};
use std::sync::Arc;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum RerankingModelError {
#[error("ai provider failed with : {0}")]
AIProviderError(#[from] AIProviderError),
}

#[derive(Debug, Clone)]
pub struct RerankingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

impl RerankingModel {
fn new(provider: Arc<AIProvider>, model_name: String) -> Self {
Self {
provider,
model_name,
}
}
}
7 changes: 4 additions & 3 deletions crates/umem_ai/src/providers/amazon_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,10 @@ impl AmazonBedrockProviderBuilder {
mod tests {
use super::*;
use crate::{
AIProvider, EmbeddingModel, GenerateObjectRequestBuilder, GenerateTextRequestBuilder,
LanguageModel, RerankingModel, SerializationFormat, generate_object, generate_text, rerank,
structured_rerank,
AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, SerializationFormat,
generate_object, generate_text,
models::{EmbeddingModel, LanguageModel, RerankingModel},
rerank, structured_rerank,
};
use serde::Deserialize;
use std::sync::Arc;
Expand Down
Loading