diff --git a/Makefile b/Makefile index 7ccd039..95a7cf8 100644 --- a/Makefile +++ b/Makefile @@ -6,9 +6,9 @@ sources = sqlalchemy_easy_softdelete lint: uv run pre-commit run --all-files -# Run type checking (mypy) +# Run type checking (mypy) on source code and tests typecheck: - uv run mypy $(sources) + uv run mypy $(sources) tests/ # Quick test with SQLite (no docker needed) test: diff --git a/README.md b/README.md index 8a451a3..036cdfd 100644 --- a/README.md +++ b/README.md @@ -27,18 +27,32 @@ pip install sqlalchemy-easy-softdelete ```py from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class from sqlalchemy_easy_softdelete.hook import IgnoredTable -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declarative_base, Mapped from sqlalchemy import Column, Integer from datetime import datetime # Create a Class that inherits from our class builder -class SoftDeleteMixin(generate_soft_delete_mixin_class( - # This table will be ignored by the hook - # even if the table has the soft-delete column - ignored_tables=[IgnoredTable(table_schema="public", name="cars"),] -)): - # type hint for autocomplete IDE support - deleted_at: datetime +class SoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + # This table will be ignored by the hook + # even if the table has the soft-delete column + ignored_tables=[IgnoredTable(table_schema="public", name="cars"),] + ) +): + # type: ignore[misc] is required because the mixin is dynamically generated + + # Type hint for IDE autocomplete and type checker support. + # Using Mapped[T | None] ensures type checkers understand this is a + # SQLAlchemy column that supports query operations like .where() + deleted_at: Mapped[datetime | None] + + # Optional: Add method stubs for delete/undelete for type checker support. + # The actual implementations are provided by the generated mixin class. + def delete(self, v: datetime | None = None) -> None: + super().delete(v) # type: ignore[misc] + + def undelete(self) -> None: + super().undelete() # type: ignore[misc] # Apply the mixin to your Models Base = declarative_base() diff --git a/mypy.ini b/mypy.ini index 053bccf..cc989f1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,12 +1,12 @@ [mypy] -packages = sqlalchemy_easy_softdelete +packages = sqlalchemy_easy_softdelete,tests python_version = 3.10 -# Lenient settings - this codebase wasn't originally typed -disallow_untyped_calls = false +# Strictness settings +disallow_untyped_calls = true disallow_untyped_defs = false -disallow_untyped_decorators = false +disallow_untyped_decorators = true check_untyped_defs = false -ignore_missing_imports = true +ignore_missing_imports = false allow_redefinition = true warn_unused_configs = true diff --git a/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py b/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py index a970b2f..ea5d22a 100644 --- a/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py +++ b/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py @@ -9,16 +9,15 @@ from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from sqlalchemy_easy_softdelete.hook import IgnoredTable -global_rewriter: SoftDeleteQueryRewriter | None = None - def activate_soft_delete_hook( deleted_field_name: str, disable_soft_delete_option_name: str, ignored_tables: list[IgnoredTable] -): - """Activate an event hook to rewrite the queries.""" +) -> SoftDeleteQueryRewriter: + """Activate an event hook to rewrite the queries. - global global_rewriter - global_rewriter = SoftDeleteQueryRewriter( + Returns the SoftDeleteQueryRewriter instance for use by the mixin class. + """ + rewriter = SoftDeleteQueryRewriter( deleted_field_name=deleted_field_name, disable_soft_delete_option_name=disable_soft_delete_option_name, ignored_tables=ignored_tables, @@ -30,10 +29,12 @@ def soft_delete_execute(state: ORMExecuteState): if not state.is_select: return - # Rewrite the statement - adapted = global_rewriter.rewrite_statement(state.statement) + # Rewrite the statement (closure captures local `rewriter`) + adapted = rewriter.rewrite_statement(state.statement) # Replace the statement # Cast needed because Statement type includes LambdaElement which mypy # doesn't recognize as Executable (even though it is at runtime) state.statement = cast(Executable, adapted) + + return rewriter diff --git a/sqlalchemy_easy_softdelete/mixin.py b/sqlalchemy_easy_softdelete/mixin.py index 8546fdb..d34b606 100644 --- a/sqlalchemy_easy_softdelete/mixin.py +++ b/sqlalchemy_easy_softdelete/mixin.py @@ -45,7 +45,11 @@ def undelete_method(_self): class_attributes[undelete_method_name] = undelete_method - activate_soft_delete_hook(deleted_field_name, disable_soft_delete_filtering_option_name, ignored_tables) + # Activate the soft delete hook and get the rewriter instance + rewriter = activate_soft_delete_hook(deleted_field_name, disable_soft_delete_filtering_option_name, ignored_tables) + + # Store rewriter on the generated class for testing purposes + class_attributes["_sqlalchemy_easy_softdelete_rewriter"] = rewriter generated_class = type(class_name, tuple(), class_attributes) diff --git a/tests/conftest.py b/tests/conftest.py index 0a196cc..981afb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,33 +1,28 @@ import os +from collections.abc import Generator import pytest from sqlalchemy import create_engine from sqlalchemy.engine import Connection, Engine -from sqlalchemy.orm import Session, sessionmaker - -from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter -from tests.model import TestModelBase -from tests.seed_data import generate_table_with_inheritance_obj -from tests.seed_data.parent_child_childchild import generate_parent_child_object_hierarchy env_connection_string = os.environ.get("TEST_CONNECTION_STRING", None) @pytest.fixture -def sqla2_warnings() -> Engine: +def sqla2_warnings() -> None: # Enable SQLAlchemy 2.0 Warnings mode to help with 2.0 support os.environ["SQLALCHEMY_WARN_20"] = "1" @pytest.fixture -def db_engine(sqla2_warnings) -> Engine: +def db_engine(sqla2_warnings: None) -> Engine: test_db_url = env_connection_string or "sqlite://" print(f"connection_string={test_db_url}") return create_engine(test_db_url, future=True) @pytest.fixture -def db_connection(db_engine) -> Connection: +def db_connection(db_engine: Engine) -> Generator[Connection, None, None]: connection = db_engine.connect() # start a transaction @@ -38,28 +33,3 @@ def db_connection(db_engine) -> Connection: finally: transaction.rollback() connection.close() - - -@pytest.fixture -def db_session(db_connection) -> Session: - TestModelBase.metadata.create_all(db_connection) - return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() - - -@pytest.fixture -def seeded_session(db_session) -> Session: - generate_parent_child_object_hierarchy(db_session, 1000) - generate_parent_child_object_hierarchy(db_session, 1001) - generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True) - - generate_table_with_inheritance_obj(db_session, 1000, deleted=False) - generate_table_with_inheritance_obj(db_session, 1001, deleted=False) - generate_table_with_inheritance_obj(db_session, 1002, deleted=True) - return db_session - - -@pytest.fixture -def rewriter() -> SoftDeleteQueryRewriter: - from sqlalchemy_easy_softdelete.handler.sqlalchemy_easy_softdelete import global_rewriter - - return global_rewriter diff --git a/tests/custom_default_value/__init__.py b/tests/custom_default_value/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/custom_default_value/conftest.py b/tests/custom_default_value/conftest.py new file mode 100644 index 0000000..1f8121a --- /dev/null +++ b/tests/custom_default_value/conftest.py @@ -0,0 +1,19 @@ +from typing import cast + +import pytest +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, sessionmaker + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.custom_default_value.model import CDVModelBase, CDVSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection: Connection) -> Session: + CDVModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, CDVSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/custom_default_value/model.py b/tests/custom_default_value/model.py new file mode 100644 index 0000000..37811cb --- /dev/null +++ b/tests/custom_default_value/model.py @@ -0,0 +1,26 @@ +from datetime import datetime, timezone + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class CDVModelBase: + """CDV = Custom Default Value""" + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class CDVSoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + delete_method_default_value=lambda: datetime(2000, 1, 1, tzinfo=timezone.utc), + ) +): + deleted_at: Mapped[datetime | None] + + +class CDVTable(CDVModelBase, CDVSoftDeleteMixin): + __tablename__ = "cdvtable" + value = Column(Integer) diff --git a/tests/custom_default_value/test_custom_default_value.py b/tests/custom_default_value/test_custom_default_value.py new file mode 100644 index 0000000..d1b1562 --- /dev/null +++ b/tests/custom_default_value/test_custom_default_value.py @@ -0,0 +1,31 @@ +"""Tests for custom default value option.""" + +from datetime import datetime, timezone + +from tests.custom_default_value.model import CDVTable + + +def test_delete_uses_custom_default_value(db_session): + """Verify delete() uses the custom default value function.""" + obj = CDVTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.delete() + + # Should use our custom date (2000-01-01) + # SQLite doesn't preserve timezone, so compare without it + assert obj.deleted_at.replace(tzinfo=None) == datetime(2000, 1, 1) + + +def test_delete_with_explicit_value_overrides_default(db_session): + """Verify delete(value) uses the passed value instead of default.""" + obj = CDVTable(value=1) + db_session.add(obj) + db_session.commit() + + custom_date = datetime(2020, 6, 15, 12, 30, tzinfo=timezone.utc) + obj.delete(custom_date) + + # SQLite doesn't preserve timezone, so compare without it + assert obj.deleted_at.replace(tzinfo=None) == custom_date.replace(tzinfo=None) diff --git a/tests/custom_deleted_field_name/__init__.py b/tests/custom_deleted_field_name/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/custom_deleted_field_name/conftest.py b/tests/custom_deleted_field_name/conftest.py new file mode 100644 index 0000000..640d014 --- /dev/null +++ b/tests/custom_deleted_field_name/conftest.py @@ -0,0 +1,19 @@ +from typing import cast + +import pytest +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, sessionmaker + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.custom_deleted_field_name.model import CFNModelBase, CFNSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection: Connection) -> Session: + CFNModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, CFNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/custom_deleted_field_name/model.py b/tests/custom_deleted_field_name/model.py new file mode 100644 index 0000000..e88d675 --- /dev/null +++ b/tests/custom_deleted_field_name/model.py @@ -0,0 +1,26 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class CFNModelBase: + """CFN = Custom Field Name""" + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class CFNSoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + deleted_field_name="removed_at", + ) +): + removed_at: Mapped[datetime | None] + + +class CFNTable(CFNModelBase, CFNSoftDeleteMixin): + __tablename__ = "cfntable" + value = Column(Integer) diff --git a/tests/custom_deleted_field_name/test_custom_field_name.py b/tests/custom_deleted_field_name/test_custom_field_name.py new file mode 100644 index 0000000..af53043 --- /dev/null +++ b/tests/custom_deleted_field_name/test_custom_field_name.py @@ -0,0 +1,57 @@ +"""Tests for custom deleted_field_name option.""" + +from datetime import datetime, timezone + +from tests.custom_deleted_field_name.model import CFNTable + + +def test_custom_field_name_column_exists(): + """Verify the column uses the custom field name.""" + assert "removed_at" in CFNTable.__table__.columns + assert "deleted_at" not in CFNTable.__table__.columns + + +def test_rewriter_has_correct_field_name(rewriter): + """Verify the rewriter is configured with the custom field name.""" + assert rewriter.deleted_field_name == "removed_at" + + +def test_delete_sets_custom_field(db_session): + """Verify delete() sets the custom field.""" + obj = CFNTable(value=1) + db_session.add(obj) + db_session.commit() + + assert obj.removed_at is None + obj.delete() + assert obj.removed_at is not None + + +def test_undelete_clears_custom_field(db_session): + """Verify undelete() clears the custom field.""" + obj = CFNTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.delete() + assert obj.removed_at is not None + + obj.undelete() + assert obj.removed_at is None + + +def test_soft_delete_filtering_uses_custom_field(db_session): + """Verify soft-delete filtering works with custom field name.""" + active = CFNTable(value=1) + deleted = CFNTable(value=2) + deleted.removed_at = datetime.now(timezone.utc) + + db_session.add_all([active, deleted]) + db_session.commit() + + results = db_session.query(CFNTable).all() + assert len(results) == 1 + assert results[0].value == 1 + + all_results = db_session.query(CFNTable).execution_options(include_deleted=True).all() + assert len(all_results) == 2 diff --git a/tests/custom_method_names/__init__.py b/tests/custom_method_names/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/custom_method_names/conftest.py b/tests/custom_method_names/conftest.py new file mode 100644 index 0000000..785e8cb --- /dev/null +++ b/tests/custom_method_names/conftest.py @@ -0,0 +1,19 @@ +from typing import cast + +import pytest +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, sessionmaker + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.custom_method_names.model import CMNModelBase, CMNSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection: Connection) -> Session: + CMNModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, CMNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/custom_method_names/model.py b/tests/custom_method_names/model.py new file mode 100644 index 0000000..2a57d36 --- /dev/null +++ b/tests/custom_method_names/model.py @@ -0,0 +1,33 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class CMNModelBase: + """CMN = Custom Method Names""" + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class CMNSoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + delete_method_name="soft_delete", + undelete_method_name="restore", + ) +): + deleted_at: Mapped[datetime | None] + + def soft_delete(self) -> None: + super().soft_delete() # type: ignore[misc] + + def restore(self) -> None: + super().restore() # type: ignore[misc] + + +class CMNTable(CMNModelBase, CMNSoftDeleteMixin): + __tablename__ = "cmntable" + value = Column(Integer) diff --git a/tests/custom_method_names/test_custom_method_names.py b/tests/custom_method_names/test_custom_method_names.py new file mode 100644 index 0000000..e1872eb --- /dev/null +++ b/tests/custom_method_names/test_custom_method_names.py @@ -0,0 +1,37 @@ +"""Tests for custom method names option.""" + +from tests.custom_method_names.model import CMNSoftDeleteMixin, CMNTable + + +def test_custom_method_names_exist(): + """Verify custom method names are used.""" + assert hasattr(CMNSoftDeleteMixin, "soft_delete") + assert hasattr(CMNSoftDeleteMixin, "restore") + # Original names should not exist on the generated parent class + generated_class = CMNSoftDeleteMixin.__bases__[0] + assert not hasattr(generated_class, "delete") + assert not hasattr(generated_class, "undelete") + + +def test_soft_delete_method_sets_deleted_at(db_session): + """Verify soft_delete() method works.""" + obj = CMNTable(value=1) + db_session.add(obj) + db_session.commit() + + assert obj.deleted_at is None + obj.soft_delete() + assert obj.deleted_at is not None + + +def test_restore_method_clears_deleted_at(db_session): + """Verify restore() method works.""" + obj = CMNTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.soft_delete() + assert obj.deleted_at is not None + + obj.restore() + assert obj.deleted_at is None diff --git a/tests/default_config/__init__.py b/tests/default_config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__snapshots__/test_queries.ambr b/tests/default_config/__snapshots__/test_queries.ambr similarity index 100% rename from tests/__snapshots__/test_queries.ambr rename to tests/default_config/__snapshots__/test_queries.ambr diff --git a/tests/__snapshots__/test_seed_data.ambr b/tests/default_config/__snapshots__/test_seed_data.ambr similarity index 100% rename from tests/__snapshots__/test_seed_data.ambr rename to tests/default_config/__snapshots__/test_seed_data.ambr diff --git a/tests/default_config/conftest.py b/tests/default_config/conftest.py new file mode 100644 index 0000000..87a0831 --- /dev/null +++ b/tests/default_config/conftest.py @@ -0,0 +1,33 @@ +from typing import cast + +import pytest +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, sessionmaker + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.default_config.model import SoftDeleteMixin, TestModelBase +from tests.default_config.seed_data import generate_table_with_inheritance_obj +from tests.default_config.seed_data.parent_child_childchild import generate_parent_child_object_hierarchy + + +@pytest.fixture +def db_session(db_connection: Connection) -> Session: + TestModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def seeded_session(db_session: Session) -> Session: + generate_parent_child_object_hierarchy(db_session, 1000) + generate_parent_child_object_hierarchy(db_session, 1001) + generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True) + + generate_table_with_inheritance_obj(db_session, 1000, deleted=False) + generate_table_with_inheritance_obj(db_session, 1001, deleted=False) + generate_table_with_inheritance_obj(db_session, 1002, deleted=True) + return db_session + + +@pytest.fixture +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, SoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/default_config/model.py b/tests/default_config/model.py new file mode 100644 index 0000000..c36c6a5 --- /dev/null +++ b/tests/default_config/model.py @@ -0,0 +1,105 @@ +from datetime import datetime + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, as_declarative, relationship + +from sqlalchemy_easy_softdelete.hook import IgnoredTable +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class TestModelBase: + id = Column(Integer, primary_key=True, autoincrement=True) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id}>" + + +class SoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + ignored_tables=[ + IgnoredTable(table_schema=None, name="sdtablethatshouldnotbesoftdeleted"), + ], + ) +): + # Type hint for IDE autocomplete and type checker support. + # Using Mapped[T | None] ensures type checkers understand this is a + # SQLAlchemy column that supports query operations like .where() + deleted_at: Mapped[datetime | None] + + # Optional: Add method stubs for delete/undelete for type checker support. + # The actual implementations are provided by the generated mixin class. + def delete(self, v: datetime | None = None) -> None: + super().delete(v) # type: ignore[misc] + + def undelete(self) -> None: + super().undelete() # type: ignore[misc] + + +class SDSimpleTable(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdsimpletable" + int_field = Column(Integer) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" + + +class SDParent(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdparent" + children: Mapped[list["SDChild"]] = relationship("SDChild", back_populates="parent") # type: ignore[assignment] + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" + + +class SDChild(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdchild" + parent_id = Column(Integer, ForeignKey("sdparent.id"), nullable=False) + parent: Mapped["SDParent"] = relationship("SDParent", back_populates="children") # type: ignore[assignment] + child_children: Mapped[list["SDChildChild"]] = relationship("SDChildChild", back_populates="child") # type: ignore[assignment] + + def __repr__(self) -> str: + pid = f"(parent_id={self.parent_id})" + left = f"{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}" + return f"<{left:30} {pid:>15}>" + + +class SDChildChild(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdchildchild" + child_id = Column(Integer, ForeignKey("sdchild.id"), nullable=False) + child: Mapped["SDChild"] = relationship("SDChild", back_populates="child_children") # type: ignore[assignment] + + def __repr__(self) -> str: + pid = f"(child_id={self.child_id})" + left = f"{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}" + return f"<{left:30} {pid:>15}>" + + +class SDBaseRequest(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdbaserequest" + request_type = Column(String(50)) + base_field = Column(Integer) + + __mapper_args__ = { + "polymorphic_identity": "sdbaserequest", + "polymorphic_on": "request_type", + } + + +class SDDerivedRequest(SDBaseRequest): + __tablename__ = "sdderivedrequest" + id = Column(Integer, ForeignKey("sdbaserequest.id"), primary_key=True) + derived_field = Column(Integer) + + __mapper_args__ = { + "polymorphic_identity": "sdderivedrequest", + } + + +class SDTableThatShouldNotBeSoftDeleted(TestModelBase): + __tablename__ = "sdtablethatshouldnotbesoftdeleted" + id = Column(Integer, primary_key=True) + deleted_at = Column(DateTime(timezone=True)) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id}>" diff --git a/tests/seed_data/__init__.py b/tests/default_config/seed_data/__init__.py similarity index 85% rename from tests/seed_data/__init__.py rename to tests/default_config/seed_data/__init__.py index 4709082..50f2bf2 100644 --- a/tests/seed_data/__init__.py +++ b/tests/default_config/seed_data/__init__.py @@ -2,7 +2,7 @@ from sqlalchemy.orm import Session -from tests.model import SDDerivedRequest +from tests.default_config.model import SDDerivedRequest def generate_table_with_inheritance_obj(s: Session, obj_id: int, deleted: bool = False): diff --git a/tests/seed_data/parent_child_childchild.py b/tests/default_config/seed_data/parent_child_childchild.py similarity index 88% rename from tests/seed_data/parent_child_childchild.py rename to tests/default_config/seed_data/parent_child_childchild.py index f402710..247d727 100644 --- a/tests/seed_data/parent_child_childchild.py +++ b/tests/default_config/seed_data/parent_child_childchild.py @@ -1,20 +1,20 @@ -import datetime import random +from datetime import datetime, timedelta from sqlalchemy.orm import Session -from tests.model import SDChild, SDChildChild, SDParent +from tests.default_config.model import SDChild, SDChildChild, SDParent -TEST_EPOCH = datetime.datetime(year=1985, month=8, day=4) +TEST_EPOCH = datetime(year=1985, month=8, day=4) def pseudorandom_date(max_days: int = 3650) -> datetime: - return TEST_EPOCH + datetime.timedelta(days=random.randint(0, max_days)) + return TEST_EPOCH + timedelta(days=random.randint(0, max_days)) def generate_parent_child_object_hierarchy( s: Session, parent_id: int, min_children: int = 1, max_children: int = 5, parent_deleted: bool = False -): +) -> None: # Fix a seed in the RNG for deterministic outputs random.seed(parent_id) diff --git a/tests/test_queries.py b/tests/default_config/test_queries.py similarity index 99% rename from tests/test_queries.py rename to tests/default_config/test_queries.py index 5b622ad..ebc751b 100644 --- a/tests/test_queries.py +++ b/tests/default_config/test_queries.py @@ -7,7 +7,7 @@ from sqlalchemy.sql import CompoundSelect, Select from sqlalchemy.sql.lambdas import LambdaElement, LinkedLambdaElement, StatementLambdaElement -from tests.model import ( +from tests.default_config.model import ( SDBaseRequest, SDChild, SDChildChild, diff --git a/tests/test_seed_data.py b/tests/default_config/test_seed_data.py similarity index 89% rename from tests/test_seed_data.py rename to tests/default_config/test_seed_data.py index e79399a..14cbb64 100644 --- a/tests/test_seed_data.py +++ b/tests/default_config/test_seed_data.py @@ -1,6 +1,6 @@ """Tests for `sqlalchemy_easy_softdelete` package.""" -from tests.model import SDChild, SDChildChild, SDParent +from tests.default_config.model import SDChild, SDChildChild, SDParent def test_ensure_stable_seed_data(snapshot, seeded_session): diff --git a/tests/default_config/test_type_hints.py b/tests/default_config/test_type_hints.py new file mode 100644 index 0000000..4c33ef8 --- /dev/null +++ b/tests/default_config/test_type_hints.py @@ -0,0 +1,179 @@ +"""Tests for type hint compatibility with SQLAlchemy operations. + +These tests verify that the Mapped[T | None] type hint recommendation works +correctly with SQLAlchemy query operations. See GitHub issue #31. + +The key insight is that using `deleted_at: datetime` as a type hint causes +type checkers to treat the attribute as a plain datetime, which breaks +type checking for expressions like `.where(Model.deleted_at < value)`. + +Using `deleted_at: Mapped[datetime | None]` tells the type checker this is +a SQLAlchemy mapped column that supports comparison operations. +""" + +from datetime import datetime, timezone + +from sqlalchemy import select +from sqlalchemy.sql.elements import BinaryExpression + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.default_config.model import SDChild, SDParent, SoftDeleteMixin + + +def test_deleted_at_column_supports_comparison_operators(): + """Verify that deleted_at can be used with comparison operators. + + This tests that our type hints allow using the column in expressions. + With the old `deleted_at: datetime` hint, type checkers would complain + that datetime doesn't support these operations in a SQLAlchemy context. + """ + # These should all work without type errors when using Mapped[datetime | None] + now = datetime.now(timezone.utc) + + # Less than + expr_lt = SDChild.deleted_at < now + assert isinstance(expr_lt, BinaryExpression) + + # Greater than + expr_gt = SDChild.deleted_at > now + assert isinstance(expr_gt, BinaryExpression) + + # Equals + expr_eq = SDChild.deleted_at == now + assert isinstance(expr_eq, BinaryExpression) + + # Not equals + expr_ne = SDChild.deleted_at != now + assert isinstance(expr_ne, BinaryExpression) + + # IS NULL + expr_is_none = SDChild.deleted_at.is_(None) + assert isinstance(expr_is_none, BinaryExpression) + + # IS NOT NULL + expr_is_not_none = SDChild.deleted_at.isnot(None) + assert isinstance(expr_is_not_none, BinaryExpression) + + +def test_deleted_at_column_works_in_where_clause(): + """Verify that deleted_at can be used in .where() clauses. + + This is the primary use case from issue #31 - the user was getting + type errors when using .where(Model.deleted_at < value). + """ + now = datetime.now(timezone.utc) + + # Build a select statement with a where clause using deleted_at + stmt = select(SDChild).where(SDChild.deleted_at < now) + + # The statement should compile without errors + assert stmt is not None + assert "deleted_at" in str(stmt) + + +def test_deleted_at_column_works_in_filter(): + """Verify that deleted_at works with the legacy .filter() method.""" + now = datetime.now(timezone.utc) + + # Using filter (ORM Query style) + stmt = select(SDParent).filter(SDParent.deleted_at > now) + + assert stmt is not None + assert "deleted_at" in str(stmt) + + +def test_deleted_at_column_works_with_between(): + """Verify that deleted_at works with .between().""" + now = datetime.now(timezone.utc) + earlier = datetime(2020, 1, 1, tzinfo=timezone.utc) + + expr = SDChild.deleted_at.between(earlier, now) + assert isinstance(expr, BinaryExpression) + + +def test_delete_method_is_callable(seeded_session): + """Verify that the delete() method stub works correctly. + + The SoftDeleteMixin provides method stubs for delete() and undelete() + so that type checkers know these methods exist. + """ + # Get an instance + child = seeded_session.query(SDChild).first() + assert child is not None + + # The delete method should be callable + assert hasattr(child, "delete") + assert callable(child.delete) + + # Call delete and verify it sets deleted_at + child.delete() + assert child.deleted_at is not None + + +def test_delete_without_value_uses_default(seeded_session): + """Verify delete() without value uses the default function (current time).""" + child = seeded_session.query(SDChild).first() + assert child is not None + + before = datetime.now(timezone.utc) + child.delete() + after = datetime.now(timezone.utc) + + # Should be between before and after (current time) + # SQLite doesn't preserve timezone, so compare without it + deleted_at = child.deleted_at.replace(tzinfo=timezone.utc) if child.deleted_at.tzinfo is None else child.deleted_at + assert before <= deleted_at <= after + + +def test_delete_with_custom_value(seeded_session): + """Verify delete(value) uses the passed value instead of default.""" + child = seeded_session.query(SDChild).first() + assert child is not None + + custom_date = datetime(2020, 6, 15, 12, 30, tzinfo=timezone.utc) + child.delete(custom_date) + + # SQLite doesn't preserve timezone, so compare without it + assert child.deleted_at.replace(tzinfo=None) == custom_date.replace(tzinfo=None) + + +def test_undelete_method_is_callable(seeded_session): + """Verify that the undelete() method stub works correctly.""" + # Get a deleted instance + child = ( + seeded_session + .query(SDChild) + .execution_options(include_deleted=True) + .filter(SDChild.deleted_at.isnot(None)) + .first() + ) + assert child is not None + assert child.deleted_at is not None + + # The undelete method should be callable + assert hasattr(child, "undelete") + assert callable(child.undelete) + + # Call undelete and verify it clears deleted_at + child.undelete() + assert child.deleted_at is None + + +def test_mixin_class_has_correct_type_annotations(): + """Verify that the SoftDeleteMixin has the expected type annotations.""" + annotations = SoftDeleteMixin.__annotations__ + + # Should have deleted_at annotation + assert "deleted_at" in annotations + + # The annotation should include Mapped + annotation_str = str(annotations["deleted_at"]) + assert "Mapped" in annotation_str or "datetime" in annotation_str + + +def test_rewriter_is_attached_to_mixin(): + """Verify the rewriter is attached to the mixin class.""" + assert hasattr(SoftDeleteMixin, "_sqlalchemy_easy_softdelete_rewriter") + rewriter = SoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter + assert isinstance(rewriter, SoftDeleteQueryRewriter) + assert rewriter.deleted_field_name == "deleted_at" diff --git a/tests/disabled_methods/__init__.py b/tests/disabled_methods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/disabled_methods/conftest.py b/tests/disabled_methods/conftest.py new file mode 100644 index 0000000..a59d520 --- /dev/null +++ b/tests/disabled_methods/conftest.py @@ -0,0 +1,19 @@ +from typing import cast + +import pytest +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, sessionmaker + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.disabled_methods.model import DMModelBase, DMSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection: Connection) -> Session: + DMModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, DMSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/disabled_methods/model.py b/tests/disabled_methods/model.py new file mode 100644 index 0000000..92ee1c8 --- /dev/null +++ b/tests/disabled_methods/model.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class DMModelBase: + """DM = Disabled Methods""" + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class DMSoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + generate_delete_method=False, + generate_undelete_method=False, + ) +): + deleted_at: Mapped[datetime | None] + + +class DMTable(DMModelBase, DMSoftDeleteMixin): + __tablename__ = "dmtable" + value = Column(Integer) diff --git a/tests/disabled_methods/test_disabled_methods.py b/tests/disabled_methods/test_disabled_methods.py new file mode 100644 index 0000000..81c9fe0 --- /dev/null +++ b/tests/disabled_methods/test_disabled_methods.py @@ -0,0 +1,45 @@ +"""Tests for disabled methods option.""" + +from datetime import datetime, timezone + +from tests.disabled_methods.model import DMSoftDeleteMixin, DMTable + + +def test_no_delete_method_on_generated_class(): + """Verify delete/undelete methods are not generated.""" + generated_class = DMSoftDeleteMixin.__bases__[0] + assert not hasattr(generated_class, "delete") + assert not hasattr(generated_class, "undelete") + + +def test_can_manually_set_deleted_at(db_session): + """Verify we can still manually set deleted_at.""" + obj = DMTable(value=1) + db_session.add(obj) + db_session.commit() + + obj_id = obj.id + assert obj.deleted_at is None + + deleted_time = datetime.now(timezone.utc) + obj.deleted_at = deleted_time + db_session.commit() + + # After commit, the object is soft-deleted so query with include_deleted + result = db_session.query(DMTable).execution_options(include_deleted=True).filter(DMTable.id == obj_id).first() + assert result is not None + assert result.deleted_at is not None + + +def test_soft_delete_filtering_still_works(db_session): + """Verify soft-delete filtering works even without methods.""" + active = DMTable(value=1) + deleted = DMTable(value=2) + deleted.deleted_at = datetime.now(timezone.utc) + + db_session.add_all([active, deleted]) + db_session.commit() + + results = db_session.query(DMTable).all() + assert len(results) == 1 + assert results[0].value == 1 diff --git a/tests/integer_field_type/__init__.py b/tests/integer_field_type/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integer_field_type/conftest.py b/tests/integer_field_type/conftest.py new file mode 100644 index 0000000..289b349 --- /dev/null +++ b/tests/integer_field_type/conftest.py @@ -0,0 +1,19 @@ +from typing import cast + +import pytest +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, sessionmaker + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.integer_field_type.model import IFTModelBase, IFTSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection: Connection) -> Session: + IFTModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, IFTSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/integer_field_type/model.py b/tests/integer_field_type/model.py new file mode 100644 index 0000000..cec1b74 --- /dev/null +++ b/tests/integer_field_type/model.py @@ -0,0 +1,27 @@ +from datetime import datetime, timezone + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class IFTModelBase: + """IFT = Integer Field Type""" + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class IFTSoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + deleted_field_type=Integer(), + delete_method_default_value=lambda: int(datetime.now(timezone.utc).timestamp()), + ) +): + deleted_at: Mapped[int | None] + + +class IFTTable(IFTModelBase, IFTSoftDeleteMixin): + __tablename__ = "ifttable" + value = Column(Integer) diff --git a/tests/integer_field_type/test_integer_field_type.py b/tests/integer_field_type/test_integer_field_type.py new file mode 100644 index 0000000..cfe3ea9 --- /dev/null +++ b/tests/integer_field_type/test_integer_field_type.py @@ -0,0 +1,54 @@ +"""Tests for integer field type option.""" + +from datetime import datetime, timezone + +from sqlalchemy import Integer + +from tests.integer_field_type.model import IFTTable + + +def test_field_is_integer_type(): + """Verify the field uses Integer type.""" + column = IFTTable.__table__.columns["deleted_at"] + assert isinstance(column.type, Integer) + + +def test_delete_sets_integer_timestamp(db_session): + """Verify delete() sets an integer timestamp.""" + obj = IFTTable(value=1) + db_session.add(obj) + db_session.commit() + + before = int(datetime.now(timezone.utc).timestamp()) + obj.delete() + after = int(datetime.now(timezone.utc).timestamp()) + + assert isinstance(obj.deleted_at, int) + assert before <= obj.deleted_at <= after + + +def test_undelete_clears_integer_field(db_session): + """Verify undelete() clears the integer field.""" + obj = IFTTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.delete() + assert obj.deleted_at is not None + + obj.undelete() + assert obj.deleted_at is None + + +def test_integer_field_soft_delete_filtering(db_session): + """Verify soft-delete filtering works with integer field.""" + active = IFTTable(value=1) + deleted = IFTTable(value=2) + deleted.deleted_at = int(datetime.now(timezone.utc).timestamp()) + + db_session.add_all([active, deleted]) + db_session.commit() + + results = db_session.query(IFTTable).all() + assert len(results) == 1 + assert results[0].value == 1 diff --git a/tests/model.py b/tests/model.py deleted file mode 100644 index bd7a951..0000000 --- a/tests/model.py +++ /dev/null @@ -1,102 +0,0 @@ -from datetime import datetime - -from sqlalchemy import Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import as_declarative, declared_attr, relationship - -from sqlalchemy_easy_softdelete.hook import IgnoredTable -from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class - - -@as_declarative() -class TestModelBase: - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - - id = Column(Integer, primary_key=True, autoincrement=True) - - def __repr__(self): - return f"<{self.__class__.__name__} id={self.id}>" - - -class SoftDeleteMixin( - generate_soft_delete_mixin_class( - ignored_tables=[ - IgnoredTable(table_schema=None, name="sdtablethatshouldnotbesoftdeleted"), - ], - ) -): - # for autocomplete - deleted_at: datetime - - -class SDSimpleTable(TestModelBase, SoftDeleteMixin): - int_field = Column(Integer) - - def __repr__(self): - return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" - - -class SDParent(TestModelBase, SoftDeleteMixin): - __allow_unmapped__ = True - children = relationship("SDChild", back_populates="parent") - - def __repr__(self): - return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" - - -class SDChild(TestModelBase, SoftDeleteMixin): - __allow_unmapped__ = True - parent_id = Column(Integer, ForeignKey(f"{SDParent.__tablename__}.id"), nullable=False) - parent = relationship("SDParent", back_populates="children") - - child_children = relationship("SDChildChild", back_populates="child") - - def __repr__(self): - pid = f"(parent_id={self.parent_id})" - left = f"{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}" - return f"<{left:30} {pid:>15}>" - - -class SDChildChild(TestModelBase, SoftDeleteMixin): - __allow_unmapped__ = True - - child_id = Column(Integer, ForeignKey(f"{SDChild.__tablename__}.id"), nullable=False) - child: SDChild = relationship("SDChild", back_populates="child_children") - - def __repr__(self): - pid = f"(child_id={self.child_id})" - left = f"{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}" - return f"<{left:30} {pid:>15}>" - - -class SDBaseRequest( - TestModelBase, - SoftDeleteMixin, -): - request_type = Column(String(50)) - - base_field = Column(Integer) - - __mapper_args__ = { - "polymorphic_identity": "sdbaserequest", - "polymorphic_on": request_type, - } - - -class SDDerivedRequest(SDBaseRequest): - id: Integer = Column(Integer, ForeignKey("sdbaserequest.id"), primary_key=True) - - derived_field = Column(Integer) - - __mapper_args__ = { - "polymorphic_identity": "sdderivedrequest", - } - - -class SDTableThatShouldNotBeSoftDeleted(TestModelBase): - id: Integer = Column(Integer, primary_key=True) - deleted_at: datetime = Column(DateTime(timezone=True)) - - def __repr__(self): - return f"<{self.__class__.__name__} id={self.id} name={self.name}>" diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index de13e63..5b18f73 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,11 +1,13 @@ -from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, Null +from typing import Any + +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, ColumnElement, Null from sqlalchemy.sql.schema import Table from sqlalchemy.sql.selectable import Select from tests.utils.simple_select_extractor import extract_simple_selects -def extract_binary_expressions_from_where(whereclause) -> tuple[BinaryExpression]: +def extract_binary_expressions_from_where(whereclause: Any) -> tuple[ColumnElement[Any], ...]: if isinstance(whereclause, BinaryExpression): return (whereclause,) @@ -14,16 +16,16 @@ def extract_binary_expressions_from_where(whereclause) -> tuple[BinaryExpression # Make sure we only have BinaryExpressions assert all(isinstance(c, BinaryExpression) for c in clauses) - return tuple(whereclause.clauses) + return clauses raise NotImplementedError(f'Unsupported whereclause type "{(type(whereclause))}"!') -def is_soft_delete_filter(b: BinaryExpression, tables: list[Table], deleted_field: str): +def is_soft_delete_filter(b: BinaryExpression[Any], tables: set[Table], deleted_field: str) -> bool: return b.left.table in tables and b.left.name == deleted_field and isinstance(b.right, Null) -def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table], deleted_field: str) -> bool: +def is_simple_select_doing_soft_delete_filtering(stmt: Select[Any], tables: set[Table], deleted_field: str) -> bool: # Check if query is disabled for soft-deletion opts = stmt.get_execution_options() if opts and opts.get("include_deleted"): @@ -37,9 +39,11 @@ def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table binary_expressions = extract_binary_expressions_from_where(stmt.whereclause) - found_tables = set() + found_tables: set[Table] = set() for binary_expression in binary_expressions: - if is_soft_delete_filter(binary_expression, tables, deleted_field): + if isinstance(binary_expression, BinaryExpression) and is_soft_delete_filter( + binary_expression, tables, deleted_field + ): found_tables.add(binary_expression.left.table) if found_tables == tables: @@ -48,7 +52,7 @@ def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table return False -def is_filtering_for_softdeleted(statement: Select, tables: set[Table], deleted_field: str = "deleted_at") -> bool: +def is_filtering_for_softdeleted(statement: Select[Any], tables: set[Table], deleted_field: str = "deleted_at") -> bool: selects = extract_simple_selects(statement) # Make sure all extracted selects are doing soft-delete filtering diff --git a/tests/utils/simple_select_extractor.py b/tests/utils/simple_select_extractor.py index a397f5a..2663f4a 100644 --- a/tests/utils/simple_select_extractor.py +++ b/tests/utils/simple_select_extractor.py @@ -6,9 +6,11 @@ from __future__ import annotations +from typing import Any + from sqlalchemy.orm.util import _ORMJoin from sqlalchemy.sql.schema import Table -from sqlalchemy.sql.selectable import CompoundSelect, Join, Select, SelectBase, Subquery +from sqlalchemy.sql.selectable import CompoundSelect, Join, Select, Subquery def is_simple_join(j: Join | _ORMJoin) -> bool: @@ -27,7 +29,7 @@ def is_simple_join(j: Join | _ORMJoin) -> bool: return left_simple and right_simple -def is_simple_select(s: Select | Subquery | CompoundSelect) -> bool: +def is_simple_select(s: Select[Any] | Subquery | CompoundSelect[Any]) -> bool: if isinstance(s, CompoundSelect): return False @@ -53,20 +55,20 @@ def is_simple_select(s: Select | Subquery | CompoundSelect) -> bool: return True -def extract_simple_selects(statement: Select | CompoundSelect | SelectBase) -> list[SelectBase]: +def extract_simple_selects(statement: Select[Any] | CompoundSelect[Any]) -> list[Select[Any]]: if is_simple_select(statement): - return [statement] + return [statement] # type: ignore[list-item] # We know it's a Select here if isinstance(statement, CompoundSelect): - extracted_elements = [] + extracted_elements: list[Select[Any]] = [] for select in statement.selects: - extracted_elements.extend(extract_simple_selects(select)) + extracted_elements.extend(extract_simple_selects(select)) # type: ignore[arg-type] return extracted_elements for from_obj in statement.get_final_froms(): if isinstance(from_obj, Table): continue elif isinstance(from_obj, Subquery): - return extract_simple_selects(from_obj.element) + return extract_simple_selects(from_obj.element) # type: ignore[arg-type] # element is Select raise NotImplementedError(f'Should not reach this point! statement.froms -> "{statement.froms}"!')