diff --git a/forum/api/threads.py b/forum/api/threads.py index b5b036c5..f04bcc80 100644 --- a/forum/api/threads.py +++ b/forum/api/threads.py @@ -362,6 +362,7 @@ def get_user_threads( user_id: Optional[str] = None, group_id: Optional[int] = None, group_ids: Optional[int] = None, + context: Optional[str] = None, **kwargs: Any, ) -> dict[str, Any]: """ @@ -385,6 +386,7 @@ def get_user_threads( "user_id": user_id, "group_id": group_id, "group_ids": group_ids, + "context": context, } params = {k: v for k, v in params.items() if v is not None} backend.validate_params(params) diff --git a/forum/backends/mongodb/api.py b/forum/backends/mongodb/api.py index b279ac8e..4c3dd78b 100644 --- a/forum/backends/mongodb/api.py +++ b/forum/backends/mongodb/api.py @@ -948,6 +948,7 @@ def validate_params(params: dict[str, Any], user_id: Optional[str] = None) -> No "commentable_ids", "group_id", "group_ids", + "context", ] if not user_id: valid_params.append("user_id") @@ -992,6 +993,7 @@ def get_threads( int(params.get("per_page", 100)), commentable_ids=params.get("commentable_ids", []), is_moderator=params.get("is_moderator", False), + context=params.get("context", "course"), ) context: dict[str, Any] = { "count_flagged": count_flagged, diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index c383d31a..9120ac15 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -1105,6 +1105,7 @@ def validate_params( "commentable_ids", "group_id", "group_ids", + "context", ] if not user_id: valid_params.append("user_id") @@ -1158,6 +1159,7 @@ def get_threads( params.get("sort_key", ""), int(params.get("page", 1)), int(params.get("per_page", 100)), + context=params.get("context", "course"), commentable_ids=params.get("commentable_ids", []), is_moderator=params.get("is_moderator", False), ) diff --git a/tests/test_views/test_threads.py b/tests/test_views/test_threads.py index ca28d864..46892158 100644 --- a/tests/test_views/test_threads.py +++ b/tests/test_views/test_threads.py @@ -447,6 +447,43 @@ def test_unresponded_filter(api_client: APIClient, patched_get_backend: Any) -> assert len(thread) == 1 +def test_get_user_threads_context( + api_client: APIClient, patched_get_backend: Any +) -> None: + """Test get_user_threads filters threads by context.""" + backend = patched_get_backend + user_id, course_thread_id = setup_models(backend=backend) + standalone_thread_id = backend.create_thread( + { + "title": "Standalone Thread", + "body": "Standalone Thread", + "course_id": "course1", + "commentable_id": "CommentThread", + "author_id": user_id, + "author_username": "user1", + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + "context": "standalone", + } + ) + + # Default (course) context: only the course thread is returned + response = api_client.get_json("/api/v2/threads", {"course_id": "course1"}) + assert response.status_code == 200 + ids = [t["id"] for t in response.json()["collection"]] + assert course_thread_id in ids + assert standalone_thread_id not in ids + + # Explicit standalone context: only the standalone thread is returned + response = api_client.get_json( + "/api/v2/threads", {"course_id": "course1", "context": "standalone"} + ) + assert response.status_code == 200 + ids = [t["id"] for t in response.json()["collection"]] + assert standalone_thread_id in ids + assert course_thread_id not in ids + + def test_filter_by_post_type(api_client: APIClient, patched_get_backend: Any) -> None: """Test filter threads by thread_type through get thread API.""" backend = patched_get_backend