Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/tutorial/field-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,46 @@
`DbmModel` (Embed Model)
> Field value can be an instance of another DbmModel subclass. See [Embed Models](embed-models.md) for details.

## Optional Types

Any supported type can be made optional using `Optional[T]` or `T | None` (Python 3.10+).
Optional fields default to `None` if no value is provided and accept both the inner type value and `None`.

```python
import typing

from pydbm import DbmModel


class User(DbmModel):
name: str
nickname: typing.Optional[str] # defaults to None, accepts str or None
age: typing.Optional[int] # defaults to None, accepts int or None
```

```python
user = User(name="hakan")
assert user.nickname is None
assert user.age is None

user = User(name="hakan", nickname="hako", age=30)
assert user.nickname == "hako"
assert user.age == 30
```

You can also provide a custom default value for optional fields:

```python
class User(DbmModel):
name: str
nickname: typing.Optional[str] = "anonymous"

user = User(name="hakan")
assert user.nickname == "anonymous"
```

## Example

```python
import datetime

Expand Down
32 changes: 28 additions & 4 deletions src/pydbm/database/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pydbm import contstant as C
from pydbm.database.data_types import BaseDataType
from pydbm.inspect_extra import get_obj_annotations
from pydbm.inspect_extra import get_obj_annotations, is_optional_type, unwrap_optional
from pydbm.models.fields import AutoField

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +52,7 @@ class DatabaseManager:
"db_path",
"db",
DATABASE_HEADER_NAME,
"__optional_fields__",
"_keys",
"__is_db_open",
)
Expand Down Expand Up @@ -113,8 +114,18 @@ def set_database_header(self):
ann = get_obj_annotations(obj=self.model)

resolved_ann = {}
optional_fields: set[str] = set()
for key, value in ann.items():
if value in DATABASE_HEADER_MAPPING:
if is_optional_type(value):
inner_type = unwrap_optional(value)
optional_fields.add(key)
if inner_type in DATABASE_HEADER_MAPPING:
resolved_ann[key] = inner_type
elif isinstance(inner_type, type) and hasattr(inner_type, "objects"):
resolved_ann[key] = dict
else:
resolved_ann[key] = inner_type
elif value in DATABASE_HEADER_MAPPING:
resolved_ann[key] = value
elif isinstance(value, type) and hasattr(value, "objects"):
resolved_ann[key] = dict
Expand All @@ -133,6 +144,7 @@ def set_database_header(self):
assert database_header == db_headers, f"Database headers are not equal: '{database_header}' != '{db_headers}'" # type: ignore[str-bytes-safe] # noqa: E501

setattr(self, DATABASE_HEADER_NAME, resolved_ann)
self.__optional_fields__ = optional_fields

def open(self):
if not self.__is_db_open:
Expand All @@ -148,6 +160,9 @@ def close(self) -> None:
def save(self, *, id: str, fields: dict[str, typing.Any]) -> None:
data: dict[str, typing.Any] = {}
for key, value in fields.items():
if value is None and key in self.__optional_fields__:
data[key] = None
continue
header_type = self.__database_headers__[key]
if header_type is dict and hasattr(value, "as_dict"):
embed_headers = value.objects.__database_headers__
Expand Down Expand Up @@ -194,7 +209,10 @@ def get(self, *, id: str | None = None, **unique_together) -> DbmModel:
to_python = ast.literal_eval(data_from_dbm.decode("utf-8")) # TODO: implement own parser
fields: dict[str, typing.Any] = {}
for key, value in to_python.items():
fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value)
if value is None and key in self.__optional_fields__:
fields[key] = None
else:
fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value)

return self.model(**fields)

Expand All @@ -212,6 +230,9 @@ def update(self, *, id: str, **updated_fields) -> None:

data = ast.literal_eval(data_from_dbm.decode("utf-8"))
for key, value in updated_fields.items():
if value is None and key in self.__optional_fields__:
data[key] = None
continue
header_type = self.__database_headers__[key]
if header_type is dict and hasattr(value, "as_dict"):
embed_headers = value.objects.__database_headers__
Expand Down Expand Up @@ -241,7 +262,10 @@ def all(self) -> typing.Iterable[DbmModel]:
to_python = ast.literal_eval(data_from_dbm.decode("utf-8"))
fields: dict[str, typing.Any] = {}
for key, value in to_python.items():
fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value)
if value is None and key in self.__optional_fields__:
fields[key] = None
else:
fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value)
yield self.model(**fields)

def filter(self, **kwargs) -> typing.Iterator[DbmModel]:
Expand Down
26 changes: 26 additions & 0 deletions src/pydbm/inspect_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,35 @@

__all__ = (
"get_obj_annotations",
"is_optional_type",
"unwrap_optional",
)


def is_optional_type(tp: typing.Any) -> bool:
"""Check if a type is Optional (Union with NoneType), e.g. Optional[str] or str | None."""
origin = typing.get_origin(tp)
if origin is typing.Union:
args = typing.get_args(tp)
return type(None) in args and len(args) == 2
if sys.version_info >= (3, 10):
import types as _types

if isinstance(tp, _types.UnionType):
args = typing.get_args(tp)
return type(None) in args and len(args) == 2
return False


def unwrap_optional(tp: typing.Any) -> typing.Any:
"""Extract the inner type from Optional[X] / X | None."""
args = typing.get_args(tp)
for arg in args:
if arg is not type(None):
return arg
return type(None)


def get_obj_annotations(*, obj: typing.Type[typing.Any]) -> dict[str, typing.Any]:
assert inspect.isclass(obj), f"{obj!r} must be a class"

Expand Down
21 changes: 19 additions & 2 deletions src/pydbm/models/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class BaseField:
"kwargs",
"_is_call_run",
"_is_embed_model",
"_is_optional",
)

def __init__(
Expand All @@ -65,6 +66,7 @@ def __init__(

self._is_call_run = False
self._is_embed_model = False
self._is_optional = False

def __set_name__(self, instance: Meta, name: str) -> None:
self.public_name = name
Expand All @@ -78,6 +80,12 @@ def __get__(self, instance: Meta, owner: DbmModel) -> typing.Any:

def __set__(self, instance: DbmModel, value: typing.Any) -> None:
if self._is_embed_model:
if value is None and self._is_optional:
setattr(instance, self.private_name, None)
if self.field_name != C.PRIMARY_KEY:
instance.fields[self.field_name] = None
return

from pydbm.database.data_types import BaseDataType

if isinstance(value, dict):
Expand All @@ -102,8 +110,9 @@ def __set__(self, instance: DbmModel, value: typing.Any) -> None:
if self.field_name != C.PRIMARY_KEY:
instance.fields[self.field_name] = eligible_value

def __call__(self: Self, field_name: str, field_type: SupportedClassT, *args, **kwargs) -> Self: # type: ignore[valid-type] # noqa: E501
def __call__(self: Self, field_name: str, field_type: SupportedClassT, *args, is_optional: bool = False, **kwargs) -> Self: # type: ignore[valid-type] # noqa: E501
self._is_call_run = True
self._is_optional = is_optional

self.field_name = field_name
self.field_type = field_type
Expand All @@ -122,7 +131,15 @@ def __call__(self: Self, field_name: str, field_type: SupportedClassT, *args, **
)
self.validators.append(validator_mapping[field_type])
else:
self.validators.append(validator_mapping[field_type])
inner_validator = validator_mapping[field_type]
if is_optional:
def optional_validator(value: typing.Any, v: ValidatorT = inner_validator) -> None:
if value is not None:
v(value)

self.validators.append(optional_validator)
else:
self.validators.append(inner_validator)

if field_type is int:
if self.min_value:
Expand Down
14 changes: 11 additions & 3 deletions src/pydbm/models/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydbm import typing_extra
from pydbm.database import DatabaseManager
from pydbm.exceptions import EmptyModelError, PydbmBaseException, ReadOnlyFieldError, UnnecessaryParamsError
from pydbm.inspect_extra import get_obj_annotations
from pydbm.inspect_extra import get_obj_annotations, is_optional_type, unwrap_optional
from pydbm.models.fields import AutoField, Field, Undefined

__all__ = (
Expand Down Expand Up @@ -130,9 +130,17 @@ def generate_fields(mcs, cls, cls_name: str, namespace: dict[str, typing.Any]) -
if field_name == C.PRIMARY_KEY:
continue

optional = is_optional_type(field_type)
actual_type = unwrap_optional(field_type) if optional else field_type

default_value: Field | typing.Any = namespace.get(field_name, Undefined)
field = default_value if isinstance(default_value, Field) else Field(default=default_value)
fields.update({field_name: field(field_name, field_type)})
if isinstance(default_value, Field):
field = default_value
elif optional and default_value is Undefined:
field = Field(default=None)
else:
field = Field(default=default_value)
fields.update({field_name: field(field_name, actual_type, is_optional=optional)})
return fields

@staticmethod
Expand Down
Loading