diff --git a/astra/lua/http.lua b/astra/lua/http.lua index e36dffd..e90ee88 100644 --- a/astra/lua/http.lua +++ b/astra/lua/http.lua @@ -77,7 +77,7 @@ local http = {} ---@field fields fun(): table Returns all multipart fields as an array ---@field get_field fun(name: string): HTTPMultipartField Returns a specific field by name ---@field file_name fun(): string|nil Returns the first filename found in the multipart data ----@field save_file fun(multipart: HTTPMultipart, file_path: string | nil): string | nil Saves the multipart into disk +---@field save_file fun(multipart: HTTPMultipart, file_path: string | table | nil): string | nil Saves the multipart into disk ---@class HTTPServerRequest ---@field method fun(self: HTTPServerRequest): string Returns the HTTP method (e.g., "GET", "POST"). diff --git a/astra/teal/http.tl b/astra/teal/http.tl index ee6e545..1bd2921 100644 --- a/astra/teal/http.tl +++ b/astra/teal/http.tl @@ -75,7 +75,7 @@ global interface HTTPMultipart is userdata fields: function(self: HTTPMultipart): {HTTPMultipartField} get_field: function(self: HTTPMultipart, name: string): HTTPMultipartField file_name: function(self: HTTPMultipart): string | nil - save_file: function(self: HTTPMultipart, file_path?: string) + save_file: function(self: HTTPMultipart, file_path?: string | {string: string}) end global interface HTTPServerRequest is userdata diff --git a/src/components/http/server/requests.rs b/src/components/http/server/requests.rs index 076a7a4..551be43 100644 --- a/src/components/http/server/requests.rs +++ b/src/components/http/server/requests.rs @@ -273,27 +273,41 @@ impl UserData for AstraMultipart { methods.add_async_method_mut( "save_file", - |_, this, file_path: Option| async move { - let mut file_path = if let Some(file_path) = file_path { - Some(tokio::fs::File::create(file_path).await?) - } else { - None - }; - - for field in &this.fields { - if file_path.is_none() - && let Some(filename) = &field.file_name - { - file_path = Some(tokio::fs::File::create(filename).await?); + |_lua, this, arg: mlua::Value| async move { + match arg { + mlua::Value::String(s) => { + let folder = s.to_str()?.to_string(); + let folder_path = std::path::PathBuf::from(&folder); + if !folder_path.exists() { + tokio::fs::create_dir_all(&folder_path).await?; + } + for field in &this.fields { + let filename = field.file_name.as_ref().unwrap_or(&field.name); + let file_path = folder_path.join(filename); + let mut file = tokio::fs::File::create(file_path).await?; + file.write_all(&field.data).await?; + } } - if let Some(ref mut file) = file_path - && let bytes = &field.data - && let Err(err) = file.write(bytes).await - { - return Err(err.into_lua_err()); + mlua::Value::Table(map) => { + for field in &this.fields { + let path: String = map.get(field.name.clone())?; + let mut file = tokio::fs::File::create(path).await?; + file.write_all(&field.data).await?; + } + } + mlua::Value::Nil => { + for field in &this.fields { + let filename = field.file_name.as_ref().unwrap_or(&field.name); + let mut file = tokio::fs::File::create(filename).await?; + file.write_all(&field.data).await?; + } + } + _ => { + return Err(mlua::Error::runtime( + "save_file expects string (folder), table (mapping), or nil", + )) } } - Ok(()) }, );