diff --git a/Cargo.lock b/Cargo.lock index 9576d0ae..742e0a18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -97,6 +97,12 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + [[package]] name = "async-stream" version = "0.3.6" @@ -1261,6 +1267,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minicov" version = "0.3.7" @@ -1588,6 +1604,7 @@ dependencies = [ name = "progenitor-client" version = "0.11.2" dependencies = [ + "assert_matches", "bytes", "futures-core", "percent-encoding", @@ -1759,6 +1776,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -2721,6 +2739,12 @@ dependencies = [ "typify-impl", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/Cargo.toml b/Cargo.toml index 65f2cdd1..0ca0e02c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ quote = "1.0.42" rand = "0.9.2" regex = "1.11.2" regress = "0.10.5" -reqwest = { version = "0.12.28", default-features = false, features = ["json", "stream"] } +reqwest = { version = "0.12.28", default-features = false, features = ["json", "stream", "multipart"] } rustfmt-wrapper = "0.2.1" schemars = { version = "0.8.22", features = ["chrono", "uuid1"] } semver = "1.0.27" @@ -58,6 +58,7 @@ typify = { version = "0.5.0" } url = "2.5.7" unicode-ident = "1.0.22" uuid = { version = "1.19.0", features = ["serde", "v4"] } +assert_matches = "1.5.0" #[patch."https://github.com/oxidecomputer/typify"] #typify = { path = "../typify/typify" } diff --git a/progenitor-client/Cargo.toml b/progenitor-client/Cargo.toml index b54cc20a..2562fb54 100644 --- a/progenitor-client/Cargo.toml +++ b/progenitor-client/Cargo.toml @@ -16,5 +16,6 @@ serde_json = { workspace = true } serde_urlencoded = { workspace = true } [dev-dependencies] +assert_matches = { workspace = true } url = { workspace = true } uuid = { workspace = true } diff --git a/progenitor-client/src/progenitor_client.rs b/progenitor-client/src/progenitor_client.rs index 6f8dcca0..a0db1b2f 100644 --- a/progenitor-client/src/progenitor_client.rs +++ b/progenitor-client/src/progenitor_client.rs @@ -8,7 +8,7 @@ use std::ops::{Deref, DerefMut}; use bytes::Bytes; use futures_core::Stream; -use reqwest::RequestBuilder; +use reqwest::{multipart::Part, RequestBuilder}; use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize}; #[cfg(not(target_arch = "wasm32"))] @@ -17,6 +17,404 @@ type InnerByteStream = std::pin::Pin>>>; +/// A validated filename for form part uploads. +/// +/// Filenames are validated to not contain path separators or null bytes, +/// preventing path traversal attacks. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Filename(String); + +impl Filename { + /// Create a new filename, validating that it doesn't contain path separators. + /// + /// # Errors + /// Returns an error if the filename contains `/`, `\`, or null bytes. + pub fn new(name: impl Into) -> Result { + let name = name.into(); + if name.contains('/') || name.contains('\\') || name.contains('\0') { + Err(FilenameError::InvalidCharacter) + } else if name.is_empty() { + Err(FilenameError::Empty) + } else { + Ok(Self(name)) + } + } + + /// Create a filename without validation. + /// + /// # Safety + /// The caller must ensure the filename doesn't contain path separators. + /// This is useful when the filename comes from a trusted source. + pub fn new_unchecked(name: impl Into) -> Self { + Self(name.into()) + } + + /// Get the filename as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Consume the Filename and return the inner String. + pub fn into_string(self) -> String { + self.0 + } +} + +impl AsRef for Filename { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for Filename { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Error type for invalid filenames. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FilenameError { + /// Filename contains path separators (`/`, `\`) or null bytes. + InvalidCharacter, + /// Filename is empty. + Empty, +} + +impl std::fmt::Display for FilenameError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FilenameError::InvalidCharacter => { + write!(f, "filename contains invalid characters (/, \\, or null)") + } + FilenameError::Empty => write!(f, "filename cannot be empty"), + } + } +} + +impl std::error::Error for FilenameError {} + +/// A validated MIME content-type for form parts. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ContentType(String); + +impl ContentType { + /// Create a new content-type, validating the format. + /// + /// # Errors + /// Returns an error if the content-type doesn't follow `type/subtype` format. + pub fn new(content_type: impl Into) -> Result { + let content_type = content_type.into(); + let Some((type_part, rest)) = content_type.split_once('/') else { + return Err(ContentTypeError::InvalidFormat); + }; + let subtype_part = rest.split_once(';').map_or(rest, |(s, _)| s); + if type_part.trim().is_empty() + || subtype_part.trim().is_empty() + || subtype_part.contains('/') + { + return Err(ContentTypeError::InvalidFormat); + } + Ok(Self(content_type)) + } + + /// Create a content-type without validation. + /// + /// # Safety + /// The caller must ensure the content-type is in valid `type/subtype` format. + pub fn new_unchecked(content_type: impl Into) -> Self { + Self(content_type.into()) + } + + /// Get the content-type as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Consume the ContentType and return the inner String. + pub fn into_string(self) -> String { + self.0 + } + + /// Create `application/json` content type. + pub fn json() -> Self { + Self("application/json".to_string()) + } + + /// Create `application/octet-stream` content type. + pub fn octet_stream() -> Self { + Self("application/octet-stream".to_string()) + } + + /// Create an `application/{subtype}` content type. + pub fn application(subtype: &str) -> Self { + Self(format!("application/{}", subtype)) + } + + /// Create an `image/{subtype}` content type. + pub fn image(subtype: &str) -> Self { + Self(format!("image/{}", subtype)) + } +} + +impl AsRef for ContentType { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for ContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Error type for invalid content-types. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ContentTypeError { + /// Content-type doesn't follow the `type/subtype` format. + InvalidFormat, +} + +impl std::fmt::Display for ContentTypeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "content-type must be in 'type/subtype' format") + } +} + +impl std::error::Error for ContentTypeError {} + +/// Binary form part data with optional filename and content-type. +#[derive(Debug, Clone)] +pub struct BinaryFormPart { + /// The binary data + pub data: Bytes, + /// Optional filename for the part + pub filename: Option, + /// Optional content-type override + pub content_type: Option, +} + +impl BinaryFormPart { + /// Create a new binary form part from bytes. + pub fn new(data: impl Into) -> Self { + Self { + data: data.into(), + filename: None, + content_type: None, + } + } + + /// Create a binary form part with a filename. + pub fn with_filename(data: impl Into, filename: Filename) -> Self { + Self { + data: data.into(), + filename: Some(filename), + content_type: None, + } + } + + /// Create a binary form part with filename and content-type. + pub fn with_metadata( + data: impl Into, + filename: Option, + content_type: Option, + ) -> Self { + Self { + data: data.into(), + filename, + content_type, + } + } + + /// Create a builder for a binary form part. + pub fn builder(data: impl Into) -> BinaryFormPartBuilder { + BinaryFormPartBuilder { + data: data.into(), + filename: None, + content_type: None, + } + } +} + +impl From for FormPart { + fn from(part: BinaryFormPart) -> Self { + FormPart::Binary(part) + } +} + +/// Text form part data with optional content-type. +#[derive(Debug, Clone)] +pub struct TextFormPart { + /// The text value + pub value: String, + /// Optional content-type override + pub content_type: Option, +} + +impl TextFormPart { + /// Create a new text form part. + pub fn new(value: impl Into) -> Self { + Self { + value: value.into(), + content_type: None, + } + } + + /// Create a text form part with a specific content-type. + pub fn with_content_type(value: impl Into, content_type: ContentType) -> Self { + Self { + value: value.into(), + content_type: Some(content_type), + } + } + + /// Create a JSON text form part. + pub fn json(value: &T) -> Result { + Ok(Self { + value: serde_json::to_string(value)?, + content_type: Some(ContentType::json()), + }) + } + + /// Create a builder for a text form part. + pub fn builder(value: impl Into) -> TextFormPartBuilder { + TextFormPartBuilder { + value: value.into(), + content_type: None, + } + } +} + +impl From for FormPart { + fn from(part: TextFormPart) -> Self { + FormPart::Text(part) + } +} + +/// A part of a multipart form, either binary data or text. +#[derive(Debug, Clone)] +pub enum FormPart { + /// Binary data (e.g., file contents) with optional filename and content-type + Binary(BinaryFormPart), + /// Text data (will be sent as a text field) with optional content-type + Text(TextFormPart), +} + +impl FormPart { + /// Create a binary form part from bytes + pub fn binary(data: impl Into) -> Self { + Self::Binary(BinaryFormPart::new(data)) + } + + /// Create a binary form part with a filename + pub fn binary_with_filename(data: impl Into, filename: Filename) -> Self { + Self::Binary(BinaryFormPart::with_filename(data, filename)) + } + + /// Create a binary form part with filename and content-type + pub fn binary_with_metadata( + data: impl Into, + filename: Option, + content_type: Option, + ) -> Self { + Self::Binary(BinaryFormPart::with_metadata(data, filename, content_type)) + } + + /// Create a text form part from a string + pub fn text(data: impl Into) -> Self { + Self::Text(TextFormPart::new(data)) + } + + /// Create a text form part with a specific content-type + pub fn text_with_content_type(data: impl Into, content_type: ContentType) -> Self { + Self::Text(TextFormPart::with_content_type(data, content_type)) + } + + /// Create a JSON text form part + pub fn json(value: &T) -> Result { + Ok(Self::Text(TextFormPart::json(value)?)) + } + + /// Create a builder for a binary form part + pub fn binary_builder(data: impl Into) -> BinaryFormPartBuilder { + BinaryFormPart::builder(data) + } + + /// Create a builder for a text form part + pub fn text_builder(value: impl Into) -> TextFormPartBuilder { + TextFormPart::builder(value) + } +} + +/// Builder for binary form parts. +/// +/// Created via [`FormPart::binary_builder`] or [`BinaryFormPart::builder`]. +#[derive(Debug, Clone)] +pub struct BinaryFormPartBuilder { + data: Bytes, + filename: Option, + content_type: Option, +} + +impl BinaryFormPartBuilder { + /// Set the filename for this part. + pub fn filename(mut self, filename: Filename) -> Self { + self.filename = Some(filename); + self + } + + /// Set the content-type for this part. + pub fn content_type(mut self, content_type: ContentType) -> Self { + self.content_type = Some(content_type); + self + } + + /// Build the BinaryFormPart. + pub fn build(self) -> BinaryFormPart { + BinaryFormPart { + data: self.data, + filename: self.filename, + content_type: self.content_type, + } + } + + /// Build directly into a FormPart. + pub fn into_form_part(self) -> FormPart { + FormPart::Binary(self.build()) + } +} + +/// Builder for text form parts. +/// +/// Created via [`FormPart::text_builder`] or [`TextFormPart::builder`]. +#[derive(Debug, Clone)] +pub struct TextFormPartBuilder { + value: String, + content_type: Option, +} + +impl TextFormPartBuilder { + /// Set the content-type for this part. + pub fn content_type(mut self, content_type: ContentType) -> Self { + self.content_type = Some(content_type); + self + } + + /// Build the TextFormPart. + pub fn build(self) -> TextFormPart { + TextFormPart { + value: self.value, + content_type: self.content_type, + } + } + + /// Build directly into a FormPart. + pub fn into_form_part(self) -> FormPart { + FormPart::Text(self.build()) + } +} + /// Untyped byte stream used for both success and error responses. pub struct ByteStream(InnerByteStream); @@ -527,8 +925,21 @@ pub fn encode_path(pc: &str) -> String { } #[doc(hidden)] -pub trait RequestBuilderExt { +pub trait RequestBuilderExt +where + Self: Sized, +{ fn form_urlencoded(self, body: &T) -> Result>; + + fn form_from_raw, T: AsRef<[u8]>, I: Sized + IntoIterator>( + self, + iter: I, + ) -> Result>; + + fn form_from_parts, I: Sized + IntoIterator>( + self, + iter: I, + ) -> Result>; } impl RequestBuilderExt for RequestBuilder { @@ -543,6 +954,70 @@ impl RequestBuilderExt for RequestBuilder { .map_err(|_| Error::InvalidRequest("failed to serialize body".to_string()))?, )) } + + fn form_from_raw, T: AsRef<[u8]>, I: Sized + IntoIterator>( + self, + iter: I, + ) -> Result> { + use reqwest::multipart::Form; + + let mut form = Form::new(); + for (name, value) in iter { + form = form.part( + name.as_ref().to_owned(), + Part::stream(Vec::from(value.as_ref())), + ); + } + // Note: reqwest's .multipart() automatically sets the Content-Type header + // with the correct boundary, so we don't set it manually here. + Ok(self.multipart(form)) + } + + fn form_from_parts, I: Sized + IntoIterator>( + self, + iter: I, + ) -> Result> { + use reqwest::multipart::Form; + + let mut form = Form::new(); + for (name, part) in iter { + let name = name.as_ref().to_owned(); + form = match part { + FormPart::Binary(BinaryFormPart { + data, + filename, + content_type, + }) => { + let mut p = Part::stream(data.to_vec()); + if let Some(fname) = filename { + p = p.file_name(fname.into_string()); + } + if let Some(ct) = content_type { + p = p.mime_str(ct.as_str()).map_err(|e| { + Error::InvalidRequest(format!("invalid content-type: {}", e)) + })?; + } + form.part(name, p) + } + FormPart::Text(TextFormPart { + value, + content_type, + }) => { + if let Some(ct) = content_type { + let p = Part::text(value).mime_str(ct.as_str()).map_err(|e| { + Error::InvalidRequest(format!("invalid content-type: {}", e)) + })?; + form.part(name, p) + } else { + form.text(name, value) + } + } + }; + } + // Note: reqwest's .multipart() automatically sets the Content-Type header + // with the correct boundary, so we don't set it manually here. + Ok(self.multipart(form)) + } } #[doc(hidden)] diff --git a/progenitor-client/tests/client_test.rs b/progenitor-client/tests/client_test.rs index 6cd6a0d9..6c300b62 100644 --- a/progenitor-client/tests/client_test.rs +++ b/progenitor-client/tests/client_test.rs @@ -219,3 +219,366 @@ fn test_query_option() { let result = encode_query_param("paramName", &value).unwrap(); assert_eq!(result, "paramName=42"); } + +mod form_part_tests { + use assert_matches::assert_matches; + use bytes::Bytes; + use progenitor_client::{BinaryFormPart, ContentType, Filename, FormPart, TextFormPart}; + use serde::Serialize; + + #[test] + fn test_form_part_binary() { + let data = vec![1u8, 2, 3, 4]; + let part = FormPart::binary(data.clone()); + + assert_matches!(part, FormPart::Binary(BinaryFormPart { data: d, filename: None, content_type: None }) => { + assert_eq!(d.as_ref(), &data[..]); + }); + } + + #[test] + fn test_form_part_binary_from_bytes() { + let data = Bytes::from_static(b"hello world"); + let part = FormPart::binary(data.clone()); + + assert_matches!(part, FormPart::Binary(BinaryFormPart { data: d, .. }) => { + assert_eq!(d, data); + }); + } + + #[test] + fn test_form_part_binary_with_filename() { + let part = + FormPart::binary_with_filename(vec![1u8, 2, 3], Filename::new("test.bin").unwrap()); + + assert_matches!(part, FormPart::Binary(BinaryFormPart { filename: Some(f), content_type: None, .. }) => { + assert_eq!(f.as_str(), "test.bin"); + }); + } + + #[test] + fn test_form_part_binary_with_metadata() { + let part = FormPart::binary_with_metadata( + vec![1u8, 2, 3], + Some(Filename::new("document.pdf").unwrap()), + Some(ContentType::application("pdf")), + ); + + assert_matches!(part, FormPart::Binary(BinaryFormPart { filename: Some(f), content_type: Some(ct), .. }) => { + assert_eq!(f.as_str(), "document.pdf"); + assert_eq!(ct.as_str(), "application/pdf"); + }); + } + + #[test] + fn test_form_part_binary_with_metadata_none() { + let part = FormPart::binary_with_metadata(vec![1u8, 2, 3], None, None); + assert_matches!( + part, + FormPart::Binary(BinaryFormPart { + filename: None, + content_type: None, + .. + }) + ); + } + + #[test] + fn test_form_part_text() { + let part = FormPart::text("hello world"); + + assert_matches!(part, FormPart::Text(TextFormPart { value, content_type: None }) => { + assert_eq!(value, "hello world"); + }); + } + + #[test] + fn test_form_part_text_from_string() { + let part = FormPart::text(String::from("hello world")); + + assert_matches!(part, FormPart::Text(TextFormPart { value, .. }) => { + assert_eq!(value, "hello world"); + }); + } + + #[test] + fn test_form_part_text_with_content_type() { + let part = FormPart::text_with_content_type("{\"key\": \"value\"}", ContentType::json()); + + assert_matches!(part, FormPart::Text(TextFormPart { value, content_type: Some(ct) }) => { + assert_eq!(value, "{\"key\": \"value\"}"); + assert_eq!(ct.as_str(), "application/json"); + }); + } + + #[test] + fn test_form_part_json() { + #[derive(Serialize)] + struct TestData { + name: String, + count: i32, + } + + let data = TestData { + name: "test".to_string(), + count: 42, + }; + let part = FormPart::json(&data).unwrap(); + + assert_matches!(part, FormPart::Text(TextFormPart { value, content_type: Some(ct) }) => { + assert_eq!(value, r#"{"name":"test","count":42}"#); + assert_eq!(ct.as_str(), "application/json"); + }); + } + + #[test] + fn test_form_part_json_array() { + let part = FormPart::json(&vec![1, 2, 3, 4, 5]).unwrap(); + + assert_matches!(part, FormPart::Text(TextFormPart { value, content_type: Some(ct) }) => { + assert_eq!(value, "[1,2,3,4,5]"); + assert_eq!(ct.as_str(), "application/json"); + }); + } + + #[test] + fn test_form_part_json_nested() { + #[derive(Serialize)] + struct Inner { + value: String, + } + + #[derive(Serialize)] + struct Outer { + inner: Inner, + tags: Vec, + } + + let data = Outer { + inner: Inner { + value: "nested".to_string(), + }, + tags: vec!["a".to_string(), "b".to_string()], + }; + let part = FormPart::json(&data).unwrap(); + + assert_matches!(part, FormPart::Text(TextFormPart { value, .. }) => { + assert!(value.contains("\"inner\"")); + assert!(value.contains("\"nested\"")); + assert!(value.contains("\"tags\"")); + }); + } + + #[test] + fn test_form_part_clone() { + let part1 = + FormPart::binary_with_filename(vec![1, 2, 3], Filename::new("test.bin").unwrap()); + let part2 = part1.clone(); + + assert_matches!( + (part1, part2), + (FormPart::Binary(BinaryFormPart { data: d1, filename: f1, .. }), FormPart::Binary(BinaryFormPart { data: d2, filename: f2, .. })) => { + assert_eq!(d1, d2); + assert_eq!(f1, f2); + } + ); + } + + #[test] + fn test_form_part_debug() { + let part = FormPart::text("test"); + let debug_str = format!("{:?}", part); + assert!(debug_str.contains("Text")); + assert!(debug_str.contains("test")); + } + + // Tests for the new type-safe newtypes + #[test] + fn test_filename_validation() { + // Valid filenames + assert!(Filename::new("test.bin").is_ok()); + assert!(Filename::new("my-file_123.txt").is_ok()); + assert!(Filename::new("file with spaces.pdf").is_ok()); + + // Invalid filenames (path separators) + assert!(Filename::new("path/to/file.txt").is_err()); + assert!(Filename::new("path\\to\\file.txt").is_err()); + assert!(Filename::new("file\0name.txt").is_err()); + assert!(Filename::new("").is_err()); + } + + #[test] + fn test_filename_unchecked() { + // new_unchecked bypasses validation (useful for trusted sources) + let f = Filename::new_unchecked("any/path/works"); + assert_eq!(f.as_str(), "any/path/works"); + } + + #[test] + fn test_content_type_json() { + assert_eq!(ContentType::json().as_str(), "application/json"); + } + + #[test] + fn test_content_type_new_unchecked() { + assert_eq!( + ContentType::new_unchecked("text/html").as_str(), + "text/html" + ); + } + + #[test] + fn test_content_type_application() { + assert_eq!(ContentType::application("pdf").as_str(), "application/pdf"); + assert_eq!(ContentType::application("xml").as_str(), "application/xml"); + } + + #[test] + fn test_content_type_image() { + assert_eq!(ContentType::image("png").as_str(), "image/png"); + assert_eq!(ContentType::image("jpeg").as_str(), "image/jpeg"); + } + + #[test] + fn test_content_type_octet_stream() { + assert_eq!( + ContentType::octet_stream().as_str(), + "application/octet-stream" + ); + } + + #[test] + fn test_filename_display() { + let f = Filename::new("test.bin").unwrap(); + assert_eq!(format!("{}", f), "test.bin"); + } + + #[test] + fn test_content_type_display() { + let ct = ContentType::json(); + assert_eq!(format!("{}", ct), "application/json"); + } + + #[test] + fn test_filename_as_ref() { + let f = Filename::new("test.bin").unwrap(); + let s: &str = f.as_ref(); + assert_eq!(s, "test.bin"); + } + + #[test] + fn test_content_type_as_ref() { + let ct = ContentType::json(); + let s: &str = ct.as_ref(); + assert_eq!(s, "application/json"); + } + + // Builder pattern tests + #[test] + fn test_binary_builder_minimal() { + let part = FormPart::binary_builder(vec![1u8, 2, 3]).build(); + assert_matches!( + part, + BinaryFormPart { + filename: None, + content_type: None, + .. + } + ); + } + + #[test] + fn test_binary_builder_with_filename() { + let part = FormPart::binary_builder(vec![1u8, 2, 3]) + .filename(Filename::new("test.bin").unwrap()) + .build(); + + assert_matches!(part, BinaryFormPart { filename: Some(f), content_type: None, .. } => { + assert_eq!(f.as_str(), "test.bin"); + }); + } + + #[test] + fn test_binary_builder_with_content_type() { + let part = FormPart::binary_builder(vec![1u8, 2, 3]) + .content_type(ContentType::image("png")) + .build(); + + assert_matches!(part, BinaryFormPart { filename: None, content_type: Some(ct), .. } => { + assert_eq!(ct.as_str(), "image/png"); + }); + } + + #[test] + fn test_binary_builder_full() { + let part = FormPart::binary_builder(vec![1u8, 2, 3]) + .filename(Filename::new("document.pdf").unwrap()) + .content_type(ContentType::application("pdf")) + .build(); + + assert_matches!(part, BinaryFormPart { data, filename: Some(f), content_type: Some(ct) } => { + assert_eq!(data.as_ref(), &[1u8, 2, 3]); + assert_eq!(f.as_str(), "document.pdf"); + assert_eq!(ct.as_str(), "application/pdf"); + }); + } + + #[test] + fn test_text_builder_minimal() { + let part = FormPart::text_builder("hello").build(); + + assert_matches!(part, TextFormPart { value, content_type: None } => { + assert_eq!(value, "hello"); + }); + } + + #[test] + fn test_text_builder_with_content_type() { + let part = FormPart::text_builder("{}") + .content_type(ContentType::json()) + .build(); + + assert_matches!(part, TextFormPart { value, content_type: Some(ct) } => { + assert_eq!(value, "{}"); + assert_eq!(ct.as_str(), "application/json"); + }); + } + + #[test] + fn test_builder_into_form_part() { + // Test that builders can be converted directly to FormPart + let part: FormPart = FormPart::binary_builder(vec![1u8, 2, 3]) + .filename(Filename::new("test.bin").unwrap()) + .into_form_part(); + assert_matches!( + part, + FormPart::Binary(BinaryFormPart { + filename: Some(_), + .. + }) + ); + + let part: FormPart = FormPart::text_builder("hello") + .content_type(ContentType::json()) + .into_form_part(); + assert_matches!( + part, + FormPart::Text(TextFormPart { + content_type: Some(_), + .. + }) + ); + } + + #[test] + fn test_inner_types_into_form_part() { + // Test that BinaryFormPart and TextFormPart can be converted to FormPart via From + let binary = BinaryFormPart::new(vec![1u8, 2, 3]); + let part: FormPart = binary.into(); + assert_matches!(part, FormPart::Binary(_)); + + let text = TextFormPart::new("hello"); + let part: FormPart = text.into(); + assert_matches!(part, FormPart::Text(_)); + } +} diff --git a/progenitor-impl/src/cli.rs b/progenitor-impl/src/cli.rs index e1bdf30c..8d3a349e 100644 --- a/progenitor-impl/src/cli.rs +++ b/progenitor-impl/src/cli.rs @@ -441,7 +441,8 @@ impl Generator { // are currently... OperationParameterType::RawBody => None, - OperationParameterType::Type(body_type_id) => Some(body_type_id), + OperationParameterType::Type(body_type_id) + | OperationParameterType::Form(body_type_id) => Some(body_type_id), }); if let Some(body_type_id) = maybe_body_type_id { diff --git a/progenitor-impl/src/httpmock.rs b/progenitor-impl/src/httpmock.rs index ed56fbf6..5c1bbf12 100644 --- a/progenitor-impl/src/httpmock.rs +++ b/progenitor-impl/src/httpmock.rs @@ -156,7 +156,8 @@ impl Generator { description: _, }| { let arg_type_name = match typ { - OperationParameterType::Type(arg_type_id) => self + OperationParameterType::Type(arg_type_id) + | OperationParameterType::Form(arg_type_id) => self .type_space .get_type(arg_type_id) .unwrap() @@ -226,7 +227,7 @@ impl Generator { }, ), OperationParameterKind::Body(body_content_type) => match typ { - OperationParameterType::Type(_) => ( + OperationParameterType::Type(_) | OperationParameterType::Form(_) => ( true, quote! { Self(self.0.json_body_obj(value)) diff --git a/progenitor-impl/src/lib.rs b/progenitor-impl/src/lib.rs index 68454ece..b48bbbde 100644 --- a/progenitor-impl/src/lib.rs +++ b/progenitor-impl/src/lib.rs @@ -6,12 +6,13 @@ use std::collections::{BTreeMap, HashMap, HashSet}; +use indexmap::IndexMap; use openapiv3::OpenAPI; use proc_macro2::TokenStream; use quote::quote; use serde::Deserialize; use thiserror::Error; -use typify::{TypeSpace, TypeSpaceSettings}; +use typify::{TypeDetails, TypeId, TypeSpace, TypeSpaceSettings}; use crate::to_schema::ToSchema; @@ -50,11 +51,33 @@ pub type Result = std::result::Result; /// OpenAPI generator. pub struct Generator { type_space: TypeSpace, + /// Maps form type IDs to their field metadata (field name -> is_binary) + forms: IndexMap, settings: GenerationSettings, uses_futures: bool, uses_websockets: bool, } +/// Information about form fields for multipart/form-data generation +#[derive(Debug, Clone, Default)] +pub(crate) struct FormFieldsInfo { + /// Maps field names to their metadata + pub fields: IndexMap, +} + +/// Metadata about a single form field +#[derive(Debug, Clone, Default)] +pub(crate) struct FormFieldMeta { + /// The original API name for the field (used in form field name) + pub api_name: String, + /// Whether the field is binary (format: binary) + pub is_binary: bool, + /// Whether the field is a complex type (array or object) requiring JSON serialization + pub needs_json: bool, + /// Content-type override from the encoding object + pub content_type: Option, +} + /// Settings for [Generator]. #[derive(Default, Clone)] pub struct GenerationSettings { @@ -261,6 +284,7 @@ impl Default for Generator { fn default() -> Self { Self { type_space: TypeSpace::new(TypeSpaceSettings::default().with_type_mod("types")), + forms: Default::default(), settings: Default::default(), uses_futures: Default::default(), uses_websockets: Default::default(), @@ -312,6 +336,7 @@ impl Generator { Self { type_space: TypeSpace::new(&type_settings), + forms: Default::default(), settings: settings.clone(), uses_futures: false, uses_websockets: false, @@ -374,6 +399,90 @@ impl Generator { let types = self.type_space.to_stream(); + // Generate as_form() implementations for form data types. + // Each form type gets a method that returns an iterator of (name, FormPart) pairs. + let extra_impl = TokenStream::from_iter(self.forms.iter().map(|(type_id, field_info)| { + let typ = self.get_type_space().get_type(type_id).unwrap(); + let td = typ.details(); + let TypeDetails::Struct(tstru) = td else { + unreachable!() + }; + + // Generate field accessors that convert each field to the appropriate FormPart + let field_conversions = tstru.properties().filter_map(|(prop_name, _prop_id)| { + let meta = field_info.fields.get(prop_name)?; + let ident = quote::format_ident!("{}", prop_name); + // Use the original API name for the form field name + let api_name = &meta.api_name; + let content_type = meta + .content_type + .as_ref() + .map(|ct| { + quote! { Some(progenitor_client::ContentType::new_unchecked(#ct)) } + }) + .unwrap_or_else(|| quote! { None }); + + if meta.is_binary { + // Binary fields: Option -> FormPart::Binary + // Use .into() to support custom binary types via with_conversion + Some(quote! { + if let Some(ref val) = self.#ident { + parts.push((#api_name, progenitor_client::FormPart::Binary( + progenitor_client::BinaryFormPart { + data: val.clone().into(), + filename: None, + content_type: #content_type, + } + ))); + } + }) + } else if meta.needs_json { + // Complex types (array/object): JSON serialize + let json_content_type = if meta.content_type.is_some() { + content_type.clone() + } else { + quote! { Some(progenitor_client::ContentType::json()) } + }; + Some(quote! { + if let Some(ref val) = self.#ident { + parts.push((#api_name, progenitor_client::FormPart::Text( + progenitor_client::TextFormPart { + value: serde_json::to_string(val).unwrap_or_default(), + content_type: #json_content_type, + } + ))); + } + }) + } else { + // Simple text fields: serialize to string + Some(quote! { + if let Some(ref val) = self.#ident { + parts.push((#api_name, progenitor_client::FormPart::Text( + progenitor_client::TextFormPart { + value: val.to_string(), + content_type: #content_type, + } + ))); + } + }) + } + }); + + let form_name = quote::format_ident!("{}", typ.name()); + + quote! { + impl #form_name { + /// Convert this form into an iterator of (field_name, field_value) pairs + /// suitable for multipart/form-data encoding. + pub fn as_form(&self) -> Vec<(&'static str, progenitor_client::FormPart)> { + let mut parts = Vec::new(); + #(#field_conversions)* + parts + } + } + } + })); + let (inner_type, inner_fn_value) = match self.settings.inner_type.as_ref() { Some(inner_type) => (inner_type.clone(), quote! { &self.inner }), None => (quote! { () }, quote! { &() }), @@ -440,6 +549,8 @@ impl Generator { #[allow(clippy::all)] pub mod types { #types + + #extra_impl } #[derive(Clone, Debug)] diff --git a/progenitor-impl/src/method.rs b/progenitor-impl/src/method.rs index 8bf1d0fe..a263f44e 100644 --- a/progenitor-impl/src/method.rs +++ b/progenitor-impl/src/method.rs @@ -6,6 +6,8 @@ use std::{ str::FromStr, }; +use indexmap::IndexMap; + use openapiv3::{Components, Parameter, ReferenceOr, Response, StatusCode}; use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; @@ -102,9 +104,10 @@ pub struct OperationParameter { pub kind: OperationParameterKind, } -#[derive(Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq)] pub enum OperationParameterType { Type(TypeId), + Form(TypeId), RawBody, } @@ -137,6 +140,7 @@ pub enum BodyContentType { OctetStream, Json, FormUrlencoded, + FormData, Text(String), } @@ -149,6 +153,7 @@ impl FromStr for BodyContentType { "application/octet-stream" => Ok(Self::OctetStream), "application/json" => Ok(Self::Json), "application/x-www-form-urlencoded" => Ok(Self::FormUrlencoded), + "form-data" | "multipart/form-data" => Ok(Self::FormData), "text/plain" | "text/x-markdown" => Ok(Self::Text(String::from(&s[..offset]))), _ => Err(Error::UnexpectedFormat(format!( "unexpected content type: {}", @@ -164,6 +169,7 @@ impl std::fmt::Display for BodyContentType { Self::OctetStream => "application/octet-stream", Self::Json => "application/json", Self::FormUrlencoded => "application/x-www-form-urlencoded", + Self::FormData => "multipart/form-data", Self::Text(typ) => typ, }) } @@ -562,12 +568,20 @@ impl Generator { .map(|param| { let name = format_ident!("{}", param.name); let typ = match (¶m.typ, param.kind.is_optional()) { - (OperationParameterType::Type(type_id), false) => self + ( + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id), + false, + ) => self .type_space .get_type(type_id) .unwrap() .parameter_ident_with_lifetime("a"), - (OperationParameterType::Type(type_id), true) => { + ( + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id), + true, + ) => { let t = self .type_space .get_type(type_id) @@ -918,6 +932,16 @@ impl Generator { // returns an error in the case of a serialization failure. .form_urlencoded(&body)? }), + ( + OperationParameterKind::Body(BodyContentType::FormData), + OperationParameterType::Form(_), + ) => { + Some(quote! { + // This uses `progenitor_client::RequestBuilderExt` which + // builds a multipart form from the form parts + .form_from_parts(body.as_form())? + }) + } (OperationParameterKind::Body(_), _) => { unreachable!("invalid body kind/type combination") } @@ -1436,7 +1460,7 @@ impl Generator { .params .iter() .map(|param| match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; // For body parameters only, if there's a builder we'll @@ -1469,7 +1493,7 @@ impl Generator { .params .iter() .map(|param| match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; // Fill in the appropriate initial value for the @@ -1499,7 +1523,7 @@ impl Generator { .params .iter() .map(|param| match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; if ty.builder().is_some() { let type_name = ty.ident(); @@ -1523,7 +1547,8 @@ impl Generator { .map(|param| { let param_name = format_ident!("{}", param.name); match ¶m.typ { - OperationParameterType::Type(type_id) => { + OperationParameterType::Type(type_id) + | OperationParameterType::Form(type_id) => { let ty = self.type_space.get_type(type_id)?; match (ty.builder(), param.kind.is_optional()) { // TODO right now optional body parameters are not @@ -2131,6 +2156,95 @@ impl Generator { }?; OperationParameterType::RawBody } + BodyContentType::FormData => { + // For multipart/form-data, we accept an object schema with various property types: + // - type: string, format: binary -> Binary file data (bytes::Bytes) + // - type: string -> Text field + // - type: integer/number/boolean -> Text field (serialized) + // - type: object -> JSON-serialized text field + // - type: array -> JSON-serialized text field + // + // We track field metadata to generate appropriate as_form() implementations. + // The encoding object can override content-types for individual fields. + + let field_info = match schema.item(components)? { + openapiv3::Schema { + schema_kind: + openapiv3::SchemaKind::Type(openapiv3::Type::Object( + openapiv3::ObjectType { properties, .. }, + )), + .. + } => { + let mut fields = IndexMap::new(); + for (name, property) in properties { + let (is_binary, needs_json) = match property { + ReferenceOr::Item(prop) => { + let is_binary = matches!( + &prop.schema_kind, + openapiv3::SchemaKind::Type(openapiv3::Type::String( + openapiv3::StringType { + format: openapiv3::VariantOrUnknownOrEmpty::Item( + openapiv3::StringFormat::Binary, + ), + .. + }, + )) + ); + let needs_json = matches!( + &prop.schema_kind, + openapiv3::SchemaKind::Type(openapiv3::Type::Array(_)) + | openapiv3::SchemaKind::Type(openapiv3::Type::Object( + _ + )) + ); + (is_binary, needs_json) + } + ReferenceOr::Reference { .. } => { + // References are assumed to be complex types requiring JSON + (false, true) + } + }; + + // Check for encoding overrides (content-type per field) + let content_type = media_type + .encoding + .get(name) + .and_then(|enc| enc.content_type.clone()); + + // Use the Rust-ized (snake_case) name as the key since + // typify's properties() returns Rust names, but store the + // original API name for use in the form field name. + let rust_name = sanitize(name, Case::Snake); + fields.insert( + rust_name, + crate::FormFieldMeta { + api_name: name.clone(), + is_binary, + needs_json, + content_type, + }, + ); + } + crate::FormFieldsInfo { fields } + } + _ => { + return Err(Error::UnexpectedFormat(format!( + "multipart/form-data requires an object schema, got: {:?}", + schema + ))); + } + }; + + let form_name = sanitize( + &format!("{}-form", operation.operation_id.as_ref().unwrap(),), + Case::Pascal, + ); + let type_id = self + .type_space + .add_type_with_name(&schema.to_schema(), Some(form_name))?; + self.forms.insert(type_id.clone(), field_info); + OperationParameterType::Form(type_id) + } BodyContentType::Json | BodyContentType::FormUrlencoded => { // TODO it would be legal to have the encoding field set for // application/x-www-form-urlencoded content, but I'm not sure diff --git a/progenitor-impl/tests/output/src/example_multipart_builder.rs b/progenitor-impl/tests/output/src/example_multipart_builder.rs new file mode 100644 index 00000000..5e2d0ea7 --- /dev/null +++ b/progenitor-impl/tests/output/src/example_multipart_builder.rs @@ -0,0 +1,331 @@ +#[allow(unused_imports)] +use progenitor_client::{encode_path, ClientHooks, OperationInfo, RequestBuilderExt}; +#[allow(unused_imports)] +pub use progenitor_client::{ByteStream, ClientInfo, Error, ResponseValue}; +/// Types used as operation parameters and responses. +#[allow(clippy::all)] +pub mod types { + /// Error types. + pub mod error { + /// Error from a `TryFrom` or `FromStr` implementation. + pub struct ConversionError(::std::borrow::Cow<'static, str>); + impl ::std::error::Error for ConversionError {} + impl ::std::fmt::Display for ConversionError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> { + ::std::fmt::Display::fmt(&self.0, f) + } + } + + impl ::std::fmt::Debug for ConversionError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> { + ::std::fmt::Debug::fmt(&self.0, f) + } + } + + impl From<&'static str> for ConversionError { + fn from(value: &'static str) -> Self { + Self(value.into()) + } + } + + impl From for ConversionError { + fn from(value: String) -> Self { + Self(value.into()) + } + } + } + + ///`UploadForm` + /// + ///
JSON schema + /// + /// ```json + ///{ + /// "type": "object", + /// "required": [ + /// "file", + /// "name" + /// ], + /// "properties": { + /// "file": { + /// "type": "string", + /// "format": "binary" + /// }, + /// "name": { + /// "type": "string" + /// } + /// } + ///} + /// ``` + ///
+ #[derive( + :: serde :: Deserialize, :: serde :: Serialize, Clone, Debug, schemars :: JsonSchema, + )] + pub struct UploadForm { + pub file: ::std::string::String, + pub name: ::std::string::String, + } + + impl ::std::convert::From<&UploadForm> for UploadForm { + fn from(value: &UploadForm) -> Self { + value.clone() + } + } + + impl UploadForm { + pub fn builder() -> builder::UploadForm { + Default::default() + } + } + + /// Types for composing complex structures. + pub mod builder { + #[derive(Clone, Debug)] + pub struct UploadForm { + file: ::std::result::Result<::std::string::String, ::std::string::String>, + name: ::std::result::Result<::std::string::String, ::std::string::String>, + } + + impl ::std::default::Default for UploadForm { + fn default() -> Self { + Self { + file: Err("no value supplied for file".to_string()), + name: Err("no value supplied for name".to_string()), + } + } + } + + impl UploadForm { + pub fn file(mut self, value: T) -> Self + where + T: ::std::convert::TryInto<::std::string::String>, + T::Error: ::std::fmt::Display, + { + self.file = value + .try_into() + .map_err(|e| format!("error converting supplied value for file: {}", e)); + self + } + pub fn name(mut self, value: T) -> Self + where + T: ::std::convert::TryInto<::std::string::String>, + T::Error: ::std::fmt::Display, + { + self.name = value + .try_into() + .map_err(|e| format!("error converting supplied value for name: {}", e)); + self + } + } + + impl ::std::convert::TryFrom for super::UploadForm { + type Error = super::error::ConversionError; + fn try_from( + value: UploadForm, + ) -> ::std::result::Result { + Ok(Self { + file: value.file?, + name: value.name?, + }) + } + } + + impl ::std::convert::From for UploadForm { + fn from(value: super::UploadForm) -> Self { + Self { + file: Ok(value.file), + name: Ok(value.name), + } + } + } + } + + impl UploadForm { + /// Convert this form into an iterator of (field_name, field_value) + /// pairs + /// suitable for multipart/form-data encoding. + pub fn as_form(&self) -> Vec<(&'static str, progenitor_client::FormPart)> { + let mut parts = Vec::new(); + if let Some(ref val) = self.file { + parts.push(( + "file", + progenitor_client::FormPart::Binary(progenitor_client::BinaryFormPart { + data: val.clone().into(), + filename: None, + content_type: None, + }), + )); + } + if let Some(ref val) = self.name { + parts.push(( + "name", + progenitor_client::FormPart::Text(progenitor_client::TextFormPart { + value: val.to_string(), + content_type: None, + }), + )); + } + parts + } + } +} + +#[derive(Clone, Debug)] +///Client for Multipart Example +/// +///Version: 1.0.0 +pub struct Client { + pub(crate) baseurl: String, + pub(crate) client: reqwest::Client, +} + +impl Client { + /// Create a new client. + /// + /// `baseurl` is the base URL provided to the internal + /// `reqwest::Client`, and should include a scheme and hostname, + /// as well as port and a path stem if applicable. + pub fn new(baseurl: &str) -> Self { + #[cfg(not(target_arch = "wasm32"))] + let client = { + let dur = ::std::time::Duration::from_secs(15u64); + reqwest::ClientBuilder::new() + .connect_timeout(dur) + .timeout(dur) + }; + #[cfg(target_arch = "wasm32")] + let client = reqwest::ClientBuilder::new(); + Self::new_with_client(baseurl, client.build().unwrap()) + } + + /// Construct a new client with an existing `reqwest::Client`, + /// allowing more control over its configuration. + /// + /// `baseurl` is the base URL provided to the internal + /// `reqwest::Client`, and should include a scheme and hostname, + /// as well as port and a path stem if applicable. + pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Self { + Self { + baseurl: baseurl.to_string(), + client, + } + } +} + +impl ClientInfo<()> for Client { + fn api_version() -> &'static str { + "1.0.0" + } + + fn baseurl(&self) -> &str { + self.baseurl.as_str() + } + + fn client(&self) -> &reqwest::Client { + &self.client + } + + fn inner(&self) -> &() { + &() + } +} + +impl ClientHooks<()> for &Client {} +impl Client { + ///Sends a `POST` request to `/upload` + /// + ///```ignore + /// let response = client.upload() + /// .body(body) + /// .send() + /// .await; + /// ``` + pub fn upload(&self) -> builder::Upload<'_> { + builder::Upload::new(self) + } +} + +/// Types for composing operation parameters. +#[allow(clippy::all)] +pub mod builder { + use super::types; + #[allow(unused_imports)] + use super::{ + encode_path, ByteStream, ClientHooks, ClientInfo, Error, OperationInfo, RequestBuilderExt, + ResponseValue, + }; + ///Builder for [`Client::upload`] + /// + ///[`Client::upload`]: super::Client::upload + #[derive(Debug, Clone)] + pub struct Upload<'a> { + client: &'a super::Client, + body: Result, + } + + impl<'a> Upload<'a> { + pub fn new(client: &'a super::Client) -> Self { + Self { + client: client, + body: Ok(::std::default::Default::default()), + } + } + + pub fn body(mut self, value: V) -> Self + where + V: std::convert::TryInto, + >::Error: std::fmt::Display, + { + self.body = value + .try_into() + .map(From::from) + .map_err(|s| format!("conversion to `UploadForm` for body failed: {}", s)); + self + } + + pub fn body_map(mut self, f: F) -> Self + where + F: std::ops::FnOnce(types::builder::UploadForm) -> types::builder::UploadForm, + { + self.body = self.body.map(f); + self + } + + ///Sends a `POST` request to `/upload` + pub async fn send(self) -> Result, Error<()>> { + let Self { client, body } = self; + let body = body + .and_then(|v| types::UploadForm::try_from(v).map_err(|e| e.to_string())) + .map_err(Error::InvalidRequest)?; + let url = format!("{}/upload", client.baseurl,); + let mut header_map = ::reqwest::header::HeaderMap::with_capacity(1usize); + header_map.append( + ::reqwest::header::HeaderName::from_static("api-version"), + ::reqwest::header::HeaderValue::from_static(super::Client::api_version()), + ); + #[allow(unused_mut)] + let mut request = client + .client + .post(url) + .form_from_parts(body.as_form())? + .headers(header_map) + .build()?; + let info = OperationInfo { + operation_id: "upload", + }; + client.pre(&mut request, &info).await?; + let result = client.exec(request, &info).await; + client.post(&result, &info).await?; + let response = result?; + match response.status().as_u16() { + 200u16 => Ok(ResponseValue::empty(response)), + _ => Err(Error::UnexpectedResponse(response)), + } + } + } +} + +/// Items consumers will typically use such as the Client. +pub mod prelude { + pub use self::super::Client; +} diff --git a/progenitor-impl/tests/output/src/example_multipart_builder_tagged.rs b/progenitor-impl/tests/output/src/example_multipart_builder_tagged.rs new file mode 100644 index 00000000..e486c003 --- /dev/null +++ b/progenitor-impl/tests/output/src/example_multipart_builder_tagged.rs @@ -0,0 +1,331 @@ +#[allow(unused_imports)] +use progenitor_client::{encode_path, ClientHooks, OperationInfo, RequestBuilderExt}; +#[allow(unused_imports)] +pub use progenitor_client::{ByteStream, ClientInfo, Error, ResponseValue}; +/// Types used as operation parameters and responses. +#[allow(clippy::all)] +pub mod types { + /// Error types. + pub mod error { + /// Error from a `TryFrom` or `FromStr` implementation. + pub struct ConversionError(::std::borrow::Cow<'static, str>); + impl ::std::error::Error for ConversionError {} + impl ::std::fmt::Display for ConversionError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> { + ::std::fmt::Display::fmt(&self.0, f) + } + } + + impl ::std::fmt::Debug for ConversionError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> { + ::std::fmt::Debug::fmt(&self.0, f) + } + } + + impl From<&'static str> for ConversionError { + fn from(value: &'static str) -> Self { + Self(value.into()) + } + } + + impl From for ConversionError { + fn from(value: String) -> Self { + Self(value.into()) + } + } + } + + ///`UploadForm` + /// + ///
JSON schema + /// + /// ```json + ///{ + /// "type": "object", + /// "required": [ + /// "file", + /// "name" + /// ], + /// "properties": { + /// "file": { + /// "type": "string", + /// "format": "binary" + /// }, + /// "name": { + /// "type": "string" + /// } + /// } + ///} + /// ``` + ///
+ #[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)] + pub struct UploadForm { + pub file: ::std::string::String, + pub name: ::std::string::String, + } + + impl ::std::convert::From<&UploadForm> for UploadForm { + fn from(value: &UploadForm) -> Self { + value.clone() + } + } + + impl UploadForm { + pub fn builder() -> builder::UploadForm { + Default::default() + } + } + + /// Types for composing complex structures. + pub mod builder { + #[derive(Clone, Debug)] + pub struct UploadForm { + file: ::std::result::Result<::std::string::String, ::std::string::String>, + name: ::std::result::Result<::std::string::String, ::std::string::String>, + } + + impl ::std::default::Default for UploadForm { + fn default() -> Self { + Self { + file: Err("no value supplied for file".to_string()), + name: Err("no value supplied for name".to_string()), + } + } + } + + impl UploadForm { + pub fn file(mut self, value: T) -> Self + where + T: ::std::convert::TryInto<::std::string::String>, + T::Error: ::std::fmt::Display, + { + self.file = value + .try_into() + .map_err(|e| format!("error converting supplied value for file: {}", e)); + self + } + pub fn name(mut self, value: T) -> Self + where + T: ::std::convert::TryInto<::std::string::String>, + T::Error: ::std::fmt::Display, + { + self.name = value + .try_into() + .map_err(|e| format!("error converting supplied value for name: {}", e)); + self + } + } + + impl ::std::convert::TryFrom for super::UploadForm { + type Error = super::error::ConversionError; + fn try_from( + value: UploadForm, + ) -> ::std::result::Result { + Ok(Self { + file: value.file?, + name: value.name?, + }) + } + } + + impl ::std::convert::From for UploadForm { + fn from(value: super::UploadForm) -> Self { + Self { + file: Ok(value.file), + name: Ok(value.name), + } + } + } + } + + impl UploadForm { + /// Convert this form into an iterator of (field_name, field_value) + /// pairs + /// suitable for multipart/form-data encoding. + pub fn as_form(&self) -> Vec<(&'static str, progenitor_client::FormPart)> { + let mut parts = Vec::new(); + if let Some(ref val) = self.file { + parts.push(( + "file", + progenitor_client::FormPart::Binary(progenitor_client::BinaryFormPart { + data: val.clone().into(), + filename: None, + content_type: None, + }), + )); + } + if let Some(ref val) = self.name { + parts.push(( + "name", + progenitor_client::FormPart::Text(progenitor_client::TextFormPart { + value: val.to_string(), + content_type: None, + }), + )); + } + parts + } + } +} + +#[derive(Clone, Debug)] +///Client for Multipart Example +/// +///Version: 1.0.0 +pub struct Client { + pub(crate) baseurl: String, + pub(crate) client: reqwest::Client, +} + +impl Client { + /// Create a new client. + /// + /// `baseurl` is the base URL provided to the internal + /// `reqwest::Client`, and should include a scheme and hostname, + /// as well as port and a path stem if applicable. + pub fn new(baseurl: &str) -> Self { + #[cfg(not(target_arch = "wasm32"))] + let client = { + let dur = ::std::time::Duration::from_secs(15u64); + reqwest::ClientBuilder::new() + .connect_timeout(dur) + .timeout(dur) + }; + #[cfg(target_arch = "wasm32")] + let client = reqwest::ClientBuilder::new(); + Self::new_with_client(baseurl, client.build().unwrap()) + } + + /// Construct a new client with an existing `reqwest::Client`, + /// allowing more control over its configuration. + /// + /// `baseurl` is the base URL provided to the internal + /// `reqwest::Client`, and should include a scheme and hostname, + /// as well as port and a path stem if applicable. + pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Self { + Self { + baseurl: baseurl.to_string(), + client, + } + } +} + +impl ClientInfo<()> for Client { + fn api_version() -> &'static str { + "1.0.0" + } + + fn baseurl(&self) -> &str { + self.baseurl.as_str() + } + + fn client(&self) -> &reqwest::Client { + &self.client + } + + fn inner(&self) -> &() { + &() + } +} + +impl ClientHooks<()> for &Client {} +impl Client { + ///Sends a `POST` request to `/upload` + /// + ///```ignore + /// let response = client.upload() + /// .body(body) + /// .send() + /// .await; + /// ``` + pub fn upload(&self) -> builder::Upload<'_> { + builder::Upload::new(self) + } +} + +/// Types for composing operation parameters. +#[allow(clippy::all)] +pub mod builder { + use super::types; + #[allow(unused_imports)] + use super::{ + encode_path, ByteStream, ClientHooks, ClientInfo, Error, OperationInfo, RequestBuilderExt, + ResponseValue, + }; + ///Builder for [`Client::upload`] + /// + ///[`Client::upload`]: super::Client::upload + #[derive(Debug, Clone)] + pub struct Upload<'a> { + client: &'a super::Client, + body: Result, + } + + impl<'a> Upload<'a> { + pub fn new(client: &'a super::Client) -> Self { + Self { + client: client, + body: Ok(::std::default::Default::default()), + } + } + + pub fn body(mut self, value: V) -> Self + where + V: std::convert::TryInto, + >::Error: std::fmt::Display, + { + self.body = value + .try_into() + .map(From::from) + .map_err(|s| format!("conversion to `UploadForm` for body failed: {}", s)); + self + } + + pub fn body_map(mut self, f: F) -> Self + where + F: std::ops::FnOnce(types::builder::UploadForm) -> types::builder::UploadForm, + { + self.body = self.body.map(f); + self + } + + ///Sends a `POST` request to `/upload` + pub async fn send(self) -> Result, Error<()>> { + let Self { client, body } = self; + let body = body + .and_then(|v| types::UploadForm::try_from(v).map_err(|e| e.to_string())) + .map_err(Error::InvalidRequest)?; + let url = format!("{}/upload", client.baseurl,); + let mut header_map = ::reqwest::header::HeaderMap::with_capacity(1usize); + header_map.append( + ::reqwest::header::HeaderName::from_static("api-version"), + ::reqwest::header::HeaderValue::from_static(super::Client::api_version()), + ); + #[allow(unused_mut)] + let mut request = client + .client + .post(url) + .form_from_parts(body.as_form())? + .headers(header_map) + .build()?; + let info = OperationInfo { + operation_id: "upload", + }; + client.pre(&mut request, &info).await?; + let result = client.exec(request, &info).await; + client.post(&result, &info).await?; + let response = result?; + match response.status().as_u16() { + 200u16 => Ok(ResponseValue::empty(response)), + _ => Err(Error::UnexpectedResponse(response)), + } + } + } +} + +/// Items consumers will typically use such as the Client and +/// extension traits. +pub mod prelude { + #[allow(unused_imports)] + pub use super::Client; +} diff --git a/progenitor-impl/tests/output/src/example_multipart_cli.rs b/progenitor-impl/tests/output/src/example_multipart_cli.rs new file mode 100644 index 00000000..2a883c46 --- /dev/null +++ b/progenitor-impl/tests/output/src/example_multipart_cli.rs @@ -0,0 +1,130 @@ +use crate::example_multipart_builder::*; +use anyhow::Context as _; +pub struct Cli { + client: Client, + config: T, +} + +impl Cli { + pub fn new(client: Client, config: T) -> Self { + Self { client, config } + } + + pub fn get_command(cmd: CliCommand) -> ::clap::Command { + match cmd { + CliCommand::Upload => Self::cli_upload(), + } + } + + pub fn cli_upload() -> ::clap::Command { + ::clap::Command::new("") + .arg( + ::clap::Arg::new("file") + .long("file") + .value_parser(::clap::value_parser!(::std::string::String)) + .required_unless_present("json-body"), + ) + .arg( + ::clap::Arg::new("name") + .long("name") + .value_parser(::clap::value_parser!(::std::string::String)) + .required_unless_present("json-body"), + ) + .arg( + ::clap::Arg::new("json-body") + .long("json-body") + .value_name("JSON-FILE") + .required(false) + .value_parser(::clap::value_parser!(std::path::PathBuf)) + .help("Path to a file that contains the full json body."), + ) + .arg( + ::clap::Arg::new("json-body-template") + .long("json-body-template") + .action(::clap::ArgAction::SetTrue) + .help("XXX"), + ) + } + + pub async fn execute( + &self, + cmd: CliCommand, + matches: &::clap::ArgMatches, + ) -> anyhow::Result<()> { + match cmd { + CliCommand::Upload => self.execute_upload(matches).await, + } + } + + pub async fn execute_upload(&self, matches: &::clap::ArgMatches) -> anyhow::Result<()> { + let mut request = self.client.upload(); + if let Some(value) = matches.get_one::<::std::string::String>("file") { + request = request.body_map(|body| body.file(value.clone())) + } + + if let Some(value) = matches.get_one::<::std::string::String>("name") { + request = request.body_map(|body| body.name(value.clone())) + } + + if let Some(value) = matches.get_one::("json-body") { + let body_txt = std::fs::read_to_string(value) + .with_context(|| format!("failed to read {}", value.display()))?; + let body_value = serde_json::from_str::(&body_txt) + .with_context(|| format!("failed to parse {}", value.display()))?; + request = request.body(body_value); + } + + self.config.execute_upload(matches, &mut request)?; + let result = request.send().await; + match result { + Ok(r) => { + self.config.success_no_item(&r); + Ok(()) + } + Err(r) => { + self.config.error(&r); + Err(anyhow::Error::new(r)) + } + } + } +} + +pub trait CliConfig { + fn success_item(&self, value: &ResponseValue) + where + T: std::clone::Clone + schemars::JsonSchema + serde::Serialize + std::fmt::Debug; + fn success_no_item(&self, value: &ResponseValue<()>); + fn error(&self, value: &Error) + where + T: std::clone::Clone + schemars::JsonSchema + serde::Serialize + std::fmt::Debug; + fn list_start(&self) + where + T: std::clone::Clone + schemars::JsonSchema + serde::Serialize + std::fmt::Debug; + fn list_item(&self, value: &T) + where + T: std::clone::Clone + schemars::JsonSchema + serde::Serialize + std::fmt::Debug; + fn list_end_success(&self) + where + T: std::clone::Clone + schemars::JsonSchema + serde::Serialize + std::fmt::Debug; + fn list_end_error(&self, value: &Error) + where + T: std::clone::Clone + schemars::JsonSchema + serde::Serialize + std::fmt::Debug; + fn execute_upload( + &self, + matches: &::clap::ArgMatches, + request: &mut builder::Upload, + ) -> anyhow::Result<()> { + Ok(()) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum CliCommand { + Upload, +} + +impl CliCommand { + pub fn iter() -> impl Iterator { + vec![CliCommand::Upload].into_iter() + } +} diff --git a/progenitor-impl/tests/output/src/example_multipart_httpmock.rs b/progenitor-impl/tests/output/src/example_multipart_httpmock.rs new file mode 100644 index 00000000..4446f766 --- /dev/null +++ b/progenitor-impl/tests/output/src/example_multipart_httpmock.rs @@ -0,0 +1,63 @@ +pub mod operations { + #![doc = r" [`When`](::httpmock::When) and [`Then`](::httpmock::Then)"] + #![doc = r" wrappers for each operation. Each can be converted to"] + #![doc = r" its inner type with a call to `into_inner()`. This can"] + #![doc = r" be used to explicitly deviate from permitted values."] + use crate::example_multipart_builder::*; + pub struct UploadWhen(::httpmock::When); + impl UploadWhen { + pub fn new(inner: ::httpmock::When) -> Self { + Self( + inner + .method(::httpmock::Method::POST) + .path_matches(regex::Regex::new("^/upload$").unwrap()), + ) + } + + pub fn into_inner(self) -> ::httpmock::When { + self.0 + } + + pub fn body(self, value: &types::UploadForm) -> Self { + Self(self.0.json_body_obj(value)) + } + } + + pub struct UploadThen(::httpmock::Then); + impl UploadThen { + pub fn new(inner: ::httpmock::Then) -> Self { + Self(inner) + } + + pub fn into_inner(self) -> ::httpmock::Then { + self.0 + } + + pub fn ok(self) -> Self { + Self(self.0.status(200u16)) + } + } +} + +#[doc = r" An extension trait for [`MockServer`](::httpmock::MockServer) that"] +#[doc = r" adds a method for each operation. These are the equivalent of"] +#[doc = r" type-checked [`mock()`](::httpmock::MockServer::mock) calls."] +pub trait MockServerExt { + fn upload(&self, config_fn: F) -> ::httpmock::Mock<'_> + where + F: FnOnce(operations::UploadWhen, operations::UploadThen); +} + +impl MockServerExt for ::httpmock::MockServer { + fn upload(&self, config_fn: F) -> ::httpmock::Mock<'_> + where + F: FnOnce(operations::UploadWhen, operations::UploadThen), + { + self.mock(|when, then| { + config_fn( + operations::UploadWhen::new(when), + operations::UploadThen::new(then), + ) + }) + } +} diff --git a/progenitor-impl/tests/output/src/example_multipart_positional.rs b/progenitor-impl/tests/output/src/example_multipart_positional.rs new file mode 100644 index 00000000..e290da4a --- /dev/null +++ b/progenitor-impl/tests/output/src/example_multipart_positional.rs @@ -0,0 +1,202 @@ +#[allow(unused_imports)] +use progenitor_client::{encode_path, ClientHooks, OperationInfo, RequestBuilderExt}; +#[allow(unused_imports)] +pub use progenitor_client::{ByteStream, ClientInfo, Error, ResponseValue}; +/// Types used as operation parameters and responses. +#[allow(clippy::all)] +pub mod types { + /// Error types. + pub mod error { + /// Error from a `TryFrom` or `FromStr` implementation. + pub struct ConversionError(::std::borrow::Cow<'static, str>); + impl ::std::error::Error for ConversionError {} + impl ::std::fmt::Display for ConversionError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> { + ::std::fmt::Display::fmt(&self.0, f) + } + } + + impl ::std::fmt::Debug for ConversionError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> { + ::std::fmt::Debug::fmt(&self.0, f) + } + } + + impl From<&'static str> for ConversionError { + fn from(value: &'static str) -> Self { + Self(value.into()) + } + } + + impl From for ConversionError { + fn from(value: String) -> Self { + Self(value.into()) + } + } + } + + ///`UploadForm` + /// + ///
JSON schema + /// + /// ```json + ///{ + /// "type": "object", + /// "required": [ + /// "file", + /// "name" + /// ], + /// "properties": { + /// "file": { + /// "type": "string", + /// "format": "binary" + /// }, + /// "name": { + /// "type": "string" + /// } + /// } + ///} + /// ``` + ///
+ #[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)] + pub struct UploadForm { + pub file: ::std::string::String, + pub name: ::std::string::String, + } + + impl ::std::convert::From<&UploadForm> for UploadForm { + fn from(value: &UploadForm) -> Self { + value.clone() + } + } + + impl UploadForm { + /// Convert this form into an iterator of (field_name, field_value) + /// pairs + /// suitable for multipart/form-data encoding. + pub fn as_form(&self) -> Vec<(&'static str, progenitor_client::FormPart)> { + let mut parts = Vec::new(); + if let Some(ref val) = self.file { + parts.push(( + "file", + progenitor_client::FormPart::Binary(progenitor_client::BinaryFormPart { + data: val.clone().into(), + filename: None, + content_type: None, + }), + )); + } + if let Some(ref val) = self.name { + parts.push(( + "name", + progenitor_client::FormPart::Text(progenitor_client::TextFormPart { + value: val.to_string(), + content_type: None, + }), + )); + } + parts + } + } +} + +#[derive(Clone, Debug)] +///Client for Multipart Example +/// +///Version: 1.0.0 +pub struct Client { + pub(crate) baseurl: String, + pub(crate) client: reqwest::Client, +} + +impl Client { + /// Create a new client. + /// + /// `baseurl` is the base URL provided to the internal + /// `reqwest::Client`, and should include a scheme and hostname, + /// as well as port and a path stem if applicable. + pub fn new(baseurl: &str) -> Self { + #[cfg(not(target_arch = "wasm32"))] + let client = { + let dur = ::std::time::Duration::from_secs(15u64); + reqwest::ClientBuilder::new() + .connect_timeout(dur) + .timeout(dur) + }; + #[cfg(target_arch = "wasm32")] + let client = reqwest::ClientBuilder::new(); + Self::new_with_client(baseurl, client.build().unwrap()) + } + + /// Construct a new client with an existing `reqwest::Client`, + /// allowing more control over its configuration. + /// + /// `baseurl` is the base URL provided to the internal + /// `reqwest::Client`, and should include a scheme and hostname, + /// as well as port and a path stem if applicable. + pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Self { + Self { + baseurl: baseurl.to_string(), + client, + } + } +} + +impl ClientInfo<()> for Client { + fn api_version() -> &'static str { + "1.0.0" + } + + fn baseurl(&self) -> &str { + self.baseurl.as_str() + } + + fn client(&self) -> &reqwest::Client { + &self.client + } + + fn inner(&self) -> &() { + &() + } +} + +impl ClientHooks<()> for &Client {} +#[allow(clippy::all)] +impl Client { + ///Sends a `POST` request to `/upload` + pub async fn upload<'a>( + &'a self, + body: &'a types::UploadForm, + ) -> Result, Error<()>> { + let url = format!("{}/upload", self.baseurl,); + let mut header_map = ::reqwest::header::HeaderMap::with_capacity(1usize); + header_map.append( + ::reqwest::header::HeaderName::from_static("api-version"), + ::reqwest::header::HeaderValue::from_static(Self::api_version()), + ); + #[allow(unused_mut)] + let mut request = self + .client + .post(url) + .form_from_parts(body.as_form())? + .headers(header_map) + .build()?; + let info = OperationInfo { + operation_id: "upload", + }; + self.pre(&mut request, &info).await?; + let result = self.exec(request, &info).await; + self.post(&result, &info).await?; + let response = result?; + match response.status().as_u16() { + 200u16 => Ok(ResponseValue::empty(response)), + _ => Err(Error::UnexpectedResponse(response)), + } + } +} + +/// Items consumers will typically use such as the Client. +pub mod prelude { + #[allow(unused_imports)] + pub use super::Client; +} diff --git a/progenitor-impl/tests/test_output.rs b/progenitor-impl/tests/test_output.rs index 010315d9..5bbce076 100644 --- a/progenitor-impl/tests/test_output.rs +++ b/progenitor-impl/tests/test_output.rs @@ -163,6 +163,11 @@ fn test_cli_gen() { verify_apis("cli-gen.json"); } +#[test] +fn test_example_multipart() { + verify_apis("example_multipart.json"); +} + #[test] fn test_nexus_with_different_timeout() { const OPENAPI_FILE: &'static str = "nexus.json"; diff --git a/progenitor-impl/tests/test_specific.rs b/progenitor-impl/tests/test_specific.rs index ba29aa6c..681a34b9 100644 --- a/progenitor-impl/tests/test_specific.rs +++ b/progenitor-impl/tests/test_specific.rs @@ -355,3 +355,497 @@ async fn test_stream_pagination() { server.close().await.expect("failed to close server"); } + +mod multipart_tests { + use openapiv3::OpenAPI; + use progenitor_impl::{GenerationSettings, Generator, InterfaceStyle}; + + fn make_multipart_spec(operation_id: &str, properties: serde_json::Value) -> OpenAPI { + serde_json::from_value(serde_json::json!({ + "openapi": "3.0.0", + "info": { "title": "Test", "version": "1.0.0" }, + "paths": { + "/upload": { + "post": { + "operationId": operation_id, + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": properties + } + } + } + }, + "responses": { + "200": { "description": "OK" } + } + } + } + } + })) + .unwrap() + } + + fn make_multipart_spec_with_encoding( + operation_id: &str, + properties: serde_json::Value, + encoding: serde_json::Value, + ) -> OpenAPI { + serde_json::from_value(serde_json::json!({ + "openapi": "3.0.0", + "info": { "title": "Test", "version": "1.0.0" }, + "paths": { + "/upload": { + "post": { + "operationId": operation_id, + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": properties + }, + "encoding": encoding + } + } + }, + "responses": { + "200": { "description": "OK" } + } + } + } + } + })) + .unwrap() + } + + fn make_builder_generator() -> Generator { + Generator::new(GenerationSettings::default().with_interface(InterfaceStyle::Builder)) + } + + fn make_positional_generator() -> Generator { + Generator::new(GenerationSettings::default().with_interface(InterfaceStyle::Positional)) + } + + fn assert_generates_ok(spec: &OpenAPI, msg: &str) -> String { + let mut generator = make_builder_generator(); + let result = generator.generate_tokens(spec); + assert!(result.is_ok(), "{}: {:?}", msg, result.err()); + result.unwrap().to_string() + } + + fn assert_generates_err(spec: &OpenAPI, expected_err: &str) { + let mut generator = make_builder_generator(); + let result = generator.generate_tokens(spec); + assert!(result.is_err(), "Expected generation to fail"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains(expected_err), + "Error should contain '{}': {}", + expected_err, + err + ); + } + + fn binary_file_prop() -> serde_json::Value { + serde_json::json!({ "type": "string", "format": "binary" }) + } + + fn string_prop() -> serde_json::Value { + serde_json::json!({ "type": "string" }) + } + + fn integer_prop() -> serde_json::Value { + serde_json::json!({ "type": "integer" }) + } + + fn boolean_prop() -> serde_json::Value { + serde_json::json!({ "type": "boolean" }) + } + + fn number_prop() -> serde_json::Value { + serde_json::json!({ "type": "number", "format": "double" }) + } + + fn string_array_prop() -> serde_json::Value { + serde_json::json!({ "type": "array", "items": { "type": "string" } }) + } + + fn integer_array_prop() -> serde_json::Value { + serde_json::json!({ "type": "array", "items": { "type": "integer" } }) + } + + fn simple_object_prop() -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + } + }) + } + + #[test] + fn test_multipart_binary_only_succeeds() { + let spec = make_multipart_spec( + "upload_file", + serde_json::json!({ + "file": binary_file_prop() + }), + ); + assert_generates_ok(&spec, "Binary-only multipart should be supported"); + } + + #[test] + fn test_multipart_mixed_types_supported() { + let spec = make_multipart_spec( + "upload_with_name", + serde_json::json!({ + "file": binary_file_prop(), + "name": string_prop() + }), + ); + let code = assert_generates_ok(&spec, "Mixed types in multipart should be supported"); + assert!(code.contains("UploadWithNameForm")); + } + + #[test] + fn test_multipart_nested_objects_supported() { + let spec = make_multipart_spec( + "upload_with_metadata", + serde_json::json!({ + "file": binary_file_prop(), + "metadata": { + "type": "object", + "properties": { + "author": { "type": "string" } + } + } + }), + ); + assert_generates_ok(&spec, "Nested objects in multipart should be supported"); + } + + #[test] + fn test_multipart_integer_fields_supported() { + let spec = make_multipart_spec( + "upload_with_count", + serde_json::json!({ + "file": binary_file_prop(), + "count": integer_prop() + }), + ); + assert_generates_ok(&spec, "Integer fields in multipart should be supported"); + } + + #[test] + fn test_multipart_array_fields_supported() { + let spec = make_multipart_spec( + "upload_with_tags", + serde_json::json!({ + "file": binary_file_prop(), + "tags": string_array_prop() + }), + ); + assert_generates_ok(&spec, "Array fields in multipart should be supported"); + } + + #[test] + fn test_multipart_multiple_binary_files_succeeds() { + let spec = make_multipart_spec( + "upload_multiple", + serde_json::json!({ + "primary": binary_file_prop(), + "secondary": binary_file_prop(), + "thumbnail": binary_file_prop() + }), + ); + assert_generates_ok(&spec, "Multiple binary files should be supported"); + } + + #[test] + fn test_multipart_encoding_content_type() { + let spec = make_multipart_spec_with_encoding( + "upload_with_encoding", + serde_json::json!({ + "file": binary_file_prop(), + "metadata": { + "type": "object", + "properties": { + "author": string_prop(), + "tags": string_array_prop() + } + } + }), + serde_json::json!({ + "file": { "contentType": "application/octet-stream" }, + "metadata": { "contentType": "application/json" } + }), + ); + let code = assert_generates_ok(&spec, "Encoding object should be supported"); + assert!(code.contains("UploadWithEncodingForm")); + } + + #[test] + fn test_multipart_array_json_serialization() { + let spec = make_multipart_spec( + "upload_with_array", + serde_json::json!({ + "tags": string_array_prop() + }), + ); + let code = assert_generates_ok(&spec, "Array fields should be supported"); + // TokenStream::to_string() uses spaces between tokens + assert!(code.contains("serde_json :: to_string")); + } + + #[test] + fn test_multipart_object_json_serialization() { + let spec = make_multipart_spec( + "upload_with_object", + serde_json::json!({ + "metadata": simple_object_prop() + }), + ); + let code = assert_generates_ok(&spec, "Object fields should be supported"); + assert!(code.contains("serde_json :: to_string")); + } + + #[test] + fn test_multipart_boolean_fields_supported() { + let spec = make_multipart_spec( + "upload_with_flag", + serde_json::json!({ + "file": binary_file_prop(), + "is_public": boolean_prop() + }), + ); + let code = assert_generates_ok(&spec, "Boolean fields in multipart should be supported"); + assert!(code.contains("UploadWithFlagForm")); + } + + #[test] + fn test_multipart_number_fields_supported() { + let spec = make_multipart_spec( + "upload_with_score", + serde_json::json!({ + "file": binary_file_prop(), + "score": number_prop() + }), + ); + assert_generates_ok(&spec, "Number fields in multipart should be supported"); + } + + #[test] + fn test_multipart_required_fields() { + let spec: OpenAPI = serde_json::from_value(serde_json::json!({ + "openapi": "3.0.0", + "info": { "title": "Test", "version": "1.0.0" }, + "paths": { + "/upload": { + "post": { + "operationId": "upload_required", + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "required": ["file", "name"], + "properties": { + "file": binary_file_prop(), + "name": string_prop(), + "description": string_prop() + } + } + } + } + }, + "responses": { "200": { "description": "OK" } } + } + } + } + })) + .unwrap(); + + let code = assert_generates_ok(&spec, "Required fields in multipart should be supported"); + assert!(code.contains("UploadRequiredForm")); + } + + #[test] + fn test_multipart_with_schema_reference() { + let spec: OpenAPI = serde_json::from_value(serde_json::json!({ + "openapi": "3.0.0", + "info": { "title": "Test", "version": "1.0.0" }, + "paths": { + "/upload": { + "post": { + "operationId": "upload_with_ref", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "file": binary_file_prop(), + "metadata": { "$ref": "#/components/schemas/Metadata" } + } + } + } + } + }, + "responses": { "200": { "description": "OK" } } + } + } + }, + "components": { + "schemas": { + "Metadata": { + "type": "object", + "properties": { + "author": { "type": "string" }, + "version": { "type": "integer" } + } + } + } + } + })) + .unwrap(); + + let code = assert_generates_ok(&spec, "Schema references in multipart should be supported"); + assert!(code.contains("serde_json :: to_string")); + } + + #[test] + fn test_multipart_non_object_schema_fails() { + let spec: OpenAPI = serde_json::from_value(serde_json::json!({ + "openapi": "3.0.0", + "info": { "title": "Test", "version": "1.0.0" }, + "paths": { + "/upload": { + "post": { + "operationId": "upload_bad", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "string", + "format": "binary" + } + } + } + }, + "responses": { "200": { "description": "OK" } } + } + } + } + })) + .unwrap(); + + assert_generates_err(&spec, "object schema"); + } + + #[test] + #[should_panic(expected = "unreachable")] + fn test_multipart_empty_properties_panics() { + let spec = make_multipart_spec("upload_empty", serde_json::json!({})); + let mut generator = make_builder_generator(); + let _ = generator.generate_tokens(&spec); + } + + #[test] + fn test_multipart_array_of_integers() { + let spec = make_multipart_spec( + "upload_ids", + serde_json::json!({ + "ids": integer_array_prop() + }), + ); + let code = assert_generates_ok(&spec, "Array of integers in multipart should be supported"); + assert!(code.contains("serde_json :: to_string")); + } + + #[test] + fn test_multipart_deeply_nested_object() { + let spec = make_multipart_spec( + "upload_nested", + serde_json::json!({ + "config": { + "type": "object", + "properties": { + "settings": { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "value": { "type": "string" } + } + } + } + } + } + } + }), + ); + assert_generates_ok( + &spec, + "Deeply nested objects in multipart should be supported", + ); + } + + #[test] + fn test_multipart_multiple_encoding_entries() { + let spec = make_multipart_spec_with_encoding( + "upload_multi_encoding", + serde_json::json!({ + "image": binary_file_prop(), + "document": binary_file_prop(), + "data": simple_object_prop() + }), + serde_json::json!({ + "image": { "contentType": "image/png" }, + "document": { "contentType": "application/pdf" }, + "data": { "contentType": "application/xml" } + }), + ); + let code = assert_generates_ok(&spec, "Multiple encoding entries should be supported"); + assert!(code.contains("image/png") || code.contains("image / png")); + } + + #[test] + fn test_multipart_positional_interface() { + let spec = make_multipart_spec( + "upload_file", + serde_json::json!({ + "file": binary_file_prop(), + "name": string_prop() + }), + ); + + let mut generator = make_positional_generator(); + let result = generator.generate_tokens(&spec); + assert!( + result.is_ok(), + "Multipart should work with positional interface: {:?}", + result.err() + ); + assert!(result.unwrap().to_string().contains("UploadFileForm")); + } + + #[test] + fn test_multipart_enum_string_field() { + let spec = make_multipart_spec( + "upload_with_category", + serde_json::json!({ + "file": binary_file_prop(), + "category": { + "type": "string", + "enum": ["image", "document", "video"] + } + }), + ); + assert_generates_ok(&spec, "Enum string fields in multipart should be supported"); + } +} diff --git a/sample_openapi/example_multipart.json b/sample_openapi/example_multipart.json new file mode 100644 index 00000000..0bcae8a5 --- /dev/null +++ b/sample_openapi/example_multipart.json @@ -0,0 +1,39 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Multipart Example", + "version": "1.0.0" + }, + "paths": { + "/upload": { + "post": { + "operationId": "upload", + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "required": ["file", "name"], + "properties": { + "file": { + "type": "string", + "format": "binary" + }, + "name": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Success" + } + } + } + } + } +}