Skip to content

Commit a05f7e6

Browse files
AlexPetuladhtruong
andauthored
fix: Respect init parameter for SQLAlchemy dataclasses (#793)
Co-authored-by: Andrew Truong <40660973+adhtruong@users.noreply.github.com>
1 parent ad86dc1 commit a05f7e6

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

polyfactory/factories/sqlalchemy_factory.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Collection, Mapping
4+
from dataclasses import is_dataclass
45
from datetime import date, datetime
56
from typing import (
67
TYPE_CHECKING,
@@ -174,8 +175,22 @@ def should_column_be_set(cls, column: Any) -> bool:
174175
if not cls.__set_primary_key__ and column.primary_key:
175176
return False
176177

178+
if not cls.should_dataclass_init_field(column.name):
179+
return False
180+
177181
return bool(cls.__set_foreign_keys__ or not column.foreign_keys)
178182

183+
@classmethod
184+
def should_dataclass_init_field(cls, field_name: str) -> bool:
185+
if not is_dataclass(cls.__model__):
186+
return True
187+
188+
dataclass_fields = cls.__model__.__dataclass_fields__
189+
try:
190+
return dataclass_fields[field_name].init
191+
except KeyError:
192+
return True
193+
179194
@classmethod
180195
def _get_type_from_type_engine(cls, type_engine: TypeEngine) -> type:
181196
if type(type_engine) in cls.get_sqlalchemy_types():
@@ -285,6 +300,9 @@ def get_model_fields(cls) -> list[FieldMeta]:
285300
)
286301
if cls.__set_relationships__:
287302
for name, relationship in table.relationships.items():
303+
if not cls.should_dataclass_init_field(name):
304+
continue
305+
288306
annotation = cls._get_relationship_type(relationship)
289307
fields_meta.append(
290308
FieldMeta.from_type(
@@ -295,6 +313,9 @@ def get_model_fields(cls) -> list[FieldMeta]:
295313
if cls.__set_association_proxy__:
296314
for name, attr in table.all_orm_descriptors.items():
297315
if isinstance(attr, AssociationProxy):
316+
if not cls.should_dataclass_init_field(name):
317+
continue
318+
298319
# Read-only proxies derive from the underlying relationship and shouldn't be set directly.
299320
if not getattr(attr, "creator", None):
300321
continue

tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from uuid import UUID
55

66
import pytest
7-
from sqlalchemy import Text, __version__, orm, types
7+
from sqlalchemy import ForeignKey, Text, __version__, orm, types
88
from sqlalchemy.dialects.mssql import JSON as MSSQL_JSON
99
from sqlalchemy.dialects.mysql import JSON as MYSQL_JSON
1010
from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB
1111
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON
12+
from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy
1213
from sqlalchemy.ext.mutable import MutableDict, MutableList
1314

1415
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
@@ -132,3 +133,44 @@ class ModelFactory(SQLAlchemyFactory[Model]):
132133

133134
instance = ModelFactory.build()
134135
assert instance.overridden is not None
136+
137+
138+
def test_dataclass_mapped_do_not_init_field() -> None:
139+
class Base(orm.DeclarativeBase): ...
140+
141+
class Parent(orm.MappedAsDataclass, Base):
142+
__tablename__ = "tesT_model"
143+
144+
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
145+
name: orm.Mapped[str] = orm.mapped_column(init=False)
146+
children_no_init: orm.Mapped[list["Child"]] = orm.relationship(
147+
"Child",
148+
uselist=True,
149+
init=False,
150+
)
151+
children_init: orm.Mapped[list["Child"]] = orm.relationship(
152+
"Child",
153+
uselist=True,
154+
overlaps="children_no_init",
155+
)
156+
157+
child_ids: AssociationProxy[list[int]] = association_proxy(
158+
"children_init",
159+
"id",
160+
init=False,
161+
)
162+
163+
class Child(orm.MappedAsDataclass, Base):
164+
__tablename__ = "child_with_overridden_type"
165+
166+
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
167+
model_id: orm.Mapped[int] = orm.mapped_column(ForeignKey(Parent.id))
168+
169+
class ModelFactory(SQLAlchemyFactory[Parent]):
170+
__model__ = Parent
171+
172+
instance = ModelFactory.build()
173+
assert instance.name is None
174+
assert instance.children_no_init == [] # type: ignore[unreachable]
175+
assert len(instance.children_init) > 0
176+
assert instance.child_ids[0] == instance.children_init[0].id

0 commit comments

Comments
 (0)