Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,11 @@ def _is_json_field(field: ModelField) -> bool:


def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: str | None = None
field: ModelField,
values: Mapping[str, Any],
alias: str | None = None,
*,
use_default_when_missing: bool = True,
) -> Any:
alias = alias or get_validation_alias(field)
if (
Expand All @@ -776,8 +780,9 @@ def _get_multidict_value(
):
if field.field_info.is_required():
return
else:
if use_default_when_missing:
return deepcopy(field.default)
return
return value


Expand All @@ -795,11 +800,13 @@ def request_params_to_args(
fields_to_extract = fields
single_not_embedded_field = False
default_convert_underscores = True
is_model_param = False
if len(fields) == 1 and lenient_issubclass(
first_field.field_info.annotation, BaseModel
):
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
single_not_embedded_field = True
is_model_param = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
default_convert_underscores = getattr(
Expand All @@ -822,7 +829,12 @@ def request_params_to_args(
alias = get_validation_alias(field)
if alias == field.name:
alias = alias.replace("_", "-")
value = _get_multidict_value(field, received_params, alias=alias)
value = _get_multidict_value(
field,
received_params,
alias=alias,
use_default_when_missing=not is_model_param,
)
if value is not None:
params_to_process[get_validation_alias(field)] = value
processed_keys.add(alias or get_validation_alias(field))
Expand Down Expand Up @@ -912,11 +924,15 @@ def _should_embed_body_fields(fields: list[ModelField]) -> bool:
async def _extract_form_body(
body_fields: list[ModelField],
received_body: FormData,
*,
use_default_when_missing: bool = True,
) -> dict[str, Any]:
values = {}

for field in body_fields:
value = _get_multidict_value(field, received_body)
value = _get_multidict_value(
field, received_body, use_default_when_missing=use_default_when_missing
)
field_info = field.field_info
if (
isinstance(field_info, params.File)
Expand Down Expand Up @@ -970,7 +986,16 @@ async def request_body_to_args(
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)

if isinstance(received_body, FormData):
body_to_process = await _extract_form_body(fields_to_extract, received_body)
body_to_process = await _extract_form_body(
fields_to_extract,
received_body,
# Keep omitted fields absent so Pydantic can apply defaults without
# marking them as explicitly provided on the resulting model.
use_default_when_missing=not (
single_not_embedded_field
and lenient_issubclass(first_field.field_info.annotation, BaseModel)
),
)

if single_not_embedded_field:
loc: tuple[str, ...] = ("body",)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_forms_single_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def test_no_data():
"type": "missing",
"loc": ["body", "username"],
"msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"},
"input": {},
},
{
"type": "missing",
"loc": ["body", "lastname"],
"msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"},
"input": {},
},
]
}
Expand Down
96 changes: 96 additions & 0 deletions tests/test_model_param_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Annotated

import pytest
from fastapi import Cookie, FastAPI, Form, Header, Query
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field

app = FastAPI()


class DefaultModel(BaseModel):
field_1: bool = True


class InvalidDefaultModel(BaseModel):
field_1: Annotated[str, Field(default=0)]


@app.get("/query")
def read_query(model: Annotated[DefaultModel, Query()]) -> dict[str, object]:
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}


@app.get("/header")
def read_header(model: Annotated[DefaultModel, Header()]) -> dict[str, object]:
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}


@app.get("/cookie")
def read_cookie(model: Annotated[DefaultModel, Cookie()]) -> dict[str, object]:
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}


@app.post("/form")
def read_form(model: Annotated[DefaultModel, Form()]) -> dict[str, object]:
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}


@app.post("/body-invalid-default")
def read_body_invalid_default(model: InvalidDefaultModel) -> dict[str, list[str]]:
return {"fields_set": sorted(model.model_fields_set)}


@app.post("/form-invalid-default")
def read_form_invalid_default(
model: Annotated[InvalidDefaultModel, Form()],
) -> dict[str, list[str]]:
return {"fields_set": sorted(model.model_fields_set)}


client = TestClient(app)


@pytest.mark.parametrize(
("method", "path", "kwargs"),
[
("get", "/query", {}),
("get", "/header", {}),
("get", "/cookie", {}),
("post", "/form", {"data": {}}),
],
)
def test_missing_model_defaults_not_marked_as_set(
method: str, path: str, kwargs: dict[str, object]
) -> None:
response = getattr(client, method)(path, **kwargs)

assert response.status_code == 200, response.text
assert response.json() == {
"fields_set": [],
"model": {"field_1": True},
}


def test_explicit_form_model_value_is_still_marked_as_set() -> None:
response = client.post("/form", data={"field_1": "false"})

assert response.status_code == 200, response.text
assert response.json() == {
"fields_set": ["field_1"],
"model": {"field_1": False},
}


@pytest.mark.parametrize(
"path",
["/body-invalid-default", "/form-invalid-default"],
)
def test_omitted_invalid_defaults_do_not_trigger_validation(path: str) -> None:
if path == "/body-invalid-default":
response = client.post(path, json={})
else:
response = client.post(path, data={})

assert response.status_code == 200, response.text
assert response.json() == {"fields_set": []}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def test_header_param_model_invalid(client: TestClient):
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {
"x_tag": [],
"host": "testserver",
"accept": "*/*",
"accept-encoding": "gzip, deflate",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_header_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {"x_tag": [], "host": "testserver"},
"input": {"host": "testserver"},
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_header_param_model_no_underscore(client: TestClient):
"input": {
"host": "testserver",
"traceparent": "123",
"x_tag": [],
"accept": "*/*",
"accept-encoding": "gzip, deflate",
"connection": "keep-alive",
Expand Down Expand Up @@ -102,7 +101,6 @@ def test_header_param_model_invalid(client: TestClient):
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {
"x_tag": [],
"host": "testserver",
"accept": "*/*",
"accept-encoding": "gzip, deflate",
Expand Down