Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions polyfactory/factories/attrs_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Generic, TypeVar

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.base import BaseFactory, cache_model_fields
from polyfactory.field_meta import FieldMeta, Null

if TYPE_CHECKING:
Expand Down Expand Up @@ -35,13 +35,19 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return isclass(value) and hasattr(value, "__attrs_attrs__")

@classmethod
def _init_model(cls) -> None:
"""Initialize the model and resolve type annotations."""
super()._init_model()
if hasattr(cls, "__model__"):
cls.resolve_types(cls.__model__)

@classmethod
@cache_model_fields
def get_model_fields(cls) -> list[FieldMeta]:
field_metas: list[FieldMeta] = []
none_type = type(None)

cls.resolve_types(cls.__model__)
fields = attrs.fields(cls.__model__)

for field in fields:
if not field.init:
continue
Expand Down
22 changes: 20 additions & 2 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import functools
import inspect
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
Expand Down Expand Up @@ -97,6 +98,22 @@
F = TypeVar("F", bound="BaseFactory[Any]")


def cache_model_fields(func: Callable[[type[F]], list["FieldMeta"]]) -> Callable[[type[F]], list["FieldMeta"]]:
"""Decorator to cache the results of get_model_fields() to avoid repeated introspection.

:param func: The get_model_fields classmethod to wrap
:returns: Wrapped function with caching
"""

@functools.wraps(func)
def wrapper(cls: type[F]) -> list["FieldMeta"]:
if "_fields_metadata" not in cls.__dict__:
cls._fields_metadata = func(cls)
return cls._fields_metadata

return wrapper


class BuildContext(TypedDict):
seen_models: set[type]

Expand Down Expand Up @@ -124,12 +141,13 @@ class BaseFactory(ABC, Generic[T]):
"""A sync persistence handler. Can be a class or a class instance."""
__async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None
"""An async persistence handler. Can be a class or a class instance."""
__set_as_default_factory_for_type__ = False

__set_as_default_factory_for_type__: ClassVar[bool] = False
"""
Flag dictating whether to set as the default factory for the given type.
If 'True' the factory will be used instead of dynamically generating a factory for the type.
"""
__is_base_factory__: bool = False
__is_base_factory__: ClassVar[bool] = False
"""
Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types.
For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories.
Expand Down
3 changes: 2 additions & 1 deletion polyfactory/factories/dataclass_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing_extensions import TypeGuard

from polyfactory.factories.base import BaseFactory, T
from polyfactory.factories.base import BaseFactory, T, cache_model_fields
from polyfactory.field_meta import FieldMeta, Null


Expand All @@ -24,6 +24,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return bool(is_dataclass(value))

@classmethod
@cache_model_fields
def get_model_fields(cls) -> list["FieldMeta"]:
"""Retrieve a list of fields from the factory's model.

Expand Down
3 changes: 2 additions & 1 deletion polyfactory/factories/msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, get_type_hints

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.base import BaseFactory, cache_model_fields
from polyfactory.field_meta import FieldMeta, Null
from polyfactory.value_generators.constrained_numbers import handle_constrained_int
from polyfactory.value_generators.primitives import create_random_bytes
Expand Down Expand Up @@ -46,6 +46,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return isclass(value) and hasattr(value, "__struct_fields__")

@classmethod
@cache_model_fields
def get_model_fields(cls) -> list[FieldMeta]:
fields_meta: list[FieldMeta] = []

Expand Down
42 changes: 20 additions & 22 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing_extensions import Literal, get_args

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory, BuildContext
from polyfactory.factories.base import BaseFactory, BuildContext, cache_model_fields
from polyfactory.factories.base import BuildContext as BaseBuildContext
from polyfactory.field_meta import Constraints, FieldMeta, Null
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
Expand Down Expand Up @@ -411,35 +411,33 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value)

@classmethod
@cache_model_fields
def get_model_fields(cls) -> list["FieldMeta"]:
"""Retrieve a list of fields from the factory's model.


:returns: A list of field MetaData instances.

"""
if "_fields_metadata" not in cls.__dict__:
if _is_pydantic_v1_model(cls.__model__):
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
)
for field in cls.__model__.__fields__.values()
]
else:
use_alias = cls.__model__.model_config.get("validate_by_name", False) or cls.__model__.model_config.get(
"populate_by_name", False
if _is_pydantic_v1_model(cls.__model__):
return [
PydanticFieldMeta.from_model_field(
field,
use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
)
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
use_alias=not use_alias,
)
for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
]
return cls._fields_metadata
for field in cls.__model__.__fields__.values()
]
use_alias = cls.__model__.model_config.get("validate_by_name", False) or cls.__model__.model_config.get(
"populate_by_name", False
)
return [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
use_alias=not use_alias,
)
for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
]

@classmethod
def get_constrained_field_value(
Expand Down
3 changes: 3 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from polyfactory.exceptions import ConfigurationException, MissingDependencyException, ParameterException
from polyfactory.factories.base import BaseFactory
from polyfactory.exceptions import MissingDependencyException, ParameterException
from polyfactory.factories.base import BaseFactory, cache_model_fields
from polyfactory.field_meta import Constraints, FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
from polyfactory.utils.types import Frozendict
Expand Down Expand Up @@ -241,6 +243,7 @@ def get_type_from_collection_class(
return annotation

@classmethod
@cache_model_fields
def get_model_fields(cls) -> list[FieldMeta]:
fields_meta: list[FieldMeta] = []

Expand Down
3 changes: 2 additions & 1 deletion polyfactory/factories/typed_dict_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
is_typeddict,
)

from polyfactory.factories.base import BaseFactory
from polyfactory.factories.base import BaseFactory, cache_model_fields
from polyfactory.field_meta import FieldMeta, Null

TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta)
Expand All @@ -32,6 +32,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]:
return is_typeddict(value)

@classmethod
@cache_model_fields
def get_model_fields(cls) -> list["FieldMeta"]:
"""Retrieve a list of fields from the factory's model.

Expand Down
4 changes: 1 addition & 3 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import datetime
from collections.abc import Sequence
from decimal import Decimal
from random import Random
from re import Pattern

from typing_extensions import NotRequired, Self
Expand Down Expand Up @@ -68,10 +67,9 @@ class Constraints(TypedDict):
class FieldMeta:
"""Factory field metadata container. This class is used to store the data about a field of a factory's model."""

__slots__ = ("__dict__", "annotation", "children", "constraints", "default", "name", "random")
__slots__ = ("__dict__", "annotation", "children", "constraints", "default", "name")

annotation: Any
random: Random
children: list[FieldMeta] | None
default: Any
name: str
Expand Down
Loading