diff --git a/Cargo.lock b/Cargo.lock index 562e124..f32204b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,22 @@ dependencies = [ "libc", ] +[[package]] +name = "annotate-snippets" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15580ece6ea97cbf832d60ba19c021113469480852c6a2a6beb0db28f097bf1f" +dependencies = [ + "anstyle", + "unicode-width", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + [[package]] name = "anyhow" version = "1.0.100" @@ -157,6 +173,425 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-config" +version = "1.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96571e6996817bf3d58f6b569e4b9fd2e9d2fcf9f7424eed07b2ce9bb87535e5" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 1.4.0", + "ring", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + +[[package]] +name = "aws-lc-rs" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a88aab2464f1f25453baa7a07c84c5b7684e274054ba06817f382357f77a288" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45afffdee1e7c9126814751f88dddc747f41d91da16c9551a0f1e8a11e788a1" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "aws-runtime" +version = "1.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81b5b2898f6798ad58f484856768bca817e3cd9de0974c24ae0f1113fe88f1b" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-bedrockagentruntime" +version = "1.119.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d098831f9c542a92bb9458f21a9a8d11afa55a72e06d9bc117bb7353e58d9faa" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-bedrockruntime" +version = "1.120.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b8dcf42378ab2d5accac1652cdd059114fb071baf53250ceafb76fcdde347f" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "hyper 0.14.32", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.91.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ee6402a36f27b52fe67661c6732d684b2635152b676aa2babbfb5204f99115d" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.93.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a45a7f750bbd170ee3677671ad782d90b894548f4e4ae168302c57ec9de5cb3e" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.95.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55542378e419558e6b1f398ca70adb0b2088077e79ad9f14eb09441f2f7b2164" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c" +dependencies = [ + "aws-credential-types", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.4.0", + "percent-encoding", + "sha2", + "time", + "tracing", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc12f8b310e38cad85cf3bef45ad236f470717393c613266ce0a89512286b650" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-http" +version = "0.62.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" +dependencies = [ + "aws-smithy-eventstream", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "futures-util", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http-client" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59e62db736db19c488966c8d787f52e6270be565727236fd5579eaa301e7bc4a" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "h2 0.3.27", + "h2 0.4.12", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper 1.8.1", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.7", + "hyper-util", + "pin-project-lite", + "rustls 0.21.12", + "rustls 0.23.35", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.4", + "tower 0.5.2", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-observability" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a" +dependencies = [ + "aws-smithy-runtime-api", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a392db6c583ea4a912538afb86b7be7c5d8887d91604f50eb55c262ee1b4a5f5" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "pin-project-lite", + "pin-utils", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab0d43d899f9e508300e587bf582ba54c27a452dd0a9ea294690669138ae14a2" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.4.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "905cb13a9895626d49cf2ced759b062d913834c7482c38e49557eac4e6193f01" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11b2f670422ff42bf7065031e72b45bc52a3508bd089f743ea90731ca2b6ea57" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "axum" version = "0.7.9" @@ -167,8 +602,8 @@ dependencies = [ "axum-core 0.4.5", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "itoa", "matchit 0.7.3", @@ -195,10 +630,10 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "itoa", "matchit 0.8.4", @@ -227,8 +662,8 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -246,8 +681,8 @@ checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -285,6 +720,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "base64ct" version = "1.8.1" @@ -381,6 +826,16 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "cbc" version = "0.1.2" @@ -438,6 +893,15 @@ dependencies = [ "inout", ] +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -870,6 +1334,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -903,6 +1373,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "encoding_rs_io" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3c5651fb62ab8aa3103998dade57efdd028544bd300516baa31840c252a83" +dependencies = [ + "encoding_rs", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1038,6 +1517,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futf" version = "0.1.5" @@ -1197,6 +1682,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "h2" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.12.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.12" @@ -1208,7 +1712,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.4.0", "indexmap 2.12.1", "slab", "tokio", @@ -1307,6 +1811,17 @@ dependencies = [ "syn", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.4.0" @@ -1317,6 +1832,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -1324,7 +1850,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.4.0", ] [[package]] @@ -1335,8 +1861,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1352,6 +1878,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.5.10", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.8.1" @@ -1362,32 +1912,48 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", "pin-project-lite", "pin-utils", - "smallvec", + "smallvec 1.15.1", "tokio", "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-rustls" version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", - "hyper", + "http 1.4.0", + "hyper 1.8.1", "hyper-util", - "rustls", + "rustls 0.23.35", + "rustls-native-certs", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "tower-service", "webpki-roots", ] @@ -1398,7 +1964,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" dependencies = [ - "hyper", + "hyper 1.8.1", "hyper-util", "pin-project-lite", "tokio", @@ -1413,7 +1979,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "native-tls", "tokio", @@ -1432,9 +1998,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", - "hyper", + "http 1.4.0", + "http-body 1.0.1", + "hyper 1.8.1", "ipnet", "libc", "percent-encoding", @@ -1529,7 +2095,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.15.1", "zerovec", ] @@ -1611,7 +2177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.15.1", "utf8_iter", ] @@ -2066,6 +2632,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "8.0.0" @@ -2117,7 +2689,7 @@ dependencies = [ "num-iter", "num-traits", "rand 0.8.5", - "smallvec", + "smallvec 1.15.1", "zeroize", ] @@ -2166,7 +2738,7 @@ dependencies = [ "base64", "chrono", "getrandom 0.2.16", - "http", + "http 1.4.0", "rand 0.8.5", "reqwest", "serde", @@ -2243,6 +2815,12 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "pandoc" version = "0.8.11" @@ -2277,7 +2855,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.15.1", "windows-link", ] @@ -2636,7 +3214,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.1.1", - "rustls", + "rustls 0.23.35", "socket2 0.6.1", "thiserror 2.0.17", "tokio", @@ -2656,7 +3234,7 @@ dependencies = [ "rand 0.9.2", "ring", "rustc-hash 2.1.1", - "rustls", + "rustls 0.23.35", "rustls-pki-types", "slab", "thiserror 2.0.17", @@ -2842,6 +3420,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d942b98df5e658f56f20d592c7f868833fe38115e65c33003d8cd224b0155da" + [[package]] name = "regex-syntax" version = "0.8.8" @@ -2861,12 +3445,12 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.8.1", + "hyper-rustls 0.27.7", "hyper-tls", "hyper-util", "js-sys", @@ -2876,7 +3460,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.35", "rustls-pki-types", "serde", "serde_json", @@ -2884,7 +3468,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.26.4", "tokio-util", "tower 0.5.2", "tower-http", @@ -2922,8 +3506,8 @@ dependencies = [ "bytes", "chrono", "futures", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "oauth2", "paste", @@ -3014,6 +3598,15 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.1.2" @@ -3027,17 +3620,30 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.103.8", "subtle", "zeroize", ] @@ -3073,12 +3679,23 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustls-webpki" version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -3096,6 +3713,16 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "saphyr-parser" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb771b59f6b1985d1406325ec28f97cfb14256abcec4fdfb37b36a1766d6af7" +dependencies = [ + "arraydeque", + "hashlink", +] + [[package]] name = "schannel" version = "0.1.28" @@ -3137,6 +3764,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -3189,6 +3826,25 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-saphyr" +version = "0.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45400dbf0a2c4c2af106c08eb1028b3e45bea3cd45517d3a0e8170b86597122f" +dependencies = [ + "ahash", + "annotate-snippets", + "base64", + "encoding_rs_io", + "nohash-hasher", + "num-traits", + "saphyr-parser", + "serde", + "serde_json", + "smallvec 2.0.0-alpha.12", + "zmij", +] + [[package]] name = "serde-untagged" version = "0.1.9" @@ -3238,6 +3894,7 @@ version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ + "indexmap 2.12.1", "itoa", "memchr", "ryu", @@ -3372,6 +4029,12 @@ dependencies = [ "serde", ] +[[package]] +name = "smallvec" +version = "2.0.0-alpha.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef784004ca8777809dcdad6ac37629f0a97caee4c685fcea805278d81dd8b857" + [[package]] name = "socket2" version = "0.5.10" @@ -3451,7 +4114,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "smallvec", + "smallvec 1.15.1", "thiserror 2.0.17", "tokio", "tokio-stream", @@ -3532,7 +4195,7 @@ dependencies = [ "serde", "sha1", "sha2", - "smallvec", + "smallvec 1.15.1", "sqlx-core", "stringprep", "thiserror 2.0.17", @@ -3570,7 +4233,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "smallvec", + "smallvec 1.15.1", "sqlx-core", "stringprep", "thiserror 2.0.17", @@ -3612,7 +4275,7 @@ checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" dependencies = [ "bytes", "futures-util", - "http-body", + "http-body 1.0.1", "http-body-util", "pin-project-lite", ] @@ -3953,13 +4616,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ - "rustls", + "rustls 0.23.35", "tokio", ] @@ -4030,11 +4703,11 @@ dependencies = [ "base64", "bytes", "flate2", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-timeout", "hyper-util", "percent-encoding", @@ -4044,7 +4717,7 @@ dependencies = [ "rustls-pemfile", "socket2 0.5.10", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "tokio-stream", "tower 0.4.13", "tower-layer", @@ -4111,8 +4784,8 @@ dependencies = [ "bitflags", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "iri-string", "pin-project-lite", "tower 0.5.2", @@ -4201,7 +4874,7 @@ dependencies = [ "once_cell", "regex-automata", "sharded-slab", - "smallvec", + "smallvec 1.15.1", "thread_local", "tracing", "tracing-core", @@ -4282,6 +4955,10 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "aws-config", + "aws-sdk-bedrockagentruntime", + "aws-sdk-bedrockruntime", + "aws-smithy-types", "backon", "base64", "lazy_static", @@ -4290,6 +4967,7 @@ dependencies = [ "rustc-hash 2.1.1", "schemars", "serde", + "serde-saphyr", "serde_json", "thiserror 2.0.17", "tokio", @@ -4518,6 +5196,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" @@ -4536,6 +5220,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf-8" version = "0.7.6" @@ -4578,6 +5268,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "want" version = "0.3.1" @@ -5057,6 +5753,12 @@ dependencies = [ "markup5ever", ] +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "yaml-rust2" version = "0.10.4" @@ -5171,3 +5873,9 @@ dependencies = [ "quote", "syn", ] + +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index 354be11..1392093 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "umem_ai" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] anyhow.workspace = true @@ -20,3 +20,8 @@ tokio.workspace = true tracing.workspace = true typed-builder.workspace = true umem_config = {workspace = true} +aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } +aws-sdk-bedrockruntime = "1.120.0" +aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]} +serde-saphyr = "0.0.14" +aws-sdk-bedrockagentruntime = "1.119.0" diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index ffd6a10..8c72b2e 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -2,18 +2,16 @@ #![allow(dead_code)] mod providers; mod response_generators; -pub(crate) mod utils; - -use std::sync::Arc; +mod utils; use anyhow::Result; use async_trait::async_trait; use lazy_static::lazy_static; -use schemars::JsonSchema; -use serde::{de::DeserializeOwned, Serialize}; - pub use providers::*; pub use response_generators::*; +use schemars::JsonSchema; +use serde::{Serialize, de::DeserializeOwned}; +use std::sync::Arc; use umem_config::CONFIG; pub type HashMap = rustc_hash::FxHashMap; @@ -40,6 +38,26 @@ impl LanguageModel { } } +#[derive(Debug, Clone)] +pub struct RerankingModel { + pub provider: Arc, + pub model_name: String, +} + +impl RerankingModel { + fn new(provider: Arc, model_name: String) -> Self { + Self { + provider, + model_name, + } + } + + pub fn get_model() -> Arc { + Arc::clone(&LANGUAGE_MODEL) + } +} + +#[derive(Debug)] pub enum AIProvider { OpenAI(OpenAIProvider), AzureOpenAI(AzureOpenAIProvider), @@ -47,6 +65,7 @@ pub enum AIProvider { Anthropic(AnthropicProvider), XAI(XAIProvider), AmazonBedrock(AmazonBedrockProvider), + Cohere(CohereProvider), } lazy_static! { @@ -77,6 +96,7 @@ impl AIProvider { ) -> Result { match self { AIProvider::OpenAI(provider) => provider.generate_text(request), + AIProvider::AmazonBedrock(provider) => provider.generate_text(request), _ => unimplemented!(), } .await @@ -90,10 +110,37 @@ impl AIProvider { ) -> Result, ResponseGeneratorError> { match self { AIProvider::OpenAI(provider) => provider.generate_object(request), + AIProvider::AmazonBedrock(provider) => provider.generate_object(request), _ => unimplemented!(), } .await } + + pub(crate) async fn do_reranking( + &self, + request: RerankRequest, + ) -> Result { + match self { + AIProvider::Cohere(provider) => provider.rerank(request), + AIProvider::AmazonBedrock(provider) => provider.rerank(request), + _ => unimplemented!(), + } + .await + } + + pub(crate) async fn do_structured_reranking( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + match self { + AIProvider::Cohere(provider) => provider.rerank_structured(request).await, + AIProvider::AmazonBedrock(provider) => provider.rerank_structured(request).await, + _ => unimplemented!(), + } + } } #[async_trait] @@ -112,6 +159,24 @@ pub trait GeneratesObject { ) -> Result, ResponseGeneratorError>; } +#[async_trait] +pub trait Reranks { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result; +} + +#[async_trait] +pub trait ReranksStructuredData { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync; +} + impl From for AIProvider { fn from(config: OpenAIProvider) -> Self { AIProvider::OpenAI(config) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 6f19701..a15f618 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -1,75 +1,892 @@ -use anyhow::{bail, Result}; -use async_trait::async_trait; - use crate::{ - response_generators::{GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError}, - GeneratesText, + GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, OpenAIProvider, + Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, SerializationMode, + StructuredRanking, StructuredRerankRequest, StructuredRerankResponse, + messages::{FilePart, UserModelMessage}, + response_generators::{ + self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, + }, + utils, +}; +use anyhow::Result; +use async_trait::async_trait; +use aws_config::{BehaviorVersion, Region}; +use aws_sdk_bedrockagentruntime::types::{ + BedrockRerankingConfiguration, BedrockRerankingModelConfiguration, RerankDocument, + RerankDocumentType, RerankQuery, RerankQueryContentType, RerankSource, RerankSourceType, + RerankTextDocument, RerankingConfiguration, RerankingConfigurationType, +}; +use aws_sdk_bedrockruntime::{ + error::BuildError, + operation::converse::builders::ConverseFluentBuilder, + types::{ + AnyToolChoice, ContentBlock, ConverseOutput, ImageBlock, InferenceConfiguration, Message, + Tool, ToolChoice, ToolConfiguration, ToolInputSchema, ToolSpecification, + }, }; +use base64::Engine; +use schemars::JsonSchema; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::Map; +use std::sync::Arc; +use thiserror::Error; +#[derive(Clone, Debug)] pub struct AmazonBedrockProvider { - pub region: String, - pub access_key: String, - pub secret_key: String, - pub session_token: Option, + region: Region, + bedrockruntime_client: Arc, + bedrockagentruntime_client: Arc, } #[async_trait] impl GeneratesText for AmazonBedrockProvider { async fn generate_text( &self, - _request: GenerateTextRequest, + request: GenerateTextRequest, ) -> Result { - unimplemented!() + let converse_request = self + .normalize_generate_text_request(&request) + .map_err(ResponseGeneratorError::Transient)?; + + let converse_response = converse_request + .set_inference_config(Some( + InferenceConfiguration::builder() + .temperature(request.temperature.unwrap_or(0.0)) + .top_p(request.top_p.unwrap_or(1.0)) + .max_tokens(request.max_output_tokens.unwrap_or(0_usize) as i32) + .build(), + )) + .additional_model_request_fields(aws_smithy_types::Document::Object( + request + .headers + .iter() + .map(|(key, value)| { + ( + key.as_str().to_string(), + aws_smithy_types::Document::String(value.to_str().unwrap().to_string()), + ) + }) + .collect(), + )) + .send() + .await + .map_err(|e| { + tracing::error!("{}", e); + ResponseGeneratorError::BedrockConverseError(format!("{:?}", e)) + })?; + + let converse_output = match converse_response.output { + Some(output) => output, + None => { + return Err(ResponseGeneratorError::EmptyProviderResponse); + } + }; + + let output_message = converse_output.as_message().map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output to be a bedrock message".into(), + ) + })?; + + let output_text = output_message + .content + .first() + .ok_or(ResponseGeneratorError::EmptyProviderResponse)? + .as_text() + .map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output message content to be text".into(), + ) + })?; + + Ok(GenerateTextResponse { + text: output_text.to_string(), + }) + } +} + +#[async_trait] +impl GeneratesObject for AmazonBedrockProvider { + async fn generate_object( + &self, + request: GenerateObjectRequest, + ) -> Result, ResponseGeneratorError> { + let converse_request = self + .normalize_generate_object_request(&request) + .map_err(ResponseGeneratorError::Transient)?; + + let converse_response = converse_request + .set_inference_config(Some( + InferenceConfiguration::builder() + .temperature(request.temperature.unwrap_or(0.0)) + .top_p(request.top_p.unwrap_or(1.0)) + .max_tokens(request.max_output_tokens.unwrap_or(5140_usize) as i32) + .build(), + )) + .additional_model_request_fields(aws_smithy_types::Document::Object( + request + .headers + .iter() + .map(|(key, value)| { + ( + key.as_str().to_string(), + aws_smithy_types::Document::String(value.to_str().unwrap().to_string()), + ) + }) + .collect(), + )) + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; + + let converse_output = match converse_response.output { + Some(output) => output, + None => { + return Err(ResponseGeneratorError::EmptyProviderResponse); + } + }; + + let output_message = match converse_output { + ConverseOutput::Message(msg) => msg, + _ => { + return Err(ResponseGeneratorError::InvalidProviderResponse( + "was expecting the output to be a bedrock message".into(), + )); + } + }; + + let json_tool = output_message + .content + .into_iter() + .rfind(|content_item| content_item.is_tool_use()) + .ok_or(ResponseGeneratorError::EmptyProviderResponse)?; + + let json_tool_input = json_tool + .as_tool_use() + .map_err(|_| { + ResponseGeneratorError::InvalidProviderResponse( + "was expecting the model to call the tool use".into(), + ) + })? + .input(); + + serde_json::from_value::(utils::aws_smithy_document_to_json(json_tool_input)) + .map(|output| GenerateObjectResponse { output }) + .map_err(|e| { + ResponseGeneratorError::Deserialization(e, format!("{:?}", json_tool_input)) + }) + } +} + +#[async_trait] +impl Reranks for AmazonBedrockProvider { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result { + let inline_sources: Vec = request + .documents + .iter() + .map(|document| { + RerankSource::builder() + .inline_document_source( + RerankDocument::builder() + .r#type(RerankDocumentType::Text) + .text_document(RerankTextDocument::builder().text(document).build()) + .build()?, + ) + .r#type(RerankSourceType::Inline) + .build() + }) + .collect::>() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build Bedrock-compat Rerank Sources from provided documents, Error: {}", e + )) + })?; + + let response = self + .bedrockagentruntime_client + .rerank() + .queries( + RerankQuery::builder() + .r#type(RerankQueryContentType::Text) + .text_query(RerankTextDocument::builder().text(&request.query).build()) + .build() + .map_err(|e| ResponseGeneratorError::InvalidArgumentsProvided( + format!("Failed to build RerankQuery, Details: {}", e) + ))? + ) + .set_sources( + Some(inline_sources) + ) + .reranking_configuration( + RerankingConfiguration::builder() + .r#type(RerankingConfigurationType::BedrockRerankingModel) + .bedrock_reranking_configuration( + BedrockRerankingConfiguration::builder() + .model_configuration( + BedrockRerankingModelConfiguration::builder() + .model_arn(format!("arn:aws:bedrock:{}::foundation-model/{}",&self.region, &request.model.model_name)) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided( + format!("Failed to build BedrockRerankingModelConfiguration, Details: {}", e) + ) + })?, + ) + .number_of_results(request.top_k as i32) + .build(), + ) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided( + format!("Failed to build RerankingConfiguration, Details: {}", e) + ) + })?, + ) + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?; + + let results = response.results(); + + let (rankings, ranked_documents) = results.iter().try_fold( + ( + Vec::with_capacity(results.len()), + Vec::with_capacity(results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index as usize).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Bedrock Rerank API returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(Ranking { + original_index: result.index as usize, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(RerankResponse { + rankings, + ranked_documents, + raw_fields: Map::with_capacity(0_usize), + }) + } +} + +#[async_trait] +impl ReranksStructuredData for AmazonBedrockProvider { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + let inline_sources: Vec = match request.serialization_mode { + SerializationMode::Json => { + let json_documents: Vec = request + .documents + .iter() + .map(|doc| { + serde_json::to_value(doc) + .map(utils::json_to_aws_smithy_document) + .map_err(|e| { + ResponseGeneratorError::StructuredRerankDocumentsSerializationError( + e.to_string(), + ) + }) + }) + .collect::, _>>()?; + + json_documents + .into_iter() + .map(|document| { + RerankSource::builder() + .inline_document_source( + RerankDocument::builder() + .r#type(RerankDocumentType::Json) + .json_document(document) + .build()?, + ) + .r#type(RerankSourceType::Inline) + .build() + }) + .collect::>() + } + SerializationMode::Yaml => { + let yaml_documents: Vec = request + .documents + .iter() + .map(|doc| { + serde_saphyr::to_string(doc).map_err(|e| { + ResponseGeneratorError::StructuredRerankDocumentsSerializationError( + e.to_string(), + ) + }) + }) + .collect::, _>>()?; + + yaml_documents + .iter() + .map(|document| { + RerankSource::builder() + .inline_document_source( + RerankDocument::builder() + .r#type(RerankDocumentType::Text) + .text_document( + RerankTextDocument::builder().text(document).build(), + ) + .build()?, + ) + .r#type(RerankSourceType::Inline) + .build() + }) + .collect::>() + } + } + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build Bedrock-compat Rerank Sources from provided documents, Error: {}", + e + )) + })?; + + let response = self + .bedrockagentruntime_client + .rerank() + .queries( + RerankQuery::builder() + .r#type(RerankQueryContentType::Text) + .text_query(RerankTextDocument::builder().text(&request.query).build()) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build RerankQuery, Details: {}", + e + )) + })?, + ) + .set_sources(Some(inline_sources)) + .reranking_configuration( + RerankingConfiguration::builder() + .r#type(RerankingConfigurationType::BedrockRerankingModel) + .bedrock_reranking_configuration( + BedrockRerankingConfiguration::builder() + .model_configuration( + BedrockRerankingModelConfiguration::builder() + .model_arn(format!( + "arn:aws:bedrock:{}::foundation-model/{}", + &self.region, &request.model.model_name + )) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build BedrockRerankingModelConfiguration, Details: {}", + e + )) + })?, + ) + .number_of_results(request.top_n as i32) + .build(), + ) + .build() + .map_err(|e| { + ResponseGeneratorError::InvalidArgumentsProvided(format!( + "Failed to build RerankingConfiguration, Details: {}", + e + )) + })?, + ) + .send() + .await + .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?; + + let results = response.results(); + + let (rankings, ranked_documents) = results.iter().try_fold( + ( + Vec::with_capacity(results.len()), + Vec::with_capacity(results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index as usize).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Bedrock Rerank API returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(StructuredRanking { + original_index: result.index as usize, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(StructuredRerankResponse { + rankings, + ranked_documents, + raw_fields: Map::with_capacity(0_usize), + }) } } +impl AmazonBedrockProvider { + fn normalize_generate_object_request< + T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, + >( + &self, + request: &GenerateObjectRequest, + ) -> anyhow::Result { + let system = OpenAIProvider::normalize_system_message(&request.messages); + let user_messages = Self::normalize_user_messages(&request.messages).unwrap(); + let output_schema_value = serde_json::to_value(&request.output_schema)?; + + Ok(self + .bedrockruntime_client + .converse() + .model_id(request.model.model_name.clone()) + .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( + system, + )) + .tool_config( + ToolConfiguration::builder() + .tools(Tool::ToolSpec( + ToolSpecification::builder() + .name("json_output") + .description("Return output as JSON.") + .input_schema(ToolInputSchema::Json( + utils::json_to_aws_smithy_document(output_schema_value), + )) + .build()?, + )) + .tool_choice(ToolChoice::Any(AnyToolChoice::builder().build())) + .build()?, + ) + .messages( + Message::builder() + .role("user".into()) + .set_content(Some(user_messages)) + .build() + .unwrap(), + )) + } + + fn normalize_generate_text_request( + &self, + request: &GenerateTextRequest, + ) -> anyhow::Result { + let system = OpenAIProvider::normalize_system_message(&request.messages); + let user_messages = Self::normalize_user_messages(&request.messages)?; + + Ok(self + .bedrockruntime_client + .converse() + .model_id(request.model.model_name.clone()) + .system(aws_sdk_bedrockruntime::types::SystemContentBlock::Text( + system, + )) + .messages( + Message::builder() + .role("user".into()) + .set_content(Some(user_messages)) + .build() + .unwrap(), + )) + } + + fn normalize_user_messages( + messages: &[response_generators::messages::Message], + ) -> anyhow::Result> { + let user_messages: Vec<&response_generators::messages::UserModelMessage> = messages + .iter() + .filter_map(|msg| match msg { + response_generators::messages::Message::User(v) => Some(v), + _ => None, + }) + .collect(); + + let user_message_content_blocks: Vec = user_messages + .iter() + .flat_map(|um| match um { + UserModelMessage::Text(text) => vec![ContentBlock::Text(text.into())], + UserModelMessage::Parts(user_message_parts) => user_message_parts + .iter() + .map(|part| match part { + crate::messages::UserMessagePart::Text(text) => { + ContentBlock::Text(text.into()) + } + crate::messages::UserMessagePart::Image(image_part) => { + let image_block = match image_part { + FilePart::Url(_, _) => { + unimplemented!("AWS doesn't support URL images yet"); + } + FilePart::Base64(b64_string, media_type) => { + let decoded = base64::engine::general_purpose::STANDARD + .decode(b64_string) + .expect("not a valid base64 string"); + + ImageBlock::builder() + .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( + decoded.as_slice().into(), + )) + .format( + media_type + .clone() + .unwrap_or(mime::IMAGE_PNG) + .to_string() + .as_str() + .into(), + ) + .build() + .expect("failed to build image block") + } + FilePart::Buffer(items, media_type) => ImageBlock::builder() + .source(aws_sdk_bedrockruntime::types::ImageSource::Bytes( + items.as_slice().into(), + )) + .format( + media_type + .clone() + .unwrap_or(mime::IMAGE_PNG) + .to_string() + .as_str() + .into(), + ) + .build() + .expect("failed to build image block"), + }; + ContentBlock::Image(image_block) + } + crate::messages::UserMessagePart::File(_) => { + unimplemented!("file handling not yet supported for Bedrock") + } + }) + .collect(), + }) + .collect(); + + Ok(user_message_content_blocks) + } + + fn builder() -> AmazonBedrockProviderBuilder { + AmazonBedrockProviderBuilder::new() + } +} + +#[derive(Default)] pub struct AmazonBedrockProviderBuilder { - pub region: Option, - pub access_key: Option, - pub secret_key: Option, - pub session_token: Option, + access_key_id: Option, + secret_access_key: Option, + region: Option, + provider_name: Option, +} + +#[derive(Error, Debug)] +pub enum AmazonBedrockProviderBuilderError { + #[error("Missing AWS Access Key ID")] + MissingAccessKeyId, + #[error("Missing AWS Secret Access Key")] + MissingSecretAccessKey, + #[error("Missing AWS Bedrock Region")] + MissingRegion, + #[error(transparent)] + BadHeaders(#[from] utils::BuildHeaderMapError), } impl AmazonBedrockProviderBuilder { pub fn new() -> Self { - Self { - region: None, - access_key: None, - secret_key: None, - session_token: None, - } + Self::default() } - pub fn region(mut self, region: String) -> Self { - self.region = Some(region); + pub fn access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); self } - pub fn access_key(mut self, access_key: String) -> Self { - self.access_key = Some(access_key); + pub fn secret_access_key(mut self, secret_access_key: impl Into) -> Self { + self.secret_access_key = Some(secret_access_key.into()); self } - pub fn secret_key(mut self, secret_key: String) -> Self { - self.secret_key = Some(secret_key); + pub fn region(mut self, region: impl Into) -> Self { + self.region = Some(Region::new(region.into())); self } - pub fn session_token(mut self, session_token: String) -> Self { - self.session_token = Some(session_token); + pub fn provider_name(mut self, provider_name: impl Into) -> Self { + self.provider_name = Some(provider_name.into()); self } - pub fn build(self) -> Result { - if self.region.is_none() || self.access_key.is_none() || self.secret_key.is_none() { - bail!("region, access_key, and secret_key are required"); - } + pub async fn build(self) -> Result { + let sdk_config = aws_config::defaults(BehaviorVersion::latest()) + .region(self.region.clone()) + .credentials_provider( + aws_sdk_bedrockruntime::config::Credentials::builder() + .access_key_id( + self.access_key_id + .clone() + .ok_or(AmazonBedrockProviderBuilderError::MissingAccessKeyId)?, + ) + .secret_access_key( + self.secret_access_key + .ok_or(AmazonBedrockProviderBuilderError::MissingSecretAccessKey)?, + ) + .provider_name("umem-ai-bedrock-provider") + .build(), + ) + .load() + .await; Ok(AmazonBedrockProvider { - region: self.region.unwrap(), - access_key: self.access_key.unwrap(), - secret_key: self.secret_key.unwrap(), - session_token: self.session_token, + region: self + .region + .ok_or(AmazonBedrockProviderBuilderError::MissingRegion)?, + bedrockruntime_client: Arc::new(aws_sdk_bedrockruntime::Client::new(&sdk_config)), + bedrockagentruntime_client: Arc::new(aws_sdk_bedrockagentruntime::Client::new( + &sdk_config, + )), }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, LanguageModel, + RerankingModel, SerializationFormat, generate_object, generate_text, rerank, + structured_rerank, + }; + use serde::Deserialize; + use std::sync::Arc; + + #[tokio::test] + async fn test_bedrock_generate_object() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + #[derive(Clone, JsonSchema, Serialize, Deserialize, Debug)] + struct Holiday { + name: String, + traditions: String, + } + + let model = Arc::new(LanguageModel { + provider, + model_name: "deepseek.v3-v1:0".to_string(), + }); + + let request = GenerateObjectRequestBuilder::::new() + .model(model) + .system("You are a helpful assistant.".to_string()) + .prompt("Invent a new holiday and describe its traditions.".to_string()) + .max_output_tokens(2000) + .temperature(0.7) + .build() + .unwrap(); + + let generate_object_response = generate_object(request).await.unwrap(); + + dbg!(&generate_object_response); + } + + #[tokio::test] + async fn test_bedrock_generate_text() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(LanguageModel { + provider, + model_name: "deepseek.v3-v1:0".to_string(), + }); + + let request = GenerateTextRequestBuilder::new() + .model(model) + .system("You are a helpful assistant.") + .prompt("Invent a new holiday and describe its traditions.") + .max_output_tokens(2000) + .temperature(0.7) + .build() + .unwrap(); + + let generate_text_response = generate_text(request).await.unwrap(); + + dbg!(&generate_text_response); + } + + #[tokio::test] + async fn test_rerank() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(RerankingModel { + provider, + model_name: "cohere.rerank-v3-5:0".to_string(), + }); + + let request = RerankRequest::builder() + .model(model) + .document("Stock markets reached record highs today as investors reacted positively to economic data.") + .document("The local sports team won their championship game in a thrilling overtime victory.") + .document("A new cafe opened downtown, offering a variety of artisanal coffees and pastries.") + .document("Scientists have discovered a new species of bird in the remote rainforests of the Amazon.") + .document("The city council has approved a new plan to improve public transportation and reduce traffic congestion.") + .document("Researchers develop more efficient solar panel technology.") + .query("environmental sustainability initiatives") + .top_k(6) + .build() + .unwrap(); + + let rerank_response = rerank(request).await.unwrap(); + dbg!(&rerank_response); + } + + #[tokio::test] + async fn test_structured_rerank() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(RerankingModel { + provider, + model_name: "cohere.rerank-v3-5:0".to_string(), + }); + + #[derive(Debug, Clone, Serialize, Deserialize)] + struct Document { + id: String, + content: String, + metadata: DocumentMetadata, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + struct DocumentMetadata { + category: String, + author: Option, + timestamp: Option, + } + + let request = StructuredRerankRequest::builder() + .model(model) + .serialization_format(SerializationFormat::Compact) + .serialization_mode(SerializationMode::Json) + .documents(vec![ + Document { + id: "doc_1".to_string(), + content: "Python is a high-level programming language known for its simplicity and readability.".to_string(), + metadata: DocumentMetadata { + category: "Programming Languages".to_string(), + author: Some("Tech Writer".to_string()), + timestamp: Some("2024-01-15".to_string()), + }, + }, + Document { + id: "doc_2".to_string(), + content: "JavaScript is primarily used for web development and runs in browsers.".to_string(), + metadata: DocumentMetadata { + category: "Programming Languages".to_string(), + author: Some("Tech Writer".to_string()), + timestamp: Some("2024-01-16".to_string()), + }, + }, + Document { + id: "doc_3".to_string(), + content: "Machine learning models require large datasets for training.".to_string(), + metadata: DocumentMetadata { + category: "Machine Learning".to_string(), + author: Some("Data Scientist".to_string()), + timestamp: Some("2024-01-17".to_string()), + }, + }, + Document { + id: "doc_4".to_string(), + content: "Python's pandas library is excellent for data manipulation and analysis.".to_string(), + metadata: DocumentMetadata { + category: "Data Analysis".to_string(), + author: Some("Data Analyst".to_string()), + timestamp: Some("2024-01-18".to_string()), + }, + }, + Document { + id: "doc_5".to_string(), + content: "The React framework is built on top of JavaScript for building user interfaces.".to_string(), + metadata: DocumentMetadata { + category: "Web Development".to_string(), + author: Some("Frontend Dev".to_string()), + timestamp: Some("2024-01-19".to_string()), + }, + }, + Document { + id: "doc_6".to_string(), + content: "Deep learning is a subset of machine learning that uses neural networks.".to_string(), + metadata: DocumentMetadata { + category: "Machine Learning".to_string(), + author: Some("ML Engineer".to_string()), + timestamp: Some("2024-01-20".to_string()), + }, + }, + Document { + id: "doc_7".to_string(), + content: "Python supports multiple programming paradigms including object-oriented and functional programming.".to_string(), + metadata: DocumentMetadata { + category: "Programming Languages".to_string(), + author: Some("Tech Writer".to_string()), + timestamp: Some("2024-01-21".to_string()), + }, + }, + Document { + id: "doc_8".to_string(), + content: "Node.js allows JavaScript to run on the server side.".to_string(), + metadata: DocumentMetadata { + category: "Backend Development".to_string(), + author: Some("Backend Dev".to_string()), + timestamp: Some("2024-01-22".to_string()), + }, + }, + ]) + .query("How to use Python for data analysis?") + .top_k(6) + .build() + .unwrap(); + + let rerank_response = structured_rerank(request).await.unwrap(); + dbg!(&rerank_response); + } +} diff --git a/crates/umem_ai/src/providers/anthropic.rs b/crates/umem_ai/src/providers/anthropic.rs index 7e694ab..2a20e85 100644 --- a/crates/umem_ai/src/providers/anthropic.rs +++ b/crates/umem_ai/src/providers/anthropic.rs @@ -1,8 +1,8 @@ +use crate::GeneratesText; +use crate::ResponseGeneratorError; use crate::response_generators::GenerateTextRequest; use crate::response_generators::GenerateTextResponse; use crate::utils; -use crate::GeneratesText; -use crate::ResponseGeneratorError; use anyhow::Result; use async_trait::async_trait; use reqwest::header::HeaderMap; @@ -19,9 +19,7 @@ pub struct AnthropicProvider { #[builder(default = "https://api.anthropic.com/v1".into())] pub base_url: String, - #[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| - utils::build_header_map(value.as_slice()).unwrap_or_default() - ))] + #[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| utils::build_header_map(value.as_slice()).unwrap_or_default()))] pub headers: HeaderMap, } @@ -41,7 +39,9 @@ mod tests { #[test] fn test_building_anthropic_provider() { - let provider = AnthropicProvider::builder().api_key("sk-some-api-key").build(); + let provider = AnthropicProvider::builder() + .api_key("sk-some-api-key") + .build(); dbg!("Anthropic Provider: {:?}", provider); } } diff --git a/crates/umem_ai/src/providers/azure_openai.rs b/crates/umem_ai/src/providers/azure_openai.rs index 62d841e..cff8f41 100644 --- a/crates/umem_ai/src/providers/azure_openai.rs +++ b/crates/umem_ai/src/providers/azure_openai.rs @@ -1,9 +1,11 @@ use crate::{ - response_generators::{GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError}, GeneratesText, + response_generators::{GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError}, }; -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; use async_trait::async_trait; + +#[derive(Clone, Debug)] pub struct AzureOpenAIProvider { pub resource_name: Option, pub api_key: String, diff --git a/crates/umem_ai/src/providers/cohere.rs b/crates/umem_ai/src/providers/cohere.rs new file mode 100644 index 0000000..bc96b80 --- /dev/null +++ b/crates/umem_ai/src/providers/cohere.rs @@ -0,0 +1,164 @@ +use crate::{ + Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, ResponseGeneratorError, + SerializationFormat, SerializationMode, StructuredRanking, StructuredRerankRequest, + StructuredRerankResponse, reqwest_client, utils, +}; +use async_trait::async_trait; +use reqwest::header::HeaderMap; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value, json}; +use typed_builder::TypedBuilder; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohereRerankAPIV2Response { + pub results: Vec, + #[serde(flatten)] + pub raw_fields: Map, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohereRerankResult { + pub index: usize, + pub relevance_score: f32, +} + +#[derive(TypedBuilder, Debug, Clone)] +pub struct CohereProvider { + #[builder(default = "https://api.cohere.com/v2".into(), setter(transform = |value: impl Into| value.into()))] + base_url: String, + + #[builder(setter(transform = |value: impl Into| value.into()))] + api_key: String, + + #[builder(default = HeaderMap::default(), setter(transform = |value: Vec<(String, String)>| + utils::build_header_map(value.as_slice()).unwrap_or_default() + ))] + headers: HeaderMap, +} + +#[async_trait] +impl Reranks for CohereProvider { + async fn rerank( + &self, + request: RerankRequest, + ) -> Result { + let response = reqwest_client + .post(format!("{}/rerank", self.base_url)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .header("Authorization", &format!("bearer {}", &self.api_key)) + .json(&json!({ + "model": &request.model.model_name, + "query": &request.query, + "documents": &request.documents, + "top_n": request.top_k, + })) + .send() + .await? + .error_for_status()? + .json::() + .await?; + + let (rankings, ranked_documents) = response.results.iter().try_fold( + ( + Vec::with_capacity(response.results.len()), + Vec::with_capacity(response.results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Cohere returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(Ranking { + original_index: result.index, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(RerankResponse { + raw_fields: response.raw_fields, + rankings, + ranked_documents, + }) + } +} + +#[async_trait] +impl ReranksStructuredData for CohereProvider { + async fn rerank_structured( + &self, + request: StructuredRerankRequest, + ) -> Result, ResponseGeneratorError> + where + T: Serialize + Clone + Send + Sync, + { + let serialized_documents: Vec = request + .documents + .iter() + .map(|doc| { + match (request.serialization_mode, request.serialization_format) { + (SerializationMode::Json, SerializationFormat::Compact) => { + serde_json::to_string(doc).map_err(|e| e.to_string()) + } + (SerializationMode::Json, SerializationFormat::Pretty) => { + serde_json::to_string_pretty(doc).map_err(|e| e.to_string()) + } + (SerializationMode::Yaml, _) => { + serde_saphyr::to_string(doc).map_err(|e| e.to_string()) + } + } + .map_err(ResponseGeneratorError::StructuredRerankDocumentsSerializationError) + }) + .collect::, _>>()?; + + let response = reqwest_client + .post(format!("{}/rerank", self.base_url)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .header("Authorization", &format!("bearer {}", &self.api_key)) + .headers(self.headers.clone()) + .json(&json!({ + "model": &request.model.model_name, + "query": &request.query, + "documents": &serialized_documents, + "top_n": request.top_n, + })) + .send() + .await? + .error_for_status()? + .json::() + .await?; + + let (rankings, ranked_documents) = response.results.iter().try_fold( + ( + Vec::with_capacity(response.results.len()), + Vec::with_capacity(response.results.len()), + ), + |(mut rankings, mut docs), result| { + let document = request.documents.get(result.index).ok_or( + ResponseGeneratorError::InvalidProviderResponse( + "Cohere returned an invalid index".to_string(), + ), + )?; + docs.push(document.clone()); + rankings.push(StructuredRanking { + original_index: result.index, + score: result.relevance_score, + document: document.clone(), + }); + Ok::<_, ResponseGeneratorError>((rankings, docs)) + }, + )?; + + Ok(StructuredRerankResponse { + rankings, + ranked_documents, + raw_fields: response.raw_fields, + }) + } +} diff --git a/crates/umem_ai/src/providers/google_vertex.rs b/crates/umem_ai/src/providers/google_vertex.rs index c117961..b2c311d 100644 --- a/crates/umem_ai/src/providers/google_vertex.rs +++ b/crates/umem_ai/src/providers/google_vertex.rs @@ -1,5 +1,6 @@ -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; +#[derive(Clone, Debug)] pub struct GoogleVertexAIProvider { pub project: String, pub location: String, @@ -7,6 +8,7 @@ pub struct GoogleVertexAIProvider { pub credentials: GoogleCredentials, } +#[derive(Clone, Debug)] pub struct GoogleCredentials { client_email: String, private_key: String, diff --git a/crates/umem_ai/src/providers/mod.rs b/crates/umem_ai/src/providers/mod.rs index e996961..96567dc 100644 --- a/crates/umem_ai/src/providers/mod.rs +++ b/crates/umem_ai/src/providers/mod.rs @@ -1,6 +1,7 @@ mod amazon_bedrock; mod anthropic; mod azure_openai; +mod cohere; mod google_vertex; mod openai; mod xai; @@ -8,6 +9,7 @@ mod xai; pub use amazon_bedrock::AmazonBedrockProvider; pub use anthropic::AnthropicProvider; pub use azure_openai::AzureOpenAIProvider; +pub use cohere::CohereProvider; pub use google_vertex::GoogleVertexAIProvider; -pub use openai::*; +pub use openai::OpenAIProvider; pub use xai::XAIProvider; diff --git a/crates/umem_ai/src/providers/openai.rs b/crates/umem_ai/src/providers/openai.rs index 81e03a9..a59aa70 100644 --- a/crates/umem_ai/src/providers/openai.rs +++ b/crates/umem_ai/src/providers/openai.rs @@ -1,17 +1,17 @@ use crate::{ - reqwest_client, + GeneratesObject, GeneratesText, reqwest_client, response_generators::{ + GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, generate_object::{GenerateObjectRequest, GenerateObjectResponse}, messages::{FilePart, Message, UserMessagePart, UserModelMessage}, - GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, }, - utils, GeneratesObject, GeneratesText, + utils, }; use async_trait::async_trait; use base64::Engine; use reqwest::header::HeaderMap; use schemars::JsonSchema; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{Map, Value}; use typed_builder::TypedBuilder; @@ -23,9 +23,7 @@ pub struct OpenAIProvider { #[builder(default = "https://api.openai.com/v1".into(), setter(transform = |value: impl Into| value.into()))] pub base_url: String, - #[builder(default, setter(transform = |value: Vec<(String, String)>| - utils::build_header_map(value.as_slice()).unwrap_or_default() - ))] + #[builder(default, setter(transform = |value: Vec<(String, String)>| utils::build_header_map(value.as_slice()).unwrap_or_default()))] pub default_headers: HeaderMap, #[builder(default)] @@ -42,8 +40,8 @@ impl OpenAIProvider { &self, request: &GenerateObjectRequest, ) -> String { - let system = self.normalize_system_message(&request.messages); - let normalized_user_messages = self.normalize_user_messages(&request.messages); + let system = Self::normalize_system_message(&request.messages); + let normalized_user_messages = Self::normalize_user_messages(&request.messages); let schema = request.output_schema.clone(); let name = std::any::type_name::() .split("::") @@ -78,8 +76,8 @@ impl OpenAIProvider { } pub fn normalize_generate_text_request(&self, request: &GenerateTextRequest) -> String { - let system = self.normalize_system_message(&request.messages); - let normalized_user_messages = self.normalize_user_messages(&request.messages); + let system = Self::normalize_system_message(&request.messages); + let normalized_user_messages = Self::normalize_user_messages(&request.messages); serde_json::json!({ "model": request.model.model_name, @@ -102,7 +100,7 @@ impl OpenAIProvider { .to_string() } - fn normalize_system_message(&self, messages: &[Message]) -> String { + pub(crate) fn normalize_system_message(messages: &[Message]) -> String { messages .iter() .find_map(|msg| match msg { @@ -113,7 +111,7 @@ impl OpenAIProvider { .into() } - fn normalize_user_messages(&self, messages: &[Message]) -> Vec { + pub(crate) fn normalize_user_messages(messages: &[Message]) -> Vec { let user_messages: Vec<&UserModelMessage> = messages .iter() .filter_map(|msg| match msg { @@ -135,28 +133,28 @@ impl OpenAIProvider { serde_json::json!({"type":"input_text","text":input_text}) } UserMessagePart::Image(image_part) => match image_part { - FilePart::Url(ref image_url, _) => { + FilePart::Url(image_url, _) => { serde_json::json!({"type":"input_image","image_url":image_url}) } - FilePart::Base64(ref b64,ref media_type) => { + FilePart::Base64(b64, media_type) => { let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_image","image_url":format!("data:{};base64,{}", media_type.to_string(), b64)}) } - FilePart::Buffer(buf,ref media_type) => { + FilePart::Buffer(buf, media_type) => { let buf_as_b64 = base64::engine::general_purpose::STANDARD.encode(buf); let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_image","image_url":format!("data:{};base64,{}", media_type.to_string(), buf_as_b64)}) }, }, - UserMessagePart::File(file_part) => match file_part{ - FilePart::Url(ref image_url, _) => { + UserMessagePart::File(file_part) => match file_part { + FilePart::Url(image_url, _) => { serde_json::json!({"type":"input_file","file_url":image_url}) } - FilePart::Base64(ref b64,ref media_type) => { + FilePart::Base64(b64, media_type) => { let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_file","file_url":format!("data:{};base64,{}", media_type.to_string(), b64)}) } - FilePart::Buffer(buf,ref media_type) => { + FilePart::Buffer(buf, media_type) => { let buf_as_b64 = base64::engine::general_purpose::STANDARD.encode(buf); let media_type = media_type.clone().unwrap_or(mime::IMAGE_PNG); serde_json::json!({"type":"input_file","file_url":format!("data:{};base64,{}", media_type.to_string(), buf_as_b64)}) @@ -264,8 +262,8 @@ impl GeneratesObject for OpenAIProvider { }) .unwrap_or_default(); - let output: T = - serde_json::from_str(&output_text).map_err(ResponseGeneratorError::Serialization)?; + let output: T = serde_json::from_str(&output_text) + .map_err(|e| ResponseGeneratorError::Deserialization(e, output_text))?; Ok(GenerateObjectResponse { output }) } @@ -347,11 +345,12 @@ mod tests { use super::*; use crate::{ + AIProvider, LanguageModel, response_generators::{ - generate_object::{generate_object, GenerateObjectRequestBuilder}, - generate_text, GenerateTextRequestBuilder, + GenerateTextRequestBuilder, + generate_object::{GenerateObjectRequestBuilder, generate_object}, + generate_text, }, - AIProvider, LanguageModel, }; use std::sync::Arc; @@ -361,7 +360,7 @@ mod tests { OpenAIProvider::builder() .api_key("") .base_url("https://openrouter.ai/api/v1") - .build() + .build(), )); #[derive(Clone, JsonSchema, Serialize, Deserialize, Debug)] @@ -395,7 +394,7 @@ mod tests { OpenAIProvider::builder() .api_key("") .base_url("https://openrouter.ai/api/v1") - .build() + .build(), )); let model = Arc::new(LanguageModel { diff --git a/crates/umem_ai/src/providers/xai.rs b/crates/umem_ai/src/providers/xai.rs index 3610a29..da968e1 100644 --- a/crates/umem_ai/src/providers/xai.rs +++ b/crates/umem_ai/src/providers/xai.rs @@ -1,5 +1,6 @@ -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; +#[derive(Clone, Debug)] pub struct XAIProvider { pub api_key: String, pub base_url: String, diff --git a/crates/umem_ai/src/response_generators/generate_object.rs b/crates/umem_ai/src/response_generators/generate_object.rs index b53a5db..33b6e5b 100644 --- a/crates/umem_ai/src/response_generators/generate_object.rs +++ b/crates/umem_ai/src/response_generators/generate_object.rs @@ -1,9 +1,9 @@ +use crate::{LanguageModel, ResponseGeneratorError, utils}; use crate::{response_generators::messages::Message, utils::is_retryable_error}; -use crate::{utils, LanguageModel, ResponseGeneratorError}; use backon::{ExponentialBuilder, Retryable}; use reqwest::header::HeaderMap; -use schemars::{schema_for, JsonSchema, Schema}; -use serde::{de::DeserializeOwned, Serialize}; +use schemars::{JsonSchema, Schema, schema_for}; +use serde::{Serialize, de::DeserializeOwned}; use std::time::Duration; use std::{marker::PhantomData, sync::Arc}; use thiserror::Error; @@ -76,6 +76,15 @@ where pub timeout: Duration, } +impl GenerateObjectRequest +where + T: Clone + JsonSchema + Send + Sync + Serialize + DeserializeOwned, +{ + fn builder() -> GenerateObjectRequestBuilder { + GenerateObjectRequestBuilder::new() + } +} + #[derive(Debug)] pub struct GenerateObjectResponse { pub output: T, diff --git a/crates/umem_ai/src/response_generators/generate_text.rs b/crates/umem_ai/src/response_generators/generate_text.rs index 2e58bad..64446fa 100644 --- a/crates/umem_ai/src/response_generators/generate_text.rs +++ b/crates/umem_ai/src/response_generators/generate_text.rs @@ -1,8 +1,8 @@ +use crate::LanguageModel; +use crate::ResponseGeneratorError; use crate::response_generators::messages::Message; use crate::utils; use crate::utils::is_retryable_error; -use crate::LanguageModel; -use crate::ResponseGeneratorError; use backon::ExponentialBuilder; use backon::Retryable; use reqwest::header::HeaderMap; @@ -65,6 +65,12 @@ pub struct GenerateTextRequest { pub timeout: Duration, } +impl GenerateTextRequest { + pub fn builder() -> GenerateTextRequestBuilder { + GenerateTextRequestBuilder::new() + } +} + #[derive(Debug, Error)] pub enum GenerateTextRequestBuilderError { #[error("either set the `system` field or provide a system message in `messages` array")] diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index 63b79d9..b967ef0 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -1,19 +1,36 @@ pub mod generate_object; -mod generate_text; +pub mod generate_text; pub mod messages; +pub mod rerank; +pub mod structured_rerank; pub use generate_object::*; pub use generate_text::*; +pub use messages::*; +pub use rerank::*; +pub use structured_rerank::*; use thiserror::Error; #[derive(Error, Debug)] pub enum ResponseGeneratorError { #[error(transparent)] Http(#[from] reqwest::Error), + #[error("deserialization error, Details: {1}, Response: {0}")] + Deserialization(serde_json::Error, String), #[error(transparent)] - Serialization(#[from] serde_json::Error), + TimeoutError(#[from] tokio::time::error::Elapsed), + #[error("Bedrock Converse API error, Details: {0}")] + BedrockConverseError(String), + #[error("BedrockAgentRuntime Rerank Command error, Details: {0}")] + BedrockAgentRerankCommandSendError(String), + #[error("empty response from AI provider")] + EmptyProviderResponse, + #[error("invalid response from AI provider, Details: {0}")] + InvalidProviderResponse(String), + #[error("invalid arguments provided: {0}")] + InvalidArgumentsProvided(String), #[error(transparent)] Transient(#[from] anyhow::Error), - #[error(transparent)] - TimeoutError(#[from] tokio::time::error::Elapsed), + #[error("yaml serialization error: {0}")] + StructuredRerankDocumentsSerializationError(String), } diff --git a/crates/umem_ai/src/response_generators/rerank.rs b/crates/umem_ai/src/response_generators/rerank.rs new file mode 100644 index 0000000..bcfde7d --- /dev/null +++ b/crates/umem_ai/src/response_generators/rerank.rs @@ -0,0 +1,159 @@ +use backon::{ExponentialBuilder, Retryable}; +use serde_json::{Map, Value}; + +use crate::{RerankingModel, ResponseGeneratorError, utils::is_retryable_error}; +use std::{sync::Arc, time::Duration}; + +pub async fn rerank(request: RerankRequest) -> Result { + let per_request_timeout = request.timeout; + let max_retries = request.max_retries; + let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0); + + let reranking_request = || { + let model = Arc::clone(&request.model); + let provider = Arc::clone(&model.provider); + let request = request.clone(); + + async move { + tokio::time::timeout(per_request_timeout, provider.do_reranking(request)) + .await + .map_err(ResponseGeneratorError::TimeoutError) + .flatten() + } + }; + + reranking_request + .retry( + ExponentialBuilder::default() + .with_max_times(max_retries) + .with_total_delay(Some(total_delay)), + ) + .sleep(tokio::time::sleep) + .when(is_retryable_error) + .notify(|err, dur| { + tracing::debug!("retrying {:?} after {:?}", err, dur); + }) + .await +} + +#[derive(Clone)] +pub struct RerankRequest { + pub query: String, + pub documents: Vec, + pub top_k: usize, + pub timeout: Duration, + pub max_retries: usize, + pub model: Arc, +} + +impl RerankRequest { + pub fn builder() -> RerankRequestBuilder { + RerankRequestBuilder { + top_k: 5, + ..Default::default() + } + } +} + +pub struct RerankRequestBuilder { + query: Option, + documents: Vec, + top_k: usize, + timeout: Duration, + max_retries: usize, + model: Option>, +} + +impl Default for RerankRequestBuilder { + fn default() -> Self { + Self { + query: None, + documents: vec![], + top_k: 5, + timeout: Duration::from_secs(30), + max_retries: 3, + model: None, + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum RerankRequestBuilderError { + #[error("missing query from rerank request")] + MissingQuery, + + #[error("at least one document is required in rerank request")] + EmptyDocuments, + + #[error("missing model while sending reranking request")] + MissingModel, +} + +impl RerankRequestBuilder { + pub fn query(mut self, query: impl Into) -> Self { + self.query = Some(query.into()); + self + } + + pub fn documents(mut self, documents: I) -> Self + where + I: IntoIterator, + { + self.documents.extend(documents); + self + } + + pub fn document(mut self, document: impl Into) -> Self { + self.documents.push(document.into()); + self + } + + pub fn top_k(mut self, top_k: usize) -> Self { + self.top_k = top_k; + self + } + + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + pub fn model(mut self, model: Arc) -> Self { + self.model = Some(model); + self + } + + pub fn build(self) -> Result { + if self.documents.is_empty() { + return Err(RerankRequestBuilderError::EmptyDocuments); + } + + Ok(RerankRequest { + query: self.query.ok_or(RerankRequestBuilderError::MissingQuery)?, + documents: self.documents, + top_k: self.top_k, + timeout: self.timeout, + max_retries: self.max_retries, + model: Arc::clone(&self.model.ok_or(RerankRequestBuilderError::MissingModel)?), + }) + } +} + +#[derive(Debug)] +pub struct RerankResponse { + pub rankings: Vec, + pub ranked_documents: Vec, + pub raw_fields: Map, +} + +#[derive(Debug)] +pub struct Ranking { + pub original_index: usize, + pub score: f32, + pub document: String, +} diff --git a/crates/umem_ai/src/response_generators/structured_rerank.rs b/crates/umem_ai/src/response_generators/structured_rerank.rs new file mode 100644 index 0000000..64b7ff9 --- /dev/null +++ b/crates/umem_ai/src/response_generators/structured_rerank.rs @@ -0,0 +1,223 @@ +use std::{sync::Arc, time::Duration}; + +use backon::{ExponentialBuilder, Retryable}; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::{Map, Value}; + +use crate::{RerankingModel, ResponseGeneratorError, utils::is_retryable_error}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum SerializationFormat { + #[default] + Compact, + Pretty, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum SerializationMode { + #[default] + Json, + Yaml, +} + +pub async fn structured_rerank( + request: StructuredRerankRequest, +) -> Result, ResponseGeneratorError> +where + T: Serialize + Clone + Send + Sync, +{ + let per_request_timeout = request.timeout; + let max_retries = request.max_retries; + let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0); + + let reranking_request = || { + let model = Arc::clone(&request.model); + let provider = Arc::clone(&model.provider); + let request = request.clone(); + + async move { + tokio::time::timeout( + per_request_timeout, + provider.do_structured_reranking(request), + ) + .await + .map_err(ResponseGeneratorError::TimeoutError) + .flatten() + } + }; + + reranking_request + .retry( + ExponentialBuilder::default() + .with_max_times(max_retries) + .with_total_delay(Some(total_delay)), + ) + .sleep(tokio::time::sleep) + .when(is_retryable_error) + .notify(|err, dur| { + tracing::debug!("retrying {:?} after {:?}", err, dur); + }) + .await +} + +#[derive(Clone, Debug)] +pub struct StructuredRerankRequest +where + T: Serialize + Clone, +{ + pub query: String, + pub documents: Vec, + pub top_n: usize, + pub timeout: Duration, + pub max_retries: usize, + pub model: Arc, + pub serialization_format: SerializationFormat, + pub serialization_mode: SerializationMode, +} + +impl StructuredRerankRequest +where + T: Serialize + Clone, +{ + pub fn builder() -> StructuredRerankRequestBuilder { + StructuredRerankRequestBuilder::default() + } +} + +pub struct StructuredRerankRequestBuilder +where + T: Serialize + Clone, +{ + query: Option, + documents: Vec, + top_n: usize, + timeout: Duration, + max_retries: usize, + model: Option>, + serialization_format: SerializationFormat, + serialization_mode: SerializationMode, +} + +impl Default for StructuredRerankRequestBuilder +where + T: Serialize + Clone, +{ + fn default() -> Self { + Self { + query: None, + documents: vec![], + top_n: 5, + timeout: Duration::from_secs(30), + max_retries: 3, + model: None, + serialization_format: SerializationFormat::default(), + serialization_mode: SerializationMode::default(), + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum StructuredRerankRequestBuilderError { + #[error("missing query from structured rerank request")] + MissingQuery, + + #[error("at least one document is required in structured rerank request")] + EmptyDocuments, + + #[error("missing model while sending structured reranking request")] + MissingModel, +} + +impl StructuredRerankRequestBuilder +where + T: Serialize + DeserializeOwned + Clone, +{ + pub fn query(mut self, query: impl Into) -> Self { + self.query = Some(query.into()); + self + } + + pub fn documents(mut self, documents: I) -> Self + where + I: IntoIterator, + { + self.documents.extend(documents); + self + } + + pub fn document(mut self, document: T) -> Self { + self.documents.push(document); + self + } + + pub fn top_k(mut self, top_n: usize) -> Self { + self.top_n = top_n; + self + } + + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + pub fn model(mut self, model: Arc) -> Self { + self.model = Some(model); + self + } + + pub fn serialization_format(mut self, format: SerializationFormat) -> Self { + self.serialization_format = format; + self + } + + pub fn serialization_mode(mut self, mode: SerializationMode) -> Self { + self.serialization_mode = mode; + self + } + + pub fn build(self) -> Result, StructuredRerankRequestBuilderError> { + if self.documents.is_empty() { + return Err(StructuredRerankRequestBuilderError::EmptyDocuments); + } + + Ok(StructuredRerankRequest { + query: self + .query + .ok_or(StructuredRerankRequestBuilderError::MissingQuery)?, + documents: self.documents, + top_n: self.top_n, + timeout: self.timeout, + max_retries: self.max_retries, + model: self + .model + .ok_or(StructuredRerankRequestBuilderError::MissingModel)?, + serialization_format: self.serialization_format, + serialization_mode: self.serialization_mode, + }) + } +} + +#[derive(Clone, Debug)] +pub struct StructuredRerankResponse +where + T: Serialize + Clone, +{ + pub rankings: Vec>, + pub ranked_documents: Vec, + pub raw_fields: Map, +} + +#[derive(Clone, Debug)] +pub struct StructuredRanking +where + T: Serialize + Clone, +{ + pub original_index: usize, + pub score: f32, + pub document: T, +} diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index f00485a..8edd6d1 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use crate::response_generators::ResponseGeneratorError; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use thiserror::Error; @@ -15,10 +17,11 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { s.is_server_error() || s == reqwest::StatusCode::TOO_MANY_REQUESTS }) } - ResponseGeneratorError::Serialization(error) => { + ResponseGeneratorError::Deserialization(error, response) => { tracing::warn!( - "Serialization error, AI Might have built a bad JSON output: {}", - error + "Serialization error, AI Might have built a bad JSON output: \n Error: {} \n Received Response: {}", + error, + response ); true } @@ -36,6 +39,33 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { ); false } + ResponseGeneratorError::InvalidProviderResponse(e) => { + tracing::error!("Invalid response from AI provider: {}", e); + true + } + ResponseGeneratorError::EmptyProviderResponse => { + tracing::error!("Empty response from AI provider"); + true + } + ResponseGeneratorError::BedrockConverseError(sdk_error) => { + tracing::error!( + "AWS Bedrock Converse SDK error occurred when communicating with AI provider: {}", + sdk_error + ); + true + } + ResponseGeneratorError::StructuredRerankDocumentsSerializationError(e) => { + tracing::error!("Structured rerank documents serialization error: {}", e); + false + } + ResponseGeneratorError::InvalidArgumentsProvided(e) => { + tracing::error!("Invalid arguments provided: {}", e); + false + } + ResponseGeneratorError::BedrockAgentRerankCommandSendError(e) => { + tracing::error!("Bedrock agent rerank command send error: {}", e); + true + } } } @@ -68,3 +98,53 @@ pub fn build_header_map(headers: &[(String, String)]) -> Result aws_smithy_types::Document { + match value { + serde_json::Value::Null => aws_smithy_types::Document::Null, + serde_json::Value::Bool(b) => aws_smithy_types::Document::Bool(b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + aws_smithy_types::Document::from(i) + } else if let Some(u) = n.as_u64() { + aws_smithy_types::Document::from(u) + } else if let Some(f) = n.as_f64() { + aws_smithy_types::Document::from(f) + } else { + aws_smithy_types::Document::Null + } + } + serde_json::Value::String(s) => aws_smithy_types::Document::String(s), + serde_json::Value::Array(arr) => aws_smithy_types::Document::Array( + arr.into_iter().map(json_to_aws_smithy_document).collect(), + ), + serde_json::Value::Object(obj) => aws_smithy_types::Document::Object( + obj.into_iter() + .map(|(k, v)| (k, json_to_aws_smithy_document(v))) + .collect::>(), + ), + } +} + +pub fn aws_smithy_document_to_json(doc: &aws_smithy_types::Document) -> serde_json::Value { + match doc { + aws_smithy_types::Document::Null => serde_json::Value::Null, + aws_smithy_types::Document::Bool(b) => serde_json::Value::Bool(*b), + aws_smithy_types::Document::Number(n) => match n { + aws_smithy_types::Number::PosInt(u) => serde_json::Value::Number((*u).into()), + aws_smithy_types::Number::NegInt(i) => serde_json::Value::Number((*i).into()), + aws_smithy_types::Number::Float(f) => serde_json::Number::from_f64((*f).into()) + .map(serde_json::Value::Number) + .unwrap_or(serde_json::Value::Null), + }, + aws_smithy_types::Document::String(s) => serde_json::Value::String(s.clone()), + aws_smithy_types::Document::Array(arr) => { + serde_json::Value::Array(arr.into_iter().map(aws_smithy_document_to_json).collect()) + } + aws_smithy_types::Document::Object(obj) => serde_json::Value::Object( + obj.into_iter() + .map(|(k, v)| (k.clone(), aws_smithy_document_to_json(v))) + .collect(), + ), + } +}