From 89f571ff253f0e4a956549ae15d06a9544cde00a Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 29 Dec 2025 13:54:48 -0800 Subject: [PATCH 01/12] feat(ai/bedrock): add Amazon Bedrock provider support - Added `aws-config` and `aws-sdk-bedrockruntime` dependencies to enable Amazon Bedrock integration. - Implemented `AmazonBedrockProvider` with support for object generation via the Bedrock Converse API. - Updated `AIProvider` to route object generation requests to Bedrock when selected. - Introduced error handling for Bedrock-specific and deserialization errors. - Refactored provider builder for Bedrock, including region, credentials, and header management. - Added test scaffolding for Bedrock provider. - Updated error types and retry logic to handle new Bedrock error cases. --- Cargo.lock | 697 ++++++++++++++++-- crates/umem_ai/Cargo.toml | 2 + crates/umem_ai/src/lib.rs | 1 + .../umem_ai/src/providers/amazon_bedrock.rs | 299 +++++++- crates/umem_ai/src/providers/openai.rs | 14 +- crates/umem_ai/src/response_generators/mod.rs | 12 +- crates/umem_ai/src/utils.rs | 17 +- 7 files changed, 950 insertions(+), 92 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 562e124..522ff1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -157,6 +157,402 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-config" +version = "1.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96571e6996817bf3d58f6b569e4b9fd2e9d2fcf9f7424eed07b2ce9bb87535e5" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 1.4.0", + "ring", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + +[[package]] +name = "aws-lc-rs" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a88aab2464f1f25453baa7a07c84c5b7684e274054ba06817f382357f77a288" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45afffdee1e7c9126814751f88dddc747f41d91da16c9551a0f1e8a11e788a1" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "aws-runtime" +version = "1.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81b5b2898f6798ad58f484856768bca817e3cd9de0974c24ae0f1113fe88f1b" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-bedrockruntime" +version = "1.120.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b8dcf42378ab2d5accac1652cdd059114fb071baf53250ceafb76fcdde347f" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "hyper 0.14.32", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.91.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ee6402a36f27b52fe67661c6732d684b2635152b676aa2babbfb5204f99115d" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.93.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a45a7f750bbd170ee3677671ad782d90b894548f4e4ae168302c57ec9de5cb3e" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.95.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55542378e419558e6b1f398ca70adb0b2088077e79ad9f14eb09441f2f7b2164" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c" +dependencies = [ + "aws-credential-types", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.4.0", + "percent-encoding", + "sha2", + "time", + "tracing", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc12f8b310e38cad85cf3bef45ad236f470717393c613266ce0a89512286b650" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-http" +version = "0.62.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" +dependencies = [ + "aws-smithy-eventstream", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "futures-util", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http-client" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59e62db736db19c488966c8d787f52e6270be565727236fd5579eaa301e7bc4a" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "h2 0.3.27", + "h2 0.4.12", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper 1.8.1", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.7", + "hyper-util", + "pin-project-lite", + "rustls 0.21.12", + "rustls 0.23.35", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.4", + "tower 0.5.2", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-observability" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a" +dependencies = [ + "aws-smithy-runtime-api", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a392db6c583ea4a912538afb86b7be7c5d8887d91604f50eb55c262ee1b4a5f5" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "pin-project-lite", + "pin-utils", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab0d43d899f9e508300e587bf582ba54c27a452dd0a9ea294690669138ae14a2" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.4.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "905cb13a9895626d49cf2ced759b062d913834c7482c38e49557eac4e6193f01" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11b2f670422ff42bf7065031e72b45bc52a3508bd089f743ea90731ca2b6ea57" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "axum" version = "0.7.9" @@ -167,8 +563,8 @@ dependencies = [ "axum-core 0.4.5", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "itoa", "matchit 0.7.3", @@ -195,10 +591,10 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "itoa", "matchit 0.8.4", @@ -227,8 +623,8 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -246,8 +642,8 @@ checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -285,6 +681,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "base64ct" version = "1.8.1" @@ -381,6 +787,16 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "cbc" version = "0.1.2" @@ -438,6 +854,15 @@ dependencies = [ "inout", ] +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -870,6 +1295,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -1038,6 +1469,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futf" version = "0.1.5" @@ -1197,6 +1634,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "h2" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.12.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.12" @@ -1208,7 +1664,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.4.0", "indexmap 2.12.1", "slab", "tokio", @@ -1307,6 +1763,17 @@ dependencies = [ "syn", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.4.0" @@ -1317,6 +1784,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -1324,7 +1802,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.4.0", ] [[package]] @@ -1335,8 +1813,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1352,6 +1830,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.5.10", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.8.1" @@ -1362,9 +1864,9 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1375,19 +1877,35 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-rustls" version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", - "hyper", + "http 1.4.0", + "hyper 1.8.1", "hyper-util", - "rustls", + "rustls 0.23.35", + "rustls-native-certs", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "tower-service", "webpki-roots", ] @@ -1398,7 +1916,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" dependencies = [ - "hyper", + "hyper 1.8.1", "hyper-util", "pin-project-lite", "tokio", @@ -1413,7 +1931,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "native-tls", "tokio", @@ -1432,9 +1950,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", - "hyper", + "http 1.4.0", + "http-body 1.0.1", + "hyper 1.8.1", "ipnet", "libc", "percent-encoding", @@ -2166,7 +2684,7 @@ dependencies = [ "base64", "chrono", "getrandom 0.2.16", - "http", + "http 1.4.0", "rand 0.8.5", "reqwest", "serde", @@ -2243,6 +2761,12 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "pandoc" version = "0.8.11" @@ -2636,7 +3160,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.1.1", - "rustls", + "rustls 0.23.35", "socket2 0.6.1", "thiserror 2.0.17", "tokio", @@ -2656,7 +3180,7 @@ dependencies = [ "rand 0.9.2", "ring", "rustc-hash 2.1.1", - "rustls", + "rustls 0.23.35", "rustls-pki-types", "slab", "thiserror 2.0.17", @@ -2842,6 +3366,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d942b98df5e658f56f20d592c7f868833fe38115e65c33003d8cd224b0155da" + [[package]] name = "regex-syntax" version = "0.8.8" @@ -2861,12 +3391,12 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.8.1", + "hyper-rustls 0.27.7", "hyper-tls", "hyper-util", "js-sys", @@ -2876,7 +3406,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.35", "rustls-pki-types", "serde", "serde_json", @@ -2884,7 +3414,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.26.4", "tokio-util", "tower 0.5.2", "tower-http", @@ -2922,8 +3452,8 @@ dependencies = [ "bytes", "chrono", "futures", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "oauth2", "paste", @@ -3014,6 +3544,15 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.1.2" @@ -3027,17 +3566,30 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.103.8", "subtle", "zeroize", ] @@ -3073,12 +3625,23 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustls-webpki" version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -3137,6 +3700,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -3612,7 +4185,7 @@ checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" dependencies = [ "bytes", "futures-util", - "http-body", + "http-body 1.0.1", "http-body-util", "pin-project-lite", ] @@ -3953,13 +4526,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ - "rustls", + "rustls 0.23.35", "tokio", ] @@ -4030,11 +4613,11 @@ dependencies = [ "base64", "bytes", "flate2", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-timeout", "hyper-util", "percent-encoding", @@ -4044,7 +4627,7 @@ dependencies = [ "rustls-pemfile", "socket2 0.5.10", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "tokio-stream", "tower 0.4.13", "tower-layer", @@ -4111,8 +4694,8 @@ dependencies = [ "bitflags", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "iri-string", "pin-project-lite", "tower 0.5.2", @@ -4282,6 +4865,8 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "aws-config", + "aws-sdk-bedrockruntime", "backon", "base64", "lazy_static", @@ -4536,6 +5121,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf-8" version = "0.7.6" @@ -4578,6 +5169,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "want" version = "0.3.1" @@ -5057,6 +5654,12 @@ dependencies = [ "markup5ever", ] +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "yaml-rust2" version = "0.10.4" diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index 354be11..0bd8558 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -20,3 +20,5 @@ tokio.workspace = true tracing.workspace = true typed-builder.workspace = true umem_config = {workspace = true} +aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } +aws-sdk-bedrockruntime = "1.120.0" diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index ffd6a10..5454e62 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -90,6 +90,7 @@ impl AIProvider { ) -> Result, ResponseGeneratorError> { match self { AIProvider::OpenAI(provider) => provider.generate_object(request), + AIProvider::AmazonBedrock(provider) => provider.generate_object(request), _ => unimplemented!(), } .await diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 6f19701..51dc942 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -1,16 +1,26 @@ -use anyhow::{bail, Result}; -use async_trait::async_trait; - use crate::{ - response_generators::{GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError}, - GeneratesText, + messages::{FilePart, UserModelMessage}, + response_generators::{ + self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, + }, + utils, GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, + OpenAIProvider, +}; +use async_trait::async_trait; +use aws_config::{BehaviorVersion, Region}; +use aws_sdk_bedrockruntime::{ + operation::converse::builders::ConverseFluentBuilder, + types::{ContentBlock, ImageBlock, Message}, }; +use reqwest::header::HeaderMap; +use schemars::JsonSchema; +use serde::{de::DeserializeOwned, Serialize}; +use std::env; +use thiserror::Error; pub struct AmazonBedrockProvider { - pub region: String, - pub access_key: String, - pub secret_key: String, - pub session_token: Option, + client: aws_sdk_bedrockruntime::Client, + default_headers: HeaderMap, } #[async_trait] @@ -23,53 +33,274 @@ impl GeneratesText for AmazonBedrockProvider { } } +#[async_trait] +impl GeneratesObject for AmazonBedrockProvider { + async fn generate_object( + &self, + request: GenerateObjectRequest, + ) -> Result, ResponseGeneratorError> { + let converse_request = self + .normalizer_generate_object_request(&request) + .map_err(ResponseGeneratorError::Transient)?; + + let converse_response = converse_request + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; + + let converse_output = match converse_response.output { + Some(output) => output, + None => { + return Err(ResponseGeneratorError::EmptyProviderResponse); + } + }; + + let output_message = converse_output.as_message().map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output to be a bedrock message".into(), + ) + })?; + + let output_text = output_message + .content + .first() + .ok_or(ResponseGeneratorError::EmptyProviderResponse)? + .as_text() + .map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output message content to be text".into(), + ) + })?; + + serde_json::from_str::(output_text) + .map(|output| GenerateObjectResponse { output }) + .map_err(ResponseGeneratorError::Deserialization) + } +} + +impl AmazonBedrockProvider { + fn normalizer_generate_object_request< + T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, + >( + &self, + request: &GenerateObjectRequest, + ) -> anyhow::Result { + let mut system = OpenAIProvider::normalize_system_message(&request.messages); + system.push_str( + r#" + **Critical Instruction**: ALWAYS and ONLY respond with a JSON object that conforms EXACTLY to the provided JSON Schema (enclosed in ``). + "#, + ); + let mut user_messages = Self::normalize_user_messages(&request.messages).unwrap(); + user_messages.push(ContentBlock::Text(format!( + "\n{}\n", + serde_json::to_string_pretty(&request.output_schema,)? + ))); + + Ok(self + .client + .converse() + .model_id(request.model.model_name.clone()) + .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( + system, + )) + .messages( + Message::builder() + .role("user".into()) + .set_content(Some(user_messages)) + .build() + .unwrap(), + )) + } + + fn normalize_user_messages( + messages: &[response_generators::messages::Message], + ) -> anyhow::Result> { + let user_messages: Vec<&response_generators::messages::UserModelMessage> = messages + .iter() + .filter_map(|msg| match msg { + response_generators::messages::Message::User(v) => Some(v), + _ => None, + }) + .collect(); + + let user_message_content_blocks: Vec = user_messages + .iter() + .flat_map(|um| match um { + UserModelMessage::Text(text) => vec![ContentBlock::Text(text.into())], + UserModelMessage::Parts(user_message_parts) => user_message_parts + .iter() + .map(|part| match part { + crate::messages::UserMessagePart::Text(text) => { + ContentBlock::Text(text.into()) + } + crate::messages::UserMessagePart::Image(image_part) => { + let image_block = match image_part { + FilePart::Url(_, _) => { + unimplemented!("AWS doesn't support URL images yet"); + } + FilePart::Base64(b64_string, media_type) => ImageBlock::builder() + .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( + b64_string.as_bytes().into(), + )) + .format( + media_type + .clone() + .unwrap_or(mime::IMAGE_PNG) + .to_string() + .as_str() + .into(), + ) + .build() + .unwrap(), + FilePart::Buffer(items, media_type) => ImageBlock::builder() + .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( + items.as_slice().into(), + )) + .format( + media_type + .clone() + .unwrap_or(mime::IMAGE_PNG) + .to_string() + .as_str() + .into(), + ) + .build() + .unwrap(), + }; + ContentBlock::Image(image_block) + } + crate::messages::UserMessagePart::File(file_part) => todo!(), + }) + .collect(), + }) + .collect(); + + Ok(user_message_content_blocks) + } +} + +#[derive(Default)] pub struct AmazonBedrockProviderBuilder { - pub region: Option, - pub access_key: Option, - pub secret_key: Option, - pub session_token: Option, + access_key_id: Option, + secret_access_key: Option, + region: Option, + default_headers: Vec<(String, String)>, +} + +#[derive(Error, Debug)] +pub enum AmazonBedrockProviderBuilderError { + #[error("Missing AWS Access Key ID")] + MissingAccessKeyId, + #[error("Missing AWS Secret Access Key")] + MissingSecretAccessKey, + #[error("Missing AWS Bedrock Region")] + MissingRegion, + #[error(transparent)] + BadHeaders(#[from] utils::BuildHeaderMapError), } impl AmazonBedrockProviderBuilder { pub fn new() -> Self { - Self { - region: None, - access_key: None, - secret_key: None, - session_token: None, - } + Self::default() } - pub fn region(mut self, region: String) -> Self { - self.region = Some(region); + pub fn access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); self } - pub fn access_key(mut self, access_key: String) -> Self { - self.access_key = Some(access_key); + pub fn secret_access_key(mut self, secret_access_key: impl Into) -> Self { + self.secret_access_key = Some(secret_access_key.into()); self } - pub fn secret_key(mut self, secret_key: String) -> Self { - self.secret_key = Some(secret_key); + pub fn region(mut self, region: impl Into) -> Self { + self.region = Some(Region::new(region.into())); self } - pub fn session_token(mut self, session_token: String) -> Self { - self.session_token = Some(session_token); + pub fn default_headers(mut self, headers: Vec<(String, String)>) -> Self { + self.default_headers = headers; self } - pub fn build(self) -> Result { - if self.region.is_none() || self.access_key.is_none() || self.secret_key.is_none() { - bail!("region, access_key, and secret_key are required"); + pub async fn build(self) -> Result { + if self.access_key_id.is_none() { + return Err(AmazonBedrockProviderBuilderError::MissingAccessKeyId); + } + + if self.secret_access_key.is_none() { + return Err(AmazonBedrockProviderBuilderError::MissingSecretAccessKey); } + if self.region.is_none() { + return Err(AmazonBedrockProviderBuilderError::MissingRegion); + } + + unsafe { + env::set_var("AWS_ACCESS_KEY_ID", self.access_key_id.clone().unwrap()); + env::set_var( + "AWS_SECRET_ACCESS_KEY", + self.secret_access_key.clone().unwrap(), + ); + } + + let sdk_config = aws_config::defaults(BehaviorVersion::latest()) + .region(self.region) + .load() + .await; + + let default_headers = utils::build_header_map(self.default_headers.as_slice())?; + Ok(AmazonBedrockProvider { - region: self.region.unwrap(), - access_key: self.access_key.unwrap(), - secret_key: self.secret_key.unwrap(), - session_token: self.session_token, + client: aws_sdk_bedrockruntime::Client::new(&sdk_config), + default_headers, }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{generate_object, AIProvider, GenerateObjectRequestBuilder, LanguageModel}; + use serde::Deserialize; + use std::sync::Arc; + + #[tokio::test] + async fn test_bedrock_generate_object() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACCESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + #[derive(Clone, JsonSchema, Serialize, Deserialize, Debug)] + struct Holiday { + name: String, + traditions: String, + } + + let model = Arc::new(LanguageModel { + provider, + model_name: "deepseek.v3-v1:0".to_string(), + }); + + let request = GenerateObjectRequestBuilder::::new() + .model(model) + .system("You are a helpful assistant.".to_string()) + .prompt("Invent a new holiday and describe its traditions.".to_string()) + .max_output_tokens(2000) + .temperature(0.7) + .build() + .unwrap(); + + let generate_object_response = generate_object(request).await.unwrap(); + + dbg!(&generate_object_response); + } +} diff --git a/crates/umem_ai/src/providers/openai.rs b/crates/umem_ai/src/providers/openai.rs index 81e03a9..919b52d 100644 --- a/crates/umem_ai/src/providers/openai.rs +++ b/crates/umem_ai/src/providers/openai.rs @@ -42,8 +42,8 @@ impl OpenAIProvider { &self, request: &GenerateObjectRequest, ) -> String { - let system = self.normalize_system_message(&request.messages); - let normalized_user_messages = self.normalize_user_messages(&request.messages); + let system = Self::normalize_system_message(&request.messages); + let normalized_user_messages = Self::normalize_user_messages(&request.messages); let schema = request.output_schema.clone(); let name = std::any::type_name::() .split("::") @@ -78,8 +78,8 @@ impl OpenAIProvider { } pub fn normalize_generate_text_request(&self, request: &GenerateTextRequest) -> String { - let system = self.normalize_system_message(&request.messages); - let normalized_user_messages = self.normalize_user_messages(&request.messages); + let system = Self::normalize_system_message(&request.messages); + let normalized_user_messages = Self::normalize_user_messages(&request.messages); serde_json::json!({ "model": request.model.model_name, @@ -102,7 +102,7 @@ impl OpenAIProvider { .to_string() } - fn normalize_system_message(&self, messages: &[Message]) -> String { + pub(crate) fn normalize_system_message(messages: &[Message]) -> String { messages .iter() .find_map(|msg| match msg { @@ -113,7 +113,7 @@ impl OpenAIProvider { .into() } - fn normalize_user_messages(&self, messages: &[Message]) -> Vec { + pub(crate) fn normalize_user_messages(messages: &[Message]) -> Vec { let user_messages: Vec<&UserModelMessage> = messages .iter() .filter_map(|msg| match msg { @@ -265,7 +265,7 @@ impl GeneratesObject for OpenAIProvider { .unwrap_or_default(); let output: T = - serde_json::from_str(&output_text).map_err(ResponseGeneratorError::Serialization)?; + serde_json::from_str(&output_text).map_err(ResponseGeneratorError::Deserialization)?; Ok(GenerateObjectResponse { output }) } diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index 63b79d9..e244513 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -11,9 +11,15 @@ pub enum ResponseGeneratorError { #[error(transparent)] Http(#[from] reqwest::Error), #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error(transparent)] - Transient(#[from] anyhow::Error), + Deserialization(#[from] serde_json::Error), #[error(transparent)] TimeoutError(#[from] tokio::time::error::Elapsed), + #[error("Bedrock Converse API error, Details: {0}")] + BedrockConverseError(String), + #[error("empty response from AI provider")] + EmptyProviderResponse, + #[error("invalid response from AI provider, Details: {0}")] + InvalidProviderResponse(String), + #[error(transparent)] + Transient(#[from] anyhow::Error), } diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index f00485a..830b74f 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -15,7 +15,7 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { s.is_server_error() || s == reqwest::StatusCode::TOO_MANY_REQUESTS }) } - ResponseGeneratorError::Serialization(error) => { + ResponseGeneratorError::Deserialization(error) => { tracing::warn!( "Serialization error, AI Might have built a bad JSON output: {}", error @@ -36,6 +36,21 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { ); false } + ResponseGeneratorError::InvalidProviderResponse(e) => { + tracing::error!("Invalid response from AI provider: {}", e); + true + } + ResponseGeneratorError::EmptyProviderResponse => { + tracing::error!("Empty response from AI provider"); + true + } + ResponseGeneratorError::BedrockConverseError(sdk_error) => { + tracing::error!( + "AWS Bedrock Converse SDK error occurred when communicating with AI provider: {}", + sdk_error + ); + true + } } } From 313d487e3cea05ad7880bf1108ae9381b7ac5cf6 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 29 Dec 2025 14:15:16 -0800 Subject: [PATCH 02/12] feat(amazon-bedrock): implement text generation and add tests Implement the `generate_text` method for the `AmazonBedrockProvider`, enabling text generation via Amazon Bedrock. Add request normalization logic and integrate with the provider selection in `AIProvider`. Include unit tests for `generate_text` to verify functionality. --- crates/umem_ai/src/lib.rs | 1 + .../umem_ai/src/providers/amazon_bedrock.rs | 103 +++++++++++++++++- 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index 5454e62..dce868c 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -77,6 +77,7 @@ impl AIProvider { ) -> Result { match self { AIProvider::OpenAI(provider) => provider.generate_text(request), + AIProvider::AmazonBedrock(provider) => provider.generate_text(request), _ => unimplemented!(), } .await diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 51dc942..e5b48bc 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -27,9 +27,44 @@ pub struct AmazonBedrockProvider { impl GeneratesText for AmazonBedrockProvider { async fn generate_text( &self, - _request: GenerateTextRequest, + request: GenerateTextRequest, ) -> Result { - unimplemented!() + let converse_request = self + .normalize_generate_text_request(&request) + .map_err(ResponseGeneratorError::Transient)?; + + let converse_response = converse_request + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; + + let converse_output = match converse_response.output { + Some(output) => output, + None => { + return Err(ResponseGeneratorError::EmptyProviderResponse); + } + }; + + let output_message = converse_output.as_message().map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output to be a bedrock message".into(), + ) + })?; + + let output_text = output_message + .content + .first() + .ok_or(ResponseGeneratorError::EmptyProviderResponse)? + .as_text() + .map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output message content to be text".into(), + ) + })?; + + Ok(GenerateTextResponse { + text: output_text.to_string(), + }) } } @@ -113,6 +148,29 @@ impl AmazonBedrockProvider { )) } + fn normalize_generate_text_request( + &self, + request: &GenerateTextRequest, + ) -> anyhow::Result { + let system = OpenAIProvider::normalize_system_message(&request.messages); + let user_messages = Self::normalize_user_messages(&request.messages)?; + + Ok(self + .client + .converse() + .model_id(request.model.model_name.clone()) + .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( + system, + )) + .messages( + Message::builder() + .role("user".into()) + .set_content(Some(user_messages)) + .build() + .unwrap(), + )) + } + fn normalize_user_messages( messages: &[response_generators::messages::Message], ) -> anyhow::Result> { @@ -262,17 +320,21 @@ impl AmazonBedrockProviderBuilder { #[cfg(test)] mod tests { - use super::*; - use crate::{generate_object, AIProvider, GenerateObjectRequestBuilder, LanguageModel}; + use crate::{ + generate_object, generate_text, AIProvider, GenerateObjectRequestBuilder, + GenerateTextRequestBuilder, LanguageModel, + }; use serde::Deserialize; use std::sync::Arc; + use super::*; + #[tokio::test] async fn test_bedrock_generate_object() { let provider = Arc::new(AIProvider::from( AmazonBedrockProviderBuilder::default() .region("REGION") - .access_key_id("ACCESS_KEY_ID") + .access_key_id("ACESS_KEY_ID") .secret_access_key("SECRET_ACCESS_KEY") .build() .await @@ -303,4 +365,35 @@ mod tests { dbg!(&generate_object_response); } + + #[tokio::test] + async fn test_bedrock_generate_text() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(LanguageModel { + provider, + model_name: "deepseek.v3-v1:0".to_string(), + }); + + let request = GenerateTextRequestBuilder::new() + .model(model) + .system("You are a helpful assistant.") + .prompt("Invent a new holiday and describe its traditions.") + .max_output_tokens(2000) + .temperature(0.7) + .build() + .unwrap(); + + let generate_text_response = generate_text(request).await.unwrap(); + + dbg!(&generate_text_response); + } } From f43b7ae181fb09538711600e2a3128a1a6de55ad Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 29 Dec 2025 14:15:16 -0800 Subject: [PATCH 03/12] feat(amazon-bedrock): implement text generation and add tests Implement the `generate_text` method for the `AmazonBedrockProvider`, enabling text generation via Amazon Bedrock. Add request normalization logic and integrate with the provider selection in `AIProvider`. Include unit tests for `generate_text` to verify functionality. --- crates/umem_ai/src/providers/amazon_bedrock.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index e5b48bc..47650a4 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -114,7 +114,7 @@ impl GeneratesObject for AmazonBedrockProvider { } impl AmazonBedrockProvider { - fn normalizer_generate_object_request< + fn normalize_generate_object_request< T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, >( &self, From 258e379ea2d89fd310cb3e1c1fc7fe852ab9c66a Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 29 Dec 2025 14:15:16 -0800 Subject: [PATCH 04/12] feat(amazon-bedrock): implement text generation and add tests Implement the `generate_text` method for the `AmazonBedrockProvider`, enabling text generation via Amazon Bedrock. Add request normalization logic and integrate with the provider selection in `AIProvider`. Include unit tests for `generate_text` to verify functionality. --- crates/umem_ai/src/providers/amazon_bedrock.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 47650a4..57d31bb 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -75,7 +75,7 @@ impl GeneratesObject for AmazonBedrockProvider { request: GenerateObjectRequest, ) -> Result, ResponseGeneratorError> { let converse_request = self - .normalizer_generate_object_request(&request) + .normalize_generate_object_request(&request) .map_err(ResponseGeneratorError::Transient)?; let converse_response = converse_request @@ -228,7 +228,7 @@ impl AmazonBedrockProvider { }; ContentBlock::Image(image_block) } - crate::messages::UserMessagePart::File(file_part) => todo!(), + crate::messages::UserMessagePart::File(_) => todo!(), }) .collect(), }) From 91b99918f54c715ab072d33ad7785ff270df9851 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 29 Dec 2025 15:09:11 -0800 Subject: [PATCH 05/12] feat(umem_ai): add inference config and custom headers to Bedrock provider - Add `aws-smithy-types` dependency to support structured model request fields. - Update `AmazonBedrockProvider` to set `InferenceConfiguration` (temperature, top_p, max_tokens) on Bedrock requests. - Pass custom headers as additional model request fields using `aws_smithy_types::Document`. - Refactor provider builder to use explicit credentials and provider name, removing reliance on environment variables and default headers. - Use `Arc` for Bedrock client to support potential sharing. --- Cargo.lock | 1 + crates/umem_ai/Cargo.toml | 1 + .../umem_ai/src/providers/amazon_bedrock.rs | 91 ++++++++++++------- 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 522ff1c..c84f3cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4867,6 +4867,7 @@ dependencies = [ "async-trait", "aws-config", "aws-sdk-bedrockruntime", + "aws-smithy-types", "backon", "base64", "lazy_static", diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index 0bd8558..e51bb07 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -22,3 +22,4 @@ typed-builder.workspace = true umem_config = {workspace = true} aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = "1.120.0" +aws-smithy-types = "1.3.5" diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 57d31bb..4b69108 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -10,17 +10,15 @@ use async_trait::async_trait; use aws_config::{BehaviorVersion, Region}; use aws_sdk_bedrockruntime::{ operation::converse::builders::ConverseFluentBuilder, - types::{ContentBlock, ImageBlock, Message}, + types::{ContentBlock, ImageBlock, InferenceConfiguration, Message}, }; -use reqwest::header::HeaderMap; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Serialize}; -use std::env; +use std::sync::Arc; use thiserror::Error; pub struct AmazonBedrockProvider { - client: aws_sdk_bedrockruntime::Client, - default_headers: HeaderMap, + client: Arc, } #[async_trait] @@ -34,6 +32,25 @@ impl GeneratesText for AmazonBedrockProvider { .map_err(ResponseGeneratorError::Transient)?; let converse_response = converse_request + .set_inference_config(Some( + InferenceConfiguration::builder() + .temperature(request.temperature.unwrap_or(0.0)) + .top_p(request.top_p.unwrap_or(1.0)) + .max_tokens(request.max_output_tokens.unwrap_or(0_usize) as i32) + .build(), + )) + .additional_model_request_fields(aws_smithy_types::Document::Object( + request + .headers + .iter() + .map(|(key, value)| { + ( + key.as_str().to_string(), + aws_smithy_types::Document::String(value.to_str().unwrap().to_string()), + ) + }) + .collect(), + )) .send() .await .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; @@ -79,6 +96,25 @@ impl GeneratesObject for AmazonBedrockProvider { .map_err(ResponseGeneratorError::Transient)?; let converse_response = converse_request + .set_inference_config(Some( + InferenceConfiguration::builder() + .temperature(request.temperature.unwrap_or(0.0)) + .top_p(request.top_p.unwrap_or(1.0)) + .max_tokens(request.max_output_tokens.unwrap_or(0_usize) as i32) + .build(), + )) + .additional_model_request_fields(aws_smithy_types::Document::Object( + request + .headers + .iter() + .map(|(key, value)| { + ( + key.as_str().to_string(), + aws_smithy_types::Document::String(value.to_str().unwrap().to_string()), + ) + }) + .collect(), + )) .send() .await .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; @@ -243,7 +279,7 @@ pub struct AmazonBedrockProviderBuilder { access_key_id: Option, secret_access_key: Option, region: Option, - default_headers: Vec<(String, String)>, + provider_name: Option, } #[derive(Error, Debug)] @@ -278,42 +314,33 @@ impl AmazonBedrockProviderBuilder { self } - pub fn default_headers(mut self, headers: Vec<(String, String)>) -> Self { - self.default_headers = headers; + pub fn provider_name(mut self, provider_name: impl Into) -> Self { + self.provider_name = Some(provider_name.into()); self } pub async fn build(self) -> Result { - if self.access_key_id.is_none() { - return Err(AmazonBedrockProviderBuilderError::MissingAccessKeyId); - } - - if self.secret_access_key.is_none() { - return Err(AmazonBedrockProviderBuilderError::MissingSecretAccessKey); - } - - if self.region.is_none() { - return Err(AmazonBedrockProviderBuilderError::MissingRegion); - } - - unsafe { - env::set_var("AWS_ACCESS_KEY_ID", self.access_key_id.clone().unwrap()); - env::set_var( - "AWS_SECRET_ACCESS_KEY", - self.secret_access_key.clone().unwrap(), - ); - } - let sdk_config = aws_config::defaults(BehaviorVersion::latest()) .region(self.region) + .credentials_provider( + aws_sdk_bedrockruntime::config::Credentials::builder() + .access_key_id( + self.access_key_id + .clone() + .ok_or(AmazonBedrockProviderBuilderError::MissingAccessKeyId)?, + ) + .secret_access_key( + self.secret_access_key + .ok_or(AmazonBedrockProviderBuilderError::MissingSecretAccessKey)?, + ) + .provider_name("umem-ai-bedrock-provider") + .build(), + ) .load() .await; - let default_headers = utils::build_header_map(self.default_headers.as_slice())?; - Ok(AmazonBedrockProvider { - client: aws_sdk_bedrockruntime::Client::new(&sdk_config), - default_headers, + client: Arc::new(aws_sdk_bedrockruntime::Client::new(&sdk_config)), }) } } From ce2e64aef11f38f90691fd77d47d1a5f84627da1 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 29 Dec 2025 15:25:18 -0800 Subject: [PATCH 06/12] fix(amazon_bedrock): properly decode base64 image data before sending to Bedrock Previously, base64-encoded image data was passed as raw bytes, which could result in invalid images being sent to AWS Bedrock. This change decodes the base64 string before constructing the ImageBlock, ensuring valid image data is used. Also improves error handling and clarifies unimplemented file handling. --- .../umem_ai/src/providers/amazon_bedrock.rs | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 4b69108..5b911b8 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -12,6 +12,7 @@ use aws_sdk_bedrockruntime::{ operation::converse::builders::ConverseFluentBuilder, types::{ContentBlock, ImageBlock, InferenceConfiguration, Message}, }; +use base64::Engine; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Serialize}; use std::sync::Arc; @@ -233,20 +234,26 @@ impl AmazonBedrockProvider { FilePart::Url(_, _) => { unimplemented!("AWS doesn't support URL images yet"); } - FilePart::Base64(b64_string, media_type) => ImageBlock::builder() - .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( - b64_string.as_bytes().into(), - )) - .format( - media_type - .clone() - .unwrap_or(mime::IMAGE_PNG) - .to_string() - .as_str() - .into(), - ) - .build() - .unwrap(), + FilePart::Base64(b64_string, media_type) => { + let decoded = base64::engine::general_purpose::STANDARD + .decode(b64_string) + .expect("not a valid base64 string"); + + ImageBlock::builder() + .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( + decoded.as_slice().into(), + )) + .format( + media_type + .clone() + .unwrap_or(mime::IMAGE_PNG) + .to_string() + .as_str() + .into(), + ) + .build() + .expect("failed to build image block") + } FilePart::Buffer(items, media_type) => ImageBlock::builder() .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( items.as_slice().into(), @@ -260,11 +267,13 @@ impl AmazonBedrockProvider { .into(), ) .build() - .unwrap(), + .expect("failed to build image block"), }; ContentBlock::Image(image_block) } - crate::messages::UserMessagePart::File(_) => todo!(), + crate::messages::UserMessagePart::File(_) => { + unimplemented!("file handling not yet supported for Bedrock") + } }) .collect(), }) From d27ac1d576fc0c5738202e621320a22772fffa44 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Wed, 31 Dec 2025 15:57:24 -0500 Subject: [PATCH 07/12] fix(error-handling): improve deserialization error reporting with response context - Update `ResponseGeneratorError::Deserialization` to include both the serde_json error and the original response string. - Enhance error messages and tracing logs to display both the error and the problematic response, aiding debugging. - Change Amazon Bedrock and OpenAI providers to pass the response text into deserialization errors. - Update error handling in utils to log both error and response content for deserialization failures. --- crates/umem_ai/src/providers/amazon_bedrock.rs | 4 ++-- crates/umem_ai/src/providers/openai.rs | 2 +- crates/umem_ai/src/response_generators/mod.rs | 4 ++-- crates/umem_ai/src/utils.rs | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 5b911b8..0810784 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -135,7 +135,7 @@ impl GeneratesObject for AmazonBedrockProvider { let output_text = output_message .content - .first() + .last() .ok_or(ResponseGeneratorError::EmptyProviderResponse)? .as_text() .map_err(|_| { @@ -146,7 +146,7 @@ impl GeneratesObject for AmazonBedrockProvider { serde_json::from_str::(output_text) .map(|output| GenerateObjectResponse { output }) - .map_err(ResponseGeneratorError::Deserialization) + .map_err(|e| ResponseGeneratorError::Deserialization(e, output_text.clone())) } } diff --git a/crates/umem_ai/src/providers/openai.rs b/crates/umem_ai/src/providers/openai.rs index 919b52d..6cf2fc2 100644 --- a/crates/umem_ai/src/providers/openai.rs +++ b/crates/umem_ai/src/providers/openai.rs @@ -265,7 +265,7 @@ impl GeneratesObject for OpenAIProvider { .unwrap_or_default(); let output: T = - serde_json::from_str(&output_text).map_err(ResponseGeneratorError::Deserialization)?; + serde_json::from_str(&output_text).map_err(|e| ResponseGeneratorError::Deserialization(e, output_text))?; Ok(GenerateObjectResponse { output }) } diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index e244513..9ff7f9c 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -10,8 +10,8 @@ use thiserror::Error; pub enum ResponseGeneratorError { #[error(transparent)] Http(#[from] reqwest::Error), - #[error(transparent)] - Deserialization(#[from] serde_json::Error), + #[error("deserialization error, Details: {1}, Response: {0}")] + Deserialization(serde_json::Error, String), #[error(transparent)] TimeoutError(#[from] tokio::time::error::Elapsed), #[error("Bedrock Converse API error, Details: {0}")] diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index 830b74f..cbc749e 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -15,10 +15,10 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { s.is_server_error() || s == reqwest::StatusCode::TOO_MANY_REQUESTS }) } - ResponseGeneratorError::Deserialization(error) => { + ResponseGeneratorError::Deserialization(error, response) => { tracing::warn!( - "Serialization error, AI Might have built a bad JSON output: {}", - error + "Serialization error, AI Might have built a bad JSON output: \n Error: {} \n Received Response: {}", + error,response ); true } From 39ccca0954577244930e15a0e9299c59a43c9ff0 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Wed, 31 Dec 2025 22:45:41 -0500 Subject: [PATCH 08/12] feat(bedrock): add tool use support for JSON output and schema enforcement - Enhance AmazonBedrockProvider to use Bedrock's tool use API for structured JSON output. - Add conversion utilities between serde_json::Value and aws_smithy_types::Document. - Update error handling and deserialization logic to support tool use responses. - Configure tool specification and input schema for JSON output in Bedrock requests. - Bump aws-smithy-types dependency to enable required features. --- crates/umem_ai/Cargo.toml | 2 +- .../umem_ai/src/providers/amazon_bedrock.rs | 75 ++++++++++++------- crates/umem_ai/src/utils.rs | 52 +++++++++++++ 3 files changed, 101 insertions(+), 28 deletions(-) diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index e51bb07..c40508f 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -22,4 +22,4 @@ typed-builder.workspace = true umem_config = {workspace = true} aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = "1.120.0" -aws-smithy-types = "1.3.5" +aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]} diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 0810784..dabd3dc 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -10,7 +10,10 @@ use async_trait::async_trait; use aws_config::{BehaviorVersion, Region}; use aws_sdk_bedrockruntime::{ operation::converse::builders::ConverseFluentBuilder, - types::{ContentBlock, ImageBlock, InferenceConfiguration, Message}, + types::{ + AnyToolChoice, ContentBlock, ConverseOutput, ImageBlock, InferenceConfiguration, Message, + Tool, ToolChoice, ToolConfiguration, ToolInputSchema, ToolSpecification, + }, }; use base64::Engine; use schemars::JsonSchema; @@ -54,7 +57,10 @@ impl GeneratesText for AmazonBedrockProvider { )) .send() .await - .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; + .map_err(|e| { + tracing::error!("{}", e); + ResponseGeneratorError::BedrockConverseError(format!("{:?}", e)) + })?; let converse_output = match converse_response.output { Some(output) => output, @@ -101,7 +107,7 @@ impl GeneratesObject for AmazonBedrockProvider { InferenceConfiguration::builder() .temperature(request.temperature.unwrap_or(0.0)) .top_p(request.top_p.unwrap_or(1.0)) - .max_tokens(request.max_output_tokens.unwrap_or(0_usize) as i32) + .max_tokens(request.max_output_tokens.unwrap_or(5140_usize) as i32) .build(), )) .additional_model_request_fields(aws_smithy_types::Document::Object( @@ -127,26 +133,35 @@ impl GeneratesObject for AmazonBedrockProvider { } }; - let output_message = converse_output.as_message().map_err(|_| { - ResponseGeneratorError::InvalidProviderResponse( - "was expecting the output to be a bedrock message".into(), - ) - })?; + let output_message = match converse_output { + ConverseOutput::Message(msg) => msg, + _ => { + return Err(ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output to be a bedrock message".into(), + )) + } + }; - let output_text = output_message + let json_tool = output_message .content - .last() - .ok_or(ResponseGeneratorError::EmptyProviderResponse)? - .as_text() + .into_iter() + .rfind(|content_item| content_item.is_tool_use()) + .ok_or(ResponseGeneratorError::EmptyProviderResponse)?; + + let json_tool_input = json_tool + .as_tool_use() .map_err(|_| { ResponseGeneratorError::InvalidProviderResponse( - "was expecting the output message content to be text".into(), + "was expecting the model to call the tool use".into(), ) - })?; + })? + .input(); - serde_json::from_str::(output_text) + serde_json::from_value::(utils::aws_smithy_document_to_json(json_tool_input)) .map(|output| GenerateObjectResponse { output }) - .map_err(|e| ResponseGeneratorError::Deserialization(e, output_text.clone())) + .map_err(|e| { + ResponseGeneratorError::Deserialization(e, format!("{:?}", json_tool_input)) + }) } } @@ -157,17 +172,9 @@ impl AmazonBedrockProvider { &self, request: &GenerateObjectRequest, ) -> anyhow::Result { - let mut system = OpenAIProvider::normalize_system_message(&request.messages); - system.push_str( - r#" - **Critical Instruction**: ALWAYS and ONLY respond with a JSON object that conforms EXACTLY to the provided JSON Schema (enclosed in ``). - "#, - ); - let mut user_messages = Self::normalize_user_messages(&request.messages).unwrap(); - user_messages.push(ContentBlock::Text(format!( - "\n{}\n", - serde_json::to_string_pretty(&request.output_schema,)? - ))); + let system = OpenAIProvider::normalize_system_message(&request.messages); + let user_messages = Self::normalize_user_messages(&request.messages).unwrap(); + let output_schema_value = serde_json::to_value(&request.output_schema)?; Ok(self .client @@ -176,6 +183,20 @@ impl AmazonBedrockProvider { .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( system, )) + .tool_config( + ToolConfiguration::builder() + .tools(Tool::ToolSpec( + ToolSpecification::builder() + .name("json_output") + .description("Return output as JSON.") + .input_schema(ToolInputSchema::Json( + utils::json_to_aws_smithy_document(output_schema_value), + )) + .build()?, + )) + .tool_choice(ToolChoice::Any(AnyToolChoice::builder().build())) + .build()?, + ) .messages( Message::builder() .role("user".into()) diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index cbc749e..814b915 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use crate::response_generators::ResponseGeneratorError; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use thiserror::Error; @@ -83,3 +85,53 @@ pub fn build_header_map(headers: &[(String, String)]) -> Result aws_smithy_types::Document { + match value { + serde_json::Value::Null => aws_smithy_types::Document::Null, + serde_json::Value::Bool(b) => aws_smithy_types::Document::Bool(b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + aws_smithy_types::Document::from(i) + } else if let Some(u) = n.as_u64() { + aws_smithy_types::Document::from(u) + } else if let Some(f) = n.as_f64() { + aws_smithy_types::Document::from(f) + } else { + aws_smithy_types::Document::Null + } + } + serde_json::Value::String(s) => aws_smithy_types::Document::String(s), + serde_json::Value::Array(arr) => aws_smithy_types::Document::Array( + arr.into_iter().map(json_to_aws_smithy_document).collect(), + ), + serde_json::Value::Object(obj) => aws_smithy_types::Document::Object( + obj.into_iter() + .map(|(k, v)| (k, json_to_aws_smithy_document(v))) + .collect::>(), + ), + } +} + +pub fn aws_smithy_document_to_json(doc: &aws_smithy_types::Document) -> serde_json::Value { + match doc { + aws_smithy_types::Document::Null => serde_json::Value::Null, + aws_smithy_types::Document::Bool(b) => serde_json::Value::Bool(*b), + aws_smithy_types::Document::Number(n) => match n { + aws_smithy_types::Number::PosInt(u) => serde_json::Value::Number((*u).into()), + aws_smithy_types::Number::NegInt(i) => serde_json::Value::Number((*i).into()), + aws_smithy_types::Number::Float(f) => serde_json::Number::from_f64((*f).into()) + .map(serde_json::Value::Number) + .unwrap_or(serde_json::Value::Null), + }, + aws_smithy_types::Document::String(s) => serde_json::Value::String(s.clone()), + aws_smithy_types::Document::Array(arr) => { + serde_json::Value::Array(arr.into_iter().map(aws_smithy_document_to_json).collect()) + } + aws_smithy_types::Document::Object(obj) => serde_json::Value::Object( + obj.into_iter() + .map(|(k, v)| (k.clone(), aws_smithy_document_to_json(v))) + .collect(), + ), + } +} From 30a3b79cdbbb0577d8ea87ea3d7899c8ac1d5ff0 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 5 Jan 2026 21:34:42 -0800 Subject: [PATCH 09/12] feat(ai): add rerank and structured_rerank APIs with builder patterns - Introduce `rerank` and `structured_rerank` modules to support ranking and structured ranking of documents. - Add `RerankRequest`, `RerankResponse`, and builder with validation for reranking text documents. - Add `StructuredRerankRequest`, `StructuredRerankResponse`, and builder for reranking structured data with generic support. - Expose new modules in `response_generators::mod.rs`. - Make `generate_text` and `generate_object` requests constructible via builder methods. - Add `serde-saphyr` dependency for structured serialization. - Update `Cargo.lock` with new dependencies and versions. --- Cargo.lock | 92 ++++++++++++-- crates/umem_ai/Cargo.toml | 1 + .../response_generators/generate_object.rs | 9 ++ .../src/response_generators/generate_text.rs | 6 + crates/umem_ai/src/response_generators/mod.rs | 6 +- .../umem_ai/src/response_generators/rerank.rs | 83 +++++++++++++ .../response_generators/structured_rerank.rs | 113 ++++++++++++++++++ 7 files changed, 300 insertions(+), 10 deletions(-) create mode 100644 crates/umem_ai/src/response_generators/rerank.rs create mode 100644 crates/umem_ai/src/response_generators/structured_rerank.rs diff --git a/Cargo.lock b/Cargo.lock index c84f3cd..631ab99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,22 @@ dependencies = [ "libc", ] +[[package]] +name = "annotate-snippets" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15580ece6ea97cbf832d60ba19c021113469480852c6a2a6beb0db28f097bf1f" +dependencies = [ + "anstyle", + "unicode-width", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + [[package]] name = "anyhow" version = "1.0.100" @@ -1334,6 +1350,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "encoding_rs_io" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3c5651fb62ab8aa3103998dade57efdd028544bd300516baa31840c252a83" +dependencies = [ + "encoding_rs", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1872,7 +1897,7 @@ dependencies = [ "itoa", "pin-project-lite", "pin-utils", - "smallvec", + "smallvec 1.15.1", "tokio", "want", ] @@ -2047,7 +2072,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.15.1", "zerovec", ] @@ -2129,7 +2154,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.15.1", "utf8_iter", ] @@ -2584,6 +2609,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "8.0.0" @@ -2635,7 +2666,7 @@ dependencies = [ "num-iter", "num-traits", "rand 0.8.5", - "smallvec", + "smallvec 1.15.1", "zeroize", ] @@ -2801,7 +2832,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.15.1", "windows-link", ] @@ -3659,6 +3690,16 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "saphyr-parser" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb771b59f6b1985d1406325ec28f97cfb14256abcec4fdfb37b36a1766d6af7" +dependencies = [ + "arraydeque", + "hashlink", +] + [[package]] name = "schannel" version = "0.1.28" @@ -3762,6 +3803,25 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-saphyr" +version = "0.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be7f28bb35aab6f9cdc5f464f9bbda0b1e3908766e819557111cc13e58ce7915" +dependencies = [ + "ahash", + "annotate-snippets", + "base64", + "encoding_rs_io", + "nohash-hasher", + "num-traits", + "ryu", + "saphyr-parser", + "serde", + "serde_json", + "smallvec 2.0.0-alpha.12", +] + [[package]] name = "serde-untagged" version = "0.1.9" @@ -3811,6 +3871,7 @@ version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ + "indexmap 2.12.1", "itoa", "memchr", "ryu", @@ -3945,6 +4006,12 @@ dependencies = [ "serde", ] +[[package]] +name = "smallvec" +version = "2.0.0-alpha.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef784004ca8777809dcdad6ac37629f0a97caee4c685fcea805278d81dd8b857" + [[package]] name = "socket2" version = "0.5.10" @@ -4024,7 +4091,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "smallvec", + "smallvec 1.15.1", "thiserror 2.0.17", "tokio", "tokio-stream", @@ -4105,7 +4172,7 @@ dependencies = [ "serde", "sha1", "sha2", - "smallvec", + "smallvec 1.15.1", "sqlx-core", "stringprep", "thiserror 2.0.17", @@ -4143,7 +4210,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "smallvec", + "smallvec 1.15.1", "sqlx-core", "stringprep", "thiserror 2.0.17", @@ -4784,7 +4851,7 @@ dependencies = [ "once_cell", "regex-automata", "sharded-slab", - "smallvec", + "smallvec 1.15.1", "thread_local", "tracing", "tracing-core", @@ -4876,6 +4943,7 @@ dependencies = [ "rustc-hash 2.1.1", "schemars", "serde", + "serde-saphyr", "serde_json", "thiserror 2.0.17", "tokio", @@ -5104,6 +5172,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index c40508f..5abd232 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -23,3 +23,4 @@ umem_config = {workspace = true} aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = "1.120.0" aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]} +serde-saphyr = "0.0.13" diff --git a/crates/umem_ai/src/response_generators/generate_object.rs b/crates/umem_ai/src/response_generators/generate_object.rs index b53a5db..f74454d 100644 --- a/crates/umem_ai/src/response_generators/generate_object.rs +++ b/crates/umem_ai/src/response_generators/generate_object.rs @@ -76,6 +76,15 @@ where pub timeout: Duration, } +impl GenerateObjectRequest +where + T: Clone + JsonSchema + Send + Sync + Serialize + DeserializeOwned, +{ + fn builder() -> GenerateObjectRequestBuilder { + GenerateObjectRequestBuilder::new() + } +} + #[derive(Debug)] pub struct GenerateObjectResponse { pub output: T, diff --git a/crates/umem_ai/src/response_generators/generate_text.rs b/crates/umem_ai/src/response_generators/generate_text.rs index 2e58bad..6ddd09b 100644 --- a/crates/umem_ai/src/response_generators/generate_text.rs +++ b/crates/umem_ai/src/response_generators/generate_text.rs @@ -65,6 +65,12 @@ pub struct GenerateTextRequest { pub timeout: Duration, } +impl GenerateTextRequest { + pub fn builder() -> GenerateTextRequestBuilder { + GenerateTextRequestBuilder::new() + } +} + #[derive(Debug, Error)] pub enum GenerateTextRequestBuilderError { #[error("either set the `system` field or provide a system message in `messages` array")] diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index 9ff7f9c..c00470b 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -1,9 +1,13 @@ pub mod generate_object; -mod generate_text; +pub mod generate_text; pub mod messages; +pub mod rerank; +pub mod structured_rerank; pub use generate_object::*; pub use generate_text::*; +pub use messages::*; +pub use rerank::*; use thiserror::Error; #[derive(Error, Debug)] diff --git a/crates/umem_ai/src/response_generators/rerank.rs b/crates/umem_ai/src/response_generators/rerank.rs new file mode 100644 index 0000000..eacdb81 --- /dev/null +++ b/crates/umem_ai/src/response_generators/rerank.rs @@ -0,0 +1,83 @@ +use crate::ResponseGeneratorError; + +pub async fn rerank(request: RerankRequest) -> Result { + unimplemented!() +} + +pub struct RerankRequest { + pub query: String, + pub documents: Vec, + pub top_n: usize, +} + +impl RerankRequest { + pub fn builder() -> RerankRequestBuilder { + RerankRequestBuilder { + top_n: 5, + ..Default::default() + } + } +} + +#[derive(Default)] +pub struct RerankRequestBuilder { + query: Option, + documents: Vec, + top_n: usize, +} + +#[derive(thiserror::Error, Debug)] +pub enum RerankRequestBuilderError { + #[error("missing query from rerank request")] + MissingQuery, + #[error("at least one document is required in rerank request")] + EmptyDocuments, +} + +impl RerankRequestBuilder { + pub fn query(mut self, query: impl Into) -> Self { + self.query = Some(query.into()); + self + } + + pub fn documents(mut self, documents: I) -> Self + where + I: IntoIterator, + { + self.documents.extend(documents); + self + } + + pub fn document(mut self, document: impl Into) -> Self { + self.documents.push(document.into()); + self + } + + pub fn top_k(mut self, top_n: usize) -> Self { + self.top_n = top_n; + self + } + + pub fn build(self) -> Result { + if self.documents.is_empty() { + return Err(RerankRequestBuilderError::EmptyDocuments); + } + + Ok(RerankRequest { + query: self.query.ok_or(RerankRequestBuilderError::MissingQuery)?, + documents: self.documents, + top_n: self.top_n, + }) + } +} + +pub struct RerankResponse { + pub rankings: Vec, + pub ranked_documents: Vec, +} + +pub struct Ranking { + pub original_index: usize, + pub score: f32, + pub document: String, +} diff --git a/crates/umem_ai/src/response_generators/structured_rerank.rs b/crates/umem_ai/src/response_generators/structured_rerank.rs new file mode 100644 index 0000000..dcf3b27 --- /dev/null +++ b/crates/umem_ai/src/response_generators/structured_rerank.rs @@ -0,0 +1,113 @@ +use crate::ResponseGeneratorError; +use serde::{de::DeserializeOwned, Serialize}; +use std::marker::PhantomData; + +pub async fn structured_rerank( + request: StructuredRerankRequest, +) -> Result, ResponseGeneratorError> +where + T: Serialize + DeserializeOwned, +{ + unimplemented!() +} + +#[derive(Clone)] +pub struct StructuredRerankRequest +where + T: Serialize + DeserializeOwned, +{ + pub query: String, + pub documents: Vec, + pub top_n: usize, + pub output_type: PhantomData, +} + +impl StructuredRerankRequest +where + T: Serialize + DeserializeOwned, +{ + pub fn builder() -> StructuredRerankRequestBuilder { + StructuredRerankRequestBuilder { + top_n: 5, + output_type: PhantomData, + query: None, + documents: vec![], + } + } +} + +pub struct StructuredRerankRequestBuilder +where + T: Serialize + DeserializeOwned, +{ + query: Option, + documents: Vec, + top_n: usize, + pub output_type: PhantomData, +} + +#[derive(thiserror::Error, Debug)] +pub enum RerankRequestBuilderError { + #[error("missing query from rerank request")] + MissingQuery, + #[error("at least one document is required in rerank request")] + EmptyDocuments, +} + +impl StructuredRerankRequestBuilder +where + T: Serialize + DeserializeOwned, +{ + pub fn query(mut self, query: impl Into) -> Self { + self.query = Some(query.into()); + self + } + + pub fn documents(mut self, documents: I) -> Self + where + I: IntoIterator, + { + self.documents.extend(documents); + self + } + + pub fn document(mut self, document: impl Into) -> Self { + self.documents.push(document.into()); + self + } + + pub fn top_k(mut self, top_n: usize) -> Self { + self.top_n = top_n; + self + } + + pub fn build(self) -> Result, RerankRequestBuilderError> { + if self.documents.is_empty() { + return Err(RerankRequestBuilderError::EmptyDocuments); + } + + Ok(StructuredRerankRequest { + query: self.query.ok_or(RerankRequestBuilderError::MissingQuery)?, + documents: self.documents, + top_n: self.top_n, + output_type: PhantomData, + }) + } +} + +pub struct StructuredRerankResponse +where + T: Serialize + DeserializeOwned, +{ + pub rankings: Vec>, + pub ranked_documents: Vec, +} + +pub struct StructuredRanking +where + T: Serialize + DeserializeOwned, +{ + original_index: usize, + score: f32, + document: T, +} From e6fd15e2e2be357d4588a5e3a9b7676f67edd43b Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Tue, 13 Jan 2026 23:31:06 -0800 Subject: [PATCH 10/12] feat(rerank): add reranking support for Cohere and Amazon Bedrock providers --- Cargo.lock | 36 +++- crates/umem_ai/Cargo.toml | 5 +- crates/umem_ai/src/lib.rs | 72 +++++++- .../umem_ai/src/providers/amazon_bedrock.rs | 138 ++++++++++++-- crates/umem_ai/src/providers/anthropic.rs | 12 +- crates/umem_ai/src/providers/azure_openai.rs | 4 +- crates/umem_ai/src/providers/cohere.rs | 164 +++++++++++++++++ crates/umem_ai/src/providers/google_vertex.rs | 2 +- crates/umem_ai/src/providers/mod.rs | 4 +- crates/umem_ai/src/providers/openai.rs | 41 +++-- crates/umem_ai/src/providers/xai.rs | 2 +- .../response_generators/generate_object.rs | 6 +- .../src/response_generators/generate_text.rs | 4 +- crates/umem_ai/src/response_generators/mod.rs | 7 + .../umem_ai/src/response_generators/rerank.rs | 80 ++++++++- .../response_generators/structured_rerank.rs | 170 ++++++++++++++---- crates/umem_ai/src/utils.rs | 11 +- 17 files changed, 665 insertions(+), 93 deletions(-) create mode 100644 crates/umem_ai/src/providers/cohere.rs diff --git a/Cargo.lock b/Cargo.lock index 631ab99..f32204b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,6 +262,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-bedrockagentruntime" +version = "1.119.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d098831f9c542a92bb9458f21a9a8d11afa55a72e06d9bc117bb7353e58d9faa" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-bedrockruntime" version = "1.120.0" @@ -3805,9 +3828,9 @@ dependencies = [ [[package]] name = "serde-saphyr" -version = "0.0.13" +version = "0.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be7f28bb35aab6f9cdc5f464f9bbda0b1e3908766e819557111cc13e58ce7915" +checksum = "45400dbf0a2c4c2af106c08eb1028b3e45bea3cd45517d3a0e8170b86597122f" dependencies = [ "ahash", "annotate-snippets", @@ -3815,11 +3838,11 @@ dependencies = [ "encoding_rs_io", "nohash-hasher", "num-traits", - "ryu", "saphyr-parser", "serde", "serde_json", "smallvec 2.0.0-alpha.12", + "zmij", ] [[package]] @@ -4933,6 +4956,7 @@ dependencies = [ "anyhow", "async-trait", "aws-config", + "aws-sdk-bedrockagentruntime", "aws-sdk-bedrockruntime", "aws-smithy-types", "backon", @@ -5849,3 +5873,9 @@ dependencies = [ "quote", "syn", ] + +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index 5abd232..1392093 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "umem_ai" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] anyhow.workspace = true @@ -23,4 +23,5 @@ umem_config = {workspace = true} aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = "1.120.0" aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]} -serde-saphyr = "0.0.13" +serde-saphyr = "0.0.14" +aws-sdk-bedrockagentruntime = "1.119.0" diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index dce868c..689b70c 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -2,18 +2,16 @@ #![allow(dead_code)] mod providers; mod response_generators; -pub(crate) mod utils; - -use std::sync::Arc; +mod utils; use anyhow::Result; use async_trait::async_trait; use lazy_static::lazy_static; -use schemars::JsonSchema; -use serde::{de::DeserializeOwned, Serialize}; - pub use providers::*; pub use response_generators::*; +use schemars::JsonSchema; +use serde::{Serialize, de::DeserializeOwned}; +use std::sync::Arc; use umem_config::CONFIG; pub type HashMap = rustc_hash::FxHashMap; @@ -40,6 +38,25 @@ impl LanguageModel { } } +#[derive(Clone)] +pub struct RerankingModel { + pub provider: Arc, + pub model_name: String, +} + +impl RerankingModel { + fn new(provider: Arc, model_name: String) -> Self { + Self { + provider, + model_name, + } + } + + pub fn get_model() -> Arc { + Arc::clone(&LANGUAGE_MODEL) + } +} + pub enum AIProvider { OpenAI(OpenAIProvider), AzureOpenAI(AzureOpenAIProvider), @@ -47,6 +64,7 @@ pub enum AIProvider { Anthropic(AnthropicProvider), XAI(XAIProvider), AmazonBedrock(AmazonBedrockProvider), + Cohere(CohereProvider), } lazy_static! { @@ -96,6 +114,30 @@ impl AIProvider { } .await } + + pub(crate) async fn do_reranking( + &self, + request: RerankRequest, + ) -> Result { + match self { + AIProvider::Cohere(provider) => provider.rerank(request), + _ => unimplemented!(), + } + .await + } + + pub(crate) async fn do_structured_reranking( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + match self { + AIProvider::Cohere(provider) => provider.rerank_structured(request).await, + _ => unimplemented!(), + } + } } #[async_trait] @@ -114,6 +156,24 @@ pub trait GeneratesObject { ) -> Result, ResponseGeneratorError>; } +#[async_trait] +pub trait Reranks { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result; +} + +#[async_trait] +pub trait ReranksStructuredData { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync; +} + impl From for AIProvider { fn from(config: OpenAIProvider) -> Self { AIProvider::OpenAI(config) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index dabd3dc..86b2793 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -1,14 +1,22 @@ use crate::{ + GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, OpenAIProvider, + Ranking, RerankRequest, RerankResponse, Reranks, messages::{FilePart, UserModelMessage}, response_generators::{ self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, }, - utils, GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, - OpenAIProvider, + utils, }; +use anyhow::Result; use async_trait::async_trait; use aws_config::{BehaviorVersion, Region}; +use aws_sdk_bedrockagentruntime::types::{ + BedrockRerankingConfiguration, BedrockRerankingModelConfiguration, RerankDocument, + RerankDocumentType, RerankQuery, RerankQueryContentType, RerankSource, RerankSourceType, + RerankTextDocument, RerankingConfiguration, +}; use aws_sdk_bedrockruntime::{ + error::BuildError, operation::converse::builders::ConverseFluentBuilder, types::{ AnyToolChoice, ContentBlock, ConverseOutput, ImageBlock, InferenceConfiguration, Message, @@ -17,12 +25,14 @@ use aws_sdk_bedrockruntime::{ }; use base64::Engine; use schemars::JsonSchema; -use serde::{de::DeserializeOwned, Serialize}; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::Map; use std::sync::Arc; use thiserror::Error; pub struct AmazonBedrockProvider { - client: Arc, + bedrockruntime_client: Arc, + bedrockagentruntime_client: Arc, } #[async_trait] @@ -138,7 +148,7 @@ impl GeneratesObject for AmazonBedrockProvider { _ => { return Err(ResponseGeneratorError::InvalidProviderResponse( "was expecting the output to be a bedrock message".into(), - )) + )); } }; @@ -165,6 +175,107 @@ impl GeneratesObject for AmazonBedrockProvider { } } +#[async_trait] +impl Reranks for AmazonBedrockProvider { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result { + let inline_sources: Vec = request + .documents + .iter() + .map(|document| { + RerankSource::builder() + .inline_document_source( + RerankDocument::builder() + .r#type(RerankDocumentType::Text) + .text_document(RerankTextDocument::builder().text(document).build()) + .build()?, + ) + .r#type(RerankSourceType::Inline) + .build() + }) + .collect::>() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build Bedrock-compat Rerank Sources from provided documents, Error: {}", e + )) + })?; + + let response = self + .bedrockagentruntime_client + .rerank() + .queries( + RerankQuery::builder() + .r#type(RerankQueryContentType::Text) + .text_query(RerankTextDocument::builder().text(&request.query).build()) + .build() + .map_err(|e| ResponseGeneratorError::InvalidArgumentsProvided( + format!("Failed to build RerankQuery, Details: {}", e) + ))? + ) + .set_sources( + Some(inline_sources) + ) + .reranking_configuration( + RerankingConfiguration::builder() + .bedrock_reranking_configuration( + BedrockRerankingConfiguration::builder() + .model_configuration( + BedrockRerankingModelConfiguration::builder() + .model_arn(&request.model.model_name) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided( + format!("Failed to build BedrockRerankingModelConfiguration, Details: {}", e) + ) + })?, + ) + .number_of_results(request.top_n as i32) + .build(), + ) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided( + format!("Failed to build RerankingConfiguration, Details: {}", e) + ) + })?, + ) + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?; + + let results = response.results(); + + let (rankings, ranked_documents) = results.iter().try_fold( + ( + Vec::with_capacity(results.len()), + Vec::with_capacity(results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index as usize).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Bedrock Rerank API returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(Ranking { + original_index: result.index as usize, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(RerankResponse { + rankings, + ranked_documents, + raw_fields: Map::with_capacity(0_usize), + }) + } +} + impl AmazonBedrockProvider { fn normalize_generate_object_request< T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, @@ -177,7 +288,7 @@ impl AmazonBedrockProvider { let output_schema_value = serde_json::to_value(&request.output_schema)?; Ok(self - .client + .bedrockruntime_client .converse() .model_id(request.model.model_name.clone()) .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( @@ -214,7 +325,7 @@ impl AmazonBedrockProvider { let user_messages = Self::normalize_user_messages(&request.messages)?; Ok(self - .client + .bedrockruntime_client .converse() .model_id(request.model.model_name.clone()) .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( @@ -302,6 +413,10 @@ impl AmazonBedrockProvider { Ok(user_message_content_blocks) } + + fn builder() -> AmazonBedrockProviderBuilder { + AmazonBedrockProviderBuilder::new() + } } #[derive(Default)] @@ -370,7 +485,10 @@ impl AmazonBedrockProviderBuilder { .await; Ok(AmazonBedrockProvider { - client: Arc::new(aws_sdk_bedrockruntime::Client::new(&sdk_config)), + bedrockruntime_client: Arc::new(aws_sdk_bedrockruntime::Client::new(&sdk_config)), + bedrockagentruntime_client: Arc::new(aws_sdk_bedrockagentruntime::Client::new( + &sdk_config, + )), }) } } @@ -378,8 +496,8 @@ impl AmazonBedrockProviderBuilder { #[cfg(test)] mod tests { use crate::{ - generate_object, generate_text, AIProvider, GenerateObjectRequestBuilder, - GenerateTextRequestBuilder, LanguageModel, + AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, LanguageModel, + generate_object, generate_text, }; use serde::Deserialize; use std::sync::Arc; diff --git a/crates/umem_ai/src/providers/anthropic.rs b/crates/umem_ai/src/providers/anthropic.rs index 7e694ab..2a20e85 100644 --- a/crates/umem_ai/src/providers/anthropic.rs +++ b/crates/umem_ai/src/providers/anthropic.rs @@ -1,8 +1,8 @@ +use crate::GeneratesText; +use crate::ResponseGeneratorError; use crate::response_generators::GenerateTextRequest; use crate::response_generators::GenerateTextResponse; use crate::utils; -use crate::GeneratesText; -use crate::ResponseGeneratorError; use anyhow::Result; use async_trait::async_trait; use reqwest::header::HeaderMap; @@ -19,9 +19,7 @@ pub struct AnthropicProvider { #[builder(default = "https://api.anthropic.com/v1".into())] pub base_url: String, - #[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| - utils::build_header_map(value.as_slice()).unwrap_or_default() - ))] + #[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| utils::build_header_map(value.as_slice()).unwrap_or_default()))] pub headers: HeaderMap, } @@ -41,7 +39,9 @@ mod tests { #[test] fn test_building_anthropic_provider() { - let provider = AnthropicProvider::builder().api_key("sk-some-api-key").build(); + let provider = AnthropicProvider::builder() + .api_key("sk-some-api-key") + .build(); dbg!("Anthropic Provider: {:?}", provider); } } diff --git a/crates/umem_ai/src/providers/azure_openai.rs b/crates/umem_ai/src/providers/azure_openai.rs index 62d841e..7f9ce28 100644 --- a/crates/umem_ai/src/providers/azure_openai.rs +++ b/crates/umem_ai/src/providers/azure_openai.rs @@ -1,8 +1,8 @@ use crate::{ - response_generators::{GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError}, GeneratesText, + response_generators::{GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError}, }; -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; use async_trait::async_trait; pub struct AzureOpenAIProvider { pub resource_name: Option, diff --git a/crates/umem_ai/src/providers/cohere.rs b/crates/umem_ai/src/providers/cohere.rs new file mode 100644 index 0000000..de3ef04 --- /dev/null +++ b/crates/umem_ai/src/providers/cohere.rs @@ -0,0 +1,164 @@ +use crate::{ + Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, ResponseGeneratorError, + SerializationFormat, SerializationMode, StructuredRanking, StructuredRerankRequest, + StructuredRerankResponse, reqwest_client, utils, +}; +use async_trait::async_trait; +use reqwest::header::HeaderMap; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value, json}; +use typed_builder::TypedBuilder; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohereRerankAPIV2Response { + pub results: Vec, + #[serde(flatten)] + pub raw_fields: Map, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohereRerankResult { + pub index: usize, + pub relevance_score: f32, +} + +#[derive(TypedBuilder, Debug, Clone)] +pub struct CohereProvider { + #[builder(default = "https://api.cohere.com/v2".into(), setter(transform = |value: impl Into| value.into()))] + base_url: String, + + #[builder(setter(transform = |value: impl Into| value.into()))] + api_key: String, + + #[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| + utils::build_header_map(value.as_slice()).unwrap_or_default() + ))] + headers: HeaderMap, +} + +#[async_trait] +impl Reranks for CohereProvider { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result { + let response = reqwest_client + .post(format!("{}/rerank", self.base_url)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .header("Authorization", &format!("bearer {}", &self.api_key)) + .json(&json!({ + "model": &request.model.model_name, + "query": &request.query, + "documents": &request.documents, + "top_n": request.top_n, + })) + .send() + .await? + .error_for_status()? + .json::() + .await?; + + let (rankings, ranked_documents) = response.results.iter().try_fold( + ( + Vec::with_capacity(response.results.len()), + Vec::with_capacity(response.results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Cohere returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(Ranking { + original_index: result.index, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(RerankResponse { + raw_fields: response.raw_fields, + rankings, + ranked_documents, + }) + } +} + +#[async_trait] +impl ReranksStructuredData for CohereProvider { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + let serialized_documents: Vec = request + .documents + .iter() + .map(|doc| { + match (request.serialization_mode, request.serialization_format) { + (SerializationMode::Json, SerializationFormat::Compact) => { + serde_json::to_string(doc).map_err(|e| e.to_string()) + } + (SerializationMode::Json, SerializationFormat::Pretty) => { + serde_json::to_string_pretty(doc).map_err(|e| e.to_string()) + } + (SerializationMode::Yaml, _) => { + serde_saphyr::to_string(doc).map_err(|e| e.to_string()) + } + } + .map_err(ResponseGeneratorError::StructuredRerankDocumentsSerializationError) + }) + .collect::, _>>()?; + + let response = reqwest_client + .post(format!("{}/rerank", self.base_url)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .header("Authorization", &format!("bearer {}", &self.api_key)) + .headers(self.headers.clone()) + .json(&json!({ + "model": &request.model.model_name, + "query": &request.query, + "documents": &serialized_documents, + "top_n": request.top_n, + })) + .send() + .await? + .error_for_status()? + .json::() + .await?; + + let (rankings, ranked_documents) = response.results.iter().try_fold( + ( + Vec::with_capacity(response.results.len()), + Vec::with_capacity(response.results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Cohere returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(StructuredRanking { + original_index: result.index, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(StructuredRerankResponse { + rankings, + ranked_documents, + raw_fields: response.raw_fields, + }) + } +} diff --git a/crates/umem_ai/src/providers/google_vertex.rs b/crates/umem_ai/src/providers/google_vertex.rs index c117961..553496b 100644 --- a/crates/umem_ai/src/providers/google_vertex.rs +++ b/crates/umem_ai/src/providers/google_vertex.rs @@ -1,4 +1,4 @@ -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; pub struct GoogleVertexAIProvider { pub project: String, diff --git a/crates/umem_ai/src/providers/mod.rs b/crates/umem_ai/src/providers/mod.rs index e996961..96567dc 100644 --- a/crates/umem_ai/src/providers/mod.rs +++ b/crates/umem_ai/src/providers/mod.rs @@ -1,6 +1,7 @@ mod amazon_bedrock; mod anthropic; mod azure_openai; +mod cohere; mod google_vertex; mod openai; mod xai; @@ -8,6 +9,7 @@ mod xai; pub use amazon_bedrock::AmazonBedrockProvider; pub use anthropic::AnthropicProvider; pub use azure_openai::AzureOpenAIProvider; +pub use cohere::CohereProvider; pub use google_vertex::GoogleVertexAIProvider; -pub use openai::*; +pub use openai::OpenAIProvider; pub use xai::XAIProvider; diff --git a/crates/umem_ai/src/providers/openai.rs b/crates/umem_ai/src/providers/openai.rs index 6cf2fc2..a59aa70 100644 --- a/crates/umem_ai/src/providers/openai.rs +++ b/crates/umem_ai/src/providers/openai.rs @@ -1,17 +1,17 @@ use crate::{ - reqwest_client, + GeneratesObject, GeneratesText, reqwest_client, response_generators::{ + GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, generate_object::{GenerateObjectRequest, GenerateObjectResponse}, messages::{FilePart, Message, UserMessagePart, UserModelMessage}, - GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, }, - utils, GeneratesObject, GeneratesText, + utils, }; use async_trait::async_trait; use base64::Engine; use reqwest::header::HeaderMap; use schemars::JsonSchema; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{Map, Value}; use typed_builder::TypedBuilder; @@ -23,9 +23,7 @@ pub struct OpenAIProvider { #[builder(default = "https://api.openai.com/v1".into(), setter(transform = |value: impl Into| value.into()))] pub base_url: String, - #[builder(default, setter(transform = |value: Vec<(String, String)>| - utils::build_header_map(value.as_slice()).unwrap_or_default() - ))] + #[builder(default, setter(transform = |value: Vec<(String, String)>| utils::build_header_map(value.as_slice()).unwrap_or_default()))] pub default_headers: HeaderMap, #[builder(default)] @@ -135,28 +133,28 @@ impl OpenAIProvider { serde_json::json!({"type":"input_text","text":input_text}) } UserMessagePart::Image(image_part) => match image_part { - FilePart::Url(ref image_url, _) => { + FilePart::Url(image_url, _) => { serde_json::json!({"type":"input_image","image_url":image_url}) } - FilePart::Base64(ref b64,ref media_type) => { + FilePart::Base64(b64, media_type) => { let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_image","image_url":format!("data:{};base64,{}", media_type.to_string(), b64)}) } - FilePart::Buffer(buf,ref media_type) => { + FilePart::Buffer(buf, media_type) => { let buf_as_b64 = base64::engine::general_purpose::STANDARD.encode(buf); let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_image","image_url":format!("data:{};base64,{}", media_type.to_string(), buf_as_b64)}) }, }, - UserMessagePart::File(file_part) => match file_part{ - FilePart::Url(ref image_url, _) => { + UserMessagePart::File(file_part) => match file_part { + FilePart::Url(image_url, _) => { serde_json::json!({"type":"input_file","file_url":image_url}) } - FilePart::Base64(ref b64,ref media_type) => { + FilePart::Base64(b64, media_type) => { let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_file","file_url":format!("data:{};base64,{}", media_type.to_string(), b64)}) } - FilePart::Buffer(buf,ref media_type) => { + FilePart::Buffer(buf, media_type) => { let buf_as_b64 = base64::engine::general_purpose::STANDARD.encode(buf); let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_file","file_url":format!("data:{};base64,{}", media_type.to_string(), buf_as_b64)}) @@ -264,8 +262,8 @@ impl GeneratesObject for OpenAIProvider { }) .unwrap_or_default(); - let output: T = - serde_json::from_str(&output_text).map_err(|e| ResponseGeneratorError::Deserialization(e, output_text))?; + let output: T = serde_json::from_str(&output_text) + .map_err(|e| ResponseGeneratorError::Deserialization(e, output_text))?; Ok(GenerateObjectResponse { output }) } @@ -347,11 +345,12 @@ mod tests { use super::*; use crate::{ + AIProvider, LanguageModel, response_generators::{ - generate_object::{generate_object, GenerateObjectRequestBuilder}, - generate_text, GenerateTextRequestBuilder, + GenerateTextRequestBuilder, + generate_object::{GenerateObjectRequestBuilder, generate_object}, + generate_text, }, - AIProvider, LanguageModel, }; use std::sync::Arc; @@ -361,7 +360,7 @@ mod tests { OpenAIProvider::builder() .api_key("") .base_url("https://openrouter.ai/api/v1") - .build() + .build(), )); #[derive(Clone, JsonSchema, Serialize, Deserialize, Debug)] @@ -395,7 +394,7 @@ mod tests { OpenAIProvider::builder() .api_key("") .base_url("https://openrouter.ai/api/v1") - .build() + .build(), )); let model = Arc::new(LanguageModel { diff --git a/crates/umem_ai/src/providers/xai.rs b/crates/umem_ai/src/providers/xai.rs index 3610a29..cc9d834 100644 --- a/crates/umem_ai/src/providers/xai.rs +++ b/crates/umem_ai/src/providers/xai.rs @@ -1,4 +1,4 @@ -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; pub struct XAIProvider { pub api_key: String, diff --git a/crates/umem_ai/src/response_generators/generate_object.rs b/crates/umem_ai/src/response_generators/generate_object.rs index f74454d..33b6e5b 100644 --- a/crates/umem_ai/src/response_generators/generate_object.rs +++ b/crates/umem_ai/src/response_generators/generate_object.rs @@ -1,9 +1,9 @@ +use crate::{LanguageModel, ResponseGeneratorError, utils}; use crate::{response_generators::messages::Message, utils::is_retryable_error}; -use crate::{utils, LanguageModel, ResponseGeneratorError}; use backon::{ExponentialBuilder, Retryable}; use reqwest::header::HeaderMap; -use schemars::{schema_for, JsonSchema, Schema}; -use serde::{de::DeserializeOwned, Serialize}; +use schemars::{JsonSchema, Schema, schema_for}; +use serde::{Serialize, de::DeserializeOwned}; use std::time::Duration; use std::{marker::PhantomData, sync::Arc}; use thiserror::Error; diff --git a/crates/umem_ai/src/response_generators/generate_text.rs b/crates/umem_ai/src/response_generators/generate_text.rs index 6ddd09b..64446fa 100644 --- a/crates/umem_ai/src/response_generators/generate_text.rs +++ b/crates/umem_ai/src/response_generators/generate_text.rs @@ -1,8 +1,8 @@ +use crate::LanguageModel; +use crate::ResponseGeneratorError; use crate::response_generators::messages::Message; use crate::utils; use crate::utils::is_retryable_error; -use crate::LanguageModel; -use crate::ResponseGeneratorError; use backon::ExponentialBuilder; use backon::Retryable; use reqwest::header::HeaderMap; diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index c00470b..b967ef0 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -8,6 +8,7 @@ pub use generate_object::*; pub use generate_text::*; pub use messages::*; pub use rerank::*; +pub use structured_rerank::*; use thiserror::Error; #[derive(Error, Debug)] @@ -20,10 +21,16 @@ pub enum ResponseGeneratorError { TimeoutError(#[from] tokio::time::error::Elapsed), #[error("Bedrock Converse API error, Details: {0}")] BedrockConverseError(String), + #[error("BedrockAgentRuntime Rerank Command error, Details: {0}")] + BedrockAgentRerankCommandSendError(String), #[error("empty response from AI provider")] EmptyProviderResponse, #[error("invalid response from AI provider, Details: {0}")] InvalidProviderResponse(String), + #[error("invalid arguments provided: {0}")] + InvalidArgumentsProvided(String), #[error(transparent)] Transient(#[from] anyhow::Error), + #[error("yaml serialization error: {0}")] + StructuredRerankDocumentsSerializationError(String), } diff --git a/crates/umem_ai/src/response_generators/rerank.rs b/crates/umem_ai/src/response_generators/rerank.rs index eacdb81..b9bc260 100644 --- a/crates/umem_ai/src/response_generators/rerank.rs +++ b/crates/umem_ai/src/response_generators/rerank.rs @@ -1,13 +1,49 @@ -use crate::ResponseGeneratorError; +use backon::{ExponentialBuilder, Retryable}; +use serde_json::{Map, Value}; + +use crate::{RerankingModel, ResponseGeneratorError, utils::is_retryable_error}; +use std::{sync::Arc, time::Duration}; pub async fn rerank(request: RerankRequest) -> Result { - unimplemented!() + let per_request_timeout = request.timeout; + let max_retries = request.max_retries; + let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0); + + let reranking_request = || { + let model = Arc::clone(&request.model); + let provider = Arc::clone(&model.provider); + let request = request.clone(); + + async move { + tokio::time::timeout(per_request_timeout, provider.do_reranking(request)) + .await + .map_err(ResponseGeneratorError::TimeoutError) + .flatten() + } + }; + + reranking_request + .retry( + ExponentialBuilder::default() + .with_max_times(max_retries) + .with_total_delay(Some(total_delay)), + ) + .sleep(tokio::time::sleep) + .when(is_retryable_error) + .notify(|err, dur| { + tracing::debug!("retrying {:?} after {:?}", err, dur); + }) + .await } +#[derive(Clone)] pub struct RerankRequest { pub query: String, pub documents: Vec, pub top_n: usize, + pub timeout: Duration, + pub max_retries: usize, + pub model: Arc, } impl RerankRequest { @@ -19,19 +55,38 @@ impl RerankRequest { } } -#[derive(Default)] pub struct RerankRequestBuilder { query: Option, documents: Vec, top_n: usize, + timeout: Duration, + max_retries: usize, + model: Option>, +} + +impl Default for RerankRequestBuilder { + fn default() -> Self { + Self { + query: None, + documents: vec![], + top_n: 5, + timeout: Duration::from_secs(30), + max_retries: 3, + model: None, + } + } } #[derive(thiserror::Error, Debug)] pub enum RerankRequestBuilderError { #[error("missing query from rerank request")] MissingQuery, + #[error("at least one document is required in rerank request")] EmptyDocuments, + + #[error("missing model while sending reranking request")] + MissingModel, } impl RerankRequestBuilder { @@ -58,6 +113,21 @@ impl RerankRequestBuilder { self } + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + pub fn model(mut self, model: Arc) -> Self { + self.model = Some(model); + self + } + pub fn build(self) -> Result { if self.documents.is_empty() { return Err(RerankRequestBuilderError::EmptyDocuments); @@ -67,6 +137,9 @@ impl RerankRequestBuilder { query: self.query.ok_or(RerankRequestBuilderError::MissingQuery)?, documents: self.documents, top_n: self.top_n, + timeout: self.timeout, + max_retries: self.max_retries, + model: Arc::clone(&self.model.ok_or(RerankRequestBuilderError::MissingModel)?), }) } } @@ -74,6 +147,7 @@ impl RerankRequestBuilder { pub struct RerankResponse { pub rankings: Vec, pub ranked_documents: Vec, + pub raw_fields: Map, } pub struct Ranking { diff --git a/crates/umem_ai/src/response_generators/structured_rerank.rs b/crates/umem_ai/src/response_generators/structured_rerank.rs index dcf3b27..4f0df7a 100644 --- a/crates/umem_ai/src/response_generators/structured_rerank.rs +++ b/crates/umem_ai/src/response_generators/structured_rerank.rs @@ -1,62 +1,136 @@ -use crate::ResponseGeneratorError; -use serde::{de::DeserializeOwned, Serialize}; -use std::marker::PhantomData; +use std::{sync::Arc, time::Duration}; + +use backon::{ExponentialBuilder, Retryable}; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::{Map, Value}; + +use crate::{RerankingModel, ResponseGeneratorError, utils::is_retryable_error}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum SerializationFormat { + #[default] + Compact, + Pretty, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum SerializationMode { + #[default] + Json, + Yaml, +} pub async fn structured_rerank( request: StructuredRerankRequest, ) -> Result, ResponseGeneratorError> where - T: Serialize + DeserializeOwned, + T: Serialize + Clone + Send + Sync, { - unimplemented!() + let per_request_timeout = request.timeout; + let max_retries = request.max_retries; + let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0); + + let reranking_request = || { + let model = Arc::clone(&request.model); + let provider = Arc::clone(&model.provider); + let request = request.clone(); + + async move { + tokio::time::timeout( + per_request_timeout, + provider.do_structured_reranking(request), + ) + .await + .map_err(ResponseGeneratorError::TimeoutError) + .flatten() + } + }; + + reranking_request + .retry( + ExponentialBuilder::default() + .with_max_times(max_retries) + .with_total_delay(Some(total_delay)), + ) + .sleep(tokio::time::sleep) + .when(is_retryable_error) + .notify(|err, dur| { + tracing::debug!("retrying {:?} after {:?}", err, dur); + }) + .await } #[derive(Clone)] pub struct StructuredRerankRequest where - T: Serialize + DeserializeOwned, + T: Serialize + Clone, { pub query: String, pub documents: Vec, pub top_n: usize, - pub output_type: PhantomData, + pub timeout: Duration, + pub max_retries: usize, + pub model: Arc, + pub serialization_format: SerializationFormat, + pub serialization_mode: SerializationMode, } impl StructuredRerankRequest where - T: Serialize + DeserializeOwned, + T: Serialize + Clone, { pub fn builder() -> StructuredRerankRequestBuilder { - StructuredRerankRequestBuilder { - top_n: 5, - output_type: PhantomData, - query: None, - documents: vec![], - } + StructuredRerankRequestBuilder::default() } } pub struct StructuredRerankRequestBuilder where - T: Serialize + DeserializeOwned, + T: Serialize + Clone, { query: Option, documents: Vec, top_n: usize, - pub output_type: PhantomData, + timeout: Duration, + max_retries: usize, + model: Option>, + serialization_format: SerializationFormat, + serialization_mode: SerializationMode, +} + +impl Default for StructuredRerankRequestBuilder +where + T: Serialize + Clone, +{ + fn default() -> Self { + Self { + query: None, + documents: vec![], + top_n: 5, + timeout: Duration::from_secs(30), + max_retries: 3, + model: None, + serialization_format: SerializationFormat::default(), + serialization_mode: SerializationMode::default(), + } + } } #[derive(thiserror::Error, Debug)] -pub enum RerankRequestBuilderError { - #[error("missing query from rerank request")] +pub enum StructuredRerankRequestBuilderError { + #[error("missing query from structured rerank request")] MissingQuery, - #[error("at least one document is required in rerank request")] + + #[error("at least one document is required in structured rerank request")] EmptyDocuments, + + #[error("missing model while sending structured reranking request")] + MissingModel, } impl StructuredRerankRequestBuilder where - T: Serialize + DeserializeOwned, + T: Serialize + DeserializeOwned + Clone, { pub fn query(mut self, query: impl Into) -> Self { self.query = Some(query.into()); @@ -71,8 +145,8 @@ where self } - pub fn document(mut self, document: impl Into) -> Self { - self.documents.push(document.into()); + pub fn document(mut self, document: T) -> Self { + self.documents.push(document); self } @@ -81,33 +155,67 @@ where self } - pub fn build(self) -> Result, RerankRequestBuilderError> { + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + pub fn model(mut self, model: Arc) -> Self { + self.model = Some(model); + self + } + + pub fn serialization_format(mut self, format: SerializationFormat) -> Self { + self.serialization_format = format; + self + } + + pub fn serialization_mode(mut self, mode: SerializationMode) -> Self { + self.serialization_mode = mode; + self + } + + pub fn build(self) -> Result, StructuredRerankRequestBuilderError> { if self.documents.is_empty() { - return Err(RerankRequestBuilderError::EmptyDocuments); + return Err(StructuredRerankRequestBuilderError::EmptyDocuments); } Ok(StructuredRerankRequest { - query: self.query.ok_or(RerankRequestBuilderError::MissingQuery)?, + query: self + .query + .ok_or(StructuredRerankRequestBuilderError::MissingQuery)?, documents: self.documents, top_n: self.top_n, - output_type: PhantomData, + timeout: self.timeout, + max_retries: self.max_retries, + model: self + .model + .ok_or(StructuredRerankRequestBuilderError::MissingModel)?, + serialization_format: self.serialization_format, + serialization_mode: self.serialization_mode, }) } } pub struct StructuredRerankResponse where - T: Serialize + DeserializeOwned, + T: Serialize + Clone, { pub rankings: Vec>, pub ranked_documents: Vec, + pub raw_fields: Map, } pub struct StructuredRanking where - T: Serialize + DeserializeOwned, + T: Serialize + Clone, { - original_index: usize, - score: f32, - document: T, + pub original_index: usize, + pub score: f32, + pub document: T, } diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index 814b915..a6161d8 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -20,7 +20,8 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { ResponseGeneratorError::Deserialization(error, response) => { tracing::warn!( "Serialization error, AI Might have built a bad JSON output: \n Error: {} \n Received Response: {}", - error,response + error, + response ); true } @@ -53,6 +54,14 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { ); true } + ResponseGeneratorError::YamlSerializationError(e) => { + tracing::error!("YAML serialization error: {}", e); + false + } + ResponseGeneratorError::InvalidArgumentsProvided(e) => { + tracing::error!("Invalid arguments provided: {}", e); + false + } } } From 608498cb219dee0714ff06fc765de41c10e37d47 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Tue, 13 Jan 2026 23:50:43 -0800 Subject: [PATCH 11/12] feat: support structured rerank on amazon bedrock --- crates/umem_ai/src/lib.rs | 2 + .../umem_ai/src/providers/amazon_bedrock.rs | 158 +++++++++++++++++- crates/umem_ai/src/utils.rs | 8 +- 3 files changed, 165 insertions(+), 3 deletions(-) diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index 689b70c..4a82a5d 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -121,6 +121,7 @@ impl AIProvider { ) -> Result { match self { AIProvider::Cohere(provider) => provider.rerank(request), + AIProvider::AmazonBedrock(provider) => provider.rerank(request), _ => unimplemented!(), } .await @@ -135,6 +136,7 @@ impl AIProvider { { match self { AIProvider::Cohere(provider) => provider.rerank_structured(request).await, + AIProvider::AmazonBedrock(provider) => provider.rerank_structured(request).await, _ => unimplemented!(), } } diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 86b2793..1492eb9 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -1,6 +1,7 @@ use crate::{ GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, OpenAIProvider, - Ranking, RerankRequest, RerankResponse, Reranks, + Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, SerializationMode, + StructuredRanking, StructuredRerankRequest, StructuredRerankResponse, messages::{FilePart, UserModelMessage}, response_generators::{ self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, @@ -276,6 +277,161 @@ impl Reranks for AmazonBedrockProvider { } } +#[async_trait] +impl ReranksStructuredData for AmazonBedrockProvider { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + let inline_sources: Vec = match request.serialization_mode { + SerializationMode::Json => { + let json_documents: Vec = request + .documents + .iter() + .map(|doc| { + serde_json::to_value(doc) + .map(utils::json_to_aws_smithy_document) + .map_err(|e| { + ResponseGeneratorError::StructuredRerankDocumentsSerializationError( + e.to_string(), + ) + }) + }) + .collect::, _>>()?; + + json_documents + .into_iter() + .map(|document| { + RerankSource::builder() + .inline_document_source( + RerankDocument::builder() + .r#type(RerankDocumentType::Json) + .json_document(document) + .build()?, + ) + .r#type(RerankSourceType::Inline) + .build() + }) + .collect::>() + } + SerializationMode::Yaml => { + let yaml_documents: Vec = request + .documents + .iter() + .map(|doc| { + serde_saphyr::to_string(doc).map_err(|e| { + ResponseGeneratorError::StructuredRerankDocumentsSerializationError( + e.to_string(), + ) + }) + }) + .collect::, _>>()?; + + yaml_documents + .iter() + .map(|document| { + RerankSource::builder() + .inline_document_source( + RerankDocument::builder() + .r#type(RerankDocumentType::Text) + .text_document( + RerankTextDocument::builder().text(document).build(), + ) + .build()?, + ) + .r#type(RerankSourceType::Inline) + .build() + }) + .collect::>() + } + } + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build Bedrock-compat Rerank Sources from provided documents, Error: {}", + e + )) + })?; + + let response = self + .bedrockagentruntime_client + .rerank() + .queries( + RerankQuery::builder() + .r#type(RerankQueryContentType::Text) + .text_query(RerankTextDocument::builder().text(&request.query).build()) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build RerankQuery, Details: {}", + e + )) + })?, + ) + .set_sources(Some(inline_sources)) + .reranking_configuration( + RerankingConfiguration::builder() + .bedrock_reranking_configuration( + BedrockRerankingConfiguration::builder() + .model_configuration( + BedrockRerankingModelConfiguration::builder() + .model_arn(&request.model.model_name) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build BedrockRerankingModelConfiguration, Details: {}", + e + )) + })?, + ) + .number_of_results(request.top_n as i32) + .build(), + ) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build RerankingConfiguration, Details: {}", + e + )) + })?, + ) + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?; + + let results = response.results(); + + let (rankings, ranked_documents) = results.iter().try_fold( + ( + Vec::with_capacity(results.len()), + Vec::with_capacity(results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index as usize).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Bedrock Rerank API returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(StructuredRanking { + original_index: result.index as usize, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(StructuredRerankResponse { + rankings, + ranked_documents, + raw_fields: Map::with_capacity(0_usize), + }) + } +} + impl AmazonBedrockProvider { fn normalize_generate_object_request< T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index a6161d8..8edd6d1 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -54,14 +54,18 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { ); true } - ResponseGeneratorError::YamlSerializationError(e) => { - tracing::error!("YAML serialization error: {}", e); + ResponseGeneratorError::StructuredRerankDocumentsSerializationError(e) => { + tracing::error!("Structured rerank documents serialization error: {}", e); false } ResponseGeneratorError::InvalidArgumentsProvided(e) => { tracing::error!("Invalid arguments provided: {}", e); false } + ResponseGeneratorError::BedrockAgentRerankCommandSendError(e) => { + tracing::error!("Bedrock agent rerank command send error: {}", e); + true + } } } From b1606ab664f6b0ebc606484cb986471f369a6734 Mon Sep 17 00:00:00 2001 From: Vidur Khanal Date: Mon, 19 Jan 2026 20:25:24 -0800 Subject: [PATCH 12/12] feat: add Debug trait to AI provider and model structs, refactor rerank API to use top_k - Derive Debug for AIProvider, RerankingModel, and provider structs (AmazonBedrockProvider, AzureOpenAIProvider, GoogleVertexAIProvider, GoogleCredentials, XAIProvider, StructuredRerankRequest, StructuredRerankResponse, StructuredRanking, RerankResponse, Ranking) - Refactor rerank API to use `top_k` instead of `top_n` across request structs, builders, and provider implementations - Update Amazon Bedrock provider to require region, fix model ARN construction, and add new rerank/structured_rerank tests - Update Cohere provider to use `top_k` for rerank requests --- crates/umem_ai/src/lib.rs | 3 +- .../umem_ai/src/providers/amazon_bedrock.rs | 178 +++++++++++++++++- crates/umem_ai/src/providers/azure_openai.rs | 2 + crates/umem_ai/src/providers/cohere.rs | 2 +- crates/umem_ai/src/providers/google_vertex.rs | 2 + crates/umem_ai/src/providers/xai.rs | 1 + .../umem_ai/src/response_generators/rerank.rs | 16 +- .../response_generators/structured_rerank.rs | 4 +- 8 files changed, 190 insertions(+), 18 deletions(-) diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index 4a82a5d..8c72b2e 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -38,7 +38,7 @@ impl LanguageModel { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct RerankingModel { pub provider: Arc, pub model_name: String, @@ -57,6 +57,7 @@ impl RerankingModel { } } +#[derive(Debug)] pub enum AIProvider { OpenAI(OpenAIProvider), AzureOpenAI(AzureOpenAIProvider), diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 1492eb9..a15f618 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -14,7 +14,7 @@ use aws_config::{BehaviorVersion, Region}; use aws_sdk_bedrockagentruntime::types::{ BedrockRerankingConfiguration, BedrockRerankingModelConfiguration, RerankDocument, RerankDocumentType, RerankQuery, RerankQueryContentType, RerankSource, RerankSourceType, - RerankTextDocument, RerankingConfiguration, + RerankTextDocument, RerankingConfiguration, RerankingConfigurationType, }; use aws_sdk_bedrockruntime::{ error::BuildError, @@ -31,7 +31,9 @@ use serde_json::Map; use std::sync::Arc; use thiserror::Error; +#[derive(Clone, Debug)] pub struct AmazonBedrockProvider { + region: Region, bedrockruntime_client: Arc, bedrockagentruntime_client: Arc, } @@ -220,11 +222,12 @@ impl Reranks for AmazonBedrockProvider { ) .reranking_configuration( RerankingConfiguration::builder() + .r#type(RerankingConfigurationType::BedrockRerankingModel) .bedrock_reranking_configuration( BedrockRerankingConfiguration::builder() .model_configuration( BedrockRerankingModelConfiguration::builder() - .model_arn(&request.model.model_name) + .model_arn(format!("arn:aws:bedrock:{}::foundation-model/{}",&self.region, &request.model.model_name)) .build() .map_err(|e| { ResponseGeneratorError::InvalidArgumentsProvided( @@ -232,7 +235,7 @@ impl Reranks for AmazonBedrockProvider { ) })?, ) - .number_of_results(request.top_n as i32) + .number_of_results(request.top_k as i32) .build(), ) .build() @@ -373,11 +376,15 @@ impl ReranksStructuredData for AmazonBedrockProvider { .set_sources(Some(inline_sources)) .reranking_configuration( RerankingConfiguration::builder() + .r#type(RerankingConfigurationType::BedrockRerankingModel) .bedrock_reranking_configuration( BedrockRerankingConfiguration::builder() .model_configuration( BedrockRerankingModelConfiguration::builder() - .model_arn(&request.model.model_name) + .model_arn(format!( + "arn:aws:bedrock:{}::foundation-model/{}", + &self.region, &request.model.model_name + )) .build() .map_err(|e| { ResponseGeneratorError::InvalidArgumentsProvided(format!( @@ -622,7 +629,7 @@ impl AmazonBedrockProviderBuilder { pub async fn build(self) -> Result { let sdk_config = aws_config::defaults(BehaviorVersion::latest()) - .region(self.region) + .region(self.region.clone()) .credentials_provider( aws_sdk_bedrockruntime::config::Credentials::builder() .access_key_id( @@ -641,6 +648,9 @@ impl AmazonBedrockProviderBuilder { .await; Ok(AmazonBedrockProvider { + region: self + .region + .ok_or(AmazonBedrockProviderBuilderError::MissingRegion)?, bedrockruntime_client: Arc::new(aws_sdk_bedrockruntime::Client::new(&sdk_config)), bedrockagentruntime_client: Arc::new(aws_sdk_bedrockagentruntime::Client::new( &sdk_config, @@ -651,15 +661,15 @@ impl AmazonBedrockProviderBuilder { #[cfg(test)] mod tests { + use super::*; use crate::{ AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, LanguageModel, - generate_object, generate_text, + RerankingModel, SerializationFormat, generate_object, generate_text, rerank, + structured_rerank, }; use serde::Deserialize; use std::sync::Arc; - use super::*; - #[tokio::test] async fn test_bedrock_generate_object() { let provider = Arc::new(AIProvider::from( @@ -727,4 +737,156 @@ mod tests { dbg!(&generate_text_response); } + + #[tokio::test] + async fn test_rerank() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(RerankingModel { + provider, + model_name: "cohere.rerank-v3-5:0".to_string(), + }); + + let request = RerankRequest::builder() + .model(model) + .document("Stock markets reached record highs today as investors reacted positively to economic data.") + .document("The local sports team won their championship game in a thrilling overtime victory.") + .document("A new cafe opened downtown, offering a variety of artisanal coffees and pastries.") + .document("Scientists have discovered a new species of bird in the remote rainforests of the Amazon.") + .document("The city council has approved a new plan to improve public transportation and reduce traffic congestion.") + .document("Researchers develop more efficient solar panel technology.") + .query("environmental sustainability initiatives") + .top_k(6) + .build() + .unwrap(); + + let rerank_response = rerank(request).await.unwrap(); + dbg!(&rerank_response); + } + + #[tokio::test] + async fn test_structured_rerank() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(RerankingModel { + provider, + model_name: "cohere.rerank-v3-5:0".to_string(), + }); + + #[derive(Debug, Clone, Serialize, Deserialize)] + struct Document { + id: String, + content: String, + metadata: DocumentMetadata, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + struct DocumentMetadata { + category: String, + author: Option, + timestamp: Option, + } + + let request = StructuredRerankRequest::builder() + .model(model) + .serialization_format(SerializationFormat::Compact) + .serialization_mode(SerializationMode::Json) + .documents(vec![ + Document { + id: "doc_1".to_string(), + content: "Python is a high-level programming language known for its simplicity and readability.".to_string(), + metadata: DocumentMetadata { + category: "Programming Languages".to_string(), + author: Some("Tech Writer".to_string()), + timestamp: Some("2024-01-15".to_string()), + }, + }, + Document { + id: "doc_2".to_string(), + content: "JavaScript is primarily used for web development and runs in browsers.".to_string(), + metadata: DocumentMetadata { + category: "Programming Languages".to_string(), + author: Some("Tech Writer".to_string()), + timestamp: Some("2024-01-16".to_string()), + }, + }, + Document { + id: "doc_3".to_string(), + content: "Machine learning models require large datasets for training.".to_string(), + metadata: DocumentMetadata { + category: "Machine Learning".to_string(), + author: Some("Data Scientist".to_string()), + timestamp: Some("2024-01-17".to_string()), + }, + }, + Document { + id: "doc_4".to_string(), + content: "Python's pandas library is excellent for data manipulation and analysis.".to_string(), + metadata: DocumentMetadata { + category: "Data Analysis".to_string(), + author: Some("Data Analyst".to_string()), + timestamp: Some("2024-01-18".to_string()), + }, + }, + Document { + id: "doc_5".to_string(), + content: "The React framework is built on top of JavaScript for building user interfaces.".to_string(), + metadata: DocumentMetadata { + category: "Web Development".to_string(), + author: Some("Frontend Dev".to_string()), + timestamp: Some("2024-01-19".to_string()), + }, + }, + Document { + id: "doc_6".to_string(), + content: "Deep learning is a subset of machine learning that uses neural networks.".to_string(), + metadata: DocumentMetadata { + category: "Machine Learning".to_string(), + author: Some("ML Engineer".to_string()), + timestamp: Some("2024-01-20".to_string()), + }, + }, + Document { + id: "doc_7".to_string(), + content: "Python supports multiple programming paradigms including object-oriented and functional programming.".to_string(), + metadata: DocumentMetadata { + category: "Programming Languages".to_string(), + author: Some("Tech Writer".to_string()), + timestamp: Some("2024-01-21".to_string()), + }, + }, + Document { + id: "doc_8".to_string(), + content: "Node.js allows JavaScript to run on the server side.".to_string(), + metadata: DocumentMetadata { + category: "Backend Development".to_string(), + author: Some("Backend Dev".to_string()), + timestamp: Some("2024-01-22".to_string()), + }, + }, + ]) + .query("How to use Python for data analysis?") + .top_k(6) + .build() + .unwrap(); + + let rerank_response = structured_rerank(request).await.unwrap(); + dbg!(&rerank_response); + } } diff --git a/crates/umem_ai/src/providers/azure_openai.rs b/crates/umem_ai/src/providers/azure_openai.rs index 7f9ce28..cff8f41 100644 --- a/crates/umem_ai/src/providers/azure_openai.rs +++ b/crates/umem_ai/src/providers/azure_openai.rs @@ -4,6 +4,8 @@ use crate::{ }; use anyhow::{Result, bail}; use async_trait::async_trait; + +#[derive(Clone, Debug)] pub struct AzureOpenAIProvider { pub resource_name: Option, pub api_key: String, diff --git a/crates/umem_ai/src/providers/cohere.rs b/crates/umem_ai/src/providers/cohere.rs index de3ef04..bc96b80 100644 --- a/crates/umem_ai/src/providers/cohere.rs +++ b/crates/umem_ai/src/providers/cohere.rs @@ -51,7 +51,7 @@ impl Reranks for CohereProvider { "model": &request.model.model_name, "query": &request.query, "documents": &request.documents, - "top_n": request.top_n, + "top_n": request.top_k, })) .send() .await? diff --git a/crates/umem_ai/src/providers/google_vertex.rs b/crates/umem_ai/src/providers/google_vertex.rs index 553496b..b2c311d 100644 --- a/crates/umem_ai/src/providers/google_vertex.rs +++ b/crates/umem_ai/src/providers/google_vertex.rs @@ -1,5 +1,6 @@ use anyhow::{Result, bail}; +#[derive(Clone, Debug)] pub struct GoogleVertexAIProvider { pub project: String, pub location: String, @@ -7,6 +8,7 @@ pub struct GoogleVertexAIProvider { pub credentials: GoogleCredentials, } +#[derive(Clone, Debug)] pub struct GoogleCredentials { client_email: String, private_key: String, diff --git a/crates/umem_ai/src/providers/xai.rs b/crates/umem_ai/src/providers/xai.rs index cc9d834..da968e1 100644 --- a/crates/umem_ai/src/providers/xai.rs +++ b/crates/umem_ai/src/providers/xai.rs @@ -1,5 +1,6 @@ use anyhow::{Result, bail}; +#[derive(Clone, Debug)] pub struct XAIProvider { pub api_key: String, pub base_url: String, diff --git a/crates/umem_ai/src/response_generators/rerank.rs b/crates/umem_ai/src/response_generators/rerank.rs index b9bc260..bcfde7d 100644 --- a/crates/umem_ai/src/response_generators/rerank.rs +++ b/crates/umem_ai/src/response_generators/rerank.rs @@ -40,7 +40,7 @@ pub async fn rerank(request: RerankRequest) -> Result, - pub top_n: usize, + pub top_k: usize, pub timeout: Duration, pub max_retries: usize, pub model: Arc, @@ -49,7 +49,7 @@ pub struct RerankRequest { impl RerankRequest { pub fn builder() -> RerankRequestBuilder { RerankRequestBuilder { - top_n: 5, + top_k: 5, ..Default::default() } } @@ -58,7 +58,7 @@ impl RerankRequest { pub struct RerankRequestBuilder { query: Option, documents: Vec, - top_n: usize, + top_k: usize, timeout: Duration, max_retries: usize, model: Option>, @@ -69,7 +69,7 @@ impl Default for RerankRequestBuilder { Self { query: None, documents: vec![], - top_n: 5, + top_k: 5, timeout: Duration::from_secs(30), max_retries: 3, model: None, @@ -108,8 +108,8 @@ impl RerankRequestBuilder { self } - pub fn top_k(mut self, top_n: usize) -> Self { - self.top_n = top_n; + pub fn top_k(mut self, top_k: usize) -> Self { + self.top_k = top_k; self } @@ -136,7 +136,7 @@ impl RerankRequestBuilder { Ok(RerankRequest { query: self.query.ok_or(RerankRequestBuilderError::MissingQuery)?, documents: self.documents, - top_n: self.top_n, + top_k: self.top_k, timeout: self.timeout, max_retries: self.max_retries, model: Arc::clone(&self.model.ok_or(RerankRequestBuilderError::MissingModel)?), @@ -144,12 +144,14 @@ impl RerankRequestBuilder { } } +#[derive(Debug)] pub struct RerankResponse { pub rankings: Vec, pub ranked_documents: Vec, pub raw_fields: Map, } +#[derive(Debug)] pub struct Ranking { pub original_index: usize, pub score: f32, diff --git a/crates/umem_ai/src/response_generators/structured_rerank.rs b/crates/umem_ai/src/response_generators/structured_rerank.rs index 4f0df7a..64b7ff9 100644 --- a/crates/umem_ai/src/response_generators/structured_rerank.rs +++ b/crates/umem_ai/src/response_generators/structured_rerank.rs @@ -60,7 +60,7 @@ where .await } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct StructuredRerankRequest where T: Serialize + Clone, @@ -202,6 +202,7 @@ where } } +#[derive(Clone, Debug)] pub struct StructuredRerankResponse where T: Serialize + Clone, @@ -211,6 +212,7 @@ where pub raw_fields: Map, } +#[derive(Clone, Debug)] pub struct StructuredRanking where T: Serialize + Clone,