Skip to content

Commit ff0f83f

Browse files
authored
feat: use multipart uploads API for large files automatically with optional config (#18)
* feat: use multipart uploads API for large files * chore: use right type for multipart upload param * chore: lint fixes
1 parent e95f5f1 commit ff0f83f

4 files changed

Lines changed: 375 additions & 0 deletions

File tree

src/mixedbread/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as _t
44

55
from . import types
6+
from .lib import PartUploadEvent as PartUploadEvent, MultipartUploadOptions as MultipartUploadOptions
67
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
78
from ._utils import file_from_path
89
from ._client import (
@@ -83,6 +84,8 @@
8384
"DefaultHttpxClient",
8485
"DefaultAsyncHttpxClient",
8586
"DefaultAioHttpClient",
87+
"MultipartUploadOptions",
88+
"PartUploadEvent",
8689
]
8790

8891
if not _t.TYPE_CHECKING:

src/mixedbread/lib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .multipart_upload import PartUploadEvent as PartUploadEvent, MultipartUploadOptions as MultipartUploadOptions
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import math
5+
import asyncio
6+
import mimetypes
7+
from typing import TYPE_CHECKING, Any, List, Union, Callable, Optional
8+
from pathlib import Path
9+
from dataclasses import dataclass
10+
from concurrent.futures import ThreadPoolExecutor
11+
12+
import httpx
13+
14+
if TYPE_CHECKING:
15+
from .._types import FileTypes, FileContent
16+
from ..resources.files.uploads import UploadsResource, AsyncUploadsResource
17+
18+
from ..types.file_object import FileObject
19+
from ..types.files.multipart_upload_part_param import MultipartUploadPartParam
20+
21+
DEFAULT_THRESHOLD = 100 * 1024 * 1024 # 100 MB
22+
DEFAULT_PART_SIZE = 100 * 1024 * 1024 # 100 MB
23+
DEFAULT_CONCURRENCY = 5
24+
UPLOAD_TIMEOUT = 300 # 5 minutes
25+
26+
27+
@dataclass
28+
class PartUploadEvent:
29+
"""Event emitted after each part is uploaded."""
30+
31+
part_number: int
32+
total_parts: int
33+
part_size: int
34+
uploaded_bytes: int
35+
total_bytes: int
36+
37+
38+
@dataclass
39+
class MultipartUploadOptions:
40+
"""Options for controlling multipart upload behavior."""
41+
42+
threshold: int = DEFAULT_THRESHOLD
43+
part_size: int = DEFAULT_PART_SIZE
44+
concurrency: int = DEFAULT_CONCURRENCY
45+
on_part_upload: Optional[Callable[[PartUploadEvent], None]] = None
46+
47+
48+
@dataclass
49+
class _ResolvedFile:
50+
"""Internal resolved file representation."""
51+
52+
data: Union[bytes, Path]
53+
file_size: int
54+
filename: str
55+
mime_type: str
56+
57+
58+
def _get_file_size(file: FileTypes) -> int:
59+
"""Get file size without reading the entire file into memory.
60+
61+
Raises TypeError if the size cannot be determined.
62+
"""
63+
# Handle tuple forms: (filename, content, ...)
64+
if isinstance(file, tuple):
65+
file_content = file[1]
66+
else:
67+
file_content = file
68+
69+
if isinstance(file_content, bytes):
70+
return len(file_content)
71+
72+
if isinstance(file_content, os.PathLike):
73+
return os.stat(file_content).st_size
74+
75+
# IO[bytes] - try seek-based size detection
76+
if hasattr(file_content, "seek") and hasattr(file_content, "tell"):
77+
current = file_content.tell()
78+
file_content.seek(0, 2)
79+
size = file_content.tell()
80+
file_content.seek(current)
81+
return size
82+
83+
raise TypeError(f"Cannot determine file size for {type(file_content)}")
84+
85+
86+
def _resolve_file_input(file: FileTypes) -> _ResolvedFile:
87+
"""Resolve a FileTypes input into a normalized representation."""
88+
filename: Optional[str] = None
89+
mime_type: Optional[str] = None
90+
file_content: FileContent
91+
92+
if isinstance(file, tuple):
93+
filename = file[0]
94+
file_content = file[1]
95+
if len(file) >= 3:
96+
mime_type = file[2] # type: ignore[misc]
97+
else:
98+
file_content = file
99+
100+
# Resolve file content to bytes or Path
101+
data: Union[bytes, Path]
102+
if isinstance(file_content, bytes):
103+
data = file_content
104+
file_size = len(file_content)
105+
if filename is None:
106+
filename = "upload"
107+
elif isinstance(file_content, os.PathLike):
108+
path = Path(file_content)
109+
data = path
110+
file_size = os.stat(path).st_size
111+
if filename is None:
112+
filename = path.name
113+
elif hasattr(file_content, "read"):
114+
# IO[bytes] - read into memory
115+
data = file_content.read()
116+
file_size = len(data)
117+
if filename is None:
118+
name = getattr(file_content, "name", None)
119+
if name:
120+
filename = os.path.basename(name)
121+
else:
122+
filename = "upload"
123+
else:
124+
raise TypeError(f"Unsupported file type: {type(file_content)}")
125+
126+
# Resolve mime type
127+
if not mime_type and filename:
128+
guessed, _ = mimetypes.guess_type(filename)
129+
mime_type = guessed or "application/octet-stream"
130+
elif not mime_type:
131+
mime_type = "application/octet-stream"
132+
133+
return _ResolvedFile(
134+
data=data,
135+
file_size=file_size,
136+
filename=filename or "upload",
137+
mime_type=mime_type,
138+
)
139+
140+
141+
def _read_part(resolved: _ResolvedFile, part_number: int, part_size: int) -> bytes:
142+
"""Read a specific part from the resolved file data.
143+
144+
For bytes data, slices directly. For PathLike, opens its own file handle
145+
(thread-safe for concurrent uploads).
146+
"""
147+
offset = (part_number - 1) * part_size # parts are 1-based
148+
149+
if isinstance(resolved.data, bytes):
150+
return resolved.data[offset : offset + part_size]
151+
152+
# PathLike - each caller gets its own file handle
153+
with open(resolved.data, "rb") as f:
154+
f.seek(offset)
155+
return f.read(part_size)
156+
157+
158+
def _upload_single_part(
159+
url: str,
160+
data: bytes,
161+
http_client: httpx.Client,
162+
) -> str:
163+
"""Upload a single part to its presigned URL. Returns the ETag."""
164+
response = http_client.put(url, content=data)
165+
response.raise_for_status()
166+
return response.headers.get("etag", "")
167+
168+
169+
async def _async_upload_single_part(
170+
url: str,
171+
data: bytes,
172+
http_client: httpx.AsyncClient,
173+
) -> str:
174+
"""Upload a single part to its presigned URL asynchronously. Returns the ETag."""
175+
response = await http_client.put(url, content=data)
176+
response.raise_for_status()
177+
return response.headers.get("etag", "")
178+
179+
180+
def multipart_create_sync(
181+
uploads: UploadsResource,
182+
file: FileTypes,
183+
options: MultipartUploadOptions,
184+
) -> FileObject:
185+
"""Perform a multipart upload synchronously."""
186+
resolved = _resolve_file_input(file)
187+
part_count = max(1, math.ceil(resolved.file_size / options.part_size))
188+
189+
# Step 1: Initiate the multipart upload
190+
upload = uploads.create(
191+
filename=resolved.filename,
192+
file_size=resolved.file_size,
193+
mime_type=resolved.mime_type,
194+
part_count=part_count,
195+
)
196+
upload_id = upload.id
197+
198+
try:
199+
# Step 2: Upload parts concurrently
200+
completed_parts: List[MultipartUploadPartParam] = []
201+
202+
with httpx.Client(timeout=httpx.Timeout(UPLOAD_TIMEOUT)) as http_client:
203+
204+
def _do_upload(part_url: Any) -> MultipartUploadPartParam:
205+
part_data = _read_part(resolved, part_url.part_number, options.part_size)
206+
etag = _upload_single_part(part_url.url, part_data, http_client)
207+
208+
if options.on_part_upload:
209+
uploaded_bytes = min(
210+
part_url.part_number * options.part_size,
211+
resolved.file_size,
212+
)
213+
options.on_part_upload(
214+
PartUploadEvent(
215+
part_number=part_url.part_number,
216+
total_parts=part_count,
217+
part_size=len(part_data),
218+
uploaded_bytes=uploaded_bytes,
219+
total_bytes=resolved.file_size,
220+
)
221+
)
222+
223+
return MultipartUploadPartParam(part_number=part_url.part_number, etag=etag)
224+
225+
with ThreadPoolExecutor(max_workers=options.concurrency) as executor:
226+
futures = [executor.submit(_do_upload, pu) for pu in upload.part_urls]
227+
for future in futures:
228+
completed_parts.append(future.result())
229+
230+
# Sort by part number
231+
completed_parts.sort(key=lambda p: p["part_number"])
232+
233+
# Step 3: Complete the upload
234+
return uploads.complete(
235+
upload_id=upload_id,
236+
parts=completed_parts,
237+
)
238+
239+
except BaseException:
240+
# Abort on any failure (including KeyboardInterrupt, CancelledError)
241+
try:
242+
uploads.abort(upload_id=upload_id)
243+
except Exception:
244+
pass # Best effort abort
245+
raise
246+
247+
248+
async def multipart_create_async(
249+
uploads: AsyncUploadsResource,
250+
file: FileTypes,
251+
options: MultipartUploadOptions,
252+
) -> FileObject:
253+
"""Perform a multipart upload asynchronously."""
254+
resolved = _resolve_file_input(file)
255+
part_count = max(1, math.ceil(resolved.file_size / options.part_size))
256+
257+
# Step 1: Initiate the multipart upload
258+
upload = await uploads.create(
259+
filename=resolved.filename,
260+
file_size=resolved.file_size,
261+
mime_type=resolved.mime_type,
262+
part_count=part_count,
263+
)
264+
upload_id = upload.id
265+
266+
try:
267+
# Step 2: Upload parts concurrently
268+
semaphore = asyncio.Semaphore(options.concurrency)
269+
270+
async with httpx.AsyncClient(timeout=httpx.Timeout(UPLOAD_TIMEOUT)) as http_client:
271+
272+
async def _do_upload(part_url: Any) -> MultipartUploadPartParam:
273+
async with semaphore:
274+
part_data = _read_part(resolved, part_url.part_number, options.part_size)
275+
etag = await _async_upload_single_part(part_url.url, part_data, http_client)
276+
277+
if options.on_part_upload:
278+
uploaded_bytes = min(
279+
part_url.part_number * options.part_size,
280+
resolved.file_size,
281+
)
282+
options.on_part_upload(
283+
PartUploadEvent(
284+
part_number=part_url.part_number,
285+
total_parts=part_count,
286+
part_size=len(part_data),
287+
uploaded_bytes=uploaded_bytes,
288+
total_bytes=resolved.file_size,
289+
)
290+
)
291+
292+
return MultipartUploadPartParam(part_number=part_url.part_number, etag=etag)
293+
294+
completed_parts: List[MultipartUploadPartParam] = list(
295+
await asyncio.gather(*[_do_upload(pu) for pu in upload.part_urls])
296+
)
297+
298+
# Sort by part number
299+
completed_parts.sort(key=lambda p: p["part_number"])
300+
301+
# Step 3: Complete the upload
302+
return await uploads.complete(
303+
upload_id=upload_id,
304+
parts=completed_parts,
305+
)
306+
307+
except BaseException:
308+
# Abort on any failure (including KeyboardInterrupt, CancelledError)
309+
try:
310+
await uploads.abort(upload_id=upload_id)
311+
except Exception:
312+
pass # Best effort abort
313+
raise

0 commit comments

Comments
 (0)