diff --git a/rating_api/models/db.py b/rating_api/models/db.py index 4724411..28f5a4f 100644 --- a/rating_api/models/db.py +++ b/rating_api/models/db.py @@ -247,6 +247,36 @@ def order_by_like_diff(cls, asc_order: bool = False): else: return cls.like_dislike_diff.desc() + @hybrid_method + def has_reaction(self, user_id: int, react: Reaction) -> bool: + return any(reaction.user_id == user_id and reaction.reaction == react for reaction in self.reactions) + + @has_reaction.expression + def has_reaction(cls, user_id: int, react: Reaction): + return ( + select([true()]) + .where( + and_( + CommentReaction.comment_uuid == cls.uuid, + CommentReaction.user_id == user_id, + CommentReaction.reaction == react, + ) + ) + .exists() + ) + + @classmethod + def reactions_for_comments(cls, user_id: int, session, comments): + if not user_id or not comments: + return {} + comments_uuid = [c.uuid for c in comments] + reactions = ( + session.query(CommentReaction) + .filter(CommentReaction.user_id == user_id, CommentReaction.comment_uuid.in_(comments_uuid)) + .all() + ) + return {r.comment_uuid: r.reaction for r in reactions} + class LecturerUserComment(BaseDbModel): id: Mapped[int] = mapped_column(Integer, primary_key=True) diff --git a/rating_api/routes/comment.py b/rating_api/routes/comment.py index 9011277..5012afc 100644 --- a/rating_api/routes/comment.py +++ b/rating_api/routes/comment.py @@ -22,8 +22,10 @@ CommentGet, CommentGetAll, CommentGetAllWithAllInfo, + CommentGetAllWithLike, CommentGetAllWithStatus, CommentGetWithAllInfo, + CommentGetWithLike, CommentGetWithStatus, CommentImportAll, CommentPost, @@ -155,18 +157,25 @@ async def import_comments( return result -@comment.get("/{uuid}", response_model=CommentGet) -async def get_comment(uuid: UUID) -> CommentGet: +@comment.get("/{uuid}", response_model=CommentGetWithLike) +async def get_comment(uuid: UUID, user=Depends(UnionAuth())) -> CommentGetWithLike: """ Возвращает комментарий по его UUID в базе данных RatingAPI """ comment: Comment = Comment.query(session=db.session).filter(Comment.uuid == uuid).one_or_none() if comment is None: raise ObjectNotFound(Comment, uuid) - return CommentGet.model_validate(comment) + base_data = CommentGet.model_validate(comment) + return CommentGetWithLike( + **base_data.model_dump(), + is_liked=comment.has_reaction(user.get("id"), Reaction.LIKE), + is_disliked=comment.has_reaction(user.get("id"), Reaction.DISLIKE), + ) -@comment.get("", response_model=Union[CommentGetAll, CommentGetAllWithAllInfo, CommentGetAllWithStatus]) +@comment.get( + "", response_model=Union[CommentGetAll, CommentGetAllWithLike, CommentGetAllWithAllInfo, CommentGetAllWithStatus] +) async def get_comments( limit: int = 10, offset: int = 0, @@ -180,7 +189,7 @@ async def get_comments( unreviewed: bool = False, asc_order: bool = False, user=Depends(UnionAuth(scopes=["rating.comment.review"], auto_error=False, allow_none=False)), -) -> CommentGetAll: +) -> Union[CommentGetAll, CommentGetAllWithLike, CommentGetAllWithAllInfo, CommentGetAllWithStatus]: """ Scopes: `["rating.comment.review"]` @@ -203,6 +212,7 @@ async def get_comments( `asc_order` -Если передано true, сортировать в порядке возрастания. Иначе - в порядке убывания """ + comments_query = ( Comment.query(session=db.session) .filter(Comment.search_by_lectorer_id(lecturer_id)) @@ -219,6 +229,7 @@ async def get_comments( ) ) comments = comments_query.limit(limit).offset(offset).all() + like = False if not comments: raise ObjectNotFound(Comment, 'all') if user and "rating.comment.review" in [scope['name'] for scope in user.get('session_scopes')]: @@ -228,8 +239,13 @@ async def get_comments( result = CommentGetAllWithStatus(limit=limit, offset=offset, total=len(comments)) comment_validator = CommentGetWithStatus else: - result = CommentGetAll(limit=limit, offset=offset, total=len(comments)) + result = ( + CommentGetAllWithLike(limit=limit, offset=offset, total=len(comments)) + if user + else CommentGetAll(limit=limit, offset=offset, total=len(comments)) + ) comment_validator = CommentGet + like = True result.comments = comments @@ -244,8 +260,26 @@ async def get_comments( result.comments = [comment for comment in result.comments if comment.review_status is ReviewStatus.APPROVED] result.total = len(result.comments) - result.comments = [comment_validator.model_validate(comment) for comment in result.comments] + comments_with_like = [] + current_user_id = user.get("id") if user else None + if current_user_id and result.comments: + user_reactions = Comment.reactions_for_comments(current_user_id, db.session, result.comments) + else: + user_reactions = {} + + for comment in result.comments: + base_data = comment_validator.model_validate(comment) + + if current_user_id: + reaction = user_reactions.get(comment.uuid) + comment_with_reactions = CommentGetWithLike( + **base_data.model_dump(), is_liked=reaction == Reaction.LIKE, is_disliked=reaction == Reaction.DISLIKE + ) + comments_with_like.append(comment_with_reactions) + else: + comments_with_like.append(base_data) + result.comments = comments_with_like return result diff --git a/rating_api/schemas/models.py b/rating_api/schemas/models.py index ad1b0a0..38d61e0 100644 --- a/rating_api/schemas/models.py +++ b/rating_api/schemas/models.py @@ -27,15 +27,33 @@ class CommentGet(Base): dislike_count: int +class CommentGetWithLike(CommentGet): + is_liked: bool + is_disliked: bool + + class CommentGetWithStatus(CommentGet): review_status: ReviewStatus +""" +class CommentGetWithLikeAndStatus(CommentGetWithLike): + review_status: ReviewStatus +""" + + class CommentGetWithAllInfo(CommentGet): review_status: ReviewStatus approved_by: int | None = None +""" +class CommentGetWithAllInfoAndLike(CommentGetWithLike): + review_status: ReviewStatus + approved_by: int | None = None +""" + + class CommentUpdate(Base): subject: str = None text: str = None @@ -74,6 +92,13 @@ class CommentGetAll(Base): total: int +class CommentGetAllWithLike(Base): + comments: list[CommentGetWithLike] = [] + limit: int + offset: int + total: int + + class CommentGetAllWithStatus(Base): comments: list[CommentGetWithStatus] = [] limit: int @@ -81,6 +106,15 @@ class CommentGetAllWithStatus(Base): total: int +""" +class CommentGetAllWithStatusAndLike(Base): + comments: list[CommentGetWithLikeAndStatus] = [] + limit: int + offset: int + total: int +""" + + class CommentGetAllWithAllInfo(Base): comments: list[CommentGetWithAllInfo] = [] limit: int @@ -88,6 +122,15 @@ class CommentGetAllWithAllInfo(Base): total: int +""" +class CommentGetAllWithAllInfoAndLike(Base): + comments: list[CommentGetWithAllInfoAndLike] = [] + limit: int + offset: int + total: int +""" + + class LecturerUserCommentPost(Base): lecturer_id: int user_id: int diff --git a/tests/test_routes/test_comment.py b/tests/test_routes/test_comment.py index 5f62bba..065fa5d 100644 --- a/tests/test_routes/test_comment.py +++ b/tests/test_routes/test_comment.py @@ -1,6 +1,5 @@ import datetime import logging -import uuid import pytest from starlette import status @@ -196,13 +195,42 @@ def test_create_comment(client, dbsession, lecturers, body, lecturer_n, response assert user_comment is not None -def test_get_comment(client, comment): +@pytest.mark.parametrize( + "reaction_data, expected_reaction, comment_user_id", + [ + (None, None, 0), + ((0, Reaction.LIKE), "is_liked", 0), # my like on my comment + ((0, Reaction.DISLIKE), "is_disliked", 0), + ((999, Reaction.LIKE), None, 0), # someone else's like on my comment + ((999, Reaction.DISLIKE), None, 0), + ((0, Reaction.LIKE), "is_liked", 999), # my like on someone else's comment + ((0, Reaction.DISLIKE), "is_disliked", 999), + ((333, Reaction.LIKE), None, 999), # someone else's like on another person's comment + ((333, Reaction.DISLIKE), None, 999), + (None, None, None), # anonymous + ], +) +def test_get_comment_with_reaction(client, dbsession, comment, reaction_data, expected_reaction, comment_user_id): + comment.user_id = comment_user_id + + if reaction_data: + user_id, reaction_type = reaction_data + reaction = CommentReaction(user_id=user_id, comment_uuid=comment.uuid, reaction=reaction_type) + dbsession.add(reaction) + + dbsession.commit() + response_comment = client.get(f'{url}/{comment.uuid}') - print("1") - assert response_comment.status_code == status.HTTP_200_OK - random_uuid = uuid.uuid4() - response = client.get(f'{url}/{random_uuid}') - assert response.status_code == status.HTTP_404_NOT_FOUND + + if response_comment: + data = response_comment.json() + if expected_reaction: + assert data[expected_reaction] + else: + assert data["is_liked"] == False + assert data["is_disliked"] == False + else: + assert response_comment.status_code == status.HTTP_404_NOT_FOUND @pytest.fixture