diff --git a/migrations/env.py b/migrations/env.py index 3b98615..65e0bc2 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -3,7 +3,7 @@ from alembic import context from sqlalchemy import engine_from_config, pool -from print_service.models import Model +from print_service.models.base import Base from print_service.settings import get_settings @@ -20,7 +20,7 @@ # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -target_metadata = Model.metadata +target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, # can be acquired: diff --git a/migrations/versions/c29b6ffbfed4_add_is_deleted_field_to_unionmember.py b/migrations/versions/c29b6ffbfed4_add_is_deleted_field_to_unionmember.py new file mode 100644 index 0000000..204f4cf --- /dev/null +++ b/migrations/versions/c29b6ffbfed4_add_is_deleted_field_to_unionmember.py @@ -0,0 +1,28 @@ +"""Add is_deleted field to UnionMember + +Revision ID: c29b6ffbfed4 +Revises: a68c6bb2972c +Create Date: 2024-11-22 17:50:35.569723 + +""" + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = 'c29b6ffbfed4' +down_revision = 'a68c6bb2972c' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + 'union_member', sa.Column('is_deleted', sa.Boolean(), nullable=False, server_default=sa.false()) + ) + + +def downgrade(): + op.drop_column('union_member', 'is_deleted') + op.alter_column('file', 'source', existing_type=sa.VARCHAR(), nullable=True) diff --git a/print_service/models/__init__.py b/print_service/models/__init__.py index d1ffcfe..f62fc63 100644 --- a/print_service/models/__init__.py +++ b/print_service/models/__init__.py @@ -9,26 +9,24 @@ from sqlalchemy.sql.schema import ForeignKey from sqlalchemy.sql.sqltypes import Boolean +from print_service.models.base import BaseDbModel -@as_declarative() -class Model: - pass - -class UnionMember(Model): - __tablename__ = 'union_member' +class UnionMember(BaseDbModel): + # __tablename__ = 'union_member' id: Mapped[int] = mapped_column(Integer, primary_key=True) surname: Mapped[str] = mapped_column(String, nullable=False) union_number: Mapped[str] = mapped_column(String, nullable=True) student_number: Mapped[str] = mapped_column(String, nullable=True) + is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) files: Mapped[list[File]] = relationship('File', back_populates='owner') print_facts: Mapped[list[PrintFact]] = relationship('PrintFact', back_populates='owner') -class File(Model): - __tablename__ = 'file' +class File(BaseDbModel): + # __tablename__ = 'file' id: Mapped[int] = Column(Integer, primary_key=True) pin: Mapped[str] = Column(String, nullable=False) @@ -44,7 +42,11 @@ class File(Model): number_of_pages: Mapped[int] = Column(Integer) source: Mapped[str] = Column(String, default='unknown', nullable=False) - owner: Mapped[UnionMember] = relationship('UnionMember', back_populates='files') + owner: Mapped[UnionMember] = relationship( + 'UnionMember', + primaryjoin="and_(File.owner_id==UnionMember.id, not_(UnionMember.is_deleted))", + back_populates='files', + ) print_facts: Mapped[list[PrintFact]] = relationship('PrintFact', back_populates='file') @property @@ -79,14 +81,18 @@ def sheets_count(self) -> int | None: return len(self.flatten_pages) * self.option_copies -class PrintFact(Model): - __tablename__ = 'print_fact' +class PrintFact(BaseDbModel): + # __tablename__ = 'print_fact' id: Mapped[int] = Column(Integer, primary_key=True) file_id: Mapped[int] = Column(Integer, ForeignKey('file.id'), nullable=False) owner_id: Mapped[int] = Column(Integer, ForeignKey('union_member.id'), nullable=False) created_at: Mapped[datetime] = Column(DateTime, nullable=False, default=datetime.utcnow) - owner: Mapped[UnionMember] = relationship('UnionMember', back_populates='print_facts') + owner: Mapped[UnionMember] = relationship( + 'UnionMember', + primaryjoin="and_(PrintFact.owner_id == UnionMember.id, not_(UnionMember.is_deleted))", + back_populates='print_facts', + ) file: Mapped[File] = relationship('File', back_populates='print_facts') sheets_used: Mapped[int] = Column(Integer) diff --git a/print_service/models/base.py b/print_service/models/base.py new file mode 100644 index 0000000..b307a67 --- /dev/null +++ b/print_service/models/base.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import re + +from sqlalchemy import not_ +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Query, Session, as_declarative, declared_attr + +from print_service.exceptions import ObjectNotFound + + +@as_declarative() +class Base: + """Base class for all database entities""" + + @declared_attr + def __tablename__(cls) -> str: # pylint: disable=no-self-argument + """Generate database table name automatically. + Convert CamelCase class name to snake_case db table name. + """ + return re.sub(r"(? BaseDbModel: + obj = cls(**kwargs) + session.add(obj) + session.flush() + return obj + + @classmethod + def query(cls, *, session: Session, with_deleted: bool = False) -> Query: + """Get all objects with soft deletes""" + objs = session.query(cls) + if not with_deleted and hasattr(cls, "is_deleted"): + objs = objs.filter(not_(cls.is_deleted)) + return objs + + @classmethod + def get(cls, id: int, *, with_deleted=False, session: Session) -> BaseDbModel: + """Get object with soft deletes""" + objs = session.query(cls) + if not with_deleted and hasattr(cls, "is_deleted"): + objs = objs.filter(not_(cls.is_deleted)) + try: + return objs.filter(cls.id == id).one() + except NoResultFound: + raise ObjectNotFound(cls, id) + + @classmethod + def update(cls, id: int, *, session: Session, **kwargs) -> BaseDbModel: + obj = cls.get(id, session=session) + for k, v in kwargs.items(): + setattr(obj, k, v) + session.flush() + return obj + + @classmethod + def delete(cls, id: int, *, session: Session) -> None: + """Soft delete object if possible, else hard delete""" + obj = cls.get(id, session=session) + if hasattr(obj, "is_deleted"): + obj.is_deleted = True + else: + session.delete(obj) + session.flush() diff --git a/print_service/routes/file.py b/print_service/routes/file.py index 5205e0f..d255efc 100644 --- a/print_service/routes/file.py +++ b/print_service/routes/file.py @@ -112,7 +112,7 @@ async def send(inp: SendInput, settings: Settings = Depends(get_settings)): Полученный пин-код можно использовать в методах POST и GET `/file/{pin}`. """ - user = db.session.query(UnionMember) + user = UnionMember.query(session=db.session) if not settings.ALLOW_STUDENT_NUMBER: user = user.filter(UnionMember.union_number != None) user = user.filter( @@ -122,6 +122,7 @@ async def send(inp: SendInput, settings: Settings = Depends(get_settings)): ), func.upper(UnionMember.surname) == inp.surname.upper(), ).one_or_none() + if not user: raise NotInUnion() try: @@ -129,14 +130,18 @@ async def send(inp: SendInput, settings: Settings = Depends(get_settings)): except RuntimeError: raise PINGenerateError() filename = generate_filename(inp.filename) - file_model = FileModel(pin=pin, file=filename, source=inp.source) - file_model.owner = user - file_model.option_copies = inp.options.copies - file_model.option_pages = inp.options.pages - file_model.option_two_sided = inp.options.two_sided - db.session.add(file_model) - db.session.commit() + file_model = FileModel.create( + session=db.session, + pin=pin, + file=filename, + source=inp.source, + owner=user, + option_copies=inp.options.copies, + option_pages=inp.options.pages, + option_two_sided=inp.options.two_sided, + ) + db.session.commit() return { 'pin': file_model.pin, 'options': { @@ -170,11 +175,12 @@ async def upload_file( if file == ...: raise FileIsNotReceived() file_model = ( - db.session.query(FileModel) + FileModel.query(session=db.session) .filter(func.upper(FileModel.pin) == pin.upper()) .order_by(FileModel.created_at.desc()) .one_or_none() ) + if not file_model: await file.close() raise PINNotFound(pin) @@ -237,11 +243,12 @@ async def update_file_options( можно бесконечное количество раз. Можно изменять настройки по одной.""" options = inp.options.model_dump(exclude_unset=True) file_model = ( - db.session.query(FileModel) + FileModel.query(session=db.session) .filter(func.upper(FileModel.pin) == pin.upper()) .order_by(FileModel.created_at.desc()) .one_or_none() ) + print(options) if not file_model: raise PINNotFound(pin) diff --git a/print_service/routes/user.py b/print_service/routes/user.py index 7cd7743..680ed82 100644 --- a/print_service/routes/user.py +++ b/print_service/routes/user.py @@ -22,7 +22,7 @@ # region schemas class UserCreate(BaseModel): - username: constr(strip_whitespace=True, to_upper=True, min_length=1) + surname: constr(strip_whitespace=True, to_upper=True, min_length=1) union_number: Optional[constr(strip_whitespace=True, to_upper=True, min_length=1)] student_number: Optional[constr(strip_whitespace=True, to_upper=True, min_length=1)] @@ -40,9 +40,7 @@ class UpdateUserList(BaseModel): @router.get( '/is_union_member', status_code=202, - responses={ - 404: {'detail': 'User not found'}, - }, + responses={404: {'detail': 'User not found'}}, ) async def check_union_member( surname: constr(strip_whitespace=True, to_upper=True, min_length=1), @@ -51,7 +49,7 @@ async def check_union_member( ): """Проверяет наличие пользователя в списке.""" surname = surname.upper() - user = db.session.query(UnionMember) + user = UnionMember.query(session=db.session) if not settings.ALLOW_STUDENT_NUMBER: user = user.filter(UnionMember.union_number != None) user: UnionMember = user.filter( @@ -94,7 +92,7 @@ def update_list( for user in input.users: db_user: UnionMember = ( - db.session.query(UnionMember) + UnionMember.query(session=db.session) .filter( or_( and_( @@ -111,19 +109,9 @@ def update_list( ) if db_user: - db_user.surname = user.username - db_user.union_number = user.union_number - db_user.student_number = user.student_number + UnionMember.update(session=db.session, id=db_user.id, **user.model_dump(exclude_unset=False)) else: - db.session.add( - UnionMember( - surname=user.username, - union_number=user.union_number, - student_number=user.student_number, - ) - ) - db.session.flush() - + UnionMember.create(session=db.session, **user.model_dump(exclude_unset=False)) db.session.commit() return {"status": "ok", "count": len(input.users)} diff --git a/print_service/utils/__init__.py b/print_service/utils/__init__.py index 1c44b33..daa6e38 100644 --- a/print_service/utils/__init__.py +++ b/print_service/utils/__init__.py @@ -31,7 +31,7 @@ def generate_pin(session: Session): for i in range(15): pin = ''.join(random.choice(settings.PIN_SYMBOLS) for _ in range(settings.PIN_LENGTH)) cnt = ( - session.query(File) + File.query(session=session) .filter( File.pin == pin, File.created_at + timedelta(hours=settings.STORAGE_TIME) >= datetime.utcnow(), @@ -57,11 +57,12 @@ def generate_filename(original_filename: str): def get_file(dbsession, pin: str or list[str]): pin = [pin.upper()] if isinstance(pin, str) else tuple(p.upper() for p in pin) files: list[FileModel] = ( - dbsession.query(FileModel) + FileModel.query(session=dbsession) .filter(func.upper(FileModel.pin).in_(pin)) .order_by(FileModel.created_at.desc()) .all() ) + if len(pin) != len(files): raise FileNotFound(len(pin) - len(files)) @@ -85,8 +86,9 @@ def get_file(dbsession, pin: str or list[str]): if f.flatten_pages: if number_of_pages > max(f.flatten_pages): raise InvalidPageRequest() - file_model = PrintFact(file_id=f.id, owner_id=f.owner_id, sheets_used=f.sheets_count) - dbsession.add(file_model) + PrintFact.create( + session=dbsession, file_id=f.id, owner_id=f.owner_id, sheets_used=f.sheets_count + ) dbsession.commit() return result diff --git a/tests/test_routes/conftest.py b/tests/test_routes/conftest.py index 8eb798b..7c4327a 100644 --- a/tests/test_routes/conftest.py +++ b/tests/test_routes/conftest.py @@ -16,10 +16,23 @@ def union_member_user(dbsession): dbsession.add(UnionMember(**union_member)) dbsession.commit() yield union_member - db_user = dbsession.query(UnionMember).filter(UnionMember.id == union_member['id']).one_or_none() + db_user = ( + UnionMember.query(session=dbsession, with_deleted=True) + .filter(UnionMember.id == union_member['id']) + .one_or_none() + ) assert db_user is not None - dbsession.query(PrintFact).filter(PrintFact.owner_id == union_member['id']).delete() - dbsession.query(UnionMember).filter(UnionMember.id == union_member['id']).delete() + PrintFact.query(session=dbsession).filter(PrintFact.owner_id == union_member['id']).delete() + UnionMember.query(session=dbsession, with_deleted=True).filter( + UnionMember.id == union_member['id'] + ).delete() + dbsession.commit() + + +@pytest.fixture(scope='function') +def add_is_deleted_flag(dbsession): + db_user = UnionMember.query(session=dbsession).filter(UnionMember.id == 42).one_or_none() + db_user.is_deleted = True dbsession.commit() @@ -32,12 +45,12 @@ def uploaded_file_db(dbsession, union_member_user, client): "options": {"pages": "", "copies": 1, "two_sided": False}, } res = client.post('/file', json=body) - db_file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + db_file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() yield db_file - file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() assert file is not None - dbsession.query(PrintFact).filter(PrintFact.file_id == file.id).delete() - dbsession.query(File).filter(File.pin == res.json()['pin']).delete() + PrintFact.query(session=dbsession).filter(PrintFact.file_id == file.id).delete() + File.query(session=dbsession).filter(File.pin == res.json()['pin']).delete() dbsession.commit() @@ -60,8 +73,8 @@ def pin_pdf(dbsession, union_member_user, client): res = client.post('/file', json=body) pin = res.json()['pin'] yield pin - file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() assert file is not None - dbsession.query(PrintFact).filter(PrintFact.file_id == file.id).delete() - dbsession.query(File).filter(File.pin == res.json()['pin']).delete() + PrintFact.query(session=dbsession).filter(PrintFact.file_id == file.id).delete() + File.query(session=dbsession).filter(File.pin == res.json()['pin']).delete() dbsession.commit() diff --git a/tests/test_routes/test_file.py b/tests/test_routes/test_file.py index 864e245..d72f8b5 100644 --- a/tests/test_routes/test_file.py +++ b/tests/test_routes/test_file.py @@ -25,7 +25,7 @@ def test_post_success(union_member_user, client, dbsession): } res = client.post(url, data=json.dumps(body)) assert res.status_code == status.HTTP_200_OK - db_file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + db_file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() assert db_file is not None assert db_file.source == 'webapp' body2 = { @@ -36,7 +36,7 @@ def test_post_success(union_member_user, client, dbsession): } res2 = client.post(url, data=json.dumps(body2)) assert res2.status_code == status.HTTP_200_OK - db_file2 = dbsession.query(File).filter(File.pin == res2.json()['pin']).one_or_none() + db_file2 = File.query(session=dbsession).filter(File.pin == res2.json()['pin']).one_or_none() assert db_file2 is not None assert db_file2.source == 'unknown' dbsession.delete(db_file) @@ -44,6 +44,18 @@ def test_post_success(union_member_user, client, dbsession): dbsession.commit() +def test_post_is_deleted(client, union_member_user, add_is_deleted_flag): + body = { + "surname": union_member_user['surname'], + "number": union_member_user['union_number'], + "filename": "filename.pdf", + "source": "webapp", + "options": {"pages": "", "copies": 1, "two_sided": False}, + } + res = client.post(url, data=json.dumps(body)) + assert res.status_code == status.HTTP_403_FORBIDDEN + + def test_post_unauthorized_user(client): body = { "surname": 'surname', diff --git a/tests/test_routes/test_user.py b/tests/test_routes/test_user.py index 97fdb44..4307429 100644 --- a/tests/test_routes/test_user.py +++ b/tests/test_routes/test_user.py @@ -1,6 +1,7 @@ import json import pytest +from sqlalchemy import and_, func from starlette import status from print_service.models import UnionMember @@ -29,22 +30,55 @@ def test_get_not_found(client): assert res.status_code == status.HTTP_404_NOT_FOUND +def test_get_is_deleted(client, union_member_user, add_is_deleted_flag): + params = { + 'surname': 'test', + 'number': '6666667', + } + res = client.get(url, params=params) + assert res.status_code == status.HTTP_404_NOT_FOUND + + def test_post_success(client, dbsession): body = { 'users': [ { - 'username': 'paul', + 'surname': 'paul', 'union_number': '1966', 'student_number': '1967', } - ], + ] + } + res = client.post(url, data=json.dumps(body)) + assert res.status_code == status.HTTP_200_OK + UnionMember.query(session=dbsession).filter( + and_( + UnionMember.surname == func.upper(body['users'][0]['surname']), + UnionMember.union_number == func.upper(body['users'][0]['union_number']), + UnionMember.student_number == func.upper(body['users'][0]['student_number']), + ) + ).delete() + dbsession.commit() + + +def test_post_is_deleted(client, dbsession, union_member_user, add_is_deleted_flag): + body = { + 'users': [ + { + 'surname': 'new_test', + 'union_number': '6666667', + 'student_number': '13033224', + } + ] } res = client.post(url, data=json.dumps(body)) assert res.status_code == status.HTTP_200_OK - dbsession.query(UnionMember).filter( - UnionMember.surname == body['users'][0]['username'], - UnionMember.union_number == body['users'][0]['union_number'], - UnionMember.student_number == body['users'][0]['student_number'], + UnionMember.query(session=dbsession).filter( + and_( + UnionMember.surname == func.upper(body['users'][0]['surname']), + UnionMember.union_number == func.upper(body['users'][0]['union_number']), + UnionMember.student_number == func.upper(body['users'][0]['student_number']), + ) ).delete() dbsession.commit() @@ -55,12 +89,12 @@ def test_post_success(client, dbsession): pytest.param( [ { - 'username': 'paul', + 'surname': 'paul', 'union_number': '404man', 'student_number': '30311', }, { - 'username': 'marty', + 'surname': 'marty', 'union_number': '404man', 'student_number': '303112', }, @@ -70,12 +104,12 @@ def test_post_success(client, dbsession): pytest.param( [ { - 'username': 'alice', + 'surname': 'alice', 'union_number': '500', 'student_number': '42', }, { - 'username': 'polly', + 'surname': 'polly', 'union_number': '503', 'student_number': '42', },