Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion astra/lua/http.lua
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ local http = {}
---@field fields fun(): table Returns all multipart fields as an array
---@field get_field fun(name: string): HTTPMultipartField Returns a specific field by name
---@field file_name fun(): string|nil Returns the first filename found in the multipart data
---@field save_file fun(multipart: HTTPMultipart, file_path: string | nil): string | nil Saves the multipart into disk
---@field save_file fun(multipart: HTTPMultipart, file_path: string | table<string, string> | nil): string | nil Saves the multipart into disk

---@class HTTPServerRequest
---@field method fun(self: HTTPServerRequest): string Returns the HTTP method (e.g., "GET", "POST").
Expand Down
2 changes: 1 addition & 1 deletion astra/teal/http.tl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ global interface HTTPMultipart is userdata
fields: function(self: HTTPMultipart): {HTTPMultipartField}
get_field: function(self: HTTPMultipart, name: string): HTTPMultipartField
file_name: function(self: HTTPMultipart): string | nil
save_file: function(self: HTTPMultipart, file_path?: string)
save_file: function(self: HTTPMultipart, file_path?: string | {string: string})
end

global interface HTTPServerRequest is userdata
Expand Down
50 changes: 32 additions & 18 deletions src/components/http/server/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,27 +273,41 @@ impl UserData for AstraMultipart {

methods.add_async_method_mut(
"save_file",
|_, this, file_path: Option<String>| async move {
let mut file_path = if let Some(file_path) = file_path {
Some(tokio::fs::File::create(file_path).await?)
} else {
None
};

for field in &this.fields {
if file_path.is_none()
&& let Some(filename) = &field.file_name
{
file_path = Some(tokio::fs::File::create(filename).await?);
|_lua, this, arg: mlua::Value| async move {
match arg {
mlua::Value::String(s) => {
let folder = s.to_str()?.to_string();
let folder_path = std::path::PathBuf::from(&folder);
if !folder_path.exists() {
tokio::fs::create_dir_all(&folder_path).await?;
}
for field in &this.fields {
let filename = field.file_name.as_ref().unwrap_or(&field.name);
let file_path = folder_path.join(filename);
let mut file = tokio::fs::File::create(file_path).await?;
file.write_all(&field.data).await?;
}
}
if let Some(ref mut file) = file_path
&& let bytes = &field.data
&& let Err(err) = file.write(bytes).await
{
return Err(err.into_lua_err());
mlua::Value::Table(map) => {
for field in &this.fields {
let path: String = map.get(field.name.clone())?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're not doing any checks for the directory for these, and it automatically errors out if a field is not present, it should rather just continue in the iteration instead.

let mut file = tokio::fs::File::create(path).await?;
file.write_all(&field.data).await?;
}
}
mlua::Value::Nil => {
for field in &this.fields {
let filename = field.file_name.as_ref().unwrap_or(&field.name);
let mut file = tokio::fs::File::create(filename).await?;
file.write_all(&field.data).await?;
}
}
_ => {
return Err(mlua::Error::runtime(
"save_file expects string (folder), table (mapping), or nil",
))
}
}

Ok(())
},
);
Expand Down