diff --git a/auth_backend/auth_method/method_mixins.py b/auth_backend/auth_method/method_mixins.py index 72a2670f..7b026c33 100644 --- a/auth_backend/auth_method/method_mixins.py +++ b/auth_backend/auth_method/method_mixins.py @@ -51,7 +51,13 @@ async def _login(*args, **kwargs) -> Session: @staticmethod async def _create_session( - user: User, scopes_list_names: list[TypeScope] | None, session_name: str | None = None, *, db_session: DbSession + user: User, + scopes_list_names: list[TypeScope] | None, + session_name: str | None = None, + *, + db_session: DbSession, ) -> Session: """Создает сессию пользователя""" - return await create_session(user, scopes_list_names, db_session=db_session, session_name=session_name) + return await create_session( + user, scopes_list_names, db_session=db_session, session_name=session_name, is_unbounded=True + ) diff --git a/auth_backend/auth_plugins/email.py b/auth_backend/auth_plugins/email.py index 68373aa1..4de9902b 100644 --- a/auth_backend/auth_plugins/email.py +++ b/auth_backend/auth_plugins/email.py @@ -166,7 +166,10 @@ async def _login(cls, user_inp: EmailLogin, background_tasks: BackgroundTasks) - userdata, ) return await cls._create_session( - query.user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + query.user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @staticmethod diff --git a/auth_backend/auth_plugins/github.py b/auth_backend/auth_plugins/github.py index 1ef892a0..126168a5 100644 --- a/auth_backend/auth_plugins/github.py +++ b/auth_backend/auth_plugins/github.py @@ -114,7 +114,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -169,7 +172,10 @@ async def _login(cls, user_inp: OauthResponseSchema, background_tasks: Backgroun userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/auth_plugins/google.py b/auth_backend/auth_plugins/google.py index 90afb810..686eaa3d 100644 --- a/auth_backend/auth_plugins/google.py +++ b/auth_backend/auth_plugins/google.py @@ -122,7 +122,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -161,7 +164,10 @@ async def _login(cls, user_inp: OauthResponseSchema, background_tasks: Backgroun userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/auth_plugins/keycloak.py b/auth_backend/auth_plugins/keycloak.py index f86af864..6759f8c9 100644 --- a/auth_backend/auth_plugins/keycloak.py +++ b/auth_backend/auth_plugins/keycloak.py @@ -113,7 +113,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -170,7 +173,10 @@ async def _login(cls, user_inp: OauthResponseSchema, background_tasks: Backgroun userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/auth_plugins/lkmsu.py b/auth_backend/auth_plugins/lkmsu.py index 339fcc94..4b170daa 100644 --- a/auth_backend/auth_plugins/lkmsu.py +++ b/auth_backend/auth_plugins/lkmsu.py @@ -111,7 +111,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -164,7 +167,10 @@ async def _login( userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/auth_plugins/telegram.py b/auth_backend/auth_plugins/telegram.py index a1ffebc7..79a691c5 100644 --- a/auth_backend/auth_plugins/telegram.py +++ b/auth_backend/auth_plugins/telegram.py @@ -88,7 +88,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -118,7 +121,10 @@ async def _login(cls, user_inp: OauthResponseSchema, background_tasks: Backgroun userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/auth_plugins/vk.py b/auth_backend/auth_plugins/vk.py index 41173751..7bb2bdd8 100644 --- a/auth_backend/auth_plugins/vk.py +++ b/auth_backend/auth_plugins/vk.py @@ -121,7 +121,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -170,7 +173,10 @@ async def _login(cls, user_inp: OauthResponseSchema, background_tasks: Backgroun userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/auth_plugins/yandex.py b/auth_backend/auth_plugins/yandex.py index f18a3655..96f1239f 100644 --- a/auth_backend/auth_plugins/yandex.py +++ b/auth_backend/auth_plugins/yandex.py @@ -126,7 +126,10 @@ async def _register( ) await AuthPluginMeta.user_updated(new_user, old_user) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod @@ -174,7 +177,10 @@ async def _login(cls, user_inp: OauthResponseSchema, background_tasks: Backgroun userdata, ) return await cls._create_session( - user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name + user, + user_inp.scopes, + db_session=db.session, + session_name=user_inp.session_name, ) @classmethod diff --git a/auth_backend/models/db.py b/auth_backend/models/db.py index 8c31cee9..05693772 100644 --- a/auth_backend/models/db.py +++ b/auth_backend/models/db.py @@ -158,6 +158,7 @@ class UserSession(BaseDbModel): user_id: Mapped[int] = mapped_column(Integer, sqlalchemy.ForeignKey("user.id")) expires: Mapped[datetime.datetime] = mapped_column(DateTime, default=session_expires_date) token: Mapped[str] = mapped_column(String, unique=True) + is_unbounded: Mapped[bool] = mapped_column(Boolean, default=False) last_activity: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow) create_ts: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow) user: Mapped[User] = relationship( diff --git a/auth_backend/routes/user_session.py b/auth_backend/routes/user_session.py index 50e0a69d..37927f73 100644 --- a/auth_backend/routes/user_session.py +++ b/auth_backend/routes/user_session.py @@ -69,7 +69,11 @@ async def me( | UserIndirectGroups(indirect_groups=[group.id for group in session.user.indirect_groups]).model_dump() ) if "session_scopes" in info: - result = result | SessionScopes(session_scopes=session.scopes).model_dump() + result = result | ( + SessionScopes(session_scopes=session.user.scopes).model_dump() + if session.is_unbounded + else SessionScopes(session_scopes=session.scopes).model_dump() + ) if "user_scopes" in info: result = result | UserScopes(user_scopes=session.user.scopes).model_dump() if "auth_methods" in info: @@ -98,6 +102,7 @@ async def create_session( new_session.expires, db_session=db.session, session_name=new_session.session_name, + is_unbounded=new_session.is_unbounded, ) @@ -146,9 +151,12 @@ async def get_sessions( id=session.id, last_activity=session.last_activity, session_name=session.session_name, + is_unbounded=session.is_unbounded, ) if "session_scopes" in info: - result['session_scopes'] = [_scope.name for _scope in session.scopes] + result['session_scopes'] = [ + _scope.name for _scope in (session.user.scopes if session.is_unbounded else session.scopes) + ] if "token" in info: result['token'] = session.token[-4:] if "expires" in info: diff --git a/auth_backend/schemas/models.py b/auth_backend/schemas/models.py index 5621cd89..8fd7bd24 100644 --- a/auth_backend/schemas/models.py +++ b/auth_backend/schemas/models.py @@ -132,6 +132,7 @@ class Session(Base): expires: datetime | None = None id: int user_id: int + is_unbounded: bool | None = None session_scopes: list[Scope] | None = None last_activity: datetime @@ -140,6 +141,7 @@ class SessionPost(Base): session_name: str | None = None scopes: list[Scope] = [] expires: datetime | None = None + is_unbounded: bool | None = None @classmethod @field_validator("expires") diff --git a/auth_backend/utils/security.py b/auth_backend/utils/security.py index 4588a4b1..1119f33b 100644 --- a/auth_backend/utils/security.py +++ b/auth_backend/utils/security.py @@ -53,7 +53,12 @@ async def __call__( if user_session.expired: self._except() - session_scopes = set([scope.name.lower() for scope in user_session.scopes]) + session_scopes = set( + [ + scope.name.lower() + for scope in (user_session.user.scopes if user_session.is_unbounded else user_session.scopes) + ] + ) if self._SESSION_UPDATE_SCOPE in session_scopes: user_session.expires = session_expires_date() db.session.commit() diff --git a/auth_backend/utils/user_session_control.py b/auth_backend/utils/user_session_control.py index e45f5f18..11389315 100644 --- a/auth_backend/utils/user_session_control.py +++ b/auth_backend/utils/user_session_control.py @@ -20,6 +20,7 @@ async def create_session( scopes_list_names: list[TypeScope] | None, expires: datetime | None = None, session_name: str | None = None, + is_unbounded: bool = False, *, db_session: DbSession, ) -> Session: @@ -33,10 +34,12 @@ async def create_session( user_id=user.id, token=random_string(length=settings.TOKEN_LENGTH), session_name=session_name ) user_session.expires = expires or user_session.expires + user_session.is_unbounded = is_unbounded db_session.add(user_session) db_session.flush() - for scope in scopes: - db_session.add(UserSessionScope(scope_id=scope.id, user_session_id=user_session.id)) + if not user_session.is_unbounded: + for scope in scopes: + db_session.add(UserSessionScope(scope_id=scope.id, user_session_id=user_session.id)) db_session.commit() return Session( session_name=session_name, @@ -44,6 +47,7 @@ async def create_session( token=user_session.token, id=user_session.id, expires=user_session.expires, + is_unbounded=user_session.is_unbounded, session_scopes=[_scope.name for _scope in user_session.scopes], last_activity=user_session.last_activity, ) diff --git a/migrations/versions/6dffd8e42152_193_add_unbounded_sessions.py b/migrations/versions/6dffd8e42152_193_add_unbounded_sessions.py new file mode 100644 index 00000000..d2940e86 --- /dev/null +++ b/migrations/versions/6dffd8e42152_193_add_unbounded_sessions.py @@ -0,0 +1,27 @@ +"""193 Add unbounded sessions + +Revision ID: 6dffd8e42152 +Revises: 2d29fc132e89 +Create Date: 2024-08-19 19:27:25.867548 + +""" + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = '6dffd8e42152' +down_revision = '2d29fc132e89' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('user_session', sa.Column('is_unbounded', sa.Boolean(), nullable=True)) + op.execute("UPDATE user_session SET is_unbounded='false'") + op.alter_column('user_session', 'is_unbounded', nullable=False) + + +def downgrade(): + op.drop_column('user_session', 'is_unbounded') diff --git a/migrations/versions/dcb89e72d446_session_security_scopes.py b/migrations/versions/dcb89e72d446_session_security_scopes.py index 01c585f4..033b8bea 100644 --- a/migrations/versions/dcb89e72d446_session_security_scopes.py +++ b/migrations/versions/dcb89e72d446_session_security_scopes.py @@ -7,9 +7,7 @@ """ from alembic import op -from sqlalchemy.orm import Session - -from auth_backend.models.db import DynamicOption, Group, Scope, User, UserSession +from sqlalchemy.sql import text # revision identifiers, used by Alembic. @@ -21,31 +19,46 @@ def upgrade(): conn = op.get_bind() - session = Session(conn) - - root_group_id: DynamicOption = session.query(DynamicOption).filter(DynamicOption.name == "root_group_id").one() - users_group_id: DynamicOption = session.query(DynamicOption).filter(DynamicOption.name == "users_group_id").one() - - root_group: Group = Group.get(root_group_id.value_integer, session=session) - user_group: Group = Group.get(users_group_id.value_integer, session=session) - try: - user = root_group.users[0] - except IndexError: - user = User.create(session=session) - user.groups.append(root_group) - - scope1 = Scope(creator_id=user.id, name="auth.session.create", comment="Create user session") - scope2 = Scope(creator_id=user.id, name="auth.session.update", comment="Update user session") - session.add_all((scope1, scope2)) - session.flush() - root_group.scopes.update([scope1, scope2]) - user_group.scopes.update([scope1, scope2]) - session.flush() - user_sessions = UserSession.query(session=session).all() - for user_session in user_sessions: - user_session.scopes.extend((scope1, scope2)) - session.flush() - session.commit() + + query: str = 'SELECT value_integer FROM dynamic_option WHERE name=:option_name' + root_group_id: int = conn.execute(text(query).bindparams(option_name="root_group_id")).scalar() + users_group_id: int = conn.execute(text(query).bindparams(option_name="users_group_id")).scalar() + + query = 'SELECT user_id FROM user_group WHERE group_id=:group_id' + root_user_id = conn.execute(text(query).bindparams(group_id=root_group_id)).scalar() + if root_user_id is None: + query = ( + 'INSERT INTO "user" (is_deleted, create_ts, update_ts) VALUES (false, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)' + ) + conn.execute(text(query)) + + query = 'INSERT INTO "user_group" VALUES (:user_id, :group_id, false)' + root_user_id = conn.execute(text('SELECT id FROM "user" ORDER BY id DESC')).scalar() + conn.execute(text(query).bindparams(user_id=root_user_id, group_id=root_group_id)) + + query = 'INSERT INTO "scope" VALUES (:creator_id, :name, :comment, false)' + conn.execute( + text(query).bindparams(creator_id=root_user_id, name="auth.session.create", comment="Create user session") + ) + conn.execute( + text(query).bindparams(creator_id=root_user_id, name="auth.session.update", comment="Update user session") + ) + + query = 'SELECT id FROM scope WHERE name=:name' + scope1_id = conn.execute(text(query).bindparams(name="auth.session.create")).scalar() + scope2_id = conn.execute(text(query).bindparams(name="auth.session.update")).scalar() + + query = 'INSERT INTO "group_scope" VALUES (:group_id, :scope_id, false)' + conn.execute(text(query).bindparams(group_id=root_group_id, scope_id=scope1_id)) + conn.execute(text(query).bindparams(group_id=root_group_id, scope_id=scope2_id)) + conn.execute(text(query).bindparams(group_id=users_group_id, scope_id=scope1_id)) + conn.execute(text(query).bindparams(group_id=users_group_id, scope_id=scope2_id)) + + session_ids = conn.execute(text('SELECT id FROM user_session')).all() + query = 'INSERT INTO "user_session_scope" VALUES (:user_session_id, :scope_id, false)' + for session_id in session_ids: + conn.execute(text(query).bindparams(user_session_id=session_id[0], scope_id=scope1_id)) + conn.execute(text(query).bindparams(user_session_id=session_id[0], scope_id=scope2_id)) def downgrade(): diff --git a/migrations/versions/ed1a7f2276d4_merge_unbounded_and_verified.py b/migrations/versions/ed1a7f2276d4_merge_unbounded_and_verified.py new file mode 100644 index 00000000..aafb82ea --- /dev/null +++ b/migrations/versions/ed1a7f2276d4_merge_unbounded_and_verified.py @@ -0,0 +1,21 @@ +"""merge unbounded and verified + +Revision ID: ed1a7f2276d4 +Revises: 5d71a2a2405d, 6dffd8e42152 +Create Date: 2024-12-07 12:58:57.981808 + +""" + +# revision identifiers, used by Alembic. +revision = 'ed1a7f2276d4' +down_revision = ('5d71a2a2405d', '6dffd8e42152') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/tests/conftest.py b/tests/conftest.py index 68319eba..1f42b388 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,13 @@ import pytest from fastapi.testclient import TestClient -from sqlalchemy import create_engine +from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker from starlette import status from auth_backend.auth_plugins import YandexAuth from auth_backend.models import AuthMethod, User -from auth_backend.models.db import AuthMethod, Group, GroupScope, Scope, User, UserGroup, UserSession, UserSessionScope +from auth_backend.models.db import AuthMethod, Group, Scope, User, UserGroup, UserSession, UserSessionScope from auth_backend.routes.base import app from auth_backend.settings import get_settings from auth_backend.utils.string import random_string @@ -103,6 +103,8 @@ def user(client_auth: TestClient, dbsession): for row in session: dbsession.delete(row) dbsession.commit() + for row in dbsession.query(UserGroup).filter(UserGroup.user_id == db_user.user_id).all(): + dbsession.delete(row) for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == db_user.user_id).all(): dbsession.delete(row) dbsession.delete(dbsession.query(User).filter(User.id == db_user.user_id).one()) @@ -115,7 +117,9 @@ def parent_id(client, dbsession): body = {"name": f"group{time}", "parent_id": None, "scopes": []} response = client.post(url="/group", json=body) yield response.json()["id"] - dbsession.query(Group).get(response.json()["id"]) + group: Group = Group.get(response.json()["id"], session=dbsession) + group.users.clear() + group.delete(id=response.json()["id"], session=dbsession) dbsession.commit() @@ -135,7 +139,7 @@ def _group(client: TestClient): for row in _ids: group: Group = Group.get(row, session=dbsession) group.users.clear() - group.delete(session=dbsession) + group.delete(id=row, session=dbsession) dbsession.commit() @@ -153,26 +157,8 @@ def _user(client): yield _user - for row in dbsession.query(UserGroup).all(): - dbsession.delete(row) - dbsession.flush() - - dbsession.query(GroupScope).delete() - dbsession.flush() - - dbsession.query(Scope).delete() - dbsession.flush() - - dbsession.query(Group).delete() - dbsession.flush() - - dbsession.query(AuthMethod).delete() - dbsession.flush() - - dbsession.query(UserSession).delete() - dbsession.flush() - - dbsession.query(User).delete() + for user in _users: + dbsession.delete(user) dbsession.commit() @@ -195,11 +181,15 @@ def user_scopes(dbsession, user): "auth.session.update", ] scopes = [] + created_scopes = [] for i in scopes_names: - dbsession.add(scope1 := Scope(name=i, creator_id=user_id)) + scope1 = Scope.query(session=dbsession).filter(func.lower(Scope.name) == i.lower()).one_or_none() + if scope1 is None: + dbsession.add(scope1 := Scope(name=i, creator_id=user_id)) + created_scopes.append(scope1) scopes.append(scope1) token_ = random_string() - dbsession.add(user_session := UserSession(user_id=user_id, token=token_)) + dbsession.add(user_session := UserSession(user_id=user_id, token=token_, is_unbounded=False)) dbsession.commit() user_scopes = [] for i in scopes: @@ -211,9 +201,9 @@ def user_scopes(dbsession, user): for i in user_scopes: dbsession.delete(i) dbsession.flush() - for i in scopes: + for i in created_scopes: dbsession.delete(i) - dbsession.delete(user_session) + dbsession.flush() dbsession.commit() diff --git a/tests/test_routes/test_email_message_delay.py b/tests/test_routes/test_email_message_delay.py index 64d84920..ebc9b9de 100644 --- a/tests/test_routes/test_email_message_delay.py +++ b/tests/test_routes/test_email_message_delay.py @@ -2,6 +2,7 @@ from sqlalchemy.orm import Session from starlette import status +from auth_backend.models.db import AuthMethod, User from auth_backend.settings import get_settings @@ -25,3 +26,13 @@ def test_message_delay(client_auth_email_delay: TestClient, dbsession: Session): assert delay_response.status_code == status.HTTP_429_TOO_MANY_REQUESTS settings_.IP_DELAY_TIME_IN_MINUTES = ip_delay settings_.EMAIL_DELAY_TIME_IN_MINUTES = email_delay + auth_method = ( + dbsession.query(AuthMethod) + .filter(AuthMethod.param == "email", AuthMethod.value == "test-user@profcomff.com") + .one() + ) + for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == auth_method.user_id).all(): + dbsession.delete(row) + dbsession.flush() + dbsession.delete(dbsession.query(User).filter(User.id == auth_method.user_id).one()) + dbsession.commit() diff --git a/tests/test_routes/test_group_scopes.py b/tests/test_routes/test_group_scopes.py index 1ce55b81..d79d3841 100644 --- a/tests/test_routes/test_group_scopes.py +++ b/tests/test_routes/test_group_scopes.py @@ -85,28 +85,19 @@ def test_scopes_user_session(client_auth, dbsession, user_scopes): assert response.status_code == 200 response = client_auth.patch(f"/user/{user_id}", json={"groups": [_group3]}, headers=headers) assert response.status_code == 200 - response = client_auth.post("/email/login", json=body_user | {"scopes": [scope1.name]}) + response = client_auth.post("/email/login", json=body_user) assert response.status_code == 200 token = response.json()["token"] - response = client_auth.post("/email/login", json=body_user | {"scopes": [scope2.name + "s"]}) - assert response.status_code == 404 response = client_auth.get("/me", headers={"Authorization": token}, params={"info": ["session_scopes"]}) assert response.status_code == 200 assert scope1.id in [row["id"] for row in response.json()["session_scopes"]] - response = client_auth.get("/me", headers={"Authorization": login["token"]}, params={"info": ["session_scopes"]}) - assert response.status_code == 200 - assert scope2.id not in [row["id"] for row in response.json()["session_scopes"]] response = client_auth.patch(f"/group/{_group3}", json={"scopes": [scope1.id, scope2.id]}, headers=headers) assert response.status_code == 200 - response = client_auth.post("/email/login", json=body_user | {"scopes": [scope1.name, scope2.name]}) + response = client_auth.post("/email/login", json=body_user) assert response.status_code == 200 token1 = response.json()["token"] - response = client_auth.post("/email/login", json=body_user | {"scopes": [scope2.name]}) - assert response.status_code == 200 - token2 = response.json()["token"] - response = client_auth.post("/email/login", json=body_user | {"scopes": [scope1.name]}) + response = client_auth.post("/email/login", json=body_user) assert response.status_code == 200 - token3 = response.json()["token"] response = client_auth.get( "/me", headers={"Authorization": token1}, params={"info": ["session_scopes", "user_scopes"]} ) @@ -116,27 +107,11 @@ def test_scopes_user_session(client_auth, dbsession, user_scopes): assert scope2.id in [row["id"] for row in response.json()["user_scopes"]] assert scope1.id in [row["id"] for row in response.json()["user_scopes"]] response = client_auth.get( - "/me", headers={"Authorization": token2}, params={"info": ["session_scopes", "user_scopes"]} + "/me", headers={"Authorization": login["token"]}, params={"info": ["session_scopes", "user_scopes"]} ) assert response.status_code == 200 assert scope2.id in [row["id"] for row in response.json()["session_scopes"]] - assert scope1.id not in [row["id"] for row in response.json()["session_scopes"]] - assert scope2.id in [row["id"] for row in response.json()["user_scopes"]] - assert scope1.id in [row["id"] for row in response.json()["user_scopes"]] - response = client_auth.get( - "/me", headers={"Authorization": token3}, params={"info": ["session_scopes", "user_scopes"]} - ) - assert response.status_code == 200 assert scope1.id in [row["id"] for row in response.json()["session_scopes"]] - assert scope2.id not in [row["id"] for row in response.json()["session_scopes"]] - assert scope2.id in [row["id"] for row in response.json()["user_scopes"]] - assert scope1.id in [row["id"] for row in response.json()["user_scopes"]] - response = client_auth.get( - "/me", headers={"Authorization": login["token"]}, params={"info": ["session_scopes", "user_scopes"]} - ) - assert response.status_code == 200 - assert scope2.id not in [row["id"] for row in response.json()["session_scopes"]] - assert scope1.id not in [row["id"] for row in response.json()["session_scopes"]] assert scope2.id in [row["id"] for row in response.json()["user_scopes"]] assert scope1.id in [row["id"] for row in response.json()["user_scopes"]] dbsession.query(GroupScope).filter(GroupScope.group_id == _group1).delete() diff --git a/tests/test_routes/test_groups.py b/tests/test_routes/test_groups.py index 9263bf62..c5f23fb6 100644 --- a/tests/test_routes/test_groups.py +++ b/tests/test_routes/test_groups.py @@ -35,8 +35,8 @@ def test_create(client, dbsession): assert parent.parent_id == response_parent.json()["parent_id"] assert parent.name == response_parent.json()["name"] - Group.delete(response.json()["id"], session=dbsession) - Group.delete(response_parent.json()["id"], session=dbsession) + for row in dbsession.query(Group).get(group.id), dbsession.query(Group).get(parent.id): + dbsession.delete(row) dbsession.commit() diff --git a/tests/test_routes/test_login.py b/tests/test_routes/test_login.py index 0599c6db..111e8242 100644 --- a/tests/test_routes/test_login.py +++ b/tests/test_routes/test_login.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from starlette import status -from auth_backend.models.db import AuthMethod, Group, User, UserGroup, UserSession +from auth_backend.models.db import AuthMethod, Group, GroupScope, Scope, User, UserGroup, UserSession url = "/email/login" @@ -134,3 +134,47 @@ def test_check_me_groups(client_auth: TestClient, user_scopes, dbsession): dbsession.query(Group).filter(Group.id == _group2).delete() dbsession.query(Group).filter(Group.id == _group1).delete() dbsession.commit() + + +def test_check_unbounded_session(client_auth: TestClient, user_scopes, dbsession): + token_, user = user_scopes + body_user = user["body"] + body_user["is_unbounded"] = True + scope1 = dbsession.query(Scope).filter(Scope.name == "auth.group.create").one() + time1 = datetime.datetime.utcnow() + body = {"name": f"group{time1}", "parent_id": None, "scopes": []} + headers = {"Authorization": token_} + _group1 = client_auth.post(url="/group", json=body, headers=headers).json()["id"] + client_auth.patch(f"/user/{user['user_id']}", json={"groups": [_group1]}, headers={"Authorization": token_}) + response = client_auth.post("/email/login", json=body_user) + assert response.status_code == status.HTTP_200_OK + token = response.json()["token"] + response = client_auth.get( + "/me", headers={"Authorization": token}, params={"info": ["session_scopes", "user_scopes"]} + ) + assert response.json()["session_scopes"] == response.json()["user_scopes"] + assert scope1.id not in [row["id"] for row in response.json()["session_scopes"]] + client_auth.patch(f"/group/{_group1}", json={"scopes": [scope1.id]}, headers=headers) + response = client_auth.get("/me", headers={"Authorization": token}, params={"info": ["session_scopes"]}) + assert scope1.id in [row["id"] for row in response.json()["session_scopes"]] + dbsession.query(GroupScope).filter(GroupScope.group_id == _group1).delete() + dbsession.query(UserGroup).filter(UserGroup.group_id == _group1).delete() + dbsession.query(Group).filter(Group.id == _group1).delete() + dbsession.commit() + + +def test_check_unbounded_session_scopes(client_auth: TestClient, user_scopes, dbsession): + token_, user = user_scopes + body_user = user["body"] + body_user["is_unbounded"] = True + scope1 = dbsession.query(Scope).filter(Scope.name == "auth.session.create").one() + scope2 = dbsession.query(Scope).filter(Scope.name == "auth.session.update").one() + response = client_auth.post("/email/login", json=body_user | {"scopes": [scope1.name]}) + assert response.status_code == status.HTTP_200_OK + token = response.json()["token"] + response = client_auth.get( + "/me", headers={"Authorization": token}, params={"info": ["session_scopes", "user_scopes"]} + ) + assert response.json()["session_scopes"] == response.json()["user_scopes"] + assert scope1.id in [row["id"] for row in response.json()["session_scopes"]] + assert scope2.id in [row["id"] for row in response.json()["session_scopes"]] diff --git a/tests/test_routes/test_registration.py b/tests/test_routes/test_registration.py index b7465464..6326c366 100644 --- a/tests/test_routes/test_registration.py +++ b/tests/test_routes/test_registration.py @@ -12,7 +12,7 @@ url = "/email/registration" -def test_invalid_email(client_auth: TestClient): +def test_invalid_email(client_auth: TestClient, dbsession: Session): body1 = {"email": f"notEmailForSure", "password": "string"} body2 = {"email": f"EmailForSure{datetime.datetime.utcnow()}@mail.gtg", "password": ""} body3 = { @@ -38,6 +38,17 @@ def test_invalid_email(client_auth: TestClient): response = client_auth.post(url, json=body6) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + ids = [] + for email in [body3["email"], body4["email"], body5["email"]]: + ids.append( + dbsession.query(AuthMethod).filter(AuthMethod.param == "email", AuthMethod.value == email).one().user_id + ) + for user_id in ids: + for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == user_id).all(): + dbsession.delete(row) + dbsession.delete(dbsession.query(User).filter(User.id == user_id).one()) + dbsession.commit() + def test_main_scenario(client_auth: TestClient, dbsession: Session): time = datetime.datetime.utcnow() diff --git a/tests/test_routes/test_user.py b/tests/test_routes/test_user.py index bbb59a7c..68378e9e 100644 --- a/tests/test_routes/test_user.py +++ b/tests/test_routes/test_user.py @@ -19,7 +19,10 @@ def test_user_email(client: TestClient, dbsession: Session, user_factory): resp = client.patch(f"/user/{user1}", json={"groups": [group]}) assert resp.status_code == 200 assert "email" not in resp.json().keys() + dbsession.delete(email_user) + for row in dbsession.query(UserGroup).filter(UserGroup.user_id == user1).all(): + dbsession.delete(row) gr = Group.get(group, session=dbsession) dbsession.delete(gr) dbsession.commit() @@ -40,8 +43,9 @@ def test_delete_user(client_auth: TestClient, dbsession: Session, user_factory, assert resp.status_code == 200 user = dbsession.query(User).filter(User.id == user1).one_or_none() assert user.is_deleted - dbsession.delete(email_user) dbsession.query(GroupScope).filter(GroupScope.group_id == group).delete() - dbsession.query(UserGroup).filter(UserGroup.group_id == group).delete() + for row in dbsession.query(UserGroup).filter(UserGroup.user_id == user1).all(): + dbsession.delete(row) dbsession.query(Group).filter(Group.id == group).delete() + dbsession.delete(email_user) dbsession.commit() diff --git a/tests/test_routes/test_user_groups.py b/tests/test_routes/test_user_groups.py index db7f87ed..249b885b 100644 --- a/tests/test_routes/test_user_groups.py +++ b/tests/test_routes/test_user_groups.py @@ -25,6 +25,9 @@ def test_add_user(client: TestClient, dbsession: Session, user_factory): user = User.get(usergroup.user_id, session=dbsession) assert user in gr.users assert gr in user.groups + + for row in dbsession.query(UserGroup).filter(UserGroup.user_id == user1).all(): + dbsession.delete(row) dbsession.delete(gr) dbsession.commit() @@ -66,6 +69,13 @@ def test_get_user_list(client, dbsession, user_factory): assert us2 in gr.users assert us3 in gr.users + for user_id in [user1, user2, user3]: + for row in dbsession.query(UserGroup).filter(UserGroup.user_id == user_id).all(): + dbsession.delete(row) + dbsession.commit() + dbsession.delete(gr) + dbsession.commit() + def test_del_user_from_group(client, dbsession, user_factory): time1 = datetime.utcnow() @@ -99,3 +109,12 @@ def test_del_user_from_group(client, dbsession, user_factory): assert us1 in gr.users assert us2 not in gr.users assert us3 in gr.users + gr.users.clear() + gr.delete(id=group, session=dbsession) + dbsession.commit() + + for user_id in [user1, user2, user3]: + for row in dbsession.query(UserGroup).filter(UserGroup.user_id == user_id).all(): + dbsession.delete(row) + dbsession.delete(gr) + dbsession.commit() diff --git a/tests/test_routes/test_user_sessions.py b/tests/test_routes/test_user_sessions.py index 049dd9bc..8fc934cb 100644 --- a/tests/test_routes/test_user_sessions.py +++ b/tests/test_routes/test_user_sessions.py @@ -177,7 +177,7 @@ def test_patch_session(client_auth: TestClient, dbsession: Session, user_scopes) token = user_scopes[0] header = {"Authorization": token} params = {"info": ["session_scopes", "token", "expires"]} - payload = {"session_name": "test_session"} + payload = {"session_name": "test_session", "is_unbounded": False} new_session1 = client_auth.post("/session", headers=header, json=payload) assert new_session1.status_code == status.HTTP_200_OK assert new_session1.json()['session_name'] == payload['session_name'] @@ -193,3 +193,28 @@ def test_patch_session(client_auth: TestClient, dbsession: Session, user_scopes) for session in get_patch_session2.json(): if session['id'] == new_session1.json()['id']: assert session["session_scopes"] == [] + + +def test_create_unbounded_session(client_auth: TestClient, user_scopes, dbsession): + token_, user = user_scopes + scope1 = dbsession.query(Scope).filter(Scope.name == "auth.group.create").one() + time1 = datetime.utcnow() + body = {"name": f"group{time1}", "parent_id": None, "scopes": []} + headers = {"Authorization": token_} + _group1 = client_auth.post(url="/group", json=body, headers=headers).json()["id"] + client_auth.patch(f"/user/{user['user_id']}", json={"groups": [_group1]}, headers={"Authorization": token_}) + response = client_auth.post("/session", json={"is_unbounded": True}, headers=headers) + assert response.status_code == status.HTTP_200_OK + token = response.json()["token"] + response = client_auth.get( + "/me", headers={"Authorization": token}, params={"info": ["session_scopes", "user_scopes"]} + ) + assert response.json()["session_scopes"] == response.json()["user_scopes"] + assert scope1.id not in [row["id"] for row in response.json()["session_scopes"]] + client_auth.patch(f"/group/{_group1}", json={"scopes": [scope1.id]}, headers=headers) + response = client_auth.get("/me", headers={"Authorization": token}, params={"info": ["session_scopes"]}) + assert scope1.id in [row["id"] for row in response.json()["session_scopes"]] + dbsession.query(GroupScope).filter(GroupScope.group_id == _group1).delete() + dbsession.query(UserGroup).filter(UserGroup.group_id == _group1).delete() + dbsession.query(Group).filter(Group.id == _group1).delete() + dbsession.commit()