diff --git a/reflex/app.py b/reflex/app.py index 4ff412ef863..2a832905dee 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1911,16 +1911,47 @@ async def upload_file(request: Request): """ from reflex.utils.exceptions import UploadTypeError, UploadValueError + config = get_config() + upload_max_size = config.upload_max_size + upload_max_files = config.upload_max_files + + # Reject based on Content-Length before Starlette buffers the body. + # Content-Length covers the entire multipart request, so use + # (per-file limit * max files) as the upper bound. + if upload_max_size > 0 and upload_max_files > 0: + content_length_str = request.headers.get("content-length") + if content_length_str is not None: + try: + content_length = int(content_length_str) + except ValueError: + return JSONResponse( + status_code=400, + content={"detail": "Invalid Content-Length header."}, + ) + max_request_size = upload_max_size * upload_max_files + if content_length > max_request_size: + return JSONResponse( + status_code=413, + content={"detail": f"Upload exceeds the maximum allowed size of {max_request_size} bytes."}, + ) + # Get the files from the request. try: files = await request.form() except ClientDisconnect: return Response() # user cancelled - files = files.getlist("files") - if not files: + file_list = files.getlist("files") + if not file_list: msg = "No files were uploaded." raise UploadValueError(msg) + # Enforce max file count. + if upload_max_files > 0 and len(file_list) > upload_max_files: + return JSONResponse( + status_code=400, + content={"detail": f"Too many files uploaded ({len(file_list)}). Maximum allowed is {upload_max_files}."}, + ) + token = request.headers.get("reflex-client-token") handler = request.headers.get("reflex-event-handler") @@ -1966,19 +1997,54 @@ async def upload_file(request: Request): ) raise UploadValueError(msg) + async def _cleanup(upload_files, copied_files): + """Close all uploaded files and any already-copied BytesIO buffers. + + Args: + upload_files: The raw uploaded file list from the request. + copied_files: The list of UploadFile copies made so far. + """ + for f in upload_files: + if isinstance(f, StarletteUploadFile): + await f.close() + for f in copied_files: + f.file.close() + # Make a copy of the files as they are closed after the request. # This behaviour changed from fastapi 0.103.0 to 0.103.1 as the # AsyncExitStack was removed from the request scope and is now # part of the routing function which closes this before the # event is handled. file_copies = [] - for file in files: + for file in file_list: if not isinstance(file, StarletteUploadFile): + await _cleanup(file_list, file_copies) raise UploadValueError( "Uploaded file is not an UploadFile." + str(file) ) + # Enforce upload size limit: early rejection via file.size header + if upload_max_size > 0 and file.size is not None and file.size > upload_max_size: + await _cleanup(file_list, file_copies) + return JSONResponse( + status_code=413, + content={"detail": f"File exceeds the maximum upload size of {upload_max_size} bytes."}, + ) content_copy = io.BytesIO() - content_copy.write(await file.read()) + # Read in chunks to enforce limit even when file.size is not reported + bytes_read = 0 + while True: + chunk = await file.read(1024 * 1024) # 1 MB chunks + if not chunk: + break + bytes_read += len(chunk) + if upload_max_size > 0 and bytes_read > upload_max_size: + content_copy.close() + await _cleanup(file_list, file_copies) + return JSONResponse( + status_code=413, + content={"detail": f"File exceeds the maximum upload size of {upload_max_size} bytes."}, + ) + content_copy.write(chunk) content_copy.seek(0) file_copies.append( UploadFile( @@ -1989,12 +2055,10 @@ async def upload_file(request: Request): ) ) - for file in files: - if not isinstance(file, StarletteUploadFile): - raise UploadValueError( - "Uploaded file is not an UploadFile." + str(file) - ) - await file.close() + # Close the raw uploaded files (copies are kept for event processing). + for file in file_list: + if isinstance(file, StarletteUploadFile): + await file.close() event = Event( token=token, diff --git a/reflex/config.py b/reflex/config.py index 6977d745631..06f78da5d8b 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -260,6 +260,13 @@ class BaseConfig: # The transport method for client-server communication. transport: Literal["websocket", "polling"] = "websocket" + # Maximum file upload size in bytes. Files larger than this will be rejected with HTTP 413. + # Defaults to 10 MB. Set to 0 to disable the limit. + upload_max_size: int = 10 * 1024 * 1024 # 10 MB + + # Maximum number of files per upload request. Set to 0 to disable the limit. + upload_max_files: int = 10 + # Whether to skip plugin checks. _skip_plugins_checks: bool = dataclasses.field(default=False, repr=False) @@ -369,6 +376,13 @@ def _post_init(self, **kwargs): msg = f"{self._prefixes[0]}REDIS_URL is required when using the redis state manager." raise ConfigError(msg) + if self.upload_max_size < 0: + msg = "upload_max_size must be >= 0." + raise ConfigError(msg) + if self.upload_max_files < 0: + msg = "upload_max_files must be >= 0." + raise ConfigError(msg) + def _add_builtin_plugins(self): """Add the builtin plugins to the config.""" for plugin in _PLUGINS_ENABLED_BY_DEFAULT: diff --git a/tests/units/test_upload_size_limit.py b/tests/units/test_upload_size_limit.py new file mode 100644 index 00000000000..7a8b6dd84b6 --- /dev/null +++ b/tests/units/test_upload_size_limit.py @@ -0,0 +1,299 @@ +"""Tests for upload file size limit enforcement in the upload handler.""" + +import io +import unittest.mock +from unittest.mock import AsyncMock, patch + +import pytest +import reflex as rx +from starlette.datastructures import UploadFile as StarletteUploadFile +from starlette.responses import JSONResponse + +from reflex.app import upload + +# Use a small limit for tests to avoid allocating large buffers. +_TEST_MAX_SIZE = 500 # 500 bytes +_TEST_MAX_FILES = 2 + + +def _make_upload_file(filename: str, content: bytes, *, report_size: bool = True): + """Create a StarletteUploadFile with given content.""" + file = StarletteUploadFile(filename=filename, file=io.BytesIO(content)) + file.size = len(content) if report_size else None + return file + + +_SENTINEL = object() + + +def _make_request_mock(files, content_length=_SENTINEL): + """Create a mock Starlette Request with the given files. + + Args: + files: List of upload files. + content_length: If provided, set as the raw content-length header value. + Pass an int for valid values, or a raw string (e.g. "", " ", "abc") + to test malformed header handling. Omit to leave the header unset. + """ + request_mock = unittest.mock.Mock() + headers = { + "reflex-client-token": "test-token", + "reflex-event-handler": "fake_state.handle_upload", + } + if content_length is not _SENTINEL: + headers["content-length"] = str(content_length) if isinstance(content_length, int) else content_length + request_mock.headers = headers + + async def form(): # noqa: RUF029 + files_mock = unittest.mock.Mock() + files_mock.getlist = lambda key: files + return files_mock + + request_mock.form = form + return request_mock + + +def _mock_config(upload_max_size=_TEST_MAX_SIZE, upload_max_files=_TEST_MAX_FILES): + """Create a mock config with given upload limits.""" + config = unittest.mock.Mock() + config.upload_max_size = upload_max_size + config.upload_max_files = upload_max_files + return config + + +class _FakeState: + """A minimal fake state that has a handle_upload method with correct annotations.""" + + def handle_upload(self, files: list[rx.UploadFile]): + pass + + def get_substate(self, path): + return self + + +def _make_app_mock(): + """Create a mock app whose state manager returns a _FakeState.""" + app_mock = unittest.mock.Mock() + app_mock.state_manager.get_state = AsyncMock(return_value=_FakeState()) + return app_mock + + +# --- Config default tests --- + + +def test_default_upload_max_size(): + """Default upload_max_size should be 10 MB.""" + from reflex.config import get_config + + config = get_config() + assert config.upload_max_size == 10 * 1024 * 1024 + + +def test_default_upload_max_files(): + """Default upload_max_files should be 10.""" + from reflex.config import get_config + + config = get_config() + assert config.upload_max_files == 10 + + +def test_negative_upload_max_size_rejected(): + """Negative upload_max_size should raise ConfigError.""" + from reflex.config import Config + from reflex.utils.exceptions import ConfigError + + with pytest.raises(ConfigError, match="upload_max_size must be >= 0"): + Config(app_name="test", upload_max_size=-1, _skip_plugins_checks=True) + + +def test_negative_upload_max_files_rejected(): + """Negative upload_max_files should raise ConfigError.""" + from reflex.config import Config + from reflex.utils.exceptions import ConfigError + + with pytest.raises(ConfigError, match="upload_max_files must be >= 0"): + Config(app_name="test", upload_max_files=-1, _skip_plugins_checks=True) + + +# --- Content-Length pre-check tests --- + + +@pytest.mark.asyncio +async def test_content_length_over_limit_rejected(): + """Request with Content-Length exceeding total limit is rejected before form parsing.""" + app_mock = unittest.mock.Mock() + max_request_size = _TEST_MAX_SIZE * _TEST_MAX_FILES + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([], content_length=max_request_size + 1) + response = await upload_fn(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 413 + + +@pytest.mark.asyncio +async def test_content_length_under_limit_passes_precheck(): + """Request with Content-Length under the total limit passes the pre-check.""" + small_file = _make_upload_file("ok.txt", b"x" * 100) + app_mock = _make_app_mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([small_file], content_length=100) + response = await upload_fn(request) + + # Content-Length is under the limit, file is small — should reach + # the streaming response (event processing), not a 413. + assert not isinstance(response, JSONResponse) or response.status_code != 413 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "bad_value", + ["not-a-number", "", " ", "12.5", "1e3"], + ids=["text", "empty", "whitespace", "float", "scientific"], +) +async def test_malformed_content_length_returns_400(bad_value): + """Non-integer Content-Length header returns 400 Bad Request.""" + app_mock = unittest.mock.Mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([], content_length=bad_value) + response = await upload_fn(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + +# --- Per-file size enforcement tests --- + + +@pytest.mark.asyncio +async def test_file_over_limit_rejected_413(): + """Oversized file is rejected with HTTP 413 by the upload handler.""" + oversized = _make_upload_file("big.bin", b"x" * (_TEST_MAX_SIZE + 1)) + app_mock = _make_app_mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([oversized]) + response = await upload_fn(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 413 + + +@pytest.mark.asyncio +async def test_file_over_limit_rejected_by_chunked_read(): + """File with size=None but content over limit is rejected during chunked read.""" + oversized_no_size = _make_upload_file( + "big.bin", b"x" * (_TEST_MAX_SIZE + 1), report_size=False + ) + app_mock = _make_app_mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([oversized_no_size]) + response = await upload_fn(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 413 + + +@pytest.mark.asyncio +async def test_file_under_limit_not_rejected(): + """File under the limit passes the size enforcement and reaches event processing.""" + small = _make_upload_file("ok.txt", b"x" * 100) + app_mock = _make_app_mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([small]) + response = await upload_fn(request) + + # Handler should proceed past size checks to event processing. + # It returns a StreamingResponse on success, never a 413 JSONResponse. + assert not isinstance(response, JSONResponse) or response.status_code != 413 + + +# --- Max files tests --- + + +@pytest.mark.asyncio +async def test_too_many_files_rejected_with_400(): + """Uploading more files than upload_max_files is rejected with 400.""" + files = [ + _make_upload_file(f"file{i}.txt", b"x" * 10) + for i in range(_TEST_MAX_FILES + 1) + ] + app_mock = unittest.mock.Mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock(files) + response = await upload_fn(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_files_within_limit_not_rejected_for_count(): + """Uploading files within upload_max_files passes the count check.""" + files = [ + _make_upload_file(f"file{i}.txt", b"x" * 10) for i in range(_TEST_MAX_FILES) + ] + app_mock = _make_app_mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock(files) + response = await upload_fn(request) + + # Should not be rejected for file count — reaches event processing. + assert not isinstance(response, JSONResponse) or response.status_code not in (400, 413) + + +# --- Limit disabled tests --- + + +@pytest.mark.asyncio +async def test_size_limit_disabled_allows_large_file(): + """When upload_max_size=0, large files are not rejected for size.""" + large = _make_upload_file("big.bin", b"x" * 1000) + app_mock = _make_app_mock() + + with patch( + "reflex.app.get_config", + return_value=_mock_config(upload_max_size=0, upload_max_files=0), + ): + upload_fn = upload(app_mock) + request = _make_request_mock([large]) + response = await upload_fn(request) + + # With limits disabled, should never get a 413. + assert not isinstance(response, JSONResponse) or response.status_code != 413 + + +# --- Filename sanitization test --- + + +@pytest.mark.asyncio +async def test_error_message_does_not_contain_filename(): + """413 error message should not echo back the attacker-controlled filename.""" + evil_name = ".bin" + oversized = _make_upload_file(evil_name, b"x" * (_TEST_MAX_SIZE + 1)) + app_mock = _make_app_mock() + + with patch("reflex.app.get_config", return_value=_mock_config()): + upload_fn = upload(app_mock) + request = _make_request_mock([oversized]) + response = await upload_fn(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 413 + body = response.body.decode() + assert evil_name not in body