From fce3addde656c06064a47e81974451c9e6a64fec Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Tue, 5 Sep 2023 11:29:40 +0200 Subject: [PATCH 01/65] =?UTF-8?q?=F0=9F=8E=A8(tests)=20rename=20converter?= =?UTF-8?q?=20tests=20with=20path=20pattern?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some tests in edx to xapi converters were named without a pattern. The path pattern used in other tests is applied for these tests. --- tests/models/edx/converters/xapi/test_base.py | 6 ++-- .../edx/converters/xapi/test_navigational.py | 2 +- .../models/edx/converters/xapi/test_server.py | 6 ++-- .../models/edx/converters/xapi/test_video.py | 30 ++++++++++++------- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/tests/models/edx/converters/xapi/test_base.py b/tests/models/edx/converters/xapi/test_base.py index 6ef84d66c..d65fa2ef2 100644 --- a/tests/models/edx/converters/xapi/test_base.py +++ b/tests/models/edx/converters/xapi/test_base.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_base_xapi_converter_successful_initialization( +def test_models_edx_converters_xapi_base_xapi_converter_successful_initialization( uuid_namespace, ): """Test BaseXapiConverter initialization.""" @@ -26,8 +26,8 @@ def _get_conversion_items(self): # pylint: disable=no-self-use assert converter.uuid_namespace == UUID(uuid_namespace) -def test_base_xapi_converter_unsuccessful_initialization(): - """Test BaseXapiConverter failed initialization.""" +def test_models_edx_converters_xapi_base_xapi_converter_unsuccessful_initialization(): + """Tests BaseXapiConverter failed initialization.""" class DummyBaseXapiConverter(BaseXapiConverter): """Dummy implementation of abstract BaseXapiConverter.""" diff --git a/tests/models/edx/converters/xapi/test_navigational.py b/tests/models/edx/converters/xapi/test_navigational.py index b49565303..851ee8bd6 100644 --- a/tests/models/edx/converters/xapi/test_navigational.py +++ b/tests/models/edx/converters/xapi/test_navigational.py @@ -15,7 +15,7 @@ @custom_given(UIPageClose, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_navigational_ui_page_close_to_page_terminated( +def test_models_edx_converters_xapi_navigational_ui_page_close_to_page_terminated( uuid_namespace, event, platform_url ): """Test that converting with UIPageCloseToPageTerminated returns the expected xAPI diff --git a/tests/models/edx/converters/xapi/test_server.py b/tests/models/edx/converters/xapi/test_server.py index df787b503..bc3e68f6a 100644 --- a/tests/models/edx/converters/xapi/test_server.py +++ b/tests/models/edx/converters/xapi/test_server.py @@ -15,7 +15,7 @@ @custom_given(Server, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_models_edx_converters_xapi_server_server_event_to_xapi_convert_constant_uuid( +def test_models_edx_converters_xapi_server_server_event_to_page_viewed_constant_uuid( uuid_namespace, event, platform_url ): """Test that `ServerEventToPageViewed.convert` returns a JSON string with a @@ -35,7 +35,7 @@ def test_models_edx_converters_xapi_server_server_event_to_xapi_convert_constant # pylint: disable=line-too-long @custom_given(Server, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_models_edx_converters_xapi_server_server_event_to_xapi_convert_with_valid_event( # noqa +def test_models_edx_converters_xapi_server_server_event_to_page_viewed( uuid_namespace, event, platform_url ): """Test that converting with `ServerEventToPageViewed` returns the expected xAPI @@ -74,7 +74,7 @@ def test_models_edx_converters_xapi_server_server_event_to_xapi_convert_with_val @settings(deadline=None) @custom_given(Server, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_models_edx_converters_xapi_server_server_event_to_xapi_convert_with_anonymous_user( # noqa +def test_models_edx_converters_xapi_server_server_event_to_page_viewed_with_anonymous_user( # noqa: E501, pylint:disable=line-too-long uuid_namespace, event, platform_url ): """Test that anonymous usernames are replaced with `anonymous`.""" diff --git a/tests/models/edx/converters/xapi/test_video.py b/tests/models/edx/converters/xapi/test_video.py index 4abcb8234..0c4367433 100644 --- a/tests/models/edx/converters/xapi/test_video.py +++ b/tests/models/edx/converters/xapi/test_video.py @@ -27,8 +27,10 @@ @custom_given(UILoadVideo, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_ui_load_video_to_video_initialized(uuid_namespace, event, platform_url): - """Test that converting with `UILoadVideoToVideoInitialized` returns the +def test_models_edx_converters_xapi_video_ui_load_video_to_video_initialized( + uuid_namespace, event, platform_url +): + """Tests that converting with `UILoadVideoToVideoInitialized` returns the expected xAPI statement. """ @@ -83,8 +85,10 @@ def test_ui_load_video_to_video_initialized(uuid_namespace, event, platform_url) @custom_given(UIPlayVideo, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_ui_play_video_to_video_played(uuid_namespace, event, platform_url): - """Test that converting with `UIPlayVideoToVideoPlayed` returns the expected +def test_models_edx_converters_xapi_video_ui_play_video_to_video_played( + uuid_namespace, event, platform_url +): + """Tests that converting with `UIPlayVideoToVideoPlayed` returns the expected xAPI statement. """ @@ -143,8 +147,10 @@ def test_ui_play_video_to_video_played(uuid_namespace, event, platform_url): @custom_given(UIPauseVideo, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_ui_pause_video_to_video_paused(uuid_namespace, event, platform_url): - """Test that converting with `UIPauseVideoToVideoPaused` returns the expected xAPI +def test_models_edx_converters_xapi_video_ui_pause_video_to_video_paused( + uuid_namespace, event, platform_url +): + """Tests that converting with `UIPauseVideoToVideoPaused` returns the expected xAPI statement. """ @@ -204,8 +210,10 @@ def test_ui_pause_video_to_video_paused(uuid_namespace, event, platform_url): @custom_given(UIStopVideo, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_ui_stop_video_to_video_terminated(uuid_namespace, event, platform_url): - """Test that converting with `UIStopVideoToVideoTerminated` returns the expected +def test_models_edx_converters_xapi_video_ui_stop_video_to_video_terminated( + uuid_namespace, event, platform_url +): + """Tests that converting with `UIStopVideoToVideoTerminated` returns the expected xAPI statement. """ @@ -266,8 +274,10 @@ def test_ui_stop_video_to_video_terminated(uuid_namespace, event, platform_url): @custom_given(UISeekVideo, provisional.urls()) @pytest.mark.parametrize("uuid_namespace", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_ui_seek_video_to_video_seeked(uuid_namespace, event, platform_url): - """Test that converting with `UISeekVideoToVideoSeeked` returns the expected +def test_models_edx_converters_xapi_video_ui_seek_video_to_video_seeked( + uuid_namespace, event, platform_url +): + """Tests that converting with `UISeekVideoToVideoSeeked` returns the expected xAPI statement. """ From 773a14cbd995d4c2f33c57e6555810ba19d7507a Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Tue, 5 Sep 2023 11:32:46 +0200 Subject: [PATCH 02/65] =?UTF-8?q?=F0=9F=94=A5(models)=20remove=20context?= =?UTF-8?q?=20extensions=20for=20base=20edx=20to=20xapi=20converter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `school`, `course` and `module` context extensions used in base edx-to-xapi converter can no longer be used. The tracking of the course information depends on the target statements templates used in xAPI profiles. --- CHANGELOG.md | 1 + src/ralph/models/edx/converters/xapi/base.py | 33 ---------- tests/models/edx/converters/xapi/test_base.py | 16 ----- .../edx/converters/xapi/test_enrollment.py | 2 - .../edx/converters/xapi/test_navigational.py | 2 - .../models/edx/converters/xapi/test_server.py | 2 - .../models/edx/converters/xapi/test_video.py | 15 ----- tests/models/test_converter.py | 64 ------------------- 8 files changed, 1 insertion(+), 134 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bee4bf1af..6a75a1a41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ have an authority field matching that of the user ### Removed +- `school`, `course`, `module` context extensions in Edx to xAPI base converter - `name` field in `VideoActivity` xAPI model mistakenly used in `video` profile ## [3.9.0] - 2023-07-21 diff --git a/src/ralph/models/edx/converters/xapi/base.py b/src/ralph/models/edx/converters/xapi/base.py index 1e8c43729..da7a413e8 100644 --- a/src/ralph/models/edx/converters/xapi/base.py +++ b/src/ralph/models/edx/converters/xapi/base.py @@ -1,17 +1,9 @@ """Base xAPI Converter.""" -import re from uuid import UUID, uuid5 from ralph.exceptions import ConfigurationException from ralph.models.converter import BaseConversionSet, ConversionItem -from ralph.models.xapi.concepts.constants.acrossx_profile import ( - CONTEXT_EXTENSION_SCHOOL_ID, -) -from ralph.models.xapi.concepts.constants.scorm_profile import ( - CONTEXT_EXTENSION_COURSE_ID, - CONTEXT_EXTENSION_MODULE_ID, -) class BaseXapiConverter(BaseConversionSet): @@ -52,30 +44,5 @@ def _get_conversion_items(self): "context__user_id", lambda user_id: str(user_id) if user_id else "anonymous", ), - ConversionItem( - "object__definition__extensions__" + CONTEXT_EXTENSION_SCHOOL_ID, - "context__org_id", - ), - ConversionItem( - "object__definition__extensions__" + CONTEXT_EXTENSION_COURSE_ID, - "context__course_id", - (self.parse_course_id, lambda x: x["course"]), - ), - ConversionItem( - "object__definition__extensions__" + CONTEXT_EXTENSION_MODULE_ID, - "context__course_id", - (self.parse_course_id, lambda x: x["module"]), - ), ConversionItem("timestamp", "time"), } - - @staticmethod - def parse_course_id(course_id: str): - """Parse edX event's `context`.`course_id`. - - Return a dictionary with `course` and `module`. - """ - match = re.match(r"^course-v1:.+\+(.+)\+(.+)$", course_id) - if not match: - return {"course": None, "module": None} - return {"course": match.group(1), "module": match.group(2)} diff --git a/tests/models/edx/converters/xapi/test_base.py b/tests/models/edx/converters/xapi/test_base.py index d65fa2ef2..f7babb58f 100644 --- a/tests/models/edx/converters/xapi/test_base.py +++ b/tests/models/edx/converters/xapi/test_base.py @@ -38,19 +38,3 @@ def _get_conversion_items(self): # pylint: disable=no-self-use with pytest.raises(ConfigurationException, match="Invalid UUID namespace"): DummyBaseXapiConverter(None, "https://fun-mooc.fr") - - -@pytest.mark.parametrize( - "course_id,expected", - [ - ("", {"course": None, "module": None}), - ("course-v1:+course+not_empty", {"course": None, "module": None}), - ("course-v1:org", {"course": None, "module": None}), - ("course-v1:org+course", {"course": None, "module": None}), - ("course-v1:org+course+", {"course": None, "module": None}), - ("course-v1:org+course+module", {"course": "course", "module": "module"}), - ], -) -def test_base_xapi_converter_parse_course_id(course_id, expected): - """Test that the parse_course_id method returns the expected value.""" - assert BaseXapiConverter.parse_course_id(course_id) == expected diff --git a/tests/models/edx/converters/xapi/test_enrollment.py b/tests/models/edx/converters/xapi/test_enrollment.py index 8cf28935f..6fb975827 100644 --- a/tests/models/edx/converters/xapi/test_enrollment.py +++ b/tests/models/edx/converters/xapi/test_enrollment.py @@ -29,7 +29,6 @@ def test_models_edx_converters_xapi_enrollment_edx_course_enrollment_activated_t """ event.event.course_id = "edX/DemoX/Demo_Course" - event.context.org_id = "" event.context.user_id = "1" event_str = event.json() event = json.loads(event_str) @@ -78,7 +77,6 @@ def test_models_edx_converters_xapi_enrollment_edx_course_enrollment_deactivated """ event.event.course_id = "edX/DemoX/Demo_Course" - event.context.org_id = "" event.context.user_id = "1" event_str = event.json() event = json.loads(event_str) diff --git a/tests/models/edx/converters/xapi/test_navigational.py b/tests/models/edx/converters/xapi/test_navigational.py index 851ee8bd6..011d1c622 100644 --- a/tests/models/edx/converters/xapi/test_navigational.py +++ b/tests/models/edx/converters/xapi/test_navigational.py @@ -21,8 +21,6 @@ def test_models_edx_converters_xapi_navigational_ui_page_close_to_page_terminate """Test that converting with UIPageCloseToPageTerminated returns the expected xAPI statement. """ - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event_str = event.json() event = json.loads(event_str) diff --git a/tests/models/edx/converters/xapi/test_server.py b/tests/models/edx/converters/xapi/test_server.py index bc3e68f6a..bd27a18de 100644 --- a/tests/models/edx/converters/xapi/test_server.py +++ b/tests/models/edx/converters/xapi/test_server.py @@ -42,8 +42,6 @@ def test_models_edx_converters_xapi_server_server_event_to_page_viewed( statement. """ event.event_type = "/main/blog" - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event_str = event.json() event = json.loads(event_str) diff --git a/tests/models/edx/converters/xapi/test_video.py b/tests/models/edx/converters/xapi/test_video.py index 0c4367433..ebafc0674 100644 --- a/tests/models/edx/converters/xapi/test_video.py +++ b/tests/models/edx/converters/xapi/test_video.py @@ -33,9 +33,6 @@ def test_models_edx_converters_xapi_video_ui_load_video_to_video_initialized( """Tests that converting with `UILoadVideoToVideoInitialized` returns the expected xAPI statement. """ - - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event.session = "af45a0e650c4a4fdb0bcde75a1e4b694" session_uuid = "af45a0e6-50c4-a4fd-b0bc-de75a1e4b694" @@ -91,9 +88,6 @@ def test_models_edx_converters_xapi_video_ui_play_video_to_video_played( """Tests that converting with `UIPlayVideoToVideoPlayed` returns the expected xAPI statement. """ - - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event.session = "af45a0e650c4a4fdb0bcde75a1e4b694" session_uuid = "af45a0e6-50c4-a4fd-b0bc-de75a1e4b694" @@ -153,9 +147,6 @@ def test_models_edx_converters_xapi_video_ui_pause_video_to_video_paused( """Tests that converting with `UIPauseVideoToVideoPaused` returns the expected xAPI statement. """ - - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event.session = "af45a0e650c4a4fdb0bcde75a1e4b694" session_uuid = "af45a0e6-50c4-a4fd-b0bc-de75a1e4b694" @@ -216,9 +207,6 @@ def test_models_edx_converters_xapi_video_ui_stop_video_to_video_terminated( """Tests that converting with `UIStopVideoToVideoTerminated` returns the expected xAPI statement. """ - - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event.session = "af45a0e650c4a4fdb0bcde75a1e4b694" session_uuid = "af45a0e6-50c4-a4fd-b0bc-de75a1e4b694" @@ -280,9 +268,6 @@ def test_models_edx_converters_xapi_video_ui_seek_video_to_video_seeked( """Tests that converting with `UISeekVideoToVideoSeeked` returns the expected xAPI statement. """ - - event.context.course_id = "" - event.context.org_id = "" event.context.user_id = "1" event.session = "af45a0e650c4a4fdb0bcde75a1e4b694" session_uuid = "af45a0e6-50c4-a4fd-b0bc-de75a1e4b694" diff --git a/tests/models/test_converter.py b/tests/models/test_converter.py index 2592c73a4..74d678c31 100644 --- a/tests/models/test_converter.py +++ b/tests/models/test_converter.py @@ -329,70 +329,6 @@ def test_converter_convert_with_an_event_missing_a_conversion_set_raises_an_exce list(result) -# pylint: disable=line-too-long -@pytest.mark.parametrize( - "event", - [json.dumps({"event_source": "browser", "event_type": "page_close"})], -) -@pytest.mark.parametrize("valid_uuid", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_converter_convert_with_an_invalid_page_close_event_writes_an_error_message( # noqa - event, - valid_uuid, - caplog, -): - """Test given an event that matches a pydantic model but fails at the conversion - step, the convert method should write an error message. - """ - result = Converter(platform_url="", uuid_namespace=valid_uuid).convert( - [event], ignore_errors=True, fail_on_unknown=True - ) - with caplog.at_level(logging.ERROR): - assert not list(result) - errors = ["Failed to get the transformed value for field: ('context', 'course_id')"] - assert errors == [message for _, _, message in caplog.record_tuples] - - -@pytest.mark.parametrize( - "event", - [json.dumps({"event_source": "browser", "event_type": "page_close"})], -) -@pytest.mark.parametrize("valid_uuid", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -def test_converter_convert_with_invalid_page_close_event_raises_an_exception( - event, valid_uuid, caplog -): - """Test given an event that matches a pydantic model but fails at the conversion - step, the convert method should raise a ConversionException. - """ - result = Converter(platform_url="", uuid_namespace=valid_uuid).convert( - [event], ignore_errors=False, fail_on_unknown=True - ) - with caplog.at_level(logging.ERROR): - with pytest.raises(ConversionException): - list(result) - - -@settings(deadline=None, suppress_health_check=(HealthCheck.function_scoped_fixture,)) -@pytest.mark.parametrize("valid_uuid", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) -@pytest.mark.parametrize("invalid_platform_url", ["", "not an URL"]) -@custom_given(UIPageClose) -def test_converter_convert_with_invalid_arguments_writes_an_error_message( - valid_uuid, invalid_platform_url, caplog, event -): - """Test given invalid arguments causing the conversion to fail at the validation - step, the convert method should write an error message. - """ - event_str = event.json() - result = Converter( - platform_url=invalid_platform_url, uuid_namespace=valid_uuid - ).convert([event_str], ignore_errors=True, fail_on_unknown=True) - with caplog.at_level(logging.ERROR): - assert not list(result) - model_name = "" - errors = f"Converted event is not a valid ({model_name}) model" - for _, _, message in caplog.record_tuples: - assert errors == message - - @settings(suppress_health_check=(HealthCheck.function_scoped_fixture,)) @pytest.mark.parametrize("valid_uuid", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) @pytest.mark.parametrize("invalid_platform_url", ["", "not an URL"]) From 6447fc54dabe6a2dbea135f61020977e0aa6fd98 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 00:58:04 +0000 Subject: [PATCH 03/65] =?UTF-8?q?=E2=AC=86=EF=B8=8F(project)=20upgrade=20p?= =?UTF-8?q?ython=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit | datasource | package | from | to | | ---------- | --------------- | ------ | ------ | | pypi | hypothesis | 6.87.1 | 6.87.3 | | pypi | mkdocs-material | 9.4.2 | 9.4.4 | | pypi | pylint | 2.17.7 | 3.0.1 | --- setup.cfg | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 47789938b..e92a11ce7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,17 +73,17 @@ dev = cryptography==41.0.4 factory-boy==3.3.0 flake8==6.1.0 - hypothesis==6.87.1 + hypothesis==6.87.3 isort==5.12.0 logging-gelf==0.0.31 mkdocs==1.5.3 mkdocs-click==0.8.1 - mkdocs-material==9.4.2 + mkdocs-material==9.4.4 mkdocstrings[python-legacy]==0.23.0 moto==4.2.5 pydocstyle==6.3.0 pyfakefs==5.2.4 - pylint==2.17.7 + pylint==3.0.1 pytest==7.4.2 pytest-asyncio==0.21.1 pytest-cov==4.1.0 From 8662eeca1563cecd3ce08191b52a7df273e56a1b Mon Sep 17 00:00:00 2001 From: SergioSim Date: Tue, 25 Apr 2023 11:03:57 +0200 Subject: [PATCH 04/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20data=20and=20?= =?UTF-8?q?lrs=20backend=20interfaces?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We intend to unify the storage and database backend interfaces into a single data backend interface. We also want to separate LRS-specific methods into a dedicated lrs backend interface that extends the data interface. --- .gitignore | 3 - src/ralph/backends/data/__init__.py | 1 + src/ralph/backends/data/base.py | 225 ++++++++++++++++++++++++++++ src/ralph/backends/lrs/__init__.py | 1 + src/ralph/backends/lrs/base.py | 55 +++++++ tests/backends/data/__init__.py | 0 tests/backends/data/test_base.py | 71 +++++++++ 7 files changed, 353 insertions(+), 3 deletions(-) create mode 100644 src/ralph/backends/data/__init__.py create mode 100644 src/ralph/backends/data/base.py create mode 100644 src/ralph/backends/lrs/__init__.py create mode 100644 src/ralph/backends/lrs/base.py create mode 100644 tests/backends/data/__init__.py create mode 100644 tests/backends/data/test_base.py diff --git a/.gitignore b/.gitignore index cab7bfd3a..66d638f84 100644 --- a/.gitignore +++ b/.gitignore @@ -51,9 +51,6 @@ venv.bak/ .pylint.d .pytest_cache -# Test fixtures -data/ - # Documentation site site/ diff --git a/src/ralph/backends/data/__init__.py b/src/ralph/backends/data/__init__.py new file mode 100644 index 000000000..6e031999e --- /dev/null +++ b/src/ralph/backends/data/__init__.py @@ -0,0 +1 @@ +# noqa: D104 diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py new file mode 100644 index 000000000..87e6638c6 --- /dev/null +++ b/src/ralph/backends/data/base.py @@ -0,0 +1,225 @@ +"""Base data backend for Ralph.""" + +import functools +import logging +from abc import ABC, abstractmethod +from enum import Enum, unique +from io import IOBase +from typing import Iterable, Iterator, Optional, Union + +from pydantic import BaseModel, BaseSettings, ValidationError + +from ralph.conf import BaseSettingsConfig, core_settings +from ralph.exceptions import BackendParameterException + +logger = logging.getLogger(__name__) + + +class BaseDataBackendSettings(BaseSettings): + """Represents the data backend default configuration.""" + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__" + env_file = ".env" + env_file_encoding = core_settings.LOCALE_ENCODING + + +class BaseQuery(BaseModel): + """Base query model.""" + + class Config: + """Base query model configuration.""" + + extra = "forbid" + + query_string: Optional[str] + + +@unique +class BaseOperationType(Enum): + """Base data backend operation types. + + Attributes: + INDEX (str): creates a new record with a specific ID. + CREATE (str): creates a new record without a specific ID. + DELETE (str): deletes an existing record. + UPDATE (str): updates or overwrites an existing record. + APPEND (str): creates or appends data to an existing record. + """ + + INDEX = "index" + CREATE = "create" + DELETE = "delete" + UPDATE = "update" + APPEND = "append" + + +@unique +class DataBackendStatus(Enum): + """Data backend statuses.""" + + OK = "ok" + AWAY = "away" + ERROR = "error" + + +def enforce_query_checks(method): + """Enforces query argument type checking for methods using it.""" + + @functools.wraps(method) + def wrapper(*args, **kwargs): + """Wrap method execution.""" + query = kwargs.pop("query", None) + self_ = args[0] + + return method(*args, query=self_.validate_query(query), **kwargs) + + return wrapper + + +class BaseDataBackend(ABC): + """Base data backend interface.""" + + name = "base" + query_model = BaseQuery + default_operation_type = BaseOperationType.INDEX + settings_class = BaseDataBackendSettings + + @abstractmethod + def __init__(self, settings: settings_class = None): + """Instantiates the data backend. + + Args: + settings (BaseDataBackendSettings or None): The backend settings. + If `settings` is `None`, a default settings instance is used instead. + """ + + def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: + """Validates and transforms the query.""" + if query is None: + query = self.query_model() + + if isinstance(query, str): + query = self.query_model(query_string=query) + + if isinstance(query, dict): + try: + query = self.query_model(**query) + except ValidationError as err: + raise BackendParameterException( + "The 'query' argument is expected to be a " + f"{self.query_model.__name__} instance. {err.errors()}" + ) from err + + if not isinstance(query, self.query_model): + raise BackendParameterException( + "The 'query' argument is expected to be a " + f"{self.query_model.__name__} instance." + ) + + logger.debug("Query: %s", str(query)) + + return query + + @abstractmethod + def status(self) -> DataBackendStatus: + """Implements data backend checks (e.g. connection, cluster status). + + Returns: + DataBackendStatus: The status of the data backend. + """ + + @abstractmethod + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """Lists containers in the data backend. E.g., collections, files, indexes. + + Args: + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + details (bool): Get detailed container information instead of just names. + new (bool): Given the history, list only not already read containers. + + Yields: + str: If `details` is False. + dict: If `details` is True. + + Raises: + BackendException: If a failure occurs. + BackendParameterException: If a backend argument value is not valid. + """ + + @abstractmethod + @enforce_query_checks + def read( + self, + *, + query: Union[str, BaseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Reads records matching the `query` in the `target` container and yields them. + + Args: + query: (str or BaseQuery): The query to select records to read. + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + chunk_size (int or None): The number of records or bytes to read in one + batch, depending on whether the records are dictionaries or bytes. + raw_output (bool): Controls whether to yield bytes or dictionaries. + If the records are dictionaries and `raw_output` is set to `True`, they + are encoded as JSON. + If the records are bytes and `raw_output` is set to `False`, they are + decoded as JSON by line. + ignore_errors (bool): If `True`, errors during the read operation + are be ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + + Yields: + dict: If `raw_output` is False. + bytes: If `raw_output` is True. + + Raises: + BackendException: If a failure during the read operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid. + """ + + @abstractmethod + def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Writes `data` records to the `target` container and returns their count. + + Args: + data: (Iterable or IOBase): The data to write. + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + chunk_size (int or None): The number of records or bytes to write in one + batch, depending on whether `data` contains dictionaries or bytes. + If `chunk_size` is `None`, a default value is used instead. + ignore_errors (bool): If `True`, errors during the write operation + are ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Returns: + int: The number of written records. + + Raises: + BackendException: If a failure during the write operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid. + """ diff --git a/src/ralph/backends/lrs/__init__.py b/src/ralph/backends/lrs/__init__.py new file mode 100644 index 000000000..6e031999e --- /dev/null +++ b/src/ralph/backends/lrs/__init__.py @@ -0,0 +1 @@ +# noqa: D104 diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py new file mode 100644 index 000000000..aa574544e --- /dev/null +++ b/src/ralph/backends/lrs/base.py @@ -0,0 +1,55 @@ +"""Base data backend for Ralph.""" + +from abc import abstractmethod +from dataclasses import dataclass +from datetime import datetime +from typing import Literal, Optional +from uuid import UUID + +from pydantic import BaseModel + +from ralph.backends.data.base import BaseDataBackend + + +@dataclass +class StatementQueryResult: + """Represents a common interface for results of an LRS statements query.""" + + statements: list[dict] + pit_id: str + search_after: str + + +class StatementParameters(BaseModel): + """Represents a dictionary of possible LRS query parameters.""" + + # pylint: disable=too-many-instance-attributes + + statementId: Optional[str] = None # pylint: disable=invalid-name + voidedStatementId: Optional[str] = None # pylint: disable=invalid-name + agent: Optional[str] = None + verb: Optional[str] = None + activity: Optional[str] = None + registration: Optional[UUID] = None + related_activities: Optional[bool] = False + related_agents: Optional[bool] = False + since: Optional[datetime] = None + until: Optional[datetime] = None + limit: Optional[int] = None + format: Optional[Literal["ids", "exact", "canonical"]] = "exact" + attachments: Optional[bool] = False + ascending: Optional[bool] = False + search_after: Optional[str] = None + pit_id: Optional[str] = None + + +class BaseLRSBackend(BaseDataBackend): + """Base LRS backend interface.""" + + @abstractmethod + def query_statements(self, params: StatementParameters) -> StatementQueryResult: + """Returns the statements query payload using xAPI parameters.""" + + @abstractmethod + def query_statements_by_ids(self, ids: list[str]) -> list: + """Returns the list of matching statement IDs from the database.""" diff --git a/tests/backends/data/__init__.py b/tests/backends/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/backends/data/test_base.py b/tests/backends/data/test_base.py new file mode 100644 index 000000000..3c12183f7 --- /dev/null +++ b/tests/backends/data/test_base.py @@ -0,0 +1,71 @@ +"""Tests for the base data backend""" + +import pytest + +from ralph.backends.data.base import BaseDataBackend, BaseQuery, enforce_query_checks +from ralph.exceptions import BackendParameterException + + +@pytest.mark.parametrize( + "value,expected", + [ + (None, BaseQuery()), + ("foo", BaseQuery(query_string="foo")), + (BaseQuery(query_string="foo"), BaseQuery(query_string="foo")), + ], +) +def test_backends_data_base_enforce_query_checks_with_valid_input(value, expected): + """Tests the enforce_query_checks function given valid input.""" + + class MockBaseDataBackend(BaseDataBackend): + """A class mocking the base database class.""" + + def __init__(self, settings=None): + """Instantiates the Mock data backend.""" + + @enforce_query_checks + def read(self, query=None): # pylint: disable=no-self-use,arguments-differ + """Mock the base database read method.""" + + assert query == expected + + def status(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + + def list(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + + def write(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + + MockBaseDataBackend().read(query=value) + + +@pytest.mark.parametrize("value", [[], {"foo": "bar"}]) +def test_backends_data_base_enforce_query_checks_with_invalid_input(value): + """Tests the enforce_query_checks function given invalid input.""" + + class MockBaseDataBackend(BaseDataBackend): + """A class mocking the base database class.""" + + def __init__(self, settings=None): + """Instantiates the Mock data backend.""" + + @enforce_query_checks + def read(self, query=None): # pylint: disable=no-self-use,arguments-differ + """Mock the base database read method.""" + + return None + + def status(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + + def list(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + + def write(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + + error = "The 'query' argument is expected to be a BaseQuery instance." + with pytest.raises(BackendParameterException, match=error): + MockBaseDataBackend().read(query=value) From a08adf71d7f79bd9d9a2311c212a11dcc0007af6 Mon Sep 17 00:00:00 2001 From: SergioSim Date: Tue, 25 Apr 2023 11:04:43 +0200 Subject: [PATCH 05/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20FSDataBackend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We add the FileSystem data backend implementation that is mostly taken from the existing FSStorage backend. --- src/ralph/backends/data/fs.py | 331 +++++++++++ src/ralph/backends/mixins.py | 15 +- tests/backends/data/test_fs.py | 988 +++++++++++++++++++++++++++++++++ tests/conftest.py | 1 + tests/fixtures/backends.py | 24 +- 5 files changed, 1351 insertions(+), 8 deletions(-) create mode 100644 src/ralph/backends/data/fs.py create mode 100644 tests/backends/data/test_fs.py diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py new file mode 100644 index 000000000..1cf89d2ea --- /dev/null +++ b/src/ralph/backends/data/fs.py @@ -0,0 +1,331 @@ +"""FileSystem data backend for Ralph.""" + +import json +import logging +import os +from datetime import datetime, timezone +from io import IOBase +from itertools import chain +from pathlib import Path +from typing import IO, Iterable, Iterator, Union +from uuid import uuid4 + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.backends.mixins import HistoryMixin +from ralph.conf import BaseSettingsConfig +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + +logger = logging.getLogger(__name__) + + +class FSDataBackendSettings(BaseDataBackendSettings): + """Represents the FileSystem data backend default configuration. + + Attributes: + DEFAULT_CHUNK_SIZE (int): The default chunk size for reading files. + DEFAULT_DIRECTORY_PATH (str or Path): The default target directory path where to + perform list, read and write operations. + DEFAULT_QUERY_STRING (str): The default query string to match files for the read + operation. + LOCALE_ENCODING (str): The encoding used for writing dictionaries to files. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__FS__" + + DEFAULT_CHUNK_SIZE: int = 4096 + DEFAULT_DIRECTORY_PATH: Path = Path(".") + DEFAULT_QUERY_STRING: str = "*" + LOCALE_ENCODING: str = "utf8" + + +class FSDataBackend(HistoryMixin, BaseDataBackend): + """FileSystem data backend.""" + + name = "fs" + default_operation_type = BaseOperationType.CREATE + settings_class = FSDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Creates the default target directory if it does not exist.""" + settings = settings if settings else self.settings_class() + self.default_chunk_size = settings.DEFAULT_CHUNK_SIZE + self.default_directory = settings.DEFAULT_DIRECTORY_PATH + self.default_query_string = settings.DEFAULT_QUERY_STRING + self.locale_encoding = settings.LOCALE_ENCODING + + if not self.default_directory.is_dir(): + msg = "Default directory doesn't exist, creating: %s" + logger.info(msg, self.default_directory) + self.default_directory.mkdir(parents=True) + + logger.debug("Default directory: %s", self.default_directory) + + def status(self) -> DataBackendStatus: + """Checks whether the default directory has appropriate permissions.""" + for mode in [os.R_OK, os.W_OK, os.X_OK]: + if not os.access(self.default_directory, mode): + logger.error( + "Invalid permissions for the default directory at %s. " + "The directory should have read, write and execute permissions.", + str(self.default_directory.absolute()), + ) + return DataBackendStatus.ERROR + + return DataBackendStatus.OK + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """Lists files and directories in the target directory. + + Args: + target (str or None): The directory path where to list the files and + directories. + If target is `None`, the `default_directory` is used instead. + If target is a relative path, it is considered to be relative to the + `default_directory_path`. + details (bool): Get detailed file information instead of just file paths. + new (bool): Given the history, list only not already read files. + + Yields: + str: The next file path. (If details is False). + dict: The next file details. (If details is True). + + Raises: + BackendParameterException: If the `target` argument is not a directory path. + """ + target = Path(target) if target else self.default_directory + if not target.is_absolute() and target != self.default_directory: + target = self.default_directory / target + try: + paths = set(target.absolute().iterdir()) + except OSError as error: + msg = "Invalid target argument" + logger.error("%s. %s", msg, error) + raise BackendParameterException(msg, error.strerror) from error + + logger.debug("Found %d files", len(paths)) + + if new: + paths -= set(map(Path, self.get_command_history(self.name, "read"))) + logger.debug("New files: %d", len(paths)) + + if not details: + for path in paths: + yield str(path) + + return + + for path in paths: + stats = path.stat() + modified_at = datetime.fromtimestamp(int(stats.st_mtime), tz=timezone.utc) + yield { + "path": str(path), + "size": stats.st_size, + "modified_at": modified_at.isoformat(), + } + + @enforce_query_checks + def read( + self, + *, + query: Union[str, BaseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Reads files matching the query in the target folder and yields them. + + Args: + query: (str or BaseQuery): The relative pattern for the files to read. + target (str or None): The target directory path containing the files. + If target is `None`, the `default_directory_path` is used instead. + If target is a relative path, it is considered to be relative to the + `default_directory_path`. + chunk_size (int or None): The chunk size for reading files. Ignored if + `raw_output` is set to False. + raw_output (bool): Controls whether to yield bytes or dictionaries. + ignore_errors (bool): If `True`, errors during the read operation + will be ignored and logged. If `False` (default), a `BackendException` + will be raised if an error occurs. + + Yields: + bytes: The next chunk of the read files if `raw_output` is True. + dict: The next JSON parsed line of the read files if `raw_output` is False. + + Raises: + BackendException: If a failure during the read operation occurs and + `ignore_errors` is set to `False`. + """ + if not query.query_string: + query.query_string = self.default_query_string + + if not chunk_size: + chunk_size = self.default_chunk_size + + target = Path(target) if target else self.default_directory + if not target.is_absolute() and target != self.default_directory: + target = self.default_directory / target + paths = list( + filter(lambda path: path.is_file(), target.glob(query.query_string)) + ) + + if not paths: + logger.info("No file found for query: %s", target / query.query_string) + return + + logger.debug("Reading matching files: %s", paths) + + for path in paths: + with path.open("rb") as file: + reader = self._read_raw if raw_output else self._read_dict + for chunk in reader(file, chunk_size, ignore_errors): + yield chunk + + # The file has been read, add a new entry to the history. + self.append_to_history( + { + "backend": self.name, + "action": "read", + # WARNING: previously only the file name was used as the ID + # By changing this to the absolute file path, previously fetched + # files will not be marked as read anymore. + "id": str(path.absolute()), + "filename": path.name, + "size": path.stat().st_size, + "timestamp": now(), + } + ) + + def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Writes data records to the target file and return their count. + + Args: + data: (Iterable or IOBase): The data to write. + target (str or None): The target file path. + If target is a relative path, it is considered to be relative to the + `default_directory_path`. + If target is `None`, a random (uuid4) file is created in the + `default_directory_path` and used as the target instead. + chunk_size (int or None): Ignored. + ignore_errors (bool): Ignored. + operation_type (BaseOperationType or None): The mode of the write operation. + If operation_type is `CREATE` or `INDEX`, the target file is expected to + be absent. If the target file exists a `FileExistsError` is raised. + If operation_type is `UPDATE`, the target file is overwritten. + If operation_type is `APPEND`, the data is appended to the + end of the target file. + + Returns: + int: The number of written files. + + Raises: + BackendException: If the `operation_type` is `CREATE` or `INDEX` and the + target file already exists. + BackendParameterException: If the `operation_type` is `DELETED` as it is not + supported. + """ + data = iter(data) + try: + first_record = next(data) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return 0 + if not operation_type: + operation_type = self.default_operation_type + + if operation_type == BaseOperationType.DELETE: + msg = "Delete operation_type is not allowed." + logger.error(msg) + raise BackendParameterException(msg) + + if not target: + target = f"{now()}-{uuid4()}" + logger.info("Target file not specified; using random file name: %s", target) + + target = Path(target) + path = target if target.is_absolute() else self.default_directory / target + + if operation_type in [BaseOperationType.CREATE, BaseOperationType.INDEX]: + if path.is_file(): + msg = ( + "%s already exists and overwrite is not allowed with operation_type" + " create or index." + ) + logger.error(msg, path) + raise BackendException(msg % path) + + logger.debug("Creating file: %s", path) + + mode = "wb" + if operation_type == BaseOperationType.APPEND: + mode = "ab" + logger.debug("Appending to file: %s", path) + + with path.open(mode) as file: + is_dict = isinstance(first_record, dict) + writer = self._write_dict if is_dict else self._write_raw + for chunk in chain((first_record,), data): + writer(file, chunk) + + # The file has been created, add a new entry to the history. + self.append_to_history( + { + "backend": self.name, + "action": "write", + # WARNING: previously only the file name was used as the ID + # By changing this to the absolute file path, previously written + # files will not be marked as written anymore. + "id": str(path.absolute()), + "filename": path.name, + "size": path.stat().st_size, + "timestamp": now(), + } + ) + return 1 + + @staticmethod + def _read_raw(file: IO, chunk_size: int, _ignore_errors: bool) -> Iterator[bytes]: + """Reads the `file` in chunks of size `chunk_size` and yields them.""" + while chunk := file.read(chunk_size): + yield chunk + + @staticmethod + def _read_dict(file: IO, _chunk_size: int, ignore_errors: bool) -> Iterator[dict]: + """Reads the `file` by line and yields JSON parsed dictionaries.""" + for i, line in enumerate(file): + try: + yield json.loads(line) + except (TypeError, json.JSONDecodeError) as err: + msg = "Raised error: %s, in file %s at line %s" + logger.error(msg, err, file, i) + if not ignore_errors: + raise BackendException(msg % (err, file, i)) from err + + @staticmethod + def _write_raw(file: IO, chunk: bytes) -> None: + """Writes the `chunk` bytes to the file.""" + file.write(chunk) + + def _write_dict(self, file: IO, chunk: dict) -> None: + """Writes the `chunk` dictionary to the file.""" + file.write(bytes(f"{json.dumps(chunk)}\n", encoding=self.locale_encoding)) diff --git a/src/ralph/backends/mixins.py b/src/ralph/backends/mixins.py index 08bfde136..ae6e53ac4 100644 --- a/src/ralph/backends/mixins.py +++ b/src/ralph/backends/mixins.py @@ -59,11 +59,14 @@ def append_to_history(self, event): self.write_history(self.history + [event]) def get_command_history(self, backend_name, command): - """Extract entry ids from the history for a given command and backend_name.""" - return [ - entry["id"] - for entry in filter( - lambda e: e["backend"] == backend_name and e["command"] == command, - self.history, + """Extracts entry ids from the history for a given command and backend_name.""" + + def filter_by_name_and_command(entry): + """Checks whether the history entry matches the backend_name and command.""" + return entry.get("backend") == backend_name and ( + command in [entry.get("command"), entry.get("action")] ) + + return [ + entry["id"] for entry in filter(filter_by_name_and_command, self.history) ] diff --git a/tests/backends/data/test_fs.py b/tests/backends/data/test_fs.py new file mode 100644 index 000000000..8845710f6 --- /dev/null +++ b/tests/backends/data/test_fs.py @@ -0,0 +1,988 @@ +"""Tests for Ralph fs data backend""" +import json +import logging +import os +from collections.abc import Iterable +from operator import itemgetter +from uuid import uuid4 + +import pytest + +from ralph.backends.data.base import BaseOperationType, BaseQuery, DataBackendStatus +from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + + +def test_backends_data_fs_data_backend_default_instantiation(monkeypatch, fs): + """Tests the `FSDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "DEFAULT_CHUNK_SIZE", + "DEFAULT_DIRECTORY_PATH", + "DEFAULT_QUERY_STRING", + "LOCALE_ENCODING", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__FS__{name}", raising=False) + + assert FSDataBackend.name == "fs" + assert FSDataBackend.query_model == BaseQuery + assert FSDataBackend.default_operation_type == BaseOperationType.CREATE + assert FSDataBackend.settings_class == FSDataBackendSettings + backend = FSDataBackend() + assert str(backend.default_directory) == "." + assert backend.default_query_string == "*" + assert backend.default_chunk_size == 4096 + assert backend.locale_encoding == "utf8" + + +def test_backends_data_fs_data_backend_instantiation_with_settings(fs): + """Tests the `FSDataBackend` instantiation with settings.""" + # pylint: disable=invalid-name,unused-argument + deep_path = "deep/directories/path" + assert not os.path.exists(deep_path) + settings = FSDataBackend.settings_class( + DEFAULT_DIRECTORY_PATH=deep_path, + DEFAULT_QUERY_STRING="foo.txt", + DEFAULT_CHUNK_SIZE=1, + LOCALE_ENCODING="utf-16", + ) + backend = FSDataBackend(settings) + assert os.path.exists(deep_path) + assert str(backend.default_directory) == deep_path + assert backend.default_directory.is_dir() + assert backend.default_query_string == "foo.txt" + assert backend.default_chunk_size == 1 + assert backend.locale_encoding == "utf-16" + + try: + FSDataBackend(settings) + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"FSDataBackend should not raise exceptions: {err}") + + +@pytest.mark.parametrize( + "mode", + [0o007, 0o100, 0o200, 0o300, 0o400, 0o500, 0o600], +) +def test_backends_data_fs_data_backend_status_method_with_error_status( + mode, fs_backend, caplog +): + """Tests the `FSDataBackend.status` method, given a directory with wrong + permissions, should return `DataBackendStatus.ERROR`. + """ + os.mkdir("directory", mode) + with caplog.at_level(logging.ERROR): + assert fs_backend(path="directory").status() == DataBackendStatus.ERROR + + assert ( + "ralph.backends.data.fs", + logging.ERROR, + "Invalid permissions for the default directory at /directory. " + "The directory should have read, write and execute permissions.", + ) in caplog.record_tuples + + +@pytest.mark.parametrize("mode", [0o700]) +def test_backends_data_fs_data_backend_status_method_with_ok_status(mode, fs_backend): + """Tests the `FSDataBackend.status` method, given a directory with right + permissions, should return `DataBackendStatus.OK`. + """ + os.mkdir("directory", mode) + assert fs_backend(path="directory").status() == DataBackendStatus.OK + + +@pytest.mark.parametrize( + "files,target,error", + [ + # Given a `target` that is a file, the `list` method should raise a + # `BackendParameterException`. + (["foo/file_1"], "file_1", "Invalid target argument', 'Not a directory"), + # Given a `target` that does not exists, the `list` method should raise a + # `BackendParameterException`. + (["foo/file_1"], "bar", "Invalid target argument', 'No such file or directory"), + ], +) +def test_backends_data_fs_data_backend_list_method_with_invalid_target( + files, target, error, fs_backend, fs +): + """Tests the `FSDataBackend.list` method given an invalid `target` argument should + raise a `BackendParameterException`. + """ + # pylint: disable=invalid-name + for file in files: + fs.create_file(file) + + backend = fs_backend() + with pytest.raises(BackendParameterException, match=error): + list(backend.list(target)) + + +@pytest.mark.parametrize( + "files,target,expected", + [ + # Given an empty default directory, the `list` method should yield nothing. + ([], None, []), + # Given a default directory containing one file, the `list` method should yield + # the absolute path of the file. + (["foo/file_1"], None, ["/foo/file_1"]), + # Given a relative `target` directory containing one file, the `list` method + # should yield the absolute path of the file. + (["/foo/bar/file_1"], "bar", ["/foo/bar/file_1"]), + # Given a default directory containing two files, the `list` method should yield + # the absolute paths of the files. + (["foo/file_1", "foo/file_2"], None, ["/foo/file_1", "/foo/file_2"]), + # Given a `target` directory containing two files, the `list` method should + # yield the absolute paths of the files. + (["bar/file_1", "bar/file_2"], "/bar", ["/bar/file_1", "/bar/file_2"]), + ], +) +def test_backends_data_fs_data_backend_list_method_without_history( + files, target, expected, fs_backend, fs +): + """Tests the `FSDataBackend.list` method without history.""" + # pylint: disable=invalid-name + for file in files: + fs.create_file(file) + + backend = fs_backend() + result = backend.list(target) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + +@pytest.mark.parametrize( + "files,target,expected", + [ + # Given an empty default directory, the `list` method should yield nothing. + ([], None, []), + # Given a default directory containing one file, the `list` method should yield + # a dictionary containing the absolute path of the file. + (["foo/file_1"], None, ["/foo/file_1"]), + # Given a relative `target` directory containing one file, the `list` method + # should yield a dictionary containing the absolute path of the file. + (["/foo/bar/file_1"], "bar", ["/foo/bar/file_1"]), + # Given a default directory containing two files, the `list` method should yield + # dictionaries containing the absolute paths of the files. + (["foo/file_1", "foo/file_2"], None, ["/foo/file_1", "/foo/file_2"]), + # Given a `target` directory containing two files, the `list` method should + # yield dictionaries containing the absolute paths of the files. + (["bar/file_1", "bar/file_2"], "/bar", ["/bar/file_1", "/bar/file_2"]), + ], +) +def test_backends_data_fs_data_backend_list_method_with_details( + files, target, expected, fs_backend, fs +): + """Tests the `FSDataBackend.list` method with `details` set to `True`.""" + # pylint: disable=invalid-name,too-many-arguments + for file in files: + fs.create_file(file) + os.utime(file, (1, 1)) + + backend = fs_backend() + result = backend.list(target, details=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("path")) == [ + {"path": file, "size": 0, "modified_at": "1970-01-01T00:00:01+00:00"} + for file in expected + ] + + +def test_backends_data_fs_data_backend_list_method_with_history(fs_backend, fs): + """Tests the `FSDataBackend.list` method with history.""" + # pylint: disable=invalid-name + + # Create 3 files in the default directory. + fs.create_file("foo/file_1") + fs.create_file("foo/file_2") + fs.create_file("foo/file_3") + + backend = fs_backend() + + # Given an empty history and `new` set to `True`, the `list` method should yield all + # files in the directory. + expected = ["/foo/file_1", "/foo/file_2", "/foo/file_3"] + result = backend.list(new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + # Add file_1 to history + backend.history.append( + { + "backend": "fs", + "action": "read", + "id": "/foo/file_1", + "filename": "file_1", + "size": 0, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing one matching file and `new` set to `True`, the + # `list` method should yield all files in the directory except the matching file. + expected = ["/foo/file_2", "/foo/file_3"] + result = backend.list(new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + # Add file_2 to history + backend.history.append( + { + "backend": "fs", + "action": "read", + "id": "/foo/file_2", + "filename": "file_2", + "size": 0, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing two matching files and `new` set to `True`, the + # `list` method should yield all files in the directory except the matching files. + expected = ["/foo/file_3"] + result = backend.list(new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + # Add file_3 to history + backend.history.append( + { + "backend": "fs", + "action": "read", + "id": "/foo/file_3", + "filename": "file_3", + "size": 0, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing all matching files and `new` set to `True`, the `list` + # method should yield nothing. + expected = [] + result = backend.list(new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + +def test_backends_data_fs_data_backend_list_method_with_history_and_details( + fs_backend, fs +): + """Tests the `FSDataBackend.list` method with an history and detailed output.""" + # pylint: disable=invalid-name + + # Create 3 files in the default directory. + fs.create_file("foo/file_1") + os.utime("foo/file_1", (1, 1)) + fs.create_file("foo/file_2") + os.utime("foo/file_2", (1, 1)) + fs.create_file("foo/file_3") + os.utime("foo/file_3", (1, 1)) + + backend = fs_backend() + + # Given an empty history and `new` and `details` set to `True`, the `list` method + # should yield all files in the directory with additional details. + expected = [ + {"path": file, "size": 0, "modified_at": "1970-01-01T00:00:01+00:00"} + for file in ["/foo/file_1", "/foo/file_2", "/foo/file_3"] + ] + result = backend.list(details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("path")) == expected + + # Add file_1 to history + backend.history.append( + { + "backend": "fs", + "action": "read", + "id": "/foo/file_1", + "filename": "file_1", + "size": 0, + "timestamp": "1970-01-01T00:00:01+00:00", + } + ) + + # Given a history containing one matching file and `new` and `details` set to + # `True`, the `list` method should yield all files in the directory with additional + # details, except for the matching file. + expected = [ + {"path": file, "size": 0, "modified_at": "1970-01-01T00:00:01+00:00"} + for file in ["/foo/file_2", "/foo/file_3"] + ] + result = backend.list(details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("path")) == expected + + # Add file_2 to history + backend.history.append( + { + "backend": "fs", + "action": "read", + "id": "/foo/file_2", + "filename": "file_2", + "size": 0, + "timestamp": "1970-01-01T00:00:01+00:00", + } + ) + + # Given a history containing two matching files and `new` and `details` set to + # `True`, the `list` method should yield all files in the directory with additional + # details, except for the matching files. + expected = [ + {"path": file, "size": 0, "modified_at": "1970-01-01T00:00:01+00:00"} + for file in ["/foo/file_3"] + ] + result = backend.list(details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("path")) == expected + + # Add file_3 to history + backend.history.append( + { + "backend": "fs", + "action": "read", + "id": "/foo/file_3", + "filename": "file_3", + "size": 0, + "timestamp": "1970-01-01T00:00:01+00:00", + } + ) + + # Given a history containing all matching files and `new` and `details` set to + # `True`, the `list` method should yield nothing. + expected = [] + result = backend.list(details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("path")) == expected + + +def test_backends_data_fs_data_backend_read_method_with_raw_ouput( + fs_backend, fs, monkeypatch +): + """Tests the `FSDataBackend.read` method with `raw_output` set to `True`.""" + # pylint: disable=invalid-name + + # Create files in absolute path directory. + absolute_path = "/tmp/test_fs/" + fs.create_file(absolute_path + "file_1.txt", contents="foo") + fs.create_file(absolute_path + "file_2.txt", contents="bar") + + # Create files in default directory. + fs.create_file("foo/file_3.txt", contents="baz") + fs.create_file("foo/bar/file_4.txt", contents="qux") + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.fs.now", lambda: frozen_now) + + backend = fs_backend() + + # Given no `target`, the `read` method should read all files in the default + # directory and yield bytes. + result = backend.read(raw_output=True) + assert isinstance(result, Iterable) + assert list(result) == [b"baz"] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history == [ + { + "backend": "fs", + "action": "read", + "id": "/foo/file_3.txt", + "filename": "file_3.txt", + "size": 3, + "timestamp": frozen_now, + } + ] + + # Given an absolute `target` path, the `read` method should read all files in the + # target directory and yield bytes. + result = backend.read(raw_output=True, target=absolute_path) + assert isinstance(result, Iterable) + assert list(result) == [b"foo", b"bar"] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history[-2:] == [ + { + "backend": "fs", + "action": "read", + "id": "/tmp/test_fs/file_1.txt", + "filename": "file_1.txt", + "size": 3, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "read", + "id": "/tmp/test_fs/file_2.txt", + "filename": "file_2.txt", + "size": 3, + "timestamp": frozen_now, + }, + ] + + # Given a relative `target` path, the `read` method should read all files in the + # target directory relative to the default directory and yield bytes. + result = backend.read(raw_output=True, target="./bar") + assert isinstance(result, Iterable) + assert list(result) == [b"qux"] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history[-1:] == [ + { + "backend": "fs", + "action": "read", + "id": "/foo/bar/file_4.txt", + "filename": "file_4.txt", + "size": 3, + "timestamp": frozen_now, + }, + ] + + # Given a `chunk_size` and an absolute `target` path, + # the `read` method should write the output bytes in chunks of the specified + # `chunk_size`. + result = backend.read(raw_output=True, target=absolute_path, chunk_size=2) + assert isinstance(result, Iterable) + assert list(result) == [b"fo", b"o", b"ba", b"r"] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history[-2:] == [ + { + "backend": "fs", + "action": "read", + "id": "/tmp/test_fs/file_1.txt", + "filename": "file_1.txt", + "size": 3, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "read", + "id": "/tmp/test_fs/file_2.txt", + "filename": "file_2.txt", + "size": 3, + "timestamp": frozen_now, + }, + ] + + +def test_backends_data_fs_data_backend_read_method_without_raw_output( + fs_backend, fs, monkeypatch +): + """Tests the `FSDataBackend.read` method with `raw_output` set to `False`.""" + # pylint: disable=invalid-name + + # File contents. + valid_dictionary = {"foo": "bar"} + valid_json = json.dumps(valid_dictionary) + + # Create files in absolute path directory. + absolute_path = "/tmp/test_fs/" + fs.create_file(absolute_path + "file_1.txt", contents=valid_json) + + # Create files in default directory. + fs.create_file("foo/file_2.txt", contents=f"{valid_json}\n{valid_json}") + fs.create_file( + "foo/bar/file_3.txt", contents=f"{valid_json}\n{valid_json}\n{valid_json}" + ) + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.fs.now", lambda: frozen_now) + + backend = fs_backend() + + # Given no `target`, the `read` method should read all files in the default + # directory and yield dictionaries. + result = backend.read(raw_output=False) + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary, valid_dictionary] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history == [ + { + "backend": "fs", + "action": "read", + "id": "/foo/file_2.txt", + "filename": "file_2.txt", + "size": 29, + "timestamp": frozen_now, + } + ] + + # Given an absolute `target` path, the `read` method should read all files in the + # target directory and yield dictionaries. + result = backend.read(raw_output=False, target=absolute_path) + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history[-1:] == [ + { + "backend": "fs", + "action": "read", + "id": "/tmp/test_fs/file_1.txt", + "filename": "file_1.txt", + "size": 14, + "timestamp": frozen_now, + } + ] + + # Given a relative `target` path, the `read` method should read all files in the + # target directory relative to the default directory and yield dictionaries. + result = backend.read(raw_output=False, target="bar") + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary, valid_dictionary, valid_dictionary] + + # When the `read` method is called successfully, then a new entry should be added to + # the history. + assert backend.history[-1:] == [ + { + "backend": "fs", + "action": "read", + "id": "/foo/bar/file_3.txt", + "filename": "file_3.txt", + "size": 44, + "timestamp": frozen_now, + } + ] + + +def test_backends_data_fs_data_backend_read_method_with_ignore_errors(fs_backend, fs): + """Tests the `FSDataBackend.read` method with `ignore_errors` set to `True`, given + a file containing invalid JSON lines, should skip the invalid lines. + """ + # pylint: disable=invalid-name + + # File contents. + valid_dictionary = {"foo": "bar"} + valid_json = json.dumps(valid_dictionary) + invalid_json = "baz" + valid_invalid_json = f"{valid_json}\n{invalid_json}\n{valid_json}" + invalid_valid_jdon = f"{invalid_json}\n{valid_json}\n{invalid_json}" + + # Create files in absolute path directory. + absolute_path = "/tmp/test_fs/" + fs.create_file(absolute_path + "file_1.txt", contents=valid_json) + fs.create_file(absolute_path + "file_2.txt", contents=invalid_json) + + # Create files in default directory. + fs.create_file("foo/file_3.txt", contents=valid_invalid_json) + fs.create_file("foo/bar/file_4.txt", contents=invalid_valid_jdon) + + backend = fs_backend() + + # Given no `target`, the `read` method should read all files in the default + # directory and yield dictionaries. + result = backend.read(ignore_errors=True) + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary, valid_dictionary] + + # Given an absolute `target` path, the `read` method should read all files in the + # target directory and yield dictionaries. + result = backend.read(ignore_errors=True, target=absolute_path) + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary] + + # Given a relative `target` path, the `read` method should read all files in the + # target directory relative to the default directory and yield dictionaries. + result = backend.read(ignore_errors=True, target="bar") + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary] + + +def test_backends_data_fs_data_backend_read_method_without_ignore_errors( + fs_backend, fs, monkeypatch +): + """Tests the `FSDataBackend.read` method with `ignore_errors` set to `False`, given + a file containing invalid JSON lines, should raise a `BackendException`. + """ + # pylint: disable=invalid-name + + # File contents. + valid_dictionary = {"foo": "bar"} + valid_json = json.dumps(valid_dictionary) + invalid_json = "baz" + valid_invalid_json = f"{valid_json}\n{invalid_json}\n{valid_json}" + invalid_valid_jdon = f"{invalid_json}\n{valid_json}\n{invalid_json}" + + # Create files in absolute path directory. + absolute_path = "/tmp/test_fs/" + fs.create_file(absolute_path + "file_1.txt", contents=valid_json) + fs.create_file(absolute_path + "file_2.txt", contents=invalid_json) + + # Create files in default directory. + fs.create_file("foo/file_3.txt", contents=valid_invalid_json) + fs.create_file("foo/bar/file_4.txt", contents=invalid_valid_jdon) + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.fs.now", lambda: frozen_now) + + backend = fs_backend() + + # Given no `target`, the `read` method should read all files in the default + # directory. + # Given one file in the default directory with an invalid json at the second line, + # the `read` method should yield the first valid line and raise a `BackendException` + # at the second line. + result = backend.read(ignore_errors=False) + assert isinstance(result, Iterable) + assert next(result) == valid_dictionary + with pytest.raises(BackendException, match="Raised error:"): + next(result) + + # When the `read` method fails to read a file entirely, then no entry should be + # added to the history. + assert not backend.history + + # Given an absolute `target` path, the `read` method should read all files in the + # target directory. + # Given two files in the target directory, the first containing valid json and the + # second containing invalid json, the `read` method should yield the content of the + # first valid file and raise a `BackendException` when reading the invalid file. + result = backend.read(ignore_errors=False, target=absolute_path) + assert isinstance(result, Iterable) + assert next(result) == valid_dictionary + with pytest.raises(BackendException, match="Raised error:"): + next(result) + + # When the `read` method succeeds to read one file entirely, and fails to read + # another file, then a new entry for the succeeded file should be added to the + # history. + assert backend.history == [ + { + "backend": "fs", + "action": "read", + "id": "/tmp/test_fs/file_1.txt", + "filename": "file_1.txt", + "size": 14, + "timestamp": frozen_now, + } + ] + + # Given a relative `target` path, the `read` method should read all files in the + # target directory relative to the default directory. + # Given one file in the relative target directory with an invalid json at the first + # line, the `read` method should raise a `BackendException`. + result = backend.read(ignore_errors=False, target="bar") + assert isinstance(result, Iterable) + with pytest.raises(BackendException, match="Raised error:"): + next(result) + + # When the `read` method fails to read a file entirely, then no new entry should be + # added to the history. + assert len(backend.history) == 1 + + +def test_backends_data_fs_data_backend_read_method_with_query(fs_backend, fs): + """Tests the `FSDataBackend.read` method, given a query argument.""" + # pylint: disable=invalid-name + + # File contents. + valid_dictionary = {"foo": "bar"} + valid_json = json.dumps(valid_dictionary) + invalid_json = "invalid JSON" + + # Create files in absolute path directory. + absolute_path = "/tmp/test_fs/" + fs.create_file(absolute_path + "file_1.txt", contents=invalid_json) + fs.create_file(absolute_path + "file_2.txt", contents=valid_json) + + # Create files in default directory. + default_path = "foo/" + fs.create_file(default_path + "file_3.txt", contents=valid_json) + fs.create_file(default_path + "file_4.txt", contents=valid_json) + fs.create_file(default_path + "/bar/file_5.txt", contents=invalid_json) + + backend = fs_backend() + + # Given a `query` and no `target`, the `read` method should only read the files that + # match the query in the default directory and yield dictionaries. + result = backend.read(query="file_*") + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary, valid_dictionary] + + # Given a `query` and an absolute `target`, the `read` method should only read the + # files that match the query and yield dictionaries. + result = backend.read(query="file_2*", target=absolute_path) + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary] + + # Given a `query`, no `target` and `raw_output` set to `True`, the `read` method + # should only read the files that match the query in the default directory and yield + # bytes. + result = backend.read(query="*file*", raw_output=True) + assert isinstance(result, Iterable) + assert list(result) == [valid_json.encode(), valid_json.encode()] + # A relative query should behave in the same way. + result = backend.read(query="bar/file_*", raw_output=True) + assert isinstance(result, Iterable) + assert list(result) == [invalid_json.encode()] + + # Given a `query` that does not match any file, the `read` method should not yield + # anything. + result = backend.read(query="file_not_found") + assert isinstance(result, Iterable) + assert not list(result) + + +@pytest.mark.parametrize( + "operation_type", [None, BaseOperationType.CREATE, BaseOperationType.INDEX] +) +def test_backends_data_fs_data_backend_write_method_with_file_exists_error( + operation_type, fs_backend, fs +): + """Tests the `FSDataBackend.write` method, given a target matching an + existing file and a `CREATE` or `INDEX` `operation_type`, should raise a + `BackendException`. + """ + # pylint: disable=invalid-name + + # Create files in default directory. + fs.create_file("foo/foo.txt", contents="content") + + backend = fs_backend() + + msg = ( + "foo.txt already exists and overwrite is not allowed with operation_type create" + " or index." + ) + with pytest.raises(BackendException, match=msg): + backend.write(target="foo.txt", data=[b"foo"], operation_type=operation_type) + + # When the `write` method fails, then no entry should be added to the history. + assert not sorted(backend.history, key=itemgetter("id")) + + +def test_backends_data_fs_data_backend_write_method_with_delete_operation( + fs_backend, +): + """Tests the `FSDataBackend.write` method, given a `DELETE` `operation_type`, should + raise a `BackendParameterException`. + """ + # pylint: disable=invalid-name + backend = fs_backend() + + msg = "Delete operation_type is not allowed." + with pytest.raises(BackendParameterException, match=msg): + backend.write(data=[b"foo"], operation_type=BaseOperationType.DELETE) + + # When the `write` method fails, then no entry should be added to the history. + assert not sorted(backend.history, key=itemgetter("id")) + + +def test_backends_data_fs_data_backend_write_method_with_update_operation( + fs_backend, fs, monkeypatch +): + """Tests the `FSDataBackend.write` method, given an `UPDATE` `operation_type`, + should overwrite the target file content with the provided data. + """ + # pylint: disable=invalid-name + + # Create files in default directory. + fs.create_file("foo/foo.txt", contents="content") + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.fs.now", lambda: frozen_now) + + backend = fs_backend() + kwargs = {"operation_type": BaseOperationType.UPDATE} + + # Overwriting foo.txt. + assert list(backend.read(query="foo.txt", raw_output=True)) == [b"content"] + assert backend.write(data=[b"bar"], target="foo.txt", **kwargs) == 1 + + # When the `write` method is called successfully, then a new entry should be added + # to the history. + assert backend.history == [ + { + "backend": "fs", + "action": "read", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": 7, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "write", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": 3, + "timestamp": frozen_now, + }, + ] + assert list(backend.read(query="foo.txt", raw_output=True)) == [b"bar"] + + # Clearing foo.txt. + assert backend.write(data=[b""], target="foo.txt", **kwargs) == 1 + assert not list(backend.read(query="foo.txt", raw_output=True)) + + # When the `write` method is called successfully, then a new entry should be added + # to the history. + assert backend.history[-2:] == [ + { + "backend": "fs", + "action": "write", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": 0, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "read", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": 0, + "timestamp": frozen_now, + }, + ] + + # Creating bar.txt. + assert backend.write(data=[b"baz"], target="bar.txt", **kwargs) == 1 + assert list(backend.read(query="bar.txt", raw_output=True)) == [b"baz"] + + # When the `write` method is called successfully, then a new entry should be added + # to the history. + assert backend.history[-2:] == [ + { + "backend": "fs", + "action": "write", + "id": "/foo/bar.txt", + "filename": "bar.txt", + "size": 3, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "read", + "id": "/foo/bar.txt", + "filename": "bar.txt", + "size": 3, + "timestamp": frozen_now, + }, + ] + + +@pytest.mark.parametrize( + "data,expected", + [ + ([b"bar"], [b"foobar"]), + ([b"bar", b"baz"], [b"foobarbaz"]), + ((b"bar" for _ in range(1)), [b"foobar"]), + ((b"bar" for _ in range(3)), [b"foobarbarbar"]), + ( + [{}, {"foo": [1, 2, 4], "bar": {"baz": None}}], + [b'foo{}\n{"foo": [1, 2, 4], "bar": {"baz": null}}\n'], + ), + ], +) +def test_backends_data_fs_data_backend_write_method_with_append_operation( + data, expected, fs_backend, fs, monkeypatch +): + """Tests the `FSDataBackend.write` method, given an `APPEND` `operation_type`, + should append the provided data to the end of the target file. + """ + # pylint: disable=invalid-name + + # Create files in default directory. + fs.create_file("foo/foo.txt", contents="foo") + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.fs.now", lambda: frozen_now) + + backend = fs_backend() + kwargs = {"operation_type": BaseOperationType.APPEND} + + # Overwriting foo.txt. + assert list(backend.read(query="foo.txt", raw_output=True)) == [b"foo"] + assert backend.write(data=data, target="foo.txt", **kwargs) == 1 + assert list(backend.read(query="foo.txt", raw_output=True)) == expected + + # When the `write` method is called successfully, then a new entry should be added + # to the history. + assert backend.history == [ + { + "backend": "fs", + "action": "read", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": 3, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "write", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": len(expected[0]), + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "read", + "id": "/foo/foo.txt", + "filename": "foo.txt", + "size": len(expected[0]), + "timestamp": frozen_now, + }, + ] + + +def test_backends_data_fs_data_backend_write_method_without_target( + fs_backend, monkeypatch +): + """Tests the `FSDataBackend.write` method, given no `target` argument, + should create a new random file and write the provided data into it. + """ + # pylint: disable=invalid-name + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.fs.now", lambda: frozen_now) + + # Freeze the uuid4() value. + frozen_uuid4 = uuid4() + monkeypatch.setattr("ralph.backends.data.fs.uuid4", lambda: frozen_uuid4) + + backend = fs_backend(path=".") + + expected_filename = f"{frozen_now}-{frozen_uuid4}" + assert not os.path.exists(expected_filename) + assert backend.write(data=[b"foo", b"bar"]) == 1 + assert os.path.exists(expected_filename) + assert list(backend.read(query=expected_filename, raw_output=True)) == [b"foobar"] + + # When the `write` method is called successfully, then a new entry should be added + # to the history. + assert backend.history == [ + { + "backend": "fs", + "action": "write", + "id": f"/{expected_filename}", + "filename": expected_filename, + "size": 6, + "timestamp": frozen_now, + }, + { + "backend": "fs", + "action": "read", + "id": f"/{expected_filename}", + "filename": expected_filename, + "size": 6, + "timestamp": frozen_now, + }, + ] diff --git a/tests/conftest.py b/tests/conftest.py index 3e1754b31..52a1cbef4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ es_data_stream, es_forwarding, events, + fs_backend, lrs, mongo, mongo_forwarding, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 5a21f2f42..301f51dea 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -22,12 +22,13 @@ from pymongo import MongoClient from pymongo.errors import CollectionInvalid +from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.database.clickhouse import ClickHouseDatabase from ralph.backends.database.es import ESDatabase from ralph.backends.database.mongo import MongoDatabase from ralph.backends.storage.s3 import S3Storage from ralph.backends.storage.swift import SwiftStorage -from ralph.conf import ClickhouseClientOptions, Settings, settings +from ralph.conf import ClickhouseClientOptions, Settings, core_settings # ClickHouse backend defaults CLICKHOUSE_TEST_DATABASE = os.environ.get( @@ -159,6 +160,25 @@ def es_forwarding(): yield es_client +@pytest.fixture +def fs_backend(fs, settings_fs): + """Returns the `get_fs_data_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name,unused-argument + fs.create_dir("foo") + + def get_fs_data_backend(path: str = "foo"): + """Returns an instance of FSDataBackend.""" + settings = FSDataBackendSettings( + DEFAULT_CHUNK_SIZE=1024, + DEFAULT_DIRECTORY_PATH=path, + DEFAULT_QUERY_STRING="*", + LOCALE_ENCODING="utf8", + ) + return FSDataBackend(settings) + + return get_fs_data_backend + + def get_mongo_fixture( connection_uri=MONGO_TEST_CONNECTION_URI, database=MONGO_TEST_DATABASE, @@ -311,7 +331,7 @@ def settings_fs(fs, monkeypatch): monkeypatch.setattr( "ralph.backends.mixins.settings", - Settings(HISTORY_FILE=Path(settings.APP_DIR / "history.json")), + Settings(HISTORY_FILE=Path(core_settings.APP_DIR / "history.json")), ) From 05f502538d47aba0712114b91cd0fcf2be354dda Mon Sep 17 00:00:00 2001 From: Arnaud Henric Date: Fri, 12 May 2023 10:15:43 +0200 Subject: [PATCH 06/65] =?UTF-8?q?=E2=9C=A8(backends)=20fix=20base=20List?= =?UTF-8?q?=20type=20using=20pydantic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updating List type to use List of Pydantic and handle Python3.8 --- src/ralph/backends/lrs/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index aa574544e..4857bcbac 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -3,7 +3,7 @@ from abc import abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Literal, Optional +from typing import List, Literal, Optional from uuid import UUID from pydantic import BaseModel @@ -15,7 +15,7 @@ class StatementQueryResult: """Represents a common interface for results of an LRS statements query.""" - statements: list[dict] + statements: List[dict] pit_id: str search_after: str @@ -51,5 +51,5 @@ def query_statements(self, params: StatementParameters) -> StatementQueryResult: """Returns the statements query payload using xAPI parameters.""" @abstractmethod - def query_statements_by_ids(self, ids: list[str]) -> list: + def query_statements_by_ids(self, ids: List[str]) -> list: """Returns the list of matching statement IDs from the database.""" From de601dd7310138463659b0e645bff99741a61199 Mon Sep 17 00:00:00 2001 From: Arnaud Henric Date: Fri, 12 May 2023 11:42:28 +0200 Subject: [PATCH 07/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20mongo=20unifi?= =?UTF-8?q?ed=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Mongo backend under the new common 'data' interface. With Mongo under the new data interface, tests are updated as well. Storage and Database backends had similar interfaces and usage, so a new Data Backend interface has been created. --- src/ralph/backends/data/mongo.py | 475 ++++++++++++ tests/backends/data/test_mongo.py | 1116 +++++++++++++++++++++++++++++ 2 files changed, 1591 insertions(+) create mode 100644 src/ralph/backends/data/mongo.py create mode 100644 tests/backends/data/test_mongo.py diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py new file mode 100644 index 000000000..befbb94c4 --- /dev/null +++ b/src/ralph/backends/data/mongo.py @@ -0,0 +1,475 @@ +"""MongoDB data backend for Ralph.""" + +import hashlib +import json +import logging +import struct +from io import IOBase +from itertools import chain +from typing import ( + Any, + Dict, + Generator, + Iterable, + Iterator, + List, + Literal, + Optional, + Union, +) +from uuid import uuid4 + +from bson.objectid import ObjectId +from dateutil.parser import isoparse +from pydantic import Json +from pymongo import ASCENDING, DESCENDING, MongoClient, ReplaceOne +from pymongo.errors import BulkWriteError, ConnectionFailure, PyMongoError + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.backends.lrs.base import ( + BaseLRSBackend, + StatementParameters, + StatementQueryResult, +) +from ralph.conf import BaseSettingsConfig, MongoClientOptions +from ralph.exceptions import ( + BackendException, + BackendParameterException, + BadFormatException, +) + +logger = logging.getLogger(__name__) + + +class MongoDataBackendSettings(BaseDataBackendSettings): + """Represents the Mongo data backend default configuration. + + Attributes: + CONNECTION_URI (str): The MongoDB connection URI. + DATABASE (str): The MongoDB database to connect to. + DEFAULT_COLLECTION (str): The MongoDB database collection to get objects from. + CLIENT_OPTIONS (MongoClientOptions): A dictionary of valid options + DEFAULT_QUERY_STRING (str): The default query string to use. + DEFAULT_CHUNK_SIZE (int): The default chunk size to use when none is provided. + LOCALE_ENCODING (str): The locale encoding to use when none is provided. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__MONGO__" + + CONNECTION_URI: str = None + DATABASE: str = None + DEFAULT_COLLECTION: str = None + CLIENT_OPTIONS: MongoClientOptions = MongoClientOptions() + DEFAULT_QUERY_STRING: str = "*" + DEFAULT_CHUNK_SIZE: int = 500 + LOCALE_ENCODING: str = "utf8" + + +class MongoQuery(BaseQuery): + """Mongo query model.""" + + # pylint: disable=unsubscriptable-object + query_string: Optional[ + Json[ + Dict[ + Literal["filter", "projection"], + dict, + ] + ] + ] + filter: Optional[dict] + projection: Optional[dict] + + +class MongoDataBackend(BaseDataBackend): + """Mongo database backend.""" + + name = "mongo" + query_model = MongoQuery + default_operation_type = BaseOperationType.CREATE + settings_class = MongoDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Instantiates the Mongo client. + + Args: + settings (MongoDataBackendSettings): The Mongo data backend settings. + CONNECTION_URI (str): The MongoDB connection URI. + DATABASE (str): The MongoDB database to connect to. + DEFAULT_COLLECTION (str): The MongoDB database collection. + CLIENT_OPTIONS (MongoClientOptions): A dictionary of valid options + DEFAULT_QUERY_STRING (str): The default query string to use. + DEFAULT_CHUNK_SIZE (int): The default chunk size to use. + LOCALE_ENCODING (str): The locale encoding to use when none is provided. + """ + self.client = MongoClient( + settings.CONNECTION_URI, **settings.CLIENT_OPTIONS.dict() + ) + self.database = getattr(self.client, settings.DATABASE) + self.collection = getattr(self.database, settings.DEFAULT_COLLECTION) + self.default_chunk_size = settings.DEFAULT_CHUNK_SIZE + self.locale_encoding = settings.LOCALE_ENCODING + + def status(self) -> DataBackendStatus: + """Checks MongoDB cluster connection status.""" + # Check Mongo cluster connection + try: + self.client.admin.command("ping") + except ConnectionFailure: + return DataBackendStatus.AWAY + + # Check cluster status + if self.client.admin.command("serverStatus").get("ok", 0.0) < 1.0: + return DataBackendStatus.ERROR + + return DataBackendStatus.OK + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """Lists collections for a given database. + + Args: + target (str): The database to list collections from. + details (bool): Get detailed archive information instead of just ids. + new (bool): Given the history, list only not already fetched collections. + """ + database = self.database if not target else getattr(self.client, target) + for col in database.list_collections(): + if details: + yield col + else: + yield str(col.get("name")) + + @enforce_query_checks + def read( + self, + *, + query: Union[str, MongoQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Gets collection documents and yields them. + + Args: + query (Union[str, MongoQuery]): The query to use when fetching documents. + target (str): The collection to get documents from. + chunk_size (Union[None, int]): The chunk size to use when fetching docs. + raw_output (bool): Whether to return raw bytes or deserialized documents. + ignore_errors (bool): Whether to ignore errors when reading documents. + """ + reader = self._read_raw if raw_output else self._read_dict + if not chunk_size: + chunk_size = self.default_chunk_size + find_kwargs = {} + if query.query_string: + find_kwargs = query.query_string + else: + find_kwargs = {"filter": query.filter, "projection": query.projection} + + # deserialize query_string if exists + for document in self._find(target=target, batch_size=chunk_size, **find_kwargs): + document.update({"_id": str(document.get("_id"))}) + yield reader(document) + + @staticmethod + def to_documents( + data: Iterable[dict], + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = default_operation_type, + ) -> Generator[dict, None, None]: + """Converts `stream` lines (one statement per line) to Mongo documents. + + We expect statements to have at least an `id` and a `timestamp` field that will + be used to compute a unique MongoDB Object ID. This ensures that we will not + duplicate statements in our database and allows us to support pagination. + """ + for statement in data: + if "id" not in statement and operation_type == BaseOperationType.CREATE: + msg = f"statement {statement} has no 'id' field" + if ignore_errors: + logger.warning(msg) + continue + raise BadFormatException(msg) + if "timestamp" not in statement: + msg = f"statement {statement} has no 'timestamp' field" + if ignore_errors: + logger.warning(msg) + continue + raise BadFormatException(msg) + try: + timestamp = int(isoparse(statement["timestamp"]).timestamp()) + except ValueError as err: + msg = f"statement {statement} has an invalid 'timestamp' field" + if ignore_errors: + logger.warning(msg) + continue + raise BadFormatException(msg) from err + document = { + "_id": ObjectId( + # This might become a problem in February 2106. + # Meanwhile, we use the timestamp in the _id field for pagination. + struct.pack(">I", timestamp) + + bytes.fromhex( + hashlib.sha256( + bytes(statement.get("id", str(uuid4())), "utf-8") + ).hexdigest()[:16] + ) + ), + "_source": statement, + } + + yield document + + def bulk_import(self, batch: list, ignore_errors: bool = False, collection=None): + """Inserts a batch of documents into the selected database collection.""" + try: + collection = self.get_collection(collection) + new_documents = collection.insert_many(batch) + except BulkWriteError as error: + if not ignore_errors: + raise BackendException( + *error.args, f"{error.details['nInserted']} succeeded writes" + ) from error + logger.warning( + "Bulk importation failed for current documents chunk but you choose " + "to ignore it.", + ) + return error.details["nInserted"] + + inserted_count = len(new_documents.inserted_ids) + logger.debug("Inserted %d documents chunk with success", inserted_count) + + return inserted_count + + def bulk_delete(self, batch: list, collection=None): + """Deletes a batch of documents from the selected database collection.""" + collection = self.get_collection(collection) + new_documents = collection.delete_many({"_source.id": {"$in": batch}}) + deleted_count = new_documents.deleted_count + logger.debug("Deleted %d documents chunk with success", deleted_count) + + return deleted_count + + def bulk_update(self, batch: list, collection=None): + """Update a batch of documents into the selected database collection.""" + collection = self.get_collection(collection) + new_documents = collection.bulk_write(batch) + modified_count = new_documents.modified_count + logger.debug("Updated %d documents chunk with success", modified_count) + return modified_count + + def get_collection(self, collection=None): + """Returns the collection to use for the current operation.""" + if collection is None: + collection = self.collection + elif isinstance(collection, str): + collection = getattr(self.database, collection) + return collection + + def write( # pylint: disable=too-many-arguments disable=too-many-branches + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Writes documents from the `stream` to the instance collection. + + Args: + data: The data to write to the database. + target: The target collection to write to. + chunk_size: The number of documents to write at once. + ignore_errors: Whether to ignore errors or not. + operation_type: The operation type to use for the write operation. + """ + if not operation_type: + operation_type = self.default_operation_type + + if not chunk_size: + chunk_size = self.default_chunk_size + + collection = self.get_collection(target) + logger.debug( + "Start writing to the %s collection of the %s database (chunk size: %d)", + collection, + self.database, + chunk_size, + ) + + data = iter(data) + try: + first_record = next(data) + data = chain([first_record], data) + if isinstance(first_record, bytes): + data = self._parse_bytes_to_dict(data, ignore_errors) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return 0 + + success = 0 + batch = [] + if operation_type == BaseOperationType.UPDATE: + for document in data: + document_id = document.get("id") + batch.append( + ReplaceOne( + {"_source.id": {"$eq": document_id}}, + {"_source": document}, + ) + ) + if len(batch) >= chunk_size: + success += self.bulk_update(batch, collection=collection) + batch = [] + + if len(batch) > 0: + success += self.bulk_update(batch, collection=collection) + + logger.debug("Updated %d documents chunk with success", success) + elif operation_type == BaseOperationType.DELETE: + for document in data: + document_id = document.get("id") + batch.append(document_id) + if len(batch) >= chunk_size: + success += self.bulk_delete(batch, collection=collection) + batch = [] + + if len(batch) > 0: + success += self.bulk_delete(batch, collection=collection) + + logger.debug("Deleted %d documents chunk with success", success) + elif operation_type in [BaseOperationType.INDEX, BaseOperationType.CREATE]: + for document in self.to_documents( + data, ignore_errors=ignore_errors, operation_type=operation_type + ): + batch.append(document) + if len(batch) >= chunk_size: + success += self.bulk_import( + batch, ignore_errors=ignore_errors, collection=collection + ) + batch = [] + + # Edge case: if the total number of documents is lower than the chunk size + if len(batch) > 0: + success += self.bulk_import( + batch, ignore_errors=ignore_errors, collection=collection + ) + + logger.debug("Inserted %d documents with success", success) + else: + msg = "%s operation_type is not allowed." + logger.error(msg, operation_type.name) + raise BackendParameterException(msg % operation_type.name) + return success + + def _find(self, target: Union[None, str] = None, **kwargs): + """Wraps the MongoClient.collection.find method. + + Raises: + BackendException: raised for any failure. + """ + try: + collection = self.get_collection(target) + return list(collection.find(**kwargs)) + except (PyMongoError, IndexError, TypeError, ValueError) as error: + msg = "Failed to execute MongoDB query" + logger.error("%s. %s", msg, error) + raise BackendException(msg, *error.args) from error + + @staticmethod + def _parse_bytes_to_dict( + raw_documents: Iterable[bytes], ignore_errors: bool + ) -> Iterator[dict]: + """Reads the `raw_documents` Iterable and yields dictionaries.""" + for raw_document in raw_documents: + try: + decoded_item = raw_document.decode("utf-8") + json_data = json.loads(decoded_item) + yield json_data + except (TypeError, json.JSONDecodeError) as err: + logger.error("Raised error: %s, for document %s", err, raw_document) + if ignore_errors: + continue + raise err + + def _read_raw(self, document: Dict[str, Any]) -> bytes: + """Reads the `documents` Iterable and yields bytes.""" + return json.dumps(document).encode(self.locale_encoding) + + @staticmethod + def _read_dict(document: Dict[str, Any]) -> dict: + """Reads the `documents` Iterable and yields dictionaries.""" + return document + + +class MongoLRSBackend(BaseLRSBackend, MongoDataBackend): + """MongoDB LRS backend implementation.""" + + def query_statements(self, params: StatementParameters) -> StatementQueryResult: + """Returns the results of a statements query using xAPI parameters.""" + mongo_query_filters = {} + + if params.statementId: + mongo_query_filters.update({"_source.id": params.statementId}) + + if params.agent: + mongo_query_filters.update({"_source.actor.account.name": params.agent}) + + if params.verb: + mongo_query_filters.update({"_source.verb.id": params.verb}) + + if params.activity: + mongo_query_filters.update( + { + "_source.object.objectType": "Activity", + "_source.object.id": params.activity, + }, + ) + + if params.since: + mongo_query_filters.update({"_source.timestamp": {"$gt": params.since}}) + + if params.until: + mongo_query_filters.update({"_source.timestamp": {"$lte": params.until}}) + + if params.search_after: + search_order = "$gt" if params.ascending else "$lt" + mongo_query_filters.update( + {"_id": {search_order: ObjectId(params.search_after)}} + ) + + mongo_sort_order = ASCENDING if params.ascending else DESCENDING + mongo_query_sort = [ + ("_source.timestamp", mongo_sort_order), + ("_id", mongo_sort_order), + ] + + mongo_response = self._find( + filter=mongo_query_filters, limit=params.limit, sort=mongo_query_sort + ) + search_after = None + if mongo_response: + search_after = mongo_response[-1]["_id"] + + return StatementQueryResult( + statements=[document["_source"] for document in mongo_response], + pit_id=None, + search_after=search_after, + ) + + def query_statements_by_ids(self, ids: List[str]) -> List: + """Returns the list of matching statement IDs from the database.""" + return self._find(filter={"_source.id": {"$in": ids}}) diff --git a/tests/backends/data/test_mongo.py b/tests/backends/data/test_mongo.py new file mode 100644 index 000000000..5fb7425c8 --- /dev/null +++ b/tests/backends/data/test_mongo.py @@ -0,0 +1,1116 @@ +# pylint: disable=too-many-lines +"""Tests for Ralph mongo data backend.""" + +import json +import logging +from datetime import datetime + +import pytest +from bson.objectid import ObjectId +from pymongo import MongoClient +from pymongo.errors import PyMongoError + +from ralph.backends.data.base import BaseOperationType, DataBackendStatus +from ralph.backends.data.mongo import MongoDataBackend, MongoLRSBackend, MongoQuery +from ralph.backends.lrs.base import StatementParameters +from ralph.exceptions import ( + BackendException, + BackendParameterException, + BadFormatException, +) + +from tests.fixtures.backends import ( + MONGO_TEST_COLLECTION, + MONGO_TEST_CONNECTION_URI, + MONGO_TEST_DATABASE, + MONGO_TEST_FORWARDING_COLLECTION, +) + + +def test_backends_data_mongo_data_backend_instantiation_with_settings(): + """Test the Mongo backend instantiation.""" + assert MongoDataBackend.name == "mongo" + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + assert isinstance(backend.client, MongoClient) + assert hasattr(backend.client, MONGO_TEST_DATABASE) + database = getattr(backend.client, MONGO_TEST_DATABASE) + assert hasattr(database, MONGO_TEST_COLLECTION) + + +def test_backends_data_mongo_data_backend_read_method_without_raw_output(mongo): + """Test the mongo backend get method.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + expected = [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + ] + assert list(backend.read()) == expected + assert list(backend.read(chunk_size=2)) == expected + assert list(backend.read(chunk_size=1000)) == expected + + +def test_backends_data_mongo_data_backend_read_method_with_query_string(mongo): + """Test the mongo backend get method with query string.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + expected = [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + ] + query = MongoQuery( + query_string=json.dumps({"filter": {"_source.id": {"$eq": "foo"}}}) + ) + assert list(backend.read(query=query)) == expected + assert list(backend.read(query=query, chunk_size=2)) == expected + assert list(backend.read(query=query, chunk_size=1000)) == expected + + +def test_backends_data_mongo_data_backend_list_method(mongo): + """Test the mongo backend list method.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + assert list(backend.list(details=True))[0]["name"] == MONGO_TEST_COLLECTION + assert list(backend.list(details=False)) == [MONGO_TEST_COLLECTION] + + +def test_backends_data_mongo_data_backend_list_method_with_details(mongo): + """Test the mongo backend list method.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + assert [elt["_id"] for elt in list(backend.read())] == [ + "62b9ce922c26b46b68ffc68f", + "62b9ce92fcde2b2edba56bf4", + ] + + +def test_backends_data_mongo_data_backend_list_method_with_target(mongo): + """Test the mongo backend list method.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + assert [elt["_id"] for elt in list(backend.read(target=MONGO_TEST_COLLECTION))] == [ + "62b9ce922c26b46b68ffc68f", + "62b9ce92fcde2b2edba56bf4", + ] + + +def test_backends_database_mongo_get_method_with_raw_ouput(mongo): + """Test the mongo backend get method with raw output.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + expected = [ + {"_id": "62b9ce922c26b46b68ffc68f", "id": "foo", **timestamp}, + {"_id": "62b9ce92fcde2b2edba56bf4", "id": "bar", **timestamp}, + ] + results = list(backend.read(raw_output=True)) + assert len(results) == 2 + assert isinstance(results[0], bytes) + assert json.loads(results[0])["_source"]["id"] == expected[0]["id"] + + +def test_backends_database_mongo_get_method_with_target(mongo): + """Test the mongo backend get method with raw output.""" + # Create records + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + expected = [ + {"_id": "62b9ce922c26b46b68ffc68f", "id": "foo", **timestamp}, + {"_id": "62b9ce92fcde2b2edba56bf4", "id": "bar", **timestamp}, + ] + results = list(backend.read(raw_output=True, target=MONGO_TEST_COLLECTION)) + assert len(results) == 2 + assert isinstance(results[0], bytes) + assert json.loads(results[0])["_source"]["id"] == expected[0]["id"] + + +def test_backends_data_mongo_data_backend_read_method_with_query(mongo): + """Test the mongo backend get method with a custom query.""" + # Create records + timestamp = {"timestamp": datetime.now().isoformat()} + documents = MongoDataBackend.to_documents( + [ + {"id": "foo", "bool": 1, **timestamp}, + {"id": "bar", "bool": 0, **timestamp}, + {"id": "lol", "bool": 1, **timestamp}, + ] + ) + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + collection.insert_many(documents) + + # Get backend + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + # Test filtering + query = MongoQuery(filter={"_source.bool": {"$eq": 1}}) + results = list(backend.read(query=query)) + assert len(results) == 2 + assert results[0]["_source"]["id"] == "foo" + assert results[1]["_source"]["id"] == "lol" + + # Test projection + query = MongoQuery(projection={"_source.bool": 1}) + results = list(backend.read(query=query)) + assert len(results) == 3 + assert list(results[0]["_source"].keys()) == ["bool"] + assert list(results[1]["_source"].keys()) == ["bool"] + assert list(results[2]["_source"].keys()) == ["bool"] + + # Test filtering and projection + query = MongoQuery( + filter={"_source.bool": {"$eq": 0}}, projection={"_source.id": 1} + ) + results = list(backend.read(query=query)) + assert len(results) == 1 + assert results[0]["_source"]["id"] == "bar" + assert list(results[0]["_source"].keys()) == ["id"] + + +def test_backends_database_mongo_to_documents_method(): + """Test the mongo backend to_documents method.""" + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "bar", **timestamp}, + ] + documents = MongoDataBackend.to_documents(statements) + + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(documents) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + # Identical statement ID produces the same ObjectId + assert next(documents) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + +def test_backends_database_mongo_to_documents_method_when_statement_has_no_id(caplog): + """Test the mongo backend to_documents method when a statement has no id field.""" + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, timestamp, {"id": "bar", **timestamp}] + + documents = MongoDataBackend.to_documents(statements, ignore_errors=False) + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + with pytest.raises( + BadFormatException, match=f"statement {timestamp} has no 'id' field" + ): + next(documents) + + documents = MongoDataBackend.to_documents(statements, ignore_errors=True) + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(documents) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert caplog.records[0].message == f"statement {timestamp} has no 'id' field" + + +def test_backends_database_mongo_to_documents_method_when_statement_has_no_timestamp( + caplog, +): + """Tests the mongo backend to_documents method when a statement has no timestamp.""" + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar"}, {"id": "baz", **timestamp}] + + documents = MongoDataBackend.to_documents(statements, ignore_errors=False) + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + + with pytest.raises( + BadFormatException, match="statement {'id': 'bar'} has no 'timestamp' field" + ): + next(documents) + + documents = MongoDataBackend.to_documents(statements, ignore_errors=True) + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(documents) == { + "_id": ObjectId("62b9ce92baa5a0964d3320fb"), + "_source": {"id": "baz", **timestamp}, + } + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert caplog.records[0].message == ( + "statement {'id': 'bar'} has no 'timestamp' field" + ) + + +def test_backends_database_mongo_to_documents_method_with_invalid_timestamp(caplog): + """Tests the mongo backend to_documents method given a statement with an invalid + timestamp. + """ + valid_timestamp = {"timestamp": "2022-06-27T15:36:50"} + invalid_timestamp = {"timestamp": "This is not a valid timestamp!"} + invalid_statement = {"id": "bar", **invalid_timestamp} + statements = [ + {"id": "foo", **valid_timestamp}, + invalid_statement, + {"id": "baz", **valid_timestamp}, + ] + + documents = MongoDataBackend.to_documents(statements, ignore_errors=False) + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **valid_timestamp}, + } + + with pytest.raises( + BadFormatException, + match=f"statement {invalid_statement} has an invalid 'timestamp' field", + ): + next(documents) + + documents = MongoDataBackend.to_documents(statements, ignore_errors=True) + assert next(documents) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **valid_timestamp}, + } + assert next(documents) == { + "_id": ObjectId("62b9ce92baa5a0964d3320fb"), + "_source": {"id": "baz", **valid_timestamp}, + } + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert caplog.records[0].message == ( + f"statement {invalid_statement} has an invalid 'timestamp' field" + ) + + +def test_backends_database_mongo_bulk_import_method(mongo): + """Test the mongo backend bulk_import method.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + backend.bulk_import(MongoDataBackend.to_documents(statements)) + + results = backend.collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + +def test_backends_database_mongo_bulk_delete_method(mongo): + """Test the mongo backend bulk_delete method.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + backend.bulk_import(MongoDataBackend.to_documents(statements)) + documents = [st["id"] for st in statements] + backend.bulk_delete(batch=documents) + + results = backend.collection.find() + assert next(results, None) is None + + +def test_backends_database_mongo_bulk_update_method(mongo): + """Test the mongo backend bulk_update method.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + backend.bulk_import(MongoDataBackend.to_documents(statements)) + statements = [ + {"id": "foo", "text": "foo", **timestamp}, + {"id": "bar", "text": "bar", **timestamp}, + ] + success = backend.write(data=statements, operation_type=BaseOperationType.UPDATE) + assert success == 2 + + results = backend.collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", "text": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", "text": "bar", **timestamp}, + } + + +def test_backends_database_mongo_bulk_update_method_iterable(mongo): + """Test the mongo backend bulk_update method.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + backend.bulk_import(MongoDataBackend.to_documents(statements)) + statements = [ + {"id": "foo", "text": "foo", **timestamp}, + {"id": "bar", "text": "bar", **timestamp}, + ] + statements = iter(statements) + success = backend.write(data=statements, operation_type=BaseOperationType.UPDATE) + assert success == 2 + results = backend.collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", "text": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", "text": "bar", **timestamp}, + } + + +def test_backends_database_mongo_bulk_wrong_operation_type(mongo): + """Test the mongo backend bulk_update method.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + backend.bulk_import(MongoDataBackend.to_documents(statements)) + statements = [ + {"id": "foo", "text": "foo", **timestamp}, + {"id": "bar", "text": "bar", **timestamp}, + ] + + with pytest.raises( + BackendParameterException, + match=f"{BaseOperationType.APPEND.name} operation_type is not allowed.", + ): + backend.write(data=statements, operation_type=BaseOperationType.APPEND) + + +def test_backends_database_mongo_bulk_no_data(mongo): + """Test the mongo backend bulk_update method.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + success = backend.write(data=[], operation_type=BaseOperationType.CREATE) + + assert success == 0 + + +def test_backends_database_mongo_bulk_import_method_with_duplicated_key(mongo): + """Test the mongo backend bulk_import method with a duplicated key conflict.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + # Identical statement ID produces the same ObjectId, leading to a + # duplicated key write error while trying to bulk import this batch + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "bar", **timestamp}, + ] + documents = list(MongoDataBackend.to_documents(statements)) + with pytest.raises(BackendException, match="E11000 duplicate key error collection"): + backend.bulk_import(documents) + + success = backend.bulk_import(documents, ignore_errors=True) + assert success == 0 + + +def test_backends_database_mongo_bulk_import_method_import_partial_chunks_on_error( + mongo, +): + """Test the mongo backend bulk_import method imports partial chunks while raising a + BulkWriteError and ignoring errors. + """ + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + # Identical statement ID produces the same ObjectId, leading to a + # duplicated key write error while trying to bulk import this batch + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "lol", **timestamp}, + ] + documents = list(MongoDataBackend.to_documents(statements)) + assert backend.bulk_import(documents, ignore_errors=True) == 3 + + +def test_backends_database_mongo_put_method(mongo): + """Test the mongo backend put method.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + success = backend.write(statements) + assert success == 2 + assert collection.estimated_document_count() == 2 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + +def test_backends_database_mongo_put_method_bytes(mongo): + """Test the mongo backend put method with bytes.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", "text": "foo", **timestamp}, + {"id": "bar", "text": "bar", **timestamp}, + ] + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + byte_data = [] + for item in statements: + json_str = json.dumps(item, separators=(",", ":"), ensure_ascii=False) + byte_data.append(json_str.encode("utf-8")) + success = backend.write(byte_data) + assert success == 2 + assert collection.estimated_document_count() == 2 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", "text": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", "text": "bar", **timestamp}, + } + + +def test_backends_database_mongo_put_method_bytes_failed(mongo): + """Test the mongo backend put method with bytes.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + byte_data = [] + json_str = "failed_json_str" + byte_data.append(json_str.encode("utf-8")) + + with pytest.raises(json.JSONDecodeError): + success = backend.write(byte_data) + assert collection.estimated_document_count() == 0 + + success = backend.write(byte_data, ignore_errors=True) + assert success == 0 + assert collection.estimated_document_count() == 0 + + +def test_backends_database_mongo_put_method_with_target(mongo): + """Test the mongo backend put method.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + success = backend.write(statements, target=MONGO_TEST_COLLECTION) + assert success == 2 + assert collection.estimated_document_count() == 2 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + +def test_backends_database_mongo_put_method_with_no_ids(mongo): + """Test the mongo backend put method with no IDs.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{**timestamp}, {**timestamp}] + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + success = backend.write(statements, operation_type=BaseOperationType.INDEX) + assert success == 2 + assert collection.estimated_document_count() == 2 + + +def test_backends_database_mongo_put_method_with_custom_chunk_size(mongo): + """Test the mongo backend put method with a custom chunk_size.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + success = backend.write(statements, chunk_size=2) + assert success == 2 + assert collection.estimated_document_count() == 2 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + +def test_backends_database_mongo_put_method_with_duplicated_key(mongo): + """Test the mongo backend put method with a duplicated key conflict.""" + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + # Identical statement ID produces the same ObjectId, leading to a + # duplicated key write error while trying to bulk import this batch + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "bar", **timestamp}, + ] + with pytest.raises(BackendException, match="E11000 duplicate key error collection"): + backend.write(statements) + + success = backend.write(statements, ignore_errors=True) + assert success == 0 + + +def test_backends_data_mongo_data_backend_write_method_with_update_operation( + mongo, +): + """Test the mongo backend write method with a update operation.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + success = backend.write(statements) + assert success == 2 + assert collection.estimated_document_count() == 2 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + timestamp = {"timestamp": "2022-06-27T16:36:50"} + statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + success = backend.write( + statements, chunk_size=2, operation_type=BaseOperationType.UPDATE + ) + assert success == 2 + assert collection.estimated_document_count() == 2 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + +def test_backends_data_mongo_data_backend_write_method_with_delete_operation( + mongo, +): + """Test the mongo backend write method with a delete operation.""" + database = getattr(mongo, MONGO_TEST_DATABASE) + collection = getattr(database, MONGO_TEST_COLLECTION) + assert collection.estimated_document_count() == 0 + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + + success = backend.write(statements, chunk_size=2) + assert success == 3 + assert collection.estimated_document_count() == 3 + + results = collection.find() + assert next(results) == { + "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_source": {"id": "foo", **timestamp}, + } + assert next(results) == { + "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_source": {"id": "bar", **timestamp}, + } + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + statements = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + success = backend.write( + statements, chunk_size=2, operation_type=BaseOperationType.DELETE + ) + assert success == 3 + + assert not list(backend.read()) + + assert collection.estimated_document_count() == 0 + + +def test_backends_database_mongo_query_statements(monkeypatch, caplog, mongo): + """Tests the mongo backend query_statements method, given a search query failure, + should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument,use-implicit-booleaness-not-comparison + + # Instantiate Mongo Databases + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoLRSBackend(settings) + + # Insert documents + timestamp = {"timestamp": "2022-06-27T15:36:50"} + meta = { + "actor": {"account": {"name": "test_name"}}, + "verb": {"id": "verb_id"}, + "object": {"id": "http://example.com", "objectType": "Activity"}, + } + collection_document = list( + MongoDataBackend.to_documents( + [ + {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta}, + {"id": "62b9ce92fcde2b2edba56bf4", **timestamp, **meta}, + ] + ) + ) + backend.bulk_import(collection_document) + + statement_parameters = StatementParameters() + statement_parameters.activity = "http://example.com" + statement_parameters.registration = ObjectId("62b9ce922c26b46b68ffc68f") + statement_parameters.since = "2020-01-01T00:00:00.000000+00:00" + statement_parameters.until = "2022-12-01T15:36:50" + statement_parameters.search_after = ObjectId("62b9ce922c26b46b68ffc68f") + statement_parameters.limit = 25 + statement_parameters.ascending = True + statement_parameters.related_activities = True + statement_parameters.related_agents = True + statement_parameters.format = "ids" + statement_parameters.agent = "test_name" + statement_parameters.verb = "verb_id" + statement_parameters.attachments = False + statement_parameters.search_after = ObjectId("62b9ce922c26b46b68ffc68f") + statement_parameters.statementId = "62b9ce922c26b46b68ffc68f" + statement_query_result = backend.query_statements(statement_parameters) + + assert len(statement_query_result.statements) > 0 + + +def test_backends_database_mongo_query_statements_with_search_query_failure( + monkeypatch, caplog, mongo +): + """Tests the mongo backend query_statements method, given a search query failure, + should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + def mock_find(**_): + """Mocks the MongoClient.collection.find method.""" + raise PyMongoError("Something is wrong") + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoLRSBackend(settings) + monkeypatch.setattr(backend.collection, "find", mock_find) + + caplog.set_level(logging.ERROR) + + msg = "'Failed to execute MongoDB query', 'Something is wrong'" + with pytest.raises(BackendException, match=msg): + backend.query_statements(StatementParameters()) + + logger_name = "ralph.backends.data.mongo" + msg = "Failed to execute MongoDB query. Something is wrong" + assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] + + +def test_backends_database_mongo_query_statements_by_ids_with_search_query_failure( + monkeypatch, caplog, mongo +): + """Tests the mongo backend query_statements_by_ids method, given a search query + failure, should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + def mock_find(**_): + """Mocks the MongoClient.collection.find method.""" + raise ValueError("Something is wrong") + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoLRSBackend(settings) + monkeypatch.setattr(backend.collection, "find", mock_find) + caplog.set_level(logging.ERROR) + + msg = "'Failed to execute MongoDB query', 'Something is wrong'" + with pytest.raises(BackendException, match=msg): + backend.query_statements_by_ids(StatementParameters()) + + logger_name = "ralph.backends.data.mongo" + msg = "Failed to execute MongoDB query. Something is wrong" + assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] + + +def test_backends_database_mongo_query_statements_by_ids_with_multiple_collections( + mongo, mongo_forwarding +): + """Tests the mongo backend query_statements_by_ids method, given a valid search + query, should execute the query uniquely on the specified collection and return the + expected results. + """ + # pylint: disable=unused-argument,use-implicit-booleaness-not-comparison + + # Instantiate Mongo Databases + + settings_1 = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend_1 = MongoLRSBackend(settings_1) + + settings_2 = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_FORWARDING_COLLECTION, + ) + backend_2 = MongoLRSBackend(settings_2) + + # Insert documents + timestamp = {"timestamp": "2022-06-27T15:36:50"} + collection_1_document = list( + MongoDataBackend.to_documents([{"id": "1", **timestamp}]) + ) + collection_2_document = list( + MongoDataBackend.to_documents([{"id": "2", **timestamp}]) + ) + backend_1.bulk_import(collection_1_document) + backend_2.bulk_import(collection_2_document) + + # Check the expected search query results + assert backend_1.query_statements_by_ids(["1"]) == collection_1_document + assert backend_1.query_statements_by_ids(["2"]) == [] + assert backend_2.query_statements_by_ids(["1"]) == [] + assert backend_2.query_statements_by_ids(["2"]) == collection_2_document + + +def test_backends_database_mongo_status(mongo): + """Test the Mongo status method. + + As pymongo is monkeypatching the MongoDB client to add admin object, it's + barely untestable. 😢 + """ + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + assert backend.status() == DataBackendStatus.OK + + +def test_backends_database_mongo_status_connection_failed(mongo): + """Test the Mongo status method. + + As pymongo is monkeypatching the MongoDB client to add admin object, it's + barely untestable. 😢 + """ + # pylint: disable=unused-argument + + settings = MongoDataBackend.settings_class( + CONNECTION_URI="mongodb://localhost:27018", + DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + ) + backend = MongoDataBackend(settings) + assert backend.status() == DataBackendStatus.AWAY From 23e6aa3b0df126025e396d22a08a05c56ec40fb0 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Mon, 15 May 2023 17:24:33 +0200 Subject: [PATCH 08/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20add=20S?= =?UTF-8?q?wiftDataBackend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We add the OpenStack Swift data backend implementation. With the `data` parameter changed to an Iterable, we cannot use high level SwiftService API to upload files anymore (it needs a file object source, not an iterable). Changing to Connection lower-level API, which is more flexible. --- src/ralph/backends/data/swift.py | 384 ++++++++++++++++++ tests/backends/data/test_swift.py | 640 ++++++++++++++++++++++++++++++ tests/conftest.py | 1 + tests/fixtures/backends.py | 26 ++ 4 files changed, 1051 insertions(+) create mode 100644 src/ralph/backends/data/swift.py create mode 100644 tests/backends/data/test_swift.py diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py new file mode 100644 index 000000000..c3163603a --- /dev/null +++ b/src/ralph/backends/data/swift.py @@ -0,0 +1,384 @@ +"""Base data backend for Ralph.""" + +import json +import logging +from functools import cached_property +from io import IOBase +from typing import Iterable, Iterator, Union +from uuid import uuid4 + +from swiftclient.service import ClientException, Connection + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.backends.mixins import HistoryMixin +from ralph.conf import BaseSettingsConfig +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + +logger = logging.getLogger(__name__) + + +class SwiftDataBackendSettings(BaseDataBackendSettings): + """Represent the SWIFT data backend default configuration. + + Attributes: + AUTH_URL (str): The authentication URL. + USERNAME (str): The name of the openstack swift user. + PASSWORD (str): The password of the openstack swift user. + IDENTITY_API_VERSION (str): The keystone API version to authenticate to. + TENANT_ID (str): The identifier of the tenant of the container. + TENANT_NAME (str): The name of the tenant of the container. + PROJECT_DOMAIN_NAME (str): The project domain name. + REGION_NAME (str): The region where the container is. + OBJECT_STORAGE_URL (str): The default storage URL. + USER_DOMAIN_NAME (str): The user domain name. + DEFAULT_CONTAINER (str): The default target container. + LOCALE_ENCODING (str): The encoding used for reading/writing documents. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__SWIFT__" + + AUTH_URL: str = "https://auth.cloud.ovh.net/" + USERNAME: str = None + PASSWORD: str = None + IDENTITY_API_VERSION: str = "3" + TENANT_ID: str = None + TENANT_NAME: str = None + PROJECT_DOMAIN_NAME: str = "Default" + REGION_NAME: str = None + OBJECT_STORAGE_URL: str = None + USER_DOMAIN_NAME: str = "Default" + DEFAULT_CONTAINER: str = None + LOCALE_ENCODING: str = "utf8" + + +class SwiftDataBackend(HistoryMixin, BaseDataBackend): + """SWIFT data backend.""" + + # pylint: disable=too-many-instance-attributes + + name = "swift" + default_operation_type = BaseOperationType.CREATE + settings_class = SwiftDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Prepares the options for the SwiftService.""" + self.settings = settings if settings else self.settings_class() + + self.default_container = self.settings.DEFAULT_CONTAINER + self.locale_encoding = self.settings.LOCALE_ENCODING + self._connection = None + + @cached_property + def options(self) -> dict: + """Return the required options for the Swift Connection.""" + return { + "tenant_id": self.settings.TENANT_ID, + "tenant_name": self.settings.TENANT_NAME, + "project_domain_name": self.settings.PROJECT_DOMAIN_NAME, + "region_name": self.settings.REGION_NAME, + "object_storage_url": self.settings.OBJECT_STORAGE_URL, + "user_domain_name": self.settings.USER_DOMAIN_NAME, + } + + @property + def connection(self): + """Create a Swift Connection if it doesn't exist.""" + if not self._connection: + self._connection = Connection( + authurl=self.settings.AUTH_URL, + user=self.settings.USERNAME, + key=self.settings.PASSWORD, + os_options=self.options, + auth_version=self.settings.IDENTITY_API_VERSION, + ) + return self._connection + + def status(self) -> DataBackendStatus: + """Implement data backend checks (e.g. connection, cluster status). + + Returns: + DataBackendStatus: The status of the data backend. + """ + try: + self.connection.head_account() + except ClientException as err: + msg = "Unable to connect to the Swift account: %s" + logger.error(msg, err.msg) + return DataBackendStatus.ERROR + + return DataBackendStatus.OK + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List files for the target container. + + Args: + target (str or None): The target container to list from. + If `target` is `None`, the `default_container` will be used. + details (bool): Get detailed object information instead of just names. + new (bool): Given the history, list only not already read objects. + + Yields: + str: The next object path. (If details is False) + dict: The next object details. (If `details` is True.) + + Raises: + BackendException: If a failure occurs. + """ + if target is None: + target = self.default_container + + archives_to_skip = set() + if new: + archives_to_skip = set(self.get_command_history(self.name, "read")) + + try: + _, objects = self.connection.get_container( + container=target, full_listing=True + ) + except ClientException as err: + msg = "Failed to list container %s: %s" + logger.error(msg, target, err.msg) + raise BackendException(msg % (target, err.msg)) from err + + for obj in objects: + if new and obj in archives_to_skip: + continue + yield self._details(target, obj) if details else obj + + @enforce_query_checks + def read( + self, + *, + query: Union[str, BaseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = 500, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read objects matching the `query` in the `target` container and yields them. + + Args: + query: (str or BaseQuery): The query to select objects to read. + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + chunk_size (int or None): The number of records or bytes to read in one + batch, depending on whether the records are dictionaries or bytes. + raw_output (bool): Controls whether to yield bytes or dictionaries. + If the objects are dictionaries and `raw_output` is set to `True`, they + are encoded as JSON. + If the objects are bytes and `raw_output` is set to `False`, they are + decoded as JSON by line. + ignore_errors (bool): If `True`, errors during the read operation + are be ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + + Yields: + dict: If `raw_output` is False. + bytes: If `raw_output` is True. + + Raises: + BackendException: If a failure during the read operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid. + """ + if query.query_string is None: + msg = "Invalid query. The query should be a valid archive name." + logger.error(msg) + if not ignore_errors: + raise BackendParameterException(msg) + + target = target if target else self.default_container + + logger.info( + "Getting object from container: %s (query_string: %s)", + target, + query.query_string, + ) + + try: + resp_headers, content = self.connection.get_object( + container=target, + obj=query.query_string, + resp_chunk_size=chunk_size, + ) + except ClientException as err: + msg = "Failed to read %s: %s" + error = err.msg + logger.error(msg, query.query_string, error) + if not ignore_errors: + raise BackendException(msg % (query.query_string, error)) from err + + reader = self._read_raw if raw_output else self._read_dict + + for chunk in reader(content, chunk_size, ignore_errors): + yield chunk + + # Archive read, add a new entry to the history + self.append_to_history( + { + "backend": self.name, + "action": "read", + "id": f"{target}/{query.query_string}", + "size": resp_headers["Content-Length"], + "timestamp": now(), + } + ) + + def write( # pylint: disable=too-many-arguments, disable=too-many-branches + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Write `data` records to the `target` container and returns their count. + + Args: + data: (Iterable or IOBase): The data to write. + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + chunk_size (int or None): Ignored. + ignore_errors (bool): If `True`, errors during the write operation + are ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Returns: + int: The number of written records. + + Raises: + BackendException: If a failure during the write operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid. + """ + try: + first_record = next(iter(data)) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return 0 + if not operation_type: + operation_type = self.default_operation_type + + if not target: + target = f"{self.default_container}/{now()}-{uuid4()}" + logger.info( + ( + "Target not specified; using default container " + "with random object name: %s" + ), + target, + ) + elif "/" not in target: + target = f"{self.default_container}/{target}" + logger.info( + "Container not specified; using default container: %s", + self.default_container, + ) + + target_container, target_object = target.split("/", 1) + + if operation_type in [ + BaseOperationType.APPEND, + BaseOperationType.DELETE, + BaseOperationType.UPDATE, + ]: + msg = "%s operation_type is not allowed." + logger.error(msg, operation_type.name) + if not ignore_errors: + raise BackendParameterException(msg % operation_type.name) + + if operation_type in [BaseOperationType.CREATE, BaseOperationType.INDEX]: + if target_object in list(self.list(target=target_container)): + msg = "%s already exists and overwrite is not allowed for operation %s" + logger.error(msg, target_object, operation_type) + if not ignore_errors: + raise BackendException(msg % (target_object, operation_type)) + + if isinstance(first_record, dict): + data = [ + json.dumps(statement).encode(self.locale_encoding) + for statement in data + ] + + try: + self.connection.put_object( + container=target_container, obj=target_object, contents=data + ) + resp = self.connection.head_object( + container=target_container, obj=target_object + ) + except ClientException as err: + msg = "Failed to write to object %s: %s" + error = err.msg + logger.error(msg, target_object, error) + if not ignore_errors: + raise BackendException(msg % (target_object, error)) from err + + count = sum(1 for _ in data) + logging.info("Successfully written %s statements to %s", count, target) + + # Archive written, add a new entry to the history + self.append_to_history( + { + "backend": self.name, + "action": "write", + "operation_type": operation_type.value, + "id": target, + "size": resp["Content-Length"], + "timestamp": now(), + } + ) + return count + + def _details(self, container: str, name: str): + """Return `name` object details from `container`.""" + try: + resp = self.connection.head_object(container=container, obj=name) + except ClientException as err: + msg = "Unable to retrieve details for object %s: %s" + logger.error(msg, name, err.msg) + raise BackendException(msg % (name, err.msg)) from err + + return { + "name": name, + "lastModified": resp["Last-Modified"], + "size": resp["Content-Length"], + } + + @staticmethod + def _read_dict( + obj: Iterable, _chunk_size: int, ignore_errors: bool + ) -> Iterator[dict]: + """Read the `object` by line and yield JSON parsed dictionaries.""" + for i, line in enumerate(obj): + try: + yield json.loads(line) + except (TypeError, json.JSONDecodeError) as err: + msg = "Raised error: %s, at line %s" + logger.error(msg, err, i) + if not ignore_errors: + raise BackendException(msg % (err, i)) from err + + @staticmethod + def _read_raw( + obj: Iterable, chunk_size: int, _ignore_errors: bool + ) -> Iterator[bytes]: + """Read the `object` by line and yield bytes.""" + while chunk := obj.read(chunk_size): + yield chunk diff --git a/tests/backends/data/test_swift.py b/tests/backends/data/test_swift.py new file mode 100644 index 000000000..117363a72 --- /dev/null +++ b/tests/backends/data/test_swift.py @@ -0,0 +1,640 @@ +"""Tests for Ralph swift data backend.""" + +import json +import logging +from io import BytesIO +from operator import itemgetter +from typing import Iterable +from uuid import uuid4 + +import pytest +from swiftclient.service import ClientException + +from ralph.backends.data.base import BaseOperationType, BaseQuery, DataBackendStatus +from ralph.backends.data.swift import SwiftDataBackend, SwiftDataBackendSettings +from ralph.conf import settings +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + + +def test_backends_data_swift_data_backend_default_instantiation(monkeypatch, fs): + """Test the `SwiftDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "AUTH_URL", + "USERNAME", + "PASSWORD", + "IDENTITY_API_VERSION", + "TENANT_ID", + "TENANT_NAME", + "PROJECT_DOMAIN_NAME", + "REGION_NAME", + "OBJECT_STORAGE_URL", + "USER_DOMAIN_NAME", + "DEFAULT_CONTAINER", + "LOCALE_ENCODING", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__SWIFT__{name}", raising=False) + + assert SwiftDataBackend.name == "swift" + assert SwiftDataBackend.query_model == BaseQuery + assert SwiftDataBackend.default_operation_type == BaseOperationType.CREATE + assert SwiftDataBackend.settings_class == SwiftDataBackendSettings + backend = SwiftDataBackend() + assert backend.options["tenant_id"] is None + assert backend.options["tenant_name"] is None + assert backend.options["project_domain_name"] == "Default" + assert backend.options["region_name"] is None + assert backend.options["object_storage_url"] is None + assert backend.options["user_domain_name"] == "Default" + assert backend.default_container is None + assert backend.locale_encoding == "utf8" + + +def test_backends_data_swift_data_backend_instantiation_with_settings(fs): + """Test the `SwiftDataBackend` instantiation with settings.""" + # pylint: disable=invalid-name + fs.create_file(".env") + settings_ = SwiftDataBackend.settings_class( + AUTH_URL="https://toto.net/", + USERNAME="username", + PASSWORD="password", + IDENTITY_API_VERSION="2", + TENANT_ID="tenant_id", + TENANT_NAME="tenant_name", + PROJECT_DOMAIN_NAME="project_domain_name", + REGION_NAME="region_name", + OBJECT_STORAGE_URL="object_storage_url", + USER_DOMAIN_NAME="user_domain_name", + DEFAULT_CONTAINER="default_container", + LOCALE_ENCODING="utf-16", + ) + backend = SwiftDataBackend(settings_) + assert backend.options["tenant_id"] == "tenant_id" + assert backend.options["tenant_name"] == "tenant_name" + assert backend.options["project_domain_name"] == "project_domain_name" + assert backend.options["region_name"] == "region_name" + assert backend.options["object_storage_url"] == "object_storage_url" + assert backend.options["user_domain_name"] == "user_domain_name" + assert backend.default_container == "default_container" + assert backend.locale_encoding == "utf-16" + + try: + SwiftDataBackend(settings_) + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"SwiftDataBackend should not raise exceptions: {err}") + + +def test_backends_data_swift_data_backend_status_method_with_error_status( + monkeypatch, swift_backend, caplog +): + """Test the `SwiftDataBackend.status` method, given a failed connection, + should return `DataBackendStatus.ERROR`.""" + error = ( + "Unauthorized. Check username/id, password, tenant name/id and" + " user/tenant domain name/id." + ) + + def mock_failed_head_account(*args, **kwargs): + # pylint:disable=unused-argument + raise ClientException(error) + + swift = swift_backend() + monkeypatch.setattr(swift.connection, "head_account", mock_failed_head_account) + + with caplog.at_level(logging.ERROR): + assert swift.status() == DataBackendStatus.ERROR + + assert ( + "ralph.backends.data.swift", + logging.ERROR, + f"Unable to connect to the Swift account: {error}", + ) in caplog.record_tuples + + +def test_backends_data_swift_data_backend_status_method_with_ok_status( + monkeypatch, swift_backend, caplog +): + """Test the `SwiftDataBackend.status` method, given a directory with wrong + permissions, should return `DataBackendStatus.OK`. + """ + + def mock_successful_head_account(*args, **kwargs): # pylint:disable=unused-argument + return 1 + + swift = swift_backend() + monkeypatch.setattr(swift.connection, "head_account", mock_successful_head_account) + + with caplog.at_level(logging.ERROR): + assert swift.status() == DataBackendStatus.OK + + assert caplog.record_tuples == [] + + +def test_backends_data_swift_data_backend_list_method( + swift_backend, monkeypatch, fs, settings_fs +): # pylint:disable=invalid-name,unused-argument + """Test that the `SwiftDataBackend.list` method argument should list + the default container. + """ + frozen_now = now() + listing = [ + { + "name": "2020-04-29.gz", + "lastModified": frozen_now, + "size": 12, + }, + { + "name": "2020-04-30.gz", + "lastModified": frozen_now, + "size": 25, + }, + { + "name": "2020-05-01.gz", + "lastModified": frozen_now, + "size": 42, + }, + ] + history = [ + { + "backend": "swift", + "action": "read", + "id": "2020-04-29.gz", + }, + { + "backend": "swift", + "action": "read", + "id": "2020-04-30.gz", + }, + ] + + def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument + return (None, [x["name"] for x in listing]) + + def mock_head_object(container, obj): # pylint:disable=unused-argument + resp = next((x for x in listing if x["name"] == obj), None) + return { + "Last-Modified": resp["lastModified"], + "Content-Length": resp["size"], + } + + backend = swift_backend() + monkeypatch.setattr(backend.connection, "get_container", mock_get_container) + monkeypatch.setattr(backend.connection, "head_object", mock_head_object) + fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) + + assert list(backend.list()) == [x["name"] for x in listing] + assert list(backend.list(new=True)) == ["2020-05-01.gz"] + assert list(backend.list(details=True)) == listing + + +def test_backends_data_swift_data_backend_list_with_failed_details( + swift_backend, monkeypatch, fs, caplog, settings_fs +): # pylint:disable=invalid-name,unused-argument,too-many-arguments + """Test that the `SwiftDataBackend.list` method with a failed connection + when retrieving details, should log the error and raise a BackendException. + """ + error = "Test client exception" + + frozen_now = now() + listing = [ + { + "name": "2020-04-29.gz", + "lastModified": frozen_now, + "size": 12, + }, + ] + + def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument + return (None, [x["name"] for x in listing]) + + def mock_head_object(*args, **kwargs): # pylint:disable=unused-argument + raise ClientException(error) + + backend = swift_backend() + monkeypatch.setattr(backend.connection, "get_container", mock_get_container) + monkeypatch.setattr(backend.connection, "head_object", mock_head_object) + fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) + + error = "Test client exception" + msg = f"Unable to retrieve details for object {listing[0]['name']}: {error}" + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + next(backend.list(details=True)) + + assert ("ralph.backends.data.swift", logging.ERROR, msg) in caplog.record_tuples + + +def test_backends_data_swift_data_backend_list_with_failed_connection( + swift_backend, monkeypatch, fs, caplog, settings_fs +): # pylint:disable=invalid-name,unused-argument,too-many-arguments + """Test that the `SwiftDataBackend.list` method with a failed connection + should log the error and raise a BackendException. + """ + error = "Container not found" + + def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument + raise ClientException(error) + + backend = swift_backend() + monkeypatch.setattr(backend.connection, "get_container", mock_get_container) + fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) + + msg = "Failed to list container container_name: Container not found" + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + next(backend.list()) + with pytest.raises(BackendException, match=msg): + next(backend.list(new=True)) + with pytest.raises(BackendException, match=msg): + next(backend.list(details=True)) + + assert ("ralph.backends.data.swift", logging.ERROR, msg) in caplog.record_tuples + + +def test_backends_data_swift_data_backend_read_method_with_raw_output( + swift_backend, monkeypatch, fs, settings_fs +): # pylint:disable=invalid-name, unused-argument + """Test the `SwiftDataBackend.read` method with `raw_output` set to `True`.""" + + # Object contents. + content = b'{"foo": "bar"}' + + # Freeze the ralph.utils.now() value. + frozen_now = now() + + backend = swift_backend() + + def mock_get_object(*args, **kwargs): # pylint:disable=unused-argument + resp_headers = {"Content-Length": 14} + return (resp_headers, BytesIO(content)) + + monkeypatch.setattr(backend.connection, "get_object", mock_get_object) + monkeypatch.setattr("ralph.backends.data.swift.now", lambda: frozen_now) + fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) + + # The `read` method should read the object and yield bytes. + result = backend.read(raw_output=True, query="2020-04-29.gz") + assert isinstance(result, Iterable) + assert list(result) == [content] + + assert backend.history == [ + { + "backend": "swift", + "action": "read", + "id": "container_name/2020-04-29.gz", + "size": 14, + "timestamp": frozen_now, + } + ] + + # Given a `chunk_size`,` the `read` method should write the output bytes + # in chunks of the specified `chunk_size`. + result = backend.read(raw_output=True, query="2020-05-30.gz", chunk_size=2) + assert isinstance(result, Iterable) + assert list(result) == [b'{"', b"fo", b'o"', b": ", b'"b', b"ar", b'"}'] + + assert backend.history == [ + { + "backend": "swift", + "action": "read", + "id": "container_name/2020-04-29.gz", + "size": 14, + "timestamp": frozen_now, + }, + { + "backend": "swift", + "action": "read", + "id": "container_name/2020-05-30.gz", + "size": 14, + "timestamp": frozen_now, + }, + ] + + +def test_backends_data_swift_data_backend_read_method_without_raw_output( + swift_backend, monkeypatch, fs, settings_fs +): # pylint:disable=invalid-name, unused-argument + """Test the `SwiftDataBackend.read` method with `raw_output` set to `False`.""" + + # Object contents. + content_dict = {"foo": "bar"} + content_bytes = b'{"foo": "bar"}' + + # Freeze the ralph.utils.now() value. + frozen_now = now() + + backend = swift_backend() + + def mock_get_object(*args, **kwargs): # pylint:disable=unused-argument + resp_headers = {"Content-Length": 14} + return (resp_headers, BytesIO(content_bytes)) + + monkeypatch.setattr(backend.connection, "get_object", mock_get_object) + monkeypatch.setattr("ralph.backends.data.swift.now", lambda: frozen_now) + fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) + + # The `read` method should read the object and yield bytes. + result = backend.read(raw_output=False, query="2020-04-29.gz") + assert isinstance(result, Iterable) + assert list(result) == [content_dict] + + assert backend.history == [ + { + "backend": "swift", + "action": "read", + "id": "container_name/2020-04-29.gz", + "size": 14, + "timestamp": frozen_now, + } + ] + + +def test_backends_data_swift_data_backend_read_method_with_invalid_query(swift_backend): + """Test the `SwiftDataBackend.read` method given an invalid `query` argument should + raise a `BackendParameterException`. + """ + backend = swift_backend() + # Given no `query`, the `read` method should raise a `BackendParameterException`. + error = "Invalid query. The query should be a valid archive name" + with pytest.raises(BackendParameterException, match=error): + list(backend.read()) + + +def test_backends_data_swift_data_backend_read_method_with_ignore_errors( + monkeypatch, swift_backend, fs, settings_fs +): + """Test the `SwiftDataBackend.read` method with `ignore_errors` set to `True`, + given an archive containing invalid JSON lines, should skip the invalid lines. + """ + # pylint: disable=invalid-name, unused-argument + + # File contents. + valid_dictionary = {"foo": "bar"} + valid_json = json.dumps(valid_dictionary) + invalid_json = "baz" + valid_invalid_json = bytes( + f"{valid_json}\n{invalid_json}\n{valid_json}", + encoding="utf8", + ) + invalid_valid_json = bytes( + f"{invalid_json}\n{valid_json}\n{invalid_json}", + encoding="utf8", + ) + + backend = swift_backend() + + def mock_get_object_1(*args, **kwargs): # pylint:disable=unused-argument + resp_headers = {"Content-Length": 14} + return (resp_headers, BytesIO(valid_invalid_json)) + + monkeypatch.setattr(backend.connection, "get_object", mock_get_object_1) + + # The `read` method should read all valid statements and yield dictionaries + result = backend.read(ignore_errors=True, query="2020-06-02.gz") + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary, valid_dictionary] + + def mock_get_object_2(*args, **kwargs): # pylint:disable=unused-argument + resp_headers = {"Content-Length": 14} + return (resp_headers, BytesIO(invalid_valid_json)) + + monkeypatch.setattr(backend.connection, "get_object", mock_get_object_2) + + # The `read` method should read all valid statements and yield bytes + result = backend.read(ignore_errors=True, query="2020-06-02.gz") + assert isinstance(result, Iterable) + assert list(result) == [valid_dictionary] + + +def test_backends_data_swift_data_backend_read_method_without_ignore_errors( + monkeypatch, swift_backend, fs, settings_fs +): + """Test the `SwiftDataBackend.read` method with `ignore_errors` set to `False`, + given a file containing invalid JSON lines, should raise a `BackendException`. + """ + # pylint: disable=invalid-name, unused-argument + + # File contents. + valid_dictionary = {"foo": "bar"} + valid_json = json.dumps(valid_dictionary) + invalid_json = "baz" + valid_invalid_json = bytes( + f"{valid_json}\n{invalid_json}\n{valid_json}", + encoding="utf8", + ) + invalid_valid_json = bytes( + f"{invalid_json}\n{valid_json}\n{invalid_json}", + encoding="utf8", + ) + + backend = swift_backend() + + def mock_get_object_1(*args, **kwargs): # pylint:disable=unused-argument + resp_headers = {"Content-Length": 14} + return (resp_headers, BytesIO(valid_invalid_json)) + + monkeypatch.setattr(backend.connection, "get_object", mock_get_object_1) + + # Given one object with an invalid json at the second line, the `read` + # method should yield the first valid line and raise a `BackendException` + # at the second line. + result = backend.read(ignore_errors=False, query="2020-06-02.gz") + assert isinstance(result, Iterable) + assert next(result) == valid_dictionary + with pytest.raises(BackendException, match="Raised error:"): + next(result) + + # When the `read` method fails to read a file entirely, then no entry should be + # added to the history. + assert not backend.history + + def mock_get_object_2(*args, **kwargs): # pylint:disable=unused-argument + resp_headers = {"Content-Length": 14} + return (resp_headers, BytesIO(invalid_valid_json)) + + monkeypatch.setattr(backend.connection, "get_object", mock_get_object_2) + + # Given one object with an invalid json at the first and third lines, the `read` + # method should raise a `BackendException` at the second line. + result = backend.read(ignore_errors=False, query="2020-06-03.gz") + assert isinstance(result, Iterable) + with pytest.raises(BackendException, match="Raised error:"): + next(result) + + +def test_backends_data_swift_data_backend_read_method_with_failed_connection( + caplog, monkeypatch, swift_backend +): + """Test the `SwiftDataBackend.read` method, given a `ClientException` raised by + method `get_object`, should raise a `BackendException`.""" + + error = "Failed to get object." + + def mock_failed_get_object(*args, **kwargs): # pylint:disable=unused-argument + raise ClientException(error) + + backend = swift_backend() + monkeypatch.setattr(backend.connection, "get_object", mock_failed_get_object) + + msg = f"Failed to read object.gz: {error}" + with caplog.at_level(logging.ERROR): + result = backend.read(query="object.gz") + with pytest.raises(BackendException, match=msg): + next(result) + + assert ("ralph.backends.data.swift", logging.ERROR, msg) in caplog.record_tuples + + +@pytest.mark.parametrize( + "operation_type", [None, BaseOperationType.CREATE, BaseOperationType.INDEX] +) +def test_backends_data_swift_data_backend_write_method_with_file_exists_error( + operation_type, swift_backend, monkeypatch, fs, settings_fs +): + """Test the `SwiftDataBackend.write` method, given a target matching an + existing file and a `CREATE` or `INDEX` `operation_type`, should raise a + `BackendException`. + """ + # pylint: disable=invalid-name, unused-argument + listing = [{"name": "2020-04-29.gz"}, {"name": "object.gz"}] + + def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument + return (None, [x["name"] for x in listing]) + + backend = swift_backend() + monkeypatch.setattr(backend.connection, "get_container", mock_get_container) + + msg = ( + f"object.gz already exists and overwrite is not allowed for operation" + f" {operation_type if operation_type is not None else BaseOperationType.CREATE}" + ) + + with pytest.raises(BackendException, match=msg): + backend.write( + target="object.gz", data=[b"foo", b"test"], operation_type=operation_type + ) + + # When the `write` method fails, then no entry should be added to the history. + assert not sorted(backend.history, key=itemgetter("id")) + + +def test_backends_data_swift_data_backend_write_method_with_failed_connection( + monkeypatch, swift_backend, fs, settings_fs +): + """Test the `SwiftDataBackend.write` method, given a failed connection, should + raise a `BackendException`.""" + # pylint: disable=invalid-name, unused-argument + + backend = swift_backend() + + error = "Client Exception error." + msg = f"Failed to write to object object.gz: {error}" + + def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument + return (None, []) + + def mock_put_object(*args, **kwargs): # pylint:disable=unused-argument + return 1 + + def mock_head_object(*args, **kwargs): # pylint:disable=unused-argument + raise ClientException(error) + + monkeypatch.setattr(backend.connection, "get_container", mock_get_container) + monkeypatch.setattr(backend.connection, "put_object", mock_put_object) + monkeypatch.setattr(backend.connection, "head_object", mock_head_object) + + with pytest.raises(BackendException, match=msg): + backend.write(target="object.gz", data=[b"foo"]) + + # When the `write` method fails, then no entry should be added to the history. + assert not sorted(backend.history, key=itemgetter("id")) + + +@pytest.mark.parametrize( + "operation_type", + [ + BaseOperationType.APPEND, + BaseOperationType.DELETE, + BaseOperationType.UPDATE, + ], +) +def test_backends_data_swift_data_backend_write_method_with_invalid_operation( + # pylint: disable=line-too-long + operation_type, + swift_backend, + fs, + settings_fs, +): + """Test the `SwiftDataBackend.write` method, given an unsupported `operation_type`, + should raise a `BackendParameterException`.""" + # pylint: disable=invalid-name, unused-argument + + backend = swift_backend() + + msg = f"{operation_type.name} operation_type is not allowed." + with pytest.raises(BackendParameterException, match=msg): + backend.write(data=[b"foo"], operation_type=operation_type) + + # When the `write` method fails, then no entry should be added to the history. + assert not sorted(backend.history, key=itemgetter("id")) + + +def test_backends_data_swift_data_backend_write_method_without_target( + swift_backend, monkeypatch, fs, settings_fs +): + """Test the `SwiftDataBackend.write` method, given no target, should write + to the default container to a random object with the provided data. + """ + # pylint: disable=invalid-name, unused-argument + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.swift.now", lambda: frozen_now) + + # Freeze the uuid4() value. + frozen_uuid4 = uuid4() + monkeypatch.setattr("ralph.backends.data.swift.uuid4", lambda: frozen_uuid4) + + backend = swift_backend() + + # With empty data, `write` method is skipped + count = backend.write(data=()) + + assert backend.history == [] + assert count == 0 + + listing = [{"name": "2020-04-29.gz"}, {"name": "object.gz"}] + + def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument + return (None, [x["name"] for x in listing]) + + def mock_put_object(*args, **kwargs): # pylint:disable=unused-argument + return 1 + + def mock_head_object(*args, **kwargs): # pylint:disable=unused-argument + return {"Content-Length": 3} + + expected_filename = f"{frozen_now}-{frozen_uuid4}" + monkeypatch.setattr(backend.connection, "get_container", mock_get_container) + monkeypatch.setattr(backend.connection, "put_object", mock_put_object) + monkeypatch.setattr(backend.connection, "head_object", mock_head_object) + monkeypatch.setattr("ralph.backends.data.swift.now", lambda: frozen_now) + + count = backend.write(data=[{"foo": "bar"}, {"test": "toto"}]) + + assert count == 2 + assert backend.history == [ + { + "backend": "swift", + "action": "write", + "operation_type": BaseOperationType.CREATE.value, + "id": f"container_name/{expected_filename}", + "size": mock_head_object()["Content-Length"], + "timestamp": frozen_now, + } + ] diff --git a/tests/conftest.py b/tests/conftest.py index 52a1cbef4..af13a0707 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ s3, settings_fs, swift, + swift_backend, ws, ) from .fixtures.logs import gelf_logger # noqa: F401 diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 301f51dea..582df263e 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -23,6 +23,7 @@ from pymongo.errors import CollectionInvalid from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings +from ralph.backends.data.swift import SwiftDataBackend, SwiftDataBackendSettings from ralph.backends.database.clickhouse import ClickHouseDatabase from ralph.backends.database.es import ESDatabase from ralph.backends.database.mongo import MongoDatabase @@ -353,6 +354,31 @@ def get_swift_storage(): return get_swift_storage +@pytest.fixture +def swift_backend(): + """Returns get_swift_data_backend function.""" + + def get_swift_data_backend(): + """Returns an instance of SwiftDataBackend.""" + settings = SwiftDataBackendSettings( + AUTH_URL="https://auth.cloud.ovh.net/", + USERNAME="os_username", + PASSWORD="os_password", + IDENTITY_API_VERSION="3", + TENANT_ID="os_tenant_id", + TENANT_NAME="os_tenant_name", + PROJECT_DOMAIN_NAME="Default", + REGION_NAME="os_region_name", + OBJECT_STORAGE_URL="os_storage_url/ralph_logs_container", + USER_DOMAIN_NAME="Default", + DEFAULT_CONTAINER="container_name", + LOCALE_ENCODING="utf8", + ) + return SwiftDataBackend(settings) + + return get_swift_data_backend + + @pytest.fixture() def moto_fs(fs): """Fix the incompatibility between moto and pyfakefs""" From 684d6ef95249800c90ad025d9c9a7a6189ac51b0 Mon Sep 17 00:00:00 2001 From: SergioSim Date: Mon, 22 May 2023 10:06:06 +0200 Subject: [PATCH 09/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20LDPDataBacken?= =?UTF-8?q?d?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We add the LDP data backend implementation that is mostly taken from the existing LDPStorage backend. --- src/ralph/backends/data/ldp.py | 264 +++++++++++ tests/backends/data/test_ldp.py | 769 ++++++++++++++++++++++++++++++++ tests/conftest.py | 1 + tests/fixtures/backends.py | 42 +- 4 files changed, 1056 insertions(+), 20 deletions(-) create mode 100644 src/ralph/backends/data/ldp.py create mode 100644 tests/backends/data/test_ldp.py diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py new file mode 100644 index 000000000..b6fe84944 --- /dev/null +++ b/src/ralph/backends/data/ldp.py @@ -0,0 +1,264 @@ +"""OVH's LDP data backend for Ralph.""" + +import logging +from typing import Iterable, Iterator, Literal, Union + +import ovh +import requests + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.backends.mixins import HistoryMixin +from ralph.conf import BaseSettingsConfig +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + +logger = logging.getLogger(__name__) + + +class LDPDataBackendSettings(BaseDataBackendSettings): + """OVH LDP (Log Data Platform) data backend default configuration. + + Attributes: + APPLICATION_KEY (str): The OVH API application key (AK). + APPLICATION_SECRET (str): The OVH API application secret (AS). + CONSUMER_KEY (str): The OVH API consumer key (CK). + DEFAULT_STREAM_ID (str): The default stream identifier to query. + ENDPOINT (str): The OVH API endpoint. + REQUEST_TIMEOUT (int): HTTP request timeout in seconds. + SERVICE_NAME (str): The default LDP account name. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__LDP__" + + APPLICATION_KEY: str = None + APPLICATION_SECRET: str = None + CONSUMER_KEY: str = None + DEFAULT_STREAM_ID: str = None + ENDPOINT: Literal[ + "ovh-eu", + "ovh-us", + "ovh-ca", + "kimsufi-eu", + "kimsufi-ca", + "soyoustart-eu", + "soyoustart-ca", + ] = "ovh-eu" + REQUEST_TIMEOUT: int = None + SERVICE_NAME: str = None + + +class LDPDataBackend(HistoryMixin, BaseDataBackend): + """OVH LDP (Log Data Platform) data backend.""" + + name = "ldp" + settings_class = LDPDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Instantiate the OVH LDP client. + + Args: + settings (LDPDataBackendSettings or None): The data backend settings. + If `settings` is `None`, a default settings instance is used instead. + """ + self.settings = settings if settings else self.settings_class() + self.service_name = self.settings.SERVICE_NAME + self.stream_id = self.settings.DEFAULT_STREAM_ID + self.timeout = self.settings.REQUEST_TIMEOUT + self._client = None + + @property + def client(self): + """Create an ovh.Client if it doesn't exist.""" + if not self._client: + self._client = ovh.Client( + endpoint=self.settings.ENDPOINT, + application_key=self.settings.APPLICATION_KEY, + application_secret=self.settings.APPLICATION_SECRET, + consumer_key=self.settings.CONSUMER_KEY, + ) + return self._client + + def status(self) -> DataBackendStatus: + """Check whether the default service_name is accessible.""" + try: + self.client.get(self._get_archive_endpoint()) + except ovh.exceptions.APIError as error: + logger.error("Failed to connect to the LDP: %s", error) + return DataBackendStatus.ERROR + except BackendParameterException: + return DataBackendStatus.ERROR + + return DataBackendStatus.OK + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List archives for a given target stream_id. + + Args: + target (str or None): The target stream_id where to list the archives. + If target is `None`, the `DEFAULT_STREAM_ID` is used instead. + details (bool): Get detailed archive information in addition to archive IDs. + new (bool): Given the history, list only not already read archives. + + Yields: + str: If `details` is False. + dict: If `details` is True. + + Raises: + BackendParameterException: If the `target` is `None` and no + `DEFAULT_STREAM_ID` is given. + BackendException: If a failure during retrieval of archives list occurs. + """ + list_archives_endpoint = self._get_archive_endpoint(stream_id=target) + logger.info("List archives endpoint: %s", list_archives_endpoint) + logger.info("List archives details: %s", str(details)) + + try: + archives = self.client.get(list_archives_endpoint) + except ovh.exceptions.APIError as error: + msg = "Failed to get archives list: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + logger.info("Found %d archives", len(archives)) + + if new: + archives = set(archives) - set(self.get_command_history(self.name, "read")) + logger.debug("New archives: %d", len(archives)) + + if not details: + for archive in archives: + yield archive + + return + + for archive in archives: + yield self._details(target, archive) + + @enforce_query_checks + def read( + self, + *, + query: Union[str, BaseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = 4096, + raw_output: bool = True, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read an archive matching the query in the target stream_id and yield it. + + Args: + query (str or BaseQuery): The ID of the archive to read. + target (str or None): The target stream_id containing the archives. + If target is `None`, the `DEFAULT_STREAM_ID` is used instead. + chunk_size (int or None): The chunk size for reading archives. + raw_output (bool): Ignored. Always set to `True`. + ignore_errors (bool): Ignored. + + Yields: + bytes: The content of the archive matching the query. + + Raises: + BackendException: If a failure during the read operation occurs. + BackendParameterException: If the `query` argument is not an archive name. + """ + if query.query_string is None: + msg = "Invalid query. The query should be a valid archive name" + raise BackendParameterException(msg) + + if not raw_output or not ignore_errors: + logger.warning("The `raw_output` and `ignore_errors` arguments are ignored") + + target = target if target else self.stream_id + logger.debug("Getting archive: %s from stream: %s", query.query_string, target) + + # Stream response (archive content) + url = self._url(query.query_string) + try: + with requests.get(url, stream=True, timeout=self.timeout) as result: + result.raise_for_status() + for chunk in result.iter_content(chunk_size=chunk_size): + yield chunk + except requests.exceptions.HTTPError as error: + msg = "Failed to read archive %s: %s" + logger.error(msg, query.query_string, error) + raise BackendException(msg % (query.query_string, error)) from error + + # Get detailed information about the archive to fetch + details = self._details(target, query.query_string) + # Archive is supposed to have been fully read, add a new entry to + # the history. + self.append_to_history( + { + "backend": self.name, + "command": "read", + # WARNING: previously only the filename was used as the ID + # By changing this and prepending the `target` stream_id previously + # fetched archives will not be marked as read anymore. + "id": f"{target}/{query.query_string}", + "filename": details.get("filename"), + "size": details.get("size"), + "timestamp": now(), + } + ) + + def write( # pylint: disable=too-many-arguments + self, + data: Iterable[Union[bytes, dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """LDP data backend is read-only, calling this method will raise an error.""" + msg = "LDP data backend is read-only, cannot write to %s" + logger.error(msg, target) + raise NotImplementedError(msg % target) + + def _get_archive_endpoint(self, stream_id: Union[None, str] = None) -> str: + """Return OVH's archive endpoint.""" + stream_id = stream_id if stream_id else self.stream_id + if None in (self.service_name, stream_id): + msg = "LDPDataBackend requires to set both service_name and stream_id" + logger.error(msg) + raise BackendParameterException(msg) + return ( + f"/dbaas/logs/{self.service_name}/output/graylog/stream/{stream_id}/archive" + ) + + def _url(self, name: str) -> str: + """Get archive absolute URL.""" + download_url_endpoint = f"{self._get_archive_endpoint()}/{name}/url" + response = self.client.post(download_url_endpoint) + download_url = response.get("url") + logger.debug("Temporary URL: %s", download_url) + return download_url + + def _details(self, stream_id: str, name: str) -> dict: + """Return `name` archive details. + + Expected JSON response looks like: + + { + "archiveId": "5d49d1b3-a3eb-498c-9039-6a482166f888", + "createdAt": "2020-06-18T04:38:59.436634+02:00", + "filename": "2020-06-16.gz", + "md5": "01585b394be0495e38dbb60b20cb40a9", + "retrievalDelay": 0, + "retrievalState": "sealed", + "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", + "size": 67906662, + } + """ + return self.client.get(f"{self._get_archive_endpoint(stream_id)}/{name}") diff --git a/tests/backends/data/test_ldp.py b/tests/backends/data/test_ldp.py new file mode 100644 index 000000000..f1ef187a2 --- /dev/null +++ b/tests/backends/data/test_ldp.py @@ -0,0 +1,769 @@ +"""Tests for Ralph ldp data backend.""" + +import gzip +import json +import logging +import os.path +from collections.abc import Iterable +from operator import itemgetter +from pathlib import Path +from xmlrpc.client import gzip_decode + +import ovh +import pytest +import requests + +from ralph.backends.data.base import BaseOperationType, BaseQuery, DataBackendStatus +from ralph.backends.data.ldp import LDPDataBackend +from ralph.conf import settings +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + + +def test_backends_data_ldp_data_backend_default_instantiation(monkeypatch, fs): + """Test the `LDPDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "APPLICATION_KEY", + "APPLICATION_SECRET", + "CONSUMER_KEY", + "DEFAULT_STREAM_ID", + "ENDPOINT", + "SERVICE_NAME", + "REQUEST_TIMEOUT", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__LDP__{name}", raising=False) + + assert LDPDataBackend.name == "ldp" + assert LDPDataBackend.query_model == BaseQuery + assert LDPDataBackend.default_operation_type == BaseOperationType.INDEX + backend = LDPDataBackend() + assert isinstance(backend.client, ovh.Client) + assert backend.service_name is None + assert backend.stream_id is None + assert backend.timeout is None + + +def test_backends_data_ldp_data_backend_instantiation_with_settings(ldp_backend): + """Test the `LDPDataBackend` instantiation with settings.""" + backend = ldp_backend() + assert isinstance(backend.client, ovh.Client) + assert backend.service_name == "foo" + assert backend.stream_id == "bar" + + try: + ldp_backend(service_name="bar") + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"LDPDataBackend should not raise exceptions: {err}") + + +@pytest.mark.parametrize( + "exception_class", + [ovh.exceptions.HTTPError, ovh.exceptions.InvalidResponse], +) +def test_backends_data_ldp_data_backend_status_method_with_error_status( + exception_class, ldp_backend, monkeypatch +): + """Test the `LDPDataBackend.status` method, given a failed request to OVH's archive + endpoint, should return `DataBackendStatus.ERROR`. + """ + + def mock_get(_): + """Mock the ovh.Client get method always raising an exception.""" + raise exception_class() + + def mock_get_archive_endpoint(): + """Mock the `get_archive_endpoint` method always raising an exception.""" + raise BackendParameterException() + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + assert backend.status() == DataBackendStatus.ERROR + monkeypatch.setattr(backend, "_get_archive_endpoint", mock_get_archive_endpoint) + assert backend.status() == DataBackendStatus.ERROR + + +def test_backends_data_ldp_data_backend_status_method_with_ok_status( + ldp_backend, monkeypatch +): + """Test the `LDPDataBackend.status` method, given a successful request to OVH's + archive endpoint, the `status` method should return `DataBackendStatus.OK`. + """ + + def mock_get(_): + """Mock the ovh.Client get method always returning an empty list.""" + return [] + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + assert backend.status() == DataBackendStatus.OK + + +def test_backends_data_ldp_data_backend_list_method_with_invalid_target(ldp_backend): + """Test the `LDPDataBackend.list` method given no default `stream_id` and no target + argument should raise a `BackendParameterException`. + """ + + backend = ldp_backend(stream_id=None) + error = "LDPDataBackend requires to set both service_name and stream_id" + with pytest.raises(BackendParameterException, match=error): + list(backend.list()) + + +@pytest.mark.parametrize( + "exception_class", + [ovh.exceptions.HTTPError, ovh.exceptions.InvalidResponse], +) +def test_backends_data_ldp_data_backend_list_method_failure( + exception_class, ldp_backend, monkeypatch +): + """Test the `LDPDataBackend.list` method, given a failed OVH API request should + raise a `BackendException`. + """ + + def mock_get(_): + """Mock the ovh.Client get method always raising an exception.""" + raise exception_class("OVH Error") + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + msg = r"Failed to get archives list: OVH Error" + with pytest.raises(BackendException, match=msg): + list(backend.list()) + + +@pytest.mark.parametrize( + "archives,target,expected_stream_id", + [ + # Given no archives at the OVH's archive endpoint and no `target`, + # the `list` method should use the default `stream_id` target and yield nothing. + ([], None, "bar"), + # Given one archive at the OVH's archive endpoint and no `target`, the `list` + # method should use the default `stream_id` target yield the archive. + (["achive_1"], None, "bar"), + # Given one archive at the OVH's archive endpoint and a `target`, the `list` + # method should use the provided `stream_id` target yield the archive. + (["achive_1"], "foo", "foo"), + # Given some archives at the OVH's archive endpoint and no `target`, the `list` + # method should use the default `stream_id` target yield the archives. + (["achive_1", "achive_2"], None, "bar"), + ], +) +def test_backends_data_ldp_data_backend_list_method_without_history( + archives, target, expected_stream_id, ldp_backend, monkeypatch +): + """Test the `LDPDataBackend.list` method without history.""" + + def mock_get(url): + """Mock the OVH client get request.""" + assert expected_stream_id in url + return archives + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + result = backend.list(target) + assert isinstance(result, Iterable) + assert list(result) == archives + + +@pytest.mark.parametrize( + "archives,target,expected_stream_id", + [ + # Given no archives at the OVH's archive endpoint and no `target`, + # the `list` method should use the default `stream_id` target and yield nothing. + ([], None, "bar"), + # Given one archive at the OVH's archive endpoint and no `target`, the `list` + # method should use the default `stream_id` target yield the archive. + (["achive_1"], None, "bar"), + # Given one archive at the OVH's archive endpoint and a `target`, the `list` + # method should use the provided `stream_id` target yield the archive. + (["achive_1"], "foo", "foo"), + # Given some archives at the OVH's archive endpoint and no `target`, the `list` + # method should use the default `stream_id` target yield the archives. + (["achive_1", "achive_2"], None, "bar"), + ], +) +def test_backends_data_ldp_data_backend_list_method_with_details( + archives, target, expected_stream_id, ldp_backend, monkeypatch +): + """Test the `LDPDataBackend.list` method with `details` set to `True`.""" + details_responses = [ + { + "archiveId": archive, + "createdAt": "2020-06-18T04:38:59.436634+02:00", + "filename": "2020-06-18.gz", + "md5": "01585b394be0495e38dbb60b20cb40a9", + "retrievalDelay": 0, + "retrievalState": "sealed", + "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", + "size": 67906662, + } + for archive in archives + ] + + get_details_response = (response for response in details_responses) + + def mock_get(url): + """Mock the OVH client get request.""" + assert expected_stream_id in url + # list request + if url.endswith("archive"): + return archives + # details request + return next(get_details_response) + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + + result = backend.list(target, details=True) + assert isinstance(result, Iterable) + assert list(result) == details_responses + + +@pytest.mark.parametrize("target,expected_stream_id", [(None, "bar"), ("baz", "baz")]) +def test_backends_data_ldp_data_backend_list_method_with_history( + target, expected_stream_id, ldp_backend, monkeypatch, settings_fs +): + """Test the `LDPDataBackend.list` method with history.""" + # pylint: disable=unused-argument + + def mock_get(url): + """Mock the OVH client get request.""" + assert expected_stream_id in url + return ["archive_1", "archive_2", "archive_3"] + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + + # Given an empty history and `new` set to `True`, the `list` method should yield all + # archives. + expected = ["archive_1", "archive_2", "archive_3"] + result = backend.list(target, new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + # Add archive_1 to history + backend.history.append( + { + "backend": "ldp", + "action": "read", + "id": "archive_1", + "filename": "2020-10-07.gz", + "size": 23424233, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing one matching archive and `new` set to `True`, the + # `list` method should yield all archives except the matching one. + expected = ["archive_2", "archive_3"] + result = backend.list(target, new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + # Add archive_2 to history + backend.history.append( + { + "backend": "ldp", + "action": "read", + "id": "archive_2", + "filename": "2020-10-07.gz", + "size": 23424233, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing two matching archives and `new` set to `True`, the + # `list` method should yield all archives except the matching ones. + expected = ["archive_3"] + result = backend.list(target, new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + # Add archive_3 to history + backend.history.append( + { + "backend": "ldp", + "action": "read", + "id": "archive_3", + "filename": "2020-10-07.gz", + "size": 23424233, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing all matching archives and `new` set to `True`, the + # `list` method should yield nothing. + expected = [] + result = backend.list(target, new=True) + assert isinstance(result, Iterable) + assert sorted(result) == expected + + +@pytest.mark.parametrize("target,expected_stream_id", [(None, "bar"), ("baz", "baz")]) +def test_backends_data_ldp_data_backend_list_method_with_history_and_details( + target, expected_stream_id, ldp_backend, monkeypatch, settings_fs +): + """Test the `LDPDataBackend.list` method with a history and detailed output.""" + # pylint: disable=unused-argument + details_responses = [ + { + "archiveId": "archive_1", + "createdAt": "2020-06-18T04:38:59.436634+02:00", + "filename": "2020-06-16.gz", + "md5": "01585b394be0495e38dbb60b20cb40a9", + "retrievalDelay": 0, + "retrievalState": "sealed", + "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", + "size": 67906662, + }, + { + "archiveId": "archive_2", + "createdAt": "2020-06-18T04:38:59.436634+02:00", + "filename": "2020-06-18.gz", + "md5": "01585b394be0495e38dbb60b20cb40a9", + "retrievalDelay": 0, + "retrievalState": "sealed", + "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", + "size": 67906662, + }, + { + "archiveId": "archive_3", + "createdAt": "2020-06-19T04:38:59.436634+02:00", + "filename": "2020-06-19.gz", + "md5": "01585b394be0495e38dbb60b20cb40a9", + "retrievalDelay": 0, + "retrievalState": "sealed", + "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", + "size": 67906662, + }, + ] + + get_details_response = (response for response in details_responses) + + def mock_get(url): + """Mock the OVH client get request.""" + assert expected_stream_id in url + # list request + if url.endswith("archive"): + return ["archive_1", "archive_2", "archive_3"] + # details request + return next(get_details_response) + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "get", mock_get) + + # Given an empty history and `new` and `details` set to `True`, the `list` method + # should yield all archives with additional details. + expected = details_responses + result = backend.list(target, details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("archiveId")) == expected + + # Add archive_1 to history + backend.history.append( + { + "backend": "ldp", + "action": "read", + "id": "archive_1", + "filename": "2020-06-16.gz", + "size": 23424233, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # We expect two requests to retrieve details for archive 2 and 3. + get_details_response = (response for response in details_responses[1:]) + + # Given a history containing one matching archive and `new` and `details` set to + # `True`, the `list` method should yield all archives in the directory with + # additional details, except the matching one. + expected = [details_responses[1], details_responses[2]] + result = backend.list(target, details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("archiveId")) == expected + + # Add archive_2 to history + backend.history.append( + { + "backend": "ldp", + "action": "read", + "id": "archive_2", + "filename": "2020-06-18.gz", + "size": 23424233, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # We expect one request to retrieve details for archive 3. + get_details_response = (response for response in details_responses[2:]) + + # Given a history containing two matching archives and `new` and `details` set to + # `True`, the `list` method should yield all archives with additional details, + # except the matching ones. + expected = [details_responses[2]] + result = backend.list(target, details=True, new=True) + assert isinstance(result, Iterable) + assert sorted(result, key=itemgetter("archiveId")) == expected + + # Add archive_3 to history + backend.history.append( + { + "backend": "ldp", + "action": "read", + "id": "archive_3", + "filename": "2020-06-19.gz", + "size": 23424233, + "timestamp": "2020-10-07T16:37:25.887664+00:00", + } + ) + + # Given a history containing all matching archives and `new` and `details` set to + # `True`, the `list` method should yield nothing. + expected = [] + result = backend.list(target, details=True, new=True) + assert isinstance(result, Iterable) + assert list(result) == expected + + +def test_backends_data_ldp_data_backend_read_method_without_raw_ouput( + ldp_backend, caplog, monkeypatch +): + """Test the `LDPDataBackend.read method, given `raw_output` set to `False`, should + log a warning message. + """ + + class MockResponse: + """Mock the requests response.""" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def raise_for_status(self): + """Ignored.""" + + def iter_content(self, chunk_size): + """Fake content file iteration.""" + # pylint: disable=no-self-use,unused-argument + yield + + def mock_requests_get(url, stream=True, timeout=None): + """Mock the request get method.""" + # pylint: disable=unused-argument + return MockResponse() + + def mock_get(url): + """Mock the OVH client get request.""" + # pylint: disable=unused-argument + return {"filename": "archive_name", "size": 10} + + backend = ldp_backend() + + monkeypatch.setattr(requests, "get", mock_requests_get) + monkeypatch.setattr(backend, "_url", lambda *_: "/") + monkeypatch.setattr(backend.client, "get", mock_get) + + with caplog.at_level(logging.WARNING): + list(backend.read(query="archiveID", raw_output=False)) + + assert ( + "ralph.backends.data.ldp", + logging.WARNING, + "The `raw_output` and `ignore_errors` arguments are ignored", + ) in caplog.record_tuples + + +def test_backends_data_ldp_data_backend_read_method_without_ignore_errors( + ldp_backend, caplog, monkeypatch +): + """Test the `LDPDataBackend.read method, given `ignore_errors` set to `False`, + should log a warning message. + """ + + class MockResponse: + """Mock the requests response.""" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def raise_for_status(self): + """Ignored.""" + + def iter_content(self, chunk_size): + """Fake content file iteration.""" + # pylint: disable=no-self-use,unused-argument + yield + + def mock_requests_get(url, stream=True, timeout=None): + """Mock the request get method.""" + # pylint: disable=unused-argument + return MockResponse() + + def mock_get(url): + """Mock the OVH client get request.""" + # pylint: disable=unused-argument + return {"filename": "archive_name", "size": 10} + + backend = ldp_backend() + + monkeypatch.setattr(requests, "get", mock_requests_get) + backend = ldp_backend() + monkeypatch.setattr(backend, "_url", lambda *_: "/") + monkeypatch.setattr(backend.client, "get", mock_get) + + with caplog.at_level(logging.WARNING): + list(backend.read(query="archiveID", ignore_errors=False)) + + assert ( + "ralph.backends.data.ldp", + logging.WARNING, + "The `raw_output` and `ignore_errors` arguments are ignored", + ) in caplog.record_tuples + + +def test_backends_data_ldp_data_backend_read_method_with_invalid_query(ldp_backend): + """Test the `LDPDataBackend.read` method given an invalid `query` argument should + raise a `BackendParameterException`. + """ + backend = ldp_backend() + # Given no `query`, the `read` method should raise a `BackendParameterException`. + error = "Invalid query. The query should be a valid archive name" + with pytest.raises(BackendParameterException, match=error): + list(backend.read()) + + +def test_backends_data_ldp_data_backend_read_method_with_failure( + ldp_backend, monkeypatch +): + """Test the `LDPDataBackend.read` method, given a request failure, should raise a + `BackendException`. + """ + + def mock_ovh_post(url): + """Mock the OVH Client post request.""" + # pylint: disable=unused-argument + + return { + "expirationDate": "2020-10-13T12:59:37.326131+00:00", + "url": ( + "https://storage.gra.cloud.ovh.net/v1/" + "AUTH_-c3b123f595c46e789acdd1227eefc13/" + "gra2-pcs/5eba98fb4fcb481001180e4b/" + "2020-06-01.gz?" + "temp_url_sig=e1b3ab10a9149a4ff5dcb95f40f21063780d26f7&" + "temp_url_expires=1602593977" + ), + } + + class MockUnsuccessfulResponse: + """Mock the requests response.""" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def raise_for_status(self): + """Raise an `HttpError`.""" + # pylint: disable=no-self-use + raise requests.HTTPError("Failure during request") + + def mock_requests_get(url, stream=True, timeout=None): + """Mock the request get method.""" + # pylint: disable=unused-argument + + return MockUnsuccessfulResponse() + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.ldp.now", lambda: frozen_now) + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "post", mock_ovh_post) + monkeypatch.setattr(requests, "get", mock_requests_get) + + error = r"Failed to read archive foo: Failure during request" + with pytest.raises(BackendException, match=error): + next(backend.read(query="foo")) + + +def test_backends_data_ldp_data_backend_read_method_with_query( + ldp_backend, monkeypatch, fs +): + """Test the `LDPDataBackend.read` method, given a query argument.""" + # pylint: disable=invalid-name + + # Create fake archive to stream. + archive_path = Path("/tmp/2020-06-16.gz") + archive_content = {"foo": "bar"} + with gzip.open(archive_path, "wb") as archive_file: + archive_file.write(bytes(json.dumps(archive_content), encoding="utf-8")) + + def mock_ovh_post(url): + """Mock the OVH Client post request.""" + # pylint: disable=unused-argument + + return { + "expirationDate": "2020-10-13T12:59:37.326131+00:00", + "url": ( + "https://storage.gra.cloud.ovh.net/v1/" + "AUTH_-c3b123f595c46e789acdd1227eefc13/" + "gra2-pcs/5eba98fb4fcb481001180e4b/" + "2020-06-01.gz?" + "temp_url_sig=e1b3ab10a9149a4ff5dcb95f40f21063780d26f7&" + "temp_url_expires=1602593977" + ), + } + + def mock_ovh_get(url): + """Mock the OVH client get request.""" + # pylint: disable=unused-argument + + return { + "archiveId": "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", + "createdAt": "2020-06-18T04:38:59.436634+02:00", + "filename": "2020-06-16.gz", + "md5": "01585b394be0495e38dbb60b20cb40a9", + "retrievalDelay": 0, + "retrievalState": "sealed", + "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", + "size": 67906662, + } + + class MockRequestsResponse: + """Mock the requests response.""" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def iter_content(self, chunk_size): + """Fake content file iteration.""" + # pylint: disable=no-self-use + + with archive_path.open("rb") as archive: + while chunk := archive.read(chunk_size): + yield chunk + + def raise_for_status(self): + """Ignored.""" + + def mock_requests_get(url, stream=True, timeout=None): + """Mock the request get method.""" + # pylint: disable=unused-argument + + return MockRequestsResponse() + + # Freeze the ralph.utils.now() value. + frozen_now = now() + monkeypatch.setattr("ralph.backends.data.ldp.now", lambda: frozen_now) + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "post", mock_ovh_post) + monkeypatch.setattr(backend.client, "get", mock_ovh_get) + monkeypatch.setattr(requests, "get", mock_requests_get) + + fs.create_dir(settings.APP_DIR) + assert not os.path.exists(settings.HISTORY_FILE) + + result = b"".join(backend.read(query="5d5c4c93-04a4-42c5-9860-f51fa4044aa1")) + + assert os.path.exists(settings.HISTORY_FILE) + assert backend.history == [ + { + "backend": "ldp", + "command": "read", + "id": "bar/5d5c4c93-04a4-42c5-9860-f51fa4044aa1", + "filename": "2020-06-16.gz", + "size": 67906662, + "timestamp": frozen_now, + } + ] + + assert json.loads(gzip_decode(result)) == archive_content + + +def test_backends_data_ldp_data_backend_write_method(ldp_backend): + """Test the `LDPDataBackend.write` method.""" + backend = ldp_backend() + msg = "LDP data backend is read-only, cannot write to fake" + with pytest.raises(NotImplementedError, match=msg): + backend.write("truly", "fake", "content") + + +@pytest.mark.parametrize( + "args,expected", + [ + ([], "/dbaas/logs/foo/output/graylog/stream/bar/archive"), + (["baz"], "/dbaas/logs/foo/output/graylog/stream/baz/archive"), + ], +) +def test_backends_data_ldp_data_backend_get_archive_endpoint_method_with_valid_input( + ldp_backend, args, expected +): + """Test the `LDPDataBackend.get_archive_endpoint` method, given valid input, should + return the expected url. + """ + # pylint: disable=protected-access + assert ldp_backend()._get_archive_endpoint(*args) == expected + + +@pytest.mark.parametrize( + "service_name,stream_id", [(None, "bar"), ("foo", None), (None, None)] +) +def test_backends_data_ldp_data_backend_get_archive_endpoint_method_with_invalid_input( + ldp_backend, service_name, stream_id +): + """Test the `LDPDataBackend.get_archive_endpoint` method, given invalid input + parameters, should raise a BackendParameterException. + """ + # pylint: disable=protected-access + with pytest.raises( + BackendParameterException, + match="LDPDataBackend requires to set both service_name and stream_id", + ): + ldp_backend( + service_name=service_name, stream_id=stream_id + )._get_archive_endpoint() + + with pytest.raises( + BackendParameterException, + match="LDPDataBackend requires to set both service_name and stream_id", + ): + ldp_backend(service_name=service_name, stream_id=None)._get_archive_endpoint( + stream_id + ) + + +def test_backends_data_ldp_data_backend_url_method(monkeypatch, ldp_backend): + """Test the `LDPDataBackend.url` method.""" + # pylint: disable=protected-access + archive_name = "5d49d1b3-a3eb-498c-9039-6a482166f888" + archive_url = ( + "https://storage.gra.cloud.ovh.net/v1/" + "AUTH_-c3b123f595c46e789acdd1227eefc13/" + "gra2-pcs/5eba98fb4fcb481001180e4b/" + "2020-06-01.gz?" + "temp_url_sig=e1b3ab10a9149a4ff5dcb95f40f21063780d26f7&" + "temp_url_expires=1602593977" + ) + + def mock_post(url): + """Mock the OVH Client post request.""" + assert url.endswith(f"{archive_name}/url") + return {"expirationDate": "2020-10-13T12:59:37.326131", "url": archive_url} + + backend = ldp_backend() + monkeypatch.setattr(backend.client, "post", mock_post) + assert backend._url(archive_name) == archive_url diff --git a/tests/conftest.py b/tests/conftest.py index af13a0707..8baaa65b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ es_forwarding, events, fs_backend, + ldp_backend, lrs, mongo, mongo_forwarding, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 582df263e..596c564ec 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -6,7 +6,6 @@ import random import time from contextlib import asynccontextmanager -from enum import Enum from functools import lru_cache from multiprocessing import Process from pathlib import Path @@ -23,6 +22,7 @@ from pymongo.errors import CollectionInvalid from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings +from ralph.backends.data.ldp import LDPDataBackend from ralph.backends.data.swift import SwiftDataBackend, SwiftDataBackendSettings from ralph.backends.database.clickhouse import ClickHouseDatabase from ralph.backends.database.es import ESDatabase @@ -111,25 +111,6 @@ def get_mongo_test_backend(): ) -class NamedClassA: - """An example named class.""" - - name = "A" - - -class NamedClassB: - """A second example named class.""" - - name = "B" - - -class NamedClassEnum(Enum): - """A named test classes Enum.""" - - A = "tests.fixtures.backends.NamedClassA" - B = "tests.fixtures.backends.NamedClassB" - - def get_es_fixture(host=ES_TEST_HOSTS, index=ES_TEST_INDEX): """Create / delete an ElasticSearch test index and yields an instantiated client. @@ -336,6 +317,27 @@ def settings_fs(fs, monkeypatch): ) +@pytest.fixture +def ldp_backend(settings_fs): + """Returns the `get_ldp_data_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name,unused-argument + + def get_ldp_data_backend(service_name: str = "foo", stream_id: str = "bar"): + """Returns an instance of LDPDataBackend.""" + settings = LDPDataBackend.settings_class( + APPLICATION_KEY="fake_key", + APPLICATION_SECRET="fake_secret", + CONSUMER_KEY="another_fake_key", + DEFAULT_STREAM_ID=stream_id, + ENDPOINT="ovh-eu", + SERVICE_NAME=service_name, + REQUEST_TIMEOUT=None, + ) + return LDPDataBackend(settings) + + return get_ldp_data_backend + + @pytest.fixture def swift(): """Return get_swift_storage function.""" From cf9b825d08b5283b024dc19f18c7cc8947f7fdb1 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Tue, 4 Apr 2023 19:32:19 +0200 Subject: [PATCH 10/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20s3=20backend?= =?UTF-8?q?=20with=20unified=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add S3 backend under the new common `data` interface --- .circleci/config.yml | 2 +- Dockerfile | 2 +- setup.cfg | 1 + src/ralph/backends/data/s3.py | 388 +++++++++++++++++++ tests/backends/data/test_s3.py | 658 +++++++++++++++++++++++++++++++++ tests/conftest.py | 1 + tests/fixtures/backends.py | 22 ++ 7 files changed, 1072 insertions(+), 2 deletions(-) create mode 100644 src/ralph/backends/data/s3.py create mode 100644 tests/backends/data/test_s3.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 918022fec..436e9e2eb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -158,7 +158,7 @@ jobs: - v1-dependencies-<< parameters.python-image >>-{{ .Revision }} - run: name: Install development dependencies - command: pip install --user .[backend-clickhouse,backend-es,backend-ldp,backend-lrs,backend-mongo,backend-swift,backend-ws,cli,dev,lrs] + command: pip install --user .[backend-clickhouse,backend-es,backend-ldp,backend-lrs,backend-mongo,backend-s3,backend-swift,backend-ws,cli,dev,lrs] - save_cache: paths: - ~/.local diff --git a/Dockerfile b/Dockerfile index 5c84878c1..044edc092 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ libffi-dev && \ rm -rf /var/lib/apt/lists/* -RUN pip install .[backend-clickhouse,backend-es,backend-ldp,backend-lrs,backend-mongo,backend-swift,backend-ws,cli,lrs] +RUN pip install .[backend-clickhouse,backend-es,backend-ldp,backend-lrs,backend-mongo,backend-s3,backend-swift,backend-ws,cli,lrs] # -- Core -- diff --git a/setup.cfg b/setup.cfg index e92a11ce7..500904095 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,6 +56,7 @@ backend-mongo = backend-s3 = boto3>=1.24.70 botocore>=1.27.71 + requests-toolbelt>=1.0.0 backend-swift = python-keystoneclient>=5.0.0 python-swiftclient>=4.0.0 diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py new file mode 100644 index 000000000..3222b41ad --- /dev/null +++ b/src/ralph/backends/data/s3.py @@ -0,0 +1,388 @@ +"""S3 data backend for Ralph.""" + +import json +import logging +from io import IOBase +from itertools import chain +from typing import Iterable, Iterator, Union +from uuid import uuid4 + +import boto3 +from boto3.s3.transfer import TransferConfig +from botocore.exceptions import ( + ClientError, + ParamValidationError, + ReadTimeoutError, + ResponseStreamingError, +) +from botocore.response import StreamingBody +from requests_toolbelt import StreamingIterator + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.backends.mixins import HistoryMixin +from ralph.conf import BaseSettingsConfig +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + +logger = logging.getLogger(__name__) + + +class S3DataBackendSettings(BaseDataBackendSettings): + """S3 data backend default configuration. + + Attributes: + ACCESS_KEY_ID (str): The access key id for the S3 account. + SECRET_ACCESS_KEY (str): The secret key for the S3 account. + SESSION_TOKEN (str): The session token for the S3 account. + ENDPOINT_URL (str): The endpoint URL of the S3. + DEFAULT_REGION (str): The default region used in instantiating the client. + DEFAULT_BUCKET_NAME (str): The default bucket name targeted. + DEFAULT_CHUNK_SIZE (str): The default chunk size for reading and writing + objects. + LOCALE_ENCODING (str): The encoding used for writing dictionaries to objects. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__S3__" + + ACCESS_KEY_ID: str = None + SECRET_ACCESS_KEY: str = None + SESSION_TOKEN: str = None + ENDPOINT_URL: str = None + DEFAULT_REGION: str = None + DEFAULT_BUCKET_NAME: str = None + DEFAULT_CHUNK_SIZE: int = 4096 + LOCALE_ENCODING: str = "utf8" + + +class S3DataBackend(HistoryMixin, BaseDataBackend): + """S3 data backend.""" + + name = "s3" + default_operation_type = BaseOperationType.CREATE + settings_class = S3DataBackendSettings + + def __init__(self, settings: settings_class = None): + """Instantiate the AWS S3 client.""" + self.settings = settings if settings else self.settings_class() + + self.default_bucket_name = self.settings.DEFAULT_BUCKET_NAME + self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE + self.locale_encoding = self.settings.LOCALE_ENCODING + self._client = None + + @property + def client(self): + """Create a boto3 client if it doesn't exist.""" + if not self._client: + self._client = boto3.client( + "s3", + aws_access_key_id=self.settings.ACCESS_KEY_ID, + aws_secret_access_key=self.settings.SECRET_ACCESS_KEY, + aws_session_token=self.settings.SESSION_TOKEN, + region_name=self.settings.DEFAULT_REGION, + endpoint_url=self.settings.ENDPOINT_URL, + ) + return self._client + + def status(self) -> DataBackendStatus: + """Implement data backend checks (e.g. connection, cluster status). + + Return: + DataBackendStatus: The status of the data backend. + """ + try: + self.client.head_bucket(Bucket=self.default_bucket_name) + except ClientError: + return DataBackendStatus.ERROR + + return DataBackendStatus.OK + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List objects for the target bucket. + + Args: + target (str or None): The target bucket to list from. + If target is `None`, the `default_bucket_name` is used instead. + details (bool): Get detailed object information instead of just object name. + new (bool): Given the history, list only unread files. + + Yields: + str: The next object name. (If details is False). + dict: The next object details. (If details is True). + + Raises: + BackendException: If a failure occurs. + """ + if target is None: + target = self.default_bucket_name + + objects_to_skip = set() + if new: + objects_to_skip = set(self.get_command_history(self.name, "read")) + + try: + paginator = self.client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=target) + for objects in page_iterator: + if "Contents" not in objects: + continue + for obj in objects["Contents"]: + if new and f"{target}/{obj['Key']}" in objects_to_skip: + continue + if details: + obj["LastModified"] = obj["LastModified"].isoformat() + yield obj + else: + yield obj["Key"] + except ClientError as err: + error_msg = err.response["Error"]["Message"] + msg = "Failed to list the bucket %s: %s" + logger.error(msg, target, error_msg) + raise BackendException(msg % (target, error_msg)) from err + + @enforce_query_checks + def read( + self, + *, + query: Union[str, BaseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read an object matching the `query` in the `target` bucket and yields it. + + Args: + query: (str or BaseQuery): The ID of the object to read. + target (str or None): The target bucket containing the objects. + If target is `None`, the `default_bucket` is used instead. + chunk_size (int or None): The chunk size for reading objects. + raw_output (bool): Controls whether to yield bytes or dictionaries. + ignore_errors (bool): If `True`, errors during the read operation + will be ignored and logged. If `False` (default), a `BackendException` + will be raised if an error occurs. + + Yields: + dict: If `raw_output` is False. + bytes: If `raw_output` is True. + + Raises: + BackendException: If a failure during the read operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid and + `ignore_errors` is set to `False`. + """ + if query.query_string is None: + msg = "Invalid query. The query should be a valid object name." + logger.error(msg) + raise BackendParameterException(msg) + + if not chunk_size: + chunk_size = self.default_chunk_size + + if target is None: + target = self.default_bucket_name + + try: + response = self.client.get_object(Bucket=target, Key=query.query_string) + except ClientError as err: + error_msg = err.response["Error"]["Message"] + msg = "Failed to download %s: %s" + logger.error(msg, query.query_string, error_msg) + if not ignore_errors: + raise BackendException(msg % (query.query_string, error_msg)) from err + + reader = self._read_raw if raw_output else self._read_dict + try: + for chunk in reader(response["Body"], chunk_size, ignore_errors): + yield chunk + except (ReadTimeoutError, ResponseStreamingError) as err: + msg = "Failed to read chunk from object %s" + logger.error(msg, query.query_string) + if not ignore_errors: + raise BackendException(msg % (query.query_string)) from err + + # Archive fetched, add a new entry to the history. + self.append_to_history( + { + "backend": self.name, + "action": "read", + "id": target + "/" + query.query_string, + "size": response["ContentLength"], + "timestamp": now(), + } + ) + + def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Write `data` records to the `target` bucket and return their count. + + Args: + data: (Iterable or IOBase): The data to write. + target (str or None): The target bucket and the target object + separated by a `/`. + If target is `None`, the default bucket is used and a random + (uuid4) object is created. + If target does not contain a `/`, it is assumed to be the + target object and the default bucket is used. + chunk_size (int or None): Ignored. + ignore_errors (bool): If `True`, errors during the write operation + are ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write + operation. + If operation_type is `CREATE` or `INDEX`, the target object is + expected to be absent. If the target object exists a + `BackendException` is raised. + + Return: + int: The number of written objects. + + Raise: + BackendException: If a failure during the write operation occurs. + BackendParameterException: If a backend argument value is not valid. + """ + data = iter(data) + try: + first_record = next(data) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return 0 + + if not operation_type: + operation_type = self.default_operation_type + + if not target: + target = f"{self.default_bucket_name}/{now()}-{uuid4()}" + logger.info( + "Target not specified; using default bucket with random file name: %s", + target, + ) + + elif "/" not in target: + target = f"{self.default_bucket_name}/{target}" + logger.info( + "Target not specified; using default bucket: %s", + target, + ) + + target_bucket, target_object = target.split("/", 1) + + if operation_type in [ + BaseOperationType.APPEND, + BaseOperationType.DELETE, + BaseOperationType.UPDATE, + ]: + msg = "%s operation_type is not allowed." + logger.error(msg, operation_type.name) + raise BackendParameterException(msg % operation_type.name) + + if target_object in list(self.list(target=target_bucket)): + msg = "%s already exists and overwrite is not allowed for operation %s" + logger.error(msg, target_object, operation_type) + raise BackendException(msg % (target_object, operation_type)) + + logger.info("Creating archive: %s", target_object) + + data = chain((first_record,), data) + if isinstance(first_record, dict): + data = self._parse_dict_to_bytes(data, ignore_errors) + + counter = {"count": 0} + data = self._count(data, counter) + + # Using StreamingIterator from requests-toolbelt but without specifying a size + # as we will not use it. It implements the `read` method for iterators. + data = StreamingIterator(0, data) + + try: + self.client.upload_fileobj( + Bucket=target_bucket, + Key=target_object, + Fileobj=data, + Config=TransferConfig(multipart_chunksize=chunk_size), + ) + response = self.client.head_object(Bucket=target_bucket, Key=target_object) + except (ClientError, ParamValidationError) as exc: + msg = "Failed to upload %s" + logger.error(msg, target) + raise BackendException(msg % target) from exc + + # Archive written, add a new entry to the history + self.append_to_history( + { + "backend": self.name, + "action": "write", + "operation_type": operation_type.value, + "id": target, + "size": response["ContentLength"], + "timestamp": now(), + } + ) + + return counter["count"] + + @staticmethod + def _read_raw( + obj: StreamingBody, chunk_size: int, _ignore_errors: bool + ) -> Iterator[bytes]: + """Read the `object` in chunks of size `chunk_size` and yield them.""" + for chunk in obj.iter_chunks(chunk_size): + yield chunk + + @staticmethod + def _read_dict( + obj: StreamingBody, chunk_size: int, ignore_errors: bool + ) -> Iterator[dict]: + """Read the `object` by line and yield JSON parsed dictionaries.""" + for line in obj.iter_lines(chunk_size): + try: + yield json.loads(line) + except (TypeError, json.JSONDecodeError) as err: + msg = "Raised error: %s" + logger.error(msg, err) + if not ignore_errors: + raise BackendException(msg % err) from err + + @staticmethod + def _parse_dict_to_bytes( + statements: Iterable[dict], ignore_errors: bool + ) -> Iterator[bytes]: + """Read the `statements` Iterable and yield bytes.""" + for statement in statements: + try: + yield bytes(f"{json.dumps(statement)}\n", encoding="utf-8") + except TypeError as error: + msg = "Failed to encode JSON: %s, for document %s" + logger.error(msg, error, statement) + if ignore_errors: + continue + raise BackendException(msg % (error, statement)) from error + + @staticmethod + def _count( + statements: Union[Iterable[bytes], Iterable[dict]], + counter: dict, + ) -> Iterator: + """Count the elements in the `statements` Iterable and yield element.""" + for statement in statements: + counter["count"] += 1 + yield statement diff --git a/tests/backends/data/test_s3.py b/tests/backends/data/test_s3.py new file mode 100644 index 000000000..d8bfc4a3a --- /dev/null +++ b/tests/backends/data/test_s3.py @@ -0,0 +1,658 @@ +"""Tests for Ralph S3 data backend.""" + +import datetime +import json +import logging + +import boto3 +import pytest +from botocore.exceptions import ClientError, ResponseStreamingError +from moto import mock_s3 + +from ralph.backends.data.base import BaseOperationType, BaseQuery, DataBackendStatus +from ralph.backends.data.s3 import S3DataBackend, S3DataBackendSettings +from ralph.exceptions import BackendException, BackendParameterException + + +def test_backends_data_s3_backend_default_instantiation( + monkeypatch, fs +): # pylint: disable=invalid-name + """Test the `S3DataBackend` default instantiation.""" + fs.create_file(".env") + backend_settings_names = [ + "ACCESS_KEY_ID", + "SECRET_ACCESS_KEY", + "SESSION_TOKEN", + "ENDPOINT_URL", + "DEFAULT_REGION", + "DEFAULT_BUCKET_NAME", + "DEFAULT_CHUNK_SIZE", + "LOCALE_ENCODING", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__S3__{name}", raising=False) + + assert S3DataBackend.name == "s3" + assert S3DataBackend.query_model == BaseQuery + assert S3DataBackend.default_operation_type == BaseOperationType.CREATE + assert S3DataBackend.settings_class == S3DataBackendSettings + backend = S3DataBackend() + assert backend.default_bucket_name is None + assert backend.default_chunk_size == 4096 + assert backend.locale_encoding == "utf8" + + +def test_backends_data_s3_data_backend_instantiation_with_settings(): + """Test the `S3DataBackend` instantiation with settings.""" + settings_ = S3DataBackend.settings_class( + ACCESS_KEY_ID="access_key", + SECRET_ACCESS_KEY="secret", + SESSION_TOKEN="session_token", + ENDPOINT_URL="http://endpoint/url", + DEFAULT_REGION="us-west-2", + DEFAULT_BUCKET_NAME="bucket", + DEFAULT_CHUNK_SIZE=1000, + LOCALE_ENCODING="utf-16", + ) + backend = S3DataBackend(settings_) + assert backend.default_bucket_name == "bucket" + assert backend.default_chunk_size == 1000 + assert backend.locale_encoding == "utf-16" + + try: + S3DataBackend(settings_) + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"S3DataBackend should not raise exceptions: {err}") + + +@mock_s3 +def test_backends_data_s3_data_backend_status_method(s3_backend): + """Test the `S3DataBackend.status` method.""" + + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + + assert s3_backend().status() == DataBackendStatus.ERROR + + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + assert s3_backend().status() == DataBackendStatus.OK + + +@mock_s3 +def test_backends_data_s3_data_backend_list_should_yield_archive_names( + s3_backend, +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.list` method successfully connects to the S3 + data, the S3 backend list method should yield the archives. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-04-29.gz", + Body=json.dumps({"id": "1", "foo": "bar"}), + ) + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-04-30.gz", + Body=json.dumps({"id": "2", "some": "data"}), + ) + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-10-01.gz", + Body=json.dumps({"id": "3", "other": "info"}), + ) + + listing = [ + {"name": "2022-04-29.gz"}, + {"name": "2022-04-30.gz"}, + {"name": "2022-10-01.gz"}, + ] + + s3 = s3_backend() + + s3.history.extend( + [ + {"id": "bucket_name/2022-04-29.gz", "backend": "s3", "command": "read"}, + {"id": "bucket_name/2022-04-30.gz", "backend": "s3", "command": "read"}, + ] + ) + + try: + response_list = s3.list() + response_list_new = s3.list(new=True) + response_list_details = s3.list(details=True) + except Exception: # pylint:disable=broad-except + pytest.fail("S3 backend should not raise exception on successful list") + + assert list(response_list) == [x["name"] for x in listing] + assert list(response_list_new) == ["2022-10-01.gz"] + assert [x["Key"] for x in response_list_details] == [x["name"] for x in listing] + + +@mock_s3 +def test_backends_data_s3_list_on_empty_bucket_should_do_nothing( + s3_backend, +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.list` method successfully connects to the S3 + data, the S3 backend list method on an empty bucket should do nothing. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + listing = [] + + s3 = s3_backend() + + s3.clean_history(lambda *_: True) + try: + response_list = s3.list() + except Exception: # pylint:disable=broad-except + pytest.fail("S3 backend should not raise exception on successful list") + + assert list(response_list) == [x["name"] for x in listing] + + +@mock_s3 +def test_backends_data_s3_list_with_failed_connection_should_log_the_error( + s3_backend, caplog +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.list` method fails to retrieve the list of + archives, the S3 backend list method should log the error and raise a + BackendException. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-04-29.gz", + Body=json.dumps({"id": "1", "foo": "bar"}), + ) + + s3 = s3_backend() + + s3.clean_history(lambda *_: True) + + msg = "Failed to list the bucket wrong_name: The specified bucket does not exist" + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + next(s3.list(target="wrong_name")) + with pytest.raises(BackendException, match=msg): + next(s3.list(target="wrong_name", new=True)) + with pytest.raises(BackendException, match=msg): + next(s3.list(target="wrong_name", details=True)) + + assert ( + list( + filter( + lambda record: record[1] == logging.ERROR, + caplog.record_tuples, + ) + ) + == [("ralph.backends.data.s3", logging.ERROR, msg)] * 3 + ) + + +@mock_s3 +def test_backends_data_s3_read_with_valid_name_should_write_to_history( + s3_backend, + monkeypatch, +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.list` method successfully retrieves from the + S3 data the object with the provided name (the object exists), + the S3 backend read method should write the entry to the history. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + raw_body = b"some contents in the body" + json_body = '{"id":"foo"}' + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-09-29.gz", + Body=raw_body, + ) + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-09-30.gz", + Body=json_body, + ) + + freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + monkeypatch.setattr("ralph.backends.data.s3.now", lambda: freezed_now) + + s3 = s3_backend() + s3.clean_history(lambda *_: True) + + list( + s3.read( + query="2022-09-29.gz", + target=bucket_name, + chunk_size=1000, + raw_output=True, + ) + ) + + assert { + "backend": "s3", + "action": "read", + "id": f"{bucket_name}/2022-09-29.gz", + "size": len(raw_body), + "timestamp": freezed_now, + } in s3.history + + list( + s3.read( + query="2022-09-30.gz", + raw_output=False, + ) + ) + + assert { + "backend": "s3", + "action": "read", + "id": f"{bucket_name}/2022-09-30.gz", + "size": len(json_body), + "timestamp": freezed_now, + } in s3.history + + +@mock_s3 +def test_backends_data_s3_read_with_invalid_output_should_log_the_error( + s3_backend, caplog +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.read` method fails to serialize the object, the + S3 backend read method should log the error, not write to history and raise a + BackendException. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + body = b"some contents in the body" + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-09-29.gz", + Body=body, + ) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException): + s3 = s3_backend() + list(s3.read(query="2022-09-29.gz", raw_output=False)) + + assert ( + "ralph.backends.data.s3", + logging.ERROR, + "Raised error: Expecting value: line 1 column 1 (char 0)", + ) in caplog.record_tuples + + s3.clean_history(lambda *_: True) + + +@mock_s3 +def test_backends_data_s3_read_with_invalid_name_should_log_the_error( + s3_backend, caplog +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.read` method fails to retrieve from the S3 + data the object with the provided name (the object does not exists on S3), + the S3 backend read method should log the error, not write to history and raise a + BackendException. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + body = b"some contents in the body" + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-09-29.gz", + Body=body, + ) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendParameterException): + s3 = s3_backend() + list(s3.read(query=None, target=bucket_name)) + + assert ( + "ralph.backends.data.s3", + logging.ERROR, + "Invalid query. The query should be a valid object name.", + ) in caplog.record_tuples + + s3.clean_history(lambda *_: True) + + +@mock_s3 +def test_backends_data_s3_read_with_wrong_name_should_log_the_error( + s3_backend, caplog +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.read` method fails to retrieve from the S3 + data the object with the provided name (the object does not exists on S3), + the S3 backend read method should log the error, not write to history and raise a + BackendException. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + body = b"some contents in the body" + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-09-29.gz", + Body=body, + ) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException): + s3 = s3_backend() + s3.clean_history(lambda *_: True) + list(s3.read(query="invalid_name.gz", target=bucket_name)) + + assert ( + "ralph.backends.data.s3", + logging.ERROR, + "Failed to download invalid_name.gz: The specified key does not exist.", + ) in caplog.record_tuples + + assert s3.history == [] + + +@mock_s3 +def test_backends_data_s3_read_with_iter_error_should_log_the_error( + s3_backend, caplog, monkeypatch +): # pylint: disable=invalid-name + """Test that given `S3DataBackend.read` method fails to iterate through the result + from the S3 data the object, the S3 backend read method should log the error, + not write to history and raise a BackendException. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + body = b"some contents in the body" + + object_name = "2022-09-29.gz" + + s3_client.put_object( + Bucket=bucket_name, + Key=object_name, + Body=body, + ) + + def mock_read_raw(*args, **kwargs): + raise ResponseStreamingError(error="error") + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException): + s3 = s3_backend() + monkeypatch.setattr(s3, "_read_raw", mock_read_raw) + s3.clean_history(lambda *_: True) + list(s3.read(query=object_name, target=bucket_name, raw_output=True)) + + assert ( + "ralph.backends.data.s3", + logging.ERROR, + f"Failed to read chunk from object {object_name}", + ) in caplog.record_tuples + assert s3.history == [] + + +@pytest.mark.parametrize( + "operation_type", + [None, BaseOperationType.CREATE, BaseOperationType.INDEX], +) +@mock_s3 +def test_backends_data_s3_write_method_with_parameter_error( + operation_type, s3_backend, caplog +): # pylint: disable=invalid-name + """Test the `S3DataBackend.write` method, given a target matching an + existing object and a `CREATE` or `INDEX` `operation_type`, should raise a + `FileExistsError`. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + body = b"some contents in the body" + + s3_client.put_object( + Bucket=bucket_name, + Key="2022-09-29.gz", + Body=body, + ) + + object_name = "2022-09-29.gz" + some_content = b"some contents in the stream file to upload" + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException): + s3 = s3_backend() + s3.clean_history(lambda *_: True) + s3.write( + data=some_content, target=object_name, operation_type=operation_type + ) + + msg = ( + f"{object_name} already exists and overwrite is not allowed for operation" + f" {operation_type if operation_type is not None else BaseOperationType.CREATE}" + ) + + assert ("ralph.backends.data.s3", logging.ERROR, msg) in caplog.record_tuples + assert s3.history == [] + + +@pytest.mark.parametrize( + "operation_type", + [BaseOperationType.APPEND, BaseOperationType.DELETE], +) +def test_backends_data_s3_data_backend_write_method_with_append_or_delete_operation( + s3_backend, operation_type +): + """Test the `S3DataBackend.write` method, given an `APPEND` + `operation_type`, should raise a `BackendParameterException`. + """ + # pylint: disable=invalid-name + backend = s3_backend() + with pytest.raises( + BackendParameterException, + match=f"{operation_type.name} operation_type is not allowed.", + ): + backend.write(data=[b"foo"], operation_type=operation_type) + + +@pytest.mark.parametrize( + "operation_type", + [BaseOperationType.CREATE, BaseOperationType.INDEX], +) +@mock_s3 +def test_backends_data_s3_write_method_with_create_index_operation( + operation_type, s3_backend, monkeypatch, caplog +): # pylint: disable=invalid-name + """Test the `S3DataBackend.write` method, given a target matching an + existing object and a `CREATE` or `INDEX` `operation_type`, should add + an entry to the History. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + monkeypatch.setattr("ralph.backends.data.s3.now", lambda: freezed_now) + + object_name = "new-archive.gz" + some_content = b"some contents in the stream file to upload" + data = [some_content, some_content, some_content] + s3 = s3_backend() + s3.clean_history(lambda *_: True) + + response = s3.write( + data=data, + target=object_name, + operation_type=operation_type, + ) + + assert response == 3 + assert { + "backend": "s3", + "action": "write", + "operation_type": operation_type.value, + "id": f"{bucket_name}/{object_name}", + "size": len(some_content) * 3, + "timestamp": freezed_now, + } in s3.history + + object_name = "new-archive2.gz" + other_content = {"some": "content"} + + data = [other_content, other_content] + response = s3.write( + data=data, + target=object_name, + operation_type=operation_type, + ) + + assert response == 2 + assert { + "backend": "s3", + "action": "write", + "operation_type": operation_type.value, + "id": f"{bucket_name}/{object_name}", + "size": len(bytes(f"{json.dumps(other_content)}\n", encoding="utf8")) * 2, + "timestamp": freezed_now, + } in s3.history + + assert list(s3.read(query=object_name, raw_output=False)) == data + + object_name = "new-archive3.gz" + date = datetime.datetime(2023, 6, 30, 8, 42, 15, 554892) + + data = [{"some": "content", "datetime": date}] + + error = "Object of type datetime is not JSON serializable" + + with caplog.at_level(logging.ERROR): + # Without ignoring error + with pytest.raises(BackendException, match=error): + response = s3.write( + data=data, + target=object_name, + operation_type=operation_type, + ignore_errors=False, + ) + + # Ignoring error + response = s3.write( + data=data, + target=object_name, + operation_type=operation_type, + ignore_errors=True, + ) + + assert list( + filter( + lambda record: record[1] == logging.ERROR, + caplog.record_tuples, + ) + ) == ( + [ + ( + "ralph.backends.data.s3", + logging.ERROR, + f"Failed to encode JSON: {error}, for document {data[0]}", + ) + ] + * 2 + ) + + +@mock_s3 +def test_backends_data_s3_write_method_with_no_data_should_skip( + s3_backend, +): # pylint: disable=invalid-name + """Test the `S3DataBackend.write` method, given no data to write, + should skip and return 0. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + object_name = "new-archive.gz" + + s3 = s3_backend() + response = s3.write( + data=[], + target=object_name, + operation_type=BaseOperationType.CREATE, + ) + assert response == 0 + + +@mock_s3 +def test_backends_data_s3_write_method_with_failure_should_log_the_error( + s3_backend, +): # pylint: disable=invalid-name + """Test the `S3DataBackend.write` method, given a connection failure, + should raise a `BackendException`. + """ + # Regions outside of us-east-1 require the appropriate LocationConstraint + s3_client = boto3.client("s3", region_name="us-east-1") + # Create a valid bucket in Moto's 'virtual' AWS account + bucket_name = "bucket_name" + s3_client.create_bucket(Bucket=bucket_name) + + object_name = "new-archive.gz" + body = b"some contents in the body" + error = "Failed to upload" + + def raise_client_error(*args, **kwargs): + raise ClientError({"Error": {}}, "error") + + s3 = s3_backend() + s3.client.put_object = raise_client_error + + with pytest.raises(BackendException, match=error): + s3.write( + data=[body], + target=object_name, + operation_type=BaseOperationType.CREATE, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 8baaa65b0..f73e865cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ mongo_forwarding, moto_fs, s3, + s3_backend, settings_fs, swift, swift_backend, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 596c564ec..99e445316 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -23,6 +23,7 @@ from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.data.ldp import LDPDataBackend +from ralph.backends.data.s3 import S3DataBackend, S3DataBackendSettings from ralph.backends.data.swift import SwiftDataBackend, SwiftDataBackendSettings from ralph.backends.database.clickhouse import ClickHouseDatabase from ralph.backends.database.es import ESDatabase @@ -411,6 +412,27 @@ def get_s3_storage(): return get_s3_storage +@pytest.fixture +def s3_backend(): + """Return the `get_s3_data_backend` function.""" + + def get_s3_data_backend(): + """Return an instance of S3DataBackend.""" + settings = S3DataBackendSettings( + ACCESS_KEY_ID="access_key_id", + SECRET_ACCESS_KEY="secret_access_key", + SESSION_TOKEN="session_token", + ENDPOINT_URL=None, + DEFAULT_REGION="default-region", + DEFAULT_BUCKET_NAME="bucket_name", + DEFAULT_CHUNK_SIZE=4096, + LOCALE_ENCODING="utf8", + ) + return S3DataBackend(settings) + + return get_s3_data_backend + + @pytest.fixture def events(): """Return test events fixture.""" From eb1b8b645458435a6640ea0e98a57781ce37a01d Mon Sep 17 00:00:00 2001 From: Arnaud Henric Date: Tue, 9 May 2023 12:28:30 +0200 Subject: [PATCH 11/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20clickhouse=20?= =?UTF-8?q?unified=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add ClickHouse backend under the new common 'data' interface. With ClickHouse under the new data interface, tests are updated as well. Storage and Database backends had similar interfaces and usage, so a new Data Backend interface has been created. --- src/ralph/backends/data/clickhouse.py | 473 +++++++++++++++++++ src/ralph/backends/lrs/base.py | 64 ++- src/ralph/backends/lrs/clickhouse.py | 177 +++++++ tests/backends/data/test_clickhouse.py | 628 +++++++++++++++++++++++++ tests/backends/lrs/__init__.py | 0 tests/backends/lrs/test_clickhouse.py | 368 +++++++++++++++ tests/conftest.py | 2 + tests/fixtures/backends.py | 55 +++ 8 files changed, 1744 insertions(+), 23 deletions(-) create mode 100755 src/ralph/backends/data/clickhouse.py create mode 100644 src/ralph/backends/lrs/clickhouse.py create mode 100644 tests/backends/data/test_clickhouse.py create mode 100644 tests/backends/lrs/__init__.py create mode 100644 tests/backends/lrs/test_clickhouse.py diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py new file mode 100755 index 000000000..1010d5756 --- /dev/null +++ b/src/ralph/backends/data/clickhouse.py @@ -0,0 +1,473 @@ +"""ClickHouse data backend for Ralph.""" + +import json +import logging +from datetime import datetime +from io import IOBase +from itertools import chain +from typing import ( + Any, + Dict, + Generator, + Iterable, + Iterator, + List, + Literal, + NamedTuple, + Optional, + Union, +) +from uuid import UUID, uuid4 + +import clickhouse_connect +from clickhouse_connect.driver.exceptions import ClickHouseError +from pydantic import BaseModel, Json, ValidationError + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.conf import BaseSettingsConfig, ClientOptions +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class ClickHouseInsert(BaseModel): + """Model to validate required fields for ClickHouse insertion.""" + + event_id: UUID + emission_time: datetime + + +class ClickHouseClientOptions(ClientOptions): + """Pydantic model for `clickhouse` client options.""" + + date_time_input_format: str = "best_effort" + allow_experimental_object_type: Literal[0, 1] = 1 + + +class InsertTuple(NamedTuple): + """Named tuple for ClickHouse insertion.""" + + event_id: UUID + emission_time: datetime + event: dict + event_str: str + + +class ClickHouseDataBackendSettings(BaseDataBackendSettings): + """Represent the ClickHouse data backend default configuration. + + Attributes: + HOST (str): ClickHouse server host to connect to. + PORT (int): ClickHouse server port to connect to. + DATABASE (str): ClickHouse database to connect to. + EVENT_TABLE_NAME (str): Table where events live. + USERNAME (str): ClickHouse username to connect as (optional). + PASSWORD (str): Password for the given ClickHouse username (optional). + CLIENT_OPTIONS (ClickHouseClientOptions): A dictionary of valid options for the + ClickHouse client connection. + DEFAULT_CHUNK_SIZE (int): The default chunk size for reading/writing. + LOCALE_ENCODING (str): The locale encoding to use when none is provided. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__CLICKHOUSE__" + + HOST: str = "localhost" + PORT: int = 8123 + DATABASE: str = "xapi" + EVENT_TABLE_NAME: str = "xapi_events_all" + USERNAME: str = None + PASSWORD: str = None + CLIENT_OPTIONS: ClickHouseClientOptions = ClickHouseClientOptions() + DEFAULT_CHUNK_SIZE: int = 500 + LOCALE_ENCODING: str = "utf8" + + +class BaseClickHouseQuery(BaseQuery): + """Base ClickHouse query model.""" + + select: Union[str, List[str]] = "event" + where: Optional[Union[str, List[str]]] + parameters: Optional[Dict] + limit: Optional[int] + sort: Optional[str] + column_oriented: Optional[bool] = False + + +class ClickHouseQuery(BaseClickHouseQuery): + """ClickHouse query model.""" + + # pylint: disable=unsubscriptable-object + query_string: Optional[Json[BaseClickHouseQuery]] + + +class ClickHouseDataBackend(BaseDataBackend): + """ClickHouse database backend.""" + + name = "clickhouse" + query_model = ClickHouseQuery + default_operation_type = BaseOperationType.CREATE + settings_class = ClickHouseDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Instantiate the ClickHouse configuration. + + Args: + settings (ClickHouseDataBackendSettings or None): The ClickHouse + data backend settings. + """ + self.settings = settings if settings else self.settings_class() + self.database = self.settings.DATABASE + self.event_table_name = self.settings.EVENT_TABLE_NAME + self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE + self.locale_encoding = self.settings.LOCALE_ENCODING + self._client = None + + @property + def client(self): + """Create a ClickHouse client if it doesn't exist. + + We do this here so that we don't interrupt initialization in the case + where ClickHouse is not running when Ralph starts up, which will cause + Ralph to hang. This client is HTTP, so not actually stateful. Ralph + should be able to gracefully deal with ClickHouse outages at all other + times. + """ + if not self._client: + self._client = clickhouse_connect.get_client( + host=self.settings.HOST, + port=self.settings.PORT, + database=self.database, + username=self.settings.USERNAME, + password=self.settings.PASSWORD, + settings=self.settings.CLIENT_OPTIONS.dict(), + ) + return self._client + + def status(self) -> DataBackendStatus: + """Check ClickHouse connection status. + + Return: + DataBackendStatus: The status of the data backend. + """ + try: + self.client.query("SELECT 1") + except ClickHouseError: + return DataBackendStatus.AWAY + + return DataBackendStatus.OK + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List tables for a given database. + + Args: + target (str): The database name to list tables from. + details (bool): Get detailed table information instead of just ids. + new (bool): Given the history, list only not already fetched archives. + + Yield: + str: The next table name. (If `details` is False). + dict: The next table name. (If `details` is True). + + Raise: + BackendException: If a failure during table names retrieval occurs. + """ + sql = f"SHOW TABLES FROM {target if target else self.database}" + + try: + tables = self.client.query(sql).named_results() + except (ClickHouseError, IndexError, TypeError, ValueError) as error: + msg = "Failed to read tables: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + for table in tables: + if details: + yield table + else: + yield str(table.get("name")) + + @enforce_query_checks + def read( + self, + *, + query: Union[str, ClickHouseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read documents matching the query in the target table and yield them. + + Args: + query (str or ClickHouseQuery): The query to use when fetching documents. + target (str or None): The target table name to query. + If target is `None`, the `event_table_name` is used instead. + chunk_size (int or None): The chunk size for reading batches of documents. + If chunk_size is `None` it defaults to `default_chunk_size`. + raw_output (bool): Controls whether to yield dictionaries or bytes. + ignore_errors (bool): If `True`, errors during the encoding operation + will be ignored and logged. If `False` (default), a `BackendException` + will be raised if an error occurs. + + Yield: + bytes: The next raw document if `raw_output` is True. + dict: The next JSON parsed document if `raw_output` is False. + + Raise: + BackendException: If a failure occurs during ClickHouse connection. + """ + if target is None: + target = self.event_table_name + + if chunk_size is None: + chunk_size = self.default_chunk_size + + query = ( + BaseClickHouseQuery(query.query_string) + if query.query_string + else query.copy(exclude={"query_string"}) + ) + + if isinstance(query.select, str): + query.select = [query.select] + select = ",".join(query.select) + sql = f"SELECT {select} FROM {target}" # nosec + + if query.where: + if isinstance(query.where, str): + query.where = [query.where] + filter_str = "\nWHERE 1=1 AND " + filter_str += """ + AND + """.join( + query.where + ) + sql += filter_str + + if query.sort: + sql += f"\nORDER BY {query.sort}" + + if query.limit: + sql += f"\nLIMIT {query.limit}" + + reader = self._read_raw if raw_output else lambda _: _ + + logger.debug( + "Start reading the %s table of the %s database (chunk size: %d)", + target, + self.database, + chunk_size, + ) + try: + result = self.client.query( + sql, + parameters=query.parameters, + settings={"buffer_size": chunk_size}, + column_oriented=query.column_oriented, + ).named_results() + for statement in result: + try: + yield reader(statement) + except (TypeError, ValueError) as error: + msg = "Failed to encode document %s: %s" + if ignore_errors: + logger.warning(msg, statement, error) + continue + logger.error(msg, statement, error) + raise BackendException(msg % (statement, error)) from error + except (ClickHouseError, IndexError, TypeError, ValueError) as error: + msg = "Failed to read documents: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Write `data` documents to the `target` table and return their count. + + Args: + data: (Iterable or IOBase): The data containing documents to write. + target (str or None): The target table name. + If target is `None`, the `event_table_name` is used instead. + chunk_size (int or None): The number of documents to write in one batch. + If `chunk_size` is `None` it defaults to `default_chunk_size`. + ignore_errors (bool): If `True`, errors during the write operation + will be ignored and logged. If `False` (default), a `BackendException` + will be raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Return: + int: The number of written documents. + + Raise: + BackendException: If a failure occurs while writing to ClickHouse or + during document decoding and `ignore_errors` is set to `False`. + BackendParameterException: If the `operation_type` is `APPEND`, `UPDATE` + or `DELETE` as it is not supported. + """ + target = target if target else self.event_table_name + if not operation_type: + operation_type = self.default_operation_type + if not chunk_size: + chunk_size = self.default_chunk_size + logger.debug( + "Start writing to the %s table of the %s database (chunk size: %d)", + target, + self.database, + chunk_size, + ) + + data = iter(data) + try: + first_record = next(data) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return 0 + + data = chain([first_record], data) + if isinstance(first_record, bytes): + data = self._parse_bytes_to_dict(data, ignore_errors) + + if operation_type not in [BaseOperationType.CREATE, BaseOperationType.INDEX]: + msg = "%s operation_type is not allowed." + logger.error(msg, operation_type.name) + raise BackendParameterException(msg % operation_type.name) + + # operation_type is either CREATE or INDEX + count = 0 + batch = [] + + for insert_tuple in self._to_insert_tuples( + data, + ignore_errors=ignore_errors, + ): + batch.append(insert_tuple) + if len(batch) < chunk_size: + continue + + count += self._bulk_import( + batch, + ignore_errors=ignore_errors, + event_table_name=target, + ) + batch = [] + + # Edge case: if the total number of documents is lower than the chunk size + if len(batch) > 0: + count += self._bulk_import( + batch, + ignore_errors=ignore_errors, + event_table_name=target, + ) + + logger.info("Inserted a total of %d documents with success", count) + + return count + + @staticmethod + def _to_insert_tuples( + data: Iterable[dict], + ignore_errors: bool = False, + ) -> Generator[InsertTuple, None, None]: + """Convert `data` dictionaries to insert tuples.""" + for statement in data: + try: + insert = ClickHouseInsert( + event_id=statement.get("id", str(uuid4())), + emission_time=statement["timestamp"], + ) + except (KeyError, ValidationError) as error: + msg = "Statement %s has an invalid 'id' or 'timestamp' field" + if ignore_errors: + logger.warning(msg, statement) + continue + logger.error(msg, statement) + raise BackendException(msg % statement) from error + + insert_tuple = InsertTuple( + insert.event_id, + insert.emission_time, + statement, + json.dumps(statement), + ) + + yield insert_tuple + + def _bulk_import( + self, batch: list, ignore_errors: bool = False, event_table_name: str = None + ): + """Insert a batch of documents into the selected database table.""" + try: + found_ids = {document.event_id for document in batch} + + if len(found_ids) != len(batch): + raise BackendException("Duplicate IDs found in batch") + + self.client.insert( + event_table_name, + batch, + column_names=[ + "event_id", + "emission_time", + "event", + "event_str", + ], + # Allow ClickHouse to buffer the insert, and wait for the + # buffer to flush. Should be configurable, but I think these are + # reasonable defaults. + settings={"async_insert": 1, "wait_for_async_insert": 1}, + ) + except (ClickHouseError, BackendException) as error: + if not ignore_errors: + raise BackendException(*error.args) from error + logger.warning( + "Bulk import failed for current chunk but you choose to ignore it.", + ) + # There is no current way of knowing how many rows from the batch + # succeeded, we assume 0 here. + return 0 + + inserted_count = len(batch) + logger.debug("Inserted %d documents chunk with success", inserted_count) + + return inserted_count + + @staticmethod + def _parse_bytes_to_dict( + raw_documents: Iterable[bytes], ignore_errors: bool + ) -> Iterator[dict]: + """Read the `raw_documents` Iterable and yield dictionaries.""" + for raw_document in raw_documents: + try: + yield json.loads(raw_document) + except (TypeError, json.JSONDecodeError) as error: + if ignore_errors: + logger.warning( + "Raised error: %s, for document %s", error, raw_document + ) + continue + logger.error("Raised error: %s, for document %s", error, raw_document) + raise error + + def _read_raw(self, document: Dict[str, Any]) -> bytes: + """Read the `documents` Iterable and yield bytes.""" + return json.dumps(document).encode(self.locale_encoding) diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 4857bcbac..d7a3309f5 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -1,46 +1,64 @@ -"""Base data backend for Ralph.""" +"""Base LRS backend for Ralph.""" from abc import abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import List, Literal, Optional +from typing import Iterator, List, Literal, Optional from uuid import UUID from pydantic import BaseModel -from ralph.backends.data.base import BaseDataBackend +from ralph.backends.data.base import BaseDataBackend, BaseDataBackendSettings + + +class BaseLRSBackendSettings(BaseDataBackendSettings): + """LRS backend default configuration.""" @dataclass class StatementQueryResult: - """Represents a common interface for results of an LRS statements query.""" + """Result of an LRS statements query.""" statements: List[dict] pit_id: str search_after: str +class AgentParameters(BaseModel): + """LRS query parameters for query on type Agent. + + NB: Agent refers to the data structure, NOT to the LRS query parameter. + """ + + mbox: Optional[str] + mbox_sha1sum: Optional[str] + openid: Optional[str] + account__name: Optional[str] + account__home_page: Optional[str] + + class StatementParameters(BaseModel): - """Represents a dictionary of possible LRS query parameters.""" + """LRS statements query parameters.""" # pylint: disable=too-many-instance-attributes - statementId: Optional[str] = None # pylint: disable=invalid-name - voidedStatementId: Optional[str] = None # pylint: disable=invalid-name - agent: Optional[str] = None - verb: Optional[str] = None - activity: Optional[str] = None - registration: Optional[UUID] = None - related_activities: Optional[bool] = False - related_agents: Optional[bool] = False - since: Optional[datetime] = None - until: Optional[datetime] = None - limit: Optional[int] = None + statementId: Optional[str] # pylint: disable=invalid-name + voidedStatementId: Optional[str] # pylint: disable=invalid-name + agent: Optional[AgentParameters] + verb: Optional[str] + activity: Optional[str] + registration: Optional[UUID] + related_activities: Optional[bool] + related_agents: Optional[bool] + since: Optional[datetime] + until: Optional[datetime] + limit: Optional[int] format: Optional[Literal["ids", "exact", "canonical"]] = "exact" - attachments: Optional[bool] = False - ascending: Optional[bool] = False - search_after: Optional[str] = None - pit_id: Optional[str] = None + attachments: Optional[bool] + ascending: Optional[bool] + search_after: Optional[str] + pit_id: Optional[str] + authority: Optional[AgentParameters] class BaseLRSBackend(BaseDataBackend): @@ -48,8 +66,8 @@ class BaseLRSBackend(BaseDataBackend): @abstractmethod def query_statements(self, params: StatementParameters) -> StatementQueryResult: - """Returns the statements query payload using xAPI parameters.""" + """Return the statements query payload using xAPI parameters.""" @abstractmethod - def query_statements_by_ids(self, ids: List[str]) -> list: - """Returns the list of matching statement IDs from the database.""" + def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Yield statements with matching ids from the backend.""" diff --git a/src/ralph/backends/lrs/clickhouse.py b/src/ralph/backends/lrs/clickhouse.py new file mode 100644 index 000000000..423318b35 --- /dev/null +++ b/src/ralph/backends/lrs/clickhouse.py @@ -0,0 +1,177 @@ +"""ClickHouse LRS backend for Ralph.""" + +import logging +from typing import Iterator, List + +from ralph.backends.data.clickhouse import ( + ClickHouseDataBackend, + ClickHouseDataBackendSettings, +) +from ralph.backends.lrs.base import ( + AgentParameters, + BaseLRSBackend, + BaseLRSBackendSettings, + StatementParameters, + StatementQueryResult, +) +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class ClickHouseLRSBackendSettings( + BaseLRSBackendSettings, ClickHouseDataBackendSettings +): + """Represent the ClickHouse data backend default configuration. + + Attributes: + IDS_CHUNK_SIZE (int): The chunk size for querying by ids. + """ + + IDS_CHUNK_SIZE: int = 10000 + + +class ClickHouseLRSBackend(BaseLRSBackend, ClickHouseDataBackend): + """ClickHouse LRS backend implementation.""" + + settings_class = ClickHouseLRSBackendSettings + + def query_statements(self, params: StatementParameters) -> StatementQueryResult: + """Return the statements query payload using xAPI parameters.""" + ch_params = params.dict(exclude_none=True) + where = [] + + if params.statementId: + where.append("event_id = {statementId:UUID}") + + self._add_agent_filters(ch_params, where, params.agent, "actor") + ch_params.pop("agent", None) + + self._add_agent_filters(ch_params, where, params.authority, "authority") + ch_params.pop("authority", None) + + if params.verb: + where.append("event.verb.id = {verb:String}") + + if params.activity: + where.append("event.object.objectType = 'Activity'") + where.append("event.object.id = {activity:String}") + + if params.since: + where.append("emission_time > {since:DateTime64(6)}") + + if params.until: + where.append("emission_time <= {until:DateTime64(6)}") + + if params.search_after: + search_order = ">" if params.ascending else "<" + + where.append( + f"(emission_time {search_order} " + "{search_after:DateTime64(6)}" + " OR " + "(emission_time = {search_after:DateTime64(6)}" + " AND " + f"event_id {search_order} " + "{pit_id:UUID}" + "))" + ) + + sort_order = "ASCENDING" if params.ascending else "DESCENDING" + order_by = f"emission_time {sort_order}, event_id {sort_order}" + + query = { + "select": ["event_id", "emission_time", "event"], + "where": where, + "parameters": ch_params, + "limit": params.limit, + "sort": order_by, + } + try: + clickhouse_response = list( + self.read( + query=query, + target=self.event_table_name, + ignore_errors=True, + ) + ) + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from ClickHouse") + raise error + + new_search_after = None + new_pit_id = None + + if clickhouse_response: + # Our search after string is a combination of event timestamp and + # event id, so that we can avoid losing events when they have the + # same timestamp, and also avoid sending the same event twice. + new_search_after = clickhouse_response[-1]["emission_time"].isoformat() + new_pit_id = str(clickhouse_response[-1]["event_id"]) + + return StatementQueryResult( + statements=[document["event"] for document in clickhouse_response], + search_after=new_search_after, + pit_id=new_pit_id, + ) + + def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Yield statements with matching ids from the backend.""" + + def chunk_id_list(chunk_size=self.settings.IDS_CHUNK_SIZE): + for i in range(0, len(ids), chunk_size): + yield ids[i : i + chunk_size] + + query = { + "select": "event", + "where": "event_id IN ({ids:Array(String)})", + "parameters": {"ids": ["1"]}, + "column_oriented": True, + } + try: + for chunk_ids in chunk_id_list(): + query["parameters"]["ids"] = chunk_ids + yield from self.read( + query=query, + target=self.event_table_name, + ignore_errors=True, + ) + except (BackendException, BackendParameterException) as error: + msg = "Failed to read from ClickHouse" + logger.error(msg) + raise error + + @staticmethod + def _add_agent_filters( + ch_params: dict, + where: list, + agent_params: AgentParameters, + target_field: str, + ): + """Add filters relative to agents to `where`.""" + if not agent_params: + return + if agent_params.mbox: + ch_params[f"{target_field}__mbox"] = agent_params.mbox + where.append(f"event.{target_field}.mbox = {{{target_field}__mbox:String}}") + elif agent_params.mbox_sha1sum: + ch_params[f"{target_field}__mbox_sha1sum"] = agent_params.mbox_sha1sum + where.append( + f"event.{target_field}.mbox_sha1sum = {{{target_field}__mbox_sha1sum:String}}" # noqa: E501 # pylint: disable=line-too-long + ) + elif agent_params.openid: + ch_params[f"{target_field}__openid"] = agent_params.openid + where.append( + f"event.{target_field}.openid = {{{target_field}__openid:String}}" + ) + elif agent_params.account__name: + ch_params[f"{target_field}__account_name"] = agent_params.account__name + where.append( + f"event.{target_field}.account_name = {{{target_field}__account_name:String}}" # noqa: E501 # pylint: disable=line-too-long + ) + ch_params[ + f"{target_field}__account_homepage" + ] = agent_params.account__home_page + where.append( + f"event.{target_field}.account_homepage = {{{target_field}__account_homepage:String}}" # noqa: E501 # pylint: disable=line-too-long + ) diff --git a/tests/backends/data/test_clickhouse.py b/tests/backends/data/test_clickhouse.py new file mode 100644 index 000000000..c0876ce37 --- /dev/null +++ b/tests/backends/data/test_clickhouse.py @@ -0,0 +1,628 @@ +"""Tests for Ralph clickhouse data backend.""" + +import json +import logging +import uuid +from datetime import datetime, timedelta + +import pytest +from clickhouse_connect.driver.exceptions import ClickHouseError +from clickhouse_connect.driver.httpclient import HttpClient + +from ralph.backends.data.base import BaseOperationType, DataBackendStatus +from ralph.backends.data.clickhouse import ( + ClickHouseDataBackend, + ClickHouseDataBackendSettings, + ClickHouseQuery, +) +from ralph.exceptions import BackendException, BackendParameterException + +from tests.fixtures.backends import ( + CLICKHOUSE_TEST_DATABASE, + CLICKHOUSE_TEST_HOST, + CLICKHOUSE_TEST_PORT, + CLICKHOUSE_TEST_TABLE_NAME, +) + + +def test_backends_data_clickhouse_data_backend_default_instantiation(monkeypatch, fs): + # pylint: disable=invalid-name + """Test the `ClickHouseDataBackend` default instantiation.""" + fs.create_file(".env") + backend_settings_names = [ + "HOST", + "PORT", + "DATABASE", + "EVENT_TABLE_NAME", + "USERNAME", + "PASSWORD", + "CLIENT_OPTIONS", + "DEFAULT_CHUNK_SIZE", + "LOCALE_ENCODING", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__CLICKHOUSE__{name}", raising=False) + + assert ClickHouseDataBackend.name == "clickhouse" + assert ClickHouseDataBackend.query_model == ClickHouseQuery + assert ClickHouseDataBackend.default_operation_type == BaseOperationType.CREATE + assert ClickHouseDataBackend.settings_class == ClickHouseDataBackendSettings + backend = ClickHouseDataBackend() + assert backend.event_table_name == "xapi_events_all" + assert backend.default_chunk_size == 500 + assert backend.locale_encoding == "utf8" + + +def test_backends_data_clickhouse_data_backend_instantiation_with_settings(): + """Test the `ClickHouseDataBackend` instantiation.""" + settings = ClickHouseDataBackendSettings( + HOST=CLICKHOUSE_TEST_HOST, + PORT=CLICKHOUSE_TEST_PORT, + DATABASE=CLICKHOUSE_TEST_DATABASE, + EVENT_TABLE_NAME=CLICKHOUSE_TEST_TABLE_NAME, + USERNAME="default", + PASSWORD="", + CLIENT_OPTIONS={ + "date_time_input_format": "test_format", + "allow_experimental_object_type": 0, + }, + DEFAULT_CHUNK_SIZE=1000, + LOCALE_ENCODING="utf-16", + ) + backend = ClickHouseDataBackend(settings) + + assert isinstance(backend.client, HttpClient) + assert backend.event_table_name == CLICKHOUSE_TEST_TABLE_NAME + assert backend.default_chunk_size == 1000 + assert backend.locale_encoding == "utf-16" + + +def test_backends_data_clickhouse_data_backend_status( + clickhouse, clickhouse_backend, monkeypatch +): + """Test the `ClickHouseDataBackend.status` method.""" + # pylint: disable=unused-argument + + backend = clickhouse_backend() + + assert backend.status() == DataBackendStatus.OK + + def mock_query(*_, **__): + """Mock the ClickHouseClient.query method.""" + raise ClickHouseError("Something is wrong") + + monkeypatch.setattr(backend.client, "query", mock_query) + assert backend.status() == DataBackendStatus.AWAY + + +def test_backends_data_clickhouse_data_backend_read_method_with_raw_output( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.read` method.""" + # pylint: disable=unused-argument, protected-access + # Create records + date_1 = (datetime.now() - timedelta(seconds=3)).isoformat() + date_2 = (datetime.now() - timedelta(seconds=2)).isoformat() + date_3 = (datetime.now() - timedelta(seconds=1)).isoformat() + + statements = [ + {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_1}, + {"id": str(uuid.uuid4()), "bool": 0, "timestamp": date_2}, + {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_3}, + ] + + backend = clickhouse_backend() + backend.write(statements) + + results = list(backend.read()) + assert len(results) == 3 + assert results[0]["event"] == statements[0] + assert results[1]["event"] == statements[1] + assert results[2]["event"] == statements[2] + + results = list(backend.read(chunk_size=10)) + assert len(results) == 3 + assert results[0]["event"] == statements[0] + assert results[1]["event"] == statements[1] + assert results[2]["event"] == statements[2] + + results = list(backend.read(raw_output=True)) + assert len(results) == 3 + assert isinstance(results[0], bytes) + assert json.loads(results[0])["event"] == statements[0] + + +# pylint: disable=unused-argument +def test_backends_data_clickhouse_data_backend_read_method_with_a_custom_query( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.read` method with a custom query.""" + date_1 = (datetime.now() - timedelta(seconds=3)).isoformat() + date_2 = (datetime.now() - timedelta(seconds=2)).isoformat() + date_3 = (datetime.now() - timedelta(seconds=1)).isoformat() + + statements = [ + {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_1}, + {"id": str(uuid.uuid4()), "bool": 0, "timestamp": date_2}, + {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_3}, + ] + + backend = clickhouse_backend() + documents = list( + backend._to_insert_tuples(statements) # pylint: disable=protected-access + ) + + backend.write(statements) + + # Test filtering + query = ClickHouseQuery(where="event.bool = 1") + results = list(backend.read(query=query, chunk_size=None)) + assert len(results) == 2 + assert results[0]["event"] == statements[0] + assert results[1]["event"] == statements[2] + + # Test select fields + query = ClickHouseQuery(select=["event_id", "event.bool"]) + results = list(backend.read(query=query)) + assert len(results) == 3 + assert len(results[0]) == 2 + assert results[0]["event_id"] == documents[0][0] + assert results[0]["event.bool"] == statements[0]["bool"] + assert results[1]["event_id"] == documents[1][0] + assert results[1]["event.bool"] == statements[1]["bool"] + assert results[2]["event_id"] == documents[2][0] + assert results[2]["event.bool"] == statements[2]["bool"] + + # Test both + query = ClickHouseQuery(where="event.bool = 0", select=["event_id", "event.bool"]) + results = list(backend.read(query=query)) + assert len(results) == 1 + assert len(results[0]) == 2 + assert results[0]["event_id"] == documents[1][0] + assert results[0]["event.bool"] == statements[1]["bool"] + + # Test sort + query = ClickHouseQuery(sort="emission_time DESCENDING") + results = list(backend.read(query=query)) + assert len(results) == 3 + assert results[0]["event"] == statements[2] + assert results[1]["event"] == statements[1] + assert results[2]["event"] == statements[0] + + # Test limit + query = ClickHouseQuery(limit=1) + results = list(backend.read(query=query)) + assert len(results) == 1 + assert results[0]["event"] == statements[0] + + # Test parameters + query = ClickHouseQuery( + where="event.bool = {event_bool:Bool}", + parameters={"event_bool": 0, "format": "exact"}, + ) + results = list(backend.read(query=query)) + assert len(results) == 1 + assert results[0]["event"] == statements[1] + + +def test_backends_data_clickhouse_data_backend_read_method_with_failures( + monkeypatch, caplog, clickhouse, clickhouse_backend +): # pylint: disable=unused-argument + """Test the `ClickHouseDataBackend.read` method with failures.""" + backend = clickhouse_backend() + + statement = {"id": str(uuid.uuid4()), "timestamp": str(datetime.utcnow())} + document = {"event": statement} + backend.write([statement]) + + # JSON encoding error + def mock_read_raw(*args, **kwargs): + """Mock the `ClickHouseDataBackend._read_raw` method.""" + raise TypeError("Error") + + monkeypatch.setattr(backend, "_read_raw", mock_read_raw) + + msg = f"Failed to encode document {document}: Error" + + # Not ignoring errors + with caplog.at_level(logging.ERROR): + with pytest.raises( + BackendException, + match=msg, + ): + list(backend.read(raw_output=True, ignore_errors=False)) + + assert ( + "ralph.backends.data.clickhouse", + logging.ERROR, + msg, + ) in caplog.record_tuples + + caplog.clear() + + # Ignoring errors + with caplog.at_level(logging.WARNING): + list(backend.read(raw_output=True, ignore_errors=True)) + + assert ( + "ralph.backends.data.clickhouse", + logging.WARNING, + msg, + ) in caplog.record_tuples + + assert ( + "ralph.backends.data.clickhouse", + logging.ERROR, + msg, + ) not in caplog.record_tuples + + # ClickHouse error during query should raise even when ignoring errors + def mock_query(*_, **__): + """Mock the ClickHouseClient.query method.""" + raise ClickHouseError("Something is wrong") + + monkeypatch.setattr(backend.client, "query", mock_query) + + msg = "Failed to read documents: Something is wrong" + with caplog.at_level(logging.ERROR): + with pytest.raises( + BackendException, + match=msg, + ): + list(backend.read(ignore_errors=True)) + + assert ( + "ralph.backends.data.clickhouse", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +def test_backends_data_clickhouse_data_backend_list_method( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.list` method.""" + + backend = clickhouse_backend() + + assert list(backend.list(details=True)) == [{"name": CLICKHOUSE_TEST_TABLE_NAME}] + assert list(backend.list(details=False)) == [CLICKHOUSE_TEST_TABLE_NAME] + + +def test_backends_data_clickhouse_data_backend_list_method_with_failure( + monkeypatch, caplog, clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.list` method with a failure.""" + # pylint: disable=unused-argument + backend = clickhouse_backend() + + def mock_query(*_, **__): + """Mock the ClickHouseClient.query method.""" + raise ClickHouseError("Something is wrong") + + monkeypatch.setattr(backend.client, "query", mock_query) + + with caplog.at_level(logging.ERROR): + msg = "Failed to read tables: Something is wrong" + with pytest.raises( + BackendException, + match=msg, + ): + list(backend.list()) + + assert ( + "ralph.backends.data.clickhouse", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +def test_backends_data_clickhouse_data_backend_write_method_with_invalid_timestamp( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method with an invalid timestamp.""" + # pylint: disable=unused-argument + valid_timestamp = (datetime.now() - timedelta(seconds=3)).isoformat() + invalid_timestamp = "This is not a valid timestamp!" + invalid_statement = { + "id": str(uuid.uuid4()), + "bool": 0, + "timestamp": invalid_timestamp, + } + + statements = [ + {"id": str(uuid.uuid4()), "bool": 1, "timestamp": valid_timestamp}, + invalid_statement, + ] + + backend = clickhouse_backend() + + msg = f"Statement {invalid_statement} has an invalid 'id' or 'timestamp' field" + with pytest.raises( + BackendException, + match=msg, + ): + backend.write(statements, ignore_errors=False) + + +def test_backends_data_clickhouse_data_backend_write_method_no_timestamp( + caplog, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method when a statement has no + timestamp. + """ + statement = {"id": str(uuid.uuid4())} + + backend = clickhouse_backend() + + msg = f"Statement {statement} has an invalid 'id' or 'timestamp' field" + + # Without ignoring errors + with caplog.at_level(logging.ERROR): + with pytest.raises( + BackendException, + match=msg, + ): + backend.write([statement], ignore_errors=False) + + assert ( + "ralph.backends.data.clickhouse", + logging.ERROR, + f"Statement {statement} has an invalid 'id' or 'timestamp' field", + ) in caplog.record_tuples + + caplog.clear() + + # Ignoring errors + with caplog.at_level(logging.WARNING): + backend.write([statement], ignore_errors=True) + + assert ( + "ralph.backends.data.clickhouse", + logging.WARNING, + f"Statement {statement} has an invalid 'id' or 'timestamp' field", + ) in caplog.record_tuples + + assert ( + "ralph.backends.data.clickhouse", + logging.ERROR, + f"Statement {statement} has an invalid 'id' or 'timestamp' field", + ) not in caplog.record_tuples + + +def test_backends_data_clickhouse_data_backend_write_method_with_duplicated_key( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method with duplicated key + conflict. + """ + # pylint: disable=unused-argument + backend = clickhouse_backend() + + timestamp = {"timestamp": "2022-06-27T15:36:50"} + dupe_id = str(uuid.uuid4()) + statements = [ + {"id": str(uuid.uuid4()), **timestamp}, + {"id": dupe_id, **timestamp}, + {"id": dupe_id, **timestamp}, + ] + + # No way of knowing how many write succeeded when there is an error + assert backend.write(statements, ignore_errors=True) == 0 + + with pytest.raises(BackendException, match="Duplicate IDs found in batch"): + backend.write(statements, ignore_errors=False) + + +def test_backends_data_clickhouse_data_backend_write_method_chunks_on_error( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method imports partial chunks + while raising BulkWriteError and ignoring errors. + """ + # pylint: disable=unused-argument + backend = clickhouse_backend() + + # Identical statement ID produces the same ObjectId, leading to a + # duplicated key write error while trying to bulk import this batch + timestamp = {"timestamp": "2022-06-27T15:36:50"} + dupe_id = str(uuid.uuid4()) + statements = [ + {"id": str(uuid.uuid4()), **timestamp}, + {"id": dupe_id, **timestamp}, + {"id": str(uuid.uuid4()), **timestamp}, + {"id": str(uuid.uuid4()), **timestamp}, + {"id": dupe_id, **timestamp}, + ] + assert backend.write(statements, ignore_errors=True) == 0 + + +def test_backends_data_clickhouse_data_backend_write_method( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method.""" + + sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + native_statements = [ + {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, + {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, + ] + statements = [ + {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} + for x in native_statements + ] + backend = clickhouse_backend() + count = backend.write(statements, target=CLICKHOUSE_TEST_TABLE_NAME) + + assert count == 2 + + result = clickhouse.query(sql).result_set + assert result[0][0] == 2 + + sql = f"""SELECT * FROM {CLICKHOUSE_TEST_TABLE_NAME} ORDER BY event.timestamp""" + result = list(clickhouse.query(sql).named_results()) + + assert result[0]["event_id"] == native_statements[0]["id"] + assert result[0]["emission_time"] == native_statements[0]["timestamp"] + assert result[0]["event"] == statements[0] + + assert result[1]["event_id"] == native_statements[1]["id"] + assert result[1]["emission_time"] == native_statements[1]["timestamp"] + assert result[1]["event"] == statements[1] + + +def test_backends_data_clickhouse_data_backend_write_method_bytes( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method.""" + + sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + native_statements = [ + {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, + {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, + ] + statements = [ + {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} + for x in native_statements + ] + + backend = clickhouse_backend() + byte_data = [] + for item in statements: + json_str = json.dumps(item, separators=(",", ":"), ensure_ascii=False) + byte_data.append(json_str.encode("utf-8")) + count = backend.write(byte_data, target=CLICKHOUSE_TEST_TABLE_NAME) + + assert count == 2 + + result = clickhouse.query(sql).result_set + assert result[0][0] == 2 + + sql = f"""SELECT * FROM {CLICKHOUSE_TEST_TABLE_NAME} ORDER BY event.timestamp""" + result = list(clickhouse.query(sql).named_results()) + + assert result[0]["event_id"] == native_statements[0]["id"] + assert result[0]["emission_time"] == native_statements[0]["timestamp"] + assert result[0]["event"] == statements[0] + + assert result[1]["event_id"] == native_statements[1]["id"] + assert result[1]["emission_time"] == native_statements[1]["timestamp"] + assert result[1]["event"] == statements[1] + + +def test_backends_data_clickhouse_data_backend_write_method_bytes_failed( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method.""" + + sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + backend = clickhouse_backend() + + byte_data = [] + json_str = "failed_json_str" + byte_data.append(json_str.encode("utf-8")) + + count = 0 + with pytest.raises(json.JSONDecodeError): + count = backend.write(byte_data) + + assert count == 0 + + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + count = backend.write(byte_data, ignore_errors=True) + assert count == 0 + + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + +def test_backends_data_clickhouse_data_backend_write_method_empty( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method.""" + + sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + backend = clickhouse_backend() + count = backend.write([], target=CLICKHOUSE_TEST_TABLE_NAME) + + assert count == 0 + + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + +def test_backends_data_clickhouse_data_backend_write_method_wrong_operation_type( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method.""" + + sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + native_statements = [ + {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, + {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, + ] + statements = [ + {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} + for x in native_statements + ] + + backend = clickhouse_backend() + with pytest.raises( + BackendParameterException, + match=f"{BaseOperationType.APPEND.name} operation_type is not allowed.", + ): + backend.write(data=statements, operation_type=BaseOperationType.APPEND) + + +def test_backends_data_clickhouse_data_backend_write_method_with_custom_chunk_size( + clickhouse, clickhouse_backend +): + """Test the `ClickHouseDataBackend.write` method with a custom chunk_size.""" + + sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" + result = clickhouse.query(sql).result_set + assert result[0][0] == 0 + + native_statements = [ + {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, + {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, + ] + statements = [ + {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} + for x in native_statements + ] + + backend = clickhouse_backend() + count = backend.write(statements, chunk_size=1) + assert count == 2 + + result = clickhouse.query(sql).result_set + assert result[0][0] == 2 + + sql = f"""SELECT * FROM {CLICKHOUSE_TEST_TABLE_NAME} ORDER BY event.timestamp""" + result = list(clickhouse.query(sql).named_results()) + + assert result[0]["event_id"] == native_statements[0]["id"] + assert result[0]["emission_time"] == native_statements[0]["timestamp"] + assert result[0]["event"] == statements[0] + + assert result[1]["event_id"] == native_statements[1]["id"] + assert result[1]["emission_time"] == native_statements[1]["timestamp"] + assert result[1]["event"] == statements[1] diff --git a/tests/backends/lrs/__init__.py b/tests/backends/lrs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/backends/lrs/test_clickhouse.py b/tests/backends/lrs/test_clickhouse.py new file mode 100644 index 000000000..d5bd79e9f --- /dev/null +++ b/tests/backends/lrs/test_clickhouse.py @@ -0,0 +1,368 @@ +"""Tests for Ralph clickhouse database backend.""" + +import logging +import uuid +from datetime import datetime, timezone + +import pytest +from clickhouse_connect.driver.exceptions import ClickHouseError + +from ralph.backends.lrs.base import StatementParameters +from ralph.exceptions import BackendException + + +@pytest.mark.parametrize( + "params,expected_params", + [ + # 0. Default query. + ( + {}, + { + "where": [], + "params": {"format": "exact"}, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # 1. Query by statementId. + ( + {"statementId": "test_id"}, + { + "where": ["event_id = {statementId:UUID}"], + "params": {"statementId": "test_id", "format": "exact"}, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # # 2. Query by statementId and agent with mbox IFI. + ( + {"statementId": "test_id", "agent": {"mbox": "mailto:foo@bar.baz"}}, + { + "where": [ + "event_id = {statementId:UUID}", + "event.actor.mbox = {actor__mbox:String}", + ], + "params": { + "statementId": "test_id", + "actor__mbox": "mailto:foo@bar.baz", + "format": "exact", + }, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # # 3. Query by statementId and agent with mbox_sha1sum IFI. + ( + { + "statementId": "test_id", + "agent": {"mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64"}, + }, + { + "where": [ + "event_id = {statementId:UUID}", + "event.actor.mbox_sha1sum = {actor__mbox_sha1sum:String}", + ], + "params": { + "statementId": "test_id", + "actor__mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64", + "format": "exact", + }, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # 4. Query by statementId and agent with openid IFI. + ( + { + "statementId": "test_id", + "agent": {"openid": "http://toby.openid.example.org/"}, + }, + { + "where": [ + "event_id = {statementId:UUID}", + "event.actor.openid = {actor__openid:String}", + ], + "params": { + "statementId": "test_id", + "actor__openid": "http://toby.openid.example.org/", + "format": "exact", + }, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # 5. Query by statementId and agent with account IFI. + ( + { + "statementId": "test_id", + "agent": { + "account__home_page": "http://www.example.com", + "account__name": "13936749", + }, + "ascending": True, + }, + { + "where": [ + "event_id = {statementId:UUID}", + "event.actor.account_name = {actor__account_name:String}", + "event.actor.account_homepage = {actor__account_homepage:String}", + ], + "params": { + "statementId": "test_id", + "actor__account_name": "13936749", + "actor__account_homepage": "http://www.example.com", + "ascending": True, + "format": "exact", + }, + "limit": None, + "sort": "emission_time ASCENDING, event_id ASCENDING", + }, + ), + # 6. Query by verb and activity with limit. + ( + { + "verb": "http://adlnet.gov/expapi/verbs/attended", + "activity": "http://www.example.com/meetings/34534", + "limit": 100, + }, + { + "where": [ + "event.verb.id = {verb:String}", + "event.object.objectType = 'Activity'", + "event.object.id = {activity:String}", + ], + "params": { + "verb": "http://adlnet.gov/expapi/verbs/attended", + "activity": "http://www.example.com/meetings/34534", + "limit": 100, + "format": "exact", + }, + "limit": 100, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # 7. Query by timerange (with since/until). + ( + { + "since": "2021-06-24T00:00:20.194929+00:00", + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "where": [ + "emission_time > {since:DateTime64(6)}", + "emission_time <= {until:DateTime64(6)}", + ], + "params": { + "since": datetime( + 2021, 6, 24, 0, 0, 20, 194929, tzinfo=timezone.utc + ), + "until": datetime( + 2023, 6, 24, 0, 0, 20, 194929, tzinfo=timezone.utc + ), + "format": "exact", + }, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + # 8. Query with pagination and pit_id. + ( + {"search_after": "1686557542970|0", "pit_id": "46ToAwMDaWR5BXV1a"}, + { + "where": [ + ( + "(emission_time < {search_after:DateTime64(6)}" + " OR " + "(emission_time = {search_after:DateTime64(6)}" + " AND " + "event_id < {pit_id:UUID}))" + ), + ], + "params": { + "search_after": "1686557542970|0", + "pit_id": "46ToAwMDaWR5BXV1a", + "format": "exact", + }, + "limit": None, + "sort": "emission_time DESCENDING, event_id DESCENDING", + }, + ), + ], +) +def test_backends_database_clickhouse_query_statements( + params, + expected_params, + monkeypatch, + clickhouse, + clickhouse_lrs_backend, +): + """Test the ClickHouse backend query_statements method, given a search query + failure, should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + def mock_read(query, target, ignore_errors): + """Mock the `ClickHouseDataBackend.read` method.""" + + assert query == { + "select": ["event_id", "emission_time", "event"], + "where": expected_params["where"], + "parameters": expected_params["params"], + "limit": expected_params["limit"], + "sort": expected_params["sort"], + } + + return {} + + backend = clickhouse_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + backend.query_statements(StatementParameters(**params)) + + +def test_backends_lrs_clickhouse_lrs_backend_query_statements( + clickhouse, clickhouse_lrs_backend +): + """Test the `ClickHouseLRSBackend.query_statements` method, given a query, + should return matching statements. + """ + # pylint: disable=unused-argument, invalid-name + backend = clickhouse_lrs_backend() + + # Insert documents + date_str = "09-19-2022" + datetime_object = datetime.strptime(date_str, "%m-%d-%Y").utcnow() + test_id = str(uuid.uuid4()) + statements = [ + { + "id": test_id, + "timestamp": datetime_object.isoformat(), + "actor": {"account": {"name": "test_name"}}, + "verb": {"id": "verb_id"}, + "object": {"id": "http://example.com", "objectType": "Activity"}, + }, + ] + + success = backend.write(statements, chunk_size=1) + assert success == 1 + + # Check the expected search query results. + result = backend.query_statements( + StatementParameters(statementId=test_id, limit=10) + ) + assert result.statements == statements + + +def test_backends_lrs_clickhouse_lrs_backend__find(clickhouse, clickhouse_lrs_backend): + """Test the `ClickHouseLRSBackend._find` method, given a query, + should return matching statements. + """ + # pylint: disable=unused-argument, invalid-name + backend = clickhouse_lrs_backend() + + # Insert documents + date_str = "09-19-2022" + datetime_object = datetime.strptime(date_str, "%m-%d-%Y").utcnow() + statements = [ + { + "id": str(uuid.uuid4()), + "timestamp": datetime_object.isoformat(), + "actor": {"account": {"name": "test_name"}}, + "verb": {"id": "verb_id"}, + "object": {"id": "http://example.com", "objectType": "Activity"}, + }, + ] + + success = backend.write(statements, chunk_size=1) + assert success == 1 + + # Check the expected search query results. + result = backend.query_statements(StatementParameters()) + assert result.statements == statements + + +def test_backends_lrs_clickhouse_lrs_backend_query_statements_by_ids( + clickhouse, clickhouse_lrs_backend +): + """Test the `ClickHouseLRSBackend.query_statements_by_ids` method, given + a list of ids, should return matching statements. + """ + # pylint: disable=unused-argument + backend = clickhouse_lrs_backend() + + # Insert documents + date_str = "09-19-2022" + datetime_object = datetime.strptime(date_str, "%m-%d-%Y").utcnow() + test_id = str(uuid.uuid4()) + statements = [ + { + "id": test_id, + "timestamp": datetime_object.isoformat(), + "actor": {"account": {"name": "test_name"}}, + "verb": {"id": "verb_id"}, + "object": {"id": "http://example.com", "objectType": "Activity"}, + }, + ] + + count = backend.write(statements, chunk_size=1) + assert count == 1 + + # Check the expected search query results. + result = list(backend.query_statements_by_ids([test_id])) + assert result[0]["event"] == statements[0] + + +def test_backends_lrs_clickhouse_lrs_backend_query_statements_client_failure( + clickhouse, clickhouse_lrs_backend, monkeypatch, caplog +): + """Test the `ClickHouseLRSBackend.query_statements`, given a client query + failure, should raise a `BackendException` and log the error. + """ + # pylint: disable=invalid-name,unused-argument + + def mock_query(*args, **kwargs): + """Mock the clickhouse_connect.client.search method.""" + raise ClickHouseError("Query error") + + backend = clickhouse_lrs_backend() + monkeypatch.setattr(backend.client, "query", mock_query) + + caplog.set_level(logging.ERROR) + + msg = "Failed to read documents: Query error" + with pytest.raises(BackendException, match=msg): + next(backend.query_statements(StatementParameters())) + + assert ( + "ralph.backends.lrs.clickhouse", + logging.ERROR, + "Failed to read from ClickHouse", + ) in caplog.record_tuples + + +def test_backends_lrs_clickhouse_lrs_backend_query_statements_by_ids_client_failure( + clickhouse, clickhouse_lrs_backend, monkeypatch, caplog +): + """Test the `ClickHouseLRSBackend.query_statements_by_ids`, given a client + query failure, should raise a `BackendException` and log the error. + """ + # pylint: disable=invalid-name,unused-argument + + def mock_query(*args, **kwargs): + """Mock the clickhouse_connect.client.search method.""" + raise ClickHouseError("Query error") + + backend = clickhouse_lrs_backend() + monkeypatch.setattr(backend.client, "query", mock_query) + + caplog.set_level(logging.ERROR) + + msg = "Failed to read documents: Query error" + with pytest.raises(BackendException, match=msg): + next(backend.query_statements_by_ids(["test_id"])) + + assert ( + "ralph.backends.lrs.clickhouse", + logging.ERROR, + "Failed to read from ClickHouse", + ) in caplog.record_tuples diff --git a/tests/conftest.py b/tests/conftest.py index f73e865cd..b6ca603fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ from .fixtures.backends import ( # noqa: F401 anyio_backend, clickhouse, + clickhouse_backend, + clickhouse_lrs_backend, es, es_data_stream, es_forwarding, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 99e445316..88fd315fa 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -21,6 +21,7 @@ from pymongo import MongoClient from pymongo.errors import CollectionInvalid +from ralph.backends.data.clickhouse import ClickHouseDataBackend from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.data.ldp import LDPDataBackend from ralph.backends.data.s3 import S3DataBackend, S3DataBackendSettings @@ -28,6 +29,7 @@ from ralph.backends.database.clickhouse import ClickHouseDatabase from ralph.backends.database.es import ESDatabase from ralph.backends.database.mongo import MongoDatabase +from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend from ralph.backends.storage.s3 import S3Storage from ralph.backends.storage.swift import SwiftStorage from ralph.conf import ClickhouseClientOptions, Settings, core_settings @@ -339,6 +341,59 @@ def get_ldp_data_backend(service_name: str = "foo", stream_id: str = "bar"): return get_ldp_data_backend +@pytest.fixture +def clickhouse_backend(): + """Return the `get_clickhouse_data_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name + + def get_clickhouse_data_backend(): + """Return an instance of ClickHouseDataBackend.""" + settings = ClickHouseDataBackend.settings_class( + HOST=CLICKHOUSE_TEST_HOST, + PORT=CLICKHOUSE_TEST_PORT, + DATABASE=CLICKHOUSE_TEST_DATABASE, + EVENT_TABLE_NAME=CLICKHOUSE_TEST_TABLE_NAME, + USERNAME="default", + PASSWORD="", + CLIENT_OPTIONS={ + "date_time_input_format": "best_effort", + "allow_experimental_object_type": 1, + }, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + ) + return ClickHouseDataBackend(settings) + + return get_clickhouse_data_backend + + +@pytest.fixture +def clickhouse_lrs_backend(): + """Return the `get_clickhouse_lrs_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name + + def get_clickhouse_lrs_backend(): + """Return an instance of ClickHouseLRSBackend.""" + settings = ClickHouseLRSBackend.settings_class( + HOST=CLICKHOUSE_TEST_HOST, + PORT=CLICKHOUSE_TEST_PORT, + DATABASE=CLICKHOUSE_TEST_DATABASE, + EVENT_TABLE_NAME=CLICKHOUSE_TEST_TABLE_NAME, + USERNAME="default", + PASSWORD="", + CLIENT_OPTIONS={ + "date_time_input_format": "best_effort", + "allow_experimental_object_type": 1, + }, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + IDS_CHUNK_SIZE=10000, + ) + return ClickHouseLRSBackend(settings) + + return get_clickhouse_lrs_backend + + @pytest.fixture def swift(): """Return get_swift_storage function.""" From 68cf8527619ef46440c5b78c9b7e6a882d3d10c2 Mon Sep 17 00:00:00 2001 From: SergioSim Date: Fri, 7 Jul 2023 18:19:54 +0200 Subject: [PATCH 12/65] =?UTF-8?q?=E2=9C=85(tests)=20use=20requests=5Fmock?= =?UTF-8?q?=20fixture=20in=20LDPDataBackend=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We want to simplify our tests that are mocking the request package. Therefore we choose to use the `request_mock` library. --- setup.cfg | 1 + tests/backends/data/test_ldp.py | 99 +++++---------------------------- 2 files changed, 16 insertions(+), 84 deletions(-) diff --git a/setup.cfg b/setup.cfg index 500904095..80c7eb298 100644 --- a/setup.cfg +++ b/setup.cfg @@ -89,6 +89,7 @@ dev = pytest-asyncio==0.21.1 pytest-cov==4.1.0 pytest-httpx<0.23.0 # pin as Python 3.7 and 3.8 is no longer supported from release 0.23.0 + requests-mock==1.11.0 responses<0.23.2 # pin until boto3 supports urllib3>=2 ci = twine==4.0.2 diff --git a/tests/backends/data/test_ldp.py b/tests/backends/data/test_ldp.py index f1ef187a2..20740980c 100644 --- a/tests/backends/data/test_ldp.py +++ b/tests/backends/data/test_ldp.py @@ -6,12 +6,12 @@ import os.path from collections.abc import Iterable from operator import itemgetter -from pathlib import Path from xmlrpc.client import gzip_decode import ovh import pytest import requests +import requests_mock from ralph.backends.data.base import BaseOperationType, BaseQuery, DataBackendStatus from ralph.backends.data.ldp import LDPDataBackend @@ -435,41 +435,19 @@ def test_backends_data_ldp_data_backend_read_method_without_raw_ouput( log a warning message. """ - class MockResponse: - """Mock the requests response.""" - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def raise_for_status(self): - """Ignored.""" - - def iter_content(self, chunk_size): - """Fake content file iteration.""" - # pylint: disable=no-self-use,unused-argument - yield - - def mock_requests_get(url, stream=True, timeout=None): - """Mock the request get method.""" - # pylint: disable=unused-argument - return MockResponse() - def mock_get(url): """Mock the OVH client get request.""" # pylint: disable=unused-argument return {"filename": "archive_name", "size": 10} backend = ldp_backend() - - monkeypatch.setattr(requests, "get", mock_requests_get) - monkeypatch.setattr(backend, "_url", lambda *_: "/") + monkeypatch.setattr(backend, "_url", lambda *_: "http://example.com") monkeypatch.setattr(backend.client, "get", mock_get) with caplog.at_level(logging.WARNING): - list(backend.read(query="archiveID", raw_output=False)) + with requests_mock.Mocker() as request_mocker: + request_mocker.get("http://example.com") + assert not list(backend.read(query="archiveID", raw_output=False)) assert ( "ralph.backends.data.ldp", @@ -481,32 +459,10 @@ def mock_get(url): def test_backends_data_ldp_data_backend_read_method_without_ignore_errors( ldp_backend, caplog, monkeypatch ): - """Test the `LDPDataBackend.read method, given `ignore_errors` set to `False`, + """Test the `LDPDataBackend.read` method, given `ignore_errors` set to `False`, should log a warning message. """ - class MockResponse: - """Mock the requests response.""" - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def raise_for_status(self): - """Ignored.""" - - def iter_content(self, chunk_size): - """Fake content file iteration.""" - # pylint: disable=no-self-use,unused-argument - yield - - def mock_requests_get(url, stream=True, timeout=None): - """Mock the request get method.""" - # pylint: disable=unused-argument - return MockResponse() - def mock_get(url): """Mock the OVH client get request.""" # pylint: disable=unused-argument @@ -514,13 +470,14 @@ def mock_get(url): backend = ldp_backend() - monkeypatch.setattr(requests, "get", mock_requests_get) backend = ldp_backend() - monkeypatch.setattr(backend, "_url", lambda *_: "/") + monkeypatch.setattr(backend, "_url", lambda *_: "http://example.com") monkeypatch.setattr(backend.client, "get", mock_get) with caplog.at_level(logging.WARNING): - list(backend.read(query="archiveID", ignore_errors=False)) + with requests_mock.Mocker() as request_mocker: + request_mocker.get("http://example.com") + assert not list(backend.read(query="archiveID", ignore_errors=False)) assert ( "ralph.backends.data.ldp", @@ -603,10 +560,8 @@ def test_backends_data_ldp_data_backend_read_method_with_query( # pylint: disable=invalid-name # Create fake archive to stream. - archive_path = Path("/tmp/2020-06-16.gz") archive_content = {"foo": "bar"} - with gzip.open(archive_path, "wb") as archive_file: - archive_file.write(bytes(json.dumps(archive_content), encoding="utf-8")) + archive = gzip.compress(bytes(json.dumps(archive_content), encoding="utf-8")) def mock_ovh_post(url): """Mock the OVH Client post request.""" @@ -639,32 +594,6 @@ def mock_ovh_get(url): "size": 67906662, } - class MockRequestsResponse: - """Mock the requests response.""" - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def iter_content(self, chunk_size): - """Fake content file iteration.""" - # pylint: disable=no-self-use - - with archive_path.open("rb") as archive: - while chunk := archive.read(chunk_size): - yield chunk - - def raise_for_status(self): - """Ignored.""" - - def mock_requests_get(url, stream=True, timeout=None): - """Mock the request get method.""" - # pylint: disable=unused-argument - - return MockRequestsResponse() - # Freeze the ralph.utils.now() value. frozen_now = now() monkeypatch.setattr("ralph.backends.data.ldp.now", lambda: frozen_now) @@ -672,12 +601,14 @@ def mock_requests_get(url, stream=True, timeout=None): backend = ldp_backend() monkeypatch.setattr(backend.client, "post", mock_ovh_post) monkeypatch.setattr(backend.client, "get", mock_ovh_get) - monkeypatch.setattr(requests, "get", mock_requests_get) + monkeypatch.setattr(backend, "_url", lambda *_: "http://example.com") fs.create_dir(settings.APP_DIR) assert not os.path.exists(settings.HISTORY_FILE) - result = b"".join(backend.read(query="5d5c4c93-04a4-42c5-9860-f51fa4044aa1")) + with requests_mock.Mocker() as request_mocker: + request_mocker.get("http://example.com", content=archive) + result = b"".join(backend.read(query="5d5c4c93-04a4-42c5-9860-f51fa4044aa1")) assert os.path.exists(settings.HISTORY_FILE) assert backend.history == [ From 24657bf3f0bb6eb7096961476621276733cba5af Mon Sep 17 00:00:00 2001 From: SergioSim Date: Mon, 22 May 2023 10:14:09 +0200 Subject: [PATCH 13/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20ESDataBackend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We add the ES data backend implementation that is mostly taken from the existing ESDatabase backend. --- docker-compose.yml | 11 + src/ralph/backends/data/base.py | 17 +- src/ralph/backends/data/es.py | 395 ++++++++++++++++++ src/ralph/backends/lrs/base.py | 3 + src/ralph/backends/lrs/es.py | 106 +++++ src/ralph/conf.py | 10 +- tests/backends/data/test_base.py | 25 +- tests/backends/data/test_es.py | 695 +++++++++++++++++++++++++++++++ tests/backends/lrs/test_es.py | 389 +++++++++++++++++ tests/conftest.py | 2 + tests/fixtures/backends.py | 56 ++- tests/test_conf.py | 9 +- 12 files changed, 1691 insertions(+), 27 deletions(-) create mode 100644 src/ralph/backends/data/es.py create mode 100644 src/ralph/backends/lrs/es.py create mode 100644 tests/backends/data/test_es.py create mode 100644 tests/backends/lrs/test_es.py diff --git a/docker-compose.yml b/docker-compose.yml index 8038dc53c..0209de9d2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,11 +29,18 @@ services: elasticsearch: image: elasticsearch:8.1.0 environment: + bootstrap.memory_lock: true discovery.type: single-node xpack.security.enabled: "false" ports: - "9200:9200" + volumes: + - esdata:/usr/share/elasticsearch/data mem_limit: 2g + ulimits: + memlock: + soft: -1 + hard: -1 mongo: image: mongo:5.0.9 @@ -67,3 +74,7 @@ services: # -- tools dockerize: image: jwilder/dockerize + +volumes: + esdata: + driver: local diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 87e6638c6..9e70cc2c7 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -107,17 +107,18 @@ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery if isinstance(query, dict): try: query = self.query_model(**query) - except ValidationError as err: + except ValidationError as error: + msg = "The 'query' argument is expected to be a %s instance. %s" + errors = error.errors() + logger.error(msg, self.query_model.__name__, errors) raise BackendParameterException( - "The 'query' argument is expected to be a " - f"{self.query_model.__name__} instance. {err.errors()}" - ) from err + msg % (self.query_model.__name__, errors) + ) from error if not isinstance(query, self.query_model): - raise BackendParameterException( - "The 'query' argument is expected to be a " - f"{self.query_model.__name__} instance." - ) + msg = "The 'query' argument is expected to be a %s instance." + logger.error(msg, self.query_model.__name__) + raise BackendParameterException(msg % (self.query_model.__name__,)) logger.debug("Query: %s", str(query)) diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py new file mode 100644 index 000000000..09f4a2b38 --- /dev/null +++ b/src/ralph/backends/data/es.py @@ -0,0 +1,395 @@ +"""Elasticsearch data backend for Ralph.""" + +import json +import logging +from io import IOBase +from itertools import chain +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Union + +from elasticsearch import ApiError, Elasticsearch, TransportError +from elasticsearch.helpers import BulkIndexError, streaming_bulk +from pydantic import BaseModel + +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseOperationType, + BaseQuery, + DataBackendStatus, + enforce_query_checks, +) +from ralph.conf import BaseSettingsConfig, CommaSeparatedTuple +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class ESClientOptions(BaseModel): + """Elasticsearch additional client options.""" + + ca_certs: Path = None + verify_certs: bool = None + + +class ESDataBackendSettings(BaseDataBackendSettings): + """Elasticsearch data backend default configuration. + + Attributes: + ALLOW_YELLOW_STATUS (bool): Whether to consider Elasticsearch yellow health + status to be ok. + CLIENT_OPTIONS (dict): A dictionary of valid options for the Elasticsearch class + initialization. + DEFAULT_CHUNK_SIZE (int): The default chunk size for reading batches of + documents. + DEFAULT_INDEX (str): The default index to use for querying Elasticsearch. + HOSTS (str or tuple): The comma separated list of Elasticsearch nodes to + connect to. + LOCALE_ENCODING (str): The encoding used for reading/writing documents. + POINT_IN_TIME_KEEP_ALIVE (str): The duration for which Elasticsearch should + keep a point in time alive. + REFRESH_AFTER_WRITE (str or bool): Whether the Elasticsearch index should be + refreshed after the write operation. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__DATA__ES__" + + ALLOW_YELLOW_STATUS: bool = False + CLIENT_OPTIONS: ESClientOptions = ESClientOptions() + DEFAULT_CHUNK_SIZE: int = 500 + DEFAULT_INDEX: str = "statements" + HOSTS: CommaSeparatedTuple = ("http://localhost:9200",) + LOCALE_ENCODING: str = "utf8" + POINT_IN_TIME_KEEP_ALIVE: str = "1m" + REFRESH_AFTER_WRITE: Union[Literal["false", "true", "wait_for"], bool, str, None] + + +class ESQueryPit(BaseModel): + """Elasticsearch point in time (pit) query configuration. + + Attributes: + id (str): Context identifier of the Elasticsearch point in time. + keep_alive (str): The duration for which Elasticsearch should keep the point in + time alive. + """ + + id: Optional[str] + keep_alive: Optional[str] + + +class ESQuery(BaseQuery): + """Elasticsearch query model. + + Attributes: + query (dict): A search query definition using the Elasticsearch Query DSL. + See Elasticsearch search reference for query DSL syntax: + https://www.elastic.co/guide/en/elasticsearch/reference/8.9/search-search.html#request-body-search-query + query_string (str): The Elastisearch query in the Lucene query string syntax. + See Elasticsearch search reference for Lucene query syntax: + https://www.elastic.co/guide/en/elasticsearch/reference/8.9/search-search.html#search-api-query-params-q + pit (dict): Limit the search to a point in time (PIT). See ESQueryPit. + size (int): The maximum number of documents to yield. + sort (str or list): Specify how to sort search results. Set to `_doc` or + `_shard_doc` if order doesn't matter. + See https://www.elastic.co/guide/en/elasticsearch/reference/8.9/sort-search-results.html + search_after (list): Limit search query results to values after a document + matching the set of sort values in `search_after`. Used for pagination. + track_total_hits (bool): Number of hits matching the query to count accurately. + Not used. Always set to `False`. + """ # pylint: disable=line-too-long # noqa: E501 + + query: dict = {"match_all": {}} + pit: ESQueryPit = ESQueryPit() + size: Optional[int] + sort: Union[str, List[dict]] = "_shard_doc" + search_after: Optional[list] + track_total_hits: Literal[False] = False + + +class ESDataBackend(BaseDataBackend): + """Elasticsearch data backend.""" + + name = "es" + query_model = ESQuery + settings_class = ESDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Instantiate the Elasticsearch data backend. + + Args: + settings (ESDataBackendSettings or None): The data backend settings. + If `settings` is `None`, a default settings instance is used instead. + """ + self.settings = settings if settings else self.settings_class() + self._client = None + + @property + def client(self): + """Create an Elasticsearch client if it doesn't exist.""" + if not self._client: + self._client = Elasticsearch( + self.settings.HOSTS, **self.settings.CLIENT_OPTIONS.dict() + ) + return self._client + + def status(self) -> DataBackendStatus: + """Check Elasticsearch cluster connection and status.""" + try: + self.client.info() + cluster_status = self.client.cat.health() + except TransportError as error: + logger.error("Failed to connect to Elasticsearch: %s", error) + return DataBackendStatus.AWAY + + if "green" in cluster_status: + return DataBackendStatus.OK + + if "yellow" in cluster_status and self.settings.ALLOW_YELLOW_STATUS: + logger.info("Cluster status is yellow.") + return DataBackendStatus.OK + + logger.error("Cluster status is not green: %s", cluster_status) + + return DataBackendStatus.ERROR + + def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List available Elasticsearch indices, data streams and aliases. + + Args: + target (str or None): The comma-separated list of data streams, indices, + and aliases to limit the request. Supports wildcards (*). + If target is `None`, lists all available indices, data streams and + aliases. Equivalent to (`target` = "*"). + details (bool): Get detailed informations instead of just names. + new (bool): Ignored. + + Yield: + str: The next index, data stream or alias name. (If `details` is False). + dict: The next index, data stream or alias details. (If `details` is True). + + Raise: + BackendException: If a failure during indices retrieval occurs. + """ + target = target if target else "*" + try: + indices = self.client.indices.get(index=target) + except (ApiError, TransportError) as error: + msg = "Failed to read indices: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + if new: + logger.warning("The `new` argument is ignored") + + if details: + for index, value in indices.items(): + yield {index: value} + + return + + for index in indices: + yield index + + @enforce_query_checks + def read( + self, + *, + query: Union[str, ESQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read documents matching the query in the target index and yield them. + + Args: + query (str or ESQuery): A query in the Lucene query string syntax or a + dictionary defining a search definition using the Elasticsearch Query + DSL. The Lucene query overrides the query DSL if present. See ESQuery. + target (str or None): The target Elasticsearch index name to query. + If target is `None`, the `DEFAULT_INDEX` is used instead. + chunk_size (int or None): The chunk size for reading batches of documents. + If chunk_size is `None` it defaults to `DEFAULT_CHUNK_SIZE`. + raw_output (bool): Controls whether to yield dictionaries or bytes. + ignore_errors (bool): Ignored. + + Yield: + bytes: The next raw document if `raw_output` is True. + dict: The next JSON parsed document if `raw_output` is False. + + Raise: + BackendException: If a failure occurs during Elasticsearch connection. + """ + target = target if target else self.settings.DEFAULT_INDEX + chunk_size = chunk_size if chunk_size else self.settings.DEFAULT_CHUNK_SIZE + if ignore_errors: + logger.warning("The `ignore_errors` argument is ignored") + + if not query.pit.keep_alive: + query.pit.keep_alive = self.settings.POINT_IN_TIME_KEEP_ALIVE + if not query.pit.id: + try: + query.pit.id = self.client.open_point_in_time( + index=target, keep_alive=query.pit.keep_alive + )["id"] + except (ApiError, TransportError, ValueError) as error: + msg = "Failed to open Elasticsearch point in time: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + limit = query.size + kwargs = query.dict(exclude={"query_string", "size"}) + if query.query_string: + kwargs["q"] = query.query_string + + count = chunk_size + while limit or chunk_size == count: + kwargs["size"] = limit if limit and limit < chunk_size else chunk_size + try: + documents = self.client.search(**kwargs)["hits"]["hits"] + except (ApiError, TransportError, TypeError) as error: + msg = "Failed to execute Elasticsearch query: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + count = len(documents) + if limit: + limit -= count if chunk_size == count else limit + query.search_after = None + if count: + query.search_after = [str(part) for part in documents[-1]["sort"]] + kwargs["search_after"] = query.search_after + if raw_output: + documents = self._read_raw(documents) + for document in documents: + yield document + + def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Write data documents to the target index and return their count. + + Args: + data: (Iterable or IOBase): The data containing documents to write. + target (str or None): The target Elasticsearch index name. + If target is `None`, the `DEFAULT_INDEX` is used instead. + chunk_size (int or None): The number of documents to write in one batch. + If chunk_size is `None` it defaults to `DEFAULT_CHUNK_SIZE`. + ignore_errors (bool): If `True`, errors during the write operation + will be ignored and logged. If `False` (default), a `BackendException` + will be raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Return: + int: The number of written documents. + + Raise: + BackendException: If a failure occurs while writing to Elasticsearch or + during document decoding and `ignore_errors` is set to `False`. + BackendParameterException: If the `operation_type` is `APPEND` as it is not + supported. + """ + count = 0 + data = iter(data) + try: + first_record = next(data) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return count + if not operation_type: + operation_type = self.default_operation_type + target = target if target else self.settings.DEFAULT_INDEX + chunk_size = chunk_size if chunk_size else self.settings.DEFAULT_CHUNK_SIZE + if operation_type == BaseOperationType.APPEND: + msg = "Append operation_type is not supported." + logger.error(msg) + raise BackendParameterException(msg) + + data = chain((first_record,), data) + if isinstance(first_record, bytes): + data = self._parse_bytes_to_dict(data, ignore_errors) + + logger.debug( + "Start writing to the %s index (chunk size: %d)", target, chunk_size + ) + try: + for success, action in streaming_bulk( + client=self.client, + actions=self._to_documents(data, target, operation_type), + chunk_size=chunk_size, + raise_on_error=(not ignore_errors), + refresh=self.settings.REFRESH_AFTER_WRITE, + ): + count += success + logger.debug("Wrote %d document [action: %s]", success, action) + + logger.info("Finished writing %d documents with success", count) + except (BulkIndexError, ApiError, TransportError) as error: + msg = "%s %s Total succeeded writes: %s" + details = getattr(error, "errors", "") + logger.error(msg, error, details, count) + raise BackendException(msg % (error, details, count)) from error + return count + + @staticmethod + def _to_documents( + data: Iterable[dict], + target: str, + operation_type: BaseOperationType, + ) -> Iterator[dict]: + """Convert dictionaries from `data` to ES documents and yield them.""" + if operation_type == BaseOperationType.UPDATE: + for item in data: + yield { + "_index": target, + "_id": item.get("id", None), + "_op_type": operation_type.value, + "doc": item, + } + elif operation_type in (BaseOperationType.CREATE, BaseOperationType.INDEX): + for item in data: + yield { + "_index": target, + "_id": item.get("id", None), + "_op_type": operation_type.value, + "_source": item, + } + else: + # operation_type == BaseOperationType.DELETE (by exclusion) + for item in data: + yield { + "_index": target, + "_id": item.get("id", None), + "_op_type": operation_type.value, + } + + def _read_raw(self, documents: Iterable[Dict[str, Any]]) -> Iterator[bytes]: + """Read the `documents` Iterable and yield bytes.""" + for document in documents: + yield json.dumps(document).encode(self.settings.LOCALE_ENCODING) + + @staticmethod + def _parse_bytes_to_dict( + raw_documents: Iterable[bytes], ignore_errors: bool + ) -> Iterator[dict]: + """Read the `raw_documents` Iterable and yield dictionaries.""" + for raw_document in raw_documents: + try: + yield json.loads(raw_document) + except (TypeError, json.JSONDecodeError) as error: + msg = "Failed to decode JSON: %s, for document: %s" + logger.error(msg, error, raw_document) + if ignore_errors: + continue + raise BackendException(msg % (error, raw_document)) from error diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index d7a3309f5..957ccabc4 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -59,11 +59,14 @@ class StatementParameters(BaseModel): search_after: Optional[str] pit_id: Optional[str] authority: Optional[AgentParameters] + ignore_order: Optional[bool] class BaseLRSBackend(BaseDataBackend): """Base LRS backend interface.""" + settings_class = BaseLRSBackendSettings + @abstractmethod def query_statements(self, params: StatementParameters) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py new file mode 100644 index 000000000..8f3354f9a --- /dev/null +++ b/src/ralph/backends/lrs/es.py @@ -0,0 +1,106 @@ +"""Elasticsearch LRS backend for Ralph.""" + +import logging +from typing import Iterator, List + +from ralph.backends.data.es import ESDataBackend, ESQuery, ESQueryPit +from ralph.backends.lrs.base import ( + AgentParameters, + BaseLRSBackend, + StatementParameters, + StatementQueryResult, +) +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class ESLRSBackend(BaseLRSBackend, ESDataBackend): + """Elasticsearch LRS backend implementation.""" + + settings_class = ESDataBackend.settings_class + + def query_statements(self, params: StatementParameters) -> StatementQueryResult: + """Return the statements query payload using xAPI parameters.""" + es_query_filters = [] + + if params.statementId: + es_query_filters += [{"term": {"_id": params.statementId}}] + + self._add_agent_filters(es_query_filters, params.agent, "actor") + self._add_agent_filters(es_query_filters, params.authority, "authority") + + if params.verb: + es_query_filters += [{"term": {"verb.id.keyword": params.verb}}] + + if params.activity: + es_query_filters += [ + {"term": {"object.objectType.keyword": "Activity"}}, + {"term": {"object.id.keyword": params.activity}}, + ] + + if params.since: + es_query_filters += [{"range": {"timestamp": {"gt": params.since}}}] + + if params.until: + es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] + + es_query = { + "pit": ESQueryPit.construct(id=params.pit_id), + "size": params.limit, + "sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}], + } + if len(es_query_filters) > 0: + es_query["query"] = {"bool": {"filter": es_query_filters}} + + if params.ignore_order: + es_query["sort"] = "_shard_doc" + + if params.search_after: + es_query["search_after"] = params.search_after.split("|") + + # Note: `params` fields are validated thus we skip their validation in ESQuery. + query = ESQuery.construct(**es_query) + try: + es_documents = self.read(query=query, chunk_size=params.limit) + statements = [document["_source"] for document in es_documents] + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from Elasticsearch") + raise error + + return StatementQueryResult( + statements=statements, + pit_id=query.pit.id, + search_after="|".join(query.search_after) if query.search_after else "", + ) + + def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Yield statements with matching ids from the backend.""" + try: + es_response = self.read(query={"query": {"terms": {"_id": ids}}}) + yield from (document["_source"] for document in es_response) + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from Elasticsearch") + raise error + + @staticmethod + def _add_agent_filters( + es_query_filters: list, agent_params: AgentParameters, target_field: str + ): + """Add filters relative to agents to `es_query_filters`.""" + if not agent_params: + return + if agent_params.mbox: + field = f"{target_field}.mbox.keyword" + es_query_filters += [{"term": {field: agent_params.mbox}}] + elif agent_params.mbox_sha1sum: + field = f"{target_field}.mbox_sha1sum.keyword" + es_query_filters += [{"term": {field: agent_params.mbox_sha1sum}}] + elif agent_params.openid: + field = f"{target_field}.openid.keyword" + es_query_filters += [{"term": {field: agent_params.openid}}] + elif agent_params.account__name: + field = f"{target_field}.account.name.keyword" + es_query_filters += [{"term": {field: agent_params.account__name}}] + field = f"{target_field}.account.homePage.keyword" + es_query_filters += [{"term": {field: agent_params.account__home_page}}] diff --git a/src/ralph/conf.py b/src/ralph/conf.py index ee6eb94fc..f10c1ca0e 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -48,14 +48,14 @@ class Config(BaseSettingsConfig): class CommaSeparatedTuple(str): - """Pydantic field type validating comma separated strings or tuples.""" + """Pydantic field type validating comma separated strings or lists/tuples.""" @classmethod def __get_validators__(cls): # noqa: D105 - def validate(value: Union[str, Tuple[str]]) -> Tuple[str]: - """Checks whether the value is a comma separated string or a tuple.""" - if isinstance(value, tuple): - return value + def validate(value: Union[str, Tuple[str], List[str]]) -> Tuple[str]: + """Check whether the value is a comma separated string or a list/tuple.""" + if isinstance(value, (tuple, list)): + return tuple(value) if isinstance(value, str): return tuple(value.split(",")) diff --git a/tests/backends/data/test_base.py b/tests/backends/data/test_base.py index 3c12183f7..e86e72e08 100644 --- a/tests/backends/data/test_base.py +++ b/tests/backends/data/test_base.py @@ -1,4 +1,5 @@ """Tests for the base data backend""" +import logging import pytest @@ -41,8 +42,21 @@ def write(self): # pylint: disable=arguments-differ,missing-function-docstring MockBaseDataBackend().read(query=value) -@pytest.mark.parametrize("value", [[], {"foo": "bar"}]) -def test_backends_data_base_enforce_query_checks_with_invalid_input(value): +@pytest.mark.parametrize( + "value,error", + [ + ([], r"The 'query' argument is expected to be a BaseQuery instance."), + ( + {"foo": "bar"}, + r"The 'query' argument is expected to be a BaseQuery instance. " + r"\[\{'loc': \('foo',\), 'msg': 'extra fields not permitted', " + r"'type': 'value_error.extra'\}\]", + ), + ], +) +def test_backends_data_base_enforce_query_checks_with_invalid_input( + value, error, caplog +): """Tests the enforce_query_checks function given invalid input.""" class MockBaseDataBackend(BaseDataBackend): @@ -66,6 +80,9 @@ def list(self): # pylint: disable=arguments-differ,missing-function-docstring def write(self): # pylint: disable=arguments-differ,missing-function-docstring pass - error = "The 'query' argument is expected to be a BaseQuery instance." with pytest.raises(BackendParameterException, match=error): - MockBaseDataBackend().read(query=value) + with caplog.at_level(logging.ERROR): + MockBaseDataBackend().read(query=value) + + error = error.replace("\\", "") + assert ("ralph.backends.data.base", logging.ERROR, error) in caplog.record_tuples diff --git a/tests/backends/data/test_es.py b/tests/backends/data/test_es.py new file mode 100644 index 000000000..ed0b116e6 --- /dev/null +++ b/tests/backends/data/test_es.py @@ -0,0 +1,695 @@ +"""Tests for Ralph Elasticsearch data backend.""" + +import json +import logging +import random +import re +from collections.abc import Iterable +from datetime import datetime +from io import BytesIO +from pathlib import Path + +import pytest +from elastic_transport import ApiResponseMeta +from elasticsearch import ApiError +from elasticsearch import ConnectionError as ESConnectionError +from elasticsearch import Elasticsearch + +from ralph.backends.data.base import BaseOperationType, DataBackendStatus +from ralph.backends.data.es import ( + ESClientOptions, + ESDataBackend, + ESDataBackendSettings, + ESQuery, +) +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + +from tests.fixtures.backends import ( + ES_TEST_FORWARDING_INDEX, + ES_TEST_INDEX, + get_es_fixture, +) + + +def test_backends_data_es_data_backend_default_instantiation(monkeypatch, fs): + """Test the `ESDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "ALLOW_YELLOW_STATUS", + "CLIENT_OPTIONS", + "CLIENT_OPTIONS__ca_certs", + "CLIENT_OPTIONS__verify_certs", + "DEFAULT_CHUNK_SIZE", + "DEFAULT_INDEX", + "HOSTS", + "LOCALE_ENCODING", + "POINT_IN_TIME_KEEP_ALIVE", + "REFRESH_AFTER_WRITE", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__ES__{name}", raising=False) + + assert ESDataBackend.name == "es" + assert ESDataBackend.query_model == ESQuery + assert ESDataBackend.default_operation_type == BaseOperationType.INDEX + assert ESDataBackend.settings_class == ESDataBackendSettings + backend = ESDataBackend() + assert not backend.settings.ALLOW_YELLOW_STATUS + assert backend.settings.CLIENT_OPTIONS == ESClientOptions() + assert backend.settings.DEFAULT_CHUNK_SIZE == 500 + assert backend.settings.DEFAULT_INDEX == "statements" + assert backend.settings.HOSTS == ("http://localhost:9200",) + assert backend.settings.LOCALE_ENCODING == "utf8" + assert backend.settings.POINT_IN_TIME_KEEP_ALIVE == "1m" + assert not backend.settings.REFRESH_AFTER_WRITE + assert isinstance(backend.client, Elasticsearch) + elasticsearch_node = backend.client.transport.node_pool.get() + assert elasticsearch_node.config.ca_certs is None + assert elasticsearch_node.config.verify_certs is None + assert elasticsearch_node.host == "localhost" + assert elasticsearch_node.port == 9200 + + +def test_backends_data_es_data_backend_instantiation_with_settings(): + """Test the `ESDataBackend` instantiation with settings.""" + settings = ESDataBackendSettings( + ALLOW_YELLOW_STATUS=True, + CLIENT_OPTIONS={"verify_certs": True, "ca_certs": "/path/to/ca/bundle"}, + DEFAULT_CHUNK_SIZE=5000, + DEFAULT_INDEX=ES_TEST_INDEX, + HOSTS=["https://elasticsearch_hostname:9200"], + LOCALE_ENCODING="utf-16", + POINT_IN_TIME_KEEP_ALIVE="5m", + REFRESH_AFTER_WRITE=True, + ) + backend = ESDataBackend(settings) + assert backend.settings.ALLOW_YELLOW_STATUS + assert backend.settings.CLIENT_OPTIONS == ESClientOptions( + verify_certs=True, ca_certs="/path/to/ca/bundle" + ) + assert backend.settings.DEFAULT_CHUNK_SIZE == 5000 + assert backend.settings.DEFAULT_INDEX == ES_TEST_INDEX + assert backend.settings.HOSTS == ("https://elasticsearch_hostname:9200",) + assert backend.settings.LOCALE_ENCODING == "utf-16" + assert backend.settings.POINT_IN_TIME_KEEP_ALIVE == "5m" + assert backend.settings.REFRESH_AFTER_WRITE + assert isinstance(backend.client, Elasticsearch) + elasticsearch_node = backend.client.transport.node_pool.get() + assert elasticsearch_node.config.ca_certs == Path("/path/to/ca/bundle") + assert elasticsearch_node.config.verify_certs is True + assert elasticsearch_node.host == "elasticsearch_hostname" + assert elasticsearch_node.port == 9200 + assert backend.settings.POINT_IN_TIME_KEEP_ALIVE == "5m" + + try: + ESDataBackend(settings) + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"Two ESDataBackends should not raise exceptions: {err}") + + +def test_backends_data_es_data_backend_status_method(monkeypatch, es_backend, caplog): + """Test the `ESDataBackend.status` method.""" + backend = es_backend() + with monkeypatch.context() as elasticsearch_patch: + # Given green status, the `status` method should return `DataBackendStatus.OK`. + es_status = "1664532320 10:05:20 docker-cluster green 1 1 2 2 0 0 1 0 - 66.7%" + elasticsearch_patch.setattr(backend.client, "info", lambda: None) + elasticsearch_patch.setattr(backend.client.cat, "health", lambda: es_status) + assert backend.status() == DataBackendStatus.OK + + with monkeypatch.context() as elasticsearch_patch: + # Given yellow status, the `status` method should return + # `DataBackendStatus.ERROR`. + es_status = "1664532320 10:05:20 docker-cluster yellow 1 1 2 2 0 0 1 0 - 66.7%" + elasticsearch_patch.setattr(backend.client, "info", lambda: None) + elasticsearch_patch.setattr(backend.client.cat, "health", lambda: es_status) + assert backend.status() == DataBackendStatus.ERROR + # Given yellow status, and `settings.ALLOW_YELLOW_STATUS` set to `True`, + # the `status` method should return `DataBackendStatus.OK`. + backend.settings.ALLOW_YELLOW_STATUS = True + with caplog.at_level(logging.INFO): + assert backend.status() == DataBackendStatus.OK + + assert ( + "ralph.backends.data.es", + logging.INFO, + "Cluster status is yellow.", + ) in caplog.record_tuples + + # Given a connection exception, the `status` method should return + # `DataBackendStatus.ERROR`. + with monkeypatch.context() as elasticsearch_patch: + + def mock_connection_error(): + """ES client info mock that raises a connection error.""" + raise ESConnectionError("", (Exception("Mocked connection error"),)) + + elasticsearch_patch.setattr(backend.client, "info", mock_connection_error) + with caplog.at_level(logging.ERROR): + assert backend.status() == DataBackendStatus.AWAY + + assert ( + "ralph.backends.data.es", + logging.ERROR, + "Failed to connect to Elasticsearch: Connection error caused by: " + "Exception(Mocked connection error)", + ) in caplog.record_tuples + + +@pytest.mark.parametrize( + "exception, error", + [ + (ApiError("", ApiResponseMeta(*([None] * 5)), None), "ApiError(None, '')"), + (ESConnectionError(""), "Connection error"), + ], +) +def test_backends_data_es_data_backend_list_method_with_failure( + exception, error, caplog, monkeypatch, es_backend +): + """Test the `ESDataBackend.list` method given an failed Elasticsearch connection + should raise a `BackendException` and log an error message. + """ + + def mock_get(index): + """Mock the ES.client.indices.get method raising an exception.""" + assert index == "*" + raise exception + + backend = es_backend() + monkeypatch.setattr(backend.client.indices, "get", mock_get) + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException): + next(backend.list()) + + assert ( + "ralph.backends.data.es", + logging.ERROR, + f"Failed to read indices: {error}", + ) in caplog.record_tuples + + +def test_backends_data_es_data_backend_list_method_without_history( + es_backend, monkeypatch +): + """Test the `ESDataBackend.list` method without history.""" + + indices = {"index_1": {"info_1": "foo"}, "index_2": {"info_2": "baz"}} + + def mock_get(index): + """Mock the ES.client.indices.get method returning a dictionary.""" + assert index == "target_index*" + return indices + + backend = es_backend() + monkeypatch.setattr(backend.client.indices, "get", mock_get) + result = backend.list("target_index*") + assert isinstance(result, Iterable) + assert list(result) == list(indices.keys()) + + +def test_backends_data_es_data_backend_list_method_with_details( + es_backend, monkeypatch +): + """Test the `ESDataBackend.list` method with `details` set to `True`.""" + indices = {"index_1": {"info_1": "foo"}, "index_2": {"info_2": "baz"}} + + def mock_get(index): + """Mock the ES.client.indices.get method returning a dictionary.""" + assert index == "target_index*" + return indices + + backend = es_backend() + monkeypatch.setattr(backend.client.indices, "get", mock_get) + result = backend.list("target_index*", details=True) + assert isinstance(result, Iterable) + assert list(result) == [ + {"index_1": {"info_1": "foo"}}, + {"index_2": {"info_2": "baz"}}, + ] + + +def test_backends_data_es_data_backend_list_method_with_history( + es_backend, caplog, monkeypatch +): + """Test the `ESDataBackend.list` method given `new` argument set to True, should log + a warning message. + """ + backend = es_backend() + monkeypatch.setattr(backend.client.indices, "get", lambda index: {}) + with caplog.at_level(logging.WARNING): + assert not list(backend.list(new=True)) + + assert ( + "ralph.backends.data.es", + logging.WARNING, + "The `new` argument is ignored", + ) in caplog.record_tuples + + +@pytest.mark.parametrize( + "exception, error", + [ + (ApiError("", ApiResponseMeta(*([None] * 5)), None), r"ApiError\(None, ''\)"), + (ESConnectionError(""), "Connection error"), + ], +) +def test_backends_data_es_data_backend_read_method_with_failure( + exception, error, es, es_backend, caplog, monkeypatch +): + """Test the `ESDataBackend.read` method, given a request failure, should raise a + `BackendException`. + """ + # pylint: disable=invalid-name,unused-argument,too-many-arguments + + def mock_es_search_open_pit(**kwargs): + """Mock the ES.client.search and open_point_in_time methods always raising an + exception. + """ + raise exception + + backend = es_backend() + + # Search failure. + monkeypatch.setattr(backend.client, "search", mock_es_search_open_pit) + with pytest.raises( + BackendException, match=f"Failed to execute Elasticsearch query: {error}" + ): + with caplog.at_level(logging.ERROR): + next(backend.read()) + + assert ( + "ralph.backends.data.es", + logging.ERROR, + "Failed to execute Elasticsearch query: %s" % error.replace("\\", ""), + ) in caplog.record_tuples + + # Open point in time failure. + monkeypatch.setattr(backend.client, "open_point_in_time", mock_es_search_open_pit) + with pytest.raises( + BackendException, match=f"Failed to open Elasticsearch point in time: {error}" + ): + with caplog.at_level(logging.ERROR): + next(backend.read()) + + error = error.replace("\\", "") + assert ( + "ralph.backends.data.es", + logging.ERROR, + "Failed to open Elasticsearch point in time: %s" % error.replace("\\", ""), + ) in caplog.record_tuples + + +def test_backends_data_es_data_backend_read_method_with_ignore_errors( + es, es_backend, monkeypatch, caplog +): + """Test the `ESDataBackend.read` method, given `ignore_errors` set to `True`, + should log a warning message. + """ + # pylint: disable=invalid-name,unused-argument + backend = es_backend() + monkeypatch.setattr(backend.client, "search", lambda **_: {"hits": {"hits": []}}) + with caplog.at_level(logging.WARNING): + list(backend.read(ignore_errors=True)) + + assert ( + "ralph.backends.data.es", + logging.WARNING, + "The `ignore_errors` argument is ignored", + ) in caplog.record_tuples + + +def test_backends_data_es_data_backend_read_method_with_raw_ouput(es, es_backend): + """Test the `ESDataBackend.read` method with `raw_output` set to `True`.""" + # pylint: disable=invalid-name,unused-argument + backend = es_backend() + documents = [{"id": idx, "timestamp": now()} for idx in range(10)] + assert backend.write(documents) == 10 + hits = list(backend.read(raw_output=True)) + for i, hit in enumerate(hits): + assert isinstance(hit, bytes) + assert json.loads(hit).get("_source") == documents[i] + + +def test_backends_data_es_data_backend_read_method_without_raw_ouput(es, es_backend): + """Test the `ESDataBackend.read` method with `raw_output` set to `False`.""" + # pylint: disable=invalid-name,unused-argument + backend = es_backend() + documents = [{"id": idx, "timestamp": now()} for idx in range(10)] + assert backend.write(documents) == 10 + hits = backend.read() + for i, hit in enumerate(hits): + assert isinstance(hit, dict) + assert hit.get("_source") == documents[i] + + +def test_backends_data_es_data_backend_read_method_with_query(es, es_backend, caplog): + """Test the `ESDataBackend.read` method with a query.""" + # pylint: disable=invalid-name,unused-argument + backend = es_backend() + documents = [{"id": idx, "timestamp": now(), "modulo": idx % 2} for idx in range(5)] + assert backend.write(documents) == 5 + # Find every even item. + query = ESQuery(query={"term": {"modulo": 0}}) + results = list(backend.read(query=query)) + assert len(results) == 3 + assert results[0]["_source"]["id"] == 0 + assert results[1]["_source"]["id"] == 2 + assert results[2]["_source"]["id"] == 4 + + # Find the first two even items. + query = ESQuery(query={"term": {"modulo": 0}}, size=2) + results = list(backend.read(query=query)) + assert len(results) == 2 + assert results[0]["_source"]["id"] == 0 + assert results[1]["_source"]["id"] == 2 + + # Find the first ten even items although there are only three available. + query = ESQuery(query={"term": {"modulo": 0}}, size=10) + results = list(backend.read(query=query)) + assert len(results) == 3 + assert results[0]["_source"]["id"] == 0 + assert results[1]["_source"]["id"] == 2 + assert results[2]["_source"]["id"] == 4 + + # Find every odd item. + query = {"query": {"term": {"modulo": 1}}} + results = list(backend.read(query=query)) + assert len(results) == 2 + assert results[0]["_source"]["id"] == 1 + assert results[1]["_source"]["id"] == 3 + + # Find documents with ID equal to one or five. + query = "id:(1 OR 5)" + results = list(backend.read(query=query)) + assert len(results) == 1 + assert results[0]["_source"]["id"] == 1 + + # Check query argument type + with pytest.raises( + BackendParameterException, + match="'query' argument is expected to be a ESQuery instance.", + ): + with caplog.at_level(logging.ERROR): + list(backend.read(query={"not_query": "foo"})) + + assert ( + "ralph.backends.data.base", + logging.ERROR, + "The 'query' argument is expected to be a ESQuery instance. " + "[{'loc': ('not_query',), 'msg': 'extra fields not permitted', " + "'type': 'value_error.extra'}]", + ) in caplog.record_tuples + + +def test_backends_data_es_data_backend_write_method_with_create_operation( + es, es_backend, caplog +): + """Test the `ESDataBackend.write` method, given an `CREATE` `operation_type`, + should insert the target documents with the provided data. + """ + # pylint: disable=invalid-name,unused-argument + + backend = es_backend() + assert len(list(backend.read())) == 0 + + # Given an empty data iterator, the write method should return 0 and log a message. + data = [] + with caplog.at_level(logging.INFO): + assert backend.write(data, operation_type=BaseOperationType.CREATE) == 0 + + assert ( + "ralph.backends.data.es", + logging.INFO, + "Data Iterator is empty; skipping write to target.", + ) in caplog.record_tuples + + # Given an iterator with multiple documents, the write method should write the + # documents to the default target index. + data = ({"value": str(idx)} for idx in range(9)) + with caplog.at_level(logging.DEBUG): + assert ( + backend.write(data, chunk_size=5, operation_type=BaseOperationType.CREATE) + == 9 + ) + + write_records = 0 + for record in caplog.record_tuples: + if re.match(r"^Wrote 1 document \[action: \{.*\}\]$", record[2]): + write_records += 1 + assert write_records == 9 + + assert ( + "ralph.backends.data.es", + logging.INFO, + "Finished writing 9 documents with success", + ) in caplog.record_tuples + + hits = list(backend.read()) + assert [hit["_source"] for hit in hits] == [{"value": str(idx)} for idx in range(9)] + + +def test_backends_data_es_data_backend_write_method_with_delete_operation( + es, + es_backend, +): + """Test the `ESDataBackend.write` method, given a `DELETE` `operation_type`, should + remove the target documents. + """ + # pylint: disable=invalid-name,unused-argument + + backend = es_backend() + data = [{"id": idx, "value": str(idx)} for idx in range(10)] + + assert len(list(backend.read())) == 0 + assert backend.write(data, chunk_size=5) == 10 + + data = [{"id": idx} for idx in range(3)] + assert ( + backend.write(data, chunk_size=5, operation_type=BaseOperationType.DELETE) == 3 + ) + + hits = list(backend.read()) + assert len(hits) == 7 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(3, 10)) + + +def test_backends_data_es_data_backend_write_method_with_update_operation( + es, + es_backend, +): + """Test the `ESDataBackend.write` method, given an `UPDATE` `operation_type`, should + overwrite the target documents with the provided data. + """ + # pylint: disable=invalid-name,unused-argument + + backend = es_backend() + data = BytesIO( + "\n".join( + [json.dumps({"id": idx, "value": str(idx)}) for idx in range(10)] + ).encode("utf8") + ) + + assert len(list(backend.read())) == 0 + assert backend.write(data, chunk_size=5) == 10 + + hits = list(backend.read()) + assert len(hits) == 10 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) + assert sorted([hit["_source"]["value"] for hit in hits]) == list( + map(str, range(10)) + ) + + data = BytesIO( + "\n".join( + [json.dumps({"id": idx, "value": str(10 + idx)}) for idx in range(10)] + ).encode("utf8") + ) + + assert ( + backend.write(data, chunk_size=5, operation_type=BaseOperationType.UPDATE) == 10 + ) + + hits = list(backend.read()) + assert len(hits) == 10 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) + assert sorted([hit["_source"]["value"] for hit in hits]) == list( + map(lambda x: str(x + 10), range(10)) + ) + + +def test_backends_data_es_data_backend_write_method_with_append_operation( + es_backend, caplog +): + """Test the `ESDataBackend.write` method, given an `APPEND` `operation_type`, + should raise a `BackendParameterException`. + """ + backend = es_backend() + msg = "Append operation_type is not supported." + with pytest.raises(BackendParameterException, match=msg): + with caplog.at_level(logging.ERROR): + backend.write(data=[{}], operation_type=BaseOperationType.APPEND) + + assert ( + "ralph.backends.data.es", + logging.ERROR, + "Append operation_type is not supported.", + ) in caplog.record_tuples + + +def test_backends_data_es_data_backend_write_method_with_target(es, es_backend): + """Test the `ESDataBackend.write` method, given a target index, should insert + documents to the corresponding index. + """ + # pylint: disable=invalid-name,unused-argument + + backend = es_backend() + + def get_data(): + """Yield data.""" + yield {"value": "1"} + yield {"value": "2"} + + # Create second Elasticsearch index. + for _ in get_es_fixture(index=ES_TEST_FORWARDING_INDEX): + # Both indexes should be empty. + assert len(list(backend.read())) == 0 + assert len(list(backend.read(target=ES_TEST_FORWARDING_INDEX))) == 0 + + # Write to forwarding index. + assert backend.write(get_data(), target=ES_TEST_FORWARDING_INDEX) == 2 + + hits = list(backend.read()) + hits_with_target = list(backend.read(target=ES_TEST_FORWARDING_INDEX)) + # No documents should be inserted into the default index. + assert not hits + # Documents should be inserted into the target index. + assert [hit["_source"] for hit in hits_with_target] == [ + {"value": "1"}, + {"value": "2"}, + ] + + +def test_backends_data_es_data_backend_write_method_without_ignore_errors( + es, es_backend, caplog +): + """Test the `ESDataBackend.write` method with `ignore_errors` set to `False`, given + badly formatted data, should raise a `BackendException`. + """ + # pylint: disable=invalid-name,unused-argument + + data = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] + # Patch a record with a non-expected type for the count field (should be + # assigned as long) + data[4].update({"count": "wrong"}) + + backend = es_backend() + assert len(list(backend.read())) == 0 + + # By default, we should raise an error and stop the importation. + msg = ( + r"1 document\(s\) failed to index. " + r"\[\{'index': \{'_index': 'test-index-foo', '_id': '4', 'status': 400, 'error'" + r": \{'type': 'mapper_parsing_exception', 'reason': \"failed to parse field " + r"\[count\] of type \[long\] in document with id '4'. Preview of field's value:" + r" 'wrong'\", 'caused_by': \{'type': 'illegal_argument_exception', 'reason': " + r"'For input string: \"wrong\"'\}\}, 'data': \{'id': 4, 'count': 'wrong'\}\}\}" + r"\] Total succeeded writes: 5" + ) + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + backend.write(data, chunk_size=2) + + assert ( + "ralph.backends.data.es", + logging.ERROR, + msg.replace("\\", ""), + ) in caplog.record_tuples + + es.indices.refresh(index=ES_TEST_INDEX) + hits = list(backend.read()) + assert len(hits) == 5 + assert sorted([hit["_source"]["id"] for hit in hits]) == [0, 1, 2, 3, 5] + + # Given an unparsable binary JSON document, the write method should raise a + # `BackendException`. + data = [ + json.dumps({"foo": "bar"}).encode("utf-8"), + "This is invalid JSON".encode("utf-8"), + json.dumps({"foo": "baz"}).encode("utf-8"), + ] + + # By default, we should raise an error and stop the importation. + msg = ( + r"Failed to decode JSON: Expecting value: line 1 column 1 \(char 0\), " + r"for document: b'This is invalid JSON'" + ) + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + backend.write(data, chunk_size=2) + + assert ( + "ralph.backends.data.es", + logging.ERROR, + msg.replace("\\", ""), + ) in caplog.record_tuples + + es.indices.refresh(index=ES_TEST_INDEX) + hits = list(backend.read()) + assert len(hits) == 5 + + +def test_backends_data_es_data_backend_write_method_with_ignore_errors(es, es_backend): + """Test the `ESDataBackend.write` method with `ignore_errors` set to `True`, given + badly formatted data, should should skip the invalid data. + """ + # pylint: disable=invalid-name,unused-argument + + records = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] + # Patch a record with a non-expected type for the count field (should be + # assigned as long) + records[2].update({"count": "wrong"}) + + backend = es_backend() + assert len(list(backend.read())) == 0 + + assert backend.write(records, chunk_size=2, ignore_errors=True) == 9 + + es.indices.refresh(index=ES_TEST_INDEX) + hits = list(backend.read()) + assert len(hits) == 9 + assert sorted([hit["_source"]["id"] for hit in hits]) == [ + i for i in range(10) if i != 2 + ] + + # Given an unparsable binary JSON document, the write method should skip it. + data = [ + json.dumps({"foo": "bar"}).encode("utf-8"), + "This is invalid JSON".encode("utf-8"), + json.dumps({"foo": "baz"}).encode("utf-8"), + ] + assert backend.write(data, chunk_size=2, ignore_errors=True) == 2 + + es.indices.refresh(index=ES_TEST_INDEX) + hits = list(backend.read()) + assert len(hits) == 11 + assert [hit["_source"] for hit in hits[9:]] == [{"foo": "bar"}, {"foo": "baz"}] + + +def test_backends_data_es_data_backend_write_method_with_datastream( + es_data_stream, es_backend +): + """Test the `ESDataBackend.write` method using a configured data stream.""" + # pylint: disable=invalid-name,unused-argument + + data = [{"id": idx, "@timestamp": datetime.now().isoformat()} for idx in range(10)] + backend = es_backend() + assert len(list(backend.read())) == 0 + assert ( + backend.write(data, chunk_size=5, operation_type=BaseOperationType.CREATE) == 10 + ) + + hits = list(backend.read()) + assert len(hits) == 10 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) diff --git a/tests/backends/lrs/test_es.py b/tests/backends/lrs/test_es.py new file mode 100644 index 000000000..802231fbb --- /dev/null +++ b/tests/backends/lrs/test_es.py @@ -0,0 +1,389 @@ +"""Tests for Ralph Elasticsearch LRS backend.""" + +import logging +import re +from datetime import datetime + +import pytest +from elastic_transport import ApiResponseMeta +from elasticsearch import ApiError +from elasticsearch.helpers import bulk + +from ralph.backends.lrs.base import StatementParameters +from ralph.exceptions import BackendException + +from tests.fixtures.backends import ES_TEST_FORWARDING_INDEX, ES_TEST_INDEX + + +@pytest.mark.parametrize( + "params,expected_query", + [ + # 0. Default query. + ( + {}, + { + "pit": {"id": None, "keep_alive": None}, + "query": {"match_all": {}}, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 1. Query by statementId. + ( + {"statementId": "statementId"}, + { + "pit": {"id": None, "keep_alive": None}, + "query": {"bool": {"filter": [{"term": {"_id": "statementId"}}]}}, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 2. Query by statementId and agent with mbox IFI. + ( + {"statementId": "statementId", "agent": {"mbox": "mailto:foo@bar.baz"}}, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + {"term": {"actor.mbox.keyword": "mailto:foo@bar.baz"}}, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 3. Query by statementId and agent with mbox_sha1sum IFI. + ( + { + "statementId": "statementId", + "agent": {"mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64"}, + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + { + "term": { + "actor.mbox_sha1sum.keyword": ( + "a7a5b7462b862c8c8767d43d43e865ffff754a64" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 4. Query by statementId and agent with openid IFI. + ( + { + "statementId": "statementId", + "agent": {"openid": "http://toby.openid.example.org/"}, + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + { + "term": { + "actor.openid.keyword": ( + "http://toby.openid.example.org/" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 5. Query by statementId and agent with account IFI. + ( + { + "statementId": "statementId", + "agent": { + "account__home_page": "http://www.example.com", + "account__name": "13936749", + }, + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + {"term": {"actor.account.name.keyword": ("13936749")}}, + { + "term": { + "actor.account.homePage.keyword": ( + "http://www.example.com" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 6. Query by verb and activity. + ( + { + "verb": "http://adlnet.gov/expapi/verbs/attended", + "activity": "http://www.example.com/meetings/34534", + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + { + "term": { + "verb.id.keyword": ( + "http://adlnet.gov/expapi/verbs/attended" + ) + } + }, + {"term": {"object.objectType.keyword": "Activity"}}, + { + "term": { + "object.id.keyword": ( + "http://www.example.com/meetings/34534" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 7. Query by timerange (with since/until). + ( + { + "since": "2021-06-24T00:00:20.194929+00:00", + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + { + "range": { + "timestamp": { + "gt": datetime.fromisoformat( + "2021-06-24T00:00:20.194929+00:00" + ) + } + } + }, + { + "range": { + "timestamp": { + "lte": datetime.fromisoformat( + "2023-06-24T00:00:20.194929+00:00" + ) + } + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 8. Query with pagination and pit_id. + ( + {"search_after": "1686557542970|0", "pit_id": "46ToAwMDaWR5BXV1a"}, + { + "pit": {"id": "46ToAwMDaWR5BXV1a", "keep_alive": None}, + "query": {"match_all": {}}, + "query_string": None, + "search_after": ["1686557542970", "0"], + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 9. Query ignoring statement sort order. + ( + {"ignore_order": True}, + { + "pit": {"id": None, "keep_alive": None}, + "query": {"match_all": {}}, + "query_string": None, + "search_after": None, + "size": None, + "sort": "_shard_doc", + "track_total_hits": False, + }, + ), + ], +) +def test_backends_lrs_es_lrs_backend_query_statements_query( + params, expected_query, es_lrs_backend, monkeypatch +): + """Test the `ESLRSBackend.query_statements` method, given valid statement + parameters, should produce the expected Elasticsearch query. + """ + + def mock_read(query, chunk_size): + """Mock the `ESLRSBackend.read` method.""" + assert query.dict() == expected_query + assert chunk_size == expected_query.get("size") + query.pit.id = "foo_pit_id" + query.search_after = ["bar_search_after", "baz_search_after"] + return [] + + backend = es_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + result = backend.query_statements(StatementParameters(**params)) + assert not result.statements + assert result.pit_id == "foo_pit_id" + assert result.search_after == "bar_search_after|baz_search_after" + + +def test_backends_lrs_es_lrs_backend_query_statements(es, es_lrs_backend): + """Test the `ESLRSBackend.query_statements` method, given a query, + should return matching statements. + """ + # pylint: disable=invalid-name,unused-argument + # Instantiate ESLRSBackend. + backend = es_lrs_backend() + # Insert documents. + documents = [{"id": "2", "timestamp": "2023-06-24T00:00:20.194929+00:00"}] + assert backend.write(documents) == 1 + + # Check the expected search query results. + result = backend.query_statements(StatementParameters(limit=10)) + assert result.statements == documents + assert re.match(r"[0-9]+\|0", result.search_after) + + +def test_backends_lrs_es_lrs_backend_query_statements_with_search_query_failure( + es, es_lrs_backend, monkeypatch, caplog +): + """Test the `ESLRSBackend.query_statements`, given a search query failure, should + raise a `BackendException` and log the error. + """ + # pylint: disable=invalid-name,unused-argument + + def mock_read(**_): + """Mock the Elasticsearch.search method.""" + raise BackendException("Query error") + + backend = es_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + msg = "Query error" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + backend.query_statements(StatementParameters()) + + assert ( + "ralph.backends.lrs.es", + logging.ERROR, + "Failed to read from Elasticsearch", + ) in caplog.record_tuples + + +def test_backends_lrs_es_lrs_backend_query_statements_by_ids_with_search_query_failure( + es, es_lrs_backend, monkeypatch, caplog +): + """Test the `ESLRSBackend.query_statements_by_ids` method, given a search query + failure, should raise a `BackendException` and log the error. + """ + # pylint: disable=invalid-name,unused-argument + + def mock_search(**_): + """Mock the Elasticsearch.search method.""" + raise ApiError("Query error", ApiResponseMeta(*([None] * 5)), None) + + backend = es_lrs_backend() + monkeypatch.setattr(backend.client, "search", mock_search) + + msg = r"Failed to execute Elasticsearch query: ApiError\(None, 'Query error'\)" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + list(backend.query_statements_by_ids(StatementParameters())) + + assert ( + "ralph.backends.lrs.es", + logging.ERROR, + "Failed to read from Elasticsearch", + ) in caplog.record_tuples + + +def test_backends_lrs_es_lrs_backend_query_statements_by_ids_with_multiple_indexes( + es, es_forwarding, es_lrs_backend +): + """Test the `ESLRSBackend.query_statements_by_ids` method, given a valid search + query, should execute the query only on the specified index and return the + expected results. + """ + # pylint: disable=invalid-name + + # Insert documents. + index_1_document = {"_index": ES_TEST_INDEX, "_id": "1", "_source": {"id": "1"}} + index_2_document = { + "_index": ES_TEST_FORWARDING_INDEX, + "_id": "2", + "_source": {"id": "2"}, + } + bulk(es, [index_1_document]) + bulk(es_forwarding, [index_2_document]) + + # As we bulk insert documents, the index needs to be refreshed before making + # queries. + es.indices.refresh(index=ES_TEST_INDEX) + es_forwarding.indices.refresh(index=ES_TEST_FORWARDING_INDEX) + + # Instantiate ESLRSBackends. + backend_1 = es_lrs_backend(index=ES_TEST_INDEX) + backend_2 = es_lrs_backend(index=ES_TEST_FORWARDING_INDEX) + + # Check the expected search query results. + index_1_document = {"id": "1"} + index_2_document = {"id": "2"} + assert list(backend_1.query_statements_by_ids(["1"])) == [index_1_document] + assert not list(backend_1.query_statements_by_ids(["2"])) + assert not list(backend_2.query_statements_by_ids(["1"])) + assert list(backend_2.query_statements_by_ids(["2"])) == [index_2_document] diff --git a/tests/conftest.py b/tests/conftest.py index b6ca603fb..9d1f3bd78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,10 @@ clickhouse_backend, clickhouse_lrs_backend, es, + es_backend, es_data_stream, es_forwarding, + es_lrs_backend, events, fs_backend, ldp_backend, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 88fd315fa..13e0b2f22 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -22,6 +22,7 @@ from pymongo.errors import CollectionInvalid from ralph.backends.data.clickhouse import ClickHouseDataBackend +from ralph.backends.data.es import ESDataBackend from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.data.ldp import LDPDataBackend from ralph.backends.data.s3 import S3DataBackend, S3DataBackendSettings @@ -30,6 +31,7 @@ from ralph.backends.database.es import ESDatabase from ralph.backends.database.mongo import MongoDatabase from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend +from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.storage.s3 import S3Storage from ralph.backends.storage.swift import SwiftStorage from ralph.conf import ClickhouseClientOptions, Settings, core_settings @@ -115,9 +117,7 @@ def get_mongo_test_backend(): def get_es_fixture(host=ES_TEST_HOSTS, index=ES_TEST_INDEX): - """Create / delete an ElasticSearch test index and yields an instantiated - client. - """ + """Create / delete an Elasticsearch test index and yield an instantiated client.""" client = Elasticsearch(host) try: client.indices.create(index=index) @@ -131,16 +131,15 @@ def get_es_fixture(host=ES_TEST_HOSTS, index=ES_TEST_INDEX): @pytest.fixture def es(): - """Yield an ElasticSearch test client. See get_es_fixture above.""" + """Yield an Elasticsearch test client. See get_es_fixture above.""" # pylint: disable=invalid-name - for es_client in get_es_fixture(): yield es_client @pytest.fixture def es_forwarding(): - """Yield a second ElasticSearch test client. See get_es_fixture above.""" + """Yield a second Elasticsearch test client. See get_es_fixture above.""" for es_client in get_es_fixture(index=ES_TEST_FORWARDING_INDEX): yield es_client @@ -258,7 +257,7 @@ def clickhouse(): @pytest.fixture def es_data_stream(): - """Create / delete an ElasticSearch test datastream and yields an instantiated + """Create / delete an Elasticsearch test datastream and yield an instantiated client. """ client = Elasticsearch(ES_TEST_HOSTS) @@ -277,7 +276,7 @@ def es_data_stream(): "date_detection": True, "numeric_detection": True, # Note: We define an explicit mapping of the `timestamp` field to allow the - # ElasticSearch database to be queried even if no document has been inserted + # Elasticsearch database to be queried even if no document has been inserted # before. "properties": { "timestamp": { @@ -394,6 +393,47 @@ def get_clickhouse_lrs_backend(): return get_clickhouse_lrs_backend +@pytest.fixture +def es_backend(): + """Return the `get_es_data_backend` function.""" + + def get_es_data_backend(): + """Return an instance of ESDataBackend.""" + settings = ESDataBackend.settings_class( + ALLOW_YELLOW_STATUS=False, + CLIENT_OPTIONS={"ca_certs": None, "verify_certs": None}, + DEFAULT_CHUNK_SIZE=500, + DEFAULT_INDEX=ES_TEST_INDEX, + HOSTS=ES_TEST_HOSTS, + LOCALE_ENCODING="utf8", + REFRESH_AFTER_WRITE=True, + ) + return ESDataBackend(settings) + + return get_es_data_backend + + +@pytest.fixture +def es_lrs_backend(): + """Return the `get_es_lrs_backend` function.""" + + def get_es_lrs_backend(index: str = ES_TEST_INDEX): + """Return an instance of ESLRSBackend.""" + settings = ESLRSBackend.settings_class( + ALLOW_YELLOW_STATUS=False, + CLIENT_OPTIONS={"ca_certs": None, "verify_certs": None}, + DEFAULT_CHUNK_SIZE=500, + DEFAULT_INDEX=index, + HOSTS=ES_TEST_HOSTS, + LOCALE_ENCODING="utf8", + POINT_IN_TIME_KEEP_ALIVE="1m", + REFRESH_AFTER_WRITE=True, + ) + return ESLRSBackend(settings) + + return get_es_lrs_backend + + @pytest.fixture def swift(): """Return get_swift_storage function.""" diff --git a/tests/test_conf.py b/tests/test_conf.py index 846288a14..346fb1564 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -40,7 +40,12 @@ def test_conf_settings_field_value_priority(fs, monkeypatch): @pytest.mark.parametrize( "value,expected", - [("foo", ("foo",)), (("foo",), ("foo",)), ("foo,bar,baz", ("foo", "bar", "baz"))], + [ + ("foo", ("foo",)), + (("foo",), ("foo",)), + (["foo"], ("foo",)), + ("foo,bar,baz", ("foo", "bar", "baz")), + ], ) def test_conf_comma_separated_list_with_valid_values(value, expected, monkeypatch): """Test the CommaSeparatedTuple pydantic data type with valid values.""" @@ -49,7 +54,7 @@ def test_conf_comma_separated_list_with_valid_values(value, expected, monkeypatc assert Settings().BACKENDS.DATABASE.ES.HOSTS == expected -@pytest.mark.parametrize("value", [{}, [], None]) +@pytest.mark.parametrize("value", [{}, None]) def test_conf_comma_separated_list_with_invalid_values(value): """Test the CommaSeparatedTuple pydantic data type with invalid values.""" with pytest.raises(TypeError, match="Invalid comma separated list"): From 84b72b4b2b6d433b8582e50c5509473600b2f544 Mon Sep 17 00:00:00 2001 From: SergioSim Date: Tue, 20 Jun 2023 11:25:45 +0200 Subject: [PATCH 14/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20FSLRSBackend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We want to provide an LRS backend that could be added to Ralph's LRS without any additional dependencies. --- src/ralph/backends/data/fs.py | 10 +- src/ralph/backends/lrs/base.py | 4 +- src/ralph/backends/lrs/fs.py | 378 +++++++++++++++++++++++++++++++ tests/backends/data/test_base.py | 8 +- tests/backends/lrs/test_fs.py | 287 +++++++++++++++++++++++ tests/conftest.py | 1 + tests/fixtures/backends.py | 20 ++ 7 files changed, 697 insertions(+), 11 deletions(-) create mode 100644 src/ralph/backends/lrs/fs.py create mode 100644 tests/backends/lrs/test_fs.py diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 1cf89d2ea..54ba86b28 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -58,11 +58,11 @@ class FSDataBackend(HistoryMixin, BaseDataBackend): def __init__(self, settings: settings_class = None): """Creates the default target directory if it does not exist.""" - settings = settings if settings else self.settings_class() - self.default_chunk_size = settings.DEFAULT_CHUNK_SIZE - self.default_directory = settings.DEFAULT_DIRECTORY_PATH - self.default_query_string = settings.DEFAULT_QUERY_STRING - self.locale_encoding = settings.LOCALE_ENCODING + self.settings = settings if settings else self.settings_class() + self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE + self.default_directory = self.settings.DEFAULT_DIRECTORY_PATH + self.default_query_string = self.settings.DEFAULT_QUERY_STRING + self.locale_encoding = self.settings.LOCALE_ENCODING if not self.default_directory.is_dir(): msg = "Default directory doesn't exist, creating: %s" diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 957ccabc4..4beb06680 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -20,8 +20,8 @@ class StatementQueryResult: """Result of an LRS statements query.""" statements: List[dict] - pit_id: str - search_after: str + pit_id: Optional[str] + search_after: Optional[str] class AgentParameters(BaseModel): diff --git a/src/ralph/backends/lrs/fs.py b/src/ralph/backends/lrs/fs.py new file mode 100644 index 000000000..648719150 --- /dev/null +++ b/src/ralph/backends/lrs/fs.py @@ -0,0 +1,378 @@ +"""FileSystem LRS backend for Ralph.""" + +import logging +from datetime import datetime +from io import IOBase +from typing import Iterable, List, Literal, Union +from uuid import UUID + +from ralph.backends.data.base import BaseOperationType +from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings +from ralph.backends.lrs.base import ( + AgentParameters, + BaseLRSBackend, + BaseLRSBackendSettings, + StatementParameters, + StatementQueryResult, +) + +logger = logging.getLogger(__name__) + + +class FSLRSBackendSettings(BaseLRSBackendSettings, FSDataBackendSettings): + """FileSystem LRS backend default configuration. + + Attributes: + DEFAULT_LRS_FILE (str): The default LRS filename to store statements. + """ + + DEFAULT_LRS_FILE: str = "fs_lrs.jsonl" + + +class FSLRSBackend(BaseLRSBackend, FSDataBackend): + """FileSystem LRS Backend.""" + + settings_class = FSLRSBackendSettings + + def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Write data records to the target file and return their count. + + See `FSDataBackend.write`. + """ + target = target if target else self.settings.DEFAULT_LRS_FILE + return super().write(data, target, chunk_size, ignore_errors, operation_type) + + def query_statements(self, params: StatementParameters) -> StatementQueryResult: + """Return the statements query payload using xAPI parameters.""" + filters = [] + self._add_filter_by_id(filters, params.statementId) + self._add_filter_by_agent(filters, params.agent, params.related_agents) + self._add_filter_by_authority(filters, params.authority) + self._add_filter_by_verb(filters, params.verb) + self._add_filter_by_activity( + filters, params.activity, params.related_activities + ) + self._add_filter_by_registration(filters, params.registration) + self._add_filter_by_timestamp_since(filters, params.since) + self._add_filter_by_timestamp_until(filters, params.until) + self._add_filter_by_search_after(filters, params.search_after) + + limit = params.limit + statements_count = 0 + search_after = None + statements = [] + for statement in self.read(query=self.settings.DEFAULT_LRS_FILE): + for query_filter in filters: + if not query_filter(statement): + break + else: + statements.append(statement) + statements_count += 1 + if limit and statements_count == limit: + search_after = statements[-1].get("id") + break + + if params.ascending: + statements.reverse() + return StatementQueryResult( + statements=statements, + pit_id=None, + search_after=search_after, + ) + + def query_statements_by_ids(self, ids: List[str]) -> List: + """Return the list of matching statement IDs from the database.""" + statement_ids = set(ids) + statements = [] + for statement in self.read(query=self.settings.DEFAULT_LRS_FILE): + if statement.get("id") in statement_ids: + statements.append(statement) + + return statements + + @staticmethod + def _add_filter_by_agent( + filters: list, agent: Union[AgentParameters, None], related: Union[bool, None] + ) -> None: + """Add agent filters to `filters` if `agent` is set.""" + if not agent: + return + + FSLRSBackend._add_filter_by_mbox(filters, agent.mbox, related) + FSLRSBackend._add_filter_by_sha1sum(filters, agent.mbox_sha1sum, related) + FSLRSBackend._add_filter_by_openid(filters, agent.openid, related) + FSLRSBackend._add_filter_by_account( + filters, agent.account__name, agent.account__home_page, related + ) + + @staticmethod + def _add_filter_by_authority( + filters: list, + authority: Union[AgentParameters, None], + ) -> None: + """Add authority filters to `filters` if `authority` is set.""" + if not authority: + return + + FSLRSBackend._add_filter_by_mbox(filters, authority.mbox, field="authority") + FSLRSBackend._add_filter_by_sha1sum( + filters, authority.mbox_sha1sum, field="authority" + ) + FSLRSBackend._add_filter_by_openid(filters, authority.openid, field="authority") + FSLRSBackend._add_filter_by_account( + filters, + authority.account__name, + authority.account__home_page, + field="authority", + ) + + @staticmethod + def _add_filter_by_id(filters: list, statement_id: Union[str, None]) -> None: + """Add the `match_statement_id` filter if `statement_id` is set.""" + + def match_statement_id(statement: dict) -> bool: + """Return `True` if the statement has the given `statement_id`.""" + return statement.get("id") == statement_id + + if statement_id: + filters.append(match_statement_id) + + @staticmethod + def _get_related_agents(statement: dict) -> Iterable[dict]: + yield statement.get("actor", {}) + yield statement.get("object", {}) + yield statement.get("authority", {}) + context = statement.get("context", {}) + yield context.get("instructor", {}) + yield context.get("team", {}) + + @staticmethod + def _add_filter_by_mbox( + filters: list, + mbox: Union[str, None], + related: Union[bool, None] = False, + field: Literal["actor", "authority"] = "actor", + ) -> None: + """Add the `match_mbox` filter if `mbox` is set.""" + + def match_mbox(statement: dict) -> bool: + """Return `True` if the statement has the given `actor.mbox`.""" + return statement.get(field, {}).get("mbox") == mbox + + def match_related_mbox(statement: dict) -> bool: + """Return `True` if the statement has any agent matching `mbox`.""" + for agent in FSLRSBackend._get_related_agents(statement): + if agent.get("mbox") == mbox: + return True + + statement_object = statement.get("object", {}) + if statement_object.get("objectType") == "SubStatement": + return match_related_mbox(statement_object) + return False + + if mbox: + filters.append(match_related_mbox if related else match_mbox) + + @staticmethod + def _add_filter_by_sha1sum( + filters: list, + sha1sum: Union[str, None], + related: Union[bool, None] = False, + field: Literal["actor", "authority"] = "actor", + ) -> None: + """Add the `match_sha1sum` filter if `sha1sum` is set.""" + + def match_sha1sum(statement: dict) -> bool: + """Return `True` if the statement has the given `actor.sha1sum`.""" + return statement.get(field, {}).get("mbox_sha1sum") == sha1sum + + def match_related_sha1sum(statement: dict) -> bool: + """Return `True` if the statement has any agent matching `sha1sum`.""" + for agent in FSLRSBackend._get_related_agents(statement): + if agent.get("mbox_sha1sum") == sha1sum: + return True + + statement_object = statement.get("object", {}) + if statement_object.get("objectType") == "SubStatement": + return match_related_sha1sum(statement_object) + return False + + if sha1sum: + filters.append(match_related_sha1sum if related else match_sha1sum) + + @staticmethod + def _add_filter_by_openid( + filters: list, + openid: Union[str, None], + related: Union[bool, None] = False, + field: Literal["actor", "authority"] = "actor", + ) -> None: + """Add the `match_openid` filter if `openid` is set.""" + + def match_openid(statement: dict) -> bool: + """Return `True` if the statement has the given `actor.openid`.""" + return statement.get(field, {}).get("openid") == openid + + def match_related_openid(statement: dict) -> bool: + """Return `True` if the statement has any agent matching `openid`.""" + for agent in FSLRSBackend._get_related_agents(statement): + if agent.get("openid") == openid: + return True + + statement_object = statement.get("object", {}) + if statement_object.get("objectType") == "SubStatement": + return match_related_openid(statement_object) + return False + + if openid: + filters.append(match_related_openid if related else match_openid) + + @staticmethod + def _add_filter_by_account( + filters: list, + name: Union[str, None], + home_page: Union[str, None], + related: Union[bool, None] = False, + field: Literal["actor", "authority"] = "actor", + ) -> None: + """Add the `match_account` filter if `name` or `home_page` is set.""" + + def match_account(statement: dict) -> bool: + """Return `True` if the statement has the given `actor.account`.""" + account = statement.get(field, {}).get("account", {}) + return account.get("name") == name and account.get("homePage") == home_page + + def match_related_account(statement: dict) -> bool: + """Return `True` if the statement has any agent matching the account.""" + for agent in FSLRSBackend._get_related_agents(statement): + account = agent.get("account", {}) + if account.get("name") == name and account.get("homePage") == home_page: + return True + + statement_object = statement.get("object", {}) + if statement_object.get("objectType") == "SubStatement": + return match_related_account(statement_object) + return False + + if name and home_page: + filters.append(match_related_account if related else match_account) + + @staticmethod + def _add_filter_by_verb(filters: list, verb_id: Union[str, None]) -> None: + """Add the `match_verb_id` filter if `verb_id` is set.""" + + def match_verb_id(statement: dict) -> bool: + """Return `True` if the statement has the given `verb.id`.""" + return statement.get("verb", {}).get("id") == verb_id + + if verb_id: + filters.append(match_verb_id) + + @staticmethod + def _add_filter_by_activity( + filters: list, object_id: Union[str, None], related: Union[bool, None] + ) -> None: + """Add the `match_object_id` filter if `object_id` is set.""" + + def match_object_id(statement: dict) -> bool: + """Return `True` if the statement has the given `object.id`.""" + return statement.get("object", {}).get("id") == object_id + + def match_related_object_id(statement: dict) -> bool: + """Return `True` if the statement has any object.id matching `object_id`.""" + statement_object = statement.get("object", {}) + if statement_object.get("id") == object_id: + return True + activities = statement.get("context", {}).get("contextActivities", {}) + for activity in activities.values(): + if isinstance(activity, dict): + if activity.get("id") == object_id: + return True + else: + for sub_activity in activity: + if sub_activity.get("id") == object_id: + return True + if statement_object.get("objectType") == "SubStatement": + return match_related_object_id(statement_object) + + return False + + if object_id: + filters.append(match_related_object_id if related else match_object_id) + + @staticmethod + def _add_filter_by_timestamp_since( + filters: list, timestamp: Union[datetime, None] + ) -> None: + """Add the `match_since` filter if `timestamp` is set.""" + + def match_since(statement: dict) -> bool: + """Return `True` if the statement was created after `timestamp`.""" + try: + statement_timestamp = datetime.fromisoformat(statement.get("timestamp")) + except (TypeError, ValueError) as error: + msg = "Statement with id=%s contains unparsable timestamp=%s" + logger.debug(msg, statement.get("id"), error) + return False + return statement_timestamp > timestamp + + if timestamp: + filters.append(match_since) + + @staticmethod + def _add_filter_by_timestamp_until( + filters: list, timestamp: Union[datetime, None] + ) -> None: + """Add the `match_until` function if `timestamp` is set.""" + + def match_until(statement: dict) -> bool: + """Return `True` if the statement was created before `timestamp`.""" + try: + statement_timestamp = datetime.fromisoformat(statement.get("timestamp")) + except (TypeError, ValueError) as error: + msg = "Statement with id=%s contains unparsable timestamp=%s" + logger.debug(msg, statement.get("id"), error) + return False + return statement_timestamp <= timestamp + + if timestamp: + filters.append(match_until) + + @staticmethod + def _add_filter_by_search_after( + filters: list, search_after: Union[str, None] + ) -> None: + """Add the `match_search_after` filter if `search_after` is set.""" + search_after_state = {"state": False} + + def match_search_after(statement: dict) -> bool: + """Return `True` if the statement was created after `search_after`.""" + if search_after_state["state"]: + return True + if statement.get("id") == search_after: + search_after_state["state"] = True + return False + + if search_after: + filters.append(match_search_after) + + @staticmethod + def _add_filter_by_registration( + filters: list, registration: Union[UUID, None] + ) -> None: + """Add the `match_registration` filter if `registration` is set.""" + registration_str = str(registration) + + def match_registration(statement: dict) -> bool: + """Return `True` if the statement has the given `context.registration`.""" + return statement.get("context", {}).get("registration") == registration_str + + if registration: + filters.append(match_registration) diff --git a/tests/backends/data/test_base.py b/tests/backends/data/test_base.py index e86e72e08..deacdddfd 100644 --- a/tests/backends/data/test_base.py +++ b/tests/backends/data/test_base.py @@ -16,13 +16,13 @@ ], ) def test_backends_data_base_enforce_query_checks_with_valid_input(value, expected): - """Tests the enforce_query_checks function given valid input.""" + """Test the enforce_query_checks function given valid input.""" class MockBaseDataBackend(BaseDataBackend): """A class mocking the base database class.""" def __init__(self, settings=None): - """Instantiates the Mock data backend.""" + """Instantiate the Mock data backend.""" @enforce_query_checks def read(self, query=None): # pylint: disable=no-self-use,arguments-differ @@ -57,13 +57,13 @@ def write(self): # pylint: disable=arguments-differ,missing-function-docstring def test_backends_data_base_enforce_query_checks_with_invalid_input( value, error, caplog ): - """Tests the enforce_query_checks function given invalid input.""" + """Test the enforce_query_checks function given invalid input.""" class MockBaseDataBackend(BaseDataBackend): """A class mocking the base database class.""" def __init__(self, settings=None): - """Instantiates the Mock data backend.""" + """Instantiate the Mock data backend.""" @enforce_query_checks def read(self, query=None): # pylint: disable=no-self-use,arguments-differ diff --git a/tests/backends/lrs/test_fs.py b/tests/backends/lrs/test_fs.py new file mode 100644 index 000000000..2a3968719 --- /dev/null +++ b/tests/backends/lrs/test_fs.py @@ -0,0 +1,287 @@ +"""Tests for Ralph FileSystem LRS backend.""" + +import pytest + +from ralph.backends.lrs.base import StatementParameters + + +@pytest.mark.parametrize( + "params,expected_statement_ids", + [ + # 0. Default query. + ({}, ["0", "1", "2", "3", "4", "5", "6", "7", "8"]), + # 1. Query by statementId. + ({"statementId": "1"}, ["1"]), + # 2. Query by statementId and agent with mbox IFI. + ({"statementId": "1", "agent": {"mbox": "mailto:foo@bar.baz"}}, ["1"]), + # 3. Query by statementId and agent with mbox IFI (no match). + ({"statementId": "1", "agent": {"mbox": "mailto:bar@bar.baz"}}, []), + # 4. Query by statementId and agent with mbox_sha1sum IFI. + ({"statementId": "0", "agent": {"mbox_sha1sum": "foo_sha1sum"}}, ["0"]), + # 5. Query by agent with mbox_sha1sum IFI (no match). + ({"statementId": "0", "agent": {"mbox_sha1sum": "bar_sha1sum"}}, []), + # 6. Query by statementId and agent with openid IFI. + ({"statementId": "2", "agent": {"openid": "foo_openid"}}, ["2"]), + # 7. Query by statementId and agent with openid IFI (no match). + ({"statementId": "2", "agent": {"openid": "bar_openid"}}, []), + # 8. Query by statementId and agent with account IFI. + ( + { + "statementId": "3", + "agent": { + "account__home_page": "foo_home", + "account__name": "foo_name", + }, + }, + ["3"], + ), + # 9. Query by statementId and agent with account IFI (no match). + ( + { + "statementId": "3", + "agent": { + "account__home_page": "foo_home", + "account__name": "bar_name", + }, + }, + [], + ), + # 10. Query by verb and activity. + ({"verb": "foo_verb", "activity": "foo_object"}, ["1", "2"]), + # 11. Query by timerange (with since/until). + ( + { + "since": "2021-06-24T00:00:20.194929+00:00", + "until": "2023-06-24T00:00:20.194929+00:00", + }, + ["1", "3"], + ), + # 12. Query by timerange (with until). + ( + { + "until": "2023-06-24T00:00:20.194929+00:00", + }, + ["0", "1", "3"], + ), + # 13. Query with pagination. + ({"search_after": "1"}, ["2", "3", "4", "5", "6", "7", "8"]), + # 14. Query with pagination and limit. + ({"search_after": "1", "limit": 2}, ["2", "3"]), + # 15. Query with pagination and limit. + ({"search_after": "3", "limit": 5}, ["4", "5", "6", "7", "8"]), + # 16. Query in ascending order. + ({"ascending": True}, ["8", "7", "6", "5", "4", "3", "2", "1", "0"]), + # 17. Query by registration. + ({"registration": "b0d0e57d-9fbf-42e3-ba60-85e0be6f709d"}, ["2", "4"]), + # 18. Query by activity without related activities. + ({"activity": "bar_object", "related_activities": False}, ["0"]), + # 19. Query by activity with related activities. + ( + {"activity": "bar_object", "related_activities": True}, + ["0", "1", "2", "4", "5"], + ), + # 20. Query by related agent with mbox IFI. + ( + {"agent": {"mbox": "mailto:foo@bar.baz"}, "related_agents": True}, + ["1", "3", "4", "5", "6", "7"], + ), + # 21. Query by related agent with mbox_sha1sum IFI. + ( + {"agent": {"mbox_sha1sum": "foo_sha1sum"}, "related_agents": True}, + ["0", "1", "2", "5", "6", "7", "8"], + ), + # 22. Query by related agent with openid IFI. + ( + {"agent": {"openid": "foo_openid"}, "related_agents": True}, + ["0", "2", "4", "5", "6", "7"], + ), + # 23. Query by related agent with account IFI. + ( + { + "agent": { + "account__home_page": "foo_home", + "account__name": "foo_name", + }, + "related_agents": True, + }, + ["1", "2", "3", "4", "5", "7"], + ), + # 24. Query by authority with mbox IFI. + ({"authority": {"mbox": "mailto:foo@bar.baz"}}, ["4"]), + # 25. Query by authority with mbox IFI (no match). + ({"authority": {"mbox": "mailto:bar@bar.baz"}}, []), + # 26. Query by authority with mbox_sha1sum IFI. + ({"authority": {"mbox_sha1sum": "foo_sha1sum"}}, ["7"]), + # 27. Query by authority with mbox_sha1sum IFI (no match). + ({"authority": {"mbox_sha1sum": "bar_sha1sum"}}, []), + # 28. Query by authority with openid IFI. + ({"authority": {"openid": "foo_openid"}}, ["6"]), + # 29. Query by authority with openid IFI (no match). + ({"authority": {"openid": "bar_openid"}}, []), + # 30. Query by authority with account IFI. + ( + { + "authority": { + "account__home_page": "foo_home", + "account__name": "foo_name", + }, + }, + ["2"], + ), + # 31. Query by authority with account IFI (no match). + ( + { + "authority": { + "account__home_page": "foo_home", + "account__name": "bar_name", + }, + }, + [], + ), + ], +) +def test_backends_lrs_fs_lrs_backend_query_statements_query( + params, expected_statement_ids, fs_lrs_backend +): + """Test the `FSLRSBackend.query_statements` method, given valid statement + parameters, should return the expected statements. + """ + statements = [ + { + "id": "0", + "actor": {"mbox_sha1sum": "foo_sha1sum"}, + "verb": {"id": "foo_verb"}, + "object": {"id": "bar_object", "objectType": "Activity"}, + "context": { + "registration": "de867099-77ee-453b-949e-2c1933734436", + "instructor": {"mbox": "mailto:bar@bar.baz"}, + "team": {"openid": "foo_openid"}, + }, + "timestamp": "2021-06-24T00:00:20.194929+00:00", + }, + { + "id": "1", + "actor": {"mbox": "mailto:foo@bar.baz"}, + "verb": {"id": "foo_verb"}, + "object": { + "id": "foo_object", + "account": {"name": "foo_name", "homePage": "foo_home"}, + }, + "context": { + "instructor": {"mbox_sha1sum": "foo_sha1sum"}, + "contextActivities": {"parent": {"id": "bar_object"}}, + }, + "timestamp": "2021-06-24T00:00:20.194930+00:00", + }, + { + "id": "2", + "actor": {"openid": "foo_openid"}, + "verb": {"id": "foo_verb"}, + "object": {"id": "foo_object", "objectType": "Activity"}, + "context": { + "registration": "b0d0e57d-9fbf-42e3-ba60-85e0be6f709d", + "contextActivities": {"grouping": [{"id": "bar_object"}]}, + "team": {"mbox_sha1sum": "foo_sha1sum"}, + }, + "timestamp": "UNPARSABLE-2022-06-24T00:00:20.194929+00:00", + "authority": {"account": {"name": "foo_name", "homePage": "foo_home"}}, + }, + { + "id": "3", + "actor": {"account": {"name": "foo_name", "homePage": "foo_home"}}, + "verb": {"id": "bar_verb"}, + "object": {"objectType": "Agent", "mbox": "mailto:foo@bar.baz"}, + "timestamp": "2023-06-24T00:00:20.194929+00:00", + }, + { + "id": "4", + "verb": {"id": "bar_verb"}, + "object": {"id": "foo_object"}, + "context": { + "registration": "b0d0e57d-9fbf-42e3-ba60-85e0be6f709d", + "contextActivities": { + "category": [{"id": "foo_object"}, {"id": "baz_object"}], + "other": [{"id": "bar_object"}, {"id": "baz_object"}], + }, + "instructor": {"openid": "foo_openid"}, + "team": {"account": {"name": "foo_name", "homePage": "foo_home"}}, + }, + "timestamp": "2024-06-24T00:00:20.194929+00:00", + "authority": {"mbox": "mailto:foo@bar.baz"}, + }, + { + "id": "5", + "actor": { + "mbox_sha1sum": "foo_sha1sum", + }, + "verb": {"id": "qux_verb"}, + "object": { + "objectType": "SubStatement", + "actor": {"openid": "foo_openid"}, + "verb": {"id": "bar_verb"}, + "object": {"id": "bar_object", "objectType": "Activity"}, + "context": { + "instructor": { + "account": {"name": "foo_name", "homePage": "foo_home"} + }, + "team": { + "mbox": "mailto:foo@bar.baz", + }, + }, + }, + }, + { + "id": "6", + "object": { + "objectType": "Agent", + "mbox_sha1sum": "foo_sha1sum", + }, + "context": {"instructor": {"mbox": "mailto:foo@bar.baz"}}, + "authority": {"openid": "foo_openid"}, + }, + { + "id": "7", + "object": {"objectType": "Agent", "openid": "foo_openid"}, + "context": { + "instructor": {"account": {"name": "foo_name", "homePage": "foo_home"}}, + "team": { + "mbox": "mailto:foo@bar.baz", + }, + }, + "authority": {"mbox_sha1sum": "foo_sha1sum"}, + }, + { + "id": "8", + "object": { + "objectType": "SubStatement", + "actor": {"mbox_sha1sum": "foo_sha1sum"}, + }, + }, + ] + backend = fs_lrs_backend() + backend.write(statements) + result = backend.query_statements(StatementParameters(**params)) + ids = [statement.get("id") for statement in result.statements] + assert ids == expected_statement_ids + + +def test_backends_lrs_fs_lrs_backend_query_statements_by_ids(fs_lrs_backend): + """Test the `FSLRSBackend.query_statements_by_ids` method, given a valid search + query, should return the expected results. + """ + backend = fs_lrs_backend() + assert not backend.query_statements_by_ids(["foo"]) + backend.write( + [ + {"id": "foo"}, + {"id": "bar"}, + {"id": "baz"}, + ] + ) + assert not backend.query_statements_by_ids([]) + assert not backend.query_statements_by_ids(["qux", "foobar"]) + assert backend.query_statements_by_ids(["foo"]) == [{"id": "foo"}] + assert backend.query_statements_by_ids(["bar", "baz"]) == [ + {"id": "bar"}, + {"id": "baz"}, + ] diff --git a/tests/conftest.py b/tests/conftest.py index 9d1f3bd78..caa27c0d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ es_lrs_backend, events, fs_backend, + fs_lrs_backend, ldp_backend, lrs, mongo, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 13e0b2f22..60efcf7a7 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -32,6 +32,7 @@ from ralph.backends.database.mongo import MongoDatabase from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend from ralph.backends.lrs.es import ESLRSBackend +from ralph.backends.lrs.fs import FSLRSBackend from ralph.backends.storage.s3 import S3Storage from ralph.backends.storage.swift import SwiftStorage from ralph.conf import ClickhouseClientOptions, Settings, core_settings @@ -163,6 +164,25 @@ def get_fs_data_backend(path: str = "foo"): return get_fs_data_backend +@pytest.fixture +def fs_lrs_backend(fs, settings_fs): + """Return the `get_fs_data_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name,unused-argument + fs.create_dir("foo") + + def get_fs_lrs_backend(path: str = "foo"): + """Return an instance of FSLRSBackend.""" + settings = FSLRSBackend.settings_class( + DEFAULT_CHUNK_SIZE=1024, + DEFAULT_DIRECTORY_PATH=path, + DEFAULT_QUERY_STRING="*", + LOCALE_ENCODING="utf8", + ) + return FSLRSBackend(settings) + + return get_fs_lrs_backend + + def get_mongo_fixture( connection_uri=MONGO_TEST_CONNECTION_URI, database=MONGO_TEST_DATABASE, From 4b8fb5349a46ae96523691de25e2900719800d37 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Mon, 7 Aug 2023 15:56:21 +0200 Subject: [PATCH 15/65] =?UTF-8?q?=E2=99=BB=EF=B8=8F(backends)=20move=20bac?= =?UTF-8?q?kends=20utils=20method=20to=20`utils.py`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Methods `read_raw` and `parse_bytes_to_dict` are generic and used by multiple backends. Moving them to file `utils.py`. --- src/ralph/utils.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/ralph/utils.py b/src/ralph/utils.py index c97b9f8c2..ff6bb5b23 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -2,14 +2,17 @@ import asyncio import datetime +import json import logging import operator from functools import reduce from importlib import import_module -from typing import List, Union +from typing import Any, Dict, Iterable, Iterator, List, Union from pydantic import BaseModel +from ralph.exceptions import BackendException + # Taken from Django utilities # https://docs.djangoproject.com/en/3.1/_modules/django/utils/module_loading/#import_string @@ -142,3 +145,24 @@ def statements_are_equivalent(statement_1: dict, statement_2: dict): if any(statement_1.get(field) != statement_2.get(field) for field in fields): return False return True + + +def parse_bytes_to_dict( + raw_documents: Iterable[bytes], ignore_errors: bool, logger_class: logging.Logger +) -> Iterator[dict]: + """Read the `raw_documents` Iterable and yield dictionaries.""" + for raw_document in raw_documents: + try: + yield json.loads(raw_document) + except (TypeError, json.JSONDecodeError) as error: + msg = "Failed to decode JSON: %s, for document: %s" + logger_class.error(msg, error, raw_document) + if ignore_errors: + continue + raise BackendException(msg % (error, raw_document)) from error + + +def read_raw(documents: Iterable[Dict[str, Any]], encoding: str) -> Iterator[bytes]: + """Read the `documents` Iterable with the `encoding` and yield bytes.""" + for document in documents: + yield json.dumps(document).encode(encoding) From a144698f3da2eca42dbc5189095ad4d63ead7131 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Wed, 21 Jun 2023 09:50:35 +0200 Subject: [PATCH 16/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20base=20for=20?= =?UTF-8?q?async=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add asynchronous base interface for async backends such as async_es or async_mongo --- src/ralph/backends/data/base.py | 182 ++++++++++++++++++++++++++++++-- 1 file changed, 175 insertions(+), 7 deletions(-) diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 9e70cc2c7..107fdbe2b 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -128,7 +128,7 @@ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery def status(self) -> DataBackendStatus: """Implements data backend checks (e.g. connection, cluster status). - Returns: + Return: DataBackendStatus: The status of the data backend. """ @@ -144,11 +144,11 @@ def list( details (bool): Get detailed container information instead of just names. new (bool): Given the history, list only not already read containers. - Yields: + Yield: str: If `details` is False. dict: If `details` is True. - Raises: + Raise: BackendException: If a failure occurs. BackendParameterException: If a backend argument value is not valid. """ @@ -181,11 +181,11 @@ def read( are be ignored and logged. If `False` (default), a `BackendException` is raised if an error occurs. - Yields: + Yield: dict: If `raw_output` is False. bytes: If `raw_output` is True. - Raises: + Raise: BackendException: If a failure during the read operation occurs and `ignore_errors` is set to `False`. BackendParameterException: If a backend argument value is not valid. @@ -216,11 +216,179 @@ def write( # pylint: disable=too-many-arguments If `operation_type` is `None`, the `default_operation_type` is used instead. See `BaseOperationType`. - Returns: + Return: int: The number of written records. - Raises: + Raise: BackendException: If a failure during the write operation occurs and `ignore_errors` is set to `False`. BackendParameterException: If a backend argument value is not valid. """ + + +def async_enforce_query_checks(method): + """Enforces query argument type checking for methods using it.""" + + @functools.wraps(method) + async def wrapper(*args, **kwargs): + """Wrap method execution.""" + query = kwargs.pop("query", None) + self_ = args[0] + + return method(*args, query=self_.validate_query(query), **kwargs) + + return wrapper + + +class BaseAsyncDataBackend(ABC): + """Base data backend interface.""" + + name = "base" + query_model = BaseQuery + default_operation_type = BaseOperationType.INDEX + settings_class = BaseDataBackendSettings + + @abstractmethod + def __init__(self, settings: settings_class = None): + """Instantiates the data backend. + + Args: + settings (BaseDataBackendSettings or None): The backend settings. + If `settings` is `None`, a default settings instance is used instead. + """ + + def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: + """Validates and transforms the query.""" + if query is None: + query = self.query_model() + + if isinstance(query, str): + query = self.query_model(query_string=query) + + if isinstance(query, dict): + try: + query = self.query_model(**query) + except ValidationError as err: + raise BackendParameterException( + "The 'query' argument is expected to be a " + f"{self.query_model.__name__} instance. {err.errors()}" + ) from err + + if not isinstance(query, self.query_model): + raise BackendParameterException( + "The 'query' argument is expected to be a " + f"{self.query_model.__name__} instance." + ) + + logger.debug("Query: %s", str(query)) + + return query + + @abstractmethod + async def status(self) -> DataBackendStatus: + """Implements data backend checks (e.g. connection, cluster status). + + Return: + DataBackendStatus: The status of the data backend. + """ + + @abstractmethod + async def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """Lists containers in the data backend. E.g., collections, files, indexes. + + Args: + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + details (bool): Get detailed container information instead of just names. + new (bool): Given the history, list only not already read containers. + + Yield: + str: If `details` is False. + dict: If `details` is True. + + Raise: + BackendException: If a failure occurs. + BackendParameterException: If a backend argument value is not valid. + """ + + @abstractmethod + @async_enforce_query_checks + async def read( + self, + *, + query: Union[str, BaseQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Reads records matching the `query` in the `target` container and yields them. + + Args: + query: (str or BaseQuery): The query to select records to read. + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + chunk_size (int or None): The number of records or bytes to read in one + batch, depending on whether the records are dictionaries or bytes. + raw_output (bool): Controls whether to yield bytes or dictionaries. + If the records are dictionaries and `raw_output` is set to `True`, they + are encoded as JSON. + If the records are bytes and `raw_output` is set to `False`, they are + decoded as JSON by line. + ignore_errors (bool): If `True`, errors during the read operation + are be ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + + Yield: + dict: If `raw_output` is False. + bytes: If `raw_output` is True. + + Raise: + BackendException: If a failure during the read operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid. + """ + + @abstractmethod + async def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Writes `data` records to the `target` container and returns their count. + + Args: + data: (Iterable or IOBase): The data to write. + target (str or None): The target container name. + If `target` is `None`, a default value is used instead. + chunk_size (int or None): The number of records or bytes to write in one + batch, depending on whether `data` contains dictionaries or bytes. + If `chunk_size` is `None`, a default value is used instead. + ignore_errors (bool): If `True`, errors during the write operation + are ignored and logged. If `False` (default), a `BackendException` + is raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Return: + int: The number of written records. + + Raise: + BackendException: If a failure during the write operation occurs and + `ignore_errors` is set to `False`. + BackendParameterException: If a backend argument value is not valid. + """ + + @abstractmethod + async def close(self) -> None: + """Close the data backend client. + + Raise: + BackendException: If a failure during the close operation occurs. + """ From 162a8fa414c59f90fd65fa23c6446cbd472620bd Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Sun, 18 Jun 2023 19:11:32 +0200 Subject: [PATCH 17/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20async=20elast?= =?UTF-8?q?icsearch=20data=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We add the Async elasticsearch data backend mostly taken from the sync backend using the async elasticsearch methods. --- setup.cfg | 2 +- src/ralph/backends/data/async_es.py | 275 +++++++++ src/ralph/backends/data/base.py | 65 +-- src/ralph/backends/data/es.py | 32 +- src/ralph/backends/lrs/async_es.py | 50 ++ src/ralph/backends/lrs/base.py | 22 +- src/ralph/backends/lrs/es.py | 125 ++-- tests/backends/data/test_async_es.py | 831 +++++++++++++++++++++++++++ tests/backends/lrs/test_async_es.py | 421 ++++++++++++++ tests/backends/lrs/test_es.py | 2 +- tests/conftest.py | 2 + tests/fixtures/backends.py | 52 +- 12 files changed, 1747 insertions(+), 132 deletions(-) create mode 100644 src/ralph/backends/data/async_es.py create mode 100644 src/ralph/backends/lrs/async_es.py create mode 100644 tests/backends/data/test_async_es.py create mode 100644 tests/backends/lrs/test_async_es.py diff --git a/setup.cfg b/setup.cfg index 80c7eb298..98a8d520d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ backend-clickhouse = clickhouse-connect[numpy,pandas]<0.6 python-dateutil>=2.8.2 backend-es = - elasticsearch>=8.0.0 + elasticsearch[async]>=8.0.0 backend-ldp = ovh>=1.0.0 requests>=2.0.0 diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py new file mode 100644 index 000000000..76ede8803 --- /dev/null +++ b/src/ralph/backends/data/async_es.py @@ -0,0 +1,275 @@ +"""Asynchronous Elasticsearch data backend for Ralph.""" + +import logging +from io import IOBase +from itertools import chain +from typing import Iterable, Iterator, Union + +from elasticsearch import ApiError, AsyncElasticsearch, TransportError +from elasticsearch.helpers import BulkIndexError, async_streaming_bulk + +from ralph.backends.data.base import ( + BaseAsyncDataBackend, + BaseOperationType, + DataBackendStatus, + enforce_query_checks, +) +from ralph.backends.data.es import ESDataBackend, ESDataBackendSettings, ESQuery +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import parse_bytes_to_dict, read_raw + +# pylint: disable=duplicate-code + +logger = logging.getLogger(__name__) + + +class AsyncESDataBackend(BaseAsyncDataBackend): + """Asynchronous Elasticsearch data backend.""" + + name = "async_es" + query_model = ESQuery + settings_class = ESDataBackendSettings + + def __init__(self, settings: settings_class = None): + """Instantiate the asynchronous Elasticsearch client. + + Args: + settings (ESDataBackendSettings or None): The data backend settings. + If `settings` is `None`, a default settings instance is used instead. + """ + self.settings = settings if settings else self.settings_class() + self._client = None + + @property + def client(self): + """Create an AsyncElasticsearch client if it doesn't exist.""" + if not self._client: + self._client = AsyncElasticsearch( + self.settings.HOSTS, **self.settings.CLIENT_OPTIONS.dict() + ) + return self._client + + async def status(self) -> DataBackendStatus: + """Check Elasticsearch cluster connection and status.""" + try: + await self.client.info() + cluster_status = await self.client.cat.health() + except TransportError as error: + logger.error("Failed to connect to Elasticsearch: %s", error) + return DataBackendStatus.AWAY + + if "green" in cluster_status: + return DataBackendStatus.OK + + if "yellow" in cluster_status and self.settings.ALLOW_YELLOW_STATUS: + logger.info("Cluster status is yellow.") + return DataBackendStatus.OK + + logger.error("Cluster status is not green: %s", cluster_status) + + return DataBackendStatus.ERROR + + async def list( + self, target: str = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List available Elasticsearch indices, data streams and aliases. + + Args: + target (str or None): The comma-separated list of data streams, indices, + and aliases to limit the request. Supports wildcards (*). + If target is `None`, lists all available indices, data streams and + aliases. Equivalent to (`target` = "*"). + details (bool): Get detailed informations instead of just names. + new (bool): Ignored. + + Yield: + str: The next index, data stream or alias name. (If `details` is False). + dict: The next index, data stream or alias details. (If `details` is True). + + Raise: + BackendException: If a failure during indices retrieval occurs. + """ + target = target if target else "*" + try: + indices = await self.client.indices.get(index=target) + except (ApiError, TransportError) as error: + msg = "Failed to read indices: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + if new: + logger.warning("The `new` argument is ignored") + + if details: + for index, value in indices.items(): + yield {index: value} + + return + + for index in indices: + yield index + + @enforce_query_checks + async def read( + self, + *, + query: Union[str, ESQuery] = None, + target: str = None, + chunk_size: Union[None, int] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read documents matching the query in the target index and yield them. + + Args: + query (str or ESQuery): A query in the Lucene query string syntax or a + dictionary defining a search definition using the Elasticsearch Query + DSL. The Lucene query overrides the query DSL if present. See ESQuery. + target (str or None): The target Elasticsearch index name to query. + If target is `None`, the `DEFAULT_INDEX` is used instead. + chunk_size (int or None): The chunk size for reading batches of documents. + If chunk_size is `None` it defaults to `DEFAULT_CHUNK_SIZE`. + raw_output (bool): Controls whether to yield dictionaries or bytes. + ignore_errors (bool): Ignored. + + Yield: + bytes: The next raw document if `raw_output` is True. + dict: The next JSON parsed document if `raw_output` is False. + + Raise: + BackendException: If a failure occurs during Elasticsearch connection. + """ + target = target if target else self.settings.DEFAULT_INDEX + chunk_size = chunk_size if chunk_size else self.settings.DEFAULT_CHUNK_SIZE + if ignore_errors: + logger.warning("The `ignore_errors` argument is ignored") + + if not query.pit.keep_alive: + query.pit.keep_alive = self.settings.POINT_IN_TIME_KEEP_ALIVE + if not query.pit.id: + try: + query.pit.id = ( + await self.client.open_point_in_time( + index=target, keep_alive=query.pit.keep_alive + ) + )["id"] + except (ApiError, TransportError, ValueError) as error: + msg = "Failed to open Elasticsearch point in time: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + limit = query.size + kwargs = query.dict(exclude={"query_string", "size"}) + if query.query_string: + kwargs["q"] = query.query_string + + count = chunk_size + while limit or chunk_size == count: + kwargs["size"] = limit if limit and limit < chunk_size else chunk_size + try: + documents = (await self.client.search(**kwargs))["hits"]["hits"] + except (ApiError, TransportError, TypeError) as error: + msg = "Failed to execute Elasticsearch query: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + count = len(documents) + if limit: + limit -= count if chunk_size == count else limit + query.search_after = None + if count: + query.search_after = [str(part) for part in documents[-1]["sort"]] + kwargs["search_after"] = query.search_after + if raw_output: + documents = read_raw(documents, self.settings.LOCALE_ENCODING) + for document in documents: + yield document + + async def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[None, str] = None, + chunk_size: Union[None, int] = None, + ignore_errors: bool = False, + operation_type: Union[None, BaseOperationType] = None, + ) -> int: + """Write data documents to the target index and return their count. + + Args: + data: (Iterable or IOBase): The data containing documents to write. + target (str or None): The target Elasticsearch index name. + If target is `None`, the `DEFAULT_INDEX` is used instead. + chunk_size (int or None): The number of documents to write in one batch. + If chunk_size is `None` it defaults to `DEFAULT_CHUNK_SIZE`. + ignore_errors (bool): If `True`, errors during the write operation + will be ignored and logged. If `False` (default), a `BackendException` + will be raised if an error occurs. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Return: + int: The number of written documents. + + Raise: + BackendException: If a failure occurs while writing to Elasticsearch or + during document decoding and `ignore_errors` is set to `False`. + BackendParameterException: If the `operation_type` is `APPEND` as it is not + supported. + """ + count = 0 + data = iter(data) + try: + first_record = next(data) + except StopIteration: + logger.info("Data Iterator is empty; skipping write to target.") + return count + if not operation_type: + operation_type = self.default_operation_type + target = target if target else self.settings.DEFAULT_INDEX + chunk_size = chunk_size if chunk_size else self.settings.DEFAULT_CHUNK_SIZE + if operation_type == BaseOperationType.APPEND: + msg = "Append operation_type is not supported." + logger.error(msg) + raise BackendParameterException(msg) + + data = chain((first_record,), data) + if isinstance(first_record, bytes): + data = parse_bytes_to_dict(data, ignore_errors, logger) + + logger.debug( + "Start writing to the %s index (chunk size: %d)", target, chunk_size + ) + try: + async for success, action in async_streaming_bulk( + client=self.client, + actions=ESDataBackend.to_documents(data, target, operation_type), + chunk_size=chunk_size, + raise_on_error=(not ignore_errors), + refresh=self.settings.REFRESH_AFTER_WRITE, + ): + count += success + logger.debug("Wrote %d document [action: %s]", success, action) + + logger.info("Finished writing %d documents with success", count) + except (BulkIndexError, ApiError, TransportError) as error: + msg = "%s %s Total succeeded writes: %s" + details = getattr(error, "errors", "") + logger.error(msg, error, details, count) + raise BackendException(msg % (error, details, count)) from error + return count + + async def close(self) -> None: + """Close the AsyncElasticsearch client. + + Raise: + BackendException: If a failure during the close operation occurs. + """ + if not self._client: + return + + try: + await self.client.close() + except TransportError as error: + msg = "Failed to close Elasticsearch client: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 107fdbe2b..84524901e 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -16,7 +16,7 @@ class BaseDataBackendSettings(BaseSettings): - """Represents the data backend default configuration.""" + """Data backend default configuration.""" class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -66,7 +66,7 @@ class DataBackendStatus(Enum): def enforce_query_checks(method): - """Enforces query argument type checking for methods using it.""" + """Enforce query argument type checking for methods using it.""" @functools.wraps(method) def wrapper(*args, **kwargs): @@ -89,15 +89,15 @@ class BaseDataBackend(ABC): @abstractmethod def __init__(self, settings: settings_class = None): - """Instantiates the data backend. + """Instantiate the data backend. Args: - settings (BaseDataBackendSettings or None): The backend settings. + settings (BaseDataBackendSettings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: - """Validates and transforms the query.""" + """Validate and transform the query.""" if query is None: query = self.query_model() @@ -126,7 +126,7 @@ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery @abstractmethod def status(self) -> DataBackendStatus: - """Implements data backend checks (e.g. connection, cluster status). + """Implement data backend checks (e.g. connection, cluster status). Return: DataBackendStatus: The status of the data backend. @@ -136,7 +136,7 @@ def status(self) -> DataBackendStatus: def list( self, target: str = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: - """Lists containers in the data backend. E.g., collections, files, indexes. + """List containers in the data backend. E.g., collections, files, indexes. Args: target (str or None): The target container name. @@ -164,7 +164,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: - """Reads records matching the `query` in the `target` container and yields them. + """Read records matching the `query` in the `target` container and yield them. Args: query: (str or BaseQuery): The query to select records to read. @@ -200,7 +200,7 @@ def write( # pylint: disable=too-many-arguments ignore_errors: bool = False, operation_type: Union[None, BaseOperationType] = None, ) -> int: - """Writes `data` records to the `target` container and returns their count. + """Write `data` records to the `target` container and return their count. Args: data: (Iterable or IOBase): The data to write. @@ -226,22 +226,8 @@ def write( # pylint: disable=too-many-arguments """ -def async_enforce_query_checks(method): - """Enforces query argument type checking for methods using it.""" - - @functools.wraps(method) - async def wrapper(*args, **kwargs): - """Wrap method execution.""" - query = kwargs.pop("query", None) - self_ = args[0] - - return method(*args, query=self_.validate_query(query), **kwargs) - - return wrapper - - class BaseAsyncDataBackend(ABC): - """Base data backend interface.""" + """Base async data backend interface.""" name = "base" query_model = BaseQuery @@ -250,7 +236,7 @@ class BaseAsyncDataBackend(ABC): @abstractmethod def __init__(self, settings: settings_class = None): - """Instantiates the data backend. + """Instantiate the data backend. Args: settings (BaseDataBackendSettings or None): The backend settings. @@ -258,7 +244,7 @@ def __init__(self, settings: settings_class = None): """ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: - """Validates and transforms the query.""" + """Validate and transform the query.""" if query is None: query = self.query_model() @@ -268,17 +254,18 @@ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery if isinstance(query, dict): try: query = self.query_model(**query) - except ValidationError as err: + except ValidationError as error: + msg = "The 'query' argument is expected to be a %s instance. %s" + errors = error.errors() + logger.error(msg, self.query_model.__name__, errors) raise BackendParameterException( - "The 'query' argument is expected to be a " - f"{self.query_model.__name__} instance. {err.errors()}" - ) from err + msg % (self.query_model.__name__, errors) + ) from error if not isinstance(query, self.query_model): - raise BackendParameterException( - "The 'query' argument is expected to be a " - f"{self.query_model.__name__} instance." - ) + msg = "The 'query' argument is expected to be a %s instance." + logger.error(msg, self.query_model.__name__) + raise BackendParameterException(msg % (self.query_model.__name__,)) logger.debug("Query: %s", str(query)) @@ -286,7 +273,7 @@ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery @abstractmethod async def status(self) -> DataBackendStatus: - """Implements data backend checks (e.g. connection, cluster status). + """Implement data backend checks (e.g. connection, cluster status). Return: DataBackendStatus: The status of the data backend. @@ -296,7 +283,7 @@ async def status(self) -> DataBackendStatus: async def list( self, target: str = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: - """Lists containers in the data backend. E.g., collections, files, indexes. + """List containers in the data backend. E.g., collections, files, indexes. Args: target (str or None): The target container name. @@ -314,7 +301,7 @@ async def list( """ @abstractmethod - @async_enforce_query_checks + @enforce_query_checks async def read( self, *, @@ -324,7 +311,7 @@ async def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: - """Reads records matching the `query` in the `target` container and yields them. + """Read records matching the `query` in the `target` container and yield them. Args: query: (str or BaseQuery): The query to select records to read. @@ -360,7 +347,7 @@ async def write( # pylint: disable=too-many-arguments ignore_errors: bool = False, operation_type: Union[None, BaseOperationType] = None, ) -> int: - """Writes `data` records to the `target` container and returns their count. + """Write `data` records to the `target` container and return their count. Args: data: (Iterable or IOBase): The data to write. diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 09f4a2b38..b080e85cb 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -1,11 +1,10 @@ """Elasticsearch data backend for Ralph.""" -import json import logging from io import IOBase from itertools import chain from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Union +from typing import Iterable, Iterator, List, Literal, Optional, Union from elasticsearch import ApiError, Elasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, streaming_bulk @@ -21,6 +20,7 @@ ) from ralph.conf import BaseSettingsConfig, CommaSeparatedTuple from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import parse_bytes_to_dict, read_raw logger = logging.getLogger(__name__) @@ -264,7 +264,7 @@ def read( query.search_after = [str(part) for part in documents[-1]["sort"]] kwargs["search_after"] = query.search_after if raw_output: - documents = self._read_raw(documents) + documents = read_raw(documents, self.settings.LOCALE_ENCODING) for document in documents: yield document @@ -318,7 +318,7 @@ def write( # pylint: disable=too-many-arguments data = chain((first_record,), data) if isinstance(first_record, bytes): - data = self._parse_bytes_to_dict(data, ignore_errors) + data = parse_bytes_to_dict(data, ignore_errors, logger) logger.debug( "Start writing to the %s index (chunk size: %d)", target, chunk_size @@ -326,7 +326,7 @@ def write( # pylint: disable=too-many-arguments try: for success, action in streaming_bulk( client=self.client, - actions=self._to_documents(data, target, operation_type), + actions=ESDataBackend.to_documents(data, target, operation_type), chunk_size=chunk_size, raise_on_error=(not ignore_errors), refresh=self.settings.REFRESH_AFTER_WRITE, @@ -343,7 +343,7 @@ def write( # pylint: disable=too-many-arguments return count @staticmethod - def _to_documents( + def to_documents( data: Iterable[dict], target: str, operation_type: BaseOperationType, @@ -373,23 +373,3 @@ def _to_documents( "_id": item.get("id", None), "_op_type": operation_type.value, } - - def _read_raw(self, documents: Iterable[Dict[str, Any]]) -> Iterator[bytes]: - """Read the `documents` Iterable and yield bytes.""" - for document in documents: - yield json.dumps(document).encode(self.settings.LOCALE_ENCODING) - - @staticmethod - def _parse_bytes_to_dict( - raw_documents: Iterable[bytes], ignore_errors: bool - ) -> Iterator[dict]: - """Read the `raw_documents` Iterable and yield dictionaries.""" - for raw_document in raw_documents: - try: - yield json.loads(raw_document) - except (TypeError, json.JSONDecodeError) as error: - msg = "Failed to decode JSON: %s, for document: %s" - logger.error(msg, error, raw_document) - if ignore_errors: - continue - raise BackendException(msg % (error, raw_document)) from error diff --git a/src/ralph/backends/lrs/async_es.py b/src/ralph/backends/lrs/async_es.py new file mode 100644 index 000000000..1842b299f --- /dev/null +++ b/src/ralph/backends/lrs/async_es.py @@ -0,0 +1,50 @@ +"""Asynchronous Elasticsearch LRS backend for Ralph.""" + +import logging +from typing import Iterator, List + +from ralph.backends.data.async_es import AsyncESDataBackend +from ralph.backends.lrs.base import ( + BaseAsyncLRSBackend, + StatementParameters, + StatementQueryResult, +) +from ralph.backends.lrs.es import get_query +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class AsyncESLRSBackend(BaseAsyncLRSBackend, AsyncESDataBackend): + """Asynchronous Elasticsearch LRS backend implementation.""" + + settings_class = AsyncESDataBackend.settings_class + + async def query_statements( + self, params: StatementParameters + ) -> StatementQueryResult: + """Return the statements query payload using xAPI parameters.""" + query = get_query(params=params) + try: + statements = [ + document["_source"] + async for document in self.read(query=query, chunk_size=params.limit) + ] + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from Elasticsearch") + raise error + + return StatementQueryResult( + statements=statements, + pit_id=query.pit.id, + search_after="|".join(query.search_after) if query.search_after else "", + ) + + async def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Yield statements with matching ids from the backend.""" + try: + async for document in self.read(query={"query": {"terms": {"_id": ids}}}): + yield document["_source"] + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from Elasticsearch") + raise error diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 4beb06680..008d60dfe 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -8,7 +8,11 @@ from pydantic import BaseModel -from ralph.backends.data.base import BaseDataBackend, BaseDataBackendSettings +from ralph.backends.data.base import ( + BaseAsyncDataBackend, + BaseDataBackend, + BaseDataBackendSettings, +) class BaseLRSBackendSettings(BaseDataBackendSettings): @@ -74,3 +78,19 @@ def query_statements(self, params: StatementParameters) -> StatementQueryResult: @abstractmethod def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: """Yield statements with matching ids from the backend.""" + + +class BaseAsyncLRSBackend(BaseAsyncDataBackend): + """Base async LRS backend interface.""" + + settings_class = BaseLRSBackendSettings + + @abstractmethod + async def query_statements( + self, params: StatementParameters + ) -> StatementQueryResult: + """Return the statements query payload using xAPI parameters.""" + + @abstractmethod + async def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Return the list of matching statement IDs from the database.""" diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index 8f3354f9a..66e00b8f7 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -22,45 +22,7 @@ class ESLRSBackend(BaseLRSBackend, ESDataBackend): def query_statements(self, params: StatementParameters) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" - es_query_filters = [] - - if params.statementId: - es_query_filters += [{"term": {"_id": params.statementId}}] - - self._add_agent_filters(es_query_filters, params.agent, "actor") - self._add_agent_filters(es_query_filters, params.authority, "authority") - - if params.verb: - es_query_filters += [{"term": {"verb.id.keyword": params.verb}}] - - if params.activity: - es_query_filters += [ - {"term": {"object.objectType.keyword": "Activity"}}, - {"term": {"object.id.keyword": params.activity}}, - ] - - if params.since: - es_query_filters += [{"range": {"timestamp": {"gt": params.since}}}] - - if params.until: - es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] - - es_query = { - "pit": ESQueryPit.construct(id=params.pit_id), - "size": params.limit, - "sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}], - } - if len(es_query_filters) > 0: - es_query["query"] = {"bool": {"filter": es_query_filters}} - - if params.ignore_order: - es_query["sort"] = "_shard_doc" - - if params.search_after: - es_query["search_after"] = params.search_after.split("|") - - # Note: `params` fields are validated thus we skip their validation in ESQuery. - query = ESQuery.construct(**es_query) + query = get_query(params=params) try: es_documents = self.read(query=query, chunk_size=params.limit) statements = [document["_source"] for document in es_documents] @@ -83,24 +45,67 @@ def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: logger.error("Failed to read from Elasticsearch") raise error - @staticmethod - def _add_agent_filters( - es_query_filters: list, agent_params: AgentParameters, target_field: str - ): - """Add filters relative to agents to `es_query_filters`.""" - if not agent_params: - return - if agent_params.mbox: - field = f"{target_field}.mbox.keyword" - es_query_filters += [{"term": {field: agent_params.mbox}}] - elif agent_params.mbox_sha1sum: - field = f"{target_field}.mbox_sha1sum.keyword" - es_query_filters += [{"term": {field: agent_params.mbox_sha1sum}}] - elif agent_params.openid: - field = f"{target_field}.openid.keyword" - es_query_filters += [{"term": {field: agent_params.openid}}] - elif agent_params.account__name: - field = f"{target_field}.account.name.keyword" - es_query_filters += [{"term": {field: agent_params.account__name}}] - field = f"{target_field}.account.homePage.keyword" - es_query_filters += [{"term": {field: agent_params.account__home_page}}] + +def get_query(params: StatementParameters) -> ESQuery: + """Construct query from statement parameters.""" + es_query_filters = [] + + if params.statementId: + es_query_filters += [{"term": {"_id": params.statementId}}] + + add_agent_filters(es_query_filters, params.agent, "actor") + add_agent_filters(es_query_filters, params.authority, "authority") + + if params.verb: + es_query_filters += [{"term": {"verb.id.keyword": params.verb}}] + + if params.activity: + es_query_filters += [ + {"term": {"object.objectType.keyword": "Activity"}}, + {"term": {"object.id.keyword": params.activity}}, + ] + + if params.since: + es_query_filters += [{"range": {"timestamp": {"gt": params.since}}}] + + if params.until: + es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] + + es_query = { + "pit": ESQueryPit.construct(id=params.pit_id), + "size": params.limit, + "sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}], + } + if len(es_query_filters) > 0: + es_query["query"] = {"bool": {"filter": es_query_filters}} + + if params.ignore_order: + es_query["sort"] = "_shard_doc" + + if params.search_after: + es_query["search_after"] = params.search_after.split("|") + + # Note: `params` fields are validated thus we skip their validation in ESQuery. + return ESQuery.construct(**es_query) + + +def add_agent_filters( + es_query_filters: list, agent_params: AgentParameters, target_field: str +): + """Add filters relative to agents to `es_query_filters`.""" + if not agent_params: + return + if agent_params.mbox: + field = f"{target_field}.mbox.keyword" + es_query_filters += [{"term": {field: agent_params.mbox}}] + elif agent_params.mbox_sha1sum: + field = f"{target_field}.mbox_sha1sum.keyword" + es_query_filters += [{"term": {field: agent_params.mbox_sha1sum}}] + elif agent_params.openid: + field = f"{target_field}.openid.keyword" + es_query_filters += [{"term": {field: agent_params.openid}}] + elif agent_params.account__name: + field = f"{target_field}.account.name.keyword" + es_query_filters += [{"term": {field: agent_params.account__name}}] + field = f"{target_field}.account.homePage.keyword" + es_query_filters += [{"term": {field: agent_params.account__home_page}}] diff --git a/tests/backends/data/test_async_es.py b/tests/backends/data/test_async_es.py new file mode 100644 index 000000000..1d3b8976d --- /dev/null +++ b/tests/backends/data/test_async_es.py @@ -0,0 +1,831 @@ +"""Tests for Ralph Async Elasticsearch data backend.""" + +import json +import logging +import random +import re +from collections.abc import Iterable +from datetime import datetime +from io import BytesIO + +import pytest +from elastic_transport import ApiResponseMeta +from elasticsearch import ApiError, AsyncElasticsearch +from elasticsearch import ConnectionError as ESConnectionError + +from ralph.backends.data.async_es import ( + AsyncESDataBackend, + ESDataBackendSettings, + ESQuery, +) +from ralph.backends.data.base import BaseOperationType, DataBackendStatus +from ralph.backends.data.es import ESClientOptions +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import now + +from tests.fixtures.backends import ( + ES_TEST_FORWARDING_INDEX, + ES_TEST_INDEX, + get_es_fixture, +) + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_default_instantiation( + monkeypatch, fs +): + """Test the `AsyncESDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "ALLOW_YELLOW_STATUS", + "CLIENT_OPTIONS", + "CLIENT_OPTIONS__ca_certs", + "CLIENT_OPTIONS__verify_certs", + "DEFAULT_CHUNK_SIZE", + "DEFAULT_INDEX", + "HOSTS", + "LOCALE_ENCODING", + "POINT_IN_TIME_KEEP_ALIVE", + "REFRESH_AFTER_WRITE", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__ES__{name}", raising=False) + + assert AsyncESDataBackend.name == "async_es" + assert AsyncESDataBackend.query_model == ESQuery + assert AsyncESDataBackend.default_operation_type == BaseOperationType.INDEX + assert AsyncESDataBackend.settings_class == ESDataBackendSettings + backend = AsyncESDataBackend() + assert not backend.settings.ALLOW_YELLOW_STATUS + assert backend.settings.CLIENT_OPTIONS == ESClientOptions() + assert backend.settings.DEFAULT_CHUNK_SIZE == 500 + assert backend.settings.DEFAULT_INDEX == "statements" + assert backend.settings.HOSTS == ("http://localhost:9200",) + assert backend.settings.LOCALE_ENCODING == "utf8" + assert backend.settings.POINT_IN_TIME_KEEP_ALIVE == "1m" + assert not backend.settings.REFRESH_AFTER_WRITE + assert isinstance(backend.client, AsyncElasticsearch) + elasticsearch_node = backend.client.transport.node_pool.get() + assert elasticsearch_node.config.ca_certs is None + assert elasticsearch_node.config.verify_certs is None + assert elasticsearch_node.host == "localhost" + assert elasticsearch_node.port == 9200 + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_instantiation_with_settings(): + """Test the `AsyncESDataBackend` instantiation with settings.""" + # Not testing `ca_certs` and `verify_certs` as elasticsearch aiohttp + # node transport checks that file exists + settings = ESDataBackendSettings( + ALLOW_YELLOW_STATUS=True, + CLIENT_OPTIONS={"verify_certs": False, "ca_certs": None}, + DEFAULT_CHUNK_SIZE=5000, + DEFAULT_INDEX=ES_TEST_INDEX, + HOSTS=["https://elasticsearch_hostname:9200"], + LOCALE_ENCODING="utf-16", + POINT_IN_TIME_KEEP_ALIVE="5m", + REFRESH_AFTER_WRITE=True, + ) + backend = AsyncESDataBackend(settings) + assert backend.settings.ALLOW_YELLOW_STATUS + assert backend.settings.CLIENT_OPTIONS == ESClientOptions( + verify_certs=False, ca_certs=None + ) + assert backend.settings.DEFAULT_CHUNK_SIZE == 5000 + assert backend.settings.DEFAULT_INDEX == ES_TEST_INDEX + assert backend.settings.HOSTS == ("https://elasticsearch_hostname:9200",) + assert backend.settings.LOCALE_ENCODING == "utf-16" + assert backend.settings.POINT_IN_TIME_KEEP_ALIVE == "5m" + assert backend.settings.REFRESH_AFTER_WRITE + assert isinstance(backend.client, AsyncElasticsearch) + elasticsearch_node = backend.client.transport.node_pool.get() + assert elasticsearch_node.host == "elasticsearch_hostname" + assert elasticsearch_node.port == 9200 + assert backend.settings.POINT_IN_TIME_KEEP_ALIVE == "5m" + + try: + AsyncESDataBackend(settings) + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"Two AsyncESDataBackends should not raise exceptions: {err}") + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_status_method( + monkeypatch, async_es_backend, caplog +): + """Test the `AsyncESDataBackend.status` method.""" + + async def mock_info(): + return None + + def mock_health_status(es_status): + async def mock_health(): + return es_status + + return mock_health + + backend = async_es_backend() + + # Given green status, the `status` method should return `DataBackendStatus.OK`. + with monkeypatch.context() as elasticsearch_patch: + es_status = "1664532320 10:05:20 docker-cluster green 1 1 2 2 0 0 1 0 - 66.7%" + elasticsearch_patch.setattr(backend.client, "info", mock_info) + elasticsearch_patch.setattr( + backend.client.cat, "health", mock_health_status(es_status) + ) + assert await backend.status() == DataBackendStatus.OK + + with monkeypatch.context() as elasticsearch_patch: + # Given yellow status, the `status` method should return + # `DataBackendStatus.ERROR`. + es_status = "1664532320 10:05:20 docker-cluster yellow 1 1 2 2 0 0 1 0 - 66.7%" + elasticsearch_patch.setattr(backend.client, "info", mock_info) + elasticsearch_patch.setattr( + backend.client.cat, "health", mock_health_status(es_status) + ) + assert await backend.status() == DataBackendStatus.ERROR + # Given yellow status, and `settings.ALLOW_YELLOW_STATUS` set to `True`, + # the `status` method should return `DataBackendStatus.OK`. + elasticsearch_patch.setattr(backend.settings, "ALLOW_YELLOW_STATUS", True) + with caplog.at_level(logging.INFO): + assert await backend.status() == DataBackendStatus.OK + + assert ( + "ralph.backends.data.async_es", + logging.INFO, + "Cluster status is yellow.", + ) in caplog.record_tuples + + # Given a connection exception, the `status` method should return + # `DataBackendStatus.ERROR`. + with monkeypatch.context() as elasticsearch_patch: + + async def mock_connection_error(): + """ES client info mock that raises a connection error.""" + raise ESConnectionError("", (Exception("Mocked connection error"),)) + + elasticsearch_patch.setattr(backend.client, "info", mock_connection_error) + with caplog.at_level(logging.ERROR): + assert await backend.status() == DataBackendStatus.AWAY + + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + "Failed to connect to Elasticsearch: Connection error caused by: " + "Exception(Mocked connection error)", + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.parametrize( + "exception, error", + [ + (ApiError("", ApiResponseMeta(*([None] * 5)), None), "ApiError(None, '')"), + (ESConnectionError(""), "Connection error"), + ], +) +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_list_method_with_failure( + exception, error, caplog, monkeypatch, async_es_backend +): + """Test the `AsyncESDataBackend.list` method given a failed Elasticsearch connection + should raise a `BackendException` and log an error message. + """ + + async def mock_get(index): + """Mocks the AsyncES.client.indices.get method always raising an exception.""" + assert index == "*" + raise exception + + backend = async_es_backend() + monkeypatch.setattr(backend.client.indices, "get", mock_get) + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException): + async for result in backend.list(): + next(result) + + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + f"Failed to read indices: {error}", + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_list_method_without_history( + async_es_backend, monkeypatch +): + """Test the `AsyncESDataBackend.list` method without history.""" + + indices = {"index_1": {"info_1": "foo"}, "index_2": {"info_2": "baz"}} + + async def mock_get(index): + """Mocks the AsyncES.client.indices.get method returning a dictionary.""" + assert index == "target_index*" + return indices + + backend = async_es_backend() + monkeypatch.setattr(backend.client.indices, "get", mock_get) + result = [statement async for statement in backend.list("target_index*")] + assert isinstance(result, Iterable) + assert list(result) == list(indices.keys()) + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_list_method_with_details( + async_es_backend, monkeypatch +): + """Test the `AsyncESDataBackend.list` method with `details` set to `True`.""" + indices = {"index_1": {"info_1": "foo"}, "index_2": {"info_2": "baz"}} + + async def mock_get(index): + """Mocks the AsyncES.client.indices.get method returning a dictionary.""" + assert index == "target_index*" + return indices + + backend = async_es_backend() + monkeypatch.setattr(backend.client.indices, "get", mock_get) + result = [ + statement async for statement in backend.list("target_index*", details=True) + ] + assert isinstance(result, Iterable) + assert list(result) == [ + {"index_1": {"info_1": "foo"}}, + {"index_2": {"info_2": "baz"}}, + ] + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_list_method_with_history( + async_es_backend, caplog, monkeypatch +): + """Test the `AsyncESDataBackend.list` method given `new` argument set to True, + should log a warning message. + """ + backend = async_es_backend() + + async def mock_get(*args, **kwargs): # pylint: disable=unused-argument + return {} + + monkeypatch.setattr(backend.client.indices, "get", mock_get) + with caplog.at_level(logging.WARNING): + result = [statement async for statement in backend.list(new=True)] + assert not list(result) + + assert ( + "ralph.backends.data.async_es", + logging.WARNING, + "The `new` argument is ignored", + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.parametrize( + "exception, error", + [ + (ApiError("", ApiResponseMeta(*([None] * 5)), None), r"ApiError\(None, ''\)"), + (ESConnectionError(""), "Connection error"), + ], +) +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_read_method_with_failure( + exception, error, es, async_es_backend, caplog, monkeypatch +): + """Test the `AsyncESDataBackend.read` method, given a request failure, should + raise a `BackendException`. + """ + # pylint: disable=invalid-name,unused-argument,too-many-arguments + + def mock_async_es_search_open_pit(**kwargs): + """Mock the AsyncES.client.search and open_point_in_time methods always raising + an exception. + """ + raise exception + + backend = async_es_backend() + + # Search failure. + monkeypatch.setattr(backend.client, "search", mock_async_es_search_open_pit) + with pytest.raises( + BackendException, match=f"Failed to execute Elasticsearch query: {error}" + ): + with caplog.at_level(logging.ERROR): + result = [statement async for statement in backend.read()] + next(iter(result)) + + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + "Failed to execute Elasticsearch query: %s" % error.replace("\\", ""), + ) in caplog.record_tuples + + # Open point in time failure. + monkeypatch.setattr( + backend.client, "open_point_in_time", mock_async_es_search_open_pit + ) + with pytest.raises( + BackendException, match=f"Failed to open Elasticsearch point in time: {error}" + ): + with caplog.at_level(logging.ERROR): + result = [statement async for statement in backend.read()] + next(iter(result)) + + error = error.replace("\\", "") + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + "Failed to open Elasticsearch point in time: %s" % error.replace("\\", ""), + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_read_method_with_ignore_errors( + es, async_es_backend, monkeypatch, caplog +): + """Test the `AsyncESDataBackend.read` method, given `ignore_errors` set to `True`, + should log a warning message. + """ + # pylint: disable=invalid-name, unused-argument + backend = async_es_backend() + + async def mock_async_es_search(**kwargs): # pylint: disable=unused-argument + return {"hits": {"hits": []}} + + monkeypatch.setattr(backend.client, "search", mock_async_es_search) + with caplog.at_level(logging.WARNING): + _ = [statement async for statement in backend.read(ignore_errors=True)] + + assert ( + "ralph.backends.data.async_es", + logging.WARNING, + "The `ignore_errors` argument is ignored", + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_read_method_with_raw_ouput( + es, async_es_backend +): + """Test the `AsyncESDataBackend.read` method with `raw_output` set to `True`.""" + # pylint: disable=invalid-name,unused-argument + backend = async_es_backend() + documents = [{"id": idx, "timestamp": now()} for idx in range(10)] + assert await backend.write(documents) == 10 + hits = [statement async for statement in backend.read(raw_output=True)] + for i, hit in enumerate(hits): + assert isinstance(hit, bytes) + assert json.loads(hit).get("_source") == documents[i] + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_read_method_without_raw_ouput( + es, async_es_backend +): + """Test the `AsyncESDataBackend.read` method with `raw_output` set to `False`.""" + # pylint: disable=invalid-name,unused-argument + backend = async_es_backend() + documents = [{"id": idx, "timestamp": now()} for idx in range(10)] + assert await backend.write(documents) == 10 + hits = [statement async for statement in backend.read()] + for i, hit in enumerate(hits): + assert isinstance(hit, dict) + assert hit.get("_source") == documents[i] + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_read_method_with_query( + es, async_es_backend, caplog +): + """Test the `AsyncESDataBackend.read` method with a query.""" + # pylint: disable=invalid-name,unused-argument + backend = async_es_backend() + documents = [{"id": idx, "timestamp": now(), "modulo": idx % 2} for idx in range(5)] + assert await backend.write(documents) == 5 + # Find every even item. + query = ESQuery(query={"term": {"modulo": 0}}) + results = [statement async for statement in backend.read(query=query)] + assert len(results) == 3 + assert results[0]["_source"]["id"] == 0 + assert results[1]["_source"]["id"] == 2 + assert results[2]["_source"]["id"] == 4 + + # Find the first two even items. + query = ESQuery(query={"term": {"modulo": 0}}, size=2) + results = [statement async for statement in backend.read(query=query)] + assert len(results) == 2 + assert results[0]["_source"]["id"] == 0 + assert results[1]["_source"]["id"] == 2 + + # Find the first ten even items although there are only three available. + query = ESQuery(query={"term": {"modulo": 0}}, size=10) + results = [statement async for statement in backend.read(query=query)] + assert len(results) == 3 + assert results[0]["_source"]["id"] == 0 + assert results[1]["_source"]["id"] == 2 + assert results[2]["_source"]["id"] == 4 + # Find every odd item. + query = {"query": {"term": {"modulo": 1}}} + results = [statement async for statement in backend.read(query=query)] + assert len(results) == 2 + assert results[0]["_source"]["id"] == 1 + assert results[1]["_source"]["id"] == 3 + + # Find documents with ID equal to one or five. + query = "id:(1 OR 5)" + results = [statement async for statement in backend.read(query=query)] + assert len(results) == 1 + assert results[0]["_source"]["id"] == 1 + + # Check query argument type + with pytest.raises( + BackendParameterException, + match="'query' argument is expected to be a ESQuery instance.", + ): + with caplog.at_level(logging.ERROR): + _ = [ + statement + async for statement in backend.read(query={"not_query": "foo"}) + ] + + assert ( + "ralph.backends.data.base", + logging.ERROR, + "The 'query' argument is expected to be a ESQuery instance. " + "[{'loc': ('not_query',), 'msg': 'extra fields not permitted', " + "'type': 'value_error.extra'}]", + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_create_operation( + es, async_es_backend, caplog +): + """Test the `AsyncESDataBackend.write` method, given an `CREATE` `operation_type`, + should insert the target documents with the provided data. + """ + # pylint: disable=invalid-name,unused-argument + + backend = async_es_backend() + assert len([statement async for statement in backend.read()]) == 0 + + # Given an empty data iterator, the write method should return 0 and log a message. + data = [] + with caplog.at_level(logging.INFO): + assert await backend.write(data, operation_type=BaseOperationType.CREATE) == 0 + + assert ( + "ralph.backends.data.async_es", + logging.INFO, + "Data Iterator is empty; skipping write to target.", + ) in caplog.record_tuples + + # Given an iterator with multiple documents, the write method should write the + # documents to the default target index. + data = ({"value": str(idx)} for idx in range(9)) + with caplog.at_level(logging.DEBUG): + assert ( + await backend.write( + data, chunk_size=5, operation_type=BaseOperationType.CREATE + ) + == 9 + ) + + write_records = 0 + for record in caplog.record_tuples: + if re.match(r"^Wrote 1 document \[action: \{.*\}\]$", record[2]): + write_records += 1 + assert write_records == 9 + + assert ( + "ralph.backends.data.async_es", + logging.INFO, + "Finished writing 9 documents with success", + ) in caplog.record_tuples + + hits = [statement async for statement in backend.read()] + assert [hit["_source"] for hit in hits] == [{"value": str(idx)} for idx in range(9)] + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_delete_operation( + es, + async_es_backend, +): + """Test the `AsyncESDataBackend.write` method, given a `DELETE` `operation_type`, + should remove the target documents. + """ + # pylint: disable=invalid-name,unused-argument + + backend = async_es_backend() + data = [{"id": idx, "value": str(idx)} for idx in range(10)] + + assert len([statement async for statement in backend.read()]) == 0 + assert await backend.write(data, chunk_size=5) == 10 + + data = [{"id": idx} for idx in range(3)] + + assert ( + await backend.write(data, chunk_size=5, operation_type=BaseOperationType.DELETE) + == 3 + ) + + hits = [statement async for statement in backend.read()] + assert len(hits) == 7 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(3, 10)) + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_update_operation( + es, + async_es_backend, +): + """Test the `AsyncESDataBackend.write` method, given an `UPDATE` + `operation_type`, should overwrite the target documents with the provided data. + """ + # pylint: disable=invalid-name,unused-argument + + backend = async_es_backend() + data = BytesIO( + "\n".join( + [json.dumps({"id": idx, "value": str(idx)}) for idx in range(10)] + ).encode("utf8") + ) + + assert len([statement async for statement in backend.read()]) == 0 + assert await backend.write(data, chunk_size=5) == 10 + + hits = [statement async for statement in backend.read()] + assert len(hits) == 10 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) + assert sorted([hit["_source"]["value"] for hit in hits]) == list( + map(str, range(10)) + ) + + data = BytesIO( + "\n".join( + [json.dumps({"id": idx, "value": str(10 + idx)}) for idx in range(10)] + ).encode("utf8") + ) + + assert ( + await backend.write(data, chunk_size=5, operation_type=BaseOperationType.UPDATE) + == 10 + ) + + hits = [statement async for statement in backend.read()] + assert len(hits) == 10 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) + assert sorted([hit["_source"]["value"] for hit in hits]) == list( + map(lambda x: str(x + 10), range(10)) + ) + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_append_operation( + async_es_backend, caplog +): + """Test the `AsyncESDataBackend.write` method, given an `APPEND` `operation_type`, + should raise a `BackendParameterException`. + """ + backend = async_es_backend() + msg = "Append operation_type is not supported." + with pytest.raises(BackendParameterException, match=msg): + with caplog.at_level(logging.ERROR): + await backend.write(data=[{}], operation_type=BaseOperationType.APPEND) + + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + "Append operation_type is not supported.", + ) in caplog.record_tuples + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_target( + es, async_es_backend +): + """Test the `AsyncESDataBackend.write` method, given a target index, should insert + documents to the corresponding index. + """ + # pylint: disable=invalid-name,unused-argument + + backend = async_es_backend() + + def get_data(): + """Yield data.""" + yield {"value": "1"} + yield {"value": "2"} + + # Create second Elasticsearch index. + for _ in get_es_fixture(index=ES_TEST_FORWARDING_INDEX): + # Both indexes should be empty. + assert len([statement async for statement in backend.read()]) == 0 + assert ( + len( + [ + statement + async for statement in backend.read(target=ES_TEST_FORWARDING_INDEX) + ] + ) + == 0 + ) + + # Write to forwarding index. + assert await backend.write(get_data(), target=ES_TEST_FORWARDING_INDEX) == 2 + + hits = [statement async for statement in backend.read()] + hits_with_target = [ + statement + async for statement in backend.read(target=ES_TEST_FORWARDING_INDEX) + ] + # No documents should be inserted into the default index. + assert not hits + # Documents should be inserted into the target index. + assert [hit["_source"] for hit in hits_with_target] == [ + {"value": "1"}, + {"value": "2"}, + ] + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_without_ignore_errors( + es, async_es_backend, caplog +): + """Test the `AsyncESDataBackend.write` method with `ignore_errors` set to `False`, + given badly formatted data, should raise a `BackendException`. + """ + # pylint: disable=invalid-name,unused-argument + + data = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] + # Patch a record with a non-expected type for the count field (should be + # assigned as long) + data[4].update({"count": "wrong"}) + + backend = async_es_backend() + assert len([statement async for statement in backend.read()]) == 0 + + # By default, we should raise an error and stop the importation. + msg = ( + r"1 document\(s\) failed to index. " + r"\[\{'index': \{'_index': 'test-index-foo', '_id': '4', 'status': 400, 'error'" + r": \{'type': 'mapper_parsing_exception', 'reason': \"failed to parse field " + r"\[count\] of type \[long\] in document with id '4'. Preview of field's value:" + r" 'wrong'\", 'caused_by': \{'type': 'illegal_argument_exception', 'reason': " + r"'For input string: \"wrong\"'\}\}, 'data': \{'id': 4, 'count': 'wrong'\}\}\}" + r"\] Total succeeded writes: 5" + ) + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + await backend.write(data, chunk_size=2) + + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + msg.replace("\\", ""), + ) in caplog.record_tuples + + es.indices.refresh(index=ES_TEST_INDEX) + hits = [statement async for statement in backend.read()] + assert len(hits) == 5 + assert sorted([hit["_source"]["id"] for hit in hits]) == [0, 1, 2, 3, 5] + + # Given an unparsable binary JSON document, the write method should raise a + # `BackendException`. + data = [ + json.dumps({"foo": "bar"}).encode("utf-8"), + "This is invalid JSON".encode("utf-8"), + json.dumps({"foo": "baz"}).encode("utf-8"), + ] + + # By default, we should raise an error and stop the importation. + msg = ( + r"Failed to decode JSON: Expecting value: line 1 column 1 \(char 0\), " + r"for document: b'This is invalid JSON'" + ) + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + await backend.write(data, chunk_size=2) + + assert ( + "ralph.backends.data.async_es", + logging.ERROR, + msg.replace("\\", ""), + ) in caplog.record_tuples + + es.indices.refresh(index=ES_TEST_INDEX) + hits = [statement async for statement in backend.read()] + assert len(hits) == 5 + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_ignore_errors( + es, async_es_backend +): + """Test the `AsyncESDataBackend.write` method with `ignore_errors` set to `True`, + given badly formatted data, should should skip the invalid data. + """ + # pylint: disable=invalid-name,unused-argument + + records = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] + # Patch a record with a non-expected type for the count field (should be + # assigned as long) + records[2].update({"count": "wrong"}) + + backend = async_es_backend() + assert len([statement async for statement in backend.read()]) == 0 + + assert await backend.write(records, chunk_size=2, ignore_errors=True) == 9 + + es.indices.refresh(index=ES_TEST_INDEX) + hits = [statement async for statement in backend.read()] + assert len(hits) == 9 + assert sorted([hit["_source"]["id"] for hit in hits]) == [ + i for i in range(10) if i != 2 + ] + + # Given an unparsable binary JSON document, the write method should skip it. + data = [ + json.dumps({"foo": "bar"}).encode("utf-8"), + "This is invalid JSON".encode("utf-8"), + json.dumps({"foo": "baz"}).encode("utf-8"), + ] + assert await backend.write(data, chunk_size=2, ignore_errors=True) == 2 + + es.indices.refresh(index=ES_TEST_INDEX) + hits = [statement async for statement in backend.read()] + assert len(hits) == 11 + assert [hit["_source"] for hit in hits[9:]] == [{"foo": "bar"}, {"foo": "baz"}] + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_async_es_data_backend_write_method_with_datastream( + es_data_stream, async_es_backend +): + """Test the `AsyncESDataBackend.write` method using a configured data stream.""" + # pylint: disable=invalid-name,unused-argument + + data = [{"id": idx, "@timestamp": datetime.now().isoformat()} for idx in range(10)] + backend = async_es_backend() + assert len([statement async for statement in backend.read()]) == 0 + assert ( + await backend.write(data, chunk_size=5, operation_type=BaseOperationType.CREATE) + == 10 + ) + + hits = [statement async for statement in backend.read()] + assert len(hits) == 10 + assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_es_data_backend_close_method( + async_es_backend, monkeypatch +): + """Test the `AsyncESDataBackend.close` method.""" + + backend = async_es_backend() + + async def mock_connection_error(): + """ES client close mock that raises a connection error.""" + raise ESConnectionError("", (Exception("Mocked connection error"),)) + + monkeypatch.setattr(backend.client, "close", mock_connection_error) + + with pytest.raises(BackendException, match="Failed to close Elasticsearch client"): + await backend.close() diff --git a/tests/backends/lrs/test_async_es.py b/tests/backends/lrs/test_async_es.py new file mode 100644 index 000000000..9dd9e7466 --- /dev/null +++ b/tests/backends/lrs/test_async_es.py @@ -0,0 +1,421 @@ +"""Tests for Ralph Elasticsearch LRS backend.""" + +import logging +import re +from datetime import datetime + +import pytest +from elastic_transport import ApiResponseMeta +from elasticsearch import ApiError +from elasticsearch.helpers import bulk + +from ralph.backends.lrs.base import StatementParameters +from ralph.exceptions import BackendException + +from tests.fixtures.backends import ES_TEST_FORWARDING_INDEX, ES_TEST_INDEX + + +@pytest.mark.parametrize( + "params,expected_query", + [ + # 0. Default query. + ( + {}, + { + "pit": {"id": None, "keep_alive": None}, + "query": {"match_all": {}}, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 1. Query by statementId. + ( + {"statementId": "statementId"}, + { + "pit": {"id": None, "keep_alive": None}, + "query": {"bool": {"filter": [{"term": {"_id": "statementId"}}]}}, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 2. Query by statementId and agent with mbox IFI. + ( + {"statementId": "statementId", "agent": {"mbox": "mailto:foo@bar.baz"}}, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + {"term": {"actor.mbox.keyword": "mailto:foo@bar.baz"}}, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 3. Query by statementId and agent with mbox_sha1sum IFI. + ( + { + "statementId": "statementId", + "agent": {"mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64"}, + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + { + "term": { + "actor.mbox_sha1sum.keyword": ( + "a7a5b7462b862c8c8767d43d43e865ffff754a64" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 4. Query by statementId and agent with openid IFI. + ( + { + "statementId": "statementId", + "agent": {"openid": "http://toby.openid.example.org/"}, + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + { + "term": { + "actor.openid.keyword": ( + "http://toby.openid.example.org/" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 5. Query by statementId and agent with account IFI. + ( + { + "statementId": "statementId", + "agent": { + "account__home_page": "http://www.example.com", + "account__name": "13936749", + }, + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + {"term": {"_id": "statementId"}}, + {"term": {"actor.account.name.keyword": ("13936749")}}, + { + "term": { + "actor.account.homePage.keyword": ( + "http://www.example.com" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 6. Query by verb and activity. + ( + { + "verb": "http://adlnet.gov/expapi/verbs/attended", + "activity": "http://www.example.com/meetings/34534", + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + { + "term": { + "verb.id.keyword": ( + "http://adlnet.gov/expapi/verbs/attended" + ) + } + }, + {"term": {"object.objectType.keyword": "Activity"}}, + { + "term": { + "object.id.keyword": ( + "http://www.example.com/meetings/34534" + ) + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 7. Query by timerange (with since/until). + ( + { + "since": "2021-06-24T00:00:20.194929+00:00", + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "pit": {"id": None, "keep_alive": None}, + "query": { + "bool": { + "filter": [ + { + "range": { + "timestamp": { + "gt": datetime.fromisoformat( + "2021-06-24T00:00:20.194929+00:00" + ) + } + } + }, + { + "range": { + "timestamp": { + "lte": datetime.fromisoformat( + "2023-06-24T00:00:20.194929+00:00" + ) + } + } + }, + ] + } + }, + "query_string": None, + "search_after": None, + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 8. Query with pagination and pit_id. + ( + {"search_after": "1686557542970|0", "pit_id": "46ToAwMDaWR5BXV1a"}, + { + "pit": {"id": "46ToAwMDaWR5BXV1a", "keep_alive": None}, + "query": {"match_all": {}}, + "query_string": None, + "search_after": ["1686557542970", "0"], + "size": None, + "sort": [{"timestamp": {"order": "desc"}}], + "track_total_hits": False, + }, + ), + # 9. Query ignoring statement sort order. + ( + {"ignore_order": True}, + { + "pit": {"id": None, "keep_alive": None}, + "query": {"match_all": {}}, + "query_string": None, + "search_after": None, + "size": None, + "sort": "_shard_doc", + "track_total_hits": False, + }, + ), + ], +) +@pytest.mark.anyio +async def test_backends_lrs_async_es_lrs_backend_query_statements_query( + params, expected_query, async_es_lrs_backend, monkeypatch +): + """Test the `AsyncESLRSBackend.query_statements` method, given valid statement + parameters, should produce the expected Elasticsearch query. + """ + + async def mock_read(query, chunk_size): + """Mock the `AsyncESLRSBackend.read` method.""" + assert query.dict() == expected_query + assert chunk_size == expected_query.get("size") + query.pit.id = "foo_pit_id" + query.search_after = ["bar_search_after", "baz_search_after"] + yield {"_source": {}} + + backend = async_es_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + result = await backend.query_statements(StatementParameters(**params)) + assert result.statements == [{}] + assert result.pit_id == "foo_pit_id" + assert result.search_after == "bar_search_after|baz_search_after" + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_lrs_async_es_lrs_backend_query_statements( + es, async_es_lrs_backend +): + """Test the `AsyncESLRSBackend.query_statements` method, given a query, + should return matching statements. + """ + # pylint: disable=invalid-name, unused-argument + # Instantiate AsyncESLRSBackend. + backend = async_es_lrs_backend() + # Insert documents. + documents = [{"id": "2", "timestamp": "2023-06-24T00:00:20.194929+00:00"}] + assert await backend.write(documents) == 1 + + # Check the expected search query results. + result = await backend.query_statements(StatementParameters(limit=10)) + assert result.statements == documents + assert re.match(r"[0-9]+\|0", result.search_after) + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_lrs_async_es_lrs_backend_query_statements_pit_query_failure( + es, async_es_lrs_backend, monkeypatch, caplog +): + """Test the `AsyncESLRSBackend.query_statements` method, given a point in time + query failure, should raise a `BackendException` and log the error. + """ + # pylint: disable=invalid-name,unused-argument + + async def mock_read(**_): + """Mock the Elasticsearch.read method.""" + yield {"_source": {}} + raise BackendException("Query error") + + backend = async_es_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + msg = "Query error" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + await backend.query_statements(StatementParameters()) + + await backend.close() + + assert ( + "ralph.backends.lrs.async_es", + logging.ERROR, + "Failed to read from Elasticsearch", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_lrs_es_lrs_backend_query_statements_by_ids_search_query_failure( + es, async_es_lrs_backend, monkeypatch, caplog +): + """Test the `AsyncESLRSBackend.query_statements_by_ids` method, given a search + query failure, should raise a `BackendException` and log the error. + """ + # pylint: disable=invalid-name,unused-argument + + def mock_search(**_): + """Mock the Elasticsearch.search method.""" + raise ApiError("Query error", ApiResponseMeta(*([None] * 5)), None) + + backend = async_es_lrs_backend() + monkeypatch.setattr(backend.client, "search", mock_search) + + msg = r"Failed to execute Elasticsearch query: ApiError\(None, 'Query error'\)" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + _ = [ + statement + async for statement in backend.query_statements_by_ids( + StatementParameters() + ) + ] + + await backend.close() + + assert ( + "ralph.backends.lrs.async_es", + logging.ERROR, + "Failed to read from Elasticsearch", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_lrs_async_es_lrs_backend_query_statements_by_ids_many_indexes( + es, es_forwarding, async_es_lrs_backend +): + """Test the `AsyncESLRSBackend.query_statements_by_ids` method, given a valid + search query, should execute the query uniquely on the specified index and return + the expected results. + """ + # pylint: disable=invalid-name + + # Insert documents. + index_1_document = {"_index": ES_TEST_INDEX, "_id": "1", "_source": {"id": "1"}} + index_2_document = { + "_index": ES_TEST_FORWARDING_INDEX, + "_id": "2", + "_source": {"id": "2"}, + } + bulk(es, [index_1_document]) + bulk(es_forwarding, [index_2_document]) + + # As we bulk insert documents, the index needs to be refreshed before making + # queries. + es.indices.refresh(index=ES_TEST_INDEX) + es_forwarding.indices.refresh(index=ES_TEST_FORWARDING_INDEX) + + # Instantiate AsyncESLRSBackends. + backend_1 = async_es_lrs_backend(index=ES_TEST_INDEX) + backend_2 = async_es_lrs_backend(index=ES_TEST_FORWARDING_INDEX) + + # Check the expected search query results. + index_1_document = {"id": "1"} + index_2_document = {"id": "2"} + assert [ + statement async for statement in backend_1.query_statements_by_ids(["1"]) + ] == [index_1_document] + assert not [ + statement async for statement in backend_1.query_statements_by_ids(["2"]) + ] + assert not [ + statement async for statement in backend_2.query_statements_by_ids(["1"]) + ] + assert [ + statement async for statement in backend_2.query_statements_by_ids(["2"]) + ] == [index_2_document] + + await backend_1.close() + await backend_2.close() diff --git a/tests/backends/lrs/test_es.py b/tests/backends/lrs/test_es.py index 802231fbb..89dbf5b45 100644 --- a/tests/backends/lrs/test_es.py +++ b/tests/backends/lrs/test_es.py @@ -307,7 +307,7 @@ def test_backends_lrs_es_lrs_backend_query_statements_with_search_query_failure( # pylint: disable=invalid-name,unused-argument def mock_read(**_): - """Mock the Elasticsearch.search method.""" + """Mock the Elasticsearch.read method.""" raise BackendException("Query error") backend = es_lrs_backend() diff --git a/tests/conftest.py b/tests/conftest.py index caa27c0d3..b6ec8e2cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,8 @@ ) from .fixtures.backends import ( # noqa: F401 anyio_backend, + async_es_backend, + async_es_lrs_backend, clickhouse, clickhouse_backend, clickhouse_lrs_backend, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 60efcf7a7..ba83ebb54 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -21,6 +21,7 @@ from pymongo import MongoClient from pymongo.errors import CollectionInvalid +from ralph.backends.data.async_es import AsyncESDataBackend from ralph.backends.data.clickhouse import ClickHouseDataBackend from ralph.backends.data.es import ESDataBackend from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings @@ -30,6 +31,7 @@ from ralph.backends.database.clickhouse import ClickHouseDatabase from ralph.backends.database.es import ESDatabase from ralph.backends.database.mongo import MongoDatabase +from ralph.backends.lrs.async_es import AsyncESLRSBackend from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.fs import FSLRSBackend @@ -281,7 +283,6 @@ def es_data_stream(): client. """ client = Elasticsearch(ES_TEST_HOSTS) - # Create statements index template with enabled data stream index_patterns = [ES_TEST_INDEX_PATTERN] data_stream = {} @@ -295,9 +296,9 @@ def es_data_stream(): "dynamic_templates": [], "date_detection": True, "numeric_detection": True, - # Note: We define an explicit mapping of the `timestamp` field to allow the - # Elasticsearch database to be queried even if no document has been inserted - # before. + # Note: We define an explicit mapping of the `timestamp` field to allow + # the Elasticsearch database to be queried even if no document has + # been inserted before. "properties": { "timestamp": { "type": "date", @@ -360,6 +361,49 @@ def get_ldp_data_backend(service_name: str = "foo", stream_id: str = "bar"): return get_ldp_data_backend +@pytest.fixture +def async_es_backend(): + """Return the `get_async_es_data_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name,unused-argument + + def get_async_es_data_backend(): + """Return an instance of AsyncESDataBackend.""" + settings = AsyncESDataBackend.settings_class( + ALLOW_YELLOW_STATUS=False, + CLIENT_OPTIONS={"ca_certs": None, "verify_certs": None}, + DEFAULT_CHUNK_SIZE=500, + DEFAULT_INDEX=ES_TEST_INDEX, + HOSTS=ES_TEST_HOSTS, + LOCALE_ENCODING="utf8", + REFRESH_AFTER_WRITE=True, + ) + return AsyncESDataBackend(settings) + + return get_async_es_data_backend + + +@pytest.fixture +def async_es_lrs_backend(): + """Return the `get_async_es_lrs_backend` function.""" + # pylint: disable=invalid-name,redefined-outer-name,unused-argument + + def get_async_es_lrs_backend(index: str = ES_TEST_INDEX): + """Return an instance of AsyncESLRSBackend.""" + settings = AsyncESLRSBackend.settings_class( + ALLOW_YELLOW_STATUS=False, + CLIENT_OPTIONS={"ca_certs": None, "verify_certs": None}, + DEFAULT_CHUNK_SIZE=500, + DEFAULT_INDEX=index, + HOSTS=ES_TEST_HOSTS, + LOCALE_ENCODING="utf8", + POINT_IN_TIME_KEEP_ALIVE="1m", + REFRESH_AFTER_WRITE=True, + ) + return AsyncESLRSBackend(settings) + + return get_async_es_lrs_backend + + @pytest.fixture def clickhouse_backend(): """Return the `get_clickhouse_data_backend` function.""" From 12620e91f0fde38348e7ccfbaf82fb91d1b6903a Mon Sep 17 00:00:00 2001 From: SergioSim Date: Thu, 22 Jun 2023 14:53:53 +0200 Subject: [PATCH 18/65] =?UTF-8?q?=E2=9C=85(backends)=20update=20MongoDataB?= =?UTF-8?q?ackend=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We want to improve the current mongo data backend implementation to align it more with other unified data backends. --- src/ralph/backends/data/async_es.py | 4 +- src/ralph/backends/data/es.py | 4 +- src/ralph/backends/data/fs.py | 27 +- src/ralph/backends/data/mongo.py | 610 +++++----- src/ralph/backends/lrs/mongo.py | 137 +++ src/ralph/utils.py | 20 +- tests/backends/data/test_fs.py | 50 +- tests/backends/data/test_mongo.py | 1654 ++++++++++++--------------- tests/backends/data/test_swift.py | 2 +- tests/backends/lrs/test_mongo.py | 370 ++++++ tests/conftest.py | 2 + tests/fixtures/backends.py | 60 +- 12 files changed, 1618 insertions(+), 1322 deletions(-) create mode 100644 src/ralph/backends/lrs/mongo.py create mode 100644 tests/backends/lrs/test_mongo.py diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 76ede8803..7f987b197 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -180,7 +180,9 @@ async def read( query.search_after = [str(part) for part in documents[-1]["sort"]] kwargs["search_after"] = query.search_after if raw_output: - documents = read_raw(documents, self.settings.LOCALE_ENCODING) + documents = read_raw( + documents, self.settings.LOCALE_ENCODING, ignore_errors, logger + ) for document in documents: yield document diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index b080e85cb..f131466f2 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -264,7 +264,9 @@ def read( query.search_after = [str(part) for part in documents[-1]["sort"]] kwargs["search_after"] = query.search_after if raw_output: - documents = read_raw(documents, self.settings.LOCALE_ENCODING) + documents = read_raw( + documents, self.settings.LOCALE_ENCODING, ignore_errors, logger + ) for document in documents: yield document diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 54ba86b28..454c7209d 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -27,7 +27,7 @@ class FSDataBackendSettings(BaseDataBackendSettings): - """Represents the FileSystem data backend default configuration. + """FileSystem data backend default configuration. Attributes: DEFAULT_CHUNK_SIZE (int): The default chunk size for reading files. @@ -57,7 +57,12 @@ class FSDataBackend(HistoryMixin, BaseDataBackend): settings_class = FSDataBackendSettings def __init__(self, settings: settings_class = None): - """Creates the default target directory if it does not exist.""" + """Create the default target directory if it does not exist. + + Args: + settings (FSDataBackendSettings or None): The data backend settings. + If `settings` is `None`, a default settings instance is used instead. + """ self.settings = settings if settings else self.settings_class() self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE self.default_directory = self.settings.DEFAULT_DIRECTORY_PATH @@ -72,7 +77,7 @@ def __init__(self, settings: settings_class = None): logger.debug("Default directory: %s", self.default_directory) def status(self) -> DataBackendStatus: - """Checks whether the default directory has appropriate permissions.""" + """Check whether the default directory has appropriate permissions.""" for mode in [os.R_OK, os.W_OK, os.X_OK]: if not os.access(self.default_directory, mode): logger.error( @@ -87,7 +92,7 @@ def status(self) -> DataBackendStatus: def list( self, target: str = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: - """Lists files and directories in the target directory. + """List files and directories in the target directory. Args: target (str or None): The directory path where to list the files and @@ -146,7 +151,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: - """Reads files matching the query in the target folder and yields them. + """Read files matching the query in the target folder and yield them. Args: query: (str or BaseQuery): The relative pattern for the files to read. @@ -217,7 +222,7 @@ def write( # pylint: disable=too-many-arguments ignore_errors: bool = False, operation_type: Union[None, BaseOperationType] = None, ) -> int: - """Writes data records to the target file and return their count. + """Write data records to the target file and return their count. Args: data: (Iterable or IOBase): The data to write. @@ -241,7 +246,7 @@ def write( # pylint: disable=too-many-arguments Raises: BackendException: If the `operation_type` is `CREATE` or `INDEX` and the target file already exists. - BackendParameterException: If the `operation_type` is `DELETED` as it is not + BackendParameterException: If the `operation_type` is `DELETE` as it is not supported. """ data = iter(data) @@ -305,13 +310,13 @@ def write( # pylint: disable=too-many-arguments @staticmethod def _read_raw(file: IO, chunk_size: int, _ignore_errors: bool) -> Iterator[bytes]: - """Reads the `file` in chunks of size `chunk_size` and yields them.""" + """Read the `file` in chunks of size `chunk_size` and yield them.""" while chunk := file.read(chunk_size): yield chunk @staticmethod def _read_dict(file: IO, _chunk_size: int, ignore_errors: bool) -> Iterator[dict]: - """Reads the `file` by line and yields JSON parsed dictionaries.""" + """Read the `file` by line and yield JSON parsed dictionaries.""" for i, line in enumerate(file): try: yield json.loads(line) @@ -323,9 +328,9 @@ def _read_dict(file: IO, _chunk_size: int, ignore_errors: bool) -> Iterator[dict @staticmethod def _write_raw(file: IO, chunk: bytes) -> None: - """Writes the `chunk` bytes to the file.""" + """Write the `chunk` bytes to the file.""" file.write(chunk) def _write_dict(self, file: IO, chunk: dict) -> None: - """Writes the `chunk` dictionary to the file.""" + """Write the `chunk` dictionary to the file.""" file.write(bytes(f"{json.dumps(chunk)}\n", encoding=self.locale_encoding)) diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index befbb94c4..c7dcd1ae4 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -1,29 +1,22 @@ """MongoDB data backend for Ralph.""" +from __future__ import annotations + import hashlib -import json import logging import struct from io import IOBase from itertools import chain -from typing import ( - Any, - Dict, - Generator, - Iterable, - Iterator, - List, - Literal, - Optional, - Union, -) +from typing import Generator, Iterable, Iterator, List, Optional, Tuple, Union from uuid import uuid4 +from bson.errors import BSONError from bson.objectid import ObjectId from dateutil.parser import isoparse -from pydantic import Json -from pymongo import ASCENDING, DESCENDING, MongoClient, ReplaceOne -from pymongo.errors import BulkWriteError, ConnectionFailure, PyMongoError +from pydantic import Json, MongoDsn, constr +from pymongo import MongoClient, ReplaceOne +from pymongo.collection import Collection +from pymongo.errors import BulkWriteError, ConnectionFailure, InvalidName, PyMongoError from ralph.backends.data.base import ( BaseDataBackend, @@ -33,30 +26,28 @@ DataBackendStatus, enforce_query_checks, ) -from ralph.backends.lrs.base import ( - BaseLRSBackend, - StatementParameters, - StatementQueryResult, -) -from ralph.conf import BaseSettingsConfig, MongoClientOptions -from ralph.exceptions import ( - BackendException, - BackendParameterException, - BadFormatException, -) +from ralph.conf import BaseSettingsConfig, ClientOptions +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import parse_bytes_to_dict, read_raw logger = logging.getLogger(__name__) +class MongoClientOptions(ClientOptions): + """MongoDB additional client options.""" + + document_class: str = None + tz_aware: bool = None + + class MongoDataBackendSettings(BaseDataBackendSettings): - """Represents the Mongo data backend default configuration. + """MongoDB data backend default configuration. Attributes: CONNECTION_URI (str): The MongoDB connection URI. - DATABASE (str): The MongoDB database to connect to. + DEFAULT_DATABASE (str): The MongoDB database to connect to. DEFAULT_COLLECTION (str): The MongoDB database collection to get objects from. - CLIENT_OPTIONS (MongoClientOptions): A dictionary of valid options - DEFAULT_QUERY_STRING (str): The default query string to use. + CLIENT_OPTIONS (MongoClientOptions): A dictionary of MongoDB client options. DEFAULT_CHUNK_SIZE (int): The default chunk size to use when none is provided. LOCALE_ENCODING (str): The locale encoding to use when none is provided. """ @@ -66,70 +57,73 @@ class Config(BaseSettingsConfig): env_prefix = "RALPH_BACKENDS__DATA__MONGO__" - CONNECTION_URI: str = None - DATABASE: str = None - DEFAULT_COLLECTION: str = None + CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/", scheme="mongodb") + DEFAULT_DATABASE: constr(regex=r"^[^\s.$/\\\"\x00]+$") = "statements" # noqa : F722 + DEFAULT_COLLECTION: constr( + regex=r"^(?!.*\.\.)[^.$\x00]+(?:\.[^.$\x00]+)*$" # noqa : F722 + ) = "marsha" CLIENT_OPTIONS: MongoClientOptions = MongoClientOptions() - DEFAULT_QUERY_STRING: str = "*" DEFAULT_CHUNK_SIZE: int = 500 LOCALE_ENCODING: str = "utf8" -class MongoQuery(BaseQuery): - """Mongo query model.""" +class BaseMongoQuery(BaseQuery): + """Base MongoDB query model.""" - # pylint: disable=unsubscriptable-object - query_string: Optional[ - Json[ - Dict[ - Literal["filter", "projection"], - dict, - ] - ] - ] filter: Optional[dict] + limit: Optional[int] projection: Optional[dict] + sort: Optional[List[Tuple]] + + +class MongoQuery(BaseMongoQuery): + """MongoDB query model.""" + + # pylint: disable=unsubscriptable-object + query_string: Optional[Json[BaseMongoQuery]] class MongoDataBackend(BaseDataBackend): - """Mongo database backend.""" + """MongoDB data backend.""" name = "mongo" query_model = MongoQuery - default_operation_type = BaseOperationType.CREATE settings_class = MongoDataBackendSettings def __init__(self, settings: settings_class = None): - """Instantiates the Mongo client. + """Instantiate the MongoDB client. Args: - settings (MongoDataBackendSettings): The Mongo data backend settings. - CONNECTION_URI (str): The MongoDB connection URI. - DATABASE (str): The MongoDB database to connect to. - DEFAULT_COLLECTION (str): The MongoDB database collection. - CLIENT_OPTIONS (MongoClientOptions): A dictionary of valid options - DEFAULT_QUERY_STRING (str): The default query string to use. - DEFAULT_CHUNK_SIZE (int): The default chunk size to use. - LOCALE_ENCODING (str): The locale encoding to use when none is provided. + settings (MongoDataBackendSettings or None): The data backend settings. + If `settings` is `None`, a default settings instance is used instead. """ + self.settings = settings if settings else self.settings_class() self.client = MongoClient( - settings.CONNECTION_URI, **settings.CLIENT_OPTIONS.dict() + self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.dict() ) - self.database = getattr(self.client, settings.DATABASE) - self.collection = getattr(self.database, settings.DEFAULT_COLLECTION) - self.default_chunk_size = settings.DEFAULT_CHUNK_SIZE - self.locale_encoding = settings.LOCALE_ENCODING + self.database = self.client[self.settings.DEFAULT_DATABASE] + self.collection = self.database[self.settings.DEFAULT_COLLECTION] def status(self) -> DataBackendStatus: - """Checks MongoDB cluster connection status.""" - # Check Mongo cluster connection + """Check the MongoDB connection status. + + Returns: + DataBackendStatus: The status of the data backend. + """ + # Check MongoDB connection. try: self.client.admin.command("ping") - except ConnectionFailure: + except ConnectionFailure as error: + logger.error("Failed to connect to MongoDB: %s", error) return DataBackendStatus.AWAY - # Check cluster status - if self.client.admin.command("serverStatus").get("ok", 0.0) < 1.0: + # Check MongoDB server status. + try: + if self.client.admin.command("serverStatus").get("ok") != 1.0: + logger.error("MongoDB `serverStatus` command did not return 1.0") + return DataBackendStatus.ERROR + except PyMongoError as error: + logger.error("Failed to get MongoDB server status: %s", error) return DataBackendStatus.ERROR return DataBackendStatus.OK @@ -137,19 +131,38 @@ def status(self) -> DataBackendStatus: def list( self, target: str = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: - """Lists collections for a given database. + """List collections in the `target` database. Args: - target (str): The database to list collections from. - details (bool): Get detailed archive information instead of just ids. - new (bool): Given the history, list only not already fetched collections. + target (str or None): The MongoDB database name to list collections from. + If target is `None`, the `DEFAULT_DATABASE` is used instead. + details (bool): Get detailed collection information instead of just IDs. + new (bool): Ignored. + + Raises: + BackendException: If a failure during the list operation occurs. + BackendParameterException: If the `target` is not a valid database name. """ - database = self.database if not target else getattr(self.client, target) - for col in database.list_collections(): - if details: - yield col - else: - yield str(col.get("name")) + if new: + logger.warning("The `new` argument is ignored") + + try: + database = self.client[target] if target else self.database + except InvalidName as error: + msg = "The target=`%s` is not a valid database name: %s" + logger.error(msg, target, error) + raise BackendParameterException(msg % (target, error)) from error + + try: + for collection_info in database.list_collections(): + if details: + yield collection_info + else: + yield collection_info.get("name") + except PyMongoError as error: + msg = "Failed to list MongoDB collections: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error @enforce_query_checks def read( @@ -161,125 +174,54 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: - """Gets collection documents and yields them. + """Read documents matching the `query` from `target` collection and yield them. Args: - query (Union[str, MongoQuery]): The query to use when fetching documents. - target (str): The collection to get documents from. - chunk_size (Union[None, int]): The chunk size to use when fetching docs. - raw_output (bool): Whether to return raw bytes or deserialized documents. + query (str or MongoQuery): The MongoDB query to use when reading documents. + target (str or None): The MongoDB collection name to query. + If target is `None`, the `DEFAULT_COLLECTION` is used instead. + chunk_size (int or None): The chunk size for reading batches of documents. + If chunk_size is `None` the `DEFAULT_CHUNK_SIZE` is used instead. + raw_output (bool): Whether to yield dictionaries or bytes. ignore_errors (bool): Whether to ignore errors when reading documents. - """ - reader = self._read_raw if raw_output else self._read_dict - if not chunk_size: - chunk_size = self.default_chunk_size - find_kwargs = {} - if query.query_string: - find_kwargs = query.query_string - else: - find_kwargs = {"filter": query.filter, "projection": query.projection} - - # deserialize query_string if exists - for document in self._find(target=target, batch_size=chunk_size, **find_kwargs): - document.update({"_id": str(document.get("_id"))}) - yield reader(document) - @staticmethod - def to_documents( - data: Iterable[dict], - ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = default_operation_type, - ) -> Generator[dict, None, None]: - """Converts `stream` lines (one statement per line) to Mongo documents. + Yields: + dict: If `raw_output` is False. + bytes: If `raw_output` is True. - We expect statements to have at least an `id` and a `timestamp` field that will - be used to compute a unique MongoDB Object ID. This ensures that we will not - duplicate statements in our database and allows us to support pagination. + Raises: + BackendException: If a failure during the read operation occurs. + BackendParameterException: If the `target` is not a valid collection name. """ - for statement in data: - if "id" not in statement and operation_type == BaseOperationType.CREATE: - msg = f"statement {statement} has no 'id' field" - if ignore_errors: - logger.warning(msg) - continue - raise BadFormatException(msg) - if "timestamp" not in statement: - msg = f"statement {statement} has no 'timestamp' field" - if ignore_errors: - logger.warning(msg) - continue - raise BadFormatException(msg) - try: - timestamp = int(isoparse(statement["timestamp"]).timestamp()) - except ValueError as err: - msg = f"statement {statement} has an invalid 'timestamp' field" - if ignore_errors: - logger.warning(msg) - continue - raise BadFormatException(msg) from err - document = { - "_id": ObjectId( - # This might become a problem in February 2106. - # Meanwhile, we use the timestamp in the _id field for pagination. - struct.pack(">I", timestamp) - + bytes.fromhex( - hashlib.sha256( - bytes(statement.get("id", str(uuid4())), "utf-8") - ).hexdigest()[:16] - ) - ), - "_source": statement, - } + if not chunk_size: + chunk_size = self.settings.DEFAULT_CHUNK_SIZE - yield document + query = (query.query_string if query.query_string else query).dict( + exclude={"query_string"}, exclude_unset=True + ) - def bulk_import(self, batch: list, ignore_errors: bool = False, collection=None): - """Inserts a batch of documents into the selected database collection.""" try: - collection = self.get_collection(collection) - new_documents = collection.insert_many(batch) - except BulkWriteError as error: - if not ignore_errors: - raise BackendException( - *error.args, f"{error.details['nInserted']} succeeded writes" - ) from error - logger.warning( - "Bulk importation failed for current documents chunk but you choose " - "to ignore it.", - ) - return error.details["nInserted"] - - inserted_count = len(new_documents.inserted_ids) - logger.debug("Inserted %d documents chunk with success", inserted_count) - - return inserted_count - - def bulk_delete(self, batch: list, collection=None): - """Deletes a batch of documents from the selected database collection.""" - collection = self.get_collection(collection) - new_documents = collection.delete_many({"_source.id": {"$in": batch}}) - deleted_count = new_documents.deleted_count - logger.debug("Deleted %d documents chunk with success", deleted_count) + collection = self.database[target] if target else self.collection + except InvalidName as error: + msg = "The target=`%s` is not a valid collection name: %s" + logger.error(msg, target, error) + raise BackendParameterException(msg % (target, error)) from error - return deleted_count - - def bulk_update(self, batch: list, collection=None): - """Update a batch of documents into the selected database collection.""" - collection = self.get_collection(collection) - new_documents = collection.bulk_write(batch) - modified_count = new_documents.modified_count - logger.debug("Updated %d documents chunk with success", modified_count) - return modified_count - - def get_collection(self, collection=None): - """Returns the collection to use for the current operation.""" - if collection is None: - collection = self.collection - elif isinstance(collection, str): - collection = getattr(self.database, collection) - return collection + try: + documents = collection.find(batch_size=chunk_size, **query) + documents = (d.update({"_id": str(d.get("_id"))}) or d for d in documents) + if raw_output: + documents = read_raw( + documents, self.settings.LOCALE_ENCODING, ignore_errors, logger + ) + for document in documents: + yield document + except (PyMongoError, IndexError, TypeError, ValueError) as error: + msg = "Failed to execute MongoDB query: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error - def write( # pylint: disable=too-many-arguments disable=too-many-branches + def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], target: Union[None, str] = None, @@ -287,22 +229,39 @@ def write( # pylint: disable=too-many-arguments disable=too-many-branches ignore_errors: bool = False, operation_type: Union[None, BaseOperationType] = None, ) -> int: - """Writes documents from the `stream` to the instance collection. + """Write `data` documents to the `target` collection and return their count. Args: - data: The data to write to the database. - target: The target collection to write to. - chunk_size: The number of documents to write at once. - ignore_errors: Whether to ignore errors or not. - operation_type: The operation type to use for the write operation. + data (Iterable or IOBase): The data containing documents to write. + target (str or None): The target MongoDB collection name. + chunk_size (int or None): The number of documents to write in one batch. + If chunk_size is `None` the `DEFAULT_CHUNK_SIZE` is used instead. + ignore_errors (bool): Whether to ignore errors or not. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Returns: + int: The number of written documents. + + Raises: + BackendException: If a failure occurs while writing to MongoDB or + during document decoding and `ignore_errors` is set to `False`. + BackendParameterException: If the `operation_type` is `APPEND` as it is not + supported. """ if not operation_type: operation_type = self.default_operation_type + if operation_type == BaseOperationType.APPEND: + msg = "Append operation_type is not allowed." + logger.error(msg) + raise BackendParameterException(msg) + if not chunk_size: - chunk_size = self.default_chunk_size + chunk_size = self.settings.DEFAULT_CHUNK_SIZE - collection = self.get_collection(target) + collection = self.database[target] if target else self.collection logger.debug( "Start writing to the %s collection of the %s database (chunk size: %d)", collection, @@ -310,166 +269,157 @@ def write( # pylint: disable=too-many-arguments disable=too-many-branches chunk_size, ) + count = 0 data = iter(data) try: first_record = next(data) - data = chain([first_record], data) - if isinstance(first_record, bytes): - data = self._parse_bytes_to_dict(data, ignore_errors) except StopIteration: logger.info("Data Iterator is empty; skipping write to target.") - return 0 + return count + data = chain([first_record], data) + if isinstance(first_record, bytes): + data = parse_bytes_to_dict(data, ignore_errors, logger) - success = 0 - batch = [] if operation_type == BaseOperationType.UPDATE: - for document in data: - document_id = document.get("id") - batch.append( - ReplaceOne( - {"_source.id": {"$eq": document_id}}, - {"_source": document}, - ) - ) - if len(batch) >= chunk_size: - success += self.bulk_update(batch, collection=collection) - batch = [] - - if len(batch) > 0: - success += self.bulk_update(batch, collection=collection) - - logger.debug("Updated %d documents chunk with success", success) + for batch in self.iter_by_batch(self.to_replace_one(data), chunk_size): + count += self._bulk_update(batch, ignore_errors, collection) + logger.info("Updated %d documents with success", count) elif operation_type == BaseOperationType.DELETE: - for document in data: - document_id = document.get("id") - batch.append(document_id) - if len(batch) >= chunk_size: - success += self.bulk_delete(batch, collection=collection) - batch = [] - - if len(batch) > 0: - success += self.bulk_delete(batch, collection=collection) - - logger.debug("Deleted %d documents chunk with success", success) - elif operation_type in [BaseOperationType.INDEX, BaseOperationType.CREATE]: - for document in self.to_documents( - data, ignore_errors=ignore_errors, operation_type=operation_type - ): - batch.append(document) - if len(batch) >= chunk_size: - success += self.bulk_import( - batch, ignore_errors=ignore_errors, collection=collection - ) - batch = [] + for batch in self.iter_by_batch(self.to_ids(data), chunk_size): + count += self._bulk_delete(batch, ignore_errors, collection) + logger.info("Deleted %d documents with success", count) + else: + data = self.to_documents(data, ignore_errors, operation_type) + for batch in self.iter_by_batch(data, chunk_size): + count += self._bulk_import(batch, ignore_errors, collection) + logger.info("Inserted %d documents with success", count) - # Edge case: if the total number of documents is lower than the chunk size - if len(batch) > 0: - success += self.bulk_import( - batch, ignore_errors=ignore_errors, collection=collection - ) + return count - logger.debug("Inserted %d documents with success", success) - else: - msg = "%s operation_type is not allowed." - logger.error(msg, operation_type.name) - raise BackendParameterException(msg % operation_type.name) - return success + @staticmethod + def iter_by_batch(data: Iterable[dict], chunk_size: int): + """Iterate over `data` Iterable and yield batches of size `chunk_size`.""" + batch = [] + for document in data: + batch.append(document) + if len(batch) >= chunk_size: + yield batch + batch = [] + if batch: + yield batch - def _find(self, target: Union[None, str] = None, **kwargs): - """Wraps the MongoClient.collection.find method. + @staticmethod + def to_ids(data: Iterable[dict]) -> Iterable[str]: + """Convert `data` statements to ids.""" + for statement in data: + yield statement.get("id") - Raises: - BackendException: raised for any failure. - """ - try: - collection = self.get_collection(target) - return list(collection.find(**kwargs)) - except (PyMongoError, IndexError, TypeError, ValueError) as error: - msg = "Failed to execute MongoDB query" - logger.error("%s. %s", msg, error) - raise BackendException(msg, *error.args) from error + @staticmethod + def to_replace_one(data: Iterable[dict]) -> Iterable[ReplaceOne]: + """Convert `data` statements to Mongo `ReplaceOne` objects.""" + for statement in data: + yield ReplaceOne( + {"_source.id": {"$eq": statement.get("id")}}, + {"_source": statement}, + ) @staticmethod - def _parse_bytes_to_dict( - raw_documents: Iterable[bytes], ignore_errors: bool - ) -> Iterator[dict]: - """Reads the `raw_documents` Iterable and yields dictionaries.""" - for raw_document in raw_documents: + def to_documents( + data: Iterable[dict], ignore_errors: bool, operation_type: BaseOperationType + ) -> Generator[dict, None, None]: + """Convert `data` statements to MongoDB documents. + + We expect statements to have at least an `id` and a `timestamp` field that will + be used to compute a unique MongoDB Object ID. This ensures that we will not + duplicate statements in our database and allows us to support pagination. + """ + for statement in data: + if "id" not in statement and operation_type == BaseOperationType.INDEX: + msg = "statement %s has no 'id' field" + if ignore_errors: + logger.warning("statement %s has no 'id' field", statement) + continue + logger.error(msg, statement) + raise BackendException(msg % statement) + if "timestamp" not in statement: + msg = "statement %s has no 'timestamp' field" + if ignore_errors: + logger.warning(msg, statement) + continue + logger.error(msg, statement) + raise BackendException(msg % statement) try: - decoded_item = raw_document.decode("utf-8") - json_data = json.loads(decoded_item) - yield json_data - except (TypeError, json.JSONDecodeError) as err: - logger.error("Raised error: %s, for document %s", err, raw_document) + timestamp = int(isoparse(statement["timestamp"]).timestamp()) + except ValueError as err: + msg = "statement %s has an invalid 'timestamp' field" if ignore_errors: + logger.warning(msg, statement) continue - raise err + logger.error(msg, statement) + raise BackendException(msg % statement) from err + document = { + "_id": ObjectId( + # This might become a problem in February 2106. + # Meanwhile, we use the timestamp in the _id field for pagination. + struct.pack(">I", timestamp) + + bytes.fromhex( + hashlib.sha256( + bytes(statement.get("id", str(uuid4())), "utf-8") + ).hexdigest()[:16] + ) + ), + "_source": statement, + } - def _read_raw(self, document: Dict[str, Any]) -> bytes: - """Reads the `documents` Iterable and yields bytes.""" - return json.dumps(document).encode(self.locale_encoding) + yield document @staticmethod - def _read_dict(document: Dict[str, Any]) -> dict: - """Reads the `documents` Iterable and yields dictionaries.""" - return document - - -class MongoLRSBackend(BaseLRSBackend, MongoDataBackend): - """MongoDB LRS backend implementation.""" - - def query_statements(self, params: StatementParameters) -> StatementQueryResult: - """Returns the results of a statements query using xAPI parameters.""" - mongo_query_filters = {} - - if params.statementId: - mongo_query_filters.update({"_source.id": params.statementId}) - - if params.agent: - mongo_query_filters.update({"_source.actor.account.name": params.agent}) - - if params.verb: - mongo_query_filters.update({"_source.verb.id": params.verb}) - - if params.activity: - mongo_query_filters.update( - { - "_source.object.objectType": "Activity", - "_source.object.id": params.activity, - }, - ) - - if params.since: - mongo_query_filters.update({"_source.timestamp": {"$gt": params.since}}) + def _bulk_import(batch: list, ignore_errors: bool, collection: Collection): + """Insert a `batch` of documents into the MongoDB `collection`.""" + try: + new_documents = collection.insert_many(batch) + except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: + msg = "Failed to insert document chunk: %s" + if ignore_errors: + logger.warning(msg, error) + return getattr(error, "details", {}).get("nInserted", 0) + logger.error(msg, error) + raise BackendException(msg % error) from error - if params.until: - mongo_query_filters.update({"_source.timestamp": {"$lte": params.until}}) + inserted_count = len(new_documents.inserted_ids) + logger.debug("Inserted %d documents chunk with success", inserted_count) + return inserted_count - if params.search_after: - search_order = "$gt" if params.ascending else "$lt" - mongo_query_filters.update( - {"_id": {search_order: ObjectId(params.search_after)}} - ) + @staticmethod + def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): + """Delete a `batch` of documents from the MongoDB `collection`.""" + try: + new_documents = collection.delete_many({"_source.id": {"$in": batch}}) + except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: + msg = "Failed to delete document chunk: %s" + if ignore_errors: + logger.warning(msg, error) + return getattr(error, "details", {}).get("nRemoved", 0) + logger.error(msg, error) + raise BackendException(msg % error) from error - mongo_sort_order = ASCENDING if params.ascending else DESCENDING - mongo_query_sort = [ - ("_source.timestamp", mongo_sort_order), - ("_id", mongo_sort_order), - ] + deleted_count = new_documents.deleted_count + logger.debug("Deleted %d documents chunk with success", deleted_count) + return deleted_count - mongo_response = self._find( - filter=mongo_query_filters, limit=params.limit, sort=mongo_query_sort - ) - search_after = None - if mongo_response: - search_after = mongo_response[-1]["_id"] - - return StatementQueryResult( - statements=[document["_source"] for document in mongo_response], - pit_id=None, - search_after=search_after, - ) + @staticmethod + def _bulk_update(batch: list, ignore_errors: bool, collection: Collection): + """Update a `batch` of documents into the MongoDB `collection`.""" + try: + new_documents = collection.bulk_write(batch) + except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: + msg = "Failed to update document chunk: %s" + if ignore_errors: + logger.warning(msg, error) + return getattr(error, "details", {}).get("nModified", 0) + logger.error(msg, error) + raise BackendException(msg % error) from error - def query_statements_by_ids(self, ids: List[str]) -> List: - """Returns the list of matching statement IDs from the database.""" - return self._find(filter={"_source.id": {"$in": ids}}) + modified_count = new_documents.modified_count + logger.debug("Updated %d documents chunk with success", modified_count) + return modified_count diff --git a/src/ralph/backends/lrs/mongo.py b/src/ralph/backends/lrs/mongo.py new file mode 100644 index 000000000..fdbe83315 --- /dev/null +++ b/src/ralph/backends/lrs/mongo.py @@ -0,0 +1,137 @@ +"""MongoDB LRS backend for Ralph.""" + +import logging +from typing import Iterator, List + +from bson.objectid import ObjectId +from pymongo import ASCENDING, DESCENDING + +from ralph.backends.data.mongo import MongoDataBackend, MongoQuery +from ralph.backends.lrs.base import ( + AgentParameters, + BaseLRSBackend, + StatementParameters, + StatementQueryResult, +) +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class MongoLRSBackend(BaseLRSBackend, MongoDataBackend): + """MongoDB LRS backend.""" + + settings_class = MongoDataBackend.settings_class + + def query_statements(self, params: StatementParameters) -> StatementQueryResult: + """Return the results of a statements query using xAPI parameters.""" + query = self.get_query(params) + try: + mongo_response = list(self.read(query=query, chunk_size=params.limit)) + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from MongoDB") + raise error + + search_after = None + if mongo_response: + search_after = mongo_response[-1]["_id"] + + return StatementQueryResult( + statements=[document["_source"] for document in mongo_response], + pit_id=None, + search_after=search_after, + ) + + def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Yield statements with matching ids from the backend.""" + try: + mongo_response = self.read(query={"filter": {"_source.id": {"$in": ids}}}) + yield from (document["_source"] for document in mongo_response) + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from MongoDB") + raise error + + @staticmethod + def get_query(params: StatementParameters) -> MongoQuery: + """Construct query from statement parameters.""" + mongo_query_filters = {} + + if params.statementId: + mongo_query_filters.update({"_source.id": params.statementId}) + + MongoLRSBackend._add_agent_filters(mongo_query_filters, params.agent, "actor") + MongoLRSBackend._add_agent_filters( + mongo_query_filters, params.authority, "authority" + ) + + if params.verb: + mongo_query_filters.update({"_source.verb.id": params.verb}) + + if params.activity: + mongo_query_filters.update( + { + "_source.object.objectType": "Activity", + "_source.object.id": params.activity, + }, + ) + + if params.since: + mongo_query_filters.update( + {"_source.timestamp": {"$gt": params.since.isoformat()}} + ) + + if params.until: + if not params.since: + mongo_query_filters["_source.timestamp"] = {} + mongo_query_filters["_source.timestamp"].update( + {"$lte": params.until.isoformat()} + ) + + if params.search_after: + search_order = "$gt" if params.ascending else "$lt" + mongo_query_filters.update( + {"_id": {search_order: ObjectId(params.search_after)}} + ) + + mongo_sort_order = ASCENDING if params.ascending else DESCENDING + mongo_query_sort = [ + ("_source.timestamp", mongo_sort_order), + ("_id", mongo_sort_order), + ] + + # Note: `params` fields are validated thus we skip MongoQuery validation. + return MongoQuery.construct( + filter=mongo_query_filters, limit=params.limit, sort=mongo_query_sort + ) + + @staticmethod + def _add_agent_filters( + mongo_query_filters: dict, agent_params: AgentParameters, target_field: str + ): + """Add filters relative to agents to mongo_query_filters. + + Args: + mongo_query_filters (dict): Filters passed to MongoDB query. + agent_params (AgentParameters): Agent query parameters to search for. + target_field (str): The target agent field name to perform the search. + """ + if not agent_params: + return + + if agent_params.mbox: + key = f"_source.{target_field}.mbox" + mongo_query_filters.update({key: agent_params.mbox}) + + if agent_params.mbox_sha1sum: + key = f"_source.{target_field}.mbox_sha1sum" + mongo_query_filters.update({key: agent_params.mbox_sha1sum}) + + if agent_params.openid: + key = f"_source.{target_field}.openid" + mongo_query_filters.update({key: agent_params.openid}) + + if agent_params.account__name: + key = f"_source.{target_field}.account.name" + mongo_query_filters.update({key: agent_params.account__name}) + key = f"_source.{target_field}.account.homePage" + mongo_query_filters.update({key: agent_params.account__home_page}) diff --git a/src/ralph/utils.py b/src/ralph/utils.py index ff6bb5b23..090085534 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -156,13 +156,27 @@ def parse_bytes_to_dict( yield json.loads(raw_document) except (TypeError, json.JSONDecodeError) as error: msg = "Failed to decode JSON: %s, for document: %s" - logger_class.error(msg, error, raw_document) if ignore_errors: + logger_class.warning(msg, error, raw_document) continue + logger_class.error(msg, error, raw_document) raise BackendException(msg % (error, raw_document)) from error -def read_raw(documents: Iterable[Dict[str, Any]], encoding: str) -> Iterator[bytes]: +def read_raw( + documents: Iterable[Dict[str, Any]], + encoding: str, + ignore_errors: bool, + logger_class: logging.Logger, +) -> Iterator[bytes]: """Read the `documents` Iterable with the `encoding` and yield bytes.""" for document in documents: - yield json.dumps(document).encode(encoding) + try: + yield json.dumps(document).encode(encoding) + except (TypeError, ValueError) as error: + msg = "Failed to convert document to bytes: %s" + if ignore_errors: + logger_class.warning(msg, error) + continue + logger_class.error(msg, error) + raise BackendException(msg % error) from error diff --git a/tests/backends/data/test_fs.py b/tests/backends/data/test_fs.py index 8845710f6..51779a34f 100644 --- a/tests/backends/data/test_fs.py +++ b/tests/backends/data/test_fs.py @@ -15,7 +15,7 @@ def test_backends_data_fs_data_backend_default_instantiation(monkeypatch, fs): - """Tests the `FSDataBackend` default instantiation.""" + """Test the `FSDataBackend` default instantiation.""" # pylint: disable=invalid-name fs.create_file(".env") backend_settings_names = [ @@ -39,7 +39,7 @@ def test_backends_data_fs_data_backend_default_instantiation(monkeypatch, fs): def test_backends_data_fs_data_backend_instantiation_with_settings(fs): - """Tests the `FSDataBackend` instantiation with settings.""" + """Test the `FSDataBackend` instantiation with settings.""" # pylint: disable=invalid-name,unused-argument deep_path = "deep/directories/path" assert not os.path.exists(deep_path) @@ -60,7 +60,7 @@ def test_backends_data_fs_data_backend_instantiation_with_settings(fs): try: FSDataBackend(settings) except Exception as err: # pylint:disable=broad-except - pytest.fail(f"FSDataBackend should not raise exceptions: {err}") + pytest.fail(f"Two FSDataBackends should not raise exceptions: {err}") @pytest.mark.parametrize( @@ -70,7 +70,7 @@ def test_backends_data_fs_data_backend_instantiation_with_settings(fs): def test_backends_data_fs_data_backend_status_method_with_error_status( mode, fs_backend, caplog ): - """Tests the `FSDataBackend.status` method, given a directory with wrong + """Test the `FSDataBackend.status` method, given a directory with wrong permissions, should return `DataBackendStatus.ERROR`. """ os.mkdir("directory", mode) @@ -87,7 +87,7 @@ def test_backends_data_fs_data_backend_status_method_with_error_status( @pytest.mark.parametrize("mode", [0o700]) def test_backends_data_fs_data_backend_status_method_with_ok_status(mode, fs_backend): - """Tests the `FSDataBackend.status` method, given a directory with right + """Test the `FSDataBackend.status` method, given a directory with right permissions, should return `DataBackendStatus.OK`. """ os.mkdir("directory", mode) @@ -108,7 +108,7 @@ def test_backends_data_fs_data_backend_status_method_with_ok_status(mode, fs_bac def test_backends_data_fs_data_backend_list_method_with_invalid_target( files, target, error, fs_backend, fs ): - """Tests the `FSDataBackend.list` method given an invalid `target` argument should + """Test the `FSDataBackend.list` method given an invalid `target` argument should raise a `BackendParameterException`. """ # pylint: disable=invalid-name @@ -142,7 +142,7 @@ def test_backends_data_fs_data_backend_list_method_with_invalid_target( def test_backends_data_fs_data_backend_list_method_without_history( files, target, expected, fs_backend, fs ): - """Tests the `FSDataBackend.list` method without history.""" + """Test the `FSDataBackend.list` method without history.""" # pylint: disable=invalid-name for file in files: fs.create_file(file) @@ -175,7 +175,7 @@ def test_backends_data_fs_data_backend_list_method_without_history( def test_backends_data_fs_data_backend_list_method_with_details( files, target, expected, fs_backend, fs ): - """Tests the `FSDataBackend.list` method with `details` set to `True`.""" + """Test the `FSDataBackend.list` method with `details` set to `True`.""" # pylint: disable=invalid-name,too-many-arguments for file in files: fs.create_file(file) @@ -191,7 +191,7 @@ def test_backends_data_fs_data_backend_list_method_with_details( def test_backends_data_fs_data_backend_list_method_with_history(fs_backend, fs): - """Tests the `FSDataBackend.list` method with history.""" + """Test the `FSDataBackend.list` method with history.""" # pylint: disable=invalid-name # Create 3 files in the default directory. @@ -269,7 +269,7 @@ def test_backends_data_fs_data_backend_list_method_with_history(fs_backend, fs): def test_backends_data_fs_data_backend_list_method_with_history_and_details( fs_backend, fs ): - """Tests the `FSDataBackend.list` method with an history and detailed output.""" + """Test the `FSDataBackend.list` method with an history and detailed output.""" # pylint: disable=invalid-name # Create 3 files in the default directory. @@ -361,7 +361,7 @@ def test_backends_data_fs_data_backend_list_method_with_history_and_details( def test_backends_data_fs_data_backend_read_method_with_raw_ouput( fs_backend, fs, monkeypatch ): - """Tests the `FSDataBackend.read` method with `raw_output` set to `True`.""" + """Test the `FSDataBackend.read` method with `raw_output` set to `True`.""" # pylint: disable=invalid-name # Create files in absolute path directory. @@ -476,7 +476,7 @@ def test_backends_data_fs_data_backend_read_method_with_raw_ouput( def test_backends_data_fs_data_backend_read_method_without_raw_output( fs_backend, fs, monkeypatch ): - """Tests the `FSDataBackend.read` method with `raw_output` set to `False`.""" + """Test the `FSDataBackend.read` method with `raw_output` set to `False`.""" # pylint: disable=invalid-name # File contents. @@ -558,7 +558,7 @@ def test_backends_data_fs_data_backend_read_method_without_raw_output( def test_backends_data_fs_data_backend_read_method_with_ignore_errors(fs_backend, fs): - """Tests the `FSDataBackend.read` method with `ignore_errors` set to `True`, given + """Test the `FSDataBackend.read` method with `ignore_errors` set to `True`, given a file containing invalid JSON lines, should skip the invalid lines. """ # pylint: disable=invalid-name @@ -603,7 +603,7 @@ def test_backends_data_fs_data_backend_read_method_with_ignore_errors(fs_backend def test_backends_data_fs_data_backend_read_method_without_ignore_errors( fs_backend, fs, monkeypatch ): - """Tests the `FSDataBackend.read` method with `ignore_errors` set to `False`, given + """Test the `FSDataBackend.read` method with `ignore_errors` set to `False`, given a file containing invalid JSON lines, should raise a `BackendException`. """ # pylint: disable=invalid-name @@ -685,7 +685,7 @@ def test_backends_data_fs_data_backend_read_method_without_ignore_errors( def test_backends_data_fs_data_backend_read_method_with_query(fs_backend, fs): - """Tests the `FSDataBackend.read` method, given a query argument.""" + """Test the `FSDataBackend.read` method, given a query argument.""" # pylint: disable=invalid-name # File contents. @@ -742,7 +742,7 @@ def test_backends_data_fs_data_backend_read_method_with_query(fs_backend, fs): def test_backends_data_fs_data_backend_write_method_with_file_exists_error( operation_type, fs_backend, fs ): - """Tests the `FSDataBackend.write` method, given a target matching an + """Test the `FSDataBackend.write` method, given a target matching an existing file and a `CREATE` or `INDEX` `operation_type`, should raise a `BackendException`. """ @@ -767,7 +767,7 @@ def test_backends_data_fs_data_backend_write_method_with_file_exists_error( def test_backends_data_fs_data_backend_write_method_with_delete_operation( fs_backend, ): - """Tests the `FSDataBackend.write` method, given a `DELETE` `operation_type`, should + """Test the `FSDataBackend.write` method, given a `DELETE` `operation_type`, should raise a `BackendParameterException`. """ # pylint: disable=invalid-name @@ -784,7 +784,7 @@ def test_backends_data_fs_data_backend_write_method_with_delete_operation( def test_backends_data_fs_data_backend_write_method_with_update_operation( fs_backend, fs, monkeypatch ): - """Tests the `FSDataBackend.write` method, given an `UPDATE` `operation_type`, + """Test the `FSDataBackend.write` method, given an `UPDATE` `operation_type`, should overwrite the target file content with the provided data. """ # pylint: disable=invalid-name @@ -892,7 +892,7 @@ def test_backends_data_fs_data_backend_write_method_with_update_operation( def test_backends_data_fs_data_backend_write_method_with_append_operation( data, expected, fs_backend, fs, monkeypatch ): - """Tests the `FSDataBackend.write` method, given an `APPEND` `operation_type`, + """Test the `FSDataBackend.write` method, given an `APPEND` `operation_type`, should append the provided data to the end of the target file. """ # pylint: disable=invalid-name @@ -942,10 +942,20 @@ def test_backends_data_fs_data_backend_write_method_with_append_operation( ] +def test_backends_data_fs_data_backend_write_method_with_no_data(fs_backend, caplog): + """Test the `FSDataBackend.write` method, given no data, should return 0.""" + backend = fs_backend() + with caplog.at_level(logging.INFO): + assert backend.write(data=[]) == 0 + + msg = "Data Iterator is empty; skipping write to target." + assert ("ralph.backends.data.fs", logging.INFO, msg) in caplog.record_tuples + + def test_backends_data_fs_data_backend_write_method_without_target( fs_backend, monkeypatch ): - """Tests the `FSDataBackend.write` method, given no `target` argument, + """Test the `FSDataBackend.write` method, given no `target` argument, should create a new random file and write the provided data into it. """ # pylint: disable=invalid-name diff --git a/tests/backends/data/test_mongo.py b/tests/backends/data/test_mongo.py index 5fb7425c8..25d5d6049 100644 --- a/tests/backends/data/test_mongo.py +++ b/tests/backends/data/test_mongo.py @@ -1,1116 +1,872 @@ -# pylint: disable=too-many-lines -"""Tests for Ralph mongo data backend.""" +"""Tests for Ralph MongoDB data backend.""" import json import logging -from datetime import datetime import pytest from bson.objectid import ObjectId from pymongo import MongoClient -from pymongo.errors import PyMongoError +from pymongo.errors import ConnectionFailure, PyMongoError from ralph.backends.data.base import BaseOperationType, DataBackendStatus -from ralph.backends.data.mongo import MongoDataBackend, MongoLRSBackend, MongoQuery -from ralph.backends.lrs.base import StatementParameters -from ralph.exceptions import ( - BackendException, - BackendParameterException, - BadFormatException, +from ralph.backends.data.mongo import ( + MongoClientOptions, + MongoDataBackend, + MongoDataBackendSettings, + MongoQuery, ) +from ralph.exceptions import BackendException, BackendParameterException from tests.fixtures.backends import ( MONGO_TEST_COLLECTION, MONGO_TEST_CONNECTION_URI, MONGO_TEST_DATABASE, - MONGO_TEST_FORWARDING_COLLECTION, ) -def test_backends_data_mongo_data_backend_instantiation_with_settings(): - """Test the Mongo backend instantiation.""" - assert MongoDataBackend.name == "mongo" - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - - assert isinstance(backend.client, MongoClient) - assert hasattr(backend.client, MONGO_TEST_DATABASE) - database = getattr(backend.client, MONGO_TEST_DATABASE) - assert hasattr(database, MONGO_TEST_COLLECTION) - - -def test_backends_data_mongo_data_backend_read_method_without_raw_output(mongo): - """Test the mongo backend get method.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) - - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - expected = [ - {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, - {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, +def test_backends_data_mongo_data_backend_default_instantiation(monkeypatch, fs): + """Test the `MongoDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "CONNECTION_URI", + "DEFAULT_DATABASE", + "DEFAULT_COLLECTION", + "CLIENT_OPTIONS", + "DEFAULT_CHUNK_SIZE", + "LOCALE_ENCODING", ] - assert list(backend.read()) == expected - assert list(backend.read(chunk_size=2)) == expected - assert list(backend.read(chunk_size=1000)) == expected - - -def test_backends_data_mongo_data_backend_read_method_with_query_string(mongo): - """Test the mongo backend get method with query string.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) - - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - expected = [ - {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, - ] - query = MongoQuery( - query_string=json.dumps({"filter": {"_source.id": {"$eq": "foo"}}}) - ) - assert list(backend.read(query=query)) == expected - assert list(backend.read(query=query, chunk_size=2)) == expected - assert list(backend.read(query=query, chunk_size=1000)) == expected + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__MONGO__{name}", raising=False) + assert MongoDataBackend.name == "mongo" + assert MongoDataBackend.query_model == MongoQuery + assert MongoDataBackend.default_operation_type == BaseOperationType.INDEX + assert MongoDataBackend.settings_class == MongoDataBackendSettings + backend = MongoDataBackend() + assert isinstance(backend.client, MongoClient) + assert backend.database.name == "statements" + assert backend.collection.name == "marsha" + assert backend.settings.CONNECTION_URI == "mongodb://localhost:27017/" + assert backend.settings.CLIENT_OPTIONS == MongoClientOptions() + assert backend.settings.DEFAULT_CHUNK_SIZE == 500 + assert backend.settings.LOCALE_ENCODING == "utf8" -def test_backends_data_mongo_data_backend_list_method(mongo): - """Test the mongo backend list method.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) - # Get backend +def test_backends_data_mongo_data_backend_instantiation_with_settings(): + """Test the `MongoDataBackend` instantiation with settings.""" settings = MongoDataBackend.settings_class( CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION="foo", + CLIENT_OPTIONS={"tz_aware": "True"}, + DEFAULT_CHUNK_SIZE=1000, + LOCALE_ENCODING="utf8", ) backend = MongoDataBackend(settings) - assert list(backend.list(details=True))[0]["name"] == MONGO_TEST_COLLECTION - assert list(backend.list(details=False)) == [MONGO_TEST_COLLECTION] + assert backend.database.name == MONGO_TEST_DATABASE + assert backend.collection.name == "foo" + assert backend.settings.CONNECTION_URI == MONGO_TEST_CONNECTION_URI + assert backend.settings.CLIENT_OPTIONS == MongoClientOptions(tz_aware=True) + assert backend.settings.DEFAULT_CHUNK_SIZE == 1000 + assert backend.settings.LOCALE_ENCODING == "utf8" - -def test_backends_data_mongo_data_backend_list_method_with_details(mongo): - """Test the mongo backend list method.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) - - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - assert [elt["_id"] for elt in list(backend.read())] == [ - "62b9ce922c26b46b68ffc68f", - "62b9ce92fcde2b2edba56bf4", - ] + try: + MongoDataBackend(settings) + except Exception as err: # pylint:disable=broad-except + pytest.fail(f"Two MongoDataBackends should not raise exceptions: {err}") -def test_backends_data_mongo_data_backend_list_method_with_target(mongo): - """Test the mongo backend list method.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) +def test_backends_data_mongo_data_backend_status_with_connection_failure( + mongo_backend, monkeypatch, caplog +): + """Test the `MongoDataBackend.status` method, given a connection failure, should + return `DataBackendStatus.AWAY`. + """ - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - assert [elt["_id"] for elt in list(backend.read(target=MONGO_TEST_COLLECTION))] == [ - "62b9ce922c26b46b68ffc68f", - "62b9ce92fcde2b2edba56bf4", - ] + class MockMongoClientAdmin: + """Mock the `MongoClient.admin` property.""" + @staticmethod + def command(command: str): + """Mock the `command` method always raising a `ConnectionFailure`.""" + assert command == "ping" + raise ConnectionFailure("Connection failure") -def test_backends_database_mongo_get_method_with_raw_ouput(mongo): - """Test the mongo backend get method with raw output.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) + class MockMongoClient: + """Mock the `pymongo.MongoClient`.""" - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - expected = [ - {"_id": "62b9ce922c26b46b68ffc68f", "id": "foo", **timestamp}, - {"_id": "62b9ce92fcde2b2edba56bf4", "id": "bar", **timestamp}, - ] - results = list(backend.read(raw_output=True)) - assert len(results) == 2 - assert isinstance(results[0], bytes) - assert json.loads(results[0])["_source"]["id"] == expected[0]["id"] + admin = MockMongoClientAdmin + backend = mongo_backend() + monkeypatch.setattr(backend, "client", MockMongoClient) + with caplog.at_level(logging.ERROR): + assert backend.status() == DataBackendStatus.AWAY -def test_backends_database_mongo_get_method_with_target(mongo): - """Test the mongo backend get method with raw output.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) + assert ( + "ralph.backends.data.mongo", + logging.ERROR, + "Failed to connect to MongoDB: Connection failure", + ) in caplog.record_tuples - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - expected = [ - {"_id": "62b9ce922c26b46b68ffc68f", "id": "foo", **timestamp}, - {"_id": "62b9ce92fcde2b2edba56bf4", "id": "bar", **timestamp}, - ] - results = list(backend.read(raw_output=True, target=MONGO_TEST_COLLECTION)) - assert len(results) == 2 - assert isinstance(results[0], bytes) - assert json.loads(results[0])["_source"]["id"] == expected[0]["id"] +def test_backends_data_mongo_data_backend_status_with_error_status( + mongo_backend, monkeypatch, caplog +): + """Test the `MongoDataBackend.status` method, given a failed serverStatus command, + should return `DataBackendStatus.ERROR`. + """ -def test_backends_data_mongo_data_backend_read_method_with_query(mongo): - """Test the mongo backend get method with a custom query.""" - # Create records - timestamp = {"timestamp": datetime.now().isoformat()} - documents = MongoDataBackend.to_documents( - [ - {"id": "foo", "bool": 1, **timestamp}, - {"id": "bar", "bool": 0, **timestamp}, - {"id": "lol", "bool": 1, **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) + class MockMongoClientAdmin: + """Mock the `MongoClient.admin` property.""" + + @staticmethod + def command(command: str): + """Mock the `command` method always raising a `ConnectionFailure`.""" + if command == "ping": + return + assert command == "serverStatus" + raise PyMongoError("Server status failure") + + class MockMongoClient: + """Mock the `pymongo.MongoClient`.""" + + admin = MockMongoClientAdmin + + backend = mongo_backend() + monkeypatch.setattr(backend, "client", MockMongoClient) + with caplog.at_level(logging.ERROR): + assert backend.status() == DataBackendStatus.ERROR + + assert ( + "ralph.backends.data.mongo", + logging.ERROR, + "Failed to get MongoDB server status: Server status failure", + ) in caplog.record_tuples + + # Given a MongoDB serverStatus query returning an ok status different from 1, + # the `status` method should return `DataBackendStatus.ERROR`. + monkeypatch.setattr(MockMongoClientAdmin, "command", lambda x: {"ok": 0}) + with caplog.at_level(logging.ERROR): + assert backend.status() == DataBackendStatus.ERROR + + assert ( + "ralph.backends.data.mongo", + logging.ERROR, + "MongoDB `serverStatus` command did not return 1.0", + ) in caplog.record_tuples + + +def test_backends_data_mongo_data_backend_status_with_ok_status(mongo_backend): + """Test the `MongoDataBackend.status` method, given a successful connection and + serverStatus command, should return `DataBackendStatus.OK`. + """ + backend = mongo_backend() + assert backend.status() == DataBackendStatus.OK - # Get backend - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - # Test filtering - query = MongoQuery(filter={"_source.bool": {"$eq": 1}}) - results = list(backend.read(query=query)) - assert len(results) == 2 - assert results[0]["_source"]["id"] == "foo" - assert results[1]["_source"]["id"] == "lol" - - # Test projection - query = MongoQuery(projection={"_source.bool": 1}) - results = list(backend.read(query=query)) - assert len(results) == 3 - assert list(results[0]["_source"].keys()) == ["bool"] - assert list(results[1]["_source"].keys()) == ["bool"] - assert list(results[2]["_source"].keys()) == ["bool"] - - # Test filtering and projection - query = MongoQuery( - filter={"_source.bool": {"$eq": 0}}, projection={"_source.id": 1} +@pytest.mark.parametrize("invalid_character", [" ", ".", "/", '"']) +def test_backends_data_mongo_data_backend_list_method_with_invalid_target( + invalid_character, mongo_backend, caplog +): + """Test the `MongoDataBackend.list` method given an invalid `target` argument, + should raise a `BackendParameterException`. + """ + backend = mongo_backend() + msg = ( + f"The target=`foo{invalid_character}bar` is not a valid database name: " + f"database names cannot contain the character '{invalid_character}'" ) - results = list(backend.read(query=query)) - assert len(results) == 1 - assert results[0]["_source"]["id"] == "bar" - assert list(results[0]["_source"].keys()) == ["id"] - - -def test_backends_database_mongo_to_documents_method(): - """Test the mongo backend to_documents method.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "bar", **timestamp}, - ] - documents = MongoDataBackend.to_documents(statements) - - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - # Identical statement ID produces the same ObjectId - assert next(documents) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendParameterException, match=msg): + list(backend.list(f"foo{invalid_character}bar")) + assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples -def test_backends_database_mongo_to_documents_method_when_statement_has_no_id(caplog): - """Test the mongo backend to_documents method when a statement has no id field.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, timestamp, {"id": "bar", **timestamp}] - documents = MongoDataBackend.to_documents(statements, ignore_errors=False) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - with pytest.raises( - BadFormatException, match=f"statement {timestamp} has no 'id' field" - ): - next(documents) - - documents = MongoDataBackend.to_documents(statements, ignore_errors=True) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert caplog.records[0].message == f"statement {timestamp} has no 'id' field" - - -def test_backends_database_mongo_to_documents_method_when_statement_has_no_timestamp( - caplog, +def test_backends_data_mongo_data_backend_list_method_with_failure( + mongo_backend, monkeypatch, caplog ): - """Tests the mongo backend to_documents method when a statement has no timestamp.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar"}, {"id": "baz", **timestamp}] - - documents = MongoDataBackend.to_documents(statements, ignore_errors=False) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - - with pytest.raises( - BadFormatException, match="statement {'id': 'bar'} has no 'timestamp' field" - ): - next(documents) - - documents = MongoDataBackend.to_documents(statements, ignore_errors=True) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92baa5a0964d3320fb"), - "_source": {"id": "baz", **timestamp}, - } - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert caplog.records[0].message == ( - "statement {'id': 'bar'} has no 'timestamp' field" - ) - - -def test_backends_database_mongo_to_documents_method_with_invalid_timestamp(caplog): - """Tests the mongo backend to_documents method given a statement with an invalid - timestamp. + """Test the `MongoDataBackend.list` method given a failure while retrieving MongoDB + collections, should raise a `BackendException`. """ - valid_timestamp = {"timestamp": "2022-06-27T15:36:50"} - invalid_timestamp = {"timestamp": "This is not a valid timestamp!"} - invalid_statement = {"id": "bar", **invalid_timestamp} - statements = [ - {"id": "foo", **valid_timestamp}, - invalid_statement, - {"id": "baz", **valid_timestamp}, - ] - documents = MongoDataBackend.to_documents(statements, ignore_errors=False) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **valid_timestamp}, - } + def list_collections(): + """Mock the `list_collections` method always raising an exception.""" + raise PyMongoError("Connection error") - with pytest.raises( - BadFormatException, - match=f"statement {invalid_statement} has an invalid 'timestamp' field", - ): - next(documents) + backend = mongo_backend() + monkeypatch.setattr(backend.database, "list_collections", list_collections) + msg = "Failed to list MongoDB collections: Connection error" + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + list(backend.list()) - documents = MongoDataBackend.to_documents(statements, ignore_errors=True) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **valid_timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92baa5a0964d3320fb"), - "_source": {"id": "baz", **valid_timestamp}, - } - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert caplog.records[0].message == ( - f"statement {invalid_statement} has an invalid 'timestamp' field" - ) + assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples -def test_backends_database_mongo_bulk_import_method(mongo): - """Test the mongo backend bulk_import method.""" +def test_backends_data_mongo_data_backend_list_method_without_history( + mongo, mongo_backend +): + """Test the `MongoDataBackend.list` method without history.""" # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + backend = mongo_backend() + assert list(backend.list()) == [MONGO_TEST_COLLECTION] + assert list(backend.list(MONGO_TEST_DATABASE)) == [MONGO_TEST_COLLECTION] + assert list(backend.list(details=True))[0]["name"] == MONGO_TEST_COLLECTION + backend.database.create_collection("bar") + backend.database.create_collection("baz") + assert sorted(backend.list()) == sorted([MONGO_TEST_COLLECTION, "bar", "baz"]) + assert sorted(collection["name"] for collection in backend.list(details=True)) == ( + sorted([MONGO_TEST_COLLECTION, "bar", "baz"]) ) - backend = MongoDataBackend(settings) - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend.bulk_import(MongoDataBackend.to_documents(statements)) - - results = backend.collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } + assert not list(backend.list("non_existent_database")) -def test_backends_database_mongo_bulk_delete_method(mongo): - """Test the mongo backend bulk_delete method.""" - # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend.bulk_import(MongoDataBackend.to_documents(statements)) - documents = [st["id"] for st in statements] - backend.bulk_delete(batch=documents) +def test_backends_data_mongo_data_backend_list_method_with_history( + mongo_backend, caplog +): + """Test the `MongoDataBackend.list` method given `new` argument set to `True`, + should log a warning message. + """ + backend = mongo_backend() + with caplog.at_level(logging.WARNING): + assert not list(backend.list("non_existent_database", new=True)) - results = backend.collection.find() - assert next(results, None) is None + assert ( + "ralph.backends.data.mongo", + logging.WARNING, + "The `new` argument is ignored", + ) in caplog.record_tuples -def test_backends_database_mongo_bulk_update_method(mongo): - """Test the mongo backend bulk_update method.""" +def test_backends_data_mongo_data_backend_read_method_with_raw_output( + mongo, mongo_backend +): + """Test the `MongoDataBackend.read` method with `raw_output` set to `True`.""" # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend.bulk_import(MongoDataBackend.to_documents(statements)) - statements = [ - {"id": "foo", "text": "foo", **timestamp}, - {"id": "bar", "text": "bar", **timestamp}, + backend = mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": "bar"}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, ] - success = backend.write(data=statements, operation_type=BaseOperationType.UPDATE) - assert success == 2 - - results = backend.collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", "text": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", "text": "bar", **timestamp}, - } + expected = [ + b'{"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}', + b'{"_id": "64945e530468d817b1f756da", "id": "bar"}', + b'{"_id": "64945e530468d817b1f756db", "id": "baz"}', + ] + backend.collection.insert_many(documents) + backend.database.foobar.insert_many(documents[:2]) + assert list(backend.read(raw_output=True)) == expected + assert list(backend.read(raw_output=True, target="foobar")) == expected[:2] + assert list(backend.read(raw_output=True, chunk_size=2)) == expected + assert list(backend.read(raw_output=True, chunk_size=1000)) == expected -def test_backends_database_mongo_bulk_update_method_iterable(mongo): - """Test the mongo backend bulk_update method.""" +def test_backends_data_mongo_data_backend_read_method_without_raw_output( + mongo, mongo_backend +): + """Test the `MongoDataBackend.read` method with `raw_output` set to `False`.""" # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend.bulk_import(MongoDataBackend.to_documents(statements)) - statements = [ - {"id": "foo", "text": "foo", **timestamp}, - {"id": "bar", "text": "bar", **timestamp}, + backend = mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": "bar"}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, ] - statements = iter(statements) - success = backend.write(data=statements, operation_type=BaseOperationType.UPDATE) - assert success == 2 - results = backend.collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", "text": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", "text": "bar", **timestamp}, - } - + expected = [ + {"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}, + {"_id": "64945e530468d817b1f756da", "id": "bar"}, + {"_id": "64945e530468d817b1f756db", "id": "baz"}, + ] + backend.collection.insert_many(documents) + backend.database.foobar.insert_many(documents[:2]) + assert list(backend.read()) == expected + assert list(backend.read(target="foobar")) == expected[:2] + assert list(backend.read(chunk_size=2)) == expected + assert list(backend.read(chunk_size=1000)) == expected -def test_backends_database_mongo_bulk_wrong_operation_type(mongo): - """Test the mongo backend bulk_update method.""" - # pylint: disable=unused-argument - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, +@pytest.mark.parametrize( + "invalid_target,error", + [ + (".foo", "must not start or end with '.': '.foo'"), + ("foo.", "must not start or end with '.': 'foo.'"), + ("foo$bar", "must not contain '$': 'foo$bar'"), + ("foo..bar", "cannot be empty"), + ], +) +def test_backends_data_mongo_data_backend_read_method_with_invalid_target( + invalid_target, error, mongo_backend, caplog +): + """Test the `MongoDataBackend.read` method given an invalid `target` argument, + should raise a `BackendParameterException`. + """ + backend = mongo_backend() + msg = ( + f"The target=`{invalid_target}` is not a valid collection name: " + f"collection names {error}" ) - backend = MongoDataBackend(settings) - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend.bulk_import(MongoDataBackend.to_documents(statements)) - statements = [ - {"id": "foo", "text": "foo", **timestamp}, - {"id": "bar", "text": "bar", **timestamp}, - ] + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendParameterException, match=msg.replace("$", r"\$")): + list(backend.read(target=invalid_target)) - with pytest.raises( - BackendParameterException, - match=f"{BaseOperationType.APPEND.name} operation_type is not allowed.", - ): - backend.write(data=statements, operation_type=BaseOperationType.APPEND) + assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples -def test_backends_database_mongo_bulk_no_data(mongo): - """Test the mongo backend bulk_update method.""" - # pylint: disable=unused-argument +def test_backends_data_mongo_data_backend_read_method_with_failure( + mongo_backend, monkeypatch, caplog +): + """Test the `MongoDataBackend.read` method given a MongoClient failure, + should raise a `BackendException`. + """ - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) + def mock_find(batch_size, query=None): + """Mock the `MongoClient.collection.find` method always raising an Exception.""" + assert batch_size == 500 + assert not query + raise PyMongoError("MongoDB internal failure") - success = backend.write(data=[], operation_type=BaseOperationType.CREATE) + backend = mongo_backend() + monkeypatch.setattr(backend.collection, "find", mock_find) + msg = "Failed to execute MongoDB query: MongoDB internal failure" + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + list(backend.read()) - assert success == 0 + assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples -def test_backends_database_mongo_bulk_import_method_with_duplicated_key(mongo): - """Test the mongo backend bulk_import method with a duplicated key conflict.""" +def test_backends_data_mongo_data_backend_read_method_with_ignore_errors( + mongo, mongo_backend, caplog +): + """Test the `MongoDataBackend.read` method with `ignore_errors` set to `True`, given + a collection containing unparsable documents, should skip the invalid documents. + """ # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - - # Identical statement ID produces the same ObjectId, leading to a - # duplicated key write error while trying to bulk import this batch - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "bar", **timestamp}, + backend = mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": ObjectId()}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, ] - documents = list(MongoDataBackend.to_documents(statements)) - with pytest.raises(BackendException, match="E11000 duplicate key error collection"): - backend.bulk_import(documents) - - success = backend.bulk_import(documents, ignore_errors=True) - assert success == 0 - - -def test_backends_database_mongo_bulk_import_method_import_partial_chunks_on_error( - mongo, + expected = [ + b'{"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}', + b'{"_id": "64945e530468d817b1f756db", "id": "baz"}', + ] + backend.collection.insert_many(documents) + backend.database.foobar.insert_many(documents[:2]) + kwargs = {"raw_output": True, "ignore_errors": True} + with caplog.at_level(logging.WARNING): + assert list(backend.read(**kwargs)) == expected + assert list(backend.read(**kwargs, target="foobar")) == expected[:1] + assert list(backend.read(**kwargs, chunk_size=2)) == expected + assert list(backend.read(**kwargs, chunk_size=1000)) == expected + + assert ( + "ralph.backends.data.mongo", + logging.WARNING, + "Failed to convert document to bytes: " + "Object of type ObjectId is not JSON serializable", + ) in caplog.record_tuples + + +def test_backends_data_mongo_data_backend_read_method_without_ignore_errors( + mongo, mongo_backend, caplog ): - """Test the mongo backend bulk_import method imports partial chunks while raising a - BulkWriteError and ignoring errors. + """Test the `MongoDataBackend.read` method with `ignore_errors` set to `False`, + given a collection containing unparsable documents, should raise a + `BackendException`. """ # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - - # Identical statement ID produces the same ObjectId, leading to a - # duplicated key write error while trying to bulk import this batch - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "baz", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "lol", **timestamp}, + backend = mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": ObjectId()}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, ] - documents = list(MongoDataBackend.to_documents(statements)) - assert backend.bulk_import(documents, ignore_errors=True) == 3 - + expected = b'{"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}' + backend.collection.insert_many(documents) + backend.database.foobar.insert_many(documents[:2]) + kwargs = {"raw_output": True, "ignore_errors": False} + msg = ( + "Failed to convert document to bytes: " + "Object of type ObjectId is not JSON serializable" + ) + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + result = backend.read(**kwargs) + assert next(result) == expected + next(result) + with pytest.raises(BackendException, match=msg): + result = backend.read(**kwargs, target="foobar") + assert next(result) == expected + next(result) + with pytest.raises(BackendException, match=msg): + result = backend.read(**kwargs, chunk_size=2) + assert next(result) == expected + next(result) + with pytest.raises(BackendException, match=msg): + result = backend.read(**kwargs, chunk_size=1000) + assert next(result) == expected + next(result) + + error_log = ("ralph.backends.data.mongo", logging.ERROR, msg) + assert len(list(filter(lambda x: x == error_log, caplog.record_tuples))) == 4 + + +@pytest.mark.parametrize( + "query", + [ + '{"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}', + {"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}, + MongoQuery( + query_string='{"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}' + ), + # Given both `query_string` and other query arguments, only the `query_string` + # should be applied. + MongoQuery( + query_string='{"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}', + filter={"id": {"$eq": "foo"}}, + projection={"id": 0}, + ), + MongoQuery(filter={"id": {"$eq": "bar"}}, projection={"id": 1}), + ], +) +def test_backends_data_mongo_data_backend_read_method_with_query( + query, mongo, mongo_backend +): + """Test the `MongoDataBackend.read` method given a query argument.""" + # pylint: disable=unused-argument + # Create records + backend = mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo", "qux": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": "bar", "qux": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "bar", "qux": "foo"}, + ] + expected = [ + {"_id": "64945e530468d817b1f756da", "id": "bar"}, + {"_id": "64945e530468d817b1f756db", "id": "bar"}, + ] + backend.collection.insert_many(documents) + assert list(backend.read(query=query)) == expected + assert list(backend.read(query=query, chunk_size=1)) == expected + assert list(backend.read(query=query, chunk_size=1000)) == expected -def test_backends_database_mongo_put_method(mongo): - """Test the mongo backend put method.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 +def test_backends_data_mongo_data_backend_write_method_with_target( + mongo, mongo_backend +): + """Test the `MongoDataBackend.write` method, given a valid `target` argument, should + write documents to the target collection. + """ + # pylint: disable=unused-argument + backend = mongo_backend() timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + assert backend.write(documents, target="foo_target_collection") == 2 - success = backend.write(statements) - assert success == 2 - assert collection.estimated_document_count() == 2 + # The documents should not be written to the default collection. + assert not list(backend.read()) - results = collection.find() + results = backend.read(target="foo_target_collection") assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}, } assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}, } -def test_backends_database_mongo_put_method_bytes(mongo): - """Test the mongo backend put method with bytes.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", "text": "foo", **timestamp}, - {"id": "bar", "text": "bar", **timestamp}, - ] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - byte_data = [] - for item in statements: - json_str = json.dumps(item, separators=(",", ":"), ensure_ascii=False) - byte_data.append(json_str.encode("utf-8")) - success = backend.write(byte_data) - assert success == 2 - assert collection.estimated_document_count() == 2 - - results = collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", "text": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", "text": "bar", **timestamp}, - } - - -def test_backends_database_mongo_put_method_bytes_failed(mongo): - """Test the mongo backend put method with bytes.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - byte_data = [] - json_str = "failed_json_str" - byte_data.append(json_str.encode("utf-8")) - - with pytest.raises(json.JSONDecodeError): - success = backend.write(byte_data) - assert collection.estimated_document_count() == 0 - - success = backend.write(byte_data, ignore_errors=True) - assert success == 0 - assert collection.estimated_document_count() == 0 - - -def test_backends_database_mongo_put_method_with_target(mongo): - """Test the mongo backend put method.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - +def test_backends_data_mongo_data_backend_write_method_without_target( + mongo, mongo_backend +): + """Test the `MongoDataBackend.write` method, given a no `target` argument, should + write documents to the default collection. + """ + # pylint: disable=unused-argument + backend = mongo_backend() timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - - success = backend.write(statements, target=MONGO_TEST_COLLECTION) - assert success == 2 - assert collection.estimated_document_count() == 2 - - results = collection.find() + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + assert backend.write(documents) == 2 + results = backend.read() assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), + "_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}, } assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), + "_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}, } -def test_backends_database_mongo_put_method_with_no_ids(mongo): - """Test the mongo backend put method with no IDs.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{**timestamp}, {**timestamp}] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - - success = backend.write(statements, operation_type=BaseOperationType.INDEX) - assert success == 2 - assert collection.estimated_document_count() == 2 - - -def test_backends_database_mongo_put_method_with_custom_chunk_size(mongo): - """Test the mongo backend put method with a custom chunk_size.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - +def test_backends_data_mongo_data_backend_write_method_with_duplicated_key_error( + mongo, mongo_backend, caplog +): + """Test the `MongoDataBackend.write` method, given documents with duplicated ids, + should write the documents until it encounters a duplicated id and then raise a + `BackendException`. + """ + # pylint: disable=unused-argument + backend = mongo_backend() + # Identical statement IDs produce the same ObjectIds, leading to a + # duplicated key write error while trying to bulk import this batch. timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + documents = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + assert backend.write(documents, ignore_errors=True) == 2 + assert ( + backend.write( + documents, operation_type=BaseOperationType.CREATE, ignore_errors=True + ) + == 0 ) - backend = MongoDataBackend(settings) - success = backend.write(statements, chunk_size=2) - assert success == 2 - assert collection.estimated_document_count() == 2 - - results = collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } + assert list(backend.read()) == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + ] + # Given `ignore_errors` argument set to `False`, the `write` method should raise + # a `BackendException`. + msg = "E11000 duplicate key error collection" + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + backend.write(documents) + with pytest.raises(BackendException, match=msg) as exception_info: + backend.write(documents, operation_type=BaseOperationType.CREATE) + assert list(backend.read()) == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + ] -def test_backends_database_mongo_put_method_with_duplicated_key(mongo): - """Test the mongo backend put method with a duplicated key conflict.""" - # pylint: disable=unused-argument + assert ( + "ralph.backends.data.mongo", + logging.ERROR, + exception_info.value.args[0], + ) in caplog.record_tuples - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - # Identical statement ID produces the same ObjectId, leading to a - # duplicated key write error while trying to bulk import this batch +def test_backends_data_mongo_data_backend_write_method_with_delete_operation( + mongo, mongo_backend +): + """Test the `MongoDataBackend.write` method, given a `DELETE` `operation_type`, + should delete the provided documents from the MongoDB collection. + """ + # pylint: disable=unused-argument + backend = mongo_backend() timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ + documents = [ {"id": "foo", **timestamp}, {"id": "bar", **timestamp}, - {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + assert backend.write(documents) == 3 + assert len(list(backend.read())) == 3 + assert backend.write(documents[:2], operation_type=BaseOperationType.DELETE) == 2 + assert list(backend.read()) == [ + {"_id": "62b9ce92baa5a0964d3320fb", "_source": documents[2]} ] - with pytest.raises(BackendException, match="E11000 duplicate key error collection"): - backend.write(statements) - success = backend.write(statements, ignore_errors=True) - assert success == 0 + # Given binary data, the `write` method should have the same behaviour. + binary_documents = [json.dumps(documents[2]).encode("utf8")] + assert backend.write(binary_documents, operation_type=BaseOperationType.DELETE) == 1 + assert not list(backend.read()) -def test_backends_data_mongo_data_backend_write_method_with_update_operation( - mongo, +def test_backends_data_mongo_data_backend_write_method_with_delete_operation_failure( + mongo, mongo_backend, caplog ): - """Test the mongo backend write method with a update operation.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 + """Test the `MongoDataBackend.write` method with the `DELETE` `operation_type`, + given a MongoClient failure, should raise a `BackendException`. + """ + # pylint: disable=unused-argument + backend = mongo_backend() + msg = ( + "Failed to delete document chunk: cannot encode object: , " + "of type: " + ) + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + backend.write([{"id": object}], operation_type=BaseOperationType.DELETE) + + assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + with caplog.at_level(logging.WARNING): + assert ( + backend.write( + [{"id": object}], + operation_type=BaseOperationType.DELETE, + ignore_errors=True, + ) + == 0 + ) + + assert ("ralph.backends.data.mongo", logging.WARNING, msg) in caplog.record_tuples + +def test_backends_data_mongo_data_backend_write_method_with_update_operation( + mongo, mongo_backend +): + """Test the `MongoDataBackend.write` method, given an `UPDATE` `operation_type`, + should update the provided documents from the MongoDB collection. + """ + # pylint: disable=unused-argument + backend = mongo_backend() timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - success = backend.write(statements) - assert success == 2 - assert collection.estimated_document_count() == 2 + assert backend.write(documents) == 2 + new_timestamp = {"timestamp": "2022-06-27T16:36:50"} + documents = [{"id": "foo", **new_timestamp}, {"id": "bar", **new_timestamp}] + assert backend.write(documents, operation_type=BaseOperationType.UPDATE) == 2 - results = collection.find() + results = backend.read() assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, + "_id": "62b9ce922c26b46b68ffc68f", + "_source": {"id": "foo", **new_timestamp}, } assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, + "_id": "62b9ce92fcde2b2edba56bf4", + "_source": {"id": "bar", **new_timestamp}, } - timestamp = {"timestamp": "2022-06-27T16:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - success = backend.write( - statements, chunk_size=2, operation_type=BaseOperationType.UPDATE - ) - assert success == 2 - assert collection.estimated_document_count() == 2 - - results = collection.find() + # Given binary data, the `write` method should have the same behaviour. + binary_documents = [json.dumps({"id": "foo", "new_field": "bar"}).encode("utf8")] + assert backend.write(binary_documents, operation_type=BaseOperationType.UPDATE) == 1 + results = backend.read() assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, + "_id": "62b9ce922c26b46b68ffc68f", + "_source": {"id": "foo", "new_field": "bar"}, } -def test_backends_data_mongo_data_backend_write_method_with_delete_operation( - mongo, +def test_backends_data_mongo_data_backend_write_method_with_update_operation_failure( + mongo, mongo_backend, caplog ): - """Test the mongo backend write method with a delete operation.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "baz", **timestamp}, - ] - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + """Test the `MongoDataBackend.write` method with the `UPDATE` `operation_type`, + given a MongoClient failure, should raise a `BackendException`. + """ + # pylint: disable=unused-argument + backend = mongo_backend() + schema = { + "$jsonSchema": { + "bsonType": "object", + "required": ["_source"], + "properties": { + "_source": { + "bsonType": "object", + "required": ["timestamp"], + "description": "must be an object", + "properties": { + "timestamp": { + "bsonType": "string", + "description": "must be a string and is required", + } + }, + } + }, + } + } + backend.database.command( + "collMod", backend.collection.name, validator=schema, validationLevel="moderate" ) - backend = MongoDataBackend(settings) - - success = backend.write(statements, chunk_size=2) - assert success == 3 - assert collection.estimated_document_count() == 3 - - results = collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "baz", **timestamp}, - ] - success = backend.write( - statements, chunk_size=2, operation_type=BaseOperationType.DELETE + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + assert backend.write(documents) == 2 + documents = [{"id": "foo", "new": "field", **timestamp}, {"id": "bar"}] + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + assert ( + backend.write( + documents, operation_type=BaseOperationType.UPDATE, ignore_errors=True + ) + == 1 ) - assert success == 3 + assert next(backend.read())["_source"]["new"] == "field" - assert not list(backend.read()) + msg = "Failed to update document chunk: batch op errors occurred" + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg) as exception_info: + backend.write(documents, operation_type=BaseOperationType.UPDATE) - assert collection.estimated_document_count() == 0 + assert ( + "ralph.backends.data.mongo", + logging.ERROR, + exception_info.value.args[0], + ) in caplog.record_tuples -def test_backends_database_mongo_query_statements(monkeypatch, caplog, mongo): - """Tests the mongo backend query_statements method, given a search query failure, - should raise a BackendException and log the error. +def test_backends_data_mongo_data_backend_write_method_with_append_operation( + mongo_backend, caplog +): + """Test the `MongoDataBackend.write` method, given an `APPEND` `operation_type`, + should raise a `BackendParameterException`. """ - # pylint: disable=unused-argument,use-implicit-booleaness-not-comparison + backend = mongo_backend() + msg = "Append operation_type is not allowed." + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendParameterException, match=msg): + backend.write(data=[], operation_type=BaseOperationType.APPEND) - # Instantiate Mongo Databases + assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoLRSBackend(settings) - # Insert documents - timestamp = {"timestamp": "2022-06-27T15:36:50"} - meta = { - "actor": {"account": {"name": "test_name"}}, - "verb": {"id": "verb_id"}, - "object": {"id": "http://example.com", "objectType": "Activity"}, - } - collection_document = list( - MongoDataBackend.to_documents( - [ - {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta}, - {"id": "62b9ce92fcde2b2edba56bf4", **timestamp, **meta}, - ] - ) - ) - backend.bulk_import(collection_document) - - statement_parameters = StatementParameters() - statement_parameters.activity = "http://example.com" - statement_parameters.registration = ObjectId("62b9ce922c26b46b68ffc68f") - statement_parameters.since = "2020-01-01T00:00:00.000000+00:00" - statement_parameters.until = "2022-12-01T15:36:50" - statement_parameters.search_after = ObjectId("62b9ce922c26b46b68ffc68f") - statement_parameters.limit = 25 - statement_parameters.ascending = True - statement_parameters.related_activities = True - statement_parameters.related_agents = True - statement_parameters.format = "ids" - statement_parameters.agent = "test_name" - statement_parameters.verb = "verb_id" - statement_parameters.attachments = False - statement_parameters.search_after = ObjectId("62b9ce922c26b46b68ffc68f") - statement_parameters.statementId = "62b9ce922c26b46b68ffc68f" - statement_query_result = backend.query_statements(statement_parameters) - - assert len(statement_query_result.statements) > 0 - - -def test_backends_database_mongo_query_statements_with_search_query_failure( - monkeypatch, caplog, mongo +def test_backends_data_mongo_data_backend_write_method_with_create_operation( + mongo, mongo_backend ): - """Tests the mongo backend query_statements method, given a search query failure, - should raise a BackendException and log the error. + """Test the `MongoDataBackend.write` method, given an `CREATE` `operation_type`, + should insert the provided documents to the MongoDB collection. """ # pylint: disable=unused-argument - - def mock_find(**_): - """Mocks the MongoClient.collection.find method.""" - raise PyMongoError("Something is wrong") - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoLRSBackend(settings) - monkeypatch.setattr(backend.collection, "find", mock_find) - - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute MongoDB query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - backend.query_statements(StatementParameters()) - - logger_name = "ralph.backends.data.mongo" - msg = "Failed to execute MongoDB query. Something is wrong" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_database_mongo_query_statements_by_ids_with_search_query_failure( - monkeypatch, caplog, mongo + backend = mongo_backend() + documents = [ + {"timestamp": "2022-06-27T15:36:50"}, + {"timestamp": "2023-06-27T15:36:50"}, + ] + assert backend.write(documents, operation_type=BaseOperationType.CREATE) == 2 + results = backend.read() + assert next(results)["_source"]["timestamp"] == documents[0]["timestamp"] + assert next(results)["_source"]["timestamp"] == documents[1]["timestamp"] + + +@pytest.mark.parametrize( + "document,error", + [ + ({}, "statement {} has no 'id' field"), + ({"id": "1"}, "statement {'id': '1'} has no 'timestamp' field"), + ( + {"id": "1", "timestamp": ""}, + "statement {'id': '1', 'timestamp': ''} has an invalid 'timestamp' field", + ), + ], +) +def test_backends_data_mongo_data_backend_write_method_with_invalid_documents( + document, error, mongo, mongo_backend, caplog ): - """Tests the mongo backend query_statements_by_ids method, given a search query - failure, should raise a BackendException and log the error. + """Test the `MongoDataBackend.write` method, given invalid documents, should raise a + `BackendException`. """ # pylint: disable=unused-argument + backend = mongo_backend() + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=error): + backend.write([document]) - def mock_find(**_): - """Mocks the MongoClient.collection.find method.""" - raise ValueError("Something is wrong") - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoLRSBackend(settings) - monkeypatch.setattr(backend.collection, "find", mock_find) - caplog.set_level(logging.ERROR) + # Given binary data, the `write` method should have the same behaviour. + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=error): + backend.write([json.dumps(document).encode("utf8")]) - msg = "'Failed to execute MongoDB query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - backend.query_statements_by_ids(StatementParameters()) + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + with caplog.at_level(logging.WARNING): + assert backend.write([document], ignore_errors=True) == 0 - logger_name = "ralph.backends.data.mongo" - msg = "Failed to execute MongoDB query. Something is wrong" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] + assert ("ralph.backends.data.mongo", logging.WARNING, error) in caplog.record_tuples -def test_backends_database_mongo_query_statements_by_ids_with_multiple_collections( - mongo, mongo_forwarding +def test_backends_data_mongo_data_backend_write_method_with_unparsable_documents( + mongo_backend, caplog ): - """Tests the mongo backend query_statements_by_ids method, given a valid search - query, should execute the query uniquely on the specified collection and return the - expected results. + """Test the `MongoDataBackend.write` method, given unparsable raw documents, should + raise a `BackendException`. """ - # pylint: disable=unused-argument,use-implicit-booleaness-not-comparison - - # Instantiate Mongo Databases - - settings_1 = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend_1 = MongoLRSBackend(settings_1) - - settings_2 = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_FORWARDING_COLLECTION, + backend = mongo_backend() + msg = ( + "Failed to decode JSON: Expecting value: line 1 column 1 (char 0), " + "for document: b'not valid JSON!'" ) - backend_2 = MongoLRSBackend(settings_2) + msg_regex = msg.replace("(", r"\(").replace(")", r"\)") + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg_regex): + backend.write([b"not valid JSON!"]) - # Insert documents - timestamp = {"timestamp": "2022-06-27T15:36:50"} - collection_1_document = list( - MongoDataBackend.to_documents([{"id": "1", **timestamp}]) - ) - collection_2_document = list( - MongoDataBackend.to_documents([{"id": "2", **timestamp}]) - ) - backend_1.bulk_import(collection_1_document) - backend_2.bulk_import(collection_2_document) + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + with caplog.at_level(logging.WARNING): + assert backend.write([b"not valid JSON!"], ignore_errors=True) == 0 - # Check the expected search query results - assert backend_1.query_statements_by_ids(["1"]) == collection_1_document - assert backend_1.query_statements_by_ids(["2"]) == [] - assert backend_2.query_statements_by_ids(["1"]) == [] - assert backend_2.query_statements_by_ids(["2"]) == collection_2_document + assert ("ralph.backends.data.mongo", logging.WARNING, msg) in caplog.record_tuples -def test_backends_database_mongo_status(mongo): - """Test the Mongo status method. - - As pymongo is monkeypatching the MongoDB client to add admin object, it's - barely untestable. 😢 - """ - # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI=MONGO_TEST_CONNECTION_URI, - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, - ) - backend = MongoDataBackend(settings) - assert backend.status() == DataBackendStatus.OK +def test_backends_data_mongo_data_backend_write_method_with_no_data( + mongo_backend, caplog +): + """Test the `MongoDataBackend.write` method, given no documents, should return 0.""" + backend = mongo_backend() + with caplog.at_level(logging.INFO): + assert backend.write(data=[]) == 0 + msg = "Data Iterator is empty; skipping write to target." + assert ("ralph.backends.data.mongo", logging.INFO, msg) in caplog.record_tuples -def test_backends_database_mongo_status_connection_failed(mongo): - """Test the Mongo status method. - As pymongo is monkeypatching the MongoDB client to add admin object, it's - barely untestable. 😢 +def test_backends_data_mongo_data_backend_write_method_with_custom_chunk_size( + mongo, mongo_backend +): + """Test the `MongoDataBackend.write` method, given a custom chunk_size, should + insert the provided documents to target collection by batches of size `chunk_size`. """ # pylint: disable=unused-argument - - settings = MongoDataBackend.settings_class( - CONNECTION_URI="mongodb://localhost:27018", - DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, + backend = mongo_backend() + timestamp = {"timestamp": "2022-06-27T15:36:50"} + new_timestamp = {"timestamp": "2023-06-27T15:36:50"} + documents = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + new_documents = [ + {"id": "foo", **new_timestamp}, + {"id": "bar", **new_timestamp}, + {"id": "baz", **new_timestamp}, + ] + # Index operation type. + assert backend.write(documents, chunk_size=2) == 3 + assert list(backend.read()) == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **timestamp}}, + ] + # Delete operation type. + assert ( + backend.write(documents, chunk_size=1, operation_type=BaseOperationType.DELETE) + == 3 ) - backend = MongoDataBackend(settings) - assert backend.status() == DataBackendStatus.AWAY + assert not list(backend.read()) + # Create operation type. + assert ( + backend.write(documents, chunk_size=1, operation_type=BaseOperationType.CREATE) + == 3 + ) + assert list(backend.read()) == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **timestamp}}, + ] + # Update operation type. + assert ( + backend.write( + new_documents, chunk_size=3, operation_type=BaseOperationType.UPDATE + ) + == 3 + ) + assert list(backend.read()) == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **new_timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **new_timestamp}}, + {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **new_timestamp}}, + ] diff --git a/tests/backends/data/test_swift.py b/tests/backends/data/test_swift.py index 117363a72..c37fb6045 100644 --- a/tests/backends/data/test_swift.py +++ b/tests/backends/data/test_swift.py @@ -84,7 +84,7 @@ def test_backends_data_swift_data_backend_instantiation_with_settings(fs): try: SwiftDataBackend(settings_) except Exception as err: # pylint:disable=broad-except - pytest.fail(f"SwiftDataBackend should not raise exceptions: {err}") + pytest.fail(f"Two SwiftDataBackends should not raise exceptions: {err}") def test_backends_data_swift_data_backend_status_method_with_error_status( diff --git a/tests/backends/lrs/test_mongo.py b/tests/backends/lrs/test_mongo.py new file mode 100644 index 000000000..2effe53d2 --- /dev/null +++ b/tests/backends/lrs/test_mongo.py @@ -0,0 +1,370 @@ +"""Tests for Ralph MongoDB LRS backend.""" + +import logging + +import pytest +from bson.objectid import ObjectId +from pymongo import ASCENDING, DESCENDING + +from ralph.backends.lrs.base import StatementParameters +from ralph.exceptions import BackendException + +from tests.fixtures.backends import MONGO_TEST_FORWARDING_COLLECTION + + +@pytest.mark.parametrize( + "params,expected_query", + [ + # 0. Default query. + ( + {}, + { + "filter": {}, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 1. Query by statementId. + ( + {"statementId": "statementId"}, + { + "filter": {"_source.id": "statementId"}, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 2. Query by statementId and agent with mbox IFI. + ( + {"statementId": "statementId", "agent": {"mbox": "mailto:foo@bar.baz"}}, + { + "filter": { + "_source.id": "statementId", + "_source.actor.mbox": "mailto:foo@bar.baz", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 3. Query by statementId and agent with mbox_sha1sum IFI. + ( + { + "statementId": "statementId", + "agent": {"mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64"}, + }, + { + "filter": { + "_source.id": "statementId", + "_source.actor.mbox_sha1sum": ( + "a7a5b7462b862c8c8767d43d43e865ffff754a64" + ), + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 4. Query by statementId and agent with openid IFI. + ( + { + "statementId": "statementId", + "agent": {"openid": "http://toby.openid.example.org/"}, + }, + { + "filter": { + "_source.id": "statementId", + "_source.actor.openid": "http://toby.openid.example.org/", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 5. Query by statementId and agent with account IFI. + ( + { + "statementId": "statementId", + "agent": { + "account__name": "13936749", + "account__home_page": "http://www.example.com", + }, + }, + { + "filter": { + "_source.id": "statementId", + "_source.actor.account.name": "13936749", + "_source.actor.account.homePage": "http://www.example.com", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 6. Query by verb and activity. + ( + { + "verb": "http://adlnet.gov/expapi/verbs/attended", + "activity": "http://www.example.com/meetings/34534", + }, + { + "filter": { + "_source.verb.id": "http://adlnet.gov/expapi/verbs/attended", + "_source.object.id": "http://www.example.com/meetings/34534", + "_source.object.objectType": "Activity", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 7. Query by timerange (with since/until). + ( + { + "since": "2021-06-24T00:00:20.194929+00:00", + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "filter": { + "_source.timestamp": { + "$gt": "2021-06-24T00:00:20.194929+00:00", + "$lte": "2023-06-24T00:00:20.194929+00:00", + }, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 8. Query by timerange (with only until). + ( + { + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "filter": { + "_source.timestamp": { + "$lte": "2023-06-24T00:00:20.194929+00:00", + }, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 9. Query with pagination. + ( + {"search_after": "666f6f2d6261722d71757578", "pit_id": None}, + { + "filter": { + "_id": {"$lt": ObjectId("666f6f2d6261722d71757578")}, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 10. Query with pagination in ascending order. + ( + {"search_after": "666f6f2d6261722d71757578", "ascending": True}, + { + "filter": { + "_id": {"$gt": ObjectId("666f6f2d6261722d71757578")}, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", ASCENDING), + ("_id", ASCENDING), + ], + "query_string": None, + }, + ), + ], +) +def test_backends_lrs_mongo_lrs_backend_query_statements_query( + params, expected_query, mongo_lrs_backend, monkeypatch +): + """Test the `MongoLRSBackend.query_statements` method, given valid statement + parameters, should produce the expected MongoDB query. + """ + + def mock_read(query, chunk_size): + """Mock the `MongoLRSBackend.read` method.""" + assert query.dict() == expected_query + assert chunk_size == expected_query.get("limit") + return [{"_id": "search_after_id", "_source": {}}] + + backend = mongo_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + result = backend.query_statements(StatementParameters(**params)) + assert result.statements == [{}] + assert not result.pit_id + assert result.search_after == "search_after_id" + + +def test_backends_lrs_mongo_lrs_backend_query_statements_with_success( + mongo, mongo_lrs_backend +): + """Test the `MongoLRSBackend.query_statements` method, given a valid search query, + should return the expected statements. + """ + # pylint: disable=unused-argument + backend = mongo_lrs_backend() + + # Insert documents + timestamp = {"timestamp": "2022-06-27T15:36:50"} + meta = { + "actor": {"account": {"name": "test_name", "homePage": "http://example.com"}}, + "verb": {"id": "verb_id"}, + "object": {"id": "http://example.com", "objectType": "Activity"}, + } + documents = [ + {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta}, + {"id": "62b9ce92fcde2b2edba56bf4", **timestamp, **meta}, + ] + assert backend.write(documents) == 2 + + statement_parameters = StatementParameters( + statementId="62b9ce922c26b46b68ffc68f", + agent={ + "account__name": "test_name", + "account__home_page": "http://example.com", + }, + verb="verb_id", + activity="http://example.com", + since="2020-01-01T00:00:00.000000+00:00", + until="2022-12-01T15:36:50", + search_after="62b9ce922c26b46b68ffc68f", + ascending=True, + limit=25, + ) + statement_query_result = backend.query_statements(statement_parameters) + + assert statement_query_result.statements == [ + {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta} + ] + + +def test_backends_lrs_mongo_lrs_backend_query_statements_with_query_failure( + mongo_lrs_backend, monkeypatch, caplog +): + """Test the `MongoLRSBackend.query_statements` method, given a search query failure, + should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + msg = "Failed to execute MongoDB query: Something is wrong" + + def mock_read(**_): + """Mock the `MongoDataBackend.read` method always raising an Exception.""" + yield {"_source": {}} + raise BackendException(msg) + + backend = mongo_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + backend.query_statements(StatementParameters()) + + assert ( + "ralph.backends.lrs.mongo", + logging.ERROR, + "Failed to read from MongoDB", + ) in caplog.record_tuples + + +def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_with_query_failure( + mongo_lrs_backend, monkeypatch, caplog +): + """Test the `MongoLRSBackend.query_statements_by_ids` method, given a search query + failure, should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + msg = "Failed to execute MongoDB query: Something is wrong" + + def mock_read(**_): + """Mock the `MongoDataBackend.read` method always raising an Exception.""" + yield {"_source": {}} + raise BackendException(msg) + + backend = mongo_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + list(backend.query_statements_by_ids(StatementParameters())) + + assert ( + "ralph.backends.lrs.mongo", + logging.ERROR, + "Failed to read from MongoDB", + ) in caplog.record_tuples + + +def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_with_two_collections( + mongo, mongo_forwarding, mongo_lrs_backend +): + """Tests the `MongoLRSBackend.query_statements_by_ids` method, given a valid search + query, should execute the query only on the specified collection and return the + expected results. + """ + # pylint: disable=unused-argument + + # Instantiate Mongo Databases + backend_1 = mongo_lrs_backend() + backend_2 = mongo_lrs_backend(default_collection=MONGO_TEST_FORWARDING_COLLECTION) + + # Insert documents + timestamp = {"timestamp": "2022-06-27T15:36:50"} + assert backend_1.write([{"id": "1", **timestamp}]) == 1 + assert backend_2.write([{"id": "2", **timestamp}]) == 1 + + # Check the expected search query results + assert list(backend_1.query_statements_by_ids(["1"])) == [{"id": "1", **timestamp}] + assert not list(backend_1.query_statements_by_ids(["2"])) + assert not list(backend_2.query_statements_by_ids(["1"])) + assert list(backend_2.query_statements_by_ids(["2"])) == [{"id": "2", **timestamp}] diff --git a/tests/conftest.py b/tests/conftest.py index b6ec8e2cc..35243e071 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,9 @@ ldp_backend, lrs, mongo, + mongo_backend, mongo_forwarding, + mongo_lrs_backend, moto_fs, s3, s3_backend, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index ba83ebb54..910d026ae 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -26,6 +26,7 @@ from ralph.backends.data.es import ESDataBackend from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.data.ldp import LDPDataBackend +from ralph.backends.data.mongo import MongoDataBackend from ralph.backends.data.s3 import S3DataBackend, S3DataBackendSettings from ralph.backends.data.swift import SwiftDataBackend, SwiftDataBackendSettings from ralph.backends.database.clickhouse import ClickHouseDatabase @@ -35,6 +36,7 @@ from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.fs import FSLRSBackend +from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.backends.storage.s3 import S3Storage from ralph.backends.storage.swift import SwiftStorage from ralph.conf import ClickhouseClientOptions, Settings, core_settings @@ -149,12 +151,12 @@ def es_forwarding(): @pytest.fixture def fs_backend(fs, settings_fs): - """Returns the `get_fs_data_backend` function.""" + """Return the `get_fs_data_backend` function.""" # pylint: disable=invalid-name,redefined-outer-name,unused-argument fs.create_dir("foo") def get_fs_data_backend(path: str = "foo"): - """Returns an instance of FSDataBackend.""" + """Return an instance of `FSDataBackend`.""" settings = FSDataBackendSettings( DEFAULT_CHUNK_SIZE=1024, DEFAULT_DIRECTORY_PATH=path, @@ -213,6 +215,52 @@ def mongo(): yield mongo_client +@pytest.fixture +def mongo_backend(): + """Return the `get_mongo_data_backend` function.""" + + def get_mongo_data_backend( + connection_uri: str = MONGO_TEST_CONNECTION_URI, + default_collection: str = MONGO_TEST_COLLECTION, + client_options: dict = None, + ): + """Return an instance of `MongoDataBackend`.""" + settings = MongoDataBackend.settings_class( + CONNECTION_URI=connection_uri, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=default_collection, + CLIENT_OPTIONS=client_options if client_options else {}, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + ) + return MongoDataBackend(settings) + + return get_mongo_data_backend + + +@pytest.fixture +def mongo_lrs_backend(): + """Return the `get_mongo_lrs_backend` function.""" + + def get_mongo_lrs_backend( + connection_uri: str = MONGO_TEST_CONNECTION_URI, + default_collection: str = MONGO_TEST_COLLECTION, + client_options: dict = None, + ): + """Return an instance of MongoLRSBackend.""" + settings = MongoLRSBackend.settings_class( + CONNECTION_URI=connection_uri, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=default_collection, + CLIENT_OPTIONS=client_options if client_options else {}, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + ) + return MongoLRSBackend(settings) + + return get_mongo_lrs_backend + + @pytest.fixture def mongo_forwarding(): """Yield a second Mongo test client. See get_mongo_fixture above.""" @@ -342,11 +390,11 @@ def settings_fs(fs, monkeypatch): @pytest.fixture def ldp_backend(settings_fs): - """Returns the `get_ldp_data_backend` function.""" + """Return the `get_ldp_data_backend` function.""" # pylint: disable=invalid-name,redefined-outer-name,unused-argument def get_ldp_data_backend(service_name: str = "foo", stream_id: str = "bar"): - """Returns an instance of LDPDataBackend.""" + """Return an instance of LDPDataBackend.""" settings = LDPDataBackend.settings_class( APPLICATION_KEY="fake_key", APPLICATION_SECRET="fake_secret", @@ -518,10 +566,10 @@ def get_swift_storage(): @pytest.fixture def swift_backend(): - """Returns get_swift_data_backend function.""" + """Return get_swift_data_backend function.""" def get_swift_data_backend(): - """Returns an instance of SwiftDataBackend.""" + """Return an instance of SwiftDataBackend.""" settings = SwiftDataBackendSettings( AUTH_URL="https://auth.cloud.ovh.net/", USERNAME="os_username", From 6e42e64ef11a1f00039622b9e063a7677c53a3a2 Mon Sep 17 00:00:00 2001 From: SergioSim Date: Thu, 9 Mar 2023 13:11:50 +0100 Subject: [PATCH 19/65] =?UTF-8?q?=E2=9C=A8(backends)=20add=20AsyncMongoDat?= =?UTF-8?q?abase=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We want to provide an async version of our MongoDatabase backend. --- setup.cfg | 1 + src/ralph/backends/data/async_mongo.py | 321 +++++++ src/ralph/backends/data/mongo.py | 19 +- src/ralph/backends/lrs/async_mongo.py | 57 ++ tests/backends/data/test_async_mongo.py | 1086 +++++++++++++++++++++++ tests/backends/lrs/__init__.py | 1 + tests/backends/lrs/test_async_mongo.py | 392 ++++++++ tests/conftest.py | 2 + tests/fixtures/backends.py | 60 +- 9 files changed, 1925 insertions(+), 14 deletions(-) create mode 100644 src/ralph/backends/data/async_mongo.py create mode 100644 src/ralph/backends/lrs/async_mongo.py create mode 100644 tests/backends/data/test_async_mongo.py create mode 100644 tests/backends/lrs/test_async_mongo.py diff --git a/setup.cfg b/setup.cfg index 98a8d520d..bef3badc5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,7 @@ backend-lrs = httpx<0.25.0 # pin as Python 3.7 is no longer supported from release 0.25.0 more-itertools==10.1.0 backend-mongo = + motor[srv]>=3.1.1 pymongo[srv]>=4.0.0 python-dateutil>=2.8.2 backend-s3 = diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py new file mode 100644 index 000000000..8e1eb6738 --- /dev/null +++ b/src/ralph/backends/data/async_mongo.py @@ -0,0 +1,321 @@ +"""Async MongoDB data backend for Ralph.""" + +import json +import logging +from io import IOBase +from itertools import chain +from typing import Any, Dict, Iterable, Iterator, Union + +from bson.errors import BSONError +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo.collection import Collection +from pymongo.errors import BulkWriteError, ConnectionFailure, InvalidName, PyMongoError + +from ralph.backends.data.base import BaseOperationType +from ralph.backends.data.mongo import ( + MongoDataBackend, + MongoDataBackendSettings, + MongoQuery, +) +from ralph.exceptions import BackendException, BackendParameterException +from ralph.utils import parse_bytes_to_dict + +from ..data.base import BaseAsyncDataBackend, DataBackendStatus, enforce_query_checks + +logger = logging.getLogger(__name__) + + +class AsyncMongoDataBackend(BaseAsyncDataBackend): + """Async MongoDB data backend.""" + + name = "async_mongo" + query_model = MongoQuery + settings_class = MongoDataBackendSettings + + def __init__(self, settings: Union[settings_class, None] = None): + """Instantiate the asynchronous MongoDB client. + + Args: + settings (MongoDataBackendSettings or None): The data backend settings. + """ + self.settings = settings if settings else self.settings_class() + self.client = AsyncIOMotorClient( + self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.dict() + ) + self.database = self.client[self.settings.DEFAULT_DATABASE] + self.collection = self.database[self.settings.DEFAULT_COLLECTION] + + async def status(self) -> DataBackendStatus: + """Check the MongoDB connection status. + + Return: + DataBackendStatus: The status of the data backend. + """ + # Check MongoDB connection. + try: + await self.client.admin.command("ping") + except (ConnectionFailure, PyMongoError) as error: + logger.error("Failed to connect to MongoDB: %s", error) + return DataBackendStatus.AWAY + + # Check MongoDB server status. + try: + if (await self.client.admin.command("serverStatus")).get("ok") != 1.0: + logger.error("MongoDB `serverStatus` command did not return 1.0") + return DataBackendStatus.ERROR + except PyMongoError as error: + logger.error("Failed to get MongoDB server status: %s", error) + return DataBackendStatus.ERROR + + return DataBackendStatus.OK + + async def list( + self, target: Union[str, None] = None, details: bool = False, new: bool = False + ) -> Iterator[Union[str, dict]]: + """List collections in the target database. + + Args: + target (str or None): The MongoDB database name to list collections from. + If target is `None`, the `DEFAULT_DATABASE` is used instead. + details (bool): Get detailed collection information instead of just IDs. + new (bool): Ignored. + + Yield: + str: The next collection. (If `details` is False). + dict: The next collection details. (If `details` is True). + + Raise: + BackendException: If a failure during the list operation occurs. + BackendParameterException: If the `target` is not a valid database name. + """ + if new: + logger.warning("The `new` argument is ignored") + + try: + database = self.client[target] if target else self.database + except InvalidName as error: + msg = "The target=`%s` is not a valid database name: %s" + logger.error(msg, target, error) + raise BackendParameterException(msg % (target, error)) from error + + try: + collections = await database.list_collections() + for collection_info in collections: + if details: + yield collection_info + else: + yield collection_info.get("name") + except PyMongoError as error: + msg = "Failed to list MongoDB collections: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + @enforce_query_checks + async def read( + self, + *, + query: Union[str, MongoQuery] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, + raw_output: bool = False, + ignore_errors: bool = False, + ) -> Iterator[Union[bytes, dict]]: + """Read documents matching the `query` from `target` collection and yield them. + + Args: + query (str or MongoQuery): The MongoDB query to use when fetching documents. + target (str or None): The MongoDB collection name to query. + If target is `None`, the `DEFAULT_COLLECTION` is used instead. + chunk_size (int or None): The chunk size when reading documents by batches. + If chunk_size is `None` the `DEFAULT_CHUNK_SIZE` is used instead. + raw_output (bool): Whether to yield dictionaries or bytes. + ignore_errors (bool): Whether to ignore errors when reading documents. + + Yield: + bytes: The next raw document if `raw_output` is True. + dict: The next JSON parsed document if `raw_output` is False. + + Raise: + BackendException: If a failure occurs during MongoDB connection. + BackendParameterException: If a failure occurs with MongoDB collection. + """ + if not chunk_size: + chunk_size = self.settings.DEFAULT_CHUNK_SIZE + + query = (query.query_string if query.query_string else query).dict( + exclude={"query_string"}, exclude_unset=True + ) + try: + collection = self.database[target] if target else self.collection + except InvalidName as error: + msg = "The target=`%s` is not a valid collection name: %s" + logger.error(msg, target, error) + raise BackendParameterException(msg % (target, error)) from error + + reader = self._read_raw if raw_output else lambda _: _ + try: + async for document in collection.find(batch_size=chunk_size, **query): + document.update({"_id": str(document.get("_id"))}) + try: + yield reader(document) + except (TypeError, ValueError) as error: + msg = "Failed to encode MongoDB document with ID %s: %s" + document_id = document.get("_id") + logger.error(msg, document_id, error) + if ignore_errors: + logger.warning(msg, document_id, error) + continue + raise BackendException(msg % (document_id, error)) from error + except (PyMongoError, IndexError, TypeError, ValueError) as error: + msg = "Failed to execute MongoDB query: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + async def write( # pylint: disable=too-many-arguments + self, + data: Union[IOBase, Iterable[bytes], Iterable[dict]], + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, + ignore_errors: bool = False, + operation_type: Union[BaseOperationType, None] = None, + ) -> int: + """Write data documents to the target collection and return their count. + + Args: + data (Iterable or IOBase): The data containing documents to write. + target (str or None): The target MongoDB collection name. + chunk_size (int or None): The number of documents to write in one batch. + If chunk_size is `None` the `DEFAULT_CHUNK_SIZE` is used instead. + ignore_errors (bool): Whether to ignore errors or not. + operation_type (BaseOperationType or None): The mode of the write operation. + If `operation_type` is `None`, the `default_operation_type` is used + instead. See `BaseOperationType`. + + Return: + int: The number of documents written. + + Raise: + BackendException: If a failure occurs while writing to MongoDB or + during document decoding and `ignore_errors` is set to `False`. + BackendParameterException: If the `operation_type` is `APPEND` as it is not + supported. + """ + if not operation_type: + operation_type = self.default_operation_type + + if operation_type == BaseOperationType.APPEND: + msg = "Append operation_type is not allowed." + logger.error(msg) + raise BackendParameterException(msg) + + if not chunk_size: + chunk_size = self.settings.DEFAULT_CHUNK_SIZE + + collection = self.database[target] if target else self.collection + logger.debug( + "Start writing to the %s collection of the %s database (chunk size: %d)", + collection, + self.database, + chunk_size, + ) + + count = 0 + data = iter(data) + try: + first_record = next(data) + except StopIteration: + logger.warning("Data Iterator is empty; skipping write to target.") + return count + data = chain([first_record], data) + if isinstance(first_record, bytes): + data = parse_bytes_to_dict(data, ignore_errors, logger) + + if operation_type == BaseOperationType.UPDATE: + for batch in MongoDataBackend.iter_by_batch( + MongoDataBackend.to_replace_one(data), chunk_size + ): + count += await self._bulk_update(batch, ignore_errors, collection) + logger.info("Updated %d documents with success", count) + elif operation_type == BaseOperationType.DELETE: + for batch in MongoDataBackend.iter_by_batch( + MongoDataBackend.to_ids(data), chunk_size + ): + count += await self._bulk_delete(batch, ignore_errors, collection) + logger.info("Deleted %d documents with success", count) + else: + data = MongoDataBackend.to_documents( + data, ignore_errors, operation_type, logger + ) + for batch in MongoDataBackend.iter_by_batch(data, chunk_size): + count += await self._bulk_import(batch, ignore_errors, collection) + logger.info("Inserted %d documents with success", count) + + return count + + async def close(self) -> None: + """Close the AsyncIOMotorClient client. + + Raise: + BackendException: If a failure during the close operation occurs. + """ + try: + self.client.close() + except PyMongoError as error: + msg = "Failed to close AsyncIOMotorClient: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + + @staticmethod + async def _bulk_import(batch: list, ignore_errors: bool, collection: Collection): + """Insert a batch of documents into the selected database collection.""" + try: + new_documents = await collection.insert_many(batch) + except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: + msg = "Failed to insert document chunk: %s" + if ignore_errors: + logger.warning(msg, error) + return getattr(error, "details", {}).get("nInserted", 0) + raise BackendException(msg % error) from error + + inserted_count = len(new_documents.inserted_ids) + logger.debug("Inserted %d documents chunk with success", inserted_count) + return inserted_count + + @staticmethod + async def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): + """Delete a batch of documents from the selected database collection.""" + try: + deleted_documents = await collection.delete_many( + {"_source.id": {"$in": batch}} + ) + except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: + msg = "Failed to delete document chunk: %s" + if ignore_errors: + logger.warning(msg, error) + return getattr(error, "details", {}).get("nRemoved", 0) + raise BackendException(msg % error) from error + + deleted_count = deleted_documents.deleted_count + logger.debug("Deleted %d documents chunk with success", deleted_count) + return deleted_count + + @staticmethod + async def _bulk_update(batch: list, ignore_errors: bool, collection: Collection): + """Update a batch of documents into the selected database collection.""" + try: + updated_documents = await collection.bulk_write(batch) + except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: + msg = "Failed to update document chunk: %s" + if ignore_errors: + logger.warning(msg, error) + return getattr(error, "details", {}).get("nModified", 0) + logger.error(msg, error) + raise BackendException(msg % error) from error + + modified_count = updated_documents.modified_count + logger.debug("Updated %d documents chunk with success", modified_count) + return modified_count + + def _read_raw(self, document: Dict[str, Any]) -> bytes: + """Read the `document` dictionary and return bytes.""" + return json.dumps(document).encode(self.settings.LOCALE_ENCODING) diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index c7dcd1ae4..fcab760ff 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -289,7 +289,7 @@ def write( # pylint: disable=too-many-arguments count += self._bulk_delete(batch, ignore_errors, collection) logger.info("Deleted %d documents with success", count) else: - data = self.to_documents(data, ignore_errors, operation_type) + data = self.to_documents(data, ignore_errors, operation_type, logger) for batch in self.iter_by_batch(data, chunk_size): count += self._bulk_import(batch, ignore_errors, collection) logger.info("Inserted %d documents with success", count) @@ -325,7 +325,10 @@ def to_replace_one(data: Iterable[dict]) -> Iterable[ReplaceOne]: @staticmethod def to_documents( - data: Iterable[dict], ignore_errors: bool, operation_type: BaseOperationType + data: Iterable[dict], + ignore_errors: bool, + operation_type: BaseOperationType, + logger_class: logging.Logger, ) -> Generator[dict, None, None]: """Convert `data` statements to MongoDB documents. @@ -337,25 +340,25 @@ def to_documents( if "id" not in statement and operation_type == BaseOperationType.INDEX: msg = "statement %s has no 'id' field" if ignore_errors: - logger.warning("statement %s has no 'id' field", statement) + logger_class.warning("statement %s has no 'id' field", statement) continue - logger.error(msg, statement) + logger_class.error(msg, statement) raise BackendException(msg % statement) if "timestamp" not in statement: msg = "statement %s has no 'timestamp' field" if ignore_errors: - logger.warning(msg, statement) + logger_class.warning(msg, statement) continue - logger.error(msg, statement) + logger_class.error(msg, statement) raise BackendException(msg % statement) try: timestamp = int(isoparse(statement["timestamp"]).timestamp()) except ValueError as err: msg = "statement %s has an invalid 'timestamp' field" if ignore_errors: - logger.warning(msg, statement) + logger_class.warning(msg, statement) continue - logger.error(msg, statement) + logger_class.error(msg, statement) raise BackendException(msg % statement) from err document = { "_id": ObjectId( diff --git a/src/ralph/backends/lrs/async_mongo.py b/src/ralph/backends/lrs/async_mongo.py new file mode 100644 index 000000000..3b26c0f78 --- /dev/null +++ b/src/ralph/backends/lrs/async_mongo.py @@ -0,0 +1,57 @@ +"""Async MongoDB LRS backend for Ralph.""" + + +import logging +from typing import Iterator, List + +from ralph.backends.data.async_mongo import AsyncMongoDataBackend +from ralph.backends.lrs.base import ( + BaseAsyncLRSBackend, + StatementParameters, + StatementQueryResult, +) +from ralph.backends.lrs.mongo import MongoLRSBackend +from ralph.exceptions import BackendException, BackendParameterException + +logger = logging.getLogger(__name__) + + +class AsyncMongoLRSBackend(BaseAsyncLRSBackend, AsyncMongoDataBackend): + """Async MongoDB LRS backend implementation.""" + + settings_class = AsyncMongoDataBackend.settings_class + + async def query_statements( + self, params: StatementParameters + ) -> StatementQueryResult: + """Return the statements query payload using xAPI parameters.""" + query = MongoLRSBackend.get_query(params) + try: + mongo_response = [ + document + async for document in self.read(query=query, chunk_size=params.limit) + ] + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from async MongoDB") + raise error + + search_after = None + if mongo_response: + search_after = mongo_response[-1]["_id"] + + return StatementQueryResult( + statements=[document["_source"] for document in mongo_response], + pit_id=None, + search_after=search_after, + ) + + async def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: + """Yield statements with matching ids from the backend.""" + try: + async for document in self.read( + query={"filter": {"_source.id": {"$in": ids}}} + ): + yield document["_source"] + except (BackendException, BackendParameterException) as error: + logger.error("Failed to read from MongoDB") + raise error diff --git a/tests/backends/data/test_async_mongo.py b/tests/backends/data/test_async_mongo.py new file mode 100644 index 000000000..12782d984 --- /dev/null +++ b/tests/backends/data/test_async_mongo.py @@ -0,0 +1,1086 @@ +"""Tests for Ralph's async mongo data backend.""" # pylint: disable = too-many-lines + +import json +import logging + +import pytest +from bson.objectid import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo.errors import ConnectionFailure, PyMongoError + +from ralph.backends.data.async_mongo import ( + AsyncMongoDataBackend, + MongoDataBackendSettings, + MongoQuery, +) +from ralph.backends.data.base import BaseOperationType, DataBackendStatus +from ralph.backends.data.mongo import MongoClientOptions +from ralph.exceptions import BackendException, BackendParameterException + +from tests.fixtures.backends import ( + MONGO_TEST_COLLECTION, + MONGO_TEST_CONNECTION_URI, + MONGO_TEST_DATABASE, +) + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_default_instantiation( + monkeypatch, fs +): + """Test the `AsyncMongoDataBackend` default instantiation.""" + # pylint: disable=invalid-name + fs.create_file(".env") + backend_settings_names = [ + "CONNECTION_URI", + "DEFAULT_DATABASE", + "DEFAULT_COLLECTION", + "CLIENT_OPTIONS", + "DEFAULT_CHUNK_SIZE", + "LOCALE_ENCODING", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__DATA__MONGO__{name}", raising=False) + + assert AsyncMongoDataBackend.name == "async_mongo" + assert AsyncMongoDataBackend.query_model == MongoQuery + assert AsyncMongoDataBackend.default_operation_type == BaseOperationType.INDEX + assert AsyncMongoDataBackend.settings_class == MongoDataBackendSettings + backend = AsyncMongoDataBackend() + assert isinstance(backend.client, AsyncIOMotorClient) + assert backend.database.name == "statements" + assert backend.collection.name == "marsha" + assert backend.settings.CONNECTION_URI == "mongodb://localhost:27017/" + assert backend.settings.CLIENT_OPTIONS == MongoClientOptions() + assert backend.settings.DEFAULT_CHUNK_SIZE == 500 + assert backend.settings.LOCALE_ENCODING == "utf8" + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_instantiation_with_settings( + async_mongo_backend, +): + """Test the `AsyncMongoDataBackend` instantiation with settings.""" + backend = async_mongo_backend(default_collection="foo") + assert backend.database.name == MONGO_TEST_DATABASE + assert backend.collection.name == "foo" + assert backend.settings.CONNECTION_URI == MONGO_TEST_CONNECTION_URI + assert backend.settings.CLIENT_OPTIONS == MongoClientOptions() + assert backend.settings.DEFAULT_CHUNK_SIZE == 500 + assert backend.settings.LOCALE_ENCODING == "utf8" + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_status_with_connection_failure( + async_mongo_backend, monkeypatch, caplog +): + """Test the `AsyncMongoDataBackend.status` method, given a connection failure, + should return `DataBackendStatus.AWAY`. + """ + + class MockAsyncIOMotorClientAdmin: + """Mock the `AsyncIOMotorClient.admin` property.""" + + @staticmethod + async def command(command: str): + """Mock the `command` method always raising a `ConnectionFailure`.""" + assert command == "ping" + raise ConnectionFailure("Connection failure") + + class MockAsyncIOMotorClient: + """Mock the `motor.motor_asyncio.AsyncIOMotorClient`.""" + + admin = MockAsyncIOMotorClientAdmin + + backend = async_mongo_backend() + monkeypatch.setattr(backend, "client", MockAsyncIOMotorClient) + with caplog.at_level(logging.ERROR): + assert await backend.status() == DataBackendStatus.AWAY + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + "Failed to connect to MongoDB: Connection failure", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_status_with_error_status( + async_mongo_backend, monkeypatch, caplog +): + """Test the `AsyncMongoDataBackend.status` method, given a failed serverStatus + command, should return `DataBackendStatus.ERROR`. + """ + + class MockAsyncIOMotorClientAdmin: + """Mock the `AsyncIOMotorClient.admin` property.""" + + @staticmethod + async def command(command: str): + """Mock the `command` method always raising a `ConnectionFailure`.""" + if command == "ping": + return + assert command == "serverStatus" + raise PyMongoError("Server status failure") + + class MockAsyncIOMotorClient: + """Mock the `motor.motor_asyncio.AsyncIOMotorClient`.""" + + admin = MockAsyncIOMotorClientAdmin + + backend = async_mongo_backend() + monkeypatch.setattr(backend, "client", MockAsyncIOMotorClient) + with caplog.at_level(logging.ERROR): + assert await backend.status() == DataBackendStatus.ERROR + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + "Failed to get MongoDB server status: Server status failure", + ) in caplog.record_tuples + + # Given a MongoDB serverStatus query returning an ok status different from 1, + # the `status` method should return `DataBackendStatus.ERROR`. + + class MockAsyncIOMotorClientAdmin: # pylint: disable = function-redefined + """Mock the `AsyncIOMotorClient.admin` property.""" + + @staticmethod + async def command(*_, **__): + """Mock the `command` method always raising a `ConnectionFailure`.""" + return {"ok": 0} + + class MockAsyncIOMotorClient: # pylint: disable = function-redefined + """Mock the `motor.motor_asyncio.AsyncIOMotorClient`.""" + + admin = MockAsyncIOMotorClientAdmin + + monkeypatch.setattr(backend, "client", MockAsyncIOMotorClient) + + with caplog.at_level(logging.ERROR): + assert await backend.status() == DataBackendStatus.ERROR + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + "MongoDB `serverStatus` command did not return 1.0", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_status_with_ok_status( + async_mongo_backend, monkeypatch +): + """Test the `AsyncMongoDataBackend.status` method, given a successful connection + and serverStatus command, should return `DataBackendStatus.OK`. + """ + + class MockAsyncIOMotorClientAdmin: + """Mock the `AsyncIOMotorClient.admin` property.""" + + @staticmethod + async def command(command: str): # pylint: disable = unused-argument + """Mock the `command` method always ensuring the server is up.""" + return {"ok": 1.0} + + class MockAsyncIOMotorClient: + """Mock the `motor.motor_asyncio.AsyncIOMotorClient`.""" + + admin = MockAsyncIOMotorClientAdmin + + backend = async_mongo_backend() + monkeypatch.setattr(backend, "client", MockAsyncIOMotorClient) + + assert await backend.status() == DataBackendStatus.OK + + +@pytest.mark.parametrize("invalid_character", [" ", ".", "/", '"']) +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_list_method_with_invalid_target( + invalid_character, async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.list` method given an invalid `target` argument, + should raise a `BackendParameterException`. + """ + backend = async_mongo_backend() + msg = ( + f"The target=`foo{invalid_character}bar` is not a valid database name: " + f"database names cannot contain the character '{invalid_character}'" + ) + + with pytest.raises(BackendParameterException, match=msg): + with caplog.at_level(logging.ERROR): + async for result in backend.list(f"foo{invalid_character}bar"): + next(result) + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_list_method_with_failure( + async_mongo_backend, monkeypatch, caplog +): + """Test the `AsyncMongoDataBackend.list` method given a failure while retrieving + MongoDB collections, should raise a `BackendException`. + """ + + def mock_list_collections(): + """Mock the `list_collections` method always raising an exception.""" + raise PyMongoError("Connection error") + + backend = async_mongo_backend() + monkeypatch.setattr(backend.database, "list_collections", mock_list_collections) + msg = "Failed to list MongoDB collections: Connection error" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + async for result in backend.list(): + next(result) + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_list_method_without_history( + mongo, async_mongo_backend, monkeypatch +): + """Test the `AsyncMongoDataBackend.list` method without history.""" + # pylint: disable=unused-argument + + backend = async_mongo_backend() + + # Test `list` method with default parameters + result = [collection async for collection in backend.list()] + assert result == [MONGO_TEST_COLLECTION] + + # Test `list` method with a given target (database for MongoDB) + result = [ + collection async for collection in backend.list(target=MONGO_TEST_DATABASE) + ] + assert result == [MONGO_TEST_COLLECTION] + + # Test `list` method with detailed information about collections + result = [collection async for collection in backend.list(details=True)] + assert result[0]["name"] == MONGO_TEST_COLLECTION + + # Test `list` method with several collections + await backend.database.create_collection("bar") + await backend.database.create_collection("baz") + + result = [collection async for collection in backend.list()] + assert sorted(result) == sorted([MONGO_TEST_COLLECTION, "bar", "baz"]) + + result = [collection["name"] async for collection in backend.list(details=True)] + assert sorted(result) == (sorted([MONGO_TEST_COLLECTION, "bar", "baz"])) + + result = [collection async for collection in backend.list("non_existent_database")] + assert not result + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_list_method_with_history( + mongo, async_mongo_backend, caplog # pylint: disable=unused-argument +): + """Test the `AsyncMongoDataBackend.list` method given `new` argument set to + `True`, should log a warning message. + """ + backend = async_mongo_backend() + with caplog.at_level(logging.WARNING): + result = [ + collection + async for collection in backend.list("non_existent_database", new=True) + ] + assert not list(result) + + assert ( + "ralph.backends.data.async_mongo", + logging.WARNING, + "The `new` argument is ignored", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_with_raw_output( + mongo, + async_mongo_backend, +): + """Test the `AsyncMongoDataBackend.read` method with `raw_output` set to `True`.""" + # pylint: disable=unused-argument + backend = async_mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": "bar"}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, + ] + expected = [ + b'{"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}', + b'{"_id": "64945e530468d817b1f756da", "id": "bar"}', + b'{"_id": "64945e530468d817b1f756db", "id": "baz"}', + ] + await backend.collection.insert_many(documents) + await backend.database.foobar.insert_many(documents[:2]) + + result = [statement async for statement in backend.read(raw_output=True)] + assert result == expected + result = [ + statement async for statement in backend.read(raw_output=True, target="foobar") + ] + assert result == expected[:2] + result = [ + statement async for statement in backend.read(raw_output=True, chunk_size=2) + ] + assert result == expected + result = [ + statement async for statement in backend.read(raw_output=True, chunk_size=1000) + ] + assert result == expected + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_without_raw_output( + mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.read` method with `raw_output` set to + `False`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": "bar"}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, + ] + expected = [ + {"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}, + {"_id": "64945e530468d817b1f756da", "id": "bar"}, + {"_id": "64945e530468d817b1f756db", "id": "baz"}, + ] + await backend.collection.insert_many(documents) + await backend.database.foobar.insert_many(documents[:2]) + + assert [statement async for statement in backend.read()] == expected + assert [statement async for statement in backend.read(target="foobar")] == expected[ + :2 + ] + assert [statement async for statement in backend.read(chunk_size=2)] == expected + assert [statement async for statement in backend.read(chunk_size=1000)] == expected + + +@pytest.mark.parametrize( + "invalid_target,error", + [ + (".foo", "must not start or end with '.': '.foo'"), + ("foo.", "must not start or end with '.': 'foo.'"), + ("foo$bar", "must not contain '$': 'foo$bar'"), + ("foo..bar", "cannot be empty"), + ], +) +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_with_invalid_target( + invalid_target, + error, + async_mongo_backend, + caplog, +): + """Test the `AsyncMongoDataBackend.read` method given an invalid `target` argument, + should raise a `BackendParameterException`. + """ + backend = async_mongo_backend() + msg = ( + f"The target=`{invalid_target}` is not a valid collection name: " + f"collection names {error}" + ) + with pytest.raises(BackendParameterException, match=msg.replace("$", r"\$")): + with caplog.at_level(logging.ERROR): + async for statement in backend.read(target=invalid_target): + next(statement) + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_with_failure( + async_mongo_backend, monkeypatch, caplog +): + """Test the `AsyncMongoDataBackend.read` method given an AsyncIOMotorClient failure, + should raise a `BackendException`. + """ + + def mock_find(*_, **__): + """Mock the `motor.motor_asyncio.AsyncIOMotorClient.collection.find` + method returning a failing Cursor. + """ + raise PyMongoError("MongoDB internal failure") + + backend = async_mongo_backend() + monkeypatch.setattr(backend.collection, "find", mock_find) + msg = "Failed to execute MongoDB query: MongoDB internal failure" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + result = [statement async for statement in backend.read()] + next(result) + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_with_ignore_errors( + mongo, async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.read` method with `ignore_errors` set to `True`, + given a collection containing unparsable documents, should skip the invalid + documents. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": ObjectId()}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, + ] + expected = [ + b'{"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}', + b'{"_id": "64945e530468d817b1f756db", "id": "baz"}', + ] + await backend.collection.insert_many(documents) + await backend.database.foobar.insert_many(documents[:2]) + kwargs = {"raw_output": True, "ignore_errors": True} + with caplog.at_level(logging.WARNING): + assert [statement async for statement in backend.read(**kwargs)] == expected + assert [ + statement async for statement in backend.read(**kwargs, target="foobar") + ] == expected[:1] + assert [ + statement async for statement in backend.read(**kwargs, chunk_size=2) + ] == expected + assert [ + statement async for statement in backend.read(**kwargs, chunk_size=1000) + ] == expected + + assert ( + "ralph.backends.data.async_mongo", + logging.WARNING, + "Failed to encode MongoDB document with ID 64945e530468d817b1f756da: " + "Object of type ObjectId is not JSON serializable", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_without_ignore_errors( + mongo, async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.read` method with `ignore_errors` set to `False`, + given a collection containing unparsable documents, should raise a + `BackendException`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": ObjectId()}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "baz"}, + ] + expected = b'{"_id": "64945e53a4ee2699573e0d6f", "id": "foo"}' + await backend.collection.insert_many(documents) + await backend.database.foobar.insert_many(documents[:2]) + kwargs = {"raw_output": True, "ignore_errors": False} + msg = ( + "Failed to encode MongoDB document with ID 64945e530468d817b1f756da: " + "Object of type ObjectId is not JSON serializable" + ) + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + result = [statement async for statement in backend.read(**kwargs)] + assert next(result) == expected + next(result) + with pytest.raises(BackendException, match=msg): + result = [ + statement async for statement in backend.read(**kwargs, target="foobar") + ] + assert next(result) == expected + next(result) + with pytest.raises(BackendException, match=msg): + result = [ + statement async for statement in backend.read(**kwargs, chunk_size=2) + ] + assert next(result) == expected + next(result) + with pytest.raises(BackendException, match=msg): + result = [ + statement async for statement in backend.read(**kwargs, chunk_size=1000) + ] + assert next(result) == expected + next(result) + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +@pytest.mark.parametrize( + "query", + [ + '{"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}', + {"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}, + MongoQuery( + query_string='{"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}' + ), + # Given both `query_string` and other query arguments, only the `query_string` + # should be applied. + MongoQuery( + query_string='{"filter": {"id": {"$eq": "bar"}}, "projection": {"id": 1}}', + filter={"id": {"$eq": "foo"}}, + projection={"id": 0}, + ), + MongoQuery(filter={"id": {"$eq": "bar"}}, projection={"id": 1}), + ], +) +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_read_method_with_query( + query, mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.read` method given a query argument.""" + # pylint: disable=unused-argument + # Create records + backend = async_mongo_backend() + documents = [ + {"_id": ObjectId("64945e53a4ee2699573e0d6f"), "id": "foo", "qux": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756da"), "id": "bar", "qux": "foo"}, + {"_id": ObjectId("64945e530468d817b1f756db"), "id": "bar", "qux": "foo"}, + ] + expected = [ + {"_id": "64945e530468d817b1f756da", "id": "bar"}, + {"_id": "64945e530468d817b1f756db", "id": "bar"}, + ] + await backend.collection.insert_many(documents) + + assert [statement async for statement in backend.read(query=query)] == expected + assert [ + statement async for statement in backend.read(query=query, chunk_size=1) + ] == expected + assert [ + statement async for statement in backend.read(query=query, chunk_size=1000) + ] == expected + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_target( + mongo, + async_mongo_backend, +): + """Test the `AsyncMongoDataBackend.write` method, given a valid `target` argument, + should write documents to the target collection. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + assert await backend.write(documents, target="foo_target_collection") == 2 + + # The documents should not be written to the default collection. + assert not [statement async for statement in backend.read()] + + result = [ + statement async for statement in backend.read(target="foo_target_collection") + ] + assert result[0] == { + "_id": "62b9ce922c26b46b68ffc68f", + "_source": {"id": "foo", **timestamp}, + } + assert result[1] == { + "_id": "62b9ce92fcde2b2edba56bf4", + "_source": {"id": "bar", **timestamp}, + } + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_without_target( + mongo, + async_mongo_backend, +): + """Test the `AsyncMongoDataBackend.write` method, given a no `target` argument, + should write documents to the default collection. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + assert await backend.write(documents) == 2 + result = [statement async for statement in backend.read()] + assert result[0] == { + "_id": "62b9ce922c26b46b68ffc68f", + "_source": {"id": "foo", **timestamp}, + } + assert result[1] == { + "_id": "62b9ce92fcde2b2edba56bf4", + "_source": {"id": "bar", **timestamp}, + } + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_duplicated_key_error( # noqa: E501 + mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.write` method, given documents with duplicated + ids, should write the documents until it encounters a duplicated id and then raise + a `BackendException`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + # Identical statement IDs produce the same ObjectIds, leading to a + # duplicated key write error while trying to bulk import this batch. + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + assert await backend.write(documents, ignore_errors=True) == 2 + assert ( + await backend.write( + documents, operation_type=BaseOperationType.CREATE, ignore_errors=True + ) + == 0 + ) + assert [statement async for statement in backend.read()] == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + ] + + # Given `ignore_errors` argument set to `False`, the `write` method should raise + # a `BackendException`. + with pytest.raises(BackendException, match="E11000 duplicate key error collection"): + await backend.write(documents) + with pytest.raises(BackendException, match="E11000 duplicate key error collection"): + await backend.write(documents, operation_type=BaseOperationType.CREATE) + assert [statement async for statement in backend.read()] == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + ] + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_delete_operation( # noqa: E501 + mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.write` method, given a `DELETE` `operation_type`, + should delete the provided documents from the MongoDB collection. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + assert await backend.write(documents) == 3 + assert len([statement async for statement in backend.read()]) == 3 + assert ( + await backend.write(documents[:2], operation_type=BaseOperationType.DELETE) == 2 + ) + assert [statement async for statement in backend.read()] == [ + {"_id": "62b9ce92baa5a0964d3320fb", "_source": documents[2]} + ] + + # Given binary data, the `write` method should have the same behaviour. + binary_documents = [json.dumps(documents[2]).encode("utf8")] + assert ( + await backend.write(binary_documents, operation_type=BaseOperationType.DELETE) + == 1 + ) + assert not [statement async for statement in backend.read()] + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_delete_operation_failure( # noqa: E501 + mongo, async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.write` method with the `DELETE` `operation_type`, + given an AsyncIOMotorClient failure, should raise a `BackendException`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + msg = ( + "Failed to delete document chunk: cannot encode object: , " + "of type: " + ) + with pytest.raises(BackendException, match=msg): + await backend.write([{"id": object}], operation_type=BaseOperationType.DELETE) + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + with caplog.at_level(logging.WARNING): + assert ( + await backend.write( + [{"id": object}], + operation_type=BaseOperationType.DELETE, + ignore_errors=True, + ) + == 0 + ) + + assert ( + "ralph.backends.data.async_mongo", + logging.WARNING, + msg, + ) in caplog.record_tuples + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_update_operation( # noqa: E501 + mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.write` method, given an `UPDATE` + `operation_type`, should update the provided documents from the MongoDB collection. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + + assert await backend.write(documents) == 2 + new_timestamp = {"timestamp": "2022-06-27T16:36:50"} + documents = [{"id": "foo", **new_timestamp}, {"id": "bar", **new_timestamp}] + assert await backend.write(documents, operation_type=BaseOperationType.UPDATE) == 2 + + results = [statement async for statement in backend.read()] + assert results[0] == { + "_id": "62b9ce922c26b46b68ffc68f", + "_source": {"id": "foo", **new_timestamp}, + } + assert results[1] == { + "_id": "62b9ce92fcde2b2edba56bf4", + "_source": {"id": "bar", **new_timestamp}, + } + + # Given binary data, the `write` method should have the same behaviour. + binary_documents = [json.dumps({"id": "foo", "new_field": "bar"}).encode("utf8")] + assert ( + await backend.write(binary_documents, operation_type=BaseOperationType.UPDATE) + == 1 + ) + results = [statement async for statement in backend.read()] + assert results[0] == { + "_id": "62b9ce922c26b46b68ffc68f", + "_source": {"id": "foo", "new_field": "bar"}, + } + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_update_operation_failure( # noqa: E501 + mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.write` method with the `UPDATE` `operation_type`, + given an AsyncIOMotorClient failure, should raise a `BackendException`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + schema = { + "$jsonSchema": { + "bsonType": "object", + "required": ["_source"], + "properties": { + "_source": { + "bsonType": "object", + "required": ["timestamp"], + "description": "must be an object", + "properties": { + "timestamp": { + "bsonType": "string", + "description": "must be a string and is required", + } + }, + } + }, + } + } + await backend.database.command( + "collMod", backend.collection.name, validator=schema, validationLevel="moderate" + ) + timestamp = {"timestamp": "2022-06-27T15:36:50"} + documents = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] + assert await backend.write(documents) == 2 + documents = [{"id": "foo", "new": "field", **timestamp}, {"id": "bar"}] + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + assert ( + await backend.write( + documents, operation_type=BaseOperationType.UPDATE, ignore_errors=True + ) + == 1 + ) + assert [statement async for statement in backend.read()][0]["_source"][ + "new" + ] == "field" + + msg = "Failed to update document chunk: batch op errors occurred" + with pytest.raises(BackendException, match=msg): + await backend.write( + documents, + operation_type=BaseOperationType.UPDATE, + ignore_errors=False, + ) + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_append_operation( # noqa: E501 + async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.write` method, given an `APPEND` + `operation_type`, should raise a `BackendParameterException`. + """ + backend = async_mongo_backend() + msg = "Append operation_type is not allowed." + with pytest.raises(BackendParameterException, match=msg): + with caplog.at_level(logging.ERROR): + await backend.write(data=[], operation_type=BaseOperationType.APPEND) + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + msg, + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_create_operation( # noqa: E501 + mongo, async_mongo_backend +): + """Test the `AsyncMongoDataBackend.write` method, given an `CREATE` + `operation_type`, should insert the provided documents to the MongoDB collection. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + documents = [ + {"timestamp": "2022-06-27T15:36:50"}, + {"timestamp": "2023-06-27T15:36:50"}, + ] + assert await backend.write(documents, operation_type=BaseOperationType.CREATE) == 2 + results = [statement async for statement in backend.read()] + assert results[0]["_source"]["timestamp"] == documents[0]["timestamp"] + assert results[1]["_source"]["timestamp"] == documents[1]["timestamp"] + + +# pylint: disable=line-too-long +@pytest.mark.parametrize( + "document,error", + [ + ({}, "statement {} has no 'id' field"), + ({"id": "1"}, "statement {'id': '1'} has no 'timestamp' field"), + ( + {"id": "1", "timestamp": ""}, + "statement {'id': '1', 'timestamp': ''} has an invalid 'timestamp' field", + ), + ], +) +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_invalid_documents( # noqa: E501 + document, error, mongo, async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.write` method, given invalid documents, should + raise a `BackendException`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + with pytest.raises(BackendException, match=error): + await backend.write([document]) + + # Given binary data, the `write` method should have the same behaviour. + with pytest.raises(BackendException, match=error): + await backend.write([json.dumps(document).encode("utf8")]) + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + with caplog.at_level(logging.WARNING): + assert await backend.write([document], ignore_errors=True) == 0 + + assert ( + "ralph.backends.data.async_mongo", + logging.WARNING, + error, + ) in caplog.record_tuples + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_unparsable_documents( # noqa: E501 + async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.write` method, given unparsable raw documents, + should raise a `BackendException`. + """ + backend = async_mongo_backend() + msg = ( + "Failed to decode JSON: Expecting value: line 1 column 1 (char 0), " + "for document: b'not valid JSON!'" + ) + msg_regex = msg.replace("(", r"\(").replace(")", r"\)") + with pytest.raises(BackendException, match=msg_regex): + await backend.write([b"not valid JSON!"]) + + # Given `ignore_errors` argument set to `True`, the `write` method should not raise + # an exception. + with caplog.at_level(logging.WARNING): + assert await backend.write([b"not valid JSON!"], ignore_errors=True) == 0 + + assert ( + "ralph.backends.data.async_mongo", + logging.WARNING, + msg, + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_no_data( + async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.write` method, given no documents, should return + 0. + """ + backend = async_mongo_backend() + with caplog.at_level(logging.WARNING): + assert await backend.write(data=[]) == 0 + + msg = "Data Iterator is empty; skipping write to target." + assert ( + "ralph.backends.data.async_mongo", + logging.WARNING, + msg, + ) in caplog.record_tuples + + +# pylint: disable=line-too-long +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_write_method_with_custom_chunk_size( # noqa: E501 + mongo, async_mongo_backend, caplog +): + """Test the `AsyncMongoDataBackend.write` method, given a custom chunk_size, should + insert the provided documents to target collection by batches of size `chunk_size`. + """ + # pylint: disable=unused-argument + backend = async_mongo_backend() + timestamp = {"timestamp": "2022-06-27T15:36:50"} + new_timestamp = {"timestamp": "2023-06-27T15:36:50"} + documents = [ + {"id": "foo", **timestamp}, + {"id": "bar", **timestamp}, + {"id": "baz", **timestamp}, + ] + new_documents = [ + {"id": "foo", **new_timestamp}, + {"id": "bar", **new_timestamp}, + {"id": "baz", **new_timestamp}, + ] + # Index operation type. + with caplog.at_level(logging.DEBUG): + assert await backend.write(documents, chunk_size=2) == 3 + + assert ( + "ralph.backends.data.async_mongo", + logging.INFO, + f"Inserted {len(documents)} documents with success", + ) in caplog.record_tuples + + assert ( + "ralph.backends.data.async_mongo", + logging.INFO, + f"Inserted {len(documents)} documents with success", + ) in caplog.record_tuples + + assert [statement async for statement in backend.read()] == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **timestamp}}, + ] + # Delete operation type. + assert ( + await backend.write( + documents, chunk_size=1, operation_type=BaseOperationType.DELETE + ) + == 3 + ) + assert not [statement async for statement in backend.read()] + # Create operation type. + assert ( + await backend.write( + documents, chunk_size=1, operation_type=BaseOperationType.CREATE + ) + == 3 + ) + assert [statement async for statement in backend.read()] == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, + {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **timestamp}}, + ] + # Update operation type. + assert ( + await backend.write( + new_documents, chunk_size=3, operation_type=BaseOperationType.UPDATE + ) + == 3 + ) + assert [statement async for statement in backend.read()] == [ + {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **new_timestamp}}, + {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **new_timestamp}}, + {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **new_timestamp}}, + ] + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_close( + async_mongo_backend, monkeypatch, caplog +): + """Test the `AsyncMongoDataBackend.close` method, given a failed close, + should raise a BackendException. + """ + + class MockAsyncIOMotorClient: + """Mock the `motor.motor_asyncio.AsyncIOMotorClient`.""" + + @staticmethod + def close(): + """Mock the `close` method always raising a `PyMongoError`.""" + raise PyMongoError("Close failure") + + backend = async_mongo_backend() + monkeypatch.setattr(backend, "client", MockAsyncIOMotorClient) + + msg = "Failed to close AsyncIOMotorClient: Close failure" + with pytest.raises(BackendException, match=msg): + with caplog.at_level(logging.ERROR): + await backend.close() + + assert ( + "ralph.backends.data.async_mongo", + logging.ERROR, + "Failed to close AsyncIOMotorClient: Close failure", + ) in caplog.record_tuples diff --git a/tests/backends/lrs/__init__.py b/tests/backends/lrs/__init__.py index e69de29bb..6e031999e 100644 --- a/tests/backends/lrs/__init__.py +++ b/tests/backends/lrs/__init__.py @@ -0,0 +1 @@ +# noqa: D104 diff --git a/tests/backends/lrs/test_async_mongo.py b/tests/backends/lrs/test_async_mongo.py new file mode 100644 index 000000000..75ee8cd56 --- /dev/null +++ b/tests/backends/lrs/test_async_mongo.py @@ -0,0 +1,392 @@ +"""Tests for Ralph MongoDB LRS backend.""" + +import logging + +import pytest +from bson.objectid import ObjectId +from pymongo import ASCENDING, DESCENDING + +from ralph.backends.lrs.base import StatementParameters +from ralph.exceptions import BackendException + +from tests.fixtures.backends import MONGO_TEST_FORWARDING_COLLECTION + + +@pytest.mark.parametrize( + "params,expected_query", + [ + # 0. Default query. + ( + {}, + { + "filter": {}, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 1. Query by statementId. + ( + {"statementId": "statementId"}, + { + "filter": {"_source.id": "statementId"}, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 2. Query by statementId and agent with mbox IFI. + ( + {"statementId": "statementId", "agent": {"mbox": "mailto:foo@bar.baz"}}, + { + "filter": { + "_source.id": "statementId", + "_source.actor.mbox": "mailto:foo@bar.baz", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 3. Query by statementId and agent with mbox_sha1sum IFI. + ( + { + "statementId": "statementId", + "agent": {"mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64"}, + }, + { + "filter": { + "_source.id": "statementId", + "_source.actor.mbox_sha1sum": ( + "a7a5b7462b862c8c8767d43d43e865ffff754a64" + ), + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 4. Query by statementId and agent with openid IFI. + ( + { + "statementId": "statementId", + "agent": {"openid": "http://toby.openid.example.org/"}, + }, + { + "filter": { + "_source.id": "statementId", + "_source.actor.openid": "http://toby.openid.example.org/", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 5. Query by statementId and agent with account IFI. + ( + { + "statementId": "statementId", + "agent": { + "account__name": "13936749", + "account__home_page": "http://www.example.com", + }, + }, + { + "filter": { + "_source.id": "statementId", + "_source.actor.account.name": "13936749", + "_source.actor.account.homePage": "http://www.example.com", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 6. Query by verb and activity. + ( + { + "verb": "http://adlnet.gov/expapi/verbs/attended", + "activity": "http://www.example.com/meetings/34534", + }, + { + "filter": { + "_source.verb.id": "http://adlnet.gov/expapi/verbs/attended", + "_source.object.id": "http://www.example.com/meetings/34534", + "_source.object.objectType": "Activity", + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 7. Query by timerange (with since/until). + ( + { + "since": "2021-06-24T00:00:20.194929+00:00", + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "filter": { + "_source.timestamp": { + "$gt": "2021-06-24T00:00:20.194929+00:00", + "$lte": "2023-06-24T00:00:20.194929+00:00", + }, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 8. Query by timerange (with only until). + ( + { + "until": "2023-06-24T00:00:20.194929+00:00", + }, + { + "filter": { + "_source.timestamp": { + "$lte": "2023-06-24T00:00:20.194929+00:00", + }, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 9. Query with pagination. + ( + {"search_after": "666f6f2d6261722d71757578", "pit_id": None}, + { + "filter": { + "_id": {"$lt": ObjectId("666f6f2d6261722d71757578")}, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", DESCENDING), + ("_id", DESCENDING), + ], + "query_string": None, + }, + ), + # 10. Query with pagination in ascending order. + ( + {"search_after": "666f6f2d6261722d71757578", "ascending": True}, + { + "filter": { + "_id": {"$gt": ObjectId("666f6f2d6261722d71757578")}, + }, + "limit": None, + "projection": None, + "sort": [ + ("_source.timestamp", ASCENDING), + ("_id", ASCENDING), + ], + "query_string": None, + }, + ), + ], +) +@pytest.mark.anyio +async def test_backends_lrs_async_mongo_lrs_backend_query_statements_query( + params, expected_query, async_mongo_lrs_backend, monkeypatch +): + """Test the `AsyncMongoLRSBackend.query_statements` method, given valid statement + parameters, should produce the expected MongoDB query. + """ + + async def mock_read(query, chunk_size): + """Mock the `AsyncMongoLRSBackend.read` method.""" + assert query.dict() == expected_query + assert chunk_size == expected_query.get("limit") + yield {"_id": "search_after_id", "_source": {}} + + backend = async_mongo_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + result = await backend.query_statements(StatementParameters(**params)) + assert result.statements == [{}] + assert not result.pit_id + assert result.search_after == "search_after_id" + + await backend.close() + + +@pytest.mark.anyio +async def test_backends_lrs_async_mongo_lrs_backend_query_statements_with_success( + mongo, async_mongo_lrs_backend +): + """Test the `AsyncMongoLRSBackend.query_statements` method, given a valid search + query, should return the expected statements. + """ + # pylint: disable=unused-argument + backend = async_mongo_lrs_backend() + + # Insert documents + timestamp = {"timestamp": "2022-06-27T15:36:50"} + meta = { + "actor": {"account": {"name": "test_name", "homePage": "http://example.com"}}, + "verb": {"id": "verb_id"}, + "object": {"id": "http://example.com", "objectType": "Activity"}, + } + documents = [ + {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta}, + {"id": "62b9ce92fcde2b2edba56bf4", **timestamp, **meta}, + ] + assert await backend.write(documents) == 2 + + statement_parameters = StatementParameters( + statementId="62b9ce922c26b46b68ffc68f", + agent={ + "account__name": "test_name", + "account__home_page": "http://example.com", + }, + verb="verb_id", + activity="http://example.com", + since="2020-01-01T00:00:00.000000+00:00", + until="2022-12-01T15:36:50", + search_after="62b9ce922c26b46b68ffc68f", + ascending=True, + limit=25, + ) + statement_query_result = await backend.query_statements(statement_parameters) + + assert statement_query_result.statements == [ + {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta} + ] + + +@pytest.mark.anyio +async def test_backends_lrs_async_mongo_lrs_backend_query_statements_with_query_failure( + async_mongo_lrs_backend, monkeypatch, caplog +): + """Test the `AsyncMongoLRSBackend.query_statements` method, given a search query + failure, should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + msg = "Failed to execute MongoDB query: Something is wrong" + + async def mock_read(**_): + """Mock the `MongoDataBackend.read` method always raising an Exception.""" + yield {"_source": {}} + raise BackendException(msg) + + backend = async_mongo_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + await backend.query_statements(StatementParameters()) + + assert ( + "ralph.backends.lrs.async_mongo", + logging.ERROR, + "Failed to read from async MongoDB", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_query_failure( + async_mongo_lrs_backend, monkeypatch, caplog +): + """Test the `AsyncMongoLRSBackend.query_statements_by_ids` method, given a search + query failure, should raise a BackendException and log the error. + """ + # pylint: disable=unused-argument + + msg = "Failed to execute MongoDB query: Something is wrong" + + async def mock_read(**_): + """Mock the `AsyncMongoDataBackend.read` method always raising an Exception.""" + yield {"_source": {}} + raise BackendException(msg) + + backend = async_mongo_lrs_backend() + monkeypatch.setattr(backend, "read", mock_read) + + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + _ = [ + statement + async for statement in backend.query_statements_by_ids( + StatementParameters() + ) + ] + + assert ( + "ralph.backends.lrs.async_mongo", + logging.ERROR, + "Failed to read from MongoDB", + ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_two_collections( + mongo, mongo_forwarding, async_mongo_lrs_backend +): + """Tests the `AsyncMongoLRSBackend.query_statements_by_ids` method, given a valid + search query, should execute the query only on the specified collection and return + the expected results. + """ + # pylint: disable=unused-argument + + # Instantiate Mongo Databases + backend_1 = async_mongo_lrs_backend() + backend_2 = async_mongo_lrs_backend( + default_collection=MONGO_TEST_FORWARDING_COLLECTION + ) + + # Insert documents + timestamp = {"timestamp": "2022-06-27T15:36:50"} + assert await backend_1.write([{"id": "1", **timestamp}]) == 1 + assert await backend_2.write([{"id": "2", **timestamp}]) == 1 + + # Check the expected search query results + assert [ + statement async for statement in backend_1.query_statements_by_ids(["1"]) + ] == [{"id": "1", **timestamp}] + assert not [ + statement async for statement in backend_1.query_statements_by_ids(["2"]) + ] + assert not [ + statement async for statement in backend_2.query_statements_by_ids(["1"]) + ] + assert [ + statement async for statement in backend_2.query_statements_by_ids(["2"]) + ] == [{"id": "2", **timestamp}] diff --git a/tests/conftest.py b/tests/conftest.py index 35243e071..77a5a6d25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ anyio_backend, async_es_backend, async_es_lrs_backend, + async_mongo_backend, + async_mongo_lrs_backend, clickhouse, clickhouse_backend, clickhouse_lrs_backend, diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 910d026ae..d99a897f0 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -22,6 +22,7 @@ from pymongo.errors import CollectionInvalid from ralph.backends.data.async_es import AsyncESDataBackend +from ralph.backends.data.async_mongo import AsyncMongoDataBackend from ralph.backends.data.clickhouse import ClickHouseDataBackend from ralph.backends.data.es import ESDataBackend from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings @@ -33,6 +34,7 @@ from ralph.backends.database.es import ESDatabase from ralph.backends.database.mongo import MongoDatabase from ralph.backends.lrs.async_es import AsyncESLRSBackend +from ralph.backends.lrs.async_mongo import AsyncMongoLRSBackend from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.fs import FSLRSBackend @@ -187,6 +189,58 @@ def get_fs_lrs_backend(path: str = "foo"): return get_fs_lrs_backend +@pytest.fixture +def anyio_backend(): + """Select asyncio backend for pytest anyio.""" + return "asyncio" + + +@pytest.fixture +def async_mongo_backend(): + """Return the `get_mongo_data_backend` function.""" + + def get_mongo_data_backend( + connection_uri: str = MONGO_TEST_CONNECTION_URI, + default_collection: str = MONGO_TEST_COLLECTION, + client_options: dict = None, + ): + """Return an instance of `MongoDataBackend`.""" + settings = AsyncMongoDataBackend.settings_class( + CONNECTION_URI=connection_uri, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=default_collection, + CLIENT_OPTIONS=client_options if client_options else {}, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + ) + return AsyncMongoDataBackend(settings) + + return get_mongo_data_backend + + +@pytest.fixture +def async_mongo_lrs_backend(): + """Return the `async_get_mongo_lrs_backend` function.""" + + def async_get_mongo_lrs_backend( + connection_uri: str = MONGO_TEST_CONNECTION_URI, + default_collection: str = MONGO_TEST_COLLECTION, + client_options: dict = None, + ): + """Return an instance of AsyncMongoLRSBackend.""" + settings = AsyncMongoLRSBackend.settings_class( + CONNECTION_URI=connection_uri, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=default_collection, + CLIENT_OPTIONS=client_options if client_options else {}, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + ) + return AsyncMongoLRSBackend(settings) + + return async_get_mongo_lrs_backend + + def get_mongo_fixture( connection_uri=MONGO_TEST_CONNECTION_URI, database=MONGO_TEST_DATABASE, @@ -701,9 +755,3 @@ async def runserver(app, host=RUNSERVER_TEST_HOST, port=RUNSERVER_TEST_PORT): process.terminate() return runserver - - -@pytest.fixture -def anyio_backend(): - """Select asyncio backend for pytest anyio.""" - return "asyncio" From 4bfe594a6ee26be5f6abef8bdb39aa6bdbeeb139 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Fri, 18 Aug 2023 17:21:36 +0200 Subject: [PATCH 20/65] =?UTF-8?q?=F0=9F=8E=A8(backends)=20improve=20severa?= =?UTF-8?q?l=20points=20for=20all=20unified=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `get_query` method for Elasticsearch would be better namespaced under the ESLRSBackend. Changing it to a static method instead of a global function. - At initialization, data backends can either take settings or None. Setting `settings_class` to Optional to anticipate mypy warning when mypy will be added. - Piping x|None is preferred since Python 3.10, changing from Optional to Union[x|None] for backends as it would be easier to switch to pipes. - Changes to backend methods docstrings - Rename variable `new_documents` to be more explicit --- src/ralph/backends/data/async_es.py | 6 +- src/ralph/backends/data/base.py | 40 ++++---- src/ralph/backends/data/clickhouse.py | 31 +++--- src/ralph/backends/data/es.py | 28 +++--- src/ralph/backends/data/fs.py | 18 ++-- src/ralph/backends/data/ldp.py | 16 ++-- src/ralph/backends/data/mongo.py | 38 ++++---- src/ralph/backends/data/s3.py | 16 ++-- src/ralph/backends/data/swift.py | 14 +-- src/ralph/backends/lrs/async_es.py | 4 +- src/ralph/backends/lrs/es.py | 130 +++++++++++++------------- tests/backends/data/test_async_es.py | 15 ++- 12 files changed, 185 insertions(+), 171 deletions(-) diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 7f987b197..b53b0a256 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -30,7 +30,7 @@ class AsyncESDataBackend(BaseAsyncDataBackend): query_model = ESQuery settings_class = ESDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the asynchronous Elasticsearch client. Args: @@ -127,7 +127,7 @@ async def read( DSL. The Lucene query overrides the query DSL if present. See ESQuery. target (str or None): The target Elasticsearch index name to query. If target is `None`, the `DEFAULT_INDEX` is used instead. - chunk_size (int or None): The chunk size for reading batches of documents. + chunk_size (int or None): The chunk size when reading documents by batches. If chunk_size is `None` it defaults to `DEFAULT_CHUNK_SIZE`. raw_output (bool): Controls whether to yield dictionaries or bytes. ignore_errors (bool): Ignored. @@ -210,7 +210,7 @@ async def write( # pylint: disable=too-many-arguments instead. See `BaseOperationType`. Return: - int: The number of written documents. + int: The number of documents written. Raise: BackendException: If a failure occurs while writing to Elasticsearch or diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 84524901e..4d6ac2b0a 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum, unique from io import IOBase -from typing import Iterable, Iterator, Optional, Union +from typing import Iterable, Iterator, Union from pydantic import BaseModel, BaseSettings, ValidationError @@ -34,7 +34,7 @@ class Config: extra = "forbid" - query_string: Optional[str] + query_string: Union[str, None] @unique @@ -88,7 +88,7 @@ class BaseDataBackend(ABC): settings_class = BaseDataBackendSettings @abstractmethod - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the data backend. Args: @@ -96,7 +96,9 @@ def __init__(self, settings: settings_class = None): If `settings` is `None`, a default settings instance is used instead. """ - def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: + def validate_query( + self, query: Union[str, dict, BaseQuery, None] = None + ) -> BaseQuery: """Validate and transform the query.""" if query is None: query = self.query_model() @@ -134,7 +136,7 @@ def status(self) -> DataBackendStatus: @abstractmethod def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List containers in the data backend. E.g., collections, files, indexes. @@ -159,8 +161,8 @@ def read( self, *, query: Union[str, BaseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -195,10 +197,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write `data` records to the `target` container and return their count. @@ -235,7 +237,7 @@ class BaseAsyncDataBackend(ABC): settings_class = BaseDataBackendSettings @abstractmethod - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the data backend. Args: @@ -243,7 +245,9 @@ def __init__(self, settings: settings_class = None): If `settings` is `None`, a default settings instance is used instead. """ - def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: + def validate_query( + self, query: Union[str, dict, BaseQuery, None] = None + ) -> BaseQuery: """Validate and transform the query.""" if query is None: query = self.query_model() @@ -281,7 +285,7 @@ async def status(self) -> DataBackendStatus: @abstractmethod async def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List containers in the data backend. E.g., collections, files, indexes. @@ -306,8 +310,8 @@ async def read( self, *, query: Union[str, BaseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -342,10 +346,10 @@ async def read( async def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write `data` records to the `target` container and return their count. diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 1010d5756..b417b18cf 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -14,7 +14,6 @@ List, Literal, NamedTuple, - Optional, Union, ) from uuid import UUID, uuid4 @@ -96,18 +95,18 @@ class BaseClickHouseQuery(BaseQuery): """Base ClickHouse query model.""" select: Union[str, List[str]] = "event" - where: Optional[Union[str, List[str]]] - parameters: Optional[Dict] - limit: Optional[int] - sort: Optional[str] - column_oriented: Optional[bool] = False + where: Union[str, List[str], None] + parameters: Union[Dict, None] + limit: Union[int, None] + sort: Union[str, None] + column_oriented: Union[bool, None] = False class ClickHouseQuery(BaseClickHouseQuery): """ClickHouse query model.""" # pylint: disable=unsubscriptable-object - query_string: Optional[Json[BaseClickHouseQuery]] + query_string: Union[Json[BaseClickHouseQuery], None] class ClickHouseDataBackend(BaseDataBackend): @@ -118,7 +117,7 @@ class ClickHouseDataBackend(BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = ClickHouseDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the ClickHouse configuration. Args: @@ -167,7 +166,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List tables for a given database. @@ -203,8 +202,8 @@ def read( self, *, query: Union[str, ClickHouseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -214,7 +213,7 @@ def read( query (str or ClickHouseQuery): The query to use when fetching documents. target (str or None): The target table name to query. If target is `None`, the `event_table_name` is used instead. - chunk_size (int or None): The chunk size for reading batches of documents. + chunk_size (int or None): The chunk size when reading documents by batches. If chunk_size is `None` it defaults to `default_chunk_size`. raw_output (bool): Controls whether to yield dictionaries or bytes. ignore_errors (bool): If `True`, errors during the encoding operation @@ -295,10 +294,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write `data` documents to the `target` table and return their count. @@ -316,7 +315,7 @@ def write( # pylint: disable=too-many-arguments instead. See `BaseOperationType`. Return: - int: The number of written documents. + int: The number of documents written. Raise: BackendException: If a failure occurs while writing to ClickHouse or diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index f131466f2..d9b1174d4 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -4,7 +4,7 @@ from io import IOBase from itertools import chain from pathlib import Path -from typing import Iterable, Iterator, List, Literal, Optional, Union +from typing import Iterable, Iterator, List, Literal, Union from elasticsearch import ApiError, Elasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, streaming_bulk @@ -76,8 +76,8 @@ class ESQueryPit(BaseModel): time alive. """ - id: Optional[str] - keep_alive: Optional[str] + id: Union[str, None] + keep_alive: Union[str, None] class ESQuery(BaseQuery): @@ -103,9 +103,9 @@ class ESQuery(BaseQuery): query: dict = {"match_all": {}} pit: ESQueryPit = ESQueryPit() - size: Optional[int] + size: Union[int, None] sort: Union[str, List[dict]] = "_shard_doc" - search_after: Optional[list] + search_after: Union[list, None] track_total_hits: Literal[False] = False @@ -116,7 +116,7 @@ class ESDataBackend(BaseDataBackend): query_model = ESQuery settings_class = ESDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the Elasticsearch data backend. Args: @@ -156,7 +156,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.ERROR def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List available Elasticsearch indices, data streams and aliases. @@ -200,8 +200,8 @@ def read( self, *, query: Union[str, ESQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -213,7 +213,7 @@ def read( DSL. The Lucene query overrides the query DSL if present. See ESQuery. target (str or None): The target Elasticsearch index name to query. If target is `None`, the `DEFAULT_INDEX` is used instead. - chunk_size (int or None): The chunk size for reading batches of documents. + chunk_size (int or None): The chunk size when reading documents by batches. If chunk_size is `None` it defaults to `DEFAULT_CHUNK_SIZE`. raw_output (bool): Controls whether to yield dictionaries or bytes. ignore_errors (bool): Ignored. @@ -273,10 +273,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write data documents to the target index and return their count. @@ -294,7 +294,7 @@ def write( # pylint: disable=too-many-arguments instead. See `BaseOperationType`. Return: - int: The number of written documents. + int: The number of documents written. Raise: BackendException: If a failure occurs while writing to Elasticsearch or diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 454c7209d..1e3479ce8 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -56,7 +56,7 @@ class FSDataBackend(HistoryMixin, BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = FSDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Create the default target directory if it does not exist. Args: @@ -90,7 +90,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List files and directories in the target directory. @@ -146,8 +146,8 @@ def read( self, *, query: Union[str, BaseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -159,8 +159,8 @@ def read( If target is `None`, the `default_directory_path` is used instead. If target is a relative path, it is considered to be relative to the `default_directory_path`. - chunk_size (int or None): The chunk size for reading files. Ignored if - `raw_output` is set to False. + chunk_size (int or None): The chunk size when reading documents by batches. + Ignored if `raw_output` is set to False. raw_output (bool): Controls whether to yield bytes or dictionaries. ignore_errors (bool): If `True`, errors during the read operation will be ignored and logged. If `False` (default), a `BackendException` @@ -217,10 +217,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write data records to the target file and return their count. diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index b6fe84944..6f39d0b19 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -63,7 +63,7 @@ class LDPDataBackend(HistoryMixin, BaseDataBackend): name = "ldp" settings_class = LDPDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the OVH LDP client. Args: @@ -101,7 +101,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List archives for a given target stream_id. @@ -151,8 +151,8 @@ def read( self, *, query: Union[str, BaseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = 4096, + target: Union[str, None] = None, + chunk_size: Union[int, None] = 4096, raw_output: bool = True, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -162,7 +162,7 @@ def read( query (str or BaseQuery): The ID of the archive to read. target (str or None): The target stream_id containing the archives. If target is `None`, the `DEFAULT_STREAM_ID` is used instead. - chunk_size (int or None): The chunk size for reading archives. + chunk_size (int or None): The chunk size when reading archives by batch. raw_output (bool): Ignored. Always set to `True`. ignore_errors (bool): Ignored. @@ -216,10 +216,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Iterable[Union[bytes, dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """LDP data backend is read-only, calling this method will raise an error.""" msg = "LDP data backend is read-only, cannot write to %s" diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index fcab760ff..96f8704ee 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -7,7 +7,7 @@ import struct from io import IOBase from itertools import chain -from typing import Generator, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Generator, Iterable, Iterator, List, Tuple, Union from uuid import uuid4 from bson.errors import BSONError @@ -70,17 +70,17 @@ class Config(BaseSettingsConfig): class BaseMongoQuery(BaseQuery): """Base MongoDB query model.""" - filter: Optional[dict] - limit: Optional[int] - projection: Optional[dict] - sort: Optional[List[Tuple]] + filter: Union[dict, None] + limit: Union[int, None] + projection: Union[dict, None] + sort: Union[List[Tuple], None] class MongoQuery(BaseMongoQuery): """MongoDB query model.""" # pylint: disable=unsubscriptable-object - query_string: Optional[Json[BaseMongoQuery]] + query_string: Union[Json[BaseMongoQuery], None] class MongoDataBackend(BaseDataBackend): @@ -90,7 +90,7 @@ class MongoDataBackend(BaseDataBackend): query_model = MongoQuery settings_class = MongoDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the MongoDB client. Args: @@ -129,7 +129,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List collections in the `target` database. @@ -169,8 +169,8 @@ def read( self, *, query: Union[str, MongoQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -180,7 +180,7 @@ def read( query (str or MongoQuery): The MongoDB query to use when reading documents. target (str or None): The MongoDB collection name to query. If target is `None`, the `DEFAULT_COLLECTION` is used instead. - chunk_size (int or None): The chunk size for reading batches of documents. + chunk_size (int or None): The chunk size when reading archives by batch. If chunk_size is `None` the `DEFAULT_CHUNK_SIZE` is used instead. raw_output (bool): Whether to yield dictionaries or bytes. ignore_errors (bool): Whether to ignore errors when reading documents. @@ -224,10 +224,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write `data` documents to the `target` collection and return their count. @@ -242,7 +242,7 @@ def write( # pylint: disable=too-many-arguments instead. See `BaseOperationType`. Returns: - int: The number of written documents. + int: The number of documents written. Raises: BackendException: If a failure occurs while writing to MongoDB or @@ -397,7 +397,7 @@ def _bulk_import(batch: list, ignore_errors: bool, collection: Collection): def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): """Delete a `batch` of documents from the MongoDB `collection`.""" try: - new_documents = collection.delete_many({"_source.id": {"$in": batch}}) + deleted_documents = collection.delete_many({"_source.id": {"$in": batch}}) except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: msg = "Failed to delete document chunk: %s" if ignore_errors: @@ -406,7 +406,7 @@ def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): logger.error(msg, error) raise BackendException(msg % error) from error - deleted_count = new_documents.deleted_count + deleted_count = deleted_documents.deleted_count logger.debug("Deleted %d documents chunk with success", deleted_count) return deleted_count @@ -414,7 +414,7 @@ def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): def _bulk_update(batch: list, ignore_errors: bool, collection: Collection): """Update a `batch` of documents into the MongoDB `collection`.""" try: - new_documents = collection.bulk_write(batch) + updated_documents = collection.bulk_write(batch) except (BulkWriteError, PyMongoError, BSONError, ValueError) as error: msg = "Failed to update document chunk: %s" if ignore_errors: @@ -423,6 +423,6 @@ def _bulk_update(batch: list, ignore_errors: bool, collection: Collection): logger.error(msg, error) raise BackendException(msg % error) from error - modified_count = new_documents.modified_count + modified_count = updated_documents.modified_count logger.debug("Updated %d documents chunk with success", modified_count) return modified_count diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index 3222b41ad..33c57f175 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -71,7 +71,7 @@ class S3DataBackend(HistoryMixin, BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = S3DataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Instantiate the AWS S3 client.""" self.settings = settings if settings else self.settings_class() @@ -108,7 +108,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List objects for the target bucket. @@ -157,8 +157,8 @@ def read( self, *, query: Union[str, BaseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -168,7 +168,7 @@ def read( query: (str or BaseQuery): The ID of the object to read. target (str or None): The target bucket containing the objects. If target is `None`, the `default_bucket` is used instead. - chunk_size (int or None): The chunk size for reading objects. + chunk_size (int or None): The chunk size when reading objects by batch. raw_output (bool): Controls whether to yield bytes or dictionaries. ignore_errors (bool): If `True`, errors during the read operation will be ignored and logged. If `False` (default), a `BackendException` @@ -228,10 +228,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write `data` records to the `target` bucket and return their count. diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index c3163603a..12764c6e4 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -71,7 +71,7 @@ class SwiftDataBackend(HistoryMixin, BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = SwiftDataBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Union[settings_class, None] = None): """Prepares the options for the SwiftService.""" self.settings = settings if settings else self.settings_class() @@ -120,7 +120,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Union[str, None] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List files for the target container. @@ -163,8 +163,8 @@ def read( self, *, query: Union[str, BaseQuery] = None, - target: str = None, - chunk_size: Union[None, int] = 500, + target: Union[str, None] = None, + chunk_size: Union[int, None] = 500, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -240,10 +240,10 @@ def read( def write( # pylint: disable=too-many-arguments, disable=too-many-branches self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[None, str] = None, - chunk_size: Union[None, int] = None, + target: Union[str, None] = None, + chunk_size: Union[int, None] = None, ignore_errors: bool = False, - operation_type: Union[None, BaseOperationType] = None, + operation_type: Union[BaseOperationType, None] = None, ) -> int: """Write `data` records to the `target` container and returns their count. diff --git a/src/ralph/backends/lrs/async_es.py b/src/ralph/backends/lrs/async_es.py index 1842b299f..c9dae7da5 100644 --- a/src/ralph/backends/lrs/async_es.py +++ b/src/ralph/backends/lrs/async_es.py @@ -9,7 +9,7 @@ StatementParameters, StatementQueryResult, ) -from ralph.backends.lrs.es import get_query +from ralph.backends.lrs.es import ESLRSBackend from ralph.exceptions import BackendException, BackendParameterException logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ async def query_statements( self, params: StatementParameters ) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" - query = get_query(params=params) + query = ESLRSBackend.get_query(params=params) try: statements = [ document["_source"] diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index 66e00b8f7..5bf7d749e 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -22,7 +22,7 @@ class ESLRSBackend(BaseLRSBackend, ESDataBackend): def query_statements(self, params: StatementParameters) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" - query = get_query(params=params) + query = self.get_query(params=params) try: es_documents = self.read(query=query, chunk_size=params.limit) statements = [document["_source"] for document in es_documents] @@ -45,67 +45,67 @@ def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: logger.error("Failed to read from Elasticsearch") raise error - -def get_query(params: StatementParameters) -> ESQuery: - """Construct query from statement parameters.""" - es_query_filters = [] - - if params.statementId: - es_query_filters += [{"term": {"_id": params.statementId}}] - - add_agent_filters(es_query_filters, params.agent, "actor") - add_agent_filters(es_query_filters, params.authority, "authority") - - if params.verb: - es_query_filters += [{"term": {"verb.id.keyword": params.verb}}] - - if params.activity: - es_query_filters += [ - {"term": {"object.objectType.keyword": "Activity"}}, - {"term": {"object.id.keyword": params.activity}}, - ] - - if params.since: - es_query_filters += [{"range": {"timestamp": {"gt": params.since}}}] - - if params.until: - es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] - - es_query = { - "pit": ESQueryPit.construct(id=params.pit_id), - "size": params.limit, - "sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}], - } - if len(es_query_filters) > 0: - es_query["query"] = {"bool": {"filter": es_query_filters}} - - if params.ignore_order: - es_query["sort"] = "_shard_doc" - - if params.search_after: - es_query["search_after"] = params.search_after.split("|") - - # Note: `params` fields are validated thus we skip their validation in ESQuery. - return ESQuery.construct(**es_query) - - -def add_agent_filters( - es_query_filters: list, agent_params: AgentParameters, target_field: str -): - """Add filters relative to agents to `es_query_filters`.""" - if not agent_params: - return - if agent_params.mbox: - field = f"{target_field}.mbox.keyword" - es_query_filters += [{"term": {field: agent_params.mbox}}] - elif agent_params.mbox_sha1sum: - field = f"{target_field}.mbox_sha1sum.keyword" - es_query_filters += [{"term": {field: agent_params.mbox_sha1sum}}] - elif agent_params.openid: - field = f"{target_field}.openid.keyword" - es_query_filters += [{"term": {field: agent_params.openid}}] - elif agent_params.account__name: - field = f"{target_field}.account.name.keyword" - es_query_filters += [{"term": {field: agent_params.account__name}}] - field = f"{target_field}.account.homePage.keyword" - es_query_filters += [{"term": {field: agent_params.account__home_page}}] + @staticmethod + def get_query(params: StatementParameters) -> ESQuery: + """Construct query from statement parameters.""" + es_query_filters = [] + + if params.statementId: + es_query_filters += [{"term": {"_id": params.statementId}}] + + ESLRSBackend._add_agent_filters(es_query_filters, params.agent, "actor") + ESLRSBackend._add_agent_filters(es_query_filters, params.authority, "authority") + + if params.verb: + es_query_filters += [{"term": {"verb.id.keyword": params.verb}}] + + if params.activity: + es_query_filters += [ + {"term": {"object.objectType.keyword": "Activity"}}, + {"term": {"object.id.keyword": params.activity}}, + ] + + if params.since: + es_query_filters += [{"range": {"timestamp": {"gt": params.since}}}] + + if params.until: + es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] + + es_query = { + "pit": ESQueryPit.construct(id=params.pit_id), + "size": params.limit, + "sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}], + } + if len(es_query_filters) > 0: + es_query["query"] = {"bool": {"filter": es_query_filters}} + + if params.ignore_order: + es_query["sort"] = "_shard_doc" + + if params.search_after: + es_query["search_after"] = params.search_after.split("|") + + # Note: `params` fields are validated thus we skip their validation in ESQuery. + return ESQuery.construct(**es_query) + + @staticmethod + def _add_agent_filters( + es_query_filters: list, agent_params: AgentParameters, target_field: str + ): + """Add filters relative to agents to `es_query_filters`.""" + if not agent_params: + return + if agent_params.mbox: + field = f"{target_field}.mbox.keyword" + es_query_filters += [{"term": {field: agent_params.mbox}}] + elif agent_params.mbox_sha1sum: + field = f"{target_field}.mbox_sha1sum.keyword" + es_query_filters += [{"term": {field: agent_params.mbox_sha1sum}}] + elif agent_params.openid: + field = f"{target_field}.openid.keyword" + es_query_filters += [{"term": {field: agent_params.openid}}] + elif agent_params.account__name: + field = f"{target_field}.account.name.keyword" + es_query_filters += [{"term": {field: agent_params.account__name}}] + field = f"{target_field}.account.homePage.keyword" + es_query_filters += [{"term": {field: agent_params.account__home_page}}] diff --git a/tests/backends/data/test_async_es.py b/tests/backends/data/test_async_es.py index 1d3b8976d..3b121cf1b 100644 --- a/tests/backends/data/test_async_es.py +++ b/tests/backends/data/test_async_es.py @@ -751,13 +751,17 @@ async def test_backends_data_async_es_data_backend_write_method_without_ignore_e @pytest.mark.anyio async def test_backends_data_async_es_data_backend_write_method_with_ignore_errors( - es, async_es_backend + es, async_es_backend, caplog ): """Test the `AsyncESDataBackend.write` method with `ignore_errors` set to `True`, given badly formatted data, should should skip the invalid data. """ # pylint: disable=invalid-name,unused-argument + msg = ( + "Failed to decode JSON: Expecting value: line 1 column 1 (char 0), " + "for document: b'This is invalid JSON'" + ) records = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] # Patch a record with a non-expected type for the count field (should be # assigned as long) @@ -781,13 +785,20 @@ async def test_backends_data_async_es_data_backend_write_method_with_ignore_erro "This is invalid JSON".encode("utf-8"), json.dumps({"foo": "baz"}).encode("utf-8"), ] - assert await backend.write(data, chunk_size=2, ignore_errors=True) == 2 + with caplog.at_level(logging.WARNING): + assert await backend.write(data, chunk_size=2, ignore_errors=True) == 2 es.indices.refresh(index=ES_TEST_INDEX) hits = [statement async for statement in backend.read()] assert len(hits) == 11 assert [hit["_source"] for hit in hits[9:]] == [{"foo": "bar"}, {"foo": "baz"}] + assert ( + "ralph.backends.data.async_es", + logging.WARNING, + msg, + ) in caplog.record_tuples + await backend.close() From 7b9948b2eaa8ac9375f0df28d82bee613ebadb60 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Mon, 7 Aug 2023 10:35:09 +0200 Subject: [PATCH 21/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20add=20c?= =?UTF-8?q?lose=20method=20to=20sync=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Synchronous backends such as Elasticsearch or Mongo need their connection to be closed when finished. This commits adds an abstract method close to the BaseDataBackend interface, and implements it for backends that need it. --- src/ralph/backends/data/async_es.py | 3 +- src/ralph/backends/data/async_mongo.py | 2 +- src/ralph/backends/data/base.py | 10 +- src/ralph/backends/data/clickhouse.py | 17 +++ src/ralph/backends/data/es.py | 17 +++ src/ralph/backends/data/fs.py | 6 + src/ralph/backends/data/ldp.py | 6 + src/ralph/backends/data/mongo.py | 23 +++- src/ralph/backends/data/s3.py | 24 +++- src/ralph/backends/data/swift.py | 17 +++ tests/backends/data/test_async_es.py | 20 ++- tests/backends/data/test_async_mongo.py | 13 +- tests/backends/data/test_base.py | 6 + tests/backends/data/test_clickhouse.py | 57 +++++++++ tests/backends/data/test_es.py | 78 ++++++++++++ tests/backends/data/test_fs.py | 12 +- tests/backends/data/test_ldp.py | 10 ++ tests/backends/data/test_mongo.py | 55 ++++++++ tests/backends/data/test_s3.py | 159 ++++++++++++++++-------- tests/backends/data/test_swift.py | 70 ++++++++++- tests/backends/lrs/test_clickhouse.py | 6 + tests/backends/lrs/test_es.py | 11 ++ tests/backends/lrs/test_mongo.py | 6 + 23 files changed, 557 insertions(+), 71 deletions(-) diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index b53b0a256..3b39d7fcd 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -264,9 +264,10 @@ async def close(self) -> None: """Close the AsyncElasticsearch client. Raise: - BackendException: If a failure during the close operation occurs. + BackendException: If a failure occurs during the close operation. """ if not self._client: + logger.warning("No backend client to close.") return try: diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 8e1eb6738..a13d54324 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -256,7 +256,7 @@ async def close(self) -> None: """Close the AsyncIOMotorClient client. Raise: - BackendException: If a failure during the close operation occurs. + BackendException: If a failure occurs during the close operation. """ try: self.client.close() diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 4d6ac2b0a..da20328f3 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -227,6 +227,14 @@ def write( # pylint: disable=too-many-arguments BackendParameterException: If a backend argument value is not valid. """ + @abstractmethod + def close(self) -> None: + """Close the data backend client. + + Raise: + BackendException: If a failure occurs during the close operation. + """ + class BaseAsyncDataBackend(ABC): """Base async data backend interface.""" @@ -381,5 +389,5 @@ async def close(self) -> None: """Close the data backend client. Raise: - BackendException: If a failure during the close operation occurs. + BackendException: If a failure occurs during the close operation. """ diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index b417b18cf..5b5c09a4c 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -382,6 +382,23 @@ def write( # pylint: disable=too-many-arguments return count + def close(self) -> None: + """Close the ClickHouse backend client. + + Raise: + BackendException: If a failure occurs during the close operation. + """ + if not self._client: + logger.warning("No backend client to close.") + return + + try: + self.client.close() + except ClickHouseError as error: + msg = "Failed to close ClickHouse client: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + @staticmethod def _to_insert_tuples( data: Iterable[dict], diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index d9b1174d4..b7a7a9662 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -344,6 +344,23 @@ def write( # pylint: disable=too-many-arguments raise BackendException(msg % (error, details, count)) from error return count + def close(self) -> None: + """Close the Elasticsearch backend client. + + Raise: + BackendException: If a failure occurs during the close operation. + """ + if not self._client: + logger.warning("No backend client to close.") + return + + try: + self.client.close() + except TransportError as error: + msg = "Failed to close Elasticsearch client: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + @staticmethod def to_documents( data: Iterable[dict], diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 1e3479ce8..8bba06374 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -308,6 +308,12 @@ def write( # pylint: disable=too-many-arguments ) return 1 + def close(self) -> None: + """FS backend has nothing to close, this method is not implemented.""" + msg = "FS data backend does not support `close` method" + logger.error(msg) + raise NotImplementedError(msg) + @staticmethod def _read_raw(file: IO, chunk_size: int, _ignore_errors: bool) -> Iterator[bytes]: """Read the `file` in chunks of size `chunk_size` and yield them.""" diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index 6f39d0b19..cfa8cf18d 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -226,6 +226,12 @@ def write( # pylint: disable=too-many-arguments logger.error(msg, target) raise NotImplementedError(msg % target) + def close(self) -> None: + """LDP client does not support close, this method is not implemented.""" + msg = "LDP data backend does not support `close` method" + logger.error(msg) + raise NotImplementedError(msg) + def _get_archive_endpoint(self, stream_id: Union[None, str] = None) -> str: """Return OVH's archive endpoint.""" stream_id = stream_id if stream_id else self.stream_id diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index 96f8704ee..432952678 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -16,7 +16,13 @@ from pydantic import Json, MongoDsn, constr from pymongo import MongoClient, ReplaceOne from pymongo.collection import Collection -from pymongo.errors import BulkWriteError, ConnectionFailure, InvalidName, PyMongoError +from pymongo.errors import ( + BulkWriteError, + ConnectionFailure, + InvalidName, + InvalidOperation, + PyMongoError, +) from ralph.backends.data.base import ( BaseDataBackend, @@ -113,7 +119,7 @@ def status(self) -> DataBackendStatus: # Check MongoDB connection. try: self.client.admin.command("ping") - except ConnectionFailure as error: + except (ConnectionFailure, InvalidOperation) as error: logger.error("Failed to connect to MongoDB: %s", error) return DataBackendStatus.AWAY @@ -296,6 +302,19 @@ def write( # pylint: disable=too-many-arguments return count + def close(self) -> None: + """Close the MongoDB backend client. + + Raise: + BackendException: If a failure occurs during the close operation. + """ + try: + self.client.close() + except PyMongoError as error: + msg = "Failed to close MongoDB client: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + @staticmethod def iter_by_batch(data: Iterable[dict], chunk_size: int): """Iterate over `data` Iterable and yield batches of size `chunk_size`.""" diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index 33c57f175..c20521d80 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -11,6 +11,7 @@ from boto3.s3.transfer import TransferConfig from botocore.exceptions import ( ClientError, + EndpointConnectionError, ParamValidationError, ReadTimeoutError, ResponseStreamingError, @@ -102,7 +103,7 @@ def status(self) -> DataBackendStatus: """ try: self.client.head_bucket(Bucket=self.default_bucket_name) - except ClientError: + except (ClientError, EndpointConnectionError): return DataBackendStatus.ERROR return DataBackendStatus.OK @@ -197,7 +198,7 @@ def read( try: response = self.client.get_object(Bucket=target, Key=query.query_string) - except ClientError as err: + except (ClientError, EndpointConnectionError) as err: error_msg = err.response["Error"]["Message"] msg = "Failed to download %s: %s" logger.error(msg, query.query_string, error_msg) @@ -321,7 +322,7 @@ def write( # pylint: disable=too-many-arguments Config=TransferConfig(multipart_chunksize=chunk_size), ) response = self.client.head_object(Bucket=target_bucket, Key=target_object) - except (ClientError, ParamValidationError) as exc: + except (ClientError, ParamValidationError, EndpointConnectionError) as exc: msg = "Failed to upload %s" logger.error(msg, target) raise BackendException(msg % target) from exc @@ -340,6 +341,23 @@ def write( # pylint: disable=too-many-arguments return counter["count"] + def close(self) -> None: + """Close the S3 backend client. + + Raise: + BackendException: If a failure occurs during the close operation. + """ + if not self._client: + logger.warning("No backend client to close.") + return + + try: + self.client.close() + except ClientError as error: + msg = "Failed to close S3 backend client: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + @staticmethod def _read_raw( obj: StreamingBody, chunk_size: int, _ignore_errors: bool diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index 12764c6e4..18516d570 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -346,6 +346,23 @@ def write( # pylint: disable=too-many-arguments, disable=too-many-branches ) return count + def close(self) -> None: + """Close the Swift backend client. + + Raise: + BackendException: If a failure occurs during the close operation. + """ + if not self._connection: + logger.warning("No backend client to close.") + return + + try: + self.connection.close() + except ClientException as error: + msg = "Failed to close Swift backend client: %s" + logger.error(msg, error) + raise BackendException(msg % error) from error + def _details(self, container: str, name: str): """Return `name` object details from `container`.""" try: diff --git a/tests/backends/data/test_async_es.py b/tests/backends/data/test_async_es.py index 3b121cf1b..a18745897 100644 --- a/tests/backends/data/test_async_es.py +++ b/tests/backends/data/test_async_es.py @@ -825,7 +825,7 @@ async def test_backends_data_async_es_data_backend_write_method_with_datastream( @pytest.mark.anyio -async def test_backends_data_es_data_backend_close_method( +async def test_backends_data_es_data_backend_close_method_with_failure( async_es_backend, monkeypatch ): """Test the `AsyncESDataBackend.close` method.""" @@ -840,3 +840,21 @@ async def mock_connection_error(): with pytest.raises(BackendException, match="Failed to close Elasticsearch client"): await backend.close() + + +@pytest.mark.anyio +async def test_backends_data_es_data_backend_close_method(async_es_backend, caplog): + """Test the `AsyncESDataBackend.close` method.""" + + # No client instantiated + backend = async_es_backend() + await backend.close() + backend._client = None # pylint: disable=protected-access + with caplog.at_level(logging.WARNING): + await backend.close() + + assert ( + "ralph.backends.data.async_es", + logging.WARNING, + "No backend client to close.", + ) in caplog.record_tuples diff --git a/tests/backends/data/test_async_mongo.py b/tests/backends/data/test_async_mongo.py index 12782d984..3f916b449 100644 --- a/tests/backends/data/test_async_mongo.py +++ b/tests/backends/data/test_async_mongo.py @@ -1056,7 +1056,7 @@ async def test_backends_data_async_mongo_data_backend_write_method_with_custom_c @pytest.mark.anyio -async def test_backends_data_async_mongo_data_backend_close( +async def test_backends_data_async_mongo_data_backend_close_method_with_failure( async_mongo_backend, monkeypatch, caplog ): """Test the `AsyncMongoDataBackend.close` method, given a failed close, @@ -1084,3 +1084,14 @@ def close(): logging.ERROR, "Failed to close AsyncIOMotorClient: Close failure", ) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_async_mongo_data_backend_close_method(async_mongo_backend): + """Test the `AsyncMongoDataBackend.close` method.""" + + backend = async_mongo_backend() + + # Not possible to connect to client after closing it + await backend.close() + assert await backend.status() == DataBackendStatus.AWAY diff --git a/tests/backends/data/test_base.py b/tests/backends/data/test_base.py index deacdddfd..63ca68df2 100644 --- a/tests/backends/data/test_base.py +++ b/tests/backends/data/test_base.py @@ -39,6 +39,9 @@ def list(self): # pylint: disable=arguments-differ,missing-function-docstring def write(self): # pylint: disable=arguments-differ,missing-function-docstring pass + def close(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + MockBaseDataBackend().read(query=value) @@ -80,6 +83,9 @@ def list(self): # pylint: disable=arguments-differ,missing-function-docstring def write(self): # pylint: disable=arguments-differ,missing-function-docstring pass + def close(self): # pylint: disable=arguments-differ,missing-function-docstring + pass + with pytest.raises(BackendParameterException, match=error): with caplog.at_level(logging.ERROR): MockBaseDataBackend().read(query=value) diff --git a/tests/backends/data/test_clickhouse.py b/tests/backends/data/test_clickhouse.py index c0876ce37..6d5bc4f40 100644 --- a/tests/backends/data/test_clickhouse.py +++ b/tests/backends/data/test_clickhouse.py @@ -51,6 +51,7 @@ def test_backends_data_clickhouse_data_backend_default_instantiation(monkeypatch assert backend.event_table_name == "xapi_events_all" assert backend.default_chunk_size == 500 assert backend.locale_encoding == "utf8" + backend.close() def test_backends_data_clickhouse_data_backend_instantiation_with_settings(): @@ -75,6 +76,7 @@ def test_backends_data_clickhouse_data_backend_instantiation_with_settings(): assert backend.event_table_name == CLICKHOUSE_TEST_TABLE_NAME assert backend.default_chunk_size == 1000 assert backend.locale_encoding == "utf-16" + backend.close() def test_backends_data_clickhouse_data_backend_status( @@ -93,6 +95,7 @@ def mock_query(*_, **__): monkeypatch.setattr(backend.client, "query", mock_query) assert backend.status() == DataBackendStatus.AWAY + backend.close() def test_backends_data_clickhouse_data_backend_read_method_with_raw_output( @@ -130,6 +133,7 @@ def test_backends_data_clickhouse_data_backend_read_method_with_raw_output( assert len(results) == 3 assert isinstance(results[0], bytes) assert json.loads(results[0])["event"] == statements[0] + backend.close() # pylint: disable=unused-argument @@ -203,6 +207,7 @@ def test_backends_data_clickhouse_data_backend_read_method_with_a_custom_query( results = list(backend.read(query=query)) assert len(results) == 1 assert results[0]["event"] == statements[1] + backend.close() def test_backends_data_clickhouse_data_backend_read_method_with_failures( @@ -276,6 +281,7 @@ def mock_query(*_, **__): logging.ERROR, msg, ) in caplog.record_tuples + backend.close() def test_backends_data_clickhouse_data_backend_list_method( @@ -287,6 +293,7 @@ def test_backends_data_clickhouse_data_backend_list_method( assert list(backend.list(details=True)) == [{"name": CLICKHOUSE_TEST_TABLE_NAME}] assert list(backend.list(details=False)) == [CLICKHOUSE_TEST_TABLE_NAME] + backend.close() def test_backends_data_clickhouse_data_backend_list_method_with_failure( @@ -315,6 +322,7 @@ def mock_query(*_, **__): logging.ERROR, msg, ) in caplog.record_tuples + backend.close() def test_backends_data_clickhouse_data_backend_write_method_with_invalid_timestamp( @@ -343,6 +351,7 @@ def test_backends_data_clickhouse_data_backend_write_method_with_invalid_timesta match=msg, ): backend.write(statements, ignore_errors=False) + backend.close() def test_backends_data_clickhouse_data_backend_write_method_no_timestamp( @@ -388,6 +397,7 @@ def test_backends_data_clickhouse_data_backend_write_method_no_timestamp( logging.ERROR, f"Statement {statement} has an invalid 'id' or 'timestamp' field", ) not in caplog.record_tuples + backend.close() def test_backends_data_clickhouse_data_backend_write_method_with_duplicated_key( @@ -412,6 +422,7 @@ def test_backends_data_clickhouse_data_backend_write_method_with_duplicated_key( with pytest.raises(BackendException, match="Duplicate IDs found in batch"): backend.write(statements, ignore_errors=False) + backend.close() def test_backends_data_clickhouse_data_backend_write_method_chunks_on_error( @@ -435,6 +446,7 @@ def test_backends_data_clickhouse_data_backend_write_method_chunks_on_error( {"id": dupe_id, **timestamp}, ] assert backend.write(statements, ignore_errors=True) == 0 + backend.close() def test_backends_data_clickhouse_data_backend_write_method( @@ -472,6 +484,7 @@ def test_backends_data_clickhouse_data_backend_write_method( assert result[1]["event_id"] == native_statements[1]["id"] assert result[1]["emission_time"] == native_statements[1]["timestamp"] assert result[1]["event"] == statements[1] + backend.close() def test_backends_data_clickhouse_data_backend_write_method_bytes( @@ -514,6 +527,7 @@ def test_backends_data_clickhouse_data_backend_write_method_bytes( assert result[1]["event_id"] == native_statements[1]["id"] assert result[1]["emission_time"] == native_statements[1]["timestamp"] assert result[1]["event"] == statements[1] + backend.close() def test_backends_data_clickhouse_data_backend_write_method_bytes_failed( @@ -545,6 +559,7 @@ def test_backends_data_clickhouse_data_backend_write_method_bytes_failed( result = clickhouse.query(sql).result_set assert result[0][0] == 0 + backend.close() def test_backends_data_clickhouse_data_backend_write_method_empty( @@ -563,6 +578,7 @@ def test_backends_data_clickhouse_data_backend_write_method_empty( result = clickhouse.query(sql).result_set assert result[0][0] == 0 + backend.close() def test_backends_data_clickhouse_data_backend_write_method_wrong_operation_type( @@ -589,6 +605,7 @@ def test_backends_data_clickhouse_data_backend_write_method_wrong_operation_type match=f"{BaseOperationType.APPEND.name} operation_type is not allowed.", ): backend.write(data=statements, operation_type=BaseOperationType.APPEND) + backend.close() def test_backends_data_clickhouse_data_backend_write_method_with_custom_chunk_size( @@ -626,3 +643,43 @@ def test_backends_data_clickhouse_data_backend_write_method_with_custom_chunk_si assert result[1]["event_id"] == native_statements[1]["id"] assert result[1]["emission_time"] == native_statements[1]["timestamp"] assert result[1]["event"] == statements[1] + backend.close() + + +def test_backends_data_clickhouse_data_backend_close_method_with_failure( + clickhouse_backend, monkeypatch +): + """Test the `ClickHouseDataBackend.close` method with failure.""" + + backend = clickhouse_backend() + + def mock_connection_error(): + """ClickHouse client close mock that raises a connection error.""" + raise ClickHouseError("", (Exception("Mocked connection error"),)) + + monkeypatch.setattr(backend.client, "close", mock_connection_error) + + with pytest.raises(BackendException, match="Failed to close ClickHouse client"): + backend.close() + + +def test_backends_data_clickhouse_data_backend_close_method(clickhouse_backend, caplog): + """Test the `ClickHouseDataBackend.close` method.""" + + backend = clickhouse_backend() + + # Not possible to connect to client after closing it + backend.close() + assert backend.status() == DataBackendStatus.AWAY + + # No client instantiated + backend = clickhouse_backend() + backend._client = None # pylint: disable=protected-access + with caplog.at_level(logging.WARNING): + backend.close() + + assert ( + "ralph.backends.data.clickhouse", + logging.WARNING, + "No backend client to close.", + ) in caplog.record_tuples diff --git a/tests/backends/data/test_es.py b/tests/backends/data/test_es.py index ed0b116e6..0e3f06072 100644 --- a/tests/backends/data/test_es.py +++ b/tests/backends/data/test_es.py @@ -108,6 +108,8 @@ def test_backends_data_es_data_backend_instantiation_with_settings(): except Exception as err: # pylint:disable=broad-except pytest.fail(f"Two ESDataBackends should not raise exceptions: {err}") + backend.close() + def test_backends_data_es_data_backend_status_method(monkeypatch, es_backend, caplog): """Test the `ESDataBackend.status` method.""" @@ -157,6 +159,8 @@ def mock_connection_error(): "Exception(Mocked connection error)", ) in caplog.record_tuples + backend.close() + @pytest.mark.parametrize( "exception, error", @@ -189,6 +193,8 @@ def mock_get(index): f"Failed to read indices: {error}", ) in caplog.record_tuples + backend.close() + def test_backends_data_es_data_backend_list_method_without_history( es_backend, monkeypatch @@ -208,6 +214,8 @@ def mock_get(index): assert isinstance(result, Iterable) assert list(result) == list(indices.keys()) + backend.close() + def test_backends_data_es_data_backend_list_method_with_details( es_backend, monkeypatch @@ -229,6 +237,8 @@ def mock_get(index): {"index_2": {"info_2": "baz"}}, ] + backend.close() + def test_backends_data_es_data_backend_list_method_with_history( es_backend, caplog, monkeypatch @@ -247,6 +257,8 @@ def test_backends_data_es_data_backend_list_method_with_history( "The `new` argument is ignored", ) in caplog.record_tuples + backend.close() + @pytest.mark.parametrize( "exception, error", @@ -300,6 +312,8 @@ def mock_es_search_open_pit(**kwargs): "Failed to open Elasticsearch point in time: %s" % error.replace("\\", ""), ) in caplog.record_tuples + backend.close() + def test_backends_data_es_data_backend_read_method_with_ignore_errors( es, es_backend, monkeypatch, caplog @@ -319,6 +333,8 @@ def test_backends_data_es_data_backend_read_method_with_ignore_errors( "The `ignore_errors` argument is ignored", ) in caplog.record_tuples + backend.close() + def test_backends_data_es_data_backend_read_method_with_raw_ouput(es, es_backend): """Test the `ESDataBackend.read` method with `raw_output` set to `True`.""" @@ -331,6 +347,8 @@ def test_backends_data_es_data_backend_read_method_with_raw_ouput(es, es_backend assert isinstance(hit, bytes) assert json.loads(hit).get("_source") == documents[i] + backend.close() + def test_backends_data_es_data_backend_read_method_without_raw_ouput(es, es_backend): """Test the `ESDataBackend.read` method with `raw_output` set to `False`.""" @@ -343,6 +361,8 @@ def test_backends_data_es_data_backend_read_method_without_raw_ouput(es, es_back assert isinstance(hit, dict) assert hit.get("_source") == documents[i] + backend.close() + def test_backends_data_es_data_backend_read_method_with_query(es, es_backend, caplog): """Test the `ESDataBackend.read` method with a query.""" @@ -402,6 +422,8 @@ def test_backends_data_es_data_backend_read_method_with_query(es, es_backend, ca "'type': 'value_error.extra'}]", ) in caplog.record_tuples + backend.close() + def test_backends_data_es_data_backend_write_method_with_create_operation( es, es_backend, caplog @@ -449,6 +471,8 @@ def test_backends_data_es_data_backend_write_method_with_create_operation( hits = list(backend.read()) assert [hit["_source"] for hit in hits] == [{"value": str(idx)} for idx in range(9)] + backend.close() + def test_backends_data_es_data_backend_write_method_with_delete_operation( es, @@ -474,6 +498,8 @@ def test_backends_data_es_data_backend_write_method_with_delete_operation( assert len(hits) == 7 assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(3, 10)) + backend.close() + def test_backends_data_es_data_backend_write_method_with_update_operation( es, @@ -518,6 +544,8 @@ def test_backends_data_es_data_backend_write_method_with_update_operation( map(lambda x: str(x + 10), range(10)) ) + backend.close() + def test_backends_data_es_data_backend_write_method_with_append_operation( es_backend, caplog @@ -537,6 +565,8 @@ def test_backends_data_es_data_backend_write_method_with_append_operation( "Append operation_type is not supported.", ) in caplog.record_tuples + backend.close() + def test_backends_data_es_data_backend_write_method_with_target(es, es_backend): """Test the `ESDataBackend.write` method, given a target index, should insert @@ -570,6 +600,8 @@ def get_data(): {"value": "2"}, ] + backend.close() + def test_backends_data_es_data_backend_write_method_without_ignore_errors( es, es_backend, caplog @@ -639,6 +671,8 @@ def test_backends_data_es_data_backend_write_method_without_ignore_errors( hits = list(backend.read()) assert len(hits) == 5 + backend.close() + def test_backends_data_es_data_backend_write_method_with_ignore_errors(es, es_backend): """Test the `ESDataBackend.write` method with `ignore_errors` set to `True`, given @@ -676,6 +710,8 @@ def test_backends_data_es_data_backend_write_method_with_ignore_errors(es, es_ba assert len(hits) == 11 assert [hit["_source"] for hit in hits[9:]] == [{"foo": "bar"}, {"foo": "baz"}] + backend.close() + def test_backends_data_es_data_backend_write_method_with_datastream( es_data_stream, es_backend @@ -693,3 +729,45 @@ def test_backends_data_es_data_backend_write_method_with_datastream( hits = list(backend.read()) assert len(hits) == 10 assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) + + backend.close() + + +def test_backends_data_es_data_backend_close_method_with_failure( + es_backend, monkeypatch +): + """Test the `ESDataBackend.close` method.""" + + backend = es_backend() + + def mock_connection_error(): + """ES client close mock that raises a connection error.""" + raise ESConnectionError("", (Exception("Mocked connection error"),)) + + monkeypatch.setattr(backend.client, "close", mock_connection_error) + + with pytest.raises(BackendException, match="Failed to close Elasticsearch client"): + backend.close() + + +def test_backends_data_es_data_backend_close_method(es_backend, caplog): + """Test the `ESDataBackend.close` method.""" + + backend = es_backend() + backend.status() + + # Not possible to connect to client after closing it + backend.close() + assert backend.status() == DataBackendStatus.AWAY + + # No client instantiated + backend = es_backend() + backend._client = None # pylint: disable=protected-access + with caplog.at_level(logging.WARNING): + backend.close() + + assert ( + "ralph.backends.data.es", + logging.WARNING, + "No backend client to close.", + ) in caplog.record_tuples diff --git a/tests/backends/data/test_fs.py b/tests/backends/data/test_fs.py index 51779a34f..9d6133e72 100644 --- a/tests/backends/data/test_fs.py +++ b/tests/backends/data/test_fs.py @@ -1,4 +1,4 @@ -"""Tests for Ralph fs data backend""" +"""Tests for Ralph fs data backend""" # pylint: disable = too-many-lines import json import logging import os @@ -996,3 +996,13 @@ def test_backends_data_fs_data_backend_write_method_without_target( "timestamp": frozen_now, }, ] + + +def test_backends_data_fs_data_backend_close_method(fs_backend): + """Test that the `FSDataBackend.close` method raise an error.""" + + backend = fs_backend() + + error = "FS data backend does not support `close` method" + with pytest.raises(NotImplementedError, match=error): + backend.close() diff --git a/tests/backends/data/test_ldp.py b/tests/backends/data/test_ldp.py index 20740980c..a80e7a8b8 100644 --- a/tests/backends/data/test_ldp.py +++ b/tests/backends/data/test_ldp.py @@ -698,3 +698,13 @@ def mock_post(url): backend = ldp_backend() monkeypatch.setattr(backend.client, "post", mock_post) assert backend._url(archive_name) == archive_url + + +def test_backends_data_ldp_data_backend_close_method(ldp_backend): + """Test that the `LDPDataBackend.close` method raise an error.""" + + backend = ldp_backend() + + error = "LDP data backend does not support `close` method" + with pytest.raises(NotImplementedError, match=error): + backend.close() diff --git a/tests/backends/data/test_mongo.py b/tests/backends/data/test_mongo.py index 25d5d6049..2c19b2220 100644 --- a/tests/backends/data/test_mongo.py +++ b/tests/backends/data/test_mongo.py @@ -51,6 +51,7 @@ def test_backends_data_mongo_data_backend_default_instantiation(monkeypatch, fs) assert backend.settings.CLIENT_OPTIONS == MongoClientOptions() assert backend.settings.DEFAULT_CHUNK_SIZE == 500 assert backend.settings.LOCALE_ENCODING == "utf8" + backend.close() def test_backends_data_mongo_data_backend_instantiation_with_settings(): @@ -75,6 +76,7 @@ def test_backends_data_mongo_data_backend_instantiation_with_settings(): MongoDataBackend(settings) except Exception as err: # pylint:disable=broad-except pytest.fail(f"Two MongoDataBackends should not raise exceptions: {err}") + backend.close() def test_backends_data_mongo_data_backend_status_with_connection_failure( @@ -163,6 +165,7 @@ def test_backends_data_mongo_data_backend_status_with_ok_status(mongo_backend): """ backend = mongo_backend() assert backend.status() == DataBackendStatus.OK + backend.close() @pytest.mark.parametrize("invalid_character", [" ", ".", "/", '"']) @@ -182,6 +185,7 @@ def test_backends_data_mongo_data_backend_list_method_with_invalid_target( list(backend.list(f"foo{invalid_character}bar")) assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_list_method_with_failure( @@ -203,6 +207,7 @@ def list_collections(): list(backend.list()) assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_list_method_without_history( @@ -221,6 +226,7 @@ def test_backends_data_mongo_data_backend_list_method_without_history( sorted([MONGO_TEST_COLLECTION, "bar", "baz"]) ) assert not list(backend.list("non_existent_database")) + backend.close() def test_backends_data_mongo_data_backend_list_method_with_history( @@ -238,6 +244,7 @@ def test_backends_data_mongo_data_backend_list_method_with_history( logging.WARNING, "The `new` argument is ignored", ) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_read_method_with_raw_output( @@ -262,6 +269,7 @@ def test_backends_data_mongo_data_backend_read_method_with_raw_output( assert list(backend.read(raw_output=True, target="foobar")) == expected[:2] assert list(backend.read(raw_output=True, chunk_size=2)) == expected assert list(backend.read(raw_output=True, chunk_size=1000)) == expected + backend.close() def test_backends_data_mongo_data_backend_read_method_without_raw_output( @@ -286,6 +294,7 @@ def test_backends_data_mongo_data_backend_read_method_without_raw_output( assert list(backend.read(target="foobar")) == expected[:2] assert list(backend.read(chunk_size=2)) == expected assert list(backend.read(chunk_size=1000)) == expected + backend.close() @pytest.mark.parametrize( @@ -313,6 +322,7 @@ def test_backends_data_mongo_data_backend_read_method_with_invalid_target( list(backend.read(target=invalid_target)) assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_read_method_with_failure( @@ -336,6 +346,7 @@ def mock_find(batch_size, query=None): list(backend.read()) assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_read_method_with_ignore_errors( @@ -370,6 +381,7 @@ def test_backends_data_mongo_data_backend_read_method_with_ignore_errors( "Failed to convert document to bytes: " "Object of type ObjectId is not JSON serializable", ) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_read_method_without_ignore_errors( @@ -414,6 +426,7 @@ def test_backends_data_mongo_data_backend_read_method_without_ignore_errors( error_log = ("ralph.backends.data.mongo", logging.ERROR, msg) assert len(list(filter(lambda x: x == error_log, caplog.record_tuples))) == 4 + backend.close() @pytest.mark.parametrize( @@ -454,6 +467,7 @@ def test_backends_data_mongo_data_backend_read_method_with_query( assert list(backend.read(query=query)) == expected assert list(backend.read(query=query, chunk_size=1)) == expected assert list(backend.read(query=query, chunk_size=1000)) == expected + backend.close() def test_backends_data_mongo_data_backend_write_method_with_target( @@ -480,6 +494,7 @@ def test_backends_data_mongo_data_backend_write_method_with_target( "_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}, } + backend.close() def test_backends_data_mongo_data_backend_write_method_without_target( @@ -502,6 +517,7 @@ def test_backends_data_mongo_data_backend_write_method_without_target( "_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}, } + backend.close() def test_backends_data_mongo_data_backend_write_method_with_duplicated_key_error( @@ -555,6 +571,7 @@ def test_backends_data_mongo_data_backend_write_method_with_duplicated_key_error logging.ERROR, exception_info.value.args[0], ) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_delete_operation( @@ -582,6 +599,7 @@ def test_backends_data_mongo_data_backend_write_method_with_delete_operation( binary_documents = [json.dumps(documents[2]).encode("utf8")] assert backend.write(binary_documents, operation_type=BaseOperationType.DELETE) == 1 assert not list(backend.read()) + backend.close() def test_backends_data_mongo_data_backend_write_method_with_delete_operation_failure( @@ -615,6 +633,7 @@ def test_backends_data_mongo_data_backend_write_method_with_delete_operation_fai ) assert ("ralph.backends.data.mongo", logging.WARNING, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_update_operation( @@ -651,6 +670,7 @@ def test_backends_data_mongo_data_backend_write_method_with_update_operation( "_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", "new_field": "bar"}, } + backend.close() def test_backends_data_mongo_data_backend_write_method_with_update_operation_failure( @@ -708,6 +728,7 @@ def test_backends_data_mongo_data_backend_write_method_with_update_operation_fai logging.ERROR, exception_info.value.args[0], ) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_append_operation( @@ -723,6 +744,7 @@ def test_backends_data_mongo_data_backend_write_method_with_append_operation( backend.write(data=[], operation_type=BaseOperationType.APPEND) assert ("ralph.backends.data.mongo", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_create_operation( @@ -741,6 +763,7 @@ def test_backends_data_mongo_data_backend_write_method_with_create_operation( results = backend.read() assert next(results)["_source"]["timestamp"] == documents[0]["timestamp"] assert next(results)["_source"]["timestamp"] == documents[1]["timestamp"] + backend.close() @pytest.mark.parametrize( @@ -777,6 +800,7 @@ def test_backends_data_mongo_data_backend_write_method_with_invalid_documents( assert backend.write([document], ignore_errors=True) == 0 assert ("ralph.backends.data.mongo", logging.WARNING, error) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_unparsable_documents( @@ -801,6 +825,7 @@ def test_backends_data_mongo_data_backend_write_method_with_unparsable_documents assert backend.write([b"not valid JSON!"], ignore_errors=True) == 0 assert ("ralph.backends.data.mongo", logging.WARNING, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_no_data( @@ -813,6 +838,7 @@ def test_backends_data_mongo_data_backend_write_method_with_no_data( msg = "Data Iterator is empty; skipping write to target." assert ("ralph.backends.data.mongo", logging.INFO, msg) in caplog.record_tuples + backend.close() def test_backends_data_mongo_data_backend_write_method_with_custom_chunk_size( @@ -870,3 +896,32 @@ def test_backends_data_mongo_data_backend_write_method_with_custom_chunk_size( {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **new_timestamp}}, {"_id": "62b9ce92baa5a0964d3320fb", "_source": {"id": "baz", **new_timestamp}}, ] + backend.close() + + +def test_backends_data_mongo_data_backend_close_method_with_failure( + mongo_backend, monkeypatch +): + """Test the `MongoDataBackend.close` method.""" + + backend = mongo_backend() + + def mock_connection_error(): + """Mongo client close mock that raises a connection error.""" + raise PyMongoError("", (Exception("Mocked connection error"),)) + + monkeypatch.setattr(backend.client, "close", mock_connection_error) + + with pytest.raises(BackendException, match="Failed to close MongoDB client"): + backend.close() + + +def test_backends_data_mongo_data_backend_close_method(mongo_backend): + """Test the `MongoDataBackend.close` method.""" + + backend = mongo_backend() + + # Still possible to connect to client after closing it, as it creates + # a new connection + backend.close() + assert backend.status() == DataBackendStatus.AWAY diff --git a/tests/backends/data/test_s3.py b/tests/backends/data/test_s3.py index d8bfc4a3a..67ac83953 100644 --- a/tests/backends/data/test_s3.py +++ b/tests/backends/data/test_s3.py @@ -72,12 +72,16 @@ def test_backends_data_s3_data_backend_status_method(s3_backend): # Regions outside of us-east-1 require the appropriate LocationConstraint s3_client = boto3.client("s3", region_name="us-east-1") - assert s3_backend().status() == DataBackendStatus.ERROR + backend = s3_backend() + assert backend.status() == DataBackendStatus.ERROR + backend.close() bucket_name = "bucket_name" s3_client.create_bucket(Bucket=bucket_name) - assert s3_backend().status() == DataBackendStatus.OK + backend = s3_backend() + assert backend.status() == DataBackendStatus.OK + backend.close() @mock_s3 @@ -117,9 +121,9 @@ def test_backends_data_s3_data_backend_list_should_yield_archive_names( {"name": "2022-10-01.gz"}, ] - s3 = s3_backend() + backend = s3_backend() - s3.history.extend( + backend.history.extend( [ {"id": "bucket_name/2022-04-29.gz", "backend": "s3", "command": "read"}, {"id": "bucket_name/2022-04-30.gz", "backend": "s3", "command": "read"}, @@ -127,15 +131,16 @@ def test_backends_data_s3_data_backend_list_should_yield_archive_names( ) try: - response_list = s3.list() - response_list_new = s3.list(new=True) - response_list_details = s3.list(details=True) + response_list = backend.list() + response_list_new = backend.list(new=True) + response_list_details = backend.list(details=True) except Exception: # pylint:disable=broad-except pytest.fail("S3 backend should not raise exception on successful list") assert list(response_list) == [x["name"] for x in listing] assert list(response_list_new) == ["2022-10-01.gz"] assert [x["Key"] for x in response_list_details] == [x["name"] for x in listing] + backend.close() @mock_s3 @@ -153,15 +158,16 @@ def test_backends_data_s3_list_on_empty_bucket_should_do_nothing( listing = [] - s3 = s3_backend() + backend = s3_backend() - s3.clean_history(lambda *_: True) + backend.clean_history(lambda *_: True) try: - response_list = s3.list() + response_list = backend.list() except Exception: # pylint:disable=broad-except pytest.fail("S3 backend should not raise exception on successful list") assert list(response_list) == [x["name"] for x in listing] + backend.close() @mock_s3 @@ -184,19 +190,19 @@ def test_backends_data_s3_list_with_failed_connection_should_log_the_error( Body=json.dumps({"id": "1", "foo": "bar"}), ) - s3 = s3_backend() + backend = s3_backend() - s3.clean_history(lambda *_: True) + backend.clean_history(lambda *_: True) msg = "Failed to list the bucket wrong_name: The specified bucket does not exist" with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - next(s3.list(target="wrong_name")) + next(backend.list(target="wrong_name")) with pytest.raises(BackendException, match=msg): - next(s3.list(target="wrong_name", new=True)) + next(backend.list(target="wrong_name", new=True)) with pytest.raises(BackendException, match=msg): - next(s3.list(target="wrong_name", details=True)) + next(backend.list(target="wrong_name", details=True)) assert ( list( @@ -207,6 +213,7 @@ def test_backends_data_s3_list_with_failed_connection_should_log_the_error( ) == [("ralph.backends.data.s3", logging.ERROR, msg)] * 3 ) + backend.close() @mock_s3 @@ -242,11 +249,11 @@ def test_backends_data_s3_read_with_valid_name_should_write_to_history( freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() monkeypatch.setattr("ralph.backends.data.s3.now", lambda: freezed_now) - s3 = s3_backend() - s3.clean_history(lambda *_: True) + backend = s3_backend() + backend.clean_history(lambda *_: True) list( - s3.read( + backend.read( query="2022-09-29.gz", target=bucket_name, chunk_size=1000, @@ -260,10 +267,10 @@ def test_backends_data_s3_read_with_valid_name_should_write_to_history( "id": f"{bucket_name}/2022-09-29.gz", "size": len(raw_body), "timestamp": freezed_now, - } in s3.history + } in backend.history list( - s3.read( + backend.read( query="2022-09-30.gz", raw_output=False, ) @@ -275,7 +282,8 @@ def test_backends_data_s3_read_with_valid_name_should_write_to_history( "id": f"{bucket_name}/2022-09-30.gz", "size": len(json_body), "timestamp": freezed_now, - } in s3.history + } in backend.history + backend.close() @mock_s3 @@ -302,8 +310,8 @@ def test_backends_data_s3_read_with_invalid_output_should_log_the_error( with caplog.at_level(logging.ERROR): with pytest.raises(BackendException): - s3 = s3_backend() - list(s3.read(query="2022-09-29.gz", raw_output=False)) + backend = s3_backend() + list(backend.read(query="2022-09-29.gz", raw_output=False)) assert ( "ralph.backends.data.s3", @@ -311,7 +319,8 @@ def test_backends_data_s3_read_with_invalid_output_should_log_the_error( "Raised error: Expecting value: line 1 column 1 (char 0)", ) in caplog.record_tuples - s3.clean_history(lambda *_: True) + backend.clean_history(lambda *_: True) + backend.close() @mock_s3 @@ -339,8 +348,8 @@ def test_backends_data_s3_read_with_invalid_name_should_log_the_error( with caplog.at_level(logging.ERROR): with pytest.raises(BackendParameterException): - s3 = s3_backend() - list(s3.read(query=None, target=bucket_name)) + backend = s3_backend() + list(backend.read(query=None, target=bucket_name)) assert ( "ralph.backends.data.s3", @@ -348,7 +357,8 @@ def test_backends_data_s3_read_with_invalid_name_should_log_the_error( "Invalid query. The query should be a valid object name.", ) in caplog.record_tuples - s3.clean_history(lambda *_: True) + backend.clean_history(lambda *_: True) + backend.close() @mock_s3 @@ -376,9 +386,9 @@ def test_backends_data_s3_read_with_wrong_name_should_log_the_error( with caplog.at_level(logging.ERROR): with pytest.raises(BackendException): - s3 = s3_backend() - s3.clean_history(lambda *_: True) - list(s3.read(query="invalid_name.gz", target=bucket_name)) + backend = s3_backend() + backend.clean_history(lambda *_: True) + list(backend.read(query="invalid_name.gz", target=bucket_name)) assert ( "ralph.backends.data.s3", @@ -386,7 +396,8 @@ def test_backends_data_s3_read_with_wrong_name_should_log_the_error( "Failed to download invalid_name.gz: The specified key does not exist.", ) in caplog.record_tuples - assert s3.history == [] + assert backend.history == [] + backend.close() @mock_s3 @@ -418,17 +429,18 @@ def mock_read_raw(*args, **kwargs): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException): - s3 = s3_backend() - monkeypatch.setattr(s3, "_read_raw", mock_read_raw) - s3.clean_history(lambda *_: True) - list(s3.read(query=object_name, target=bucket_name, raw_output=True)) + backend = s3_backend() + monkeypatch.setattr(backend, "_read_raw", mock_read_raw) + backend.clean_history(lambda *_: True) + list(backend.read(query=object_name, target=bucket_name, raw_output=True)) assert ( "ralph.backends.data.s3", logging.ERROR, f"Failed to read chunk from object {object_name}", ) in caplog.record_tuples - assert s3.history == [] + assert backend.history == [] + backend.close() @pytest.mark.parametrize( @@ -462,9 +474,9 @@ def test_backends_data_s3_write_method_with_parameter_error( with caplog.at_level(logging.ERROR): with pytest.raises(BackendException): - s3 = s3_backend() - s3.clean_history(lambda *_: True) - s3.write( + backend = s3_backend() + backend.clean_history(lambda *_: True) + backend.write( data=some_content, target=object_name, operation_type=operation_type ) @@ -474,7 +486,8 @@ def test_backends_data_s3_write_method_with_parameter_error( ) assert ("ralph.backends.data.s3", logging.ERROR, msg) in caplog.record_tuples - assert s3.history == [] + assert backend.history == [] + backend.close() @pytest.mark.parametrize( @@ -494,6 +507,7 @@ def test_backends_data_s3_data_backend_write_method_with_append_or_delete_operat match=f"{operation_type.name} operation_type is not allowed.", ): backend.write(data=[b"foo"], operation_type=operation_type) + backend.close() @pytest.mark.parametrize( @@ -520,10 +534,10 @@ def test_backends_data_s3_write_method_with_create_index_operation( object_name = "new-archive.gz" some_content = b"some contents in the stream file to upload" data = [some_content, some_content, some_content] - s3 = s3_backend() - s3.clean_history(lambda *_: True) + backend = s3_backend() + backend.clean_history(lambda *_: True) - response = s3.write( + response = backend.write( data=data, target=object_name, operation_type=operation_type, @@ -537,13 +551,13 @@ def test_backends_data_s3_write_method_with_create_index_operation( "id": f"{bucket_name}/{object_name}", "size": len(some_content) * 3, "timestamp": freezed_now, - } in s3.history + } in backend.history object_name = "new-archive2.gz" other_content = {"some": "content"} data = [other_content, other_content] - response = s3.write( + response = backend.write( data=data, target=object_name, operation_type=operation_type, @@ -557,9 +571,9 @@ def test_backends_data_s3_write_method_with_create_index_operation( "id": f"{bucket_name}/{object_name}", "size": len(bytes(f"{json.dumps(other_content)}\n", encoding="utf8")) * 2, "timestamp": freezed_now, - } in s3.history + } in backend.history - assert list(s3.read(query=object_name, raw_output=False)) == data + assert list(backend.read(query=object_name, raw_output=False)) == data object_name = "new-archive3.gz" date = datetime.datetime(2023, 6, 30, 8, 42, 15, 554892) @@ -571,7 +585,7 @@ def test_backends_data_s3_write_method_with_create_index_operation( with caplog.at_level(logging.ERROR): # Without ignoring error with pytest.raises(BackendException, match=error): - response = s3.write( + response = backend.write( data=data, target=object_name, operation_type=operation_type, @@ -579,7 +593,7 @@ def test_backends_data_s3_write_method_with_create_index_operation( ) # Ignoring error - response = s3.write( + response = backend.write( data=data, target=object_name, operation_type=operation_type, @@ -601,6 +615,7 @@ def test_backends_data_s3_write_method_with_create_index_operation( ] * 2 ) + backend.close() @mock_s3 @@ -618,13 +633,14 @@ def test_backends_data_s3_write_method_with_no_data_should_skip( object_name = "new-archive.gz" - s3 = s3_backend() - response = s3.write( + backend = s3_backend() + response = backend.write( data=[], target=object_name, operation_type=BaseOperationType.CREATE, ) assert response == 0 + backend.close() @mock_s3 @@ -647,12 +663,47 @@ def test_backends_data_s3_write_method_with_failure_should_log_the_error( def raise_client_error(*args, **kwargs): raise ClientError({"Error": {}}, "error") - s3 = s3_backend() - s3.client.put_object = raise_client_error + backend = s3_backend() + backend.client.put_object = raise_client_error with pytest.raises(BackendException, match=error): - s3.write( + backend.write( data=[body], target=object_name, operation_type=BaseOperationType.CREATE, ) + backend.close() + + +def test_backends_data_s3_data_backend_close_method_with_failure( + s3_backend, monkeypatch +): + """Test the `S3DataBackend.close` method.""" + + backend = s3_backend() + + def mock_connection_error(): + """S3 backend client close mock that raises a connection error.""" + raise ClientError({"Error": {}}, "error") + + monkeypatch.setattr(backend.client, "close", mock_connection_error) + + with pytest.raises(BackendException, match="Failed to close S3 backend client"): + backend.close() + + +@mock_s3 +def test_backends_data_s3_data_backend_close_method(s3_backend, caplog): + """Test the `S3DataBackend.close` method.""" + + # No client instantiated + backend = s3_backend() + backend._client = None # pylint: disable=protected-access + with caplog.at_level(logging.WARNING): + backend.close() + + assert ( + "ralph.backends.data.s3", + logging.WARNING, + "No backend client to close.", + ) in caplog.record_tuples diff --git a/tests/backends/data/test_swift.py b/tests/backends/data/test_swift.py index c37fb6045..f0f8fa67b 100644 --- a/tests/backends/data/test_swift.py +++ b/tests/backends/data/test_swift.py @@ -51,6 +51,7 @@ def test_backends_data_swift_data_backend_default_instantiation(monkeypatch, fs) assert backend.options["user_domain_name"] == "Default" assert backend.default_container is None assert backend.locale_encoding == "utf8" + backend.close() def test_backends_data_swift_data_backend_instantiation_with_settings(fs): @@ -85,6 +86,7 @@ def test_backends_data_swift_data_backend_instantiation_with_settings(fs): SwiftDataBackend(settings_) except Exception as err: # pylint:disable=broad-except pytest.fail(f"Two SwiftDataBackends should not raise exceptions: {err}") + backend.close() def test_backends_data_swift_data_backend_status_method_with_error_status( @@ -101,17 +103,18 @@ def mock_failed_head_account(*args, **kwargs): # pylint:disable=unused-argument raise ClientException(error) - swift = swift_backend() - monkeypatch.setattr(swift.connection, "head_account", mock_failed_head_account) + backend = swift_backend() + monkeypatch.setattr(backend.connection, "head_account", mock_failed_head_account) with caplog.at_level(logging.ERROR): - assert swift.status() == DataBackendStatus.ERROR + assert backend.status() == DataBackendStatus.ERROR assert ( "ralph.backends.data.swift", logging.ERROR, f"Unable to connect to the Swift account: {error}", ) in caplog.record_tuples + backend.close() def test_backends_data_swift_data_backend_status_method_with_ok_status( @@ -124,13 +127,16 @@ def test_backends_data_swift_data_backend_status_method_with_ok_status( def mock_successful_head_account(*args, **kwargs): # pylint:disable=unused-argument return 1 - swift = swift_backend() - monkeypatch.setattr(swift.connection, "head_account", mock_successful_head_account) + backend = swift_backend() + monkeypatch.setattr( + backend.connection, "head_account", mock_successful_head_account + ) with caplog.at_level(logging.ERROR): - assert swift.status() == DataBackendStatus.OK + assert backend.status() == DataBackendStatus.OK assert caplog.record_tuples == [] + backend.close() def test_backends_data_swift_data_backend_list_method( @@ -188,6 +194,7 @@ def mock_head_object(container, obj): # pylint:disable=unused-argument assert list(backend.list()) == [x["name"] for x in listing] assert list(backend.list(new=True)) == ["2020-05-01.gz"] assert list(backend.list(details=True)) == listing + backend.close() def test_backends_data_swift_data_backend_list_with_failed_details( @@ -226,6 +233,7 @@ def mock_head_object(*args, **kwargs): # pylint:disable=unused-argument next(backend.list(details=True)) assert ("ralph.backends.data.swift", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_swift_data_backend_list_with_failed_connection( @@ -254,6 +262,7 @@ def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument next(backend.list(details=True)) assert ("ralph.backends.data.swift", logging.ERROR, msg) in caplog.record_tuples + backend.close() def test_backends_data_swift_data_backend_read_method_with_raw_output( @@ -314,6 +323,7 @@ def mock_get_object(*args, **kwargs): # pylint:disable=unused-argument "timestamp": frozen_now, }, ] + backend.close() def test_backends_data_swift_data_backend_read_method_without_raw_output( @@ -352,6 +362,7 @@ def mock_get_object(*args, **kwargs): # pylint:disable=unused-argument "timestamp": frozen_now, } ] + backend.close() def test_backends_data_swift_data_backend_read_method_with_invalid_query(swift_backend): @@ -363,6 +374,7 @@ def test_backends_data_swift_data_backend_read_method_with_invalid_query(swift_b error = "Invalid query. The query should be a valid archive name" with pytest.raises(BackendParameterException, match=error): list(backend.read()) + backend.close() def test_backends_data_swift_data_backend_read_method_with_ignore_errors( @@ -409,6 +421,7 @@ def mock_get_object_2(*args, **kwargs): # pylint:disable=unused-argument result = backend.read(ignore_errors=True, query="2020-06-02.gz") assert isinstance(result, Iterable) assert list(result) == [valid_dictionary] + backend.close() def test_backends_data_swift_data_backend_read_method_without_ignore_errors( @@ -465,6 +478,7 @@ def mock_get_object_2(*args, **kwargs): # pylint:disable=unused-argument assert isinstance(result, Iterable) with pytest.raises(BackendException, match="Raised error:"): next(result) + backend.close() def test_backends_data_swift_data_backend_read_method_with_failed_connection( @@ -488,6 +502,7 @@ def mock_failed_get_object(*args, **kwargs): # pylint:disable=unused-argument next(result) assert ("ralph.backends.data.swift", logging.ERROR, msg) in caplog.record_tuples + backend.close() @pytest.mark.parametrize( @@ -521,6 +536,7 @@ def mock_get_container(*args, **kwargs): # pylint:disable=unused-argument # When the `write` method fails, then no entry should be added to the history. assert not sorted(backend.history, key=itemgetter("id")) + backend.close() def test_backends_data_swift_data_backend_write_method_with_failed_connection( @@ -553,6 +569,7 @@ def mock_head_object(*args, **kwargs): # pylint:disable=unused-argument # When the `write` method fails, then no entry should be added to the history. assert not sorted(backend.history, key=itemgetter("id")) + backend.close() @pytest.mark.parametrize( @@ -582,6 +599,7 @@ def test_backends_data_swift_data_backend_write_method_with_invalid_operation( # When the `write` method fails, then no entry should be added to the history. assert not sorted(backend.history, key=itemgetter("id")) + backend.close() def test_backends_data_swift_data_backend_write_method_without_target( @@ -638,3 +656,43 @@ def mock_head_object(*args, **kwargs): # pylint:disable=unused-argument "timestamp": frozen_now, } ] + backend.close() + + +def test_backends_data_swift_data_backend_close_method_with_failure( + swift_backend, monkeypatch +): + """Test the `SwiftDataBackend.close` method.""" + + backend = swift_backend() + + def mock_connection_error(): + """Swift backend connection close mock that raises a connection error.""" + raise ClientException({"Error": {}}, "error") + + monkeypatch.setattr(backend.connection, "close", mock_connection_error) + + with pytest.raises(BackendException, match="Failed to close Swift backend client"): + backend.close() + + +def test_backends_data_swift_data_backend_close_method(swift_backend, caplog): + """Test the `SwiftDataBackend.close` method.""" + + backend = swift_backend() + + # Not possible to connect to client after closing it + backend.close() + assert backend.status() == DataBackendStatus.ERROR + + # No client instantiated + backend = swift_backend() + backend._connection = None # pylint: disable=protected-access + with caplog.at_level(logging.WARNING): + backend.close() + + assert ( + "ralph.backends.data.swift", + logging.WARNING, + "No backend client to close.", + ) in caplog.record_tuples diff --git a/tests/backends/lrs/test_clickhouse.py b/tests/backends/lrs/test_clickhouse.py index d5bd79e9f..7b44246ac 100644 --- a/tests/backends/lrs/test_clickhouse.py +++ b/tests/backends/lrs/test_clickhouse.py @@ -218,6 +218,7 @@ def mock_read(query, target, ignore_errors): monkeypatch.setattr(backend, "read", mock_read) backend.query_statements(StatementParameters(**params)) + backend.close() def test_backends_lrs_clickhouse_lrs_backend_query_statements( @@ -251,6 +252,7 @@ def test_backends_lrs_clickhouse_lrs_backend_query_statements( StatementParameters(statementId=test_id, limit=10) ) assert result.statements == statements + backend.close() def test_backends_lrs_clickhouse_lrs_backend__find(clickhouse, clickhouse_lrs_backend): @@ -279,6 +281,7 @@ def test_backends_lrs_clickhouse_lrs_backend__find(clickhouse, clickhouse_lrs_ba # Check the expected search query results. result = backend.query_statements(StatementParameters()) assert result.statements == statements + backend.close() def test_backends_lrs_clickhouse_lrs_backend_query_statements_by_ids( @@ -310,6 +313,7 @@ def test_backends_lrs_clickhouse_lrs_backend_query_statements_by_ids( # Check the expected search query results. result = list(backend.query_statements_by_ids([test_id])) assert result[0]["event"] == statements[0] + backend.close() def test_backends_lrs_clickhouse_lrs_backend_query_statements_client_failure( @@ -338,6 +342,7 @@ def mock_query(*args, **kwargs): logging.ERROR, "Failed to read from ClickHouse", ) in caplog.record_tuples + backend.close() def test_backends_lrs_clickhouse_lrs_backend_query_statements_by_ids_client_failure( @@ -366,3 +371,4 @@ def mock_query(*args, **kwargs): logging.ERROR, "Failed to read from ClickHouse", ) in caplog.record_tuples + backend.close() diff --git a/tests/backends/lrs/test_es.py b/tests/backends/lrs/test_es.py index 89dbf5b45..91bb56f31 100644 --- a/tests/backends/lrs/test_es.py +++ b/tests/backends/lrs/test_es.py @@ -280,6 +280,8 @@ def mock_read(query, chunk_size): assert result.pit_id == "foo_pit_id" assert result.search_after == "bar_search_after|baz_search_after" + backend.close() + def test_backends_lrs_es_lrs_backend_query_statements(es, es_lrs_backend): """Test the `ESLRSBackend.query_statements` method, given a query, @@ -297,6 +299,8 @@ def test_backends_lrs_es_lrs_backend_query_statements(es, es_lrs_backend): assert result.statements == documents assert re.match(r"[0-9]+\|0", result.search_after) + backend.close() + def test_backends_lrs_es_lrs_backend_query_statements_with_search_query_failure( es, es_lrs_backend, monkeypatch, caplog @@ -324,6 +328,8 @@ def mock_read(**_): "Failed to read from Elasticsearch", ) in caplog.record_tuples + backend.close() + def test_backends_lrs_es_lrs_backend_query_statements_by_ids_with_search_query_failure( es, es_lrs_backend, monkeypatch, caplog @@ -351,6 +357,8 @@ def mock_search(**_): "Failed to read from Elasticsearch", ) in caplog.record_tuples + backend.close() + def test_backends_lrs_es_lrs_backend_query_statements_by_ids_with_multiple_indexes( es, es_forwarding, es_lrs_backend @@ -387,3 +395,6 @@ def test_backends_lrs_es_lrs_backend_query_statements_by_ids_with_multiple_index assert not list(backend_1.query_statements_by_ids(["2"])) assert not list(backend_2.query_statements_by_ids(["1"])) assert list(backend_2.query_statements_by_ids(["2"])) == [index_2_document] + + backend_1.close() + backend_2.close() diff --git a/tests/backends/lrs/test_mongo.py b/tests/backends/lrs/test_mongo.py index 2effe53d2..85edc3f0d 100644 --- a/tests/backends/lrs/test_mongo.py +++ b/tests/backends/lrs/test_mongo.py @@ -242,6 +242,7 @@ def mock_read(query, chunk_size): assert result.statements == [{}] assert not result.pit_id assert result.search_after == "search_after_id" + backend.close() def test_backends_lrs_mongo_lrs_backend_query_statements_with_success( @@ -285,6 +286,7 @@ def test_backends_lrs_mongo_lrs_backend_query_statements_with_success( assert statement_query_result.statements == [ {"id": "62b9ce922c26b46b68ffc68f", **timestamp, **meta} ] + backend.close() def test_backends_lrs_mongo_lrs_backend_query_statements_with_query_failure( @@ -314,6 +316,7 @@ def mock_read(**_): logging.ERROR, "Failed to read from MongoDB", ) in caplog.record_tuples + backend.close() def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_with_query_failure( @@ -343,6 +346,7 @@ def mock_read(**_): logging.ERROR, "Failed to read from MongoDB", ) in caplog.record_tuples + backend.close() def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_with_two_collections( @@ -368,3 +372,5 @@ def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_with_two_collect assert not list(backend_1.query_statements_by_ids(["2"])) assert not list(backend_2.query_statements_by_ids(["1"])) assert list(backend_2.query_statements_by_ids(["2"])) == [{"id": "2", **timestamp}] + backend_1.close() + backend_2.close() From 3ba6ddc99c0e45489ab2c004ef2db02fb48977af Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Mon, 28 Aug 2023 09:58:09 +0200 Subject: [PATCH 22/65] =?UTF-8?q?=F0=9F=90=9B(backends)=20fix=20bug=20iter?= =?UTF-8?q?ating=20over=20async=20mongo=20collections?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The update to a recent version of `motor` highlighted a bug on our side when listing collections. Now asynchronously iterate over collections list. --- setup.cfg | 2 +- src/ralph/backends/data/async_mongo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index bef3badc5..7885a82e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ backend-lrs = httpx<0.25.0 # pin as Python 3.7 is no longer supported from release 0.25.0 more-itertools==10.1.0 backend-mongo = - motor[srv]>=3.1.1 + motor[srv]>=3.3.0 pymongo[srv]>=4.0.0 python-dateutil>=2.8.2 backend-s3 = diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index a13d54324..7b332f311 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -100,7 +100,7 @@ async def list( try: collections = await database.list_collections() - for collection_info in collections: + async for collection_info in collections: if details: yield collection_info else: From 81913183ee525c31825cfbc0a540e67f8f8a988f Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 17 Aug 2023 09:55:06 +0200 Subject: [PATCH 23/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20unify?= =?UTF-8?q?=20stream=20backend=20settings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With the new data backend interface, settings are now close to each backend and not under general conf.py. Unifying stream backend WS to have the same architecture as data backends. --- src/ralph/backends/stream/__init__.py | 5 +--- src/ralph/backends/stream/base.py | 18 ++++++++++++- src/ralph/backends/stream/ws.py | 33 ++++++++++++++++++------ tests/backends/stream/test_ws.py | 37 +++++++++++++++------------ 4 files changed, 64 insertions(+), 29 deletions(-) diff --git a/src/ralph/backends/stream/__init__.py b/src/ralph/backends/stream/__init__.py index e707dbdfc..6e031999e 100644 --- a/src/ralph/backends/stream/__init__.py +++ b/src/ralph/backends/stream/__init__.py @@ -1,4 +1 @@ -"""Stream backends for Ralph.""" - -from .base import BaseStream # noqa: F401 -from .ws import WSStream # noqa: F401 +# noqa: D104 diff --git a/src/ralph/backends/stream/base.py b/src/ralph/backends/stream/base.py index d063d01f2..6d3a3addc 100644 --- a/src/ralph/backends/stream/base.py +++ b/src/ralph/backends/stream/base.py @@ -3,12 +3,28 @@ from abc import ABC, abstractmethod from typing import BinaryIO +from pydantic import BaseSettings + +from ralph.conf import BaseSettingsConfig, core_settings + + +class BaseStreamBackendSettings(BaseSettings): + """Data backend default configuration.""" + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__STREAM__" + env_file = ".env" + env_file_encoding = core_settings.LOCALE_ENCODING + class BaseStream(ABC): """Base stream backend interface.""" name = "base" + settings_class = BaseStreamBackendSettings @abstractmethod def stream(self, target: BinaryIO): - """Read records and streams them to target.""" + """Read records and stream them to target.""" diff --git a/src/ralph/backends/stream/ws.py b/src/ralph/backends/stream/ws.py index 6893a1f97..0cad5b029 100644 --- a/src/ralph/backends/stream/ws.py +++ b/src/ralph/backends/stream/ws.py @@ -6,34 +6,51 @@ import websockets -from ralph.conf import settings +from ralph.conf import BaseSettingsConfig -from .base import BaseStream +from .base import BaseStreamBackend, BaseStreamBackendSettings logger = logging.getLogger(__name__) -class WSStream(BaseStream): +class WSStreamBackendSettings(BaseStreamBackendSettings): + """Websocket stream backend default configuration. + + Attributes: + URI (str): The URI to connect to. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__STREAM__WS__" + + URI: str = None + + +class WSStreamBackend(BaseStreamBackend): """Websocket stream backend.""" name = "ws" + settings_class = WSStreamBackendSettings - def __init__(self, uri: str = settings.BACKENDS.STREAM.WS.URI): + def __init__(self, settings: settings_class = None): """Instantiate the websocket client. Args: - uri (str): The URI to connect to. + settings (WSStreamBackendSettings or None): The stream backend settings. + If `settings` is `None`, a default settings instance is used instead. """ - self.uri = uri + self.settings = settings if settings else self.settings_class() def stream(self, target: BinaryIO): """Stream websocket content to target.""" # pylint: disable=no-member - logger.debug("Streaming from websocket uri: %s", self.uri) + logger.debug("Streaming from websocket uri: %s", self.settings.URI) async def _stream(): - async with websockets.connect(self.uri) as websocket: + async with websockets.connect(self.settings.URI) as websocket: while event := await websocket.recv(): target.write(bytes(f"{event}" + "\n", encoding="utf-8")) diff --git a/tests/backends/stream/test_ws.py b/tests/backends/stream/test_ws.py index dc4b5c462..ebb143b77 100644 --- a/tests/backends/stream/test_ws.py +++ b/tests/backends/stream/test_ws.py @@ -5,30 +5,35 @@ import websockets -from ralph.backends.stream.ws import WSStream -from ralph.conf import settings +from ralph.backends.stream.ws import WSStreamBackend, WSStreamBackendSettings from tests.fixtures.backends import WS_TEST_HOST, WS_TEST_PORT -def test_backends_stream_ws_stream_instantiation(ws): - """Test the WSStream backend instantiation.""" +def test_backends_stream_ws_stream_default_instantiation(monkeypatch, fs): + """Test the `WSStreamBackend` instantiation.""" # pylint: disable=invalid-name,unused-argument + fs.create_file(".env") + backend_settings_names = ["URI"] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__STREAM__WS__{name}", raising=False) - assert WSStream.name == "ws" - - assert WSStream().uri == settings.BACKENDS.STREAM.WS.URI + assert WSStreamBackend.name == "ws" + assert WSStreamBackend.settings_class == WSStreamBackendSettings + backend = WSStreamBackend() + assert not backend.settings.URI uri = f"ws://{WS_TEST_HOST}:{WS_TEST_PORT}" - client = WSStream(uri) - assert client.uri == uri + backend = WSStreamBackend(WSStreamBackendSettings(URI=uri)) + assert backend.settings.URI == uri def test_backends_stream_ws_stream_stream(ws, monkeypatch, events): - """Test the WSStream backend stream method.""" + """Test the `WSStreamBackend` stream method.""" # pylint: disable=invalid-name,unused-argument + settings = WSStreamBackendSettings(URI=f"ws://{WS_TEST_HOST}:{WS_TEST_PORT}") - client = WSStream(f"ws://{WS_TEST_HOST}:{WS_TEST_PORT}") + backend = WSStreamBackend(settings) # Mock stdout stream class MockStdout: @@ -39,7 +44,7 @@ class MockStdout: mock_stdout = MockStdout() try: - client.stream(mock_stdout.buffer) + backend.stream(mock_stdout.buffer) except websockets.exceptions.ConnectionClosedOK: pass @@ -49,10 +54,10 @@ class MockStdout: def test_backends_stream_ws_stream_stream_when_server_stops(ws, monkeypatch, events): - """Test the WSStream backend stream method when the websocket server stops.""" + """Test the WSStreamBackend stream method when the websocket server stops.""" # pylint: disable=invalid-name,unused-argument - - client = WSStream(f"ws://{WS_TEST_HOST}:{WS_TEST_PORT}") + settings = WSStreamBackendSettings(URI=f"ws://{WS_TEST_HOST}:{WS_TEST_PORT}") + backend = WSStreamBackend(settings) # Mock stdout stream class MockStdout: @@ -63,7 +68,7 @@ class MockStdout: mock_stdout = MockStdout() try: - client.stream(mock_stdout.buffer) + backend.stream(mock_stdout.buffer) except websockets.exceptions.ConnectionClosedOK: pass From 8f5d686b17c91c5a9d4ca8975adc010da77176cc Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 17 Aug 2023 11:28:03 +0200 Subject: [PATCH 24/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20unify?= =?UTF-8?q?=20http=20backend=20settings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With the new data backend interface, settings are now close to each backend and not under general conf.py. Unifying HTTP backends to have the same architecture as data backends. --- .env.dist | 146 ++++++++++------- src/ralph/backends/http/__init__.py | 6 +- src/ralph/backends/http/async_lrs.py | 78 ++++++---- src/ralph/backends/http/base.py | 16 +- src/ralph/backends/http/lrs.py | 6 +- tests/backends/http/test_async_lrs.py | 216 +++++++++++++++++++------- tests/backends/http/test_base.py | 4 +- tests/backends/http/test_lrs.py | 97 +++++++++--- 8 files changed, 395 insertions(+), 174 deletions(-) diff --git a/.env.dist b/.env.dist index e2be3463c..72911d339 100644 --- a/.env.dist +++ b/.env.dist @@ -13,63 +13,95 @@ RALPH_APP_DIR=/app/.ralph # define them for convenience purpose during development, but they can be # passed as CLI options. -# RALPH_BACKENDS__STORAGE__LDP__ENDPOINT= -# RALPH_BACKENDS__STORAGE__LDP__APPLICATION_KEY= -# RALPH_BACKENDS__STORAGE__LDP__APPLICATION_SECRET= -# RALPH_BACKENDS__STORAGE__LDP__CONSUMER_KEY= -# RALPH_BACKENDS__STORAGE__LDP__SERVICE_NAME= -# RALPH_BACKENDS__STORAGE__LDP__STREAM_ID= - -# Swift storage backend - -# RALPH_BACKENDS__STORAGE__SWIFT__OS_AUTH_URL=http://swift:35357/v3/ -# RALPH_BACKENDS__STORAGE__SWIFT__OS_IDENTITY_API_VERSION=3 -# RALPH_BACKENDS__STORAGE__SWIFT__OS_USER_DOMAIN_NAME=Default -# RALPH_BACKENDS__STORAGE__SWIFT__OS_PROJECT_DOMAIN_NAME=Default -# RALPH_BACKENDS__STORAGE__SWIFT__OS_TENANT_ID=cd238e84310a46e58af7f1d515887d88 -# RALPH_BACKENDS__STORAGE__SWIFT__OS_TENANT_NAME=RegionOne -# RALPH_BACKENDS__STORAGE__SWIFT__OS_USERNAME=demo -# RALPH_BACKENDS__STORAGE__SWIFT__OS_PASSWORD=demo -# RALPH_BACKENDS__STORAGE__SWIFT__OS_REGION_NAME=RegionOne -# RALPH_BACKENDS__STORAGE__SWIFT__OS_STORAGE_URL=http://swift:8080/v1/KEY_cd238e84310a46e58af7f1d515887d88/test_container - -# S3 storage backend - -# RALPH_BACKENDS__STORAGE__S3__ACCESS_KEY_ID= -# RALPH_BACKENDS__STORAGE__S3__SECRET_ACCESS_KEY= -# RALPH_BACKENDS__STORAGE__S3__SESSION_TOKEN= -# RALPH_BACKENDS__STORAGE__S3__DEFAULT_REGION= -# RALPH_BACKENDS__STORAGE__S3__BUCKET_NAME= -# RALPH_BACKENDS__STORAGE__S3__ENDPOINT_URL= - -# ES database backend - -RALPH_BACKENDS__DATABASE__ES__HOSTS=http://elasticsearch:9200 -RALPH_BACKENDS__DATABASE__ES__INDEX=statements -RALPH_BACKENDS__DATABASE__ES__TEST_HOSTS=http://elasticsearch:9200 -RALPH_BACKENDS__DATABASE__ES__TEST_INDEX=test-index-foo -RALPH_BACKENDS__DATABASE__ES__TEST_FORWARDING_INDEX=test-index-foo-2 - -# MONGO database backend - -RALPH_BACKENDS__DATABASE__MONGO__COLLECTION=foo -RALPH_BACKENDS__DATABASE__MONGO__DATABASE=statements -RALPH_BACKENDS__DATABASE__MONGO__CONNECTION_URI=mongodb://mongo:27017/ -RALPH_BACKENDS__DATABASE__MONGO__TEST_COLLECTION=foo -RALPH_BACKENDS__DATABASE__MONGO__TEST_FORWARDING_COLLECTION=foo-2 -RALPH_BACKENDS__DATABASE__MONGO__TEST_DATABASE=statements -RALPH_BACKENDS__DATABASE__MONGO__TEST_CONNECTION_URI=mongodb://mongo:27017/ - -# ClickHouse database backend - -RALPH_BACKENDS__DATABASE__CLICKHOUSE__HOST=clickhouse -RALPH_BACKENDS__DATABASE__CLICKHOUSE__PORT=8123 -RALPH_BACKENDS__DATABASE__CLICKHOUSE__XAPI_DATABASE=xapi -RALPH_BACKENDS__DATABASE__CLICKHOUSE__EVENT_TABLE_NAME=xapi_events_all -RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_DATABASE=test_statements -RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_HOST=clickhouse -RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_PORT=8123 -RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_TABLE_NAME=test_xapi_events_all +# RALPH_BACKENDS__DATA__LDP__APPLICATION_KEY= +# RALPH_BACKENDS__DATA__LDP__APPLICATION_SECRET= +# RALPH_BACKENDS__DATA__LDP__CONSUMER_KEY= +# RALPH_BACKENDS__DATA__LDP__DEFAULT_STREAM_ID= +# RALPH_BACKENDS__DATA__LDP__ENDPOINT= +# RALPH_BACKENDS__DATA__LDP__REQUEST_TIMEOUT= +# RALPH_BACKENDS__DATA__LDP__SERVICE_NAME= + +# Swift data backend + +# RALPH_BACKENDS__DATA__SWIFT__AUTH_URL=http://swift:35357/v3/ +# RALPH_BACKENDS__DATA__SWIFT__USERNAME=demo +# RALPH_BACKENDS__DATA__SWIFT__PASSWORD=demo +# RALPH_BACKENDS__DATA__SWIFT__IDENTITY_API_VERSION=3 +# RALPH_BACKENDS__DATA__SWIFT__TENANT_ID=cd238e84310a46e58af7f1d515887d88 +# RALPH_BACKENDS__DATA__SWIFT__TENANT_NAME=RegionOne +# RALPH_BACKENDS__DATA__SWIFT__PROJECT_DOMAIN_NAME=Default +# RALPH_BACKENDS__DATA__SWIFT__REGION_NAME=RegionOne +# RALPH_BACKENDS__DATA__SWIFT__OBJECT_STORAGE_URL=http://swift:8080/v1/KEY_cd238e84310a46e58af7f1d515887d88/test_container +# RALPH_BACKENDS__DATA__SWIFT__USER_DOMAIN_NAME=Default +# RALPH_BACKENDS__DATA__SWIFT__DEFAULT_CONTAINER= +# RALPH_BACKENDS__DATA__SWIFT__LOCALE_ENCODING=Default + +# S3 data backend + +# RALPH_BACKENDS__DATA__S3__ACCESS_KEY_ID= +# RALPH_BACKENDS__DATA__S3__SECRET_ACCESS_KEY= +# RALPH_BACKENDS__DATA__S3__SESSION_TOKEN= +# RALPH_BACKENDS__DATA__S3__ENDPOINT_URL= +# RALPH_BACKENDS__DATA__S3__DEFAULT_REGION= +# RALPH_BACKENDS__DATA__S3__DEFAULT_BUCKET_NAME= +# RALPH_BACKENDS__DATA__S3__DEFAULT_CHUNK_SIZE= +# RALPH_BACKENDS__DATA__S3__LOCALE_ENCODING= + +# ES data backend + +RALPH_BACKENDS__DATA__ES__HOSTS=http://elasticsearch:9200 +RALPH_BACKENDS__DATA__ES__DEFAULT_INDEX=statements +# RALPH_BACKENDS__DATA__ES__ALLOW_YELLOW_STATUS=False +# RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__ca_certs=False +# RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__verify_certs=False +# RALPH_BACKENDS__DATA__ES__DEFAULT_CHUNK_SIZE=500 +# RALPH_BACKENDS__DATA__ES__LOCALE_ENCODING=utf8 +# RALPH_BACKENDS__DATA__ES__POINT_IN_TIME_KEEP_ALIVE=1m +# RALPH_BACKENDS__DATA__ES__REFRESH_AFTER_WRITE=False +RALPH_BACKENDS__DATA__ES__TEST_HOSTS=http://elasticsearch:9200 +RALPH_BACKENDS__DATA__ES__TEST_INDEX=test-index-foo +RALPH_BACKENDS__DATA__ES__TEST_FORWARDING_INDEX=test-index-foo-2 + +# MONGO data backend + +RALPH_BACKENDS__DATA__MONGO__CONNECTION_URI=mongodb://mongo:27017/ +RALPH_BACKENDS__DATA__MONGO__DEFAULT_COLLECTION=foo +RALPH_BACKENDS__DATA__MONGO__DEFAULT_DATABASE=statements +# RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__document_class= +# RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__tz_aware=False +# RALPH_BACKENDS__DATA__MONGO__DEFAULT_CHUNK_SIZE=500 +# RALPH_BACKENDS__DATA__MONGO__LOCALE_ENCODING=utf8 +RALPH_BACKENDS__DATA__MONGO__TEST_COLLECTION=foo +RALPH_BACKENDS__DATA__MONGO__TEST_FORWARDING_COLLECTION=foo-2 +RALPH_BACKENDS__DATA__MONGO__TEST_DATABASE=statements +RALPH_BACKENDS__DATA__MONGO__TEST_CONNECTION_URI=mongodb://mongo:27017/ + +# ClickHouse data backend + +RALPH_BACKENDS__DATA__CLICKHOUSE__HOST=clickhouse +RALPH_BACKENDS__DATA__CLICKHOUSE__PORT=8123 +RALPH_BACKENDS__DATA__CLICKHOUSE__DATABASE=xapi +RALPH_BACKENDS__DATA__CLICKHOUSE__EVENT_TABLE_NAME=xapi_events_all +# RALPH_BACKENDS__DATA__CLICKHOUSE__USERNAME= +# RALPH_BACKENDS__DATA__CLICKHOUSE__PASSWORD= +# RALPH_BACKENDS__DATA__CLICKHOUSE__CLIENT_OPTIONS__date_time_input_format= +# RALPH_BACKENDS__DATA__CLICKHOUSE__CLIENT_OPTIONS__allow_experimental_object_type= +# RALPH_BACKENDS__DATA__CLICKHOUSE__DEFAULT_CHUNK_SIZE=500 +# RALPH_BACKENDS__DATA__CLICKHOUSE__LOCALE_ENCODING=utf8 +RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_DATABASE=test_statements +RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_HOST=clickhouse +RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_PORT=8123 +RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_TABLE_NAME=test_xapi_events_all + + +# LRS HTTP backend + +RALPH_BACKENDS__HTTP__LRS__BASE_URL=http://ralph:secret@0.0.0.0:8100/ +RALPH_BACKENDS__HTTP__LRS__USERNAME=ralph +RALPH_BACKENDS__HTTP__LRS__PASSWORD=secret +RALPH_BACKENDS__HTTP__LRS__HEADERS={"X-Experience-API-Version": "1.0.3", "content-type": "application/json"} +RALPH_BACKENDS__HTTP__LRS__STATUS_ENDPOINT=/__heartbeat__ +RALPH_BACKENDS__HTTP__LRS__STATEMENTS_ENDPOINT=/xAPI/statements # Sentry diff --git a/src/ralph/backends/http/__init__.py b/src/ralph/backends/http/__init__.py index 59dd7a6c7..6e031999e 100644 --- a/src/ralph/backends/http/__init__.py +++ b/src/ralph/backends/http/__init__.py @@ -1,5 +1 @@ -"""HTTP backends for Ralph.""" - -from .async_lrs import AsyncLRSHTTP # noqa: F401 -from .base import BaseHTTP # noqa: F401 -from .lrs import LRSHTTP # noqa: F401 +# noqa: D104 diff --git a/src/ralph/backends/http/async_lrs.py b/src/ralph/backends/http/async_lrs.py index e3d3d29c1..09c50ac3f 100644 --- a/src/ralph/backends/http/async_lrs.py +++ b/src/ralph/backends/http/async_lrs.py @@ -14,7 +14,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field, NonNegativeInt, parse_obj_as from pydantic.types import PositiveInt -from ralph.conf import LRSHeaders, settings +from ralph.conf import BaseSettingsConfig, HeadersParameters from ralph.exceptions import BackendException, BackendParameterException from ralph.models.xapi.base.agents import BaseXapiAgent from ralph.models.xapi.base.common import IRI @@ -22,17 +22,49 @@ from ralph.utils import gather_with_limited_concurrency from .base import ( - BaseHTTP, + BaseHTTPBackend, + BaseHTTPBackendSettings, BaseQuery, HTTPBackendStatus, OperationType, enforce_query_checks, ) -lrs_settings = settings.BACKENDS.HTTP.LRS logger = logging.getLogger(__name__) +class LRSHeaders(HeadersParameters): + """Pydantic model for LRS headers.""" + + X_EXPERIENCE_API_VERSION: str = Field("1.0.3", alias="X-Experience-API-Version") + CONTENT_TYPE: str = Field("application/json", alias="content-type") + + +class LRSHTTPBackendSettings(BaseHTTPBackendSettings): + """LRS HTTP backend default configuration. + + Attributes: + BASE_URL (AnyHttpUrl): LRS server URL. + USERNAME (str): Basic auth username for LRS authentication. + PASSWORD (str): Basic auth password for LRS authentication. + HEADERS (dict): Headers defined for the LRS server connection. + STATUS_ENDPOINT (str): Endpoint used to check server status. + STATEMENTS_ENDPOINT (str): Default endpoint for LRS statements resource. + """ + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__HTTP__LRS__" + + BASE_URL: AnyHttpUrl = Field("http://0.0.0.0:8100") + USERNAME: str = "ralph" + PASSWORD: str = "secret" + HEADERS: LRSHeaders = LRSHeaders() + STATUS_ENDPOINT: str = "/__heartbeat__" + STATEMENTS_ENDPOINT: str = "/xAPI/statements" + + class StatementResponse(BaseModel): """Pydantic model for `get` statements response.""" @@ -65,41 +97,31 @@ class LRSStatementsQuery(BaseQuery): ascending: Optional[bool] = False -class AsyncLRSHTTP(BaseHTTP): +class AsyncLRSHTTPBackend(BaseHTTPBackend): """Asynchronous LRS HTTP backend.""" name = "async_lrs" query = LRSStatementsQuery default_operation_type = OperationType.CREATE + settings_class = LRSHTTPBackendSettings def __init__( # pylint: disable=too-many-arguments - self, - base_url: str = lrs_settings.BASE_URL, - username: str = lrs_settings.USERNAME, - password: str = lrs_settings.PASSWORD, - headers: LRSHeaders = lrs_settings.HEADERS, - status_endpoint: str = lrs_settings.STATUS_ENDPOINT, - statements_endpoint: str = lrs_settings.STATEMENTS_ENDPOINT, + self, settings: settings_class = None ): - """Instantiate the LRS client. + """Instantiate the LRS HTTP (basic auth) backend client. Args: - base_url (AnyHttpUrl): LRS server URL. - username (str): Basic auth username for LRS authentication. - password (str): Basic auth password for LRS authentication. - headers (dict): Headers defined for the LRS server connection. - status_endpoint (str): Endpoint used to check server status. - statements_endpoint (str): Default endpoint for LRS statements resource. + settings (LRSHTTPBackendSettings or None): The LRS HTTP backend settings. + If `settings` is `None`, a default settings instance is used instead. """ - self.base_url = parse_obj_as(AnyHttpUrl, base_url) - self.auth = (username, password) - self.headers = headers - self.status_endpoint = status_endpoint - self.statements_endpoint = statements_endpoint + self.settings = settings if settings else self.settings_class() + + self.base_url = parse_obj_as(AnyHttpUrl, self.settings.BASE_URL) + self.auth = (self.settings.USERNAME, self.settings.PASSWORD) async def status(self): """HTTP backend check for server status.""" - status_url = urljoin(self.base_url, self.status_endpoint) + status_url = urljoin(self.base_url, self.settings.STATUS_ENDPOINT) try: async with AsyncClient() as client: @@ -163,7 +185,7 @@ async def read( # pylint: disable=too-many-arguments max_statements: The maximum number of statements to yield. """ if not target: - target = self.statements_endpoint + target = self.settings.STATEMENTS_ENDPOINT if query and query.limit: logger.warning( @@ -262,7 +284,7 @@ async def write( # pylint: disable=too-many-arguments raise BackendParameterException(msg) if not target: - target = self.statements_endpoint + target = self.settings.STATEMENTS_ENDPOINT target = ParseResult( scheme=urlparse(self.base_url).scheme, @@ -308,7 +330,7 @@ async def write( # pylint: disable=too-many-arguments async def _fetch_statements(self, target, raw_output, query_params: dict): """Fetch statements from a LRS. Used in `read`.""" async with AsyncClient( - auth=self.auth, headers=self.headers.dict(by_alias=True) + auth=self.auth, headers=self.settings.HEADERS.dict(by_alias=True) ) as client: while True: response = await client.get(target, params=query_params) @@ -374,7 +396,7 @@ async def _post_and_raise_for_status(self, target, chunk, ignore_errors): For use in `write`. """ - async with AsyncClient(auth=self.auth, headers=self.headers) as client: + async with AsyncClient(auth=self.auth, headers=self.settings.HEADERS) as client: try: request = await client.post( # Encode data to allow async post diff --git a/src/ralph/backends/http/base.py b/src/ralph/backends/http/base.py index 1494d1d9a..8418fd7c4 100644 --- a/src/ralph/backends/http/base.py +++ b/src/ralph/backends/http/base.py @@ -6,14 +6,26 @@ from enum import Enum, unique from typing import Iterator, List, Optional, Union -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, BaseSettings, ValidationError from pydantic.types import PositiveInt +from ralph.conf import BaseSettingsConfig, core_settings from ralph.exceptions import BackendParameterException logger = logging.getLogger(__name__) +class BaseHTTPBackendSettings(BaseSettings): + """Data backend default configuration.""" + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_prefix = "RALPH_BACKENDS__HTTP__" + env_file = ".env" + env_file_encoding = core_settings.LOCALE_ENCODING + + @unique class HTTPBackendStatus(Enum): """HTTP backend statuses.""" @@ -66,7 +78,7 @@ class Config: query_string: Optional[str] -class BaseHTTP(ABC): +class BaseHTTPBackend(ABC): """Base HTTP backend interface.""" name = "base" diff --git a/src/ralph/backends/http/lrs.py b/src/ralph/backends/http/lrs.py index 0f00f1266..3daf87c07 100644 --- a/src/ralph/backends/http/lrs.py +++ b/src/ralph/backends/http/lrs.py @@ -1,7 +1,7 @@ """LRS HTTP backend for Ralph.""" import asyncio -from ralph.backends.http.async_lrs import AsyncLRSHTTP +from ralph.backends.http.async_lrs import AsyncLRSHTTPBackend def _ensure_running_loop_uniqueness(func): @@ -16,7 +16,7 @@ def wrap(*args, **kwargs): if loop.is_running(): raise RuntimeError( f"This event loop is already running. You must use " - f"`AsyncLRSHTTP.{func.__name__}` (instead of `LRSHTTP." + f"`AsyncLRSHTTPBackend.{func.__name__}` (instead of `LRSHTTPBackend." f"{func.__name__}`), or run this code outside the current" " event loop." ) @@ -25,7 +25,7 @@ def wrap(*args, **kwargs): return wrap -class LRSHTTP(AsyncLRSHTTP): +class LRSHTTPBackend(AsyncLRSHTTPBackend): """LRS HTTP backend.""" # pylint: disable=invalid-overridden-method diff --git a/tests/backends/http/test_async_lrs.py b/tests/backends/http/test_async_lrs.py index a8706975a..44d4a1ff0 100644 --- a/tests/backends/http/test_async_lrs.py +++ b/tests/backends/http/test_async_lrs.py @@ -13,16 +13,20 @@ import httpx import pytest from httpx import HTTPStatusError, RequestError -from pydantic import AnyHttpUrl +from pydantic import AnyHttpUrl, parse_obj_as from pytest_httpx import HTTPXMock -from ralph.backends.http.async_lrs import LRSStatementsQuery, OperationType +from ralph.backends.http.async_lrs import ( + AsyncLRSHTTPBackend, + LRSHeaders, + LRSHTTPBackendSettings, + LRSStatementsQuery, + OperationType, +) from ralph.backends.http.base import HTTPBackendStatus -from ralph.backends.http.lrs import AsyncLRSHTTP -from ralph.conf import LRSHeaders, settings from ralph.exceptions import BackendException, BackendParameterException -lrs_settings = settings.BACKENDS.HTTP.LRS +# pylint: disable=too-many-lines async def _unpack_async_generator(async_gen): @@ -53,29 +57,58 @@ def _gen_statement(id_=None, verb=None, timestamp=None): return {"id": id_, "verb": verb, "timestamp": timestamp} +def test_backend_http_lrs_default_instantiation( + monkeypatch, fs +): # pylint:disable = invalid-name + """Test the `LRSHTTPBackend` default instantiation.""" + fs.create_file(".env") + backend_settings_names = [ + "BASE_URL", + "USERNAME", + "PASSWORD", + "HEADERS", + "STATUS_ENDPOINT", + "STATEMENTS_ENDPOINT", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__HTTP__LRS__{name}", raising=False) + + assert AsyncLRSHTTPBackend.name == "async_lrs" + assert AsyncLRSHTTPBackend.settings_class == LRSHTTPBackendSettings + backend = AsyncLRSHTTPBackend() + assert backend.query == LRSStatementsQuery + assert backend.base_url == parse_obj_as(AnyHttpUrl, "http://0.0.0.0:8100") + assert backend.auth == ("ralph", "secret") + assert backend.settings.HEADERS == LRSHeaders() + assert backend.settings.STATUS_ENDPOINT == "/__heartbeat__" + assert backend.settings.STATEMENTS_ENDPOINT == "/xAPI/statements" + + def test_backends_http_lrs_http_instantiation(): - """Test the LRS backend instantiation.""" - assert AsyncLRSHTTP.name == "async_lrs" - assert AsyncLRSHTTP.query == LRSStatementsQuery + """Test the LRS backend default instantiation.""" headers = LRSHeaders( X_EXPERIENCE_API_VERSION="1.0.3", CONTENT_TYPE="application/json" ) - backend = AsyncLRSHTTP( - base_url="http://fake-lrs.com", - username="user", - password="pass", - headers=headers, - status_endpoint="/fake-status-endpoint", - statements_endpoint="/xAPI/statements", + settings = LRSHTTPBackendSettings( + BASE_URL="http://fake-lrs.com", + USERNAME="user", + PASSWORD="pass", + HEADERS=headers, + STATUS_ENDPOINT="/fake-status-endpoint", + STATEMENTS_ENDPOINT="/xAPI/statements", ) + assert AsyncLRSHTTPBackend.name == "async_lrs" + assert AsyncLRSHTTPBackend.settings_class == LRSHTTPBackendSettings + backend = AsyncLRSHTTPBackend(settings) + assert backend.query == LRSStatementsQuery assert isinstance(backend.base_url, AnyHttpUrl) assert backend.auth == ("user", "pass") - assert backend.headers.CONTENT_TYPE == "application/json" - assert backend.headers.X_EXPERIENCE_API_VERSION == "1.0.3" - assert backend.status_endpoint == "/fake-status-endpoint" - assert backend.statements_endpoint == "/xAPI/statements" + assert backend.settings.HEADERS.CONTENT_TYPE == "application/json" + assert backend.settings.HEADERS.X_EXPERIENCE_API_VERSION == "1.0.3" + assert backend.settings.STATUS_ENDPOINT == "/fake-status-endpoint" + assert backend.settings.STATEMENTS_ENDPOINT == "/xAPI/statements" @pytest.mark.anyio @@ -88,12 +121,13 @@ async def test_backends_http_lrs_status_with_successful_request( base_url = "http://fake-lrs.com" status_endpoint = "/__heartbeat__" - backend = AsyncLRSHTTP( - base_url=base_url, - username="user", - password="pass", - status_endpoint=status_endpoint, + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + STATUS_ENDPOINT=status_endpoint, ) + backend = AsyncLRSHTTPBackend(settings) # Mock GET response of HTTPX httpx_mock.add_response( @@ -115,12 +149,13 @@ async def test_backends_http_lrs_status_with_request_error( base_url = "http://fake-lrs.com" status_endpoint = "/__heartbeat__" - backend = AsyncLRSHTTP( - base_url=base_url, - username="user", - password="pass", - status_endpoint=status_endpoint, + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + STATUS_ENDPOINT=status_endpoint, ) + backend = AsyncLRSHTTPBackend(settings) httpx_mock.add_exception(RequestError("Test Request Error")) @@ -146,12 +181,13 @@ async def test_backends_http_lrs_status_with_http_status_error( base_url = "http://fake-lrs.com" status_endpoint = "/__heartbeat__" - backend = AsyncLRSHTTP( - base_url=base_url, - username="user", - password="pass", - status_endpoint=status_endpoint, + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + STATUS_ENDPOINT=status_endpoint, ) + backend = AsyncLRSHTTPBackend(settings) httpx_mock.add_exception( HTTPStatusError("Test HTTP Status Error", request=None, response=None) @@ -175,7 +211,12 @@ async def test_backends_http_lrs_list(caplog): base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) msg = ( "LRS HTTP backend does not support `list` method, " @@ -248,7 +289,10 @@ async def test_backends_http_lrs_read_max_statements( json=more_statements, ) - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = AsyncLRSHTTPBackend.settings_class( + BASE_URL=base_url, USERNAME="user", PASSWORD="pass" + ) + backend = AsyncLRSHTTPBackend(settings) # Return an iterable of dict result = await _unpack_async_generator( @@ -277,7 +321,12 @@ async def test_backends_http_lrs_read_without_target( base_url = "http://fake-lrs.com" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) statements = {"statements": [_gen_statement() for _ in range(3)]} @@ -289,7 +338,7 @@ async def test_backends_http_lrs_read_without_target( url=ParseResult( scheme=urlparse(base_url).scheme, netloc=urlparse(base_url).netloc, - path=backend.statements_endpoint, + path=backend.settings.STATEMENTS_ENDPOINT, query=urlencode(default_params).lower(), params="", fragment="", @@ -315,7 +364,12 @@ async def test_backends_http_lrs_read_backend_error( base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) # Mock GET response of HTTPX default_params = LRSStatementsQuery(limit=500).dict( @@ -356,7 +410,12 @@ async def test_backends_http_lrs_read_without_pagination( base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) statements = { "statements": [ @@ -448,7 +507,12 @@ async def test_backends_http_lrs_read_with_pagination(httpx_mock: HTTPXMock): base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) more_target = "/xAPI/statements/?pit_id=fake-pit-id" statements = { @@ -610,7 +674,12 @@ async def test_backends_http_lrs_write_without_operation( data = [_gen_statement() for _ in range(6)] - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) # Mock HTTPX POST httpx_mock.add_response(url=urljoin(base_url, target), method="POST", json=data) @@ -649,7 +718,12 @@ async def test_backends_http_lrs_write_without_data(caplog): base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) with caplog.at_level(logging.INFO): result = await backend.write(target=target, data=[]) @@ -682,7 +756,12 @@ async def test_backends_http_lrs_write_with_unsupported_operation( base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) with pytest.raises(BackendParameterException, match=error_msg): with caplog.at_level(logging.ERROR): @@ -715,7 +794,12 @@ async def test_backends_http_lrs_write_with_invalid_parameters( base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) with pytest.raises(BackendParameterException, match=error_msg): with caplog.at_level(logging.ERROR): @@ -740,13 +824,20 @@ async def test_backends_http_lrs_write_without_target(httpx_mock: HTTPXMock, cap base_url = "http://fake-lrs.com" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) data = [_gen_statement() for _ in range(3)] # Mock HTTPX POST httpx_mock.add_response( - url=urljoin(base_url, backend.statements_endpoint), method="POST", json=data + url=urljoin(base_url, backend.settings.STATEMENTS_ENDPOINT), + method="POST", + json=data, ) with caplog.at_level(logging.DEBUG): @@ -754,7 +845,8 @@ async def test_backends_http_lrs_write_without_target(httpx_mock: HTTPXMock, cap assert ( "ralph.backends.http.async_lrs", logging.DEBUG, - f"Start writing to the {base_url}{lrs_settings.STATEMENTS_ENDPOINT} " + "Start writing to the " + f"{base_url}{LRSHTTPBackendSettings().STATEMENTS_ENDPOINT} " "endpoint (chunk size: 500)", ) in caplog.record_tuples @@ -772,7 +864,12 @@ async def test_backends_http_lrs_write_with_create_or_index_operation( base_url = "http://fake-lrs.com" target = "/xAPI/statements" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) data = [_gen_statement() for _ in range(3)] @@ -801,7 +898,12 @@ async def test_backends_http_lrs_write_backend_exception( base_url = "http://fake-lrs.com" target = "/xAPI/statements" - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) data = [_gen_statement()] @@ -871,7 +973,12 @@ async def _simulate_slow_processing(): if index < num_pages - 1: all_statements[index]["more"] = targets[index + 1] - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) # Mock HTTPX GET params = {"limit": chunk_size} @@ -930,7 +1037,12 @@ async def test_backends_http_lrs_write_concurrency( # Changing data length might break tests assert len(data) == 6 - backend = AsyncLRSHTTP(base_url=base_url, username="user", password="pass") + settings = LRSHTTPBackendSettings( + BASE_URL=base_url, + USERNAME="user", + PASSWORD="pass", + ) + backend = AsyncLRSHTTPBackend(settings) # Mock HTTPX POST async def simulate_network_latency(request: httpx.Request): diff --git a/tests/backends/http/test_base.py b/tests/backends/http/test_base.py index 253e3c75a..0d419e59e 100644 --- a/tests/backends/http/test_base.py +++ b/tests/backends/http/test_base.py @@ -2,13 +2,13 @@ from typing import Iterator, Union -from ralph.backends.http.base import BaseHTTP, BaseQuery +from ralph.backends.http.base import BaseHTTPBackend, BaseQuery def test_backends_http_base_abstract_interface_with_implemented_abstract_method(): """Test the interface mechanism with properly implemented abstract methods.""" - class GoodStorage(BaseHTTP): + class GoodStorage(BaseHTTPBackend): """Correct implementation with required abstract methods.""" name = "good" diff --git a/tests/backends/http/test_lrs.py b/tests/backends/http/test_lrs.py index 6162461ed..b5799d6e4 100644 --- a/tests/backends/http/test_lrs.py +++ b/tests/backends/http/test_lrs.py @@ -7,11 +7,14 @@ import pytest from pydantic import AnyHttpUrl, parse_obj_as -from ralph.backends.http.async_lrs import AsyncLRSHTTP, HTTPBackendStatus -from ralph.backends.http.lrs import LRSHTTP -from ralph.conf import settings - -lrs_settings = settings.BACKENDS.HTTP.LRS +from ralph.backends.http.async_lrs import ( + AsyncLRSHTTPBackend, + HTTPBackendStatus, + LRSHeaders, + LRSHTTPBackendSettings, + LRSStatementsQuery, +) +from ralph.backends.http.lrs import LRSHTTPBackend @pytest.mark.anyio @@ -32,11 +35,11 @@ async def response_mock(*args, **kwargs): else: response_mock = AsyncMock(return_value=HTTPBackendStatus.OK) - monkeypatch.setattr(AsyncLRSHTTP, method, response_mock) + monkeypatch.setattr(AsyncLRSHTTPBackend, method, response_mock) async def async_function(): """Encapsulate the synchronous method in an asynchronous function.""" - lrs = LRSHTTP() + lrs = LRSHTTPBackend() if method == "read": list(getattr(lrs, method)()) else: @@ -48,7 +51,7 @@ async def async_function(): match=re.escape( ( f"This event loop is already running. You must use " - f"`AsyncLRSHTTP.{method}` (instead of `LRSHTTP.{method}`)" + f"`AsyncLRSHTTPBackend.{method}` (instead of `LRSHTTPBackend.{method}`)" ", or run this code outside the current event loop." ) ), @@ -56,39 +59,83 @@ async def async_function(): await async_function() -def test_backend_http_lrs_default_properties(): - """Test default LRS properties.""" - lrs = LRSHTTP() - assert lrs.name == "lrs" - assert lrs.base_url == parse_obj_as(AnyHttpUrl, lrs_settings.BASE_URL) - assert lrs.auth == (lrs_settings.USERNAME, lrs_settings.PASSWORD) - assert lrs.headers == lrs_settings.HEADERS - assert lrs.status_endpoint == lrs_settings.STATUS_ENDPOINT - assert lrs.statements_endpoint == lrs_settings.STATEMENTS_ENDPOINT +@pytest.mark.anyio +def test_backend_http_lrs_default_instantiation( + monkeypatch, fs +): # pylint:disable = invalid-name + """Test the `LRSHTTPBackend` default instantiation.""" + fs.create_file(".env") + backend_settings_names = [ + "BASE_URL", + "USERNAME", + "PASSWORD", + "HEADERS", + "STATUS_ENDPOINT", + "STATEMENTS_ENDPOINT", + ] + for name in backend_settings_names: + monkeypatch.delenv(f"RALPH_BACKENDS__HTTP__LRS__{name}", raising=False) + + assert LRSHTTPBackend.name == "lrs" + assert LRSHTTPBackend.settings_class == LRSHTTPBackendSettings + backend = LRSHTTPBackend() + assert backend.query == LRSStatementsQuery + assert backend.base_url == parse_obj_as(AnyHttpUrl, "http://0.0.0.0:8100") + assert backend.auth == ("ralph", "secret") + assert backend.settings.HEADERS == LRSHeaders() + assert backend.settings.STATUS_ENDPOINT == "/__heartbeat__" + assert backend.settings.STATEMENTS_ENDPOINT == "/xAPI/statements" + + +def test_backends_http_lrs_http_instantiation(): + """Test the LRS backend default instantiation.""" + + headers = LRSHeaders( + X_EXPERIENCE_API_VERSION="1.0.3", CONTENT_TYPE="application/json" + ) + settings = LRSHTTPBackendSettings( + BASE_URL="http://fake-lrs.com", + USERNAME="user", + PASSWORD="pass", + HEADERS=headers, + STATUS_ENDPOINT="/fake-status-endpoint", + STATEMENTS_ENDPOINT="/xAPI/statements", + ) + + assert LRSHTTPBackend.name == "lrs" + assert LRSHTTPBackend.settings_class == LRSHTTPBackendSettings + backend = LRSHTTPBackend(settings) + assert backend.query == LRSStatementsQuery + assert isinstance(backend.base_url, AnyHttpUrl) + assert backend.auth == ("user", "pass") + assert backend.settings.HEADERS.CONTENT_TYPE == "application/json" + assert backend.settings.HEADERS.X_EXPERIENCE_API_VERSION == "1.0.3" + assert backend.settings.STATUS_ENDPOINT == "/fake-status-endpoint" + assert backend.settings.STATEMENTS_ENDPOINT == "/xAPI/statements" def test_backends_http_lrs_inheritence(monkeypatch): - """Test that LRSHTTP properly inherits from AsyncLRSHTTP.""" - lrs = LRSHTTP() + """Test that `LRSHTTPBackend` properly inherits from `AsyncLRSHTTPBackend`.""" + lrs = LRSHTTPBackend() # Necessary when using anyio loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - # Test class inheritance - assert issubclass(lrs.__class__, AsyncLRSHTTP) + # Test class inheritence + assert issubclass(lrs.__class__, AsyncLRSHTTPBackend) # Test "status" status_mock_response = HTTPBackendStatus.OK status_mock = AsyncMock(return_value=status_mock_response) - monkeypatch.setattr(AsyncLRSHTTP, "status", status_mock) + monkeypatch.setattr(AsyncLRSHTTPBackend, "status", status_mock) assert lrs.status() == status_mock_response status_mock.assert_awaited() # Test "list" list_exception = NotImplementedError list_mock = AsyncMock(side_effect=list_exception) - monkeypatch.setattr(AsyncLRSHTTP, "list", list_mock) + monkeypatch.setattr(AsyncLRSHTTPBackend, "list", list_mock) with pytest.raises(list_exception): lrs.list() @@ -106,14 +153,14 @@ async def read_mock(*args, **kwargs): for statement in read_mock_response: yield statement - monkeypatch.setattr(AsyncLRSHTTP, "read", read_mock) + monkeypatch.setattr(AsyncLRSHTTPBackend, "read", read_mock) assert list(lrs.read(chunk_size=read_chunk_size)) == read_mock_response # Test "write" write_mock_response = 118218 chunk_size = 17 write_mock = AsyncMock(return_value=write_mock_response) - monkeypatch.setattr(AsyncLRSHTTP, "write", write_mock) + monkeypatch.setattr(AsyncLRSHTTPBackend, "write", write_mock) assert lrs.write(chunk_size=chunk_size) == write_mock_response write_mock.assert_called_with(chunk_size=chunk_size) From e75afd5ed16dbbe4ef2b9f080db15a330adfa632 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Fri, 18 Aug 2023 15:47:19 +0200 Subject: [PATCH 25/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20integra?= =?UTF-8?q?te=20unified=20backends=20in=20the=20CLI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After unifying database and storage backends under a common interface, backends settings are now handled directly alongside the backends classes. Modifying the CLI to support new settings and new backends interfaces. --- .env.dist | 2 +- src/ralph/backends/conf.py | 91 ++++ src/ralph/backends/data/es.py | 4 +- src/ralph/backends/database/__init__.py | 4 - src/ralph/backends/database/base.py | 162 ------ src/ralph/backends/database/clickhouse.py | 441 ----------------- src/ralph/backends/database/es.py | 297 ----------- src/ralph/backends/database/mongo.py | 300 ------------ src/ralph/backends/lrs/clickhouse.py | 8 +- src/ralph/backends/storage/__init__.py | 1 - src/ralph/backends/storage/base.py | 26 - src/ralph/backends/storage/fs.py | 123 ----- src/ralph/backends/storage/ldp.py | 145 ------ src/ralph/backends/storage/s3.py | 148 ------ src/ralph/backends/storage/swift.py | 160 ------ src/ralph/backends/stream/base.py | 2 +- src/ralph/cli.py | 148 ++++-- src/ralph/conf.py | 193 +------- src/ralph/utils.py | 69 ++- tests/backends/database/__init__.py | 0 tests/backends/database/test_clickhouse.py | 533 -------------------- tests/backends/database/test_es.py | 545 --------------------- tests/backends/database/test_mongo.py | 502 ------------------- tests/backends/lrs/test_clickhouse.py | 8 +- tests/backends/storage/__init__.py | 0 tests/backends/storage/test_base.py | 28 -- tests/backends/storage/test_fs.py | 110 ----- tests/backends/storage/test_ldp.py | 459 ----------------- tests/backends/storage/test_s3.py | 398 --------------- tests/backends/storage/test_swift.py | 293 ----------- tests/backends/stream/test_base.py | 4 +- tests/backends/test_conf.py | 144 ++++++ tests/conftest.py | 2 - tests/fixtures/backends.py | 92 +--- tests/test_cli.py | 137 +++--- tests/test_cli_usage.py | 382 ++++++++++----- tests/test_conf.py | 120 +---- tests/test_logger.py | 4 +- tests/test_utils.py | 65 ++- 39 files changed, 837 insertions(+), 5313 deletions(-) create mode 100644 src/ralph/backends/conf.py delete mode 100644 src/ralph/backends/database/__init__.py delete mode 100644 src/ralph/backends/database/base.py delete mode 100755 src/ralph/backends/database/clickhouse.py delete mode 100644 src/ralph/backends/database/es.py delete mode 100644 src/ralph/backends/database/mongo.py delete mode 100644 src/ralph/backends/storage/__init__.py delete mode 100644 src/ralph/backends/storage/base.py delete mode 100644 src/ralph/backends/storage/fs.py delete mode 100644 src/ralph/backends/storage/ldp.py delete mode 100644 src/ralph/backends/storage/s3.py delete mode 100644 src/ralph/backends/storage/swift.py delete mode 100644 tests/backends/database/__init__.py delete mode 100644 tests/backends/database/test_clickhouse.py delete mode 100644 tests/backends/database/test_es.py delete mode 100644 tests/backends/database/test_mongo.py delete mode 100644 tests/backends/storage/__init__.py delete mode 100644 tests/backends/storage/test_base.py delete mode 100644 tests/backends/storage/test_fs.py delete mode 100644 tests/backends/storage/test_ldp.py delete mode 100644 tests/backends/storage/test_s3.py delete mode 100644 tests/backends/storage/test_swift.py create mode 100644 tests/backends/test_conf.py diff --git a/.env.dist b/.env.dist index 72911d339..55a38fd4f 100644 --- a/.env.dist +++ b/.env.dist @@ -4,7 +4,7 @@ RALPH_APP_DIR=/app/.ralph # Uncomment lines (by removing # characters at the beginning of target lines) # to define environment variables associated to the backend(s) you need. -# LDP storage backend +# LDP data backend # # You need to generate an API token for your OVH's account and fill the service # name and stream id you are targeting. diff --git a/src/ralph/backends/conf.py b/src/ralph/backends/conf.py new file mode 100644 index 000000000..acdb3de87 --- /dev/null +++ b/src/ralph/backends/conf.py @@ -0,0 +1,91 @@ +"""Configurations for Ralph backends.""" + +from pydantic import BaseModel, BaseSettings + +from ralph.backends.data.clickhouse import ClickHouseDataBackendSettings +from ralph.backends.data.es import ESDataBackendSettings +from ralph.backends.data.fs import FSDataBackendSettings +from ralph.backends.data.ldp import LDPDataBackendSettings +from ralph.backends.data.mongo import MongoDataBackendSettings +from ralph.backends.data.s3 import S3DataBackendSettings +from ralph.backends.data.swift import SwiftDataBackendSettings +from ralph.backends.http.async_lrs import LRSHTTPBackendSettings +from ralph.backends.lrs.clickhouse import ClickHouseLRSBackendSettings +from ralph.backends.lrs.fs import FSLRSBackendSettings +from ralph.backends.stream.ws import WSStreamBackendSettings +from ralph.conf import BaseSettingsConfig, core_settings + +# Active Data backend Settings. + + +class DataBackendSettings(BaseModel): + """Pydantic model for data backend configuration settings.""" + + ASYNC_ES: ESDataBackendSettings = ESDataBackendSettings() + ASYNC_MONGO: MongoDataBackendSettings = MongoDataBackendSettings() + CLICKHOUSE: ClickHouseDataBackendSettings = ClickHouseDataBackendSettings() + ES: ESDataBackendSettings = ESDataBackendSettings() + FS: FSDataBackendSettings = FSDataBackendSettings() + LDP: LDPDataBackendSettings = LDPDataBackendSettings() + MONGO: MongoDataBackendSettings = MongoDataBackendSettings() + SWIFT: SwiftDataBackendSettings = SwiftDataBackendSettings() + S3: S3DataBackendSettings = S3DataBackendSettings() + + +# Active HTTP backend Settings. + + +class HTTPBackendSettings(BaseModel): + """Pydantic model for HTTP backend configuration settings.""" + + LRS: LRSHTTPBackendSettings = LRSHTTPBackendSettings() + + +# Active LRS backend Settings. + + +class LRSBackendSettings(BaseModel): + """Pydantic model for LRS compatible backend configuration settings.""" + + ASYNC_ES: ESDataBackendSettings = ESDataBackendSettings() + ASYNC_MONGO: MongoDataBackendSettings = MongoDataBackendSettings() + CLICKHOUSE: ClickHouseLRSBackendSettings = ClickHouseLRSBackendSettings() + ES: ESDataBackendSettings = ESDataBackendSettings() + FS: FSLRSBackendSettings = FSLRSBackendSettings() + MONGO: MongoDataBackendSettings = MongoDataBackendSettings() + + +# Active Stream backend Settings. + + +class StreamBackendSettings(BaseModel): + """Pydantic model for stream backend configuration settings.""" + + WS: WSStreamBackendSettings = WSStreamBackendSettings() + + +# Active backend Settings. + + +class Backends(BaseModel): + """Pydantic model for backends configuration settings.""" + + DATA: DataBackendSettings = DataBackendSettings() + HTTP: HTTPBackendSettings = HTTPBackendSettings() + LRS: LRSBackendSettings = LRSBackendSettings() + STREAM: StreamBackendSettings = StreamBackendSettings() + + +class BackendSettings(BaseSettings): + """Pydantic model for Ralph's backends environment & configuration settings.""" + + class Config(BaseSettingsConfig): + """Pydantic Configuration.""" + + env_file = ".env" + env_file_encoding = core_settings.LOCALE_ENCODING + + BACKENDS: Backends = Backends() + + +backends_settings = BackendSettings() diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index b7a7a9662..db9ea11d7 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -18,14 +18,14 @@ DataBackendStatus, enforce_query_checks, ) -from ralph.conf import BaseSettingsConfig, CommaSeparatedTuple +from ralph.conf import BaseSettingsConfig, ClientOptions, CommaSeparatedTuple from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import parse_bytes_to_dict, read_raw logger = logging.getLogger(__name__) -class ESClientOptions(BaseModel): +class ESClientOptions(ClientOptions): """Elasticsearch additional client options.""" ca_certs: Path = None diff --git a/src/ralph/backends/database/__init__.py b/src/ralph/backends/database/__init__.py deleted file mode 100644 index 9c9b37b79..000000000 --- a/src/ralph/backends/database/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Database backends for Ralph.""" - -from .base import BaseDatabase # noqa: F401 -from .es import ESDatabase # noqa: F401 diff --git a/src/ralph/backends/database/base.py b/src/ralph/backends/database/base.py deleted file mode 100644 index 851fc4199..000000000 --- a/src/ralph/backends/database/base.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Base database backend for Ralph.""" - -import functools -import logging -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum, unique -from typing import BinaryIO, List, Optional, TextIO, Union - -from pydantic import BaseModel - -from ralph.backends.http.async_lrs import LRSStatementsQuery -from ralph.exceptions import BackendParameterException - -logger = logging.getLogger(__name__) - - -class BaseQuery(BaseModel): - """Base query model.""" - - class Config: - """Base query model configuration.""" - - extra = "forbid" - - -@dataclass -class StatementQueryResult: - """Represent a common interface for results of an LRS statements query.""" - - statements: List[dict] - pit_id: str - search_after: str - - -@unique -class DatabaseStatus(Enum): - """Database statuses.""" - - OK = "ok" - AWAY = "away" - ERROR = "error" - - -class AgentParameters(BaseModel): - """Dictionary of possible LRS query parameters for query on type Agent. - - NB: Agent refers to the data structure, NOT to the LRS query parameter. - """ - - mbox: Optional[str] - mbox_sha1sum: Optional[str] - openid: Optional[str] - account__name: Optional[str] - account__home_page: Optional[str] - - -class RalphStatementsQuery(LRSStatementsQuery): - """Represent a dictionary of possible LRS query parameters.""" - - # pylint: disable=too-many-instance-attributes - - agent: Optional[AgentParameters] = AgentParameters.construct() - search_after: Optional[str] - pit_id: Optional[str] - authority: Optional[AgentParameters] = AgentParameters.construct() - - def __post_init__(self): - """Perform additional conformity verifications on parameters.""" - # Initiate agent parameters for queries "agent" and "authority" - for query_param in ["agent", "authority"]: - # Check that both `homePage` and `name` are provided if any are - if (self.__dict__[query_param].account__name is not None) != ( - self.__dict__[query_param].account__home_page is not None - ): - raise BackendParameterException( - f"Invalid {query_param} parameters: homePage and name are " - "both required" - ) - - # Check that one or less Inverse Functional Identifier is provided - if ( - sum( - x is not None - for x in [ - self.__dict__[query_param].mbox, - self.__dict__[query_param].mbox_sha1sum, - self.__dict__[query_param].openid, - self.__dict__[query_param].account__name, - ] - ) - > 1 - ): - raise BackendParameterException( - f"Invalid {query_param} parameters: Only one identifier can be used" - ) - - -def enforce_query_checks(method): - """Enforce query argument type checking for methods using it.""" - - @functools.wraps(method) - def wrapper(*args, **kwargs): - """Wrap method execution.""" - query = kwargs.pop("query", None) - self_ = args[0] - - return method(*args, query=self_.validate_query(query), **kwargs) - - return wrapper - - -class BaseDatabase(ABC): - """Base database backend interface.""" - - name = "base" - query_model = BaseQuery - - def validate_query(self, query: BaseQuery = None): - """Validate database query.""" - if query is None: - query = self.query_model() - - if not isinstance(query, self.query_model): - raise BackendParameterException( - "'query' argument is expected to be a " - f"{self.query_model().__class__.__name__} instance." - ) - - logger.debug("Query: %s", str(query)) - - return query - - @abstractmethod - def status(self) -> DatabaseStatus: - """Implement database checks (e.g. connection, cluster status).""" - - @abstractmethod - @enforce_query_checks - def get(self, query: BaseQuery = None, chunk_size: int = 10): - """Yield `chunk_size` records read from the database query results.""" - - @abstractmethod - def put( - self, - stream: Union[BinaryIO, TextIO], - chunk_size: int = 10, - ignore_errors: bool = False, - ) -> int: - """Write `chunk_size` records from the `stream` to the database. - - Returns: - int: The count of successfully written records. - """ - - @abstractmethod - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: - """Return the statements query payload using xAPI parameters.""" - - @abstractmethod - def query_statements_by_ids(self, ids: List[str]) -> List: - """Return the list of matching statement IDs from the database.""" diff --git a/src/ralph/backends/database/clickhouse.py b/src/ralph/backends/database/clickhouse.py deleted file mode 100755 index a0d85a8dd..000000000 --- a/src/ralph/backends/database/clickhouse.py +++ /dev/null @@ -1,441 +0,0 @@ -"""ClickHouse database backend for Ralph.""" - -import datetime -import json -import logging -import uuid -from typing import Generator, List, Optional, TextIO, Union - -import clickhouse_connect -from clickhouse_connect.driver.exceptions import ClickHouseError -from pydantic import BaseModel, ValidationError - -from ralph.conf import ClickhouseClientOptions, settings -from ralph.exceptions import BackendException, BadFormatException - -from .base import ( - BaseDatabase, - BaseQuery, - DatabaseStatus, - RalphStatementsQuery, - StatementQueryResult, - enforce_query_checks, -) - -clickhouse_settings = settings.BACKENDS.DATABASE.CLICKHOUSE -logger = logging.getLogger(__name__) - - -class ClickHouseInsert(BaseModel): - """Model to validate required fields for ClickHouse insertion.""" - - event_id: uuid.UUID - emission_time: datetime.datetime - - -class ClickHouseQuery(BaseQuery): - """ClickHouse query model.""" - - where_clause: Optional[str] - return_fields: Optional[List[str]] - - -class ClickHouseDatabase(BaseDatabase): # pylint: disable=too-many-instance-attributes - """ClickHouse database backend.""" - - name = "clickhouse" - query_model = ClickHouseQuery - - def __init__( # pylint: disable=too-many-arguments - self, - host: str = clickhouse_settings.HOST, - port: int = clickhouse_settings.PORT, - database: str = clickhouse_settings.DATABASE, - event_table_name: str = clickhouse_settings.EVENT_TABLE_NAME, - username: str = clickhouse_settings.USERNAME, - password: str = clickhouse_settings.PASSWORD, - client_options: ClickhouseClientOptions = clickhouse_settings.CLIENT_OPTIONS, - ): - """Instantiates the ClickHouse configuration. - - Args: - host (str): ClickHouse server host to connect to. - port (int): ClickHouse server port to connect to. - database (str): ClickHouse database to connect to. - event_table_name (str): Table where events live. - username (str): ClickHouse username to connect as (optional). - password (str): Password for the given ClickHouse username (optional). - client_options (dict): A dictionary of valid options for the ClickHouse - client connection. - - If username and password are None, we will try to connect as the ClickHouse - user "default". - """ - if client_options is None: - client_options = { - "date_time_input_format": "best_effort", # Allows RFC dates - "allow_experimental_object_type": 1, # Allows JSON data type - } - else: - client_options = client_options.dict() - - self.host = host - self.port = port - self.database = database - self.event_table_name = event_table_name - self.username = username - self.password = password - self.client_options = client_options - self._client = None - - @property - def client(self): - """Create a ClickHouse client if it doesn't exist. - - We do this here so that we don't interrupt initialization in the case - where ClickHouse is not running when Ralph starts up, which will cause - Ralph to hang. This client is HTTP, so not actually stateful. Ralph - should be able to gracefully deal with ClickHouse outages at all other - times. - """ - if not self._client: - self._client = clickhouse_connect.get_client( - host=self.host, - port=self.port, - database=self.database, - username=self.username, - password=self.password, - settings=self.client_options, - ) - return self._client - - def status(self) -> DatabaseStatus: - """Check ClickHouse connection status.""" - try: - self.client.query("SELECT 1") - except ClickHouseError: - return DatabaseStatus.AWAY - - return DatabaseStatus.OK - - @enforce_query_checks - def get(self, query: ClickHouseQuery = None, chunk_size: int = 500): - """Get table rows and yields them.""" - fields = ",".join(query.return_fields) if query.return_fields else "event" - - sql = f"SELECT {fields} FROM {self.event_table_name}" # nosec - - if query.where_clause: - sql += f" WHERE {query.where_clause}" - - result = self.client.query(sql).named_results() - - for statement in result: - yield statement - - @staticmethod - def to_documents( - stream: Union[TextIO, List], ignore_errors: bool = False - ) -> Generator[dict, None, None]: - """Convert `stream` lines (one statement per line) to insert tuples.""" - for line in stream: - statement = json.loads(line) if isinstance(line, str) else line - - try: - insert = ClickHouseInsert( - event_id=statement["id"], emission_time=statement["timestamp"] - ) - except (KeyError, ValidationError) as exc: - err = ( - "Statement has an invalid or missing id or " - f"timestamp field: {statement}" - ) - if ignore_errors: - logger.warning(err) - continue - raise BadFormatException(err) from exc - - document = ( - insert.event_id, - insert.emission_time, - statement, - json.dumps(statement), - ) - - yield document - - def bulk_import(self, batch: List, ignore_errors: bool = False) -> int: - """Insert a batch of documents into the selected database table.""" - try: - # ClickHouse does not do unique keys. This is a "best effort" to - # at least check for duplicates in each batch. Overall ID checking - # against the database happens upstream in the POST / PUT methods. - # - # As opposed to Mongo, the entire batch is guaranteed to fail here - # if any dupes are found. - found_ids = {x[0] for x in batch} - - if len(found_ids) != len(batch): - raise BackendException("Duplicate IDs found in batch") - - self.client.insert( - self.event_table_name, - batch, - column_names=[ - "event_id", - "emission_time", - "event", - "event_str", - ], - # Allow ClickHouse to buffer the insert, and wait for the - # buffer to flush. Should be configurable, but I think these are - # reasonable defaults. - settings={"async_insert": 1, "wait_for_async_insert": 1}, - ) - except (ClickHouseError, BackendException) as error: - if not ignore_errors: - raise BackendException(*error.args) from error - logger.warning( - "Bulk import failed for current chunk but you choose to ignore it.", - ) - # There is no current way of knowing how many rows from the batch - # succeeded, we assume 0 here. - return 0 - - logger.debug("Inserted %s documents chunk with success", len(batch)) - - return len(batch) - - def put( - self, - stream: Union[TextIO, List], - chunk_size: int = 500, - ignore_errors: bool = False, - ) -> int: - """Write documents from the `stream` to the instance table.""" - logger.debug( - "Start writing to the %s table of the %s database (chunk size: %d)", - self.event_table_name, - self.database, - chunk_size, - ) - - rows_inserted = 0 - batch = [] - for document in self.to_documents(stream, ignore_errors=ignore_errors): - batch.append(document) - if len(batch) < chunk_size: - continue - - rows_inserted += self.bulk_import(batch, ignore_errors=ignore_errors) - batch = [] - - # Catch any remaining documents when the last batch is smaller than chunk_size - if len(batch) > 0: - rows_inserted += self.bulk_import(batch, ignore_errors=ignore_errors) - - logger.debug("Inserted a total of %s documents with success", rows_inserted) - - return rows_inserted - - def query_statements_by_ids(self, ids: List[str]) -> List[dict]: - """Return the list of matching statements from the database.""" - - def chunk_id_list(chunk_size=10000): - for i in range(0, len(ids), chunk_size): - yield ids[i : i + chunk_size] - - sql = """ - SELECT event_id, event_str - FROM {table_name:Identifier} - WHERE event_id IN ({ids:Array(String)}) - """ - - query_context = self.client.create_query_context( - query=sql, - parameters={"ids": ["1"], "table_name": self.event_table_name}, - column_oriented=True, - ) - - found_statements = [] - - try: - for chunk_ids in chunk_id_list(): - query_context.set_parameter("ids", chunk_ids) - result = self.client.query(context=query_context).named_results() - for row in result: - # This is the format to match the other backends - found_statements.append( - { - "_id": str(row["event_id"]), - "_source": json.loads(row["event_str"]), - } - ) - - return found_statements - except (ClickHouseError, IndexError, TypeError, ValueError) as error: - msg = "Failed to execute ClickHouse query" - logger.error("%s. %s", msg, error) - raise BackendException(msg, *error.args) from error - - @staticmethod - def _add_agent_filters( - clickhouse_params, where_clauses, agent_params, target_field - ): - """Add filters relative to agents to `clickhouse_params` and `where_clauses`. - - Args: - clickhouse_params: values to be used in `where_clauses` - where_clauses: filters to be passed to clickhouse - agent_params: query parameters that represent the agent to search for - target_field: the field in the database in which to perform the search - """ - if agent_params.mbox: - clickhouse_params[f"{target_field}__mbox"] = agent_params.mbox - where_clauses.append( - f"event.{target_field}.mbox = {{{target_field}__mbox:String}}" - ) - - if agent_params.mbox_sha1sum: - clickhouse_params[ - f"{target_field}__mbox_sha1sum" - ] = agent_params.mbox_sha1sum - where_clauses.append( - f"event.{target_field}.mbox_sha1sum = {{{target_field}__mbox_sha1sum:String}}" # noqa: E501 # pylint: disable=line-too-long - ) - - if agent_params.openid: - clickhouse_params[f"{target_field}__openid"] = agent_params.openid - where_clauses.append( - f"event.{target_field}.openid = {{{target_field}__openid:String}}" - ) - - if agent_params.account__name: - clickhouse_params[ - f"{target_field}__account__name" - ] = agent_params.account__name - clickhouse_params[ - f"{target_field}__account__home_page" - ] = agent_params.account__home_page - where_clauses.append( - f"event.{target_field}.account.name = {{{target_field}__account__name:String}}" # noqa: E501 # pylint: disable=line-too-long - ) - where_clauses.append( - f"event.{target_field}.account.homePage = {{{target_field}__account__home_page:String}}" # noqa: E501 # pylint: disable=line-too-long - ) - - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: - """Return the results of a statements query using xAPI parameters.""" - # pylint: disable=too-many-branches - # pylint: disable=invalid-name - - clickhouse_params = params.dict(exclude_none=True) - where_clauses = [] - - if params.statement_id: - where_clauses.append("event_id = {statementId:UUID}") - - self._add_agent_filters( - clickhouse_params, - where_clauses, - params.agent, - target_field="actor", - ) - clickhouse_params.pop("agent") - - self._add_agent_filters( - clickhouse_params, - where_clauses, - params.authority, - target_field="authority", - ) - clickhouse_params.pop("authority") - - if params.verb: - where_clauses.append("event.verb.id = {verb:String}") - - if params.activity: - where_clauses.append("event.object.objectType = 'Activity'") - where_clauses.append("event.object.id = {activity:String}") - - if params.since: - where_clauses.append("emission_time > {since:DateTime64(6)}") - - if params.until: - where_clauses.append("emission_time <= {until:DateTime64(6)}") - - if params.search_after: - search_order = ">" if params.ascending else "<" - - where_clauses.append( - f"(emission_time {search_order} " - "{search_after:DateTime64(6)}" - " OR " - "(emission_time = {search_after:DateTime64(6)}" - " AND " - f"event_id {search_order} " - "{pit_id:UUID}" - "))" - ) - - sort_order = "ASCENDING" if params.ascending else "DESCENDING" - order_by = f"emission_time {sort_order}, event_id {sort_order}" - - response = self._find( - where=where_clauses, - parameters=clickhouse_params, - limit=params.limit, - sort=order_by, - ) - response = list(response) - - new_search_after = None - new_pit_id = None - - if response: - # Our search after string is a combination of event timestamp and - # event id, so that we can avoid losing events when they have the - # same timestamp, and also avoid sending the same event twice. - new_search_after = response[-1]["emission_time"].isoformat() - new_pit_id = str(response[-1]["event_id"]) - - return StatementQueryResult( - statements=[document["event"] for document in response], - search_after=new_search_after, - pit_id=new_pit_id, - ) - - def _find( - self, parameters: dict, where: List = None, limit: int = None, sort: str = None - ): - """Wrap the ClickHouse query method. - - Raises: - BackendException: raised for any failure. - """ - sql = """ - SELECT event_id, emission_time, event - FROM {event_table_name:Identifier} - """ - if where: - filter_str = "WHERE 1=1 AND " - filter_str += """ - AND - """.join( - where - ) - sql += filter_str - if sort: - sql += f"\nORDER BY {sort}" - - if limit: - sql += f"\nLIMIT {limit}" - - parameters["event_table_name"] = self.event_table_name - - try: - return self.client.query(sql, parameters=parameters).named_results() - except (ClickHouseError, IndexError, TypeError, ValueError) as error: - msg = "Failed to execute ClickHouse query" - logger.error("%s. %s", msg, error) - raise BackendException(msg, *error.args) from error diff --git a/src/ralph/backends/database/es.py b/src/ralph/backends/database/es.py deleted file mode 100644 index 0f33b2f39..000000000 --- a/src/ralph/backends/database/es.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Elasticsearch database backend for Ralph.""" - -import json -import logging -from enum import Enum -from typing import Callable, Generator, List, Optional, TextIO - -from elasticsearch import ApiError -from elasticsearch import ConnectionError as ESConnectionError -from elasticsearch import Elasticsearch -from elasticsearch.client import CatClient -from elasticsearch.helpers import BulkIndexError, scan, streaming_bulk - -from ralph.conf import ESClientOptions, settings -from ralph.exceptions import BackendException, BackendParameterException - -from .base import ( - AgentParameters, - BaseDatabase, - BaseQuery, - DatabaseStatus, - RalphStatementsQuery, - StatementQueryResult, - enforce_query_checks, -) - -es_settings = settings.BACKENDS.DATABASE.ES -logger = logging.getLogger(__name__) - - -class OpType(Enum): - """Elasticsearch operation types.""" - - INDEX = "index" - CREATE = "create" - DELETE = "delete" - UPDATE = "update" - - -class ESQuery(BaseQuery): - """Elasticsearch body query model.""" - - query: Optional[dict] - - -class ESDatabase(BaseDatabase): - """Elasticsearch database backend.""" - - name = "es" - query_model = ESQuery - - def __init__( - self, - hosts: list = es_settings.HOSTS, - index: str = es_settings.INDEX, - client_options: ESClientOptions = es_settings.CLIENT_OPTIONS, - op_type: str = es_settings.OP_TYPE, - ): - """Instantiates the Elasticsearch client. - - Args: - hosts (list): List of Elasticsearch nodes we should connect to. - index (str): The Elasticsearch index name. - client_options (dict): A dictionary of valid options for the - Elasticsearch class initialization. - op_type (str): The Elasticsearch operation type for every document sent to - Elasticsearch (should be one of: index, create, delete, update). - """ - self._hosts = hosts - self.index = index - - self.client = Elasticsearch(self._hosts, **client_options.dict()) - if op_type not in [op.value for op in OpType]: - raise BackendParameterException( - f"{op_type} is not an allowed operation type" - ) - self.op_type = op_type - - def status(self) -> DatabaseStatus: - """Check Elasticsearch cluster (connection) status.""" - # Check ES cluster connection - try: - self.client.info() - except ESConnectionError: - return DatabaseStatus.AWAY - - # Check cluster status - if "green" not in CatClient(self.client).health(): - return DatabaseStatus.ERROR - - return DatabaseStatus.OK - - @enforce_query_checks - def get(self, query: ESQuery = None, chunk_size: int = 500): - """Get index documents and yields them. - - The `query` dictionary should only contain kwargs compatible with the - elasticsearch.helpers.scan function signature (API reference - documentation: - https://elasticsearch-py.readthedocs.io/en/latest/helpers.html#scan). - """ - for document in scan( - self.client, index=self.index, size=chunk_size, **query.dict() - ): - yield document - - def to_documents( - self, stream: TextIO, get_id: Callable[[dict], str] - ) -> Generator[dict, None, None]: - """Convert `stream` lines to ES documents.""" - for line in stream: - item = json.loads(line) if isinstance(line, str) else line - action = { - "_index": self.index, - "_id": get_id(item), - "_op_type": self.op_type, - } - if self.op_type == "update": - action.update({"doc": item}) - elif self.op_type in ("create", "index"): - action.update({"_source": item}) - yield action - - def put( - self, stream: TextIO, chunk_size: int = 500, ignore_errors: bool = False - ) -> int: - """Write documents from the `stream` to the instance index.""" - logger.debug( - "Start writing to the %s index (chunk size: %d)", self.index, chunk_size - ) - - documents = 0 - try: - for success, action in streaming_bulk( - client=self.client, - actions=self.to_documents(stream, lambda d: d.get("id", None)), - chunk_size=chunk_size, - raise_on_error=(not ignore_errors), - ): - documents += success - logger.debug( - "Wrote %d documents [action: %s ok: %d]", documents, action, success - ) - except BulkIndexError as error: - raise BackendException( - *error.args, f"{documents} succeeded writes" - ) from error - return documents - - @staticmethod - def _add_agent_filters( - es_query_filters: list, agent_params: AgentParameters, target_field: str - ): - """Add filters relative to agents to es_query_filters. - - Args: - es_query_filters: list of filters to be passed to elasticsearch - agent_params: query parameters that represent the agent to search for - target_field: the field in the database in which to perform the search - """ - if agent_params.mbox: - es_query_filters += [ - {"term": {f"{target_field}.mbox.keyword": agent_params.mbox}} - ] - - if agent_params.mbox_sha1sum: - es_query_filters += [ - { - "term": { - f"{target_field}.mbox_sha1sum.keyword": agent_params.mbox_sha1sum # noqa: E501 # pylint: disable=line-too-long - } - } - ] - - if agent_params.openid: - es_query_filters += [ - {"term": {f"{target_field}.openid.keyword": agent_params.openid}} - ] - - if agent_params.account__name: - es_query_filters += [ - { - "term": { - f"{target_field}.account.name.keyword": agent_params.account__name # noqa: E501 # pylint: disable=line-too-long - } - } - ] - es_query_filters += [ - { - "term": { - f"{target_field}.account.homePage.keyword": agent_params.account__home_page # noqa: E501 # pylint: disable=line-too-long - } - } - ] - - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: - """Return the results of a statements query using xAPI parameters.""" - es_query_filters = [] - - if params.statement_id: - es_query_filters += [{"term": {"_id": params.statement_id}}] - - self._add_agent_filters( - es_query_filters, params.__dict__["agent"], target_field="actor" - ) - self._add_agent_filters( - es_query_filters, params.__dict__["authority"], target_field="authority" - ) - - if params.verb: - es_query_filters += [{"term": {"verb.id.keyword": params.verb}}] - - if params.activity: - es_query_filters += [ - {"term": {"object.objectType.keyword": "Activity"}}, - {"term": {"object.id.keyword": params.activity}}, - ] - - if params.since: - es_query_filters += [{"range": {"timestamp": {"gt": params.since}}}] - - if params.until: - es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] - - if len(es_query_filters) > 0: - es_query = {"query": {"bool": {"filter": es_query_filters}}} - else: - es_query = {"query": {"match_all": {}}} - - # Honor the "ascending" parameter, otherwise show most recent statements first - es_query.update( - {"sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}]} - ) - - if params.search_after: - es_query.update({"search_after": params.search_after.split("|")}) - - # Disable total hits counting for performance as we're not using it. - es_query.update({"track_total_hits": False}) - - if not params.pit_id: - pit_response = self._open_point_in_time( - index=self.index, keep_alive=settings.RUNSERVER_POINT_IN_TIME_KEEP_ALIVE - ) - params.pit_id = pit_response["id"] - - es_query.update( - { - "pit": { - "id": params.pit_id, - # extend duration of PIT whenever it is used - "keep_alive": settings.RUNSERVER_POINT_IN_TIME_KEEP_ALIVE, - } - } - ) - es_response = self._search(body=es_query, size=params.limit) - es_documents = es_response["hits"]["hits"] - search_after = None - if es_documents: - search_after = "|".join([str(part) for part in es_documents[-1]["sort"]]) - - return StatementQueryResult( - statements=[document["_source"] for document in es_documents], - pit_id=es_response["pit_id"], - search_after=search_after, - ) - - def query_statements_by_ids(self, ids: List[str]) -> List: - """Return the list of matching statement IDs from the database.""" - body = {"query": {"terms": {"_id": ids}}} - return self._search(index=self.index, body=body)["hits"]["hits"] - - def _search(self, **kwargs): - """Wrap the ElasticSearch.search method. - - Raises: - BackendException: raised for any failure. - """ - try: - return self.client.search(**kwargs) - except ApiError as error: - msg = "Failed to execute ElasticSearch query" - logger.error("%s. %s", msg, error) - raise BackendException(msg, *error.args) from error - - def _open_point_in_time(self, **kwargs): - """Wrap the ElasticSearch.open_point_in_time method. - - Raises: - BackendException: raised for any failure. - """ - try: - return self.client.open_point_in_time(**kwargs) - except (ApiError, ValueError) as error: - msg = "Failed to open ElasticSearch point in time" - logger.error("%s. %s", msg, error) - raise BackendException(msg, *error.args) from error diff --git a/src/ralph/backends/database/mongo.py b/src/ralph/backends/database/mongo.py deleted file mode 100644 index c9b6eba10..000000000 --- a/src/ralph/backends/database/mongo.py +++ /dev/null @@ -1,300 +0,0 @@ -"""MongoDB database backend for Ralph.""" - -import hashlib -import json -import logging -import struct -from typing import Generator, List, Optional, TextIO, Union - -from bson.objectid import ObjectId -from dateutil.parser import isoparse -from pymongo import ASCENDING, DESCENDING, MongoClient -from pymongo.errors import BulkWriteError, ConnectionFailure, PyMongoError - -from ralph.conf import MongoClientOptions, settings -from ralph.exceptions import BackendException, BadFormatException - -from .base import ( - AgentParameters, - BaseDatabase, - BaseQuery, - DatabaseStatus, - RalphStatementsQuery, - StatementQueryResult, - enforce_query_checks, -) - -mongo_settings = settings.BACKENDS.DATABASE.MONGO -logger = logging.getLogger(__name__) - - -class MongoQuery(BaseQuery): - """Mongo query model.""" - - filter: Optional[dict] - projection: Optional[dict] - - -class MongoDatabase(BaseDatabase): - """Mongo database backend.""" - - name = "mongo" - query_model = MongoQuery - - def __init__( - self, - connection_uri: str = mongo_settings.CONNECTION_URI, - database: str = mongo_settings.DATABASE, - collection: str = mongo_settings.COLLECTION, - client_options: MongoClientOptions = mongo_settings.CLIENT_OPTIONS, - ): - """Instantiates the Mongo client. - - Args: - connection_uri (str): MongoDB connection URI. - database (str): MongoDB database to connect to. - collection (str): MongoDB database collection to get objects from. - client_options (MongoClientOptions): A dictionary of valid options - for the MongoClient class initialization. - """ - self.client = MongoClient(connection_uri, **client_options.dict()) - self.database = getattr(self.client, database) - self.collection = getattr(self.database, collection) - - def status(self) -> DatabaseStatus: - """Check MongoDB cluster connection status.""" - # Check Mongo cluster connection - try: - self.client.admin.command("ping") - except ConnectionFailure: - return DatabaseStatus.AWAY - - # Check cluster status - if self.client.admin.command("serverStatus").get("ok", 0.0) < 1.0: - return DatabaseStatus.ERROR - - return DatabaseStatus.OK - - @enforce_query_checks - def get(self, query: MongoQuery = None, chunk_size: int = 500): - """Get collection documents and yields them. - - The `query` dictionary should only contain kwargs compatible with the - pymongo.collection.Collection.find method signature (API reference - documentation: https://pymongo.readthedocs.io/en/stable/api/pymongo/). - """ - for document in self.collection.find(batch_size=chunk_size, **query.dict()): - # Make the document json-serializable - document.update({"_id": str(document.get("_id"))}) - yield document - - @staticmethod - def to_documents( - stream: Union[TextIO, list], ignore_errors: bool = False - ) -> Generator[dict, None, None]: - """Convert `stream` lines (one statement per line) to Mongo documents. - - We expect statements to have at least an `id` and a `timestamp` field that will - be used to compute a unique MongoDB Object ID. This ensures that we will not - duplicate statements in our database and allows us to support pagination. - """ - for line in stream: - statement = json.loads(line) if isinstance(line, str) else line - if "id" not in statement: - msg = f"statement {statement} has no 'id' field" - if ignore_errors: - logger.warning(msg) - continue - raise BadFormatException(msg) - if "timestamp" not in statement: - msg = f"statement {statement} has no 'timestamp' field" - if ignore_errors: - logger.warning(msg) - continue - raise BadFormatException(msg) - try: - timestamp = int(isoparse(statement["timestamp"]).timestamp()) - except ValueError as err: - msg = f"statement {statement} has an invalid 'timestamp' field" - if ignore_errors: - logger.warning(msg) - continue - raise BadFormatException(msg) from err - document = { - "_id": ObjectId( - # This might become a problem in February 2106. - # Meanwhile, we use the timestamp in the _id field for pagination. - struct.pack(">I", timestamp) - + bytes.fromhex( - hashlib.sha256(bytes(statement["id"], "utf-8")).hexdigest()[:16] - ) - ), - "_source": statement, - } - - yield document - - def bulk_import(self, batch: list, ignore_errors: bool = False): - """Insert a batch of documents into the selected database collection.""" - try: - new_documents = self.collection.insert_many(batch) - except BulkWriteError as error: - if not ignore_errors: - raise BackendException( - *error.args, f"{error.details['nInserted']} succeeded writes" - ) from error - logger.warning( - "Bulk importation failed for current documents chunk but you choose " - "to ignore it.", - ) - return error.details["nInserted"] - - inserted_count = len(new_documents.inserted_ids) - logger.debug("Inserted %d documents chunk with success", inserted_count) - - return inserted_count - - def put( - self, - stream: Union[TextIO, list], - chunk_size: int = 500, - ignore_errors: bool = False, - ) -> int: - """Write documents from the `stream` to the instance collection.""" - logger.debug( - "Start writing to the %s collection of the %s database (chunk size: %d)", - self.collection, - self.database, - chunk_size, - ) - - success = 0 - batch = [] - for document in self.to_documents(stream, ignore_errors=ignore_errors): - batch.append(document) - if len(batch) < chunk_size: - continue - - success += self.bulk_import(batch, ignore_errors=ignore_errors) - batch = [] - - # Edge case: if the total number of documents is lower than the chunk size - if len(batch) > 0: - success += self.bulk_import(batch, ignore_errors=ignore_errors) - - logger.debug("Inserted a total of %d documents with success", success) - - return success - - @staticmethod - def _add_agent_filters( - mongo_query_filters: dict, agent_params: AgentParameters, target_field: str - ): - """Add filters relative to agents to mongo_query_filters. - - Args: - mongo_query_filters: filters to be passed to mongo - agent_params: query parameters that represent the agent to search for - target_field: the field in the database in which to perform the search - """ - if agent_params.mbox: - mongo_query_filters.update( - {f"_source.{target_field}.mbox": agent_params.mbox} - ) - - if agent_params.mbox_sha1sum: - mongo_query_filters.update( - {f"_source.{target_field}.mbox_sha1sum": agent_params.mbox_sha1sum} - ) - - if agent_params.openid: - mongo_query_filters.update( - {f"_source.{target_field}.openid": agent_params.openid} - ) - - if agent_params.account__name: - mongo_query_filters.update( - {f"_source.{target_field}.account.name": agent_params.account__name} - ) - mongo_query_filters.update( - { - f"_source.{target_field}.account.homePage": agent_params.account__home_page # noqa: E501 # pylint: disable=line-too-long - } - ) - - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: - """Return the results of a statements query using xAPI parameters.""" - # pylint: disable=too-many-branches - mongo_query_filters = {} - - if params.statement_id: - mongo_query_filters.update({"_source.id": params.statement_id}) - - self._add_agent_filters( - mongo_query_filters, params.__dict__["agent"], target_field="actor" - ) - self._add_agent_filters( - mongo_query_filters, params.__dict__["authority"], target_field="authority" - ) - - if params.verb: - mongo_query_filters.update({"_source.verb.id": params.verb}) - - if params.activity: - mongo_query_filters.update( - { - "_source.object.objectType": "Activity", - "_source.object.id": params.activity, - }, - ) - - if params.since: - mongo_query_filters.update({"_source.timestamp": {"$gt": params.since}}) - - if params.until: - mongo_query_filters.update({"_source.timestamp": {"$lte": params.until}}) - - if params.search_after: - search_order = "$gt" if params.ascending else "$lt" - mongo_query_filters.update( - {"_id": {search_order: ObjectId(params.search_after)}} - ) - - mongo_sort_order = ASCENDING if params.ascending else DESCENDING - mongo_query_sort = [ - ("_source.timestamp", mongo_sort_order), - ("_id", mongo_sort_order), - ] - - mongo_response = self._find( - filter=mongo_query_filters, limit=params.limit, sort=mongo_query_sort - ) - search_after = None - if mongo_response: - search_after = mongo_response[-1]["_id"] - - return StatementQueryResult( - statements=[document["_source"] for document in mongo_response], - pit_id=None, - search_after=search_after, - ) - - def query_statements_by_ids(self, ids: List[str]) -> List: - """Return the list of matching statements from the database.""" - return [ - {"_id": statement["_source"]["id"], "_source": statement["_source"]} - for statement in self._find(filter={"_source.id": {"$in": ids}}) - ] - - def _find(self, **kwargs): - """Wrap the MongoClient.collection.find method. - - Raises: - BackendException: raised for any failure. - """ - try: - return list(self.collection.find(**kwargs)) - except (PyMongoError, IndexError, TypeError, ValueError) as error: - msg = "Failed to execute MongoDB query" - logger.error("%s. %s", msg, error) - raise BackendException(msg, *error.args) from error diff --git a/src/ralph/backends/lrs/clickhouse.py b/src/ralph/backends/lrs/clickhouse.py index 423318b35..7c97ecd54 100644 --- a/src/ralph/backends/lrs/clickhouse.py +++ b/src/ralph/backends/lrs/clickhouse.py @@ -165,13 +165,13 @@ def _add_agent_filters( f"event.{target_field}.openid = {{{target_field}__openid:String}}" ) elif agent_params.account__name: - ch_params[f"{target_field}__account_name"] = agent_params.account__name + ch_params[f"{target_field}__account__name"] = agent_params.account__name where.append( - f"event.{target_field}.account_name = {{{target_field}__account_name:String}}" # noqa: E501 # pylint: disable=line-too-long + f"event.{target_field}.account.name = {{{target_field}__account__name:String}}" # noqa: E501 # pylint: disable=line-too-long ) ch_params[ - f"{target_field}__account_homepage" + f"{target_field}__account_home_page" ] = agent_params.account__home_page where.append( - f"event.{target_field}.account_homepage = {{{target_field}__account_homepage:String}}" # noqa: E501 # pylint: disable=line-too-long + f"event.{target_field}.account.homePage = {{{target_field}__account_home_page:String}}" # noqa: E501 # pylint: disable=line-too-long ) diff --git a/src/ralph/backends/storage/__init__.py b/src/ralph/backends/storage/__init__.py deleted file mode 100644 index 6e031999e..000000000 --- a/src/ralph/backends/storage/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# noqa: D104 diff --git a/src/ralph/backends/storage/base.py b/src/ralph/backends/storage/base.py deleted file mode 100644 index a94b492e2..000000000 --- a/src/ralph/backends/storage/base.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Base storage backend for Ralph.""" - -from abc import ABC, abstractmethod -from typing import Iterable - - -class BaseStorage(ABC): - """Base storage backend interface.""" - - name = "base" - - @abstractmethod - def list(self, details=False, new=False): - """List files in the storage backend.""" - - @abstractmethod - def url(self, name): - """Get `name` file absolute URL.""" - - @abstractmethod - def read(self, name, chunk_size: int = 4096): - """Read `name` file and yields its content by chunks of a given size.""" - - @abstractmethod - def write(self, stream: Iterable, name, overwrite=False): - """Write content to the `name` target.""" diff --git a/src/ralph/backends/storage/fs.py b/src/ralph/backends/storage/fs.py deleted file mode 100644 index b77f3f77c..000000000 --- a/src/ralph/backends/storage/fs.py +++ /dev/null @@ -1,123 +0,0 @@ -"""FileSystem storage backend for Ralph.""" - -import datetime -import logging -from pathlib import Path - -from ralph.conf import settings -from ralph.utils import now - -from ..mixins import HistoryMixin -from .base import BaseStorage - -logger = logging.getLogger(__name__) - - -class FSStorage(HistoryMixin, BaseStorage): - """FileSystem storage backend.""" - - name = "fs" - - def __init__(self, path: str = settings.BACKENDS.STORAGE.FS.PATH): - """Create the path directory if it does not exist.""" - self._path = Path(path) - if not self._path.is_dir(): - logger.info("FS storage directory doesn't exist, creating: %s", self._path) - self._path.mkdir(parents=True) - - logger.debug("File system storage path: %s", self._path) - - def _get_filepath(self, name, strict=False): - """Get path for `name` file. - - Raises: - FileNotFoundError: When the file_path is not found. - - Returns: - file_path (Path): path of the archive in the FS storage. - """ - file_path = self._path / Path(name) - if strict and not file_path.exists(): - msg = "%s file does not exist" - logger.error(msg, file_path) - raise FileNotFoundError(msg % file_path) - return file_path - - def _details(self, name): - """Get `name` archive details.""" - file_path = self._get_filepath(name) - stats = file_path.stat() - - return { - "filename": name, - "size": stats.st_size, - "modified_at": datetime.datetime.fromtimestamp( - int(stats.st_mtime), tz=datetime.timezone.utc - ).isoformat(), - } - - def list(self, details=False, new=False): - """List files in the storage backend.""" - archives = [archive.name for archive in self._path.iterdir()] - logger.debug("Found %d archives", len(archives)) - - if new: - archives = set(archives) - set(self.get_command_history(self.name, "read")) - logger.debug("New archives: %d", len(archives)) - - for archive in archives: - yield self._details(archive) if details else archive - - def url(self, name): - """Get `name` file absolute URL.""" - return str(self._get_filepath(name).resolve(strict=True)) - - def read(self, name, chunk_size: int = 4096): - """Read `name` file and yields its content by chunks of a given size.""" - logger.debug("Getting archive: %s", name) - - with self._get_filepath(name).open("rb") as file: - while chunk := file.read(chunk_size): - yield chunk - - details = self._details(name) - # Archive is supposed to have been fully fetched, add a new entry to - # the history. - self.append_to_history( - { - "backend": self.name, - "command": "read", - "id": name, - "filename": details.get("filename"), - "size": details.get("size"), - "fetched_at": now(), - } - ) - - def write(self, stream, name, overwrite=False): - """Write content to the `name` target.""" - logger.debug("Creating archive: %s", name) - - file_path = self._get_filepath(name) - if file_path.is_file() and not overwrite: - msg = "%s already exists and overwrite is not allowed" - logger.error(msg, name) - raise FileExistsError(msg, name) - - with file_path.open("wb") as file: - for chunk in stream: - file.write(chunk) - - details = self._details(name) - # Archive is supposed to have been fully created, add a new entry to - # the history. - self.append_to_history( - { - "backend": self.name, - "command": "write", - "id": name, - "filename": details.get("filename"), - "size": details.get("size"), - "pushed_at": now(), - } - ) diff --git a/src/ralph/backends/storage/ldp.py b/src/ralph/backends/storage/ldp.py deleted file mode 100644 index d62431cd5..000000000 --- a/src/ralph/backends/storage/ldp.py +++ /dev/null @@ -1,145 +0,0 @@ -"""OVH's LDP storage backend for Ralph.""" - -import logging - -import ovh -import requests - -from ralph.conf import settings -from ralph.exceptions import BackendParameterException -from ralph.utils import now - -from ..mixins import HistoryMixin -from .base import BaseStorage - -ldp_settings = settings.BACKENDS.STORAGE.LDP -logger = logging.getLogger(__name__) - - -class LDPStorage(HistoryMixin, BaseStorage): - """OVH's LDP storage backend.""" - - # pylint: disable=too-many-arguments - - name = "ldp" - - def __init__( - self, - endpoint: str = ldp_settings.ENDPOINT, - application_key: str = ldp_settings.APPLICATION_KEY, - application_secret: str = ldp_settings.APPLICATION_SECRET, - consumer_key: str = ldp_settings.CONSUMER_KEY, - service_name: str = ldp_settings.SERVICE_NAME, - stream_id: str = ldp_settings.STREAM_ID, - ): - """Instantiate the OVH's LDP client.""" - self._endpoint = endpoint - self._application_key = application_key - self._application_secret = application_secret - self._consumer_key = consumer_key - self.service_name = service_name - self.stream_id = stream_id - - self.client = ovh.Client( - endpoint=self._endpoint, - application_key=self._application_key, - application_secret=self._application_secret, - consumer_key=self._consumer_key, - ) - - @property - def _archive_endpoint(self): - if None in (self.service_name, self.stream_id): - msg = ( - "LDPStorage backend instance requires to set both " - "service_name and stream_id" - ) - logger.error(msg) - raise BackendParameterException(msg) - return ( - f"/dbaas/logs/{self.service_name}/" - f"output/graylog/stream/{self.stream_id}/archive" - ) - - def _details(self, name): - """Return `name` archive details. - - Expected JSON response looks like: - - { - "archiveId": "5d49d1b3-a3eb-498c-9039-6a482166f888", - "createdAt": "2020-06-18T04:38:59.436634+02:00", - "filename": "2020-06-16.gz", - "md5": "01585b394be0495e38dbb60b20cb40a9", - "retrievalDelay": 0, - "retrievalState": "sealed", - "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", - "size": 67906662, - } - """ - return self.client.get(f"{self._archive_endpoint}/{name}") - - def url(self, name): - """Get archive absolute URL.""" - download_url_endpoint = f"{self._archive_endpoint}/{name}/url" - - response = self.client.post(download_url_endpoint) - download_url = response.get("url") - logger.debug("Temporary URL: %s", download_url) - - return download_url - - def list(self, details=False, new=False): - """List archives for a given stream. - - Args: - details (bool): Get detailed archive information instead of just ids. - new (bool): Given the history, list only not already fetched archives. - """ - list_archives_endpoint = self._archive_endpoint - logger.debug("List archives endpoint: %s", list_archives_endpoint) - logger.debug("List archives details: %s", str(details)) - - archives = self.client.get(list_archives_endpoint) - logger.debug("Found %d archives", len(archives)) - - if new: - archives = set(archives) - set(self.get_command_history(self.name, "read")) - logger.debug("New archives: %d", len(archives)) - - for archive in archives: - yield self._details(archive) if details else archive - - def read(self, name, chunk_size=4096): - """Read the `name` archive file and yields its content.""" - logger.debug("Getting archive: %s", name) - - # Get detailed information about the archive to fetch - details = self._details(name) - - # Stream response (archive content) - with requests.get( # pylint: disable=missing-timeout # nosec - self.url(name), stream=True - ) as result: - result.raise_for_status() - for chunk in result.iter_content(chunk_size=chunk_size): - yield chunk - - # Archive is supposed to have been fully fetched, add a new entry to - # the history. - self.append_to_history( - { - "backend": self.name, - "command": "read", - "id": name, - "filename": details.get("filename"), - "size": details.get("size"), - "fetched_at": now(), - } - ) - - def write(self, stream, name, overwrite=False): - """LDP storage backend is read-only, calling this method will raise an error.""" - msg = "LDP storage backend is read-only, cannot write to %s" - logger.error(msg, name) - raise NotImplementedError(msg % name) diff --git a/src/ralph/backends/storage/s3.py b/src/ralph/backends/storage/s3.py deleted file mode 100644 index c342b798b..000000000 --- a/src/ralph/backends/storage/s3.py +++ /dev/null @@ -1,148 +0,0 @@ -"""S3 storage backend for Ralph.""" - -import logging - -import boto3 -from botocore.exceptions import ClientError, ParamValidationError - -from ralph.conf import settings -from ralph.exceptions import BackendException, BackendParameterException -from ralph.utils import now - -from ..mixins import HistoryMixin -from .base import BaseStorage - -s3_settings = settings.BACKENDS.STORAGE.S3 -logger = logging.getLogger(__name__) - - -class S3Storage( - HistoryMixin, BaseStorage -): # pylint: disable=too-many-instance-attributes - """AWS S3 storage backend.""" - - name = "s3" - - # pylint: disable=too-many-arguments - - def __init__( - self, - access_key_id: str = s3_settings.ACCESS_KEY_ID, - secret_access_key: str = s3_settings.SECRET_ACCESS_KEY, - session_token: str = s3_settings.SESSION_TOKEN, - default_region: str = s3_settings.DEFAULT_REGION, - bucket_name: str = s3_settings.BUCKET_NAME, - endpoint_url: str = s3_settings.ENDPOINT_URL, - ): - """Instantiate the AWS S3 client.""" - self.access_key_id = access_key_id - self.secret_access_key = secret_access_key - self.session_token = session_token - self.default_region = default_region - self.bucket_name = bucket_name - self.endpoint_url = endpoint_url - - self.client = boto3.client( - "s3", - aws_access_key_id=self.access_key_id, - aws_secret_access_key=self.secret_access_key, - aws_session_token=self.session_token, - region_name=self.default_region, - endpoint_url=self.endpoint_url, - ) - - # Check whether bucket exists and is accessible - try: - self.client.head_bucket(Bucket=self.bucket_name) - except ClientError as err: - error_msg = err.response["Error"]["Message"] - msg = "Unable to connect to the requested bucket: %s" - logger.error(msg, error_msg) - raise BackendParameterException(msg % error_msg) from err - - def list(self, details=False, new=False): - """List archives in the storage backend.""" - archives_to_skip = set() - if new: - archives_to_skip = set(self.get_command_history(self.name, "read")) - - try: - paginator = self.client.get_paginator("list_objects_v2") - page_iterator = paginator.paginate(Bucket=self.bucket_name) - for archives in page_iterator: - if "Contents" not in archives: - continue - for archive in archives["Contents"]: - if new and archive["Key"] in archives_to_skip: - continue - if details: - archive["LastModified"] = archive["LastModified"].strftime( - "%Y-%m-%d %H:%M:%S" - ) - yield archive - else: - yield archive["Key"] - except ClientError as err: - error_msg = err.response["Error"]["Message"] - msg = "Failed to list the bucket %s: %s" - logger.error(msg, self.bucket_name, error_msg) - raise BackendException(msg % (self.bucket_name, error_msg)) from err - - def url(self, name): - """Get `name` file absolute URL.""" - return f"{self.bucket_name}.s3.{self.default_region}.amazonaws.com/{name}" - - def read(self, name, chunk_size: int = 4096): - """Read `name` file and yields its content by chunks of a given size.""" - logger.debug("Getting archive: %s", name) - - try: - obj = self.client.get_object(Bucket=self.bucket_name, Key=name) - except ClientError as err: - error_msg = err.response["Error"]["Message"] - msg = "Failed to download %s: %s" - logger.error(msg, name, error_msg) - raise BackendException(msg % (name, error_msg)) from err - - size = 0 - for chunk in obj["Body"].iter_chunks(chunk_size): - logger.debug("Chunk length %s", len(chunk)) - size += len(chunk) - yield chunk - - # Archive fetched, add a new entry to the history - self.append_to_history( - { - "backend": self.name, - "command": "read", - "id": name, - "size": size, - "fetched_at": now(), - } - ) - - def write(self, stream, name, overwrite=False): - """Write data from `stream` to the `name` target.""" - if not overwrite and name in list(self.list()): - msg = "%s already exists and overwrite is not allowed" - logger.error(msg, name) - raise FileExistsError(msg % name) - - logger.debug("Creating archive: %s", name) - - try: - self.client.upload_fileobj(stream, self.bucket_name, name) - except (ClientError, ParamValidationError) as exc: - msg = "Failed to upload" - logger.error(msg) - raise BackendException(msg) from exc - - # Archive written, add a new entry to the history - self.append_to_history( - { - "backend": self.name, - "command": "write", - "id": name, - "pushed_at": now(), - } - ) diff --git a/src/ralph/backends/storage/swift.py b/src/ralph/backends/storage/swift.py deleted file mode 100644 index 818e07ef1..000000000 --- a/src/ralph/backends/storage/swift.py +++ /dev/null @@ -1,160 +0,0 @@ -"""Swift storage backend for Ralph.""" - -import logging -from functools import cached_property -from urllib.parse import urlparse - -from swiftclient.service import SwiftService, SwiftUploadObject - -from ralph.conf import settings -from ralph.exceptions import BackendException, BackendParameterException -from ralph.utils import now - -from ..mixins import HistoryMixin -from .base import BaseStorage - -swift_settings = settings.BACKENDS.STORAGE.SWIFT -logger = logging.getLogger(__name__) - - -class SwiftStorage( - HistoryMixin, BaseStorage -): # pylint: disable=too-many-instance-attributes - """OpenStack's Swift storage backend.""" - - name = "swift" - - # pylint: disable=too-many-arguments - - def __init__( - self, - os_tenant_id: str = swift_settings.OS_TENANT_ID, - os_tenant_name: str = swift_settings.OS_TENANT_NAME, - os_username: str = swift_settings.OS_USERNAME, - os_password: str = swift_settings.OS_PASSWORD, - os_region_name: str = swift_settings.OS_REGION_NAME, - os_storage_url: str = swift_settings.OS_STORAGE_URL, - os_user_domain_name: str = swift_settings.OS_USER_DOMAIN_NAME, - os_project_domain_name: str = swift_settings.OS_PROJECT_DOMAIN_NAME, - os_auth_url: str = swift_settings.OS_AUTH_URL, - os_identity_api_version: str = swift_settings.OS_IDENTITY_API_VERSION, - ): - """Prepares the options for the SwiftService.""" - self.os_tenant_id = os_tenant_id - self.os_tenant_name = os_tenant_name - self.os_username = os_username - self.os_password = os_password - self.os_region_name = os_region_name - self.os_user_domain_name = os_user_domain_name - self.os_project_domain_name = os_project_domain_name - self.os_auth_url = os_auth_url - self.os_identity_api_version = os_identity_api_version - self.container = urlparse(os_storage_url).path.rpartition("/")[-1] - self.os_storage_url = os_storage_url - if os_storage_url.endswith(f"/{self.container}"): - self.os_storage_url = os_storage_url[: -len(f"/{self.container}")] - - with SwiftService(self.options) as swift: - stats = swift.stat() - if not stats["success"]: - msg = "Unable to connect to the requested container: %s" - logger.error(msg, stats["error"]) - raise BackendParameterException(msg % stats["error"]) - - @cached_property - def options(self): - """Return the required options for the SwiftService.""" - return { - "os_auth_url": self.os_auth_url, - "os_identity_api_version": self.os_identity_api_version, - "os_password": self.os_password, - "os_project_domain_name": self.os_project_domain_name, - "os_region_name": self.os_region_name, - "os_storage_url": self.os_storage_url, - "os_tenant_id": self.os_tenant_id, - "os_tenant_name": self.os_tenant_name, - "os_username": self.os_username, - "os_user_domain_name": self.os_user_domain_name, - } - - def list(self, details=False, new=False): - """List files in the storage backend.""" - archives_to_skip = set() - if new: - archives_to_skip = set(self.get_command_history(self.name, "read")) - with SwiftService(self.options) as swift: - for page in swift.list(self.container): - if not page["success"]: - msg = "Failed to list container %s: %s" - logger.error(msg, page["container"], page["error"]) - raise BackendException(msg % (page["container"], page["error"])) - for archive in page["listing"]: - if new and archive["name"] in archives_to_skip: - continue - yield archive if details else archive["name"] - - def url(self, name): - """Get `name` file absolute URL.""" - # What's the purpose of this function ? Seems not used anywhere. - return f"{self.options.get('os_storage_url')}/{name}" - - def read(self, name, chunk_size=None): - """Read `name` object and yields its content in chunks of (max) 2 ** 16. - - Why chunks of (max) 2 ** 16 ? - Because SwiftService opens a file to stream the object into: - See swiftclient.service.py:2082 open(filename, 'rb', DISK_BUFFER) - Where filename = "/dev/stdout" and DISK_BUFFER = 2 ** 16 - """ - logger.debug("Getting archive: %s", name) - - with SwiftService(self.options) as swift: - options = {"out_file": "-"} - download = next(swift.download(self.container, [name], options), {}) - if "contents" not in download: - msg = "Failed to download %s: %s" - error = download.get("error", "swift.download did not yield") - logger.error(msg, download.get("object", name), error) - raise BackendException(msg % (download.get("object", name), error)) - size = 0 - for chunk in download["contents"]: - logger.debug("Chunk %s", len(chunk)) - size += len(chunk) - yield chunk - - # Archive fetched, add a new entry to the history - self.append_to_history( - { - "backend": self.name, - "command": "read", - "id": name, - "size": size, - "fetched_at": now(), - } - ) - - def write(self, stream, name, overwrite=False): - """Write data from `stream` to the `name` target in chunks of (max) 2 ** 16.""" - if not overwrite and name in list(self.list()): - msg = "%s already exists and overwrite is not allowed" - logger.error(msg, name) - raise FileExistsError(msg % name) - - logger.debug("Creating archive: %s", name) - - swift_object = SwiftUploadObject(stream, object_name=name) - with SwiftService(self.options) as swift: - for upload in swift.upload(self.container, [swift_object]): - if not upload["success"]: - logger.error(upload["error"]) - raise BackendException(upload["error"]) - - # Archive written, add a new entry to the history - self.append_to_history( - { - "backend": self.name, - "command": "write", - "id": name, - "pushed_at": now(), - } - ) diff --git a/src/ralph/backends/stream/base.py b/src/ralph/backends/stream/base.py index 6d3a3addc..1f2b1f11e 100644 --- a/src/ralph/backends/stream/base.py +++ b/src/ralph/backends/stream/base.py @@ -19,7 +19,7 @@ class Config(BaseSettingsConfig): env_file_encoding = core_settings.LOCALE_ENCODING -class BaseStream(ABC): +class BaseStreamBackend(ABC): """Base stream backend interface.""" name = "base" diff --git a/src/ralph/cli.py b/src/ralph/cli.py index 5dc311446..b3272a021 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -4,7 +4,7 @@ import logging import re import sys -from inspect import isclass +from inspect import isclass, isasyncgen from pathlib import Path from tempfile import NamedTemporaryFile from typing import List @@ -29,6 +29,8 @@ from pydantic import BaseModel from ralph import __version__ as ralph_version +from ralph.backends.conf import backends_settings +from ralph.backends.data.base import BaseOperationType from ralph.conf import ClientOptions, CommaSeparatedTuple, HeadersParameters, settings from ralph.exceptions import UnsupportedBackendException from ralph.logger import configure_logging @@ -40,6 +42,7 @@ get_backend_type, get_root_logger, import_string, + iter_over_async ) # cli module logger @@ -207,7 +210,9 @@ def wrapper(command): backend_names.append(backend_name) for field_name, field in backend: field_type = backend.__fields__[field_name].type_ - field_name = f"{backend_name}-{field_name}".replace("_", "-") + field_name = f"{backend_name}-{field_name.lower()}".replace( + "_", "-" + ) option = f"--{field_name}" option_kwargs = {} # If the field is a boolean, convert it to a flag option @@ -224,6 +229,8 @@ def wrapper(command): field_type, HeadersParameters ): option_kwargs["type"] = HeadersParametersParamType(field_type) + elif field_type is Path: + option_kwargs["type"] = click.Path() command = optgroup.option( option.lower(), default=field, **option_kwargs @@ -553,8 +560,15 @@ def convert(from_, to_, ignore_errors, fail_on_unknown, **conversion_set_kwargs) click.echo(event) +read_backend_types = [ + backends_settings.BACKENDS.DATA, + backends_settings.BACKENDS.HTTP, + backends_settings.BACKENDS.STREAM, +] + + @click.argument("archive", required=False) -@backends_options(backend_types=[backend for _, backend in settings.BACKENDS]) +@backends_options(backend_types=read_backend_types) @click.option( "-c", "--chunk-size", @@ -576,7 +590,23 @@ def convert(from_, to_, ignore_errors, fail_on_unknown, **conversion_set_kwargs) default=None, help="Query object as a JSON string (database and HTTP backends ONLY)", ) -def read(backend, archive, chunk_size, target, query, **options): +@click.option( + "-i", + "--ignore_errors", + is_flag=False, + show_default=True, + default=False, + help="Ignore errors during the encoding operation.", +) +def read( + backend, + archive, + chunk_size, + target, + query, + ignore_errors, + **options, +): # pylint: disable=too-many-arguments """Read an archive or records from a configured backend.""" logger.info( ( @@ -591,25 +621,22 @@ def read(backend, archive, chunk_size, target, query, **options): ) logger.debug("Backend parameters: %s", options) - backend_type = get_backend_type(settings.BACKENDS, backend) + backend_type = get_backend_type(read_backend_types, backend) backend = get_backend_instance(backend_type, backend, options) - if backend_type == settings.BACKENDS.STORAGE: - for data in backend.read(archive, chunk_size=chunk_size): - click.echo(data, nl=False) - elif backend_type == settings.BACKENDS.DATABASE: - if query is not None: - query = backend.query_model.parse_obj(query) - for document in backend.get(query=query, chunk_size=chunk_size): - click.echo( - bytes( - json.dumps(document) if isinstance(document, dict) else document, - encoding="utf-8", - ) - ) - elif backend_type == settings.BACKENDS.STREAM: + if backend_type == backends_settings.BACKENDS.DATA: + for statement in backend.read( + query=query, + target=target, + chunk_size=chunk_size, + raw_output=True, + ignore_errors=ignore_errors, + ): + click.echo(statement) + + elif backend_type == backends_settings.BACKENDS.STREAM: backend.stream(sys.stdout.buffer) - elif backend_type == settings.BACKENDS.HTTP: + elif backend_type == backends_settings.BACKENDS.HTTP: if query is not None: query = backend.query(query=query) for statement in backend.read( @@ -627,15 +654,14 @@ def read(backend, archive, chunk_size, target, query, **options): raise UnsupportedBackendException(msg, backend) +write_backend_types = [ + backends_settings.BACKENDS.DATA, + backends_settings.BACKENDS.HTTP, +] + + # pylint: disable=unnecessary-direct-lambda-call, too-many-arguments -@click.argument("archive", required=False) -@backends_options( - backend_types=[ - settings.BACKENDS.DATABASE, - settings.BACKENDS.STORAGE, - settings.BACKENDS.HTTP, - ] -) +@backends_options(backend_types=write_backend_types) @click.option( "-c", "--chunk-size", @@ -679,11 +705,10 @@ def read(backend, archive, chunk_size, target, query, **options): "--target", type=str, default=None, - help="Endpoint in which to write events (e.g. `statements`)", + help="The target container to write into", ) def write( backend, - archive, chunk_size, force, ignore_errors, @@ -693,21 +718,27 @@ def write( **options, ): """Write an archive to a configured backend.""" - logger.info("Writing archive %s to the configured %s backend", archive, backend) + logger.info("Writing to target %s for the configured %s backend", target, backend) logger.debug("Backend parameters: %s", options) if max_num_simultaneous == 1: max_num_simultaneous = None - backend_type = get_backend_type(settings.BACKENDS, backend) + backend_type = get_backend_type(write_backend_types, backend) backend = get_backend_instance(backend_type, backend, options) - if backend_type == settings.BACKENDS.STORAGE: - backend.write(sys.stdin.buffer, archive, overwrite=force) - elif backend_type == settings.BACKENDS.DATABASE: - backend.put(sys.stdin, chunk_size=chunk_size, ignore_errors=ignore_errors) - elif backend_type == settings.BACKENDS.HTTP: + if backend_type == backends_settings.BACKENDS.DATA: + backend.write( + data=sys.stdin.buffer, + target=target, + chunk_size=chunk_size, + ignore_errors=ignore_errors, + operation_type=BaseOperationType.UPDATE + if force + else BaseOperationType.INDEX, + ) + elif backend_type == backends_settings.BACKENDS.HTTP: backend.write( target=target, data=sys.stdin.buffer, @@ -722,39 +753,54 @@ def write( raise UnsupportedBackendException(msg, backend) -@backends_options(name="list", backend_types=[settings.BACKENDS.STORAGE]) +list_backend_types = [backends_settings.BACKENDS.DATA] + + +@backends_options(name="list", backend_types=list_backend_types) +@click.option( + "-t", + "--target", + type=str, + default=None, + help="Container to list events from", +) @click.option( "-n/-a", "--new/--all", default=False, - help="List not fetched (or all) archives", + help="List not fetched (or all) documents", ) @click.option( "-D/-I", "--details/--ids", default=False, - help="Get archives detailed output (JSON)", + help="Get documents detailed output (JSON)", ) -def list_(details, new, backend, **options): - """List available archives from a configured storage backend.""" - logger.info("Listing archives for the configured %s backend", backend) +def list_(target, details, new, backend, **options): + """List available documents from a configured data backend.""" + logger.info("Listing documents for the configured %s backend", backend) + logger.debug("Target container: %s", target) logger.debug("Fetch details: %s", str(details)) logger.debug("Backend parameters: %s", options) - storage = get_backend_instance(settings.BACKENDS.STORAGE, backend, options) - - archives = storage.list(details=details, new=new) + backend_type = get_backend_type(list_backend_types, backend) + backend = get_backend_instance(backend_type, backend, options) + documents = backend.list(target=target, details=details, new=new) + documents = iter_over_async(documents) if isasyncgen(documents) else documents counter = 0 - for archive in archives: - click.echo(json.dumps(archive) if details else archive) + for document in documents: + click.echo(json.dumps(document) if details else document) counter += 1 if counter == 0: - logger.warning("Configured %s backend contains no archive", backend) + logger.warning("Configured %s backend contains no document", backend.name) + + +runserver_backend_types = [backends_settings.BACKENDS.LRS] -@backends_options(name="runserver", backend_types=[settings.BACKENDS.DATABASE]) +@backends_options(name="runserver", backend_types=runserver_backend_types) @click.option( "-h", "--host", @@ -794,7 +840,7 @@ def runserver(backend: str, host: str, port: int, **options): if value is None: continue backend_name, field_name = key.split(sep="_", maxsplit=1) - key = f"RALPH_BACKENDS__DATABASE__{backend_name}__{field_name}".upper() + key = f"RALPH_BACKENDS__LRS__{backend_name}__{field_name}".upper() if isinstance(value, tuple): value = ",".join(value) if issubclass(type(value), ClientOptions): diff --git a/src/ralph/conf.py b/src/ralph/conf.py index f10c1ca0e..0415a5ee3 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -19,7 +19,7 @@ from unittest.mock import Mock get_app_dir = Mock(return_value=".") -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, Field +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra from .utils import import_string @@ -32,6 +32,7 @@ class BaseSettingsConfig: case_sensitive = True env_nested_delimiter = "__" env_prefix = "RALPH_" + extra = "ignore" class CoreSettings(BaseSettings): @@ -78,9 +79,6 @@ def get_instance(self, **init_parameters): return import_string(self._class_path)(**init_parameters) -# Active database backend Settings. - - class ClientOptions(BaseModel): """Pydantic model for additional client options.""" @@ -88,74 +86,6 @@ class Config: # pylint: disable=missing-class-docstring # noqa: D106 extra = Extra.forbid -class ClickhouseClientOptions(ClientOptions): - """Pydantic model for `clickhouse` client options.""" - - date_time_input_format: str = "best_effort" - allow_experimental_object_type: Literal[0, 1] = None - - -class ESClientOptions(ClientOptions): - """Pydantic model for Elasticsearch additional client options.""" - - ca_certs: Path = None - verify_certs: bool = None - - -class ClickhouseDatabaseBackendSettings(InstantiableSettingsItem): - """Pydantic model for ClickHouse database backend configuration settings.""" - - _class_path: str = "ralph.backends.database.clickhouse.ClickHouseDatabase" - - HOST: str = "localhost" - PORT: int = 8123 - DATABASE: str = "xapi" - EVENT_TABLE_NAME: str = "xapi_events_all" - USERNAME: str = None - PASSWORD: str = None - CLIENT_OPTIONS: ClickhouseClientOptions = None - - -class MongoClientOptions(ClientOptions): - """Pydantic model for MongoDB additional client options.""" - - document_class: str = None - tz_aware: bool = None - - -class ESDatabaseBackendSettings(InstantiableSettingsItem): - """Pydantic model for Elasticsearch database backend configuration settings.""" - - _class_path: str = "ralph.backends.database.es.ESDatabase" - - HOSTS: CommaSeparatedTuple = ("http://localhost:9200",) - INDEX: str = "statements" - CLIENT_OPTIONS: ESClientOptions = ESClientOptions() - OP_TYPE: Literal["index", "create", "delete", "update"] = "index" - - -class MongoDatabaseBackendSettings(InstantiableSettingsItem): - """Pydantic model for Mongo database backend configuration settings.""" - - _class_path: str = "ralph.backends.database.mongo.MongoDatabase" - - CONNECTION_URI: str = "mongodb://localhost:27017/" - DATABASE: str = "statements" - COLLECTION: str = "marsha" - CLIENT_OPTIONS: MongoClientOptions = MongoClientOptions() - - -class DatabaseBackendSettings(BaseModel): - """Pydantic model for database backend configuration settings.""" - - ES: ESDatabaseBackendSettings = ESDatabaseBackendSettings() - MONGO: MongoDatabaseBackendSettings = MongoDatabaseBackendSettings() - CLICKHOUSE: ClickhouseDatabaseBackendSettings = ClickhouseDatabaseBackendSettings() - - -# Active HTTP backend Settings. - - class HeadersParameters(BaseModel): """Pydantic model for headers parameters.""" @@ -163,124 +93,6 @@ class Config: # pylint: disable=missing-class-docstring # noqa: D106 extra = Extra.allow -class LRSHeaders(HeadersParameters): - """Pydantic model for LRS headers.""" - - X_EXPERIENCE_API_VERSION: str = Field("1.0.3", alias="X-Experience-API-Version") - CONTENT_TYPE: str = Field("application/json", alias="content-type") - - -class LRSHTTPBackendSettings(InstantiableSettingsItem): - """Pydantic model for LRS HTTP backend configuration settings.""" - - _class_path: str = "ralph.backends.http.lrs.LRSHTTP" - - BASE_URL: AnyHttpUrl = Field("http://0.0.0.0:8100") - USERNAME: str = "ralph" - PASSWORD: str = "secret" - HEADERS: LRSHeaders = LRSHeaders() - STATUS_ENDPOINT: str = "/__heartbeat__" - STATEMENTS_ENDPOINT: str = "/xAPI/statements" - - -class HTTPBackendSettings(BaseModel): - """Pydantic model for HTTP backend configuration settings.""" - - LRS: LRSHTTPBackendSettings = LRSHTTPBackendSettings() - - -# Active storage backend Settings. - - -class FSStorageBackendSettings(InstantiableSettingsItem): - """Pydantic model for FileSystem storage backend configuration settings.""" - - _class_path: str = "ralph.backends.storage.fs.FSStorage" - - PATH: str = str(core_settings.APP_DIR / "archives") - - -class LDPStorageBackendSettings(InstantiableSettingsItem): - """Pydantic model for LDP storage backend configuration settings.""" - - _class_path: str = "ralph.backends.storage.ldp.LDPStorage" - - ENDPOINT: str = None - APPLICATION_KEY: str = None - APPLICATION_SECRET: str = None - CONSUMER_KEY: str = None - SERVICE_NAME: str = None - STREAM_ID: str = None - - -class SWIFTStorageBackendSettings(InstantiableSettingsItem): - """Pydantic model for SWIFT storage backend configuration settings.""" - - _class_path: str = "ralph.backends.storage.swift.SwiftStorage" - - OS_TENANT_ID: str = None - OS_TENANT_NAME: str = None - OS_USERNAME: str = None - OS_PASSWORD: str = None - OS_REGION_NAME: str = None - OS_STORAGE_URL: str = None - OS_USER_DOMAIN_NAME: str = "Default" - OS_PROJECT_DOMAIN_NAME: str = "Default" - OS_AUTH_URL: str = "https://auth.cloud.ovh.net/" - OS_IDENTITY_API_VERSION: str = "3" - - -class S3StorageBackendSettings(InstantiableSettingsItem): - """Represents the S3 storage backend configuration settings.""" - - _class_path: str = "ralph.backends.storage.s3.S3Storage" - - ACCESS_KEY_ID: str = None - SECRET_ACCESS_KEY: str = None - SESSION_TOKEN: str = None - DEFAULT_REGION: str = None - BUCKET_NAME: str = None - ENDPOINT_URL: str = None - - -class StorageBackendSettings(BaseModel): - """Pydantic model for storage backend configuration settings.""" - - LDP: LDPStorageBackendSettings = LDPStorageBackendSettings() - FS: FSStorageBackendSettings = FSStorageBackendSettings() - SWIFT: SWIFTStorageBackendSettings = SWIFTStorageBackendSettings() - S3: S3StorageBackendSettings = S3StorageBackendSettings() - - -# Active storage backend Settings. - - -class WSStreamBackendSettings(InstantiableSettingsItem): - """Pydantic model for Websocket stream backend configuration settings.""" - - _class_path: str = "ralph.backends.stream.ws.WSStream" - - URI: str = None - - -class StreamBackendSettings(BaseModel): - """Pydantic model for stream backend configuration settings.""" - - WS: WSStreamBackendSettings = WSStreamBackendSettings() - - -# Active backend Settings. - - -class BackendSettings(BaseModel): - """Pydantic model for backends configuration settings.""" - - DATABASE: DatabaseBackendSettings = DatabaseBackendSettings() - HTTP: HTTPBackendSettings = HTTPBackendSettings() - STORAGE: StorageBackendSettings = StorageBackendSettings() - STREAM: StreamBackendSettings = StreamBackendSettings() - - # Active parser Settings. @@ -336,7 +148,6 @@ class AuthBackends(Enum): AUTH_FILE: Path = _CORE.APP_DIR / "auth.json" AUTH_CACHE_MAX_SIZE = 100 AUTH_CACHE_TTL = 3600 - BACKENDS: BackendSettings = BackendSettings() CONVERTER_EDX_XAPI_UUID_NAMESPACE: str = None DEFAULT_BACKEND_CHUNK_SIZE: int = 500 EXECUTION_ENVIRONMENT: str = "development" diff --git a/src/ralph/utils.py b/src/ralph/utils.py index 090085534..e0f3f9eec 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -7,6 +7,7 @@ import operator from functools import reduce from importlib import import_module +from inspect import getmembers, isclass from typing import Any, Dict, Iterable, Iterator, List, Union from pydantic import BaseModel @@ -14,12 +15,29 @@ from ralph.exceptions import BackendException +def import_subclass(dotted_path, parent_class): + """Import a dotted module path. + + Return the class that is a subclass of `parent_class` inside this module. + Raise ImportError if the import failed. + """ + module = import_module(dotted_path) + + for _, class_ in getmembers(module, isclass): + if issubclass(class_, parent_class): + return class_ + + raise ImportError( + f'Module "{dotted_path}" does not define a subclass of "{parent_class}" class' + ) + + # Taken from Django utilities # https://docs.djangoproject.com/en/3.1/_modules/django/utils/module_loading/#import_string def import_string(dotted_path): """Import a dotted module path. - Returns the attribute/class designated by the last name in the path. + Return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. """ try: @@ -37,23 +55,60 @@ def import_string(dotted_path): ) from err -def get_backend_type(backends: BaseModel, backend_name: str): +def get_backend_type(backend_types: List[BaseModel], backend_name: str): """Return the backend type from a backend name.""" backend_name = backend_name.upper() - for _, backend_type in backends: + for backend_type in backend_types: if hasattr(backend_type, backend_name): return backend_type return None -def get_backend_instance(backend_type: BaseModel, backend_name: str, options: dict): - """Return the instantiated backend instance given backend-name-prefixed options.""" +def get_backend_class(backend_type: BaseModel, backend_name: str): + """Return the backend class given the backend type and backend name.""" + # Get type name from backend_type class name + backend_type_name = backend_type.__class__.__name__[ + : -len("BackendSettings") + ].lower() + backend_name = backend_name.lower() + + module = import_module(f"ralph.backends.{backend_type_name}.{backend_name}") + for _, class_ in getmembers(module, isclass): + if ( + getattr(class_, "type", None) == backend_type_name + and getattr(class_, "name", None) == backend_name + ): + backend_class = class_ + break + + if not backend_class: + raise BackendException( + f'No backend named "{backend_name}" ' + f'under the backend type "{backend_type_name}"' + ) + + return backend_class + + +def get_backend_instance( + backend_type: BaseModel, + backend_name: str, + options: Union[dict, None] = None, +): + """Return the instantiated backend given the backend type, name and options.""" + backend_class = get_backend_class(backend_type, backend_name) + backend_settings = getattr(backend_type, backend_name.upper()) + + if not options: + return backend_class(backend_settings) + prefix = f"{backend_name}_" # Filter backend-related parameters. Parameter name is supposed to start # with the backend name names = filter(lambda key: key.startswith(prefix), options.keys()) - options = {name.replace(prefix, ""): options[name] for name in names} - return getattr(backend_type, backend_name.upper()).get_instance(**options) + options = {name.replace(prefix, "").upper(): options[name] for name in names} + + return backend_class(backend_settings.__class__(**options)) def get_root_logger(): diff --git a/tests/backends/database/__init__.py b/tests/backends/database/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/backends/database/test_clickhouse.py b/tests/backends/database/test_clickhouse.py deleted file mode 100644 index 2f3e78f8c..000000000 --- a/tests/backends/database/test_clickhouse.py +++ /dev/null @@ -1,533 +0,0 @@ -"""Tests for Ralph clickhouse database backend.""" - -import logging -import uuid -from datetime import datetime, timedelta - -import pytest -import pytz -from clickhouse_connect.driver.exceptions import ClickHouseError -from clickhouse_connect.driver.httpclient import HttpClient - -from ralph.backends.database.base import DatabaseStatus, RalphStatementsQuery -from ralph.backends.database.clickhouse import ClickHouseDatabase, ClickHouseQuery -from ralph.exceptions import ( - BackendException, - BackendParameterException, - BadFormatException, -) - -from tests.fixtures.backends import ( - CLICKHOUSE_TEST_DATABASE, - CLICKHOUSE_TEST_HOST, - CLICKHOUSE_TEST_PORT, - CLICKHOUSE_TEST_TABLE_NAME, - get_clickhouse_test_backend, -) - - -def test_backends_db_clickhouse_database_instantiation(): - """Test the ClickHouse backend instantiation.""" - assert ClickHouseDatabase.name == "clickhouse" - - backend = get_clickhouse_test_backend() - - assert isinstance(backend.client, HttpClient) - assert backend.database == CLICKHOUSE_TEST_DATABASE - - -# pylint: disable=unused-argument -def test_backends_db_clickhouse_get_method(clickhouse): - """Test the clickhouse backend get method.""" - # Create records - date_1 = (datetime.now() - timedelta(seconds=3)).isoformat() - date_2 = (datetime.now() - timedelta(seconds=2)).isoformat() - date_3 = (datetime.now() - timedelta(seconds=1)).isoformat() - - statements = [ - {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_1}, - {"id": str(uuid.uuid4()), "bool": 0, "timestamp": date_2}, - {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_3}, - ] - documents = list(ClickHouseDatabase.to_documents(statements)) - - backend = get_clickhouse_test_backend() - backend.bulk_import(documents) - - results = list(backend.get()) - assert len(results) == 3 - assert results[0]["event"] == statements[0] - assert results[1]["event"] == statements[1] - assert results[2]["event"] == statements[2] - - results = list(backend.get(chunk_size=1)) - assert len(results) == 3 - assert results[0]["event"] == statements[0] - assert results[1]["event"] == statements[1] - assert results[2]["event"] == statements[2] - - results = list(backend.get(chunk_size=1000)) - assert len(results) == 3 - assert results[0]["event"] == statements[0] - assert results[1]["event"] == statements[1] - assert results[2]["event"] == statements[2] - - -# pylint: disable=unused-argument -def test_backends_db_clickhouse_get_method_on_timestamp_boundary(clickhouse): - """Make sure no rows are lost on pagination if they have the same timestamp.""" - # Create records - date_1 = "2023-02-17T16:55:17.721627" - date_2 = "2023-02-17T16:55:14.721633" - - # Using fixed UUIDs here to make sure they always come back in the same order - statements = [ - {"id": "9e1310cb-875f-4b14-9410-6443399be63c", "timestamp": date_1}, - {"id": "f93b5796-e0b1-4221-a867-7c2c820f9b68", "timestamp": date_2}, - {"id": "af8effc0-26eb-42b6-8f64-3a0d6b26c16c", "timestamp": date_2}, - ] - documents = list(ClickHouseDatabase.to_documents(statements)) - - backend = get_clickhouse_test_backend() - backend.bulk_import(documents) - - # First get all 3 rows with default settings - results = backend.query_statements(RalphStatementsQuery.construct()) - result_statements = results.statements - assert len(result_statements) == 3 - assert result_statements[0] == statements[0] - assert result_statements[1] == statements[1] - assert result_statements[2] == statements[2] - - # Next get them one at a time, starting with the first - params = RalphStatementsQuery.construct(limit=1) - results = backend.query_statements(params) - result_statements = results.statements - assert len(result_statements) == 1 - assert result_statements[0] == statements[0] - - # Next get the second row with an appropriate search after - params = RalphStatementsQuery.construct( - limit=1, - search_after=results.search_after, - pit_id=results.pit_id, - ) - results = backend.query_statements(params) - result_statements = results.statements - assert len(result_statements) == 1 - assert result_statements[0] == statements[1] - - # And finally the third - params = RalphStatementsQuery.construct( - limit=1, - search_after=results.search_after, - pit_id=results.pit_id, - ) - results = backend.query_statements(params) - result_statements = results.statements - assert len(result_statements) == 1 - assert result_statements[0] == statements[2] - - -# pylint: disable=unused-argument -def test_backends_db_clickhouse_get_method_with_a_custom_query(clickhouse): - """Test the clickhouse backend get method with a custom query.""" - date_1 = (datetime.now() - timedelta(seconds=3)).isoformat() - date_2 = (datetime.now() - timedelta(seconds=2)).isoformat() - date_3 = (datetime.now() - timedelta(seconds=1)).isoformat() - - statements = [ - {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_1}, - {"id": str(uuid.uuid4()), "bool": 0, "timestamp": date_2}, - {"id": str(uuid.uuid4()), "bool": 1, "timestamp": date_3}, - ] - documents = list(ClickHouseDatabase.to_documents(statements)) - - backend = get_clickhouse_test_backend() - backend.bulk_import(documents) - - # Test filtering - query = ClickHouseQuery(where_clause="event.bool = 1") - results = list(backend.get(query=query)) - assert len(results) == 2 - assert results[0]["event"] == statements[0] - assert results[1]["event"] == statements[2] - - # Test fields - query = ClickHouseQuery(return_fields=["event_id", "event.bool"]) - results = list(backend.get(query=query)) - assert len(results) == 3 - assert len(results[0]) == 2 - assert results[0]["event_id"] == documents[0][0] - assert results[0]["event.bool"] == statements[0]["bool"] - assert results[1]["event_id"] == documents[1][0] - assert results[1]["event.bool"] == statements[1]["bool"] - assert results[2]["event_id"] == documents[2][0] - assert results[2]["event.bool"] == statements[2]["bool"] - - # Test filtering and projection - query = ClickHouseQuery( - where_clause="event.bool = 0", return_fields=["event_id", "event.bool"] - ) - results = list(backend.get(query=query)) - assert len(results) == 1 - assert len(results[0]) == 2 - assert results[0]["event_id"] == documents[1][0] - assert results[0]["event.bool"] == statements[1]["bool"] - - # Check query argument type - with pytest.raises( - BackendParameterException, - match="'query' argument is expected to be a ClickHouseQuery instance.", - ): - list(backend.get(query="foo")) - - -def test_backends_db_clickhouse_to_documents_method(): - """Test the clickhouse backend to_documents method.""" - native_statements = [ - { - "id": uuid.uuid4(), - "timestamp": datetime.now(pytz.utc) - timedelta(seconds=1), - }, - {"id": uuid.uuid4(), "timestamp": datetime.now(pytz.utc)}, - ] - # Add a duplicate row to ensure statement transformation is idempotent - native_statements.append(native_statements[1]) - - statements = [ - {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} - for x in native_statements - ] - documents = ClickHouseDatabase.to_documents(statements) - - doc = next(documents) - assert doc[0] == native_statements[0]["id"] - assert doc[1] == native_statements[0]["timestamp"].replace(tzinfo=pytz.UTC) - assert doc[2] == statements[0] - - doc = next(documents) - assert doc[0] == native_statements[1]["id"] - assert doc[1] == native_statements[1]["timestamp"].replace(tzinfo=pytz.UTC) - assert doc[2] == statements[1] - - # Identical statement ID produces the same Object - doc = next(documents) - assert doc[0] == native_statements[1]["id"] - assert doc[1] == native_statements[1]["timestamp"].replace(tzinfo=pytz.UTC) - assert doc[2] == statements[1] - - -def test_backends_db_clickhouse_to_documents_method_when_statement_has_no_id( - caplog, -): - """Test the clickhouse to_documents method when a statement has no id field.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": str(uuid.uuid4()), **timestamp}, - {**timestamp}, - {"id": str(uuid.uuid4()), **timestamp}, - ] - - documents = ClickHouseDatabase.to_documents(statements, ignore_errors=False) - assert next(documents)[0] == uuid.UUID(statements[0]["id"], version=4) - - with pytest.raises( - BadFormatException, - match="Statement has an invalid or missing id or " "timestamp field", - ): - next(documents) - - documents = ClickHouseDatabase.to_documents(statements, ignore_errors=True) - assert next(documents)[0] == uuid.UUID(statements[0]["id"], version=4) - assert next(documents)[0] == uuid.UUID(statements[2]["id"], version=4) - - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - "Statement has an invalid or missing id or timestamp field" - in caplog.records[0].message - ) - - -def test_backends_db_clickhouse_to_documents_method_when_statement_has_no_timestamp( - caplog, -): - """Test the clickhouse to_documents method when a statement has no timestamp.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": str(uuid.uuid4()), **timestamp}, - {"id": str(uuid.uuid4())}, - {"id": str(uuid.uuid4()), **timestamp}, - ] - - documents = ClickHouseDatabase.to_documents(statements, ignore_errors=False) - assert next(documents)[0] == uuid.UUID(statements[0]["id"], version=4) - - with pytest.raises( - BadFormatException, - match="Statement has an invalid or missing id or " "timestamp field", - ): - next(documents) - - documents = ClickHouseDatabase.to_documents(statements, ignore_errors=True) - assert next(documents)[0] == uuid.UUID(statements[0]["id"], version=4) - assert next(documents)[0] == uuid.UUID(statements[2]["id"], version=4) - - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - "Statement has an invalid or missing id or timestamp field" - in caplog.records[0].message - ) - - -def test_backends_db_clickhouse_to_documents_method_with_invalid_timestamp( - caplog, -): - """Test the clickhouse to_documents method with an invalid timestamp.""" - valid_timestamp = {"timestamp": "2022-06-27T15:36:50"} - valid_timestamp_2 = {"timestamp": "2022-06-27T15:36:51"} - invalid_timestamp = {"timestamp": "This is not a valid timestamp!"} - invalid_statement = {"id": str(uuid.uuid4()), **invalid_timestamp} - statements = [ - {"id": str(uuid.uuid4()), **valid_timestamp}, - invalid_statement, - {"id": str(uuid.uuid4()), **valid_timestamp_2}, - ] - - with pytest.raises( - BadFormatException, - match="Statement has an invalid or missing id or timestamp field", - ): - # Since this is a generator the error won't happen until the failing - # statement is processed. - list(ClickHouseDatabase.to_documents(statements, ignore_errors=False)) - - documents = ClickHouseDatabase.to_documents(statements, ignore_errors=True) - assert next(documents)[0] == uuid.UUID(statements[0]["id"], version=4) - assert next(documents)[0] == uuid.UUID(statements[2]["id"], version=4) - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - "Statement has an invalid or missing id or timestamp field" - in caplog.records[0].message - ) - - -def test_backends_db_clickhouse_bulk_import_method(clickhouse): - """Test the clickhouse backend bulk_import method.""" - # pylint: disable=unused-argument - - backend = ClickHouseDatabase( - host=CLICKHOUSE_TEST_HOST, - port=CLICKHOUSE_TEST_PORT, - database=CLICKHOUSE_TEST_DATABASE, - event_table_name=CLICKHOUSE_TEST_TABLE_NAME, - ) - - native_statements = [ - {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, - {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, - ] - statements = [ - {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} - for x in native_statements - ] - - docs = list(ClickHouseDatabase.to_documents(statements)) - backend.bulk_import(docs) - - res = backend.client.query(f"SELECT * FROM {CLICKHOUSE_TEST_TABLE_NAME}") - result = res.named_results() - - db_statement = next(result) - assert db_statement["event_id"] == native_statements[0]["id"] - assert db_statement["emission_time"] == native_statements[0]["timestamp"] - assert db_statement["event"] == statements[0] - - db_statement = next(result) - assert db_statement["event_id"] == native_statements[1]["id"] - assert db_statement["emission_time"] == native_statements[1]["timestamp"] - assert db_statement["event"] == statements[1] - - -def test_backends_db_clickhouse_bulk_import_method_with_duplicated_key( - clickhouse, -): - """Test the clickhouse backend bulk_import method with a duplicated key conflict.""" - backend = get_clickhouse_test_backend() - - timestamp = {"timestamp": "2022-06-27T15:36:50"} - dupe_id = str(uuid.uuid4()) - statements = [ - {"id": str(uuid.uuid4()), **timestamp}, - {"id": dupe_id, **timestamp}, - {"id": dupe_id, **timestamp}, - ] - documents = list(ClickHouseDatabase.to_documents(statements)) - with pytest.raises(BackendException, match="Duplicate IDs found in batch"): - backend.bulk_import(documents) - - success = backend.bulk_import(documents, ignore_errors=True) - assert success == 0 - - -def test_backends_db_clickhouse_bulk_import_method_import_partial_chunks_on_error( - clickhouse, -): - """Test the clickhouse bulk_import method imports partial chunks while raising a - BulkWriteError and ignoring errors. - """ - # pylint: disable=unused-argument - - backend = get_clickhouse_test_backend() - - # Identical statement ID produces the same ObjectId, leading to a - # duplicated key write error while trying to bulk import this batch - timestamp = {"timestamp": "2022-06-27T15:36:50"} - dupe_id = str(uuid.uuid4()) - statements = [ - {"id": str(uuid.uuid4()), **timestamp}, - {"id": dupe_id, **timestamp}, - {"id": str(uuid.uuid4()), **timestamp}, - {"id": str(uuid.uuid4()), **timestamp}, - {"id": dupe_id, **timestamp}, - ] - documents = list(ClickHouseDatabase.to_documents(statements)) - assert backend.bulk_import(documents, ignore_errors=True) == 0 - - -def test_backends_db_clickhouse_put_method(clickhouse): - """Test the clickhouse backend put method.""" - sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" - result = clickhouse.query(sql).result_set - assert result[0][0] == 0 - - native_statements = [ - {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, - {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, - ] - statements = [ - {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} - for x in native_statements - ] - backend = get_clickhouse_test_backend() - success = backend.put(statements) - - assert success == 2 - - result = clickhouse.query(sql).result_set - assert result[0][0] == 2 - - sql = f"""SELECT * FROM {CLICKHOUSE_TEST_TABLE_NAME} ORDER BY event.timestamp""" - result = list(clickhouse.query(sql).named_results()) - - assert result[0]["event_id"] == native_statements[0]["id"] - assert result[0]["emission_time"] == native_statements[0]["timestamp"] - assert result[0]["event"] == statements[0] - - assert result[1]["event_id"] == native_statements[1]["id"] - assert result[1]["emission_time"] == native_statements[1]["timestamp"] - assert result[1]["event"] == statements[1] - - -def test_backends_db_clickhouse_put_method_with_custom_chunk_size(clickhouse): - """Test the clickhouse backend put method with a custom chunk_size.""" - sql = f"""SELECT count(*) FROM {CLICKHOUSE_TEST_TABLE_NAME}""" - result = clickhouse.query(sql).result_set - assert result[0][0] == 0 - - native_statements = [ - {"id": uuid.uuid4(), "timestamp": datetime.utcnow() - timedelta(seconds=1)}, - {"id": uuid.uuid4(), "timestamp": datetime.utcnow()}, - ] - statements = [ - {"id": str(x["id"]), "timestamp": x["timestamp"].isoformat()} - for x in native_statements - ] - - backend = get_clickhouse_test_backend() - success = backend.put(statements, chunk_size=1) - assert success == 2 - - result = clickhouse.query(sql).result_set - assert result[0][0] == 2 - - sql = f"""SELECT * FROM {CLICKHOUSE_TEST_TABLE_NAME} ORDER BY event.timestamp""" - result = list(clickhouse.query(sql).named_results()) - - assert result[0]["event_id"] == native_statements[0]["id"] - assert result[0]["emission_time"] == native_statements[0]["timestamp"] - assert result[0]["event"] == statements[0] - - assert result[1]["event_id"] == native_statements[1]["id"] - assert result[1]["emission_time"] == native_statements[1]["timestamp"] - assert result[1]["event"] == statements[1] - - -def test_backends_db_clickhouse_query_statements_with_search_query_failure( - monkeypatch, caplog, clickhouse -): - """Test the clickhouse query_statements method, given a search query failure, - should raise a BackendException and log the error. - """ - # pylint: disable=unused-argument - - def mock_query(*_, **__): - """Mock the ClickHouseClient.collection.find method.""" - raise ClickHouseError("Something is wrong") - - backend = get_clickhouse_test_backend() - monkeypatch.setattr(backend.client, "query", mock_query) - - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute ClickHouse query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - backend.query_statements(RalphStatementsQuery.construct()) - - logger_name = "ralph.backends.database.clickhouse" - msg = "Failed to execute ClickHouse query. Something is wrong" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_db_clickhouse_query_statements_by_ids_with_search_query_failure( - monkeypatch, caplog, clickhouse -): - """Test the clickhouse backend query_statements_by_ids method, given a search query - failure, should raise a BackendException and log the error. - """ - # pylint: disable=unused-argument - - def mock_find(**_): - """Mock the ClickHouseClient.collection.find method.""" - raise ClickHouseError("Something is wrong") - - backend = get_clickhouse_test_backend() - monkeypatch.setattr(backend.client, "query", mock_find) - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute ClickHouse query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - backend.query_statements_by_ids( - [ - "abcdefg", - ] - ) - - logger_name = "ralph.backends.database.clickhouse" - msg = "Failed to execute ClickHouse query. Something is wrong" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_db_clickhouse_status(clickhouse): - """Test the ClickHouse status method. - - As pyclickhouse is monkeypatching the ClickHouse client to add admin object, it's - barely untestable. 😢 - """ - # pylint: disable=unused-argument - - database = get_clickhouse_test_backend() - assert database.status() == DatabaseStatus.OK diff --git a/tests/backends/database/test_es.py b/tests/backends/database/test_es.py deleted file mode 100644 index 02341309e..000000000 --- a/tests/backends/database/test_es.py +++ /dev/null @@ -1,545 +0,0 @@ -"""Tests for Ralph es database backend.""" - -import json -import logging -import random -import sys -from collections.abc import Iterable -from datetime import datetime -from io import StringIO -from pathlib import Path - -import pytest -from elastic_transport import ApiResponseMeta -from elasticsearch import ApiError -from elasticsearch import ConnectionError as ESConnectionError -from elasticsearch import Elasticsearch -from elasticsearch.client import CatClient -from elasticsearch.helpers import bulk - -from ralph.backends.database.base import DatabaseStatus, RalphStatementsQuery -from ralph.backends.database.es import ESDatabase, ESQuery -from ralph.conf import ESClientOptions, settings -from ralph.exceptions import BackendException, BackendParameterException - -from tests.fixtures.backends import ( - ES_TEST_FORWARDING_INDEX, - ES_TEST_HOSTS, - ES_TEST_INDEX, -) - - -def test_backends_database_es_database_instantiation(es): - """Test the ES backend instantiation.""" - # pylint: disable=invalid-name,unused-argument,protected-access - - assert ESDatabase.name == "es" - assert ESDatabase.query_model == ESQuery - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - - # When running locally host is 'elasticsearch', while it's localhost when - # running from the CI - assert any( - ( - "http://elasticsearch:9200" in database._hosts, - "http://localhost:9200" in database._hosts, - ) - ) - assert database.index == ES_TEST_INDEX - assert isinstance(database.client, Elasticsearch) - assert database.op_type == "index" - - for op_type in ("index", "create", "delete", "update"): - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX, op_type=op_type) - assert database.op_type == op_type - - -def test_backends_database_es_database_instantiation_with_forbidden_op_type(es): - """Test the ES backend instantiation with an op_type that is not allowed.""" - # pylint: disable=invalid-name,unused-argument,protected-access - - with pytest.raises(BackendParameterException): - ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX, op_type="foo") - - -def test_backends_database_es_client_kwargs(es): - """Test the ES backend client instantiation using client_options that must be - passed to the http(s) connection pool. - """ - # pylint: disable=invalid-name,unused-argument,protected-access - - database = ESDatabase( - hosts=[ - "https://elasticsearch:9200", - ], - index=ES_TEST_INDEX, - client_options=ESClientOptions( - ca_certs="/path/to/ca/bundle", verify_certs=True - ), - ) - - assert database.client.transport.node_pool.get().config.ca_certs == Path( - "/path/to/ca/bundle" - ) - - assert database.client.transport.node_pool.get().config.verify_certs is True - - -def test_backends_database_es_to_documents_method(es): - """Test to_documents method.""" - # pylint: disable=invalid-name,unused-argument - - # Create stream data - stream = StringIO("\n".join([json.dumps({"id": idx}) for idx in range(10)])) - stream.seek(0) - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - documents = database.to_documents(stream, lambda item: item.get("id")) - assert isinstance(documents, Iterable) - - documents = list(documents) - assert len(documents) == 10 - assert documents == [ - { - "_index": database.index, - "_id": idx, - "_op_type": "index", - "_source": {"id": idx}, - } - for idx in range(10) - ] - - -def test_backends_database_es_to_documents_method_with_create_op_type(es): - """Test to_documents method using the create op_type.""" - # pylint: disable=invalid-name,unused-argument - - # Create stream data - stream = StringIO("\n".join([json.dumps({"id": idx}) for idx in range(10)])) - stream.seek(0) - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX, op_type="create") - documents = database.to_documents(stream, lambda item: item.get("id")) - assert isinstance(documents, Iterable) - - documents = list(documents) - assert len(documents) == 10 - assert documents == [ - { - "_index": database.index, - "_id": idx, - "_op_type": "create", - "_source": {"id": idx}, - } - for idx in range(10) - ] - - -def test_backends_database_es_get_method(es): - """Test ES get method.""" - # pylint: disable=invalid-name - - # Insert documents - bulk( - es, - ( - {"_index": ES_TEST_INDEX, "_id": idx, "_source": {"id": idx}} - for idx in range(10) - ), - ) - # As we bulk insert documents, the index needs to be refreshed before making - # queries. - es.indices.refresh(index=ES_TEST_INDEX) - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - - expected = [{"id": idx} for idx in range(10)] - assert list(map(lambda x: x.get("_source"), database.get())) == expected - - -def test_backends_database_es_get_method_with_a_custom_query(es): - """Test ES get method with a custom query.""" - # pylint: disable=invalid-name - - # Insert documents - bulk( - es, - ( - { - "_index": ES_TEST_INDEX, - "_id": idx, - "_source": {"id": idx, "modulo": idx % 2}, - } - for idx in range(10) - ), - ) - # As we bulk insert documents, the index needs to be refreshed before making - # queries. - es.indices.refresh(index=ES_TEST_INDEX) - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - - # Find every even item - query = ESQuery(query={"query": {"term": {"modulo": 0}}}) - results = list(database.get(query=query)) - assert len(results) == 5 - assert results[0]["_source"]["id"] == 0 - assert results[1]["_source"]["id"] == 2 - assert results[2]["_source"]["id"] == 4 - assert results[3]["_source"]["id"] == 6 - assert results[4]["_source"]["id"] == 8 - - # Check query argument type - with pytest.raises( - BackendParameterException, - match="'query' argument is expected to be a ESQuery instance.", - ): - list(database.get(query="foo")) - - -def test_backends_database_es_put_method(es, fs, monkeypatch): - """Test ES put method.""" - # pylint: disable=invalid-name - - # Prepare fake file system - fs.create_dir(str(settings.APP_DIR)) - # Force Path instantiation with fake FS - history_file = Path(settings.HISTORY_FILE) - assert not history_file.exists() - - monkeypatch.setattr( - "sys.stdin", StringIO("\n".join([json.dumps({"id": idx}) for idx in range(10)])) - ) - - assert len(es.search(index=ES_TEST_INDEX)["hits"]["hits"]) == 0 - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - success_count = database.put(sys.stdin, chunk_size=5) - - # As we bulk insert documents, the index needs to be refreshed before making - # queries. - es.indices.refresh(index=ES_TEST_INDEX) - - hits = es.search(index=ES_TEST_INDEX)["hits"]["hits"] - assert len(hits) == 10 - assert success_count == 10 - assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) - - -def test_backends_database_es_put_method_with_update_op_type(es, fs, monkeypatch): - """Test ES put method using the update op_type.""" - # pylint: disable=invalid-name - - # Prepare fake file system - fs.create_dir(settings.APP_DIR) - # Force Path instantiation with fake FS - history_file = Path(settings.HISTORY_FILE) - assert not history_file.exists() - - monkeypatch.setattr( - "sys.stdin", - StringIO( - "\n".join([json.dumps({"id": idx, "value": str(idx)}) for idx in range(10)]) - ), - ) - - assert len(es.search(index=ES_TEST_INDEX)["hits"]["hits"]) == 0 - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) - database.put(sys.stdin, chunk_size=5) - - # As we bulk insert documents, the index needs to be refreshed before making - # queries. - es.indices.refresh(index=ES_TEST_INDEX) - - hits = es.search(index=ES_TEST_INDEX)["hits"]["hits"] - assert len(hits) == 10 - assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) - assert sorted([hit["_source"]["value"] for hit in hits]) == list( - map(str, range(10)) - ) - - monkeypatch.setattr( - "sys.stdin", - StringIO( - "\n".join( - [json.dumps({"id": idx, "value": str(10 + idx)}) for idx in range(10)] - ) - ), - ) - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX, op_type="update") - success_count = database.put(sys.stdin, chunk_size=5) - - # As we bulk insert documents, the index needs to be refreshed before making - # queries. - es.indices.refresh(index=ES_TEST_INDEX) - - hits = es.search(index=ES_TEST_INDEX)["hits"]["hits"] - assert len(hits) == 10 - assert success_count == 10 - assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) - assert sorted([hit["_source"]["value"] for hit in hits]) == list( - map(lambda x: str(x + 10), range(10)) - ) - - -def test_backends_database_es_put_with_badly_formatted_data_raises_a_backend_exception( - es, fs, monkeypatch -): - """Test ES put method with badly formatted data.""" - # pylint: disable=invalid-name,unused-argument - - records = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] - # Patch a record with a non-expected type for the count field (should be - # assigned as long) - records[4].update({"count": "wrong"}) - - monkeypatch.setattr( - "sys.stdin", StringIO("\n".join([json.dumps(record) for record in records])) - ) - - assert len(es.search(index=ES_TEST_INDEX)["hits"]["hits"]) == 0 - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - - # By default, we should raise an error and stop the importation - msg = "\\('1 document\\(s\\) failed to index.', '5 succeeded writes'\\)" - with pytest.raises(BackendException, match=msg) as exception_info: - database.put(sys.stdin, chunk_size=2) - es.indices.refresh(index=ES_TEST_INDEX) - hits = es.search(index=ES_TEST_INDEX)["hits"]["hits"] - assert len(hits) == 5 - assert exception_info.value.args[-1] == "5 succeeded writes" - assert sorted([hit["_source"]["id"] for hit in hits]) == [0, 1, 2, 3, 5] - - -def test_backends_database_es_put_with_badly_formatted_data_in_force_mode( - es, fs, monkeypatch -): - """Test ES put method with badly formatted data when the force mode is active.""" - # pylint: disable=invalid-name,unused-argument - - records = [{"id": idx, "count": random.randint(0, 100)} for idx in range(10)] - # Patch a record with a non-expected type for the count field (should be - # assigned as long) - records[2].update({"count": "wrong"}) - - monkeypatch.setattr( - "sys.stdin", StringIO("\n".join([json.dumps(record) for record in records])) - ) - - assert len(es.search(index=ES_TEST_INDEX)["hits"]["hits"]) == 0 - - database = ESDatabase( - hosts=ES_TEST_HOSTS, - index=ES_TEST_INDEX, - ) - # When forcing import, We expect the record with non-expected type to have - # been dropped - database.put(sys.stdin, chunk_size=5, ignore_errors=True) - es.indices.refresh(index=ES_TEST_INDEX) - hits = es.search(index=ES_TEST_INDEX)["hits"]["hits"] - assert len(hits) == 9 - assert sorted([hit["_source"]["id"] for hit in hits]) == [ - i for i in range(10) if i != 2 - ] - - -def test_backends_database_es_put_with_datastream(es_data_stream, fs, monkeypatch): - """Test ES put method when using a configured data stream.""" - # pylint: disable=invalid-name,unused-argument - - monkeypatch.setattr( - "sys.stdin", - StringIO( - "\n".join( - [ - json.dumps({"id": idx, "@timestamp": datetime.now().isoformat()}) - for idx in range(10) - ] - ) - ), - ) - - assert len(es_data_stream.search(index=ES_TEST_INDEX)["hits"]["hits"]) == 0 - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX, op_type="create") - database.put(sys.stdin, chunk_size=5) - - # As we bulk insert documents, the index needs to be refreshed before making - # queries. - es_data_stream.indices.refresh(index=ES_TEST_INDEX) - - hits = es_data_stream.search(index=ES_TEST_INDEX)["hits"]["hits"] - assert len(hits) == 10 - assert sorted([hit["_source"]["id"] for hit in hits]) == list(range(10)) - - -def test_backends_database_es_query_statements_with_pit_query_failure( - monkeypatch, caplog, es -): - """Test the ES query_statements method, given a point in time query failure, should - raise a BackendException and log the error. - """ - # pylint: disable=invalid-name,unused-argument - - def mock_open_point_in_time(**_): - """Mock the Elasticsearch.open_point_in_time method.""" - raise ValueError("ES failure") - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) - monkeypatch.setattr(database.client, "open_point_in_time", mock_open_point_in_time) - - caplog.set_level(logging.ERROR) - - msg = "'Failed to open ElasticSearch point in time', 'ES failure'" - with pytest.raises(BackendException, match=msg): - database.query_statements(RalphStatementsQuery.construct()) - - logger_name = "ralph.backends.database.es" - msg = "Failed to open ElasticSearch point in time. ES failure" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_database_es_query_statements_with_search_query_failure( - monkeypatch, caplog, es -): - """Test the ES query_statements method, given a search query failure, should - raise a BackendException and log the error. - """ - # pylint: disable=invalid-name,unused-argument - - def mock_search(**_): - """Mock the Elasticsearch.search method.""" - raise ApiError("Something is wrong", ApiResponseMeta(*([None] * 5)), None) - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) - monkeypatch.setattr(database.client, "search", mock_search) - - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute ElasticSearch query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - database.query_statements(RalphStatementsQuery.construct()) - - logger_name = "ralph.backends.database.es" - msg = "Failed to execute ElasticSearch query. ApiError(None, 'Something is wrong')" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_database_es_query_statements_by_ids_with_search_query_failure( - monkeypatch, caplog, es -): - """Test the ES query_statements_by_ids method, given a search query failure, should - raise a BackendException and log the error. - """ - # pylint: disable=invalid-name,unused-argument - - def mock_search(**_): - """Mock the Elasticsearch.search method.""" - raise ApiError("Something is wrong", ApiResponseMeta(*([None] * 5)), None) - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) - monkeypatch.setattr(database.client, "search", mock_search) - - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute ElasticSearch query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - database.query_statements_by_ids(RalphStatementsQuery()) - - logger_name = "ralph.backends.database.es" - msg = "Failed to execute ElasticSearch query. ApiError(None, 'Something is wrong')" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_database_es_query_statements_by_ids_with_multiple_indexes( - es, es_forwarding -): - """Test the ES query_statements_by_ids method, given a valid search - query, should execute the query uniquely on the specified index and return the - expected results. - """ - # pylint: disable=invalid-name,use-implicit-booleaness-not-comparison - - # Insert documents - index_1_document = {"_index": ES_TEST_INDEX, "_id": "1", "_source": {"id": "1"}} - index_2_document = { - "_index": ES_TEST_FORWARDING_INDEX, - "_id": "2", - "_source": {"id": "2"}, - } - bulk(es, [index_1_document]) - bulk(es_forwarding, [index_2_document]) - - # As we bulk insert documents, the index needs to be refreshed before making queries - es.indices.refresh(index=ES_TEST_INDEX) - es_forwarding.indices.refresh(index=ES_TEST_FORWARDING_INDEX) - - # Instantiate ES Databases - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) - database_2 = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_FORWARDING_INDEX) - - # Check the expected search query results - index_1_document = dict(index_1_document, **{"_score": 1.0}) - index_2_document = dict(index_2_document, **{"_score": 1.0}) - assert database.query_statements_by_ids(["1"]) == [index_1_document] - assert database.query_statements_by_ids(["2"]) == [] - assert database_2.query_statements_by_ids(["1"]) == [] - assert database_2.query_statements_by_ids(["2"]) == [index_2_document] - - -def test_backends_database_es_status(es, monkeypatch): - """Test the ES status method.""" - # pylint: disable=invalid-name,unused-argument - - database = ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) - - with monkeypatch.context() as mkpch: - mkpch.setattr( - CatClient, - "health", - lambda client: ( - "1664532320 10:05:20 docker-cluster green 1 1 2 2 0 0 1 0 - 66.7%" - ), - ) - assert database.status() == DatabaseStatus.OK - - with monkeypatch.context() as mkpch: - mkpch.setattr( - CatClient, - "health", - lambda client: ( - "1664532320 10:05:20 docker-cluster yellow 1 1 2 2 0 0 1 0 - 66.7%" - ), - ) - assert database.status() == DatabaseStatus.ERROR - - with monkeypatch.context() as mkpch: - - def mock_connection_error(*args, **kwargs): - """ES client info mock that raises a connection error.""" - raise ESConnectionError("Mocked connection error") - - mkpch.setattr(Elasticsearch, "info", mock_connection_error) - assert database.status() == DatabaseStatus.AWAY diff --git a/tests/backends/database/test_mongo.py b/tests/backends/database/test_mongo.py deleted file mode 100644 index 85e95e4c8..000000000 --- a/tests/backends/database/test_mongo.py +++ /dev/null @@ -1,502 +0,0 @@ -"""Tests for Ralph mongo database backend.""" - -import logging -from datetime import datetime - -import pytest -from bson.objectid import ObjectId -from pymongo import MongoClient -from pymongo.errors import PyMongoError - -from ralph.backends.database.base import DatabaseStatus, RalphStatementsQuery -from ralph.backends.database.mongo import MongoDatabase, MongoQuery -from ralph.exceptions import ( - BackendException, - BackendParameterException, - BadFormatException, -) - -from tests.fixtures.backends import ( - MONGO_TEST_COLLECTION, - MONGO_TEST_CONNECTION_URI, - MONGO_TEST_DATABASE, - MONGO_TEST_FORWARDING_COLLECTION, -) - - -def test_backends_database_mongo_database_instantiation(): - """Test the Mongo backend instantiation.""" - assert MongoDatabase.name == "mongo" - - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - - assert isinstance(backend.client, MongoClient) - assert hasattr(backend.client, MONGO_TEST_DATABASE) - database = getattr(backend.client, MONGO_TEST_DATABASE) - assert hasattr(database, MONGO_TEST_COLLECTION) - - -def test_backends_database_mongo_get_method(mongo): - """Test the mongo backend get method.""" - # Create records - timestamp = {"timestamp": "2022-06-27T15:36:50"} - documents = MongoDatabase.to_documents( - [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) - - # Get backend - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - expected = [ - {"_id": "62b9ce922c26b46b68ffc68f", "_source": {"id": "foo", **timestamp}}, - {"_id": "62b9ce92fcde2b2edba56bf4", "_source": {"id": "bar", **timestamp}}, - ] - assert list(backend.get()) == expected - assert list(backend.get(chunk_size=1)) == expected - assert list(backend.get(chunk_size=1000)) == expected - - -def test_backends_database_mongo_get_method_with_a_custom_query(mongo): - """Test the mongo backend get method with a custom query.""" - # Create records - timestamp = {"timestamp": datetime.now().isoformat()} - documents = MongoDatabase.to_documents( - [ - {"id": "foo", "bool": 1, **timestamp}, - {"id": "bar", "bool": 0, **timestamp}, - {"id": "lol", "bool": 1, **timestamp}, - ] - ) - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(documents) - - # Get backend - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - - # Test filtering - query = MongoQuery(filter={"_source.bool": {"$eq": 1}}) - results = list(backend.get(query=query)) - assert len(results) == 2 - assert results[0]["_source"]["id"] == "foo" - assert results[1]["_source"]["id"] == "lol" - - # Test projection - query = MongoQuery(projection={"_source.bool": 1}) - results = list(backend.get(query=query)) - assert len(results) == 3 - assert list(results[0]["_source"].keys()) == ["bool"] - assert list(results[1]["_source"].keys()) == ["bool"] - assert list(results[2]["_source"].keys()) == ["bool"] - - # Test filtering and projection - query = MongoQuery( - filter={"_source.bool": {"$eq": 0}}, projection={"_source.id": 1} - ) - results = list(backend.get(query=query)) - assert len(results) == 1 - assert results[0]["_source"]["id"] == "bar" - assert list(results[0]["_source"].keys()) == ["id"] - - # Check query argument type - with pytest.raises( - BackendParameterException, - match="'query' argument is expected to be a MongoQuery instance.", - ): - list(backend.get(query="foo")) - - -def test_backends_database_mongo_to_documents_method(): - """Test the mongo backend to_documents method.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "bar", **timestamp}, - ] - documents = MongoDatabase.to_documents(statements) - - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - # Identical statement ID produces the same ObjectId - assert next(documents) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - - -def test_backends_database_mongo_to_documents_method_when_statement_has_no_id(caplog): - """Test the mongo backend to_documents method when a statement has no id field.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, timestamp, {"id": "bar", **timestamp}] - - documents = MongoDatabase.to_documents(statements, ignore_errors=False) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - with pytest.raises( - BadFormatException, match=f"statement {timestamp} has no 'id' field" - ): - next(documents) - - documents = MongoDatabase.to_documents(statements, ignore_errors=True) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert caplog.records[0].message == f"statement {timestamp} has no 'id' field" - - -def test_backends_database_mongo_to_documents_method_when_statement_has_no_timestamp( - caplog, -): - """Test the mongo backend to_documents method when a statement has no timestamp.""" - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar"}, {"id": "baz", **timestamp}] - - documents = MongoDatabase.to_documents(statements, ignore_errors=False) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - - with pytest.raises( - BadFormatException, match="statement {'id': 'bar'} has no 'timestamp' field" - ): - next(documents) - - documents = MongoDatabase.to_documents(statements, ignore_errors=True) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92baa5a0964d3320fb"), - "_source": {"id": "baz", **timestamp}, - } - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert caplog.records[0].message == ( - "statement {'id': 'bar'} has no 'timestamp' field" - ) - - -def test_backends_database_mongo_to_documents_method_with_invalid_timestamp(caplog): - """Test the mongo backend to_documents method given a statement with an invalid - timestamp. - """ - valid_timestamp = {"timestamp": "2022-06-27T15:36:50"} - invalid_timestamp = {"timestamp": "This is not a valid timestamp!"} - invalid_statement = {"id": "bar", **invalid_timestamp} - statements = [ - {"id": "foo", **valid_timestamp}, - invalid_statement, - {"id": "baz", **valid_timestamp}, - ] - - documents = MongoDatabase.to_documents(statements, ignore_errors=False) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **valid_timestamp}, - } - - with pytest.raises( - BadFormatException, - match=f"statement {invalid_statement} has an invalid 'timestamp' field", - ): - next(documents) - - documents = MongoDatabase.to_documents(statements, ignore_errors=True) - assert next(documents) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **valid_timestamp}, - } - assert next(documents) == { - "_id": ObjectId("62b9ce92baa5a0964d3320fb"), - "_source": {"id": "baz", **valid_timestamp}, - } - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert caplog.records[0].message == ( - f"statement {invalid_statement} has an invalid 'timestamp' field" - ) - - -def test_backends_database_mongo_bulk_import_method(mongo): - """Test the mongo backend bulk_import method.""" - # pylint: disable=unused-argument - - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend.bulk_import(MongoDatabase.to_documents(statements)) - - results = backend.collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - - -def test_backends_database_mongo_bulk_import_method_with_duplicated_key(mongo): - """Test the mongo backend bulk_import method with a duplicated key conflict.""" - # pylint: disable=unused-argument - - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - - # Identical statement ID produces the same ObjectId, leading to a - # duplicated key write error while trying to bulk import this batch - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "bar", **timestamp}, - ] - documents = list(MongoDatabase.to_documents(statements)) - with pytest.raises(BackendException, match="E11000 duplicate key error collection"): - backend.bulk_import(documents) - - success = backend.bulk_import(documents, ignore_errors=True) - assert success == 0 - - -def test_backends_database_mongo_bulk_import_method_import_partial_chunks_on_error( - mongo, -): - """Test the mongo backend bulk_import method imports partial chunks while raising a - BulkWriteError and ignoring errors. - """ - # pylint: disable=unused-argument - - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - - # Identical statement ID produces the same ObjectId, leading to a - # duplicated key write error while trying to bulk import this batch - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [ - {"id": "foo", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "baz", **timestamp}, - {"id": "bar", **timestamp}, - {"id": "lol", **timestamp}, - ] - documents = list(MongoDatabase.to_documents(statements)) - assert backend.bulk_import(documents, ignore_errors=True) == 3 - - -def test_backends_database_mongo_put_method(mongo): - """Test the mongo backend put method.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - - success = backend.put(statements) - assert success == 2 - assert collection.estimated_document_count() == 2 - - results = collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - - -def test_backends_database_mongo_put_method_with_custom_chunk_size(mongo): - """Test the mongo backend put method with a custom chunk_size.""" - database = getattr(mongo, MONGO_TEST_DATABASE) - collection = getattr(database, MONGO_TEST_COLLECTION) - assert collection.estimated_document_count() == 0 - - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statements = [{"id": "foo", **timestamp}, {"id": "bar", **timestamp}] - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - - success = backend.put(statements, chunk_size=1) - assert success == 2 - assert collection.estimated_document_count() == 2 - - results = collection.find() - assert next(results) == { - "_id": ObjectId("62b9ce922c26b46b68ffc68f"), - "_source": {"id": "foo", **timestamp}, - } - assert next(results) == { - "_id": ObjectId("62b9ce92fcde2b2edba56bf4"), - "_source": {"id": "bar", **timestamp}, - } - - -def test_backends_database_mongo_query_statements_with_search_query_failure( - monkeypatch, caplog, mongo -): - """Test the mongo backend query_statements method, given a search query failure, - should raise a BackendException and log the error. - """ - # pylint: disable=unused-argument - - def mock_find(**_): - """Mock the MongoClient.collection.find method.""" - raise PyMongoError("Something is wrong") - - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - monkeypatch.setattr(backend.collection, "find", mock_find) - - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute MongoDB query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - backend.query_statements(RalphStatementsQuery.construct()) - - logger_name = "ralph.backends.database.mongo" - msg = "Failed to execute MongoDB query. Something is wrong" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_database_mongo_query_statements_by_ids_with_search_query_failure( - monkeypatch, caplog, mongo -): - """Test the mongo backend query_statements_by_ids method, given a search query - failure, should raise a BackendException and log the error. - """ - # pylint: disable=unused-argument - - def mock_find(**_): - """Mock the MongoClient.collection.find method.""" - raise ValueError("Something is wrong") - - backend = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - monkeypatch.setattr(backend.collection, "find", mock_find) - caplog.set_level(logging.ERROR) - - msg = "'Failed to execute MongoDB query', 'Something is wrong'" - with pytest.raises(BackendException, match=msg): - backend.query_statements_by_ids(RalphStatementsQuery()) - - logger_name = "ralph.backends.database.mongo" - msg = "Failed to execute MongoDB query. Something is wrong" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_database_mongo_query_statements_by_ids_with_multiple_collections( - mongo, mongo_forwarding -): - """Test the mongo backend query_statements_by_ids method, given a valid search - query, should execute the query uniquely on the specified collection and return the - expected results. - """ - # pylint: disable=unused-argument,use-implicit-booleaness-not-comparison - - # Instantiate Mongo Databases - backend_1 = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - backend_2 = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_FORWARDING_COLLECTION, - ) - - # Insert documents - timestamp = {"timestamp": "2022-06-27T15:36:50"} - statement_1 = {"id": "1", **timestamp} - statement_1_expected = [{"_id": "1", "_source": statement_1}] - statement_2 = {"id": "2", **timestamp} - statement_2_expected = [{"_id": "2", "_source": statement_2}] - collection_1_document = list(MongoDatabase.to_documents([statement_1])) - collection_2_document = list(MongoDatabase.to_documents([statement_2])) - backend_1.bulk_import(collection_1_document) - backend_2.bulk_import(collection_2_document) - - # Check the expected search query results - assert backend_1.query_statements_by_ids(["1"]) == statement_1_expected - assert backend_2.query_statements_by_ids(["1"]) == [] - assert backend_2.query_statements_by_ids(["2"]) == statement_2_expected - assert backend_1.query_statements_by_ids(["2"]) == [] - - -def test_backends_database_mongo_status(mongo): - """Test the Mongo status method. - - As pymongo is monkeypatching the MongoDB client to add admin object, it's - barely untestable. 😢 - """ - # pylint: disable=unused-argument - - database = MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, - ) - assert database.status() == DatabaseStatus.OK diff --git a/tests/backends/lrs/test_clickhouse.py b/tests/backends/lrs/test_clickhouse.py index 7b44246ac..e08dfa9df 100644 --- a/tests/backends/lrs/test_clickhouse.py +++ b/tests/backends/lrs/test_clickhouse.py @@ -104,13 +104,13 @@ { "where": [ "event_id = {statementId:UUID}", - "event.actor.account_name = {actor__account_name:String}", - "event.actor.account_homepage = {actor__account_homepage:String}", + "event.actor.account.name = {actor__account__name:String}", + "event.actor.account.homePage = {actor__account_home_page:String}", ], "params": { "statementId": "test_id", - "actor__account_name": "13936749", - "actor__account_homepage": "http://www.example.com", + "actor__account__name": "13936749", + "actor__account_home_page": "http://www.example.com", "ascending": True, "format": "exact", }, diff --git a/tests/backends/storage/__init__.py b/tests/backends/storage/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/backends/storage/test_base.py b/tests/backends/storage/test_base.py deleted file mode 100644 index 3235ecaf5..000000000 --- a/tests/backends/storage/test_base.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Tests for Ralph base storage backend.""" - -from ralph.backends.storage.base import BaseStorage - - -def test_backends_storage_base_abstract_interface_with_implemented_abstract_method(): - """Test the interface mechanism with properly implemented abstract methods.""" - - class GoodStorage(BaseStorage): - """Correct implementation with required abstract methods.""" - - name = "good" - - def list(self, details=False, new=False): - """Fake the list method.""" - - def url(self, name): - """Fake the url method.""" - - def read(self, name, chunk_size=0): - """Fake the read method.""" - - def write(self, stream, name, overwrite=False): - """Fake the write method.""" - - GoodStorage() - - assert GoodStorage.name == "good" diff --git a/tests/backends/storage/test_fs.py b/tests/backends/storage/test_fs.py deleted file mode 100644 index d61d1ae8f..000000000 --- a/tests/backends/storage/test_fs.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tests for Ralph fs storage backend.""" - -from collections.abc import Iterable -from pathlib import Path - -import pytest - -from ralph.backends.storage.fs import FSStorage -from ralph.conf import settings - - -# pylint: disable=invalid-name -# pylint: disable=unused-argument -def test_backends_storage_fs_storage_instantiation(fs): - """Test the FSStorage backend instantiation.""" - # pylint: disable=protected-access - - assert FSStorage.name == "fs" - - storage = FSStorage() - - assert str(storage._path) == settings.BACKENDS.STORAGE.FS.PATH - - deep_path = "deep/directories/path" - - storage = FSStorage(deep_path) - - assert storage._path == Path(deep_path) - assert storage._path.is_dir() - - # Check that a storage with the same path doesn't throw an exception - FSStorage(deep_path) - - -# pylint: disable=invalid-name -# pylint: disable=unused-argument -def test_backends_storage_fs_getfile(fs): - """Test that an existing path can be returned, and throws an exception - otherwise. - """ - # pylint: disable=protected-access - - path = "test_fs/" - filename = "some_file" - storage = FSStorage(path) - - storage._get_filepath(filename) - with pytest.raises(FileNotFoundError): - storage._get_filepath(filename, strict=True) - storage._get_filepath(filename, strict=False) - - fs.create_file(Path(path, filename)) - - assert storage._get_filepath(filename, strict=True) == Path(path, filename) - - -# pylint: disable=invalid-name -# pylint: disable=unused-argument -def test_backends_storage_fs_url(fs): - """Test that the full URL of the file can be returned.""" - path = "test_fs/" - filename = "some_file" - storage = FSStorage(path) - - fs.create_file(Path(path, filename)) - - assert storage.url(filename) == "/test_fs/some_file" - - -# pylint: disable=invalid-name -# pylint: disable=unused-argument -def test_backends_storage_fs_list(fs, settings_fs): - """Test archives listing in FSStorage.""" - fs.create_dir(settings.APP_DIR) - - path = "test_fs/" - filename1 = "file1" - filename2 = "file2" - storage = FSStorage(path) - - fs.create_file(path + filename1, contents="content") - fs.create_file(path + filename2, contents="some more content") - - assert isinstance(storage.list(), Iterable) - assert isinstance(storage.list(new=True), Iterable) - assert isinstance(storage.list(details=True), Iterable) - - simple_list = list(storage.list()) - assert filename1 in simple_list - assert filename2 in simple_list - assert len(simple_list) == 2 - - # Fetch it so it's not new anymore - list(storage.read(filename1)) - - new_list = list(storage.list(new=True)) - assert filename1 not in new_list - assert filename2 in new_list - assert len(new_list) == 1 - - detail_list = list(storage.list(details=True)) - assert any( - (archive["filename"] == filename1 and archive["size"] == 7) - for archive in detail_list - ) - assert any( - (archive["filename"] == filename2 and archive["size"] == 17) - for archive in detail_list - ) - assert len(simple_list) == 2 diff --git a/tests/backends/storage/test_ldp.py b/tests/backends/storage/test_ldp.py deleted file mode 100644 index 6bff7cf29..000000000 --- a/tests/backends/storage/test_ldp.py +++ /dev/null @@ -1,459 +0,0 @@ -"""Tests for Ralph ldp storage backend.""" - -import datetime -import gzip -import json -import os.path -import uuid -from collections.abc import Iterable -from pathlib import Path, PurePath -from urllib.parse import urlparse -from xmlrpc.client import gzip_decode - -import ovh -import pytest -import requests - -from ralph.backends.storage.ldp import LDPStorage -from ralph.conf import settings -from ralph.exceptions import BackendParameterException - - -def test_backends_storage_ldp_storage_instantiation(): - """Test the LDPStorage backend instantiation.""" - # pylint: disable=protected-access - - assert LDPStorage.name == "ldp" - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - ) - - assert storage._endpoint == "ovh-eu" - assert storage._application_key == "fake_key" - assert storage._application_secret == "fake_secret" - assert storage._consumer_key == "another_fake_key" - assert storage.service_name is None - assert storage.stream_id is None - assert isinstance(storage.client, ovh.Client) - - -def test_backends_storage_ldp_archive_endpoint_property(): - """Test the LDPStorage _archive_endpoint property.""" - # pylint: disable=protected-access, pointless-statement - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="foo", - stream_id="bar", - ) - assert ( - storage._archive_endpoint == "/dbaas/logs/foo/output/graylog/stream/bar/archive" - ) - - storage.service_name = None - with pytest.raises( - BackendParameterException, - match=( - "LDPStorage backend instance requires to set " - "both service_name and stream_id" - ), - ): - storage._archive_endpoint - - storage.service_name = "foo" - storage.stream_id = None - with pytest.raises( - BackendParameterException, - match=( - "LDPStorage backend instance requires to set " - "both service_name and stream_id" - ), - ): - storage._archive_endpoint - - storage.service_name = None - with pytest.raises( - BackendParameterException, - match=( - "LDPStorage backend instance requires to set " - "both service_name and stream_id" - ), - ): - storage._archive_endpoint - - -def test_backends_storage_ldp_details_method(monkeypatch): - """Test the LDPStorage _details method.""" - # pylint: disable=protected-access - - def mock_get(url): - """Mock the OVH client get request.""" - name = PurePath(urlparse(url).path).name - return { - "archiveId": str(uuid.UUID(name)), - "createdAt": "2020-06-18T04:38:59.436634+02:00", - "filename": "2020-06-16.gz", - "md5": "01585b394be0495e38dbb60b20cb40a9", - "retrievalDelay": 0, - "retrievalState": "sealed", - "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", - "size": 67906662, - } - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - # Apply the monkeypatch for requests.get to mock_get - monkeypatch.setattr(storage.client, "get", mock_get) - - details = storage._details("5d49d1b3a3eb498c90396a482166f888") - assert details.get("archiveId") == "5d49d1b3-a3eb-498c-9039-6a482166f888" - - -def test_backends_storage_ldp_url_method(monkeypatch): - """Test the LDPStorage url method.""" - - def mock_post(url): - """Mock the OVH Client post request.""" - # pylint: disable=unused-argument - return { - "expirationDate": "2020-10-13T12:59:37.326131+00:00", - "url": ( - "https://storage.gra.cloud.ovh.net/v1/" - "AUTH_-c3b123f595c46e789acdd1227eefc13/" - "gra2-pcs/5eba98fb4fcb481001180e4b/" - "2020-06-01.gz?" - "temp_url_sig=e1b3ab10a9149a4ff5dcb95f40f21063780d26f7&" - "temp_url_expires=1602593977" - ), - } - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - # Apply the monkeypatch for requests.post to mock_get - monkeypatch.setattr(storage.client, "post", mock_post) - - assert storage.url("5d49d1b3-a3eb-498c-9039-6a482166f888") == ( - "https://storage.gra.cloud.ovh.net/v1/" - "AUTH_-c3b123f595c46e789acdd1227eefc13/" - "gra2-pcs/5eba98fb4fcb481001180e4b/" - "2020-06-01.gz?" - "temp_url_sig=e1b3ab10a9149a4ff5dcb95f40f21063780d26f7&" - "temp_url_expires=1602593977" - ) - - -def test_backends_storage_ldp_list_method(monkeypatch): - """Test the LDPStorage list method with a blank history.""" - - def mock_list(url): - """Mock OVH client list stream archives get request.""" - # pylint: disable=unused-argument - return [ - "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "997db3eb-b9ca-485d-810f-b530a6cef7c6", - "08075b54-8d24-42ea-a509-9f10b0e3b416", - "75c865fd-b4eb-4b2b-9290-e8166a187d50", - "72e82041-7245-4ef1-b876-01964c6a8c50", - ] - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - # Apply the monkeypatch for requests.post to mock_get - monkeypatch.setattr(storage.client, "get", mock_list) - - archives = storage.list(details=False, new=False) - assert isinstance(archives, Iterable) - assert list(archives) == [ - "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "997db3eb-b9ca-485d-810f-b530a6cef7c6", - "08075b54-8d24-42ea-a509-9f10b0e3b416", - "75c865fd-b4eb-4b2b-9290-e8166a187d50", - "72e82041-7245-4ef1-b876-01964c6a8c50", - ] - - -def test_backends_storage_ldp_list_method_history_management( - monkeypatch, fs, settings_fs -): - """Test the LDPStorage list method with a history.""" - # pylint: disable=invalid-name,unused-argument - - def mock_list(url): - """Mock the OVH client list stream archives get request.""" - # pylint: disable=unused-argument - return [ - "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "997db3eb-b9ca-485d-810f-b530a6cef7c6", - "08075b54-8d24-42ea-a509-9f10b0e3b416", - "75c865fd-b4eb-4b2b-9290-e8166a187d50", - "72e82041-7245-4ef1-b876-01964c6a8c50", - ] - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - # Apply the monkeypatch for requests.post to mock_get - monkeypatch.setattr(storage.client, "get", mock_list) - - # Create a read history - fs.create_file( - settings.HISTORY_FILE, - contents=json.dumps( - [ - { - "backend": "ldp", - "command": "read", - "id": "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "filename": "20201002.tgz", - "size": 23424233, - "fetched_at": "2020-10-07T16:37:25.887664+00:00", - }, - { - "backend": "ldp", - "command": "read", - "id": "997db3eb-b9ca-485d-810f-b530a6cef7c6", - "filename": "20201002.tgz", - "size": 23424233, - "fetched_at": "2020-10-07T16:40:25.887664+00:00", - }, - { - "backend": "ldp", - "command": "read", - "id": "08075b54-8d24-42ea-a509-9f10b0e3b416", - "filename": "20201002.tgz", - "size": 23424233, - "fetched_at": "2020-10-07T19:37:25.887664+00:00", - }, - ] - ), - ) - - archives = storage.list(details=False, new=True) - assert isinstance(archives, Iterable) - assert sorted(list(archives)) == sorted( - [ - "75c865fd-b4eb-4b2b-9290-e8166a187d50", - "72e82041-7245-4ef1-b876-01964c6a8c50", - ] - ) - - -def test_backends_storage_ldp_list_method_with_details(monkeypatch): - """Test the LDPStorage list method with detailed output.""" - details_responses = [ - { - "archiveId": "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "createdAt": "2020-06-18T04:38:59.436634+02:00", - "filename": "2020-06-16.gz", - "md5": "01585b394be0495e38dbb60b20cb40a9", - "retrievalDelay": 0, - "retrievalState": "sealed", - "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", - "size": 67906662, - }, - { - "archiveId": "997db3eb-b9ca-485d-810f-b530a6cef7c6", - "createdAt": "2020-06-18T04:38:59.436634+02:00", - "filename": "2020-06-17.gz", - "md5": "01585b394be0495e38dbb60b20cb40a9", - "retrievalDelay": 0, - "retrievalState": "sealed", - "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", - "size": 67906662, - }, - ] - get_details_response = (response for response in details_responses) - - def mock_get(url): - """Mock OVH client get requests.""" - # list request - if url.endswith("archive"): - return [ - "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "997db3eb-b9ca-485d-810f-b530a6cef7c6", - ] - # details request - return next(get_details_response) - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - # Apply the monkeypatch for requests.post to mock_get - monkeypatch.setattr(storage.client, "get", mock_get) - - archives = storage.list(details=True, new=False) - assert isinstance(archives, Iterable) - assert list(archives) == details_responses - - -def test_backends_storage_ldp_read_method(monkeypatch, fs, settings_fs): - """Test the LDPStorage read method with detailed output.""" - # pylint: disable=invalid-name,unused-argument - - # Create fake archive to stream - archive_path = Path("/tmp/2020-06-16.gz") - archive_content = {"foo": "bar"} - with gzip.open(archive_path, "wb") as archive_file: - archive_file.write(bytes(json.dumps(archive_content), encoding="utf-8")) - - def mock_ovh_post(url): - """Mock the OVH Client post request.""" - # pylint: disable=unused-argument - - return { - "expirationDate": "2020-10-13T12:59:37.326131+00:00", - "url": ( - "https://storage.gra.cloud.ovh.net/v1/" - "AUTH_-c3b123f595c46e789acdd1227eefc13/" - "gra2-pcs/5eba98fb4fcb481001180e4b/" - "2020-06-01.gz?" - "temp_url_sig=e1b3ab10a9149a4ff5dcb95f40f21063780d26f7&" - "temp_url_expires=1602593977" - ), - } - - def mock_ovh_get(url): - """Mock the OVH client get requests.""" - # pylint: disable=unused-argument - - return { - "archiveId": "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "createdAt": "2020-06-18T04:38:59.436634+02:00", - "filename": "2020-06-16.gz", - "md5": "01585b394be0495e38dbb60b20cb40a9", - "retrievalDelay": 0, - "retrievalState": "sealed", - "sha256": "645d8e21e6fdb8aa7ffc5c[...]9ce612d06df8dcf67cb29a45ca", - "size": 67906662, - } - - class MockRequestsResponse: - """A basic mock for a requests response.""" - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def iter_content(self, chunk_size): - """Fake content file iteration.""" - # pylint: disable=no-self-use - - with archive_path.open("rb") as archive: - while chunk := archive.read(chunk_size): - yield chunk - - def raise_for_status(self): - """Do nothing for now.""" - - def mock_requests_get(url, stream=True): - """Mock the requests get method.""" - # pylint: disable=unused-argument - - return MockRequestsResponse() - - # Freeze the datetime.datetime.now() value - freezed_now = datetime.datetime.now(tz=datetime.timezone.utc) - - class MockDatetime: - """A mock class for a fixed datetime.now() value.""" - - @classmethod - def now(cls, **kwargs): - """Always return the same testable now value.""" - # pylint: disable=unused-argument - - return freezed_now - - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - # Apply monkeypatches - monkeypatch.setattr(storage.client, "post", mock_ovh_post) - monkeypatch.setattr(storage.client, "get", mock_ovh_get) - monkeypatch.setattr(requests, "get", mock_requests_get) - monkeypatch.setattr(datetime, "datetime", MockDatetime) - - fs.create_dir(settings.APP_DIR) - assert not os.path.exists(settings.HISTORY_FILE) - - result = b"".join(storage.read(name="5d5c4c93-04a4-42c5-9860-f51fa4044aa1")) - - assert os.path.exists(settings.HISTORY_FILE) - assert storage.history == [ - { - "backend": "ldp", - "command": "read", - "id": "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", - "filename": "2020-06-16.gz", - "size": 67906662, - "fetched_at": freezed_now.isoformat(), - } - ] - - assert json.loads(gzip_decode(result)) == archive_content - - -def test_backends_storage_ldp_write_method_with_details(): - """Test the LDPStorage write method.""" - storage = LDPStorage( - endpoint="ovh-eu", - application_key="fake_key", - application_secret="fake_secret", - consumer_key="another_fake_key", - service_name="ldp_fake", - stream_id="bbf2d9fb-b092-4003-958b-1262dc902a1c", - ) - - with pytest.raises( - NotImplementedError, - match="LDP storage backend is read-only, cannot write to fake", - ): - storage.write("truly", "fake", "content") diff --git a/tests/backends/storage/test_s3.py b/tests/backends/storage/test_s3.py deleted file mode 100644 index 31cecf833..000000000 --- a/tests/backends/storage/test_s3.py +++ /dev/null @@ -1,398 +0,0 @@ -"""Tests for Ralph S3 storage backend.""" - -import datetime -import json -import logging -import sys -from io import BytesIO - -import boto3 -import pytest -from moto import mock_s3 - -from ralph.conf import settings -from ralph.exceptions import BackendException, BackendParameterException - - -@mock_s3 -def test_backends_storage_s3_storage_instantiation_should_raise_exception( - s3, caplog -): # pylint:disable=invalid-name - """S3 backend instantiation test. - - Check that S3Storage raises BackendParameterException on failure. - """ - # Regions outside us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create an invalid bucket in Moto's 'virtual' AWS account - bucket_name = "my-test-bucket" - s3_client.create_bucket(Bucket=bucket_name) - - error = "Not Found" - caplog.set_level(logging.ERROR) - - with pytest.raises(BackendParameterException): - s3() - logger_name = "ralph.backends.storage.s3" - msg = f"Unable to connect to the requested bucket: {error}" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -@mock_s3 -def test_backends_storage_s3_storage_instantiation_failure_should_not_raise_exception( - s3, -): # pylint:disable=invalid-name - """S3 backend instantiation test. - - Check that S3Storage doesn't raise exceptions when the connection is - successful. - """ - # Regions outside us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - try: - s3() - except Exception: # pylint:disable=broad-except - pytest.fail("S3Storage should not raise exception on successful connection") - - -@mock_s3 -def test_backends_storage_s3_list_should_yield_archive_names( - moto_fs, s3, fs, settings_fs -): # pylint:disable=unused-argument, invalid-name - """S3 backend list test. - - Test that given S3Service.list method successfully connects to the S3 - storage, the S3Storage list method should yield the archives. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-04-29.gz", - Body=json.dumps({"id": "1", "foo": "bar"}), - ) - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-04-30.gz", - Body=json.dumps({"id": "2", "some": "data"}), - ) - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-10-01.gz", - Body=json.dumps({"id": "3", "other": "info"}), - ) - - listing = [ - {"name": "2022-04-29.gz"}, - {"name": "2022-04-30.gz"}, - {"name": "2022-10-01.gz"}, - ] - - history = [ - {"id": "2022-04-29.gz", "backend": "s3", "command": "read"}, - {"id": "2022-04-30.gz", "backend": "s3", "command": "read"}, - ] - - s3 = s3() - try: - response_list = s3.list() - response_list_new = s3.list(new=True) - response_list_details = s3.list(details=True) - except Exception: # pylint:disable=broad-except - pytest.fail("S3Storage should not raise exception on successful list") - - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - - assert list(response_list) == [x["name"] for x in listing] - assert list(response_list_new) == ["2022-10-01.gz"] - assert [x["Key"] for x in response_list_details] == [x["name"] for x in listing] - - -@mock_s3 -def test_backends_storage_s3_list_on_empty_bucket_should_do_nothing( - moto_fs, s3, fs -): # pylint:disable=unused-argument, invalid-name - """S3 backend list test. - - Test that given S3Service.list method successfully connects to the S3 - storage, the S3Storage list method on an empty bucket should do nothing. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - listing = [] - - history = [] - - s3 = s3() - try: - response_list = s3.list() - except Exception: # pylint:disable=broad-except - pytest.fail("S3Storage should not raise exception on successful list") - - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - - assert list(response_list) == [x["name"] for x in listing] - - -@mock_s3 -def test_backends_storage_s3_list_with_failed_connection_should_log_the_error( - moto_fs, s3, fs, caplog, settings_fs -): # pylint:disable=unused-argument, invalid-name - """S3 backend list test. - - Test that given S3Service.list method fails to retrieve the list of archives, - the S3Storage list method should log the error and raise a BackendException. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-04-29.gz", - Body=json.dumps({"id": "1", "foo": "bar"}), - ) - - s3 = s3() - s3.bucket_name = "wrong_name" - - fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) - caplog.set_level(logging.ERROR) - error = "The specified bucket does not exist" - msg = f"Failed to list the bucket wrong_name: {error}" - - with pytest.raises(BackendException, match=msg): - next(s3.list()) - with pytest.raises(BackendException, match=msg): - next(s3.list(new=True)) - with pytest.raises(BackendException, match=msg): - next(s3.list(details=True)) - logger_name = "ralph.backends.storage.s3" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] * 3 - - -@mock_s3 -def test_backends_storage_s3_read_with_valid_name_should_write_to_history( - moto_fs, s3, monkeypatch, fs, settings_fs -): # pylint:disable=unused-argument, invalid-name - """S3 backend read test. - - Test that given S3Service.download method successfully retrieves from the - S3 storage the object with the provided name (the object exists), - the S3Storage read method should write the entry to the history. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - body = b"some contents in the body" - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-09-29.gz", - Body=body, - ) - - freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - monkeypatch.setattr("ralph.backends.storage.s3.now", lambda: freezed_now) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) - - try: - s3 = s3() - list(s3.read("2022-09-29.gz")) - except Exception: # pylint:disable=broad-except - pytest.fail("S3Storage should not raise exception on successful read") - - assert s3.history == [ - { - "backend": "s3", - "command": "read", - "id": "2022-09-29.gz", - "size": len(body), - "fetched_at": freezed_now, - } - ] - - -@mock_s3 -def test_backends_storage_s3_read_with_invalid_name_should_log_the_error( - moto_fs, s3, fs, caplog, settings_fs -): # pylint:disable=unused-argument, invalid-name - """S3 backend read test. - - Test that given S3Service.download method fails to retrieve from the S3 - storage the object with the provided name (the object does not exists on S3), - the S3Storage read method should log the error, not write to history and raise a - BackendException. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - body = b"some contents in the body" - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-09-29.gz", - Body=body, - ) - - fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) - caplog.set_level(logging.ERROR) - error = "The specified key does not exist." - - with pytest.raises(BackendException): - s3 = s3() - list(s3.read("invalid_name.gz")) - logger_name = "ralph.backends.storage.s3" - msg = f"Failed to download invalid_name.gz: {error}" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - assert s3.history == [] - - -# pylint: disable=line-too-long -@pytest.mark.parametrize("overwrite", [False, True]) -@pytest.mark.parametrize("new_archive", [False, True]) -@mock_s3 -def test_backends_storage_s3_write_should_write_to_history_new_or_overwritten_archives( # noqa - moto_fs, overwrite, new_archive, s3, monkeypatch, fs, caplog, settings_fs -): # pylint:disable=unused-argument, invalid-name, too-many-arguments, too-many-locals - """S3 backend write test. - - Test that given S3Service list/upload method successfully connects to the - S3 storage, the S3Storage write method should update the history file when - overwrite is True or when the name of the archive is not in the history. - In case overwrite is False and the archive is in the history, the write method - should raise a FileExistsError. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - body = b"some contents in the body" - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-09-29.gz", - Body=body, - ) - - history = [ - {"id": "2022-09-29.gz", "backend": "s3", "command": "read"}, - {"id": "2022-09-30.gz", "backend": "s3", "command": "read"}, - ] - - freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - archive_name = "not_in_history.gz" if new_archive else "2022-09-29.gz" - new_history_entry = [ - { - "backend": "s3", - "command": "write", - "id": archive_name, - "pushed_at": freezed_now, - } - ] - - stream_content = b"some contents in the stream file to upload" - monkeypatch.setattr(sys, "stdin", BytesIO(stream_content)) - monkeypatch.setattr("ralph.backends.storage.s3.now", lambda: freezed_now) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - caplog.set_level(logging.ERROR) - - s3 = s3() - if not overwrite and not new_archive: - new_history_entry = [] - msg = f"{archive_name} already exists and overwrite is not allowed" - with pytest.raises(FileExistsError, match=msg): - s3.write(sys.stdin, archive_name, overwrite=overwrite) - logger_name = "ralph.backends.storage.s3" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - else: - s3.write(sys.stdin, archive_name, overwrite=overwrite) - assert s3.history == history + new_history_entry - - -@mock_s3 -def test_backends_storage_s3_write_should_log_the_error( - moto_fs, s3, monkeypatch, fs, caplog, settings_fs -): # pylint:disable=unused-argument, invalid-name,too-many-arguments - """S3 backend write test. - - Test that given S3Service.upload method fails to write the archive, - the S3Storage write method should log the error, raise a BackendException - and not write to history. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - body = b"some contents in the body" - - s3_client.put_object( - Bucket=bucket_name, - Key="2022-09-29.gz", - Body=body, - ) - - history = [ - {"id": "2022-09-29.gz", "backend": "s3", "command": "read"}, - {"id": "2022-09-30.gz", "backend": "s3", "command": "read"}, - ] - - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - caplog.set_level(logging.ERROR) - - s3 = s3() - - error = "Failed to upload" - - stream_content = b"some contents in the stream file to upload" - monkeypatch.setattr(sys, "stdin", BytesIO(stream_content)) - - with pytest.raises(BackendException): - s3.write(sys.stdin, "", overwrite=True) - logger_name = "ralph.backends.storage.s3" - assert caplog.record_tuples == [(logger_name, logging.ERROR, error)] - assert s3.history == history - - -@mock_s3 -def test_backends_storage_url_should_concatenate_the_storage_url_and_name( - s3, -): # pylint:disable=invalid-name - """S3 backend url test. - - Check the url method returns `bucket_name.s3.default_region - .amazonaws.com/name`. - """ - # Regions outside of us-east-1 require the appropriate LocationConstraint - s3_client = boto3.client("s3", region_name="us-east-1") - # Create a valid bucket in Moto's 'virtual' AWS account - bucket_name = "bucket_name" - s3_client.create_bucket(Bucket=bucket_name) - - assert s3().url("name") == "bucket_name.s3.default-region.amazonaws.com/name" diff --git a/tests/backends/storage/test_swift.py b/tests/backends/storage/test_swift.py deleted file mode 100644 index 404916042..000000000 --- a/tests/backends/storage/test_swift.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Tests for Ralph swift storage backend.""" - -import datetime -import json -import logging -import sys - -import pytest -from swiftclient.service import SwiftService - -from ralph.conf import settings -from ralph.exceptions import BackendException, BackendParameterException - - -def test_backends_storage_swift_storage_instantiation_failure_should_raise_exception( - monkeypatch, swift, caplog -): - """Check that SwiftStorage raises BackendParameterException on failure.""" - error = "Unauthorized. Check username/id" - - def mock_failed_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": False, "error": error} - - monkeypatch.setattr(SwiftService, "stat", mock_failed_stat) - caplog.set_level(logging.ERROR) - - with pytest.raises(BackendParameterException, match=error): - swift() - logger_name = "ralph.backends.storage.swift" - msg = f"Unable to connect to the requested container: {error}" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - - -def test_backends_storage_swift_storage_instantiation_should_not_raise_exception( - monkeypatch, swift -): - """Check that SwiftStorage doesn't raise exceptions when the connection is - successful. - """ - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - try: - swift() - except Exception: # pylint:disable=broad-except - pytest.fail("SwiftStorage should not raise exception on successful connection") - - -@pytest.mark.parametrize("pages_count", [1, 2]) -def test_backends_storage_swift_list_should_yield_archive_names( - pages_count, swift, monkeypatch, fs, settings_fs -): # pylint:disable=invalid-name,unused-argument - """Test that given SwiftService.list method successfully connects to the Swift - storage, the SwiftStorage list method should yield the archives. - """ - listing = [ - {"name": "2020-04-29.gz"}, - {"name": "2020-04-30.gz"}, - {"name": "2020-05-01.gz"}, - ] - history = [ - {"id": "2020-04-29.gz", "backend": "swift", "command": "read"}, - {"id": "2020-04-30.gz", "backend": "swift", "command": "read"}, - ] - - def mock_list_with_pages(*args, **kwargs): # pylint:disable=unused-argument - return [{"success": True, "listing": listing}] * pages_count - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "list", mock_list_with_pages) - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - swift = swift() - assert list(swift.list()) == [x["name"] for x in listing] * pages_count - assert list(swift.list(new=True)) == ["2020-05-01.gz"] * pages_count - assert list(swift.list(details=True)) == listing * pages_count - - -@pytest.mark.parametrize("pages_count", [1, 2]) -def test_backends_storage_swift_list_with_failed_connection_should_log_the_error( - pages_count, swift, monkeypatch, fs, caplog, settings_fs -): # pylint:disable=invalid-name,unused-argument,too-many-arguments - """Test that given SwiftService.list method fails to retrieve the list of archives, - the SwiftStorage list method should log the error and raise a BackendException. - """ - - def mock_list_with_pages(*args, **kwargs): # pylint:disable=unused-argument - return [ - { - "success": False, - "container": "ralph_logs_container", - "error": "Container not found", - } - ] * pages_count - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "list", mock_list_with_pages) - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) - caplog.set_level(logging.ERROR) - swift = swift() - msg = "Failed to list container ralph_logs_container: Container not found" - with pytest.raises(BackendException, match=msg): - next(swift.list()) - with pytest.raises(BackendException, match=msg): - next(swift.list(new=True)) - with pytest.raises(BackendException, match=msg): - next(swift.list(details=True)) - logger_name = "ralph.backends.storage.swift" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] * 3 - - -def test_backends_storage_swift_read_with_valid_name_should_write_to_history( - swift, monkeypatch, fs, settings_fs -): # pylint:disable=invalid-name,unused-argument - """Test that given SwiftService.download method successfully retrieves from the - Swift storage the object with the provided name (the object exists), - the SwiftStorage read method should write the entry to the history. - """ - - def mock_successful_download(*args, **kwargs): # pylint:disable=unused-argument - yield {"contents": [b"some", b"contents"]} - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - monkeypatch.setattr(SwiftService, "download", mock_successful_download) - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - monkeypatch.setattr("ralph.backends.storage.swift.now", lambda: freezed_now) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) - - swift = swift() - list(swift.read("2020-04-29.gz")) - assert swift.history == [ - { - "backend": "swift", - "command": "read", - "id": "2020-04-29.gz", - "size": 12, - "fetched_at": freezed_now, - } - ] - - -def test_backends_storage_swift_read_with_invalid_name_should_log_the_error( - swift, monkeypatch, fs, caplog, settings_fs -): # pylint:disable=invalid-name,unused-argument - """Test that given SwiftService.download method fails to retrieve from the Swift - storage the object with the provided name (the object does not exists on Swift), - the SwiftStorage read method should log the error, not write to history and raise a - BackendException. - """ - error = "ClientException Object GET failed" - - def mock_failed_download(*args, **kwargs): # pylint:disable=unused-argument - yield {"object": "2020-04-31.gz", "error": error} - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "download", mock_failed_download) - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps([])) - caplog.set_level(logging.ERROR) - - swift = swift() - msg = f"Failed to download 2020-04-31.gz: {error}" - with pytest.raises(BackendException, match=msg): - list(swift.read("2020-04-31.gz")) - logger_name = "ralph.backends.storage.swift" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - assert swift.history == [] - - -# pylint: disable=line-too-long -@pytest.mark.parametrize("overwrite", [False, True]) -@pytest.mark.parametrize("new_archive", [False, True]) -def test_backends_storage_swift_write_should_write_to_history_new_or_overwritten_archives( # noqa - overwrite, new_archive, swift, monkeypatch, fs, caplog, settings_fs -): # pylint:disable=invalid-name, too-many-arguments, too-many-locals,unused-argument - """Test that given SwiftService list/upload method successfully connects to the - Swift storage, the SwiftStorage write method should update the history file when - overwrite is True or when the name of the archive is not in the history. - In case overwrite is False and the archive is in the history, the write method - should raise a FileExistsError. - """ - history = [ - {"id": "2020-04-29.gz", "backend": "swift", "command": "read"}, - {"id": "2020-04-30.gz", "backend": "swift", "command": "read"}, - ] - listing = [ - {"name": "2020-04-29.gz"}, - {"name": "2020-04-30.gz"}, - {"name": "2020-05-01.gz"}, - ] - freezed_now = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - archive_name = "not_in_history.gz" if new_archive else "2020-04-29.gz" - new_history_entry = [ - { - "backend": "swift", - "command": "write", - "id": archive_name, - "pushed_at": freezed_now, - } - ] - - def mock_successful_upload(*args, **kwargs): # pylint:disable=unused-argument - yield {"success": True} - - def mock_successful_list(*args, **kwargs): # pylint:disable=unused-argument - return [{"success": True, "listing": listing}] - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "upload", mock_successful_upload) - monkeypatch.setattr(SwiftService, "list", mock_successful_list) - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - monkeypatch.setattr("ralph.backends.storage.swift.now", lambda: freezed_now) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - caplog.set_level(logging.ERROR) - - swift = swift() - if not overwrite and not new_archive: - new_history_entry = [] - msg = f"{archive_name} already exists and overwrite is not allowed" - with pytest.raises(FileExistsError, match=msg): - swift.write(sys.stdin.buffer, archive_name, overwrite=overwrite) - logger_name = "ralph.backends.storage.swift" - assert caplog.record_tuples == [(logger_name, logging.ERROR, msg)] - else: - swift.write(sys.stdin.buffer, archive_name, overwrite=overwrite) - assert swift.history == history + new_history_entry - - -def test_backends_storage_swift_write_should_log_the_error( - swift, monkeypatch, fs, caplog, settings_fs -): # pylint:disable=invalid-name,unused-argument - """Test that given SwiftService.upload method fails to write the archive, - the SwiftStorage write method should log the error, raise a BackendException - and not write to history. - """ - error = "Unauthorized. Check username/id, password" - history = [ - {"id": "2020-04-29.gz", "backend": "swift", "command": "read"}, - {"id": "2020-04-30.gz", "backend": "swift", "command": "read"}, - ] - listing = [ - {"name": "2020-04-29.gz"}, - {"name": "2020-04-30.gz"}, - {"name": "2020-05-01.gz"}, - ] - - def mock_failed_upload(*args, **kwargs): # pylint:disable=unused-argument - yield {"success": False, "error": error} - - def mock_successful_list(*args, **kwargs): # pylint:disable=unused-argument - return [{"success": True, "listing": listing}] - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "upload", mock_failed_upload) - monkeypatch.setattr(SwiftService, "list", mock_successful_list) - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - fs.create_file(settings.HISTORY_FILE, contents=json.dumps(history)) - caplog.set_level(logging.ERROR) - - swift = swift() - with pytest.raises(BackendException, match=error): - swift.write(sys.stdin.buffer, "2020-04-29.gz", overwrite=True) - logger_name = "ralph.backends.storage.swift" - assert caplog.record_tuples == [(logger_name, logging.ERROR, error)] - assert swift.history == history - - -def test_backends_storage_url_should_concatenate_the_storage_url_and_name( - swift, monkeypatch -): - """Check the url method returns `os_storage_url/name`.""" - - def mock_successful_stat(*args, **kwargs): # pylint:disable=unused-argument - return {"success": True} - - monkeypatch.setattr(SwiftService, "stat", mock_successful_stat) - assert swift().url("name") == "os_storage_url/name" diff --git a/tests/backends/stream/test_base.py b/tests/backends/stream/test_base.py index 923cf70ea..d6c4a3cb9 100644 --- a/tests/backends/stream/test_base.py +++ b/tests/backends/stream/test_base.py @@ -1,12 +1,12 @@ """Tests for Ralph base stream backend.""" -from ralph.backends.stream.base import BaseStream +from ralph.backends.stream.base import BaseStreamBackend def test_backends_stream_base_abstract_interface_with_implemented_abstract_method(): """Test the interface mechanism with properly implemented abstract methods.""" - class GoodStream(BaseStream): + class GoodStream(BaseStreamBackend): """Correct implementation with required abstract methods.""" name = "good" diff --git a/tests/backends/test_conf.py b/tests/backends/test_conf.py new file mode 100644 index 000000000..aa4a6dd09 --- /dev/null +++ b/tests/backends/test_conf.py @@ -0,0 +1,144 @@ +"""Tests for Ralph's backends configuration loading.""" + +from pathlib import PosixPath + +import pytest +from pydantic import ValidationError + +from ralph.backends.conf import Backends, BackendSettings, DataBackendSettings +from ralph.backends.data.es import ESDataBackendSettings + + +def test_conf_settings_field_value_priority(fs, monkeypatch): + """Test that the BackendSettings object field values are defined in the following + descending order of priority: + + 1. Arguments passed to the initializer. + 2. Environment variables. + 3. Dotenv variables (.env) + 4. Default values. + """ + # pylint: disable=invalid-name + + # 4. Using default value. + assert str(BackendSettings().BACKENDS.DATA.ES.LOCALE_ENCODING) == "utf8" + + # 3. Using dotenv variables (overrides default value). + fs.create_file(".env", contents="RALPH_BACKENDS__DATA__ES__LOCALE_ENCODING=toto\n") + assert str(BackendSettings().BACKENDS.DATA.ES.LOCALE_ENCODING) == "toto" + + # 2. Using environment variable value (overrides dotenv value). + monkeypatch.setenv("RALPH_BACKENDS__DATA__ES__LOCALE_ENCODING", "foo") + assert str(BackendSettings().BACKENDS.DATA.ES.LOCALE_ENCODING) == "foo" + + # 1. Using argument value (overrides environment value). + assert ( + str( + BackendSettings( + BACKENDS=Backends( + DATA=DataBackendSettings( + ES=ESDataBackendSettings(LOCALE_ENCODING="bar") + ) + ) + ).BACKENDS.DATA.ES.LOCALE_ENCODING + ) + == "bar" + ) + + +@pytest.mark.parametrize( + "ca_certs,verify_certs,expected", + [ + ("/path", "True", {"ca_certs": PosixPath("/path"), "verify_certs": True}), + ("/path2", "f", {"ca_certs": PosixPath("/path2"), "verify_certs": False}), + (None, None, {"ca_certs": None, "verify_certs": None}), + ], +) +def test_conf_es_client_options_with_valid_values( + ca_certs, verify_certs, expected, monkeypatch +): + """Test the ESClientOptions pydantic data type with valid values.""" + # Using None here as in "not set by user" + if ca_certs is not None: + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__ca_certs", f"{ca_certs}" + ) + # Using None here as in "not set by user" + if verify_certs is not None: + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__verify_certs", + f"{verify_certs}", + ) + assert BackendSettings().BACKENDS.DATA.ES.CLIENT_OPTIONS.dict() == expected + + +@pytest.mark.parametrize( + "ca_certs,verify_certs", + [ + ("/path", 3), + ("/path", None), + ], +) +def test_conf_es_client_options_with_invalid_values( + ca_certs, verify_certs, monkeypatch +): + """Test the ESClientOptions pydantic data type with invalid values.""" + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__ca_certs", f"{ca_certs}" + ) + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__verify_certs", + f"{verify_certs}", + ) + with pytest.raises(ValidationError, match="1 validation error for"): + BackendSettings().BACKENDS.DATA.ES.CLIENT_OPTIONS.dict() + + +@pytest.mark.parametrize( + "document_class,tz_aware,expected", + [ + ("dict", "True", {"document_class": "dict", "tz_aware": True}), + ("str", "f", {"document_class": "str", "tz_aware": False}), + (None, None, {"document_class": None, "tz_aware": None}), + ], +) +def test_conf_mongo_client_options_with_valid_values( + document_class, tz_aware, expected, monkeypatch +): + """Test the MongoClientOptions pydantic data type with valid values.""" + # Using None here as in "not set by user" + if document_class is not None: + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__document_class", + f"{document_class}", + ) + # Using None here as in "not set by user" + if tz_aware is not None: + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__tz_aware", + f"{tz_aware}", + ) + assert BackendSettings().BACKENDS.DATA.MONGO.CLIENT_OPTIONS.dict() == expected + + +@pytest.mark.parametrize( + "document_class,tz_aware", + [ + ("dict", 3), + ("str", None), + ], +) +def test_conf_mongo_client_options_with_invalid_values( + document_class, tz_aware, monkeypatch +): + """Test the MongoClientOptions pydantic data type with invalid values.""" + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__document_class", + f"{document_class}", + ) + monkeypatch.setenv( + "RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__tz_aware", + f"{tz_aware}", + ) + with pytest.raises(ValidationError, match="1 validation error for"): + BackendSettings().BACKENDS.DATA.MONGO.CLIENT_OPTIONS.dict() diff --git a/tests/conftest.py b/tests/conftest.py index 77a5a6d25..10b819ee3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,10 +36,8 @@ mongo_forwarding, mongo_lrs_backend, moto_fs, - s3, s3_backend, settings_fs, - swift, swift_backend, ws, ) diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index d99a897f0..1acf1bf61 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -23,69 +23,65 @@ from ralph.backends.data.async_es import AsyncESDataBackend from ralph.backends.data.async_mongo import AsyncMongoDataBackend -from ralph.backends.data.clickhouse import ClickHouseDataBackend +from ralph.backends.data.clickhouse import ( + ClickHouseClientOptions, + ClickHouseDataBackend, +) from ralph.backends.data.es import ESDataBackend from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.data.ldp import LDPDataBackend from ralph.backends.data.mongo import MongoDataBackend from ralph.backends.data.s3 import S3DataBackend, S3DataBackendSettings from ralph.backends.data.swift import SwiftDataBackend, SwiftDataBackendSettings -from ralph.backends.database.clickhouse import ClickHouseDatabase -from ralph.backends.database.es import ESDatabase -from ralph.backends.database.mongo import MongoDatabase from ralph.backends.lrs.async_es import AsyncESLRSBackend from ralph.backends.lrs.async_mongo import AsyncMongoLRSBackend from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.fs import FSLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend -from ralph.backends.storage.s3 import S3Storage -from ralph.backends.storage.swift import SwiftStorage -from ralph.conf import ClickhouseClientOptions, Settings, core_settings +from ralph.conf import Settings, core_settings # ClickHouse backend defaults CLICKHOUSE_TEST_DATABASE = os.environ.get( - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_DATABASE", "test_statements" + "RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_DATABASE", "test_statements" ) CLICKHOUSE_TEST_HOST = os.environ.get( - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_HOST", "localhost" + "RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_HOST", "localhost" ) CLICKHOUSE_TEST_PORT = os.environ.get( - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_PORT", 8123 + "RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_PORT", 8123 ) CLICKHOUSE_TEST_TABLE_NAME = os.environ.get( - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__TEST_TABLE_NAME", "test_xapi_events_all" + "RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_TABLE_NAME", "test_xapi_events_all" ) # Elasticsearch backend defaults -ES_TEST_INDEX = os.environ.get( - "RALPH_BACKENDS__DATABASE__ES__TEST_INDEX", "test-index-foo" -) +ES_TEST_INDEX = os.environ.get("RALPH_BACKENDS__DATA__ES__TEST_INDEX", "test-index-foo") ES_TEST_FORWARDING_INDEX = os.environ.get( - "RALPH_BACKENDS__DATABASE__ES__TEST_FORWARDING_INDEX", "test-index-foo-2" + "RALPH_BACKENDS__DATA__ES__TEST_FORWARDING_INDEX", "test-index-foo-2" ) ES_TEST_INDEX_TEMPLATE = os.environ.get( - "RALPH_BACKENDS__DATABASE__ES__INDEX_TEMPLATE", "test-index" + "RALPH_BACKENDS__DATA__ES__INDEX_TEMPLATE", "test-index" ) ES_TEST_INDEX_PATTERN = os.environ.get( - "RALPH_BACKENDS__DATABASE__ES__TEST_INDEX_PATTERN", "test-index-*" + "RALPH_BACKENDS__DATA__ES__TEST_INDEX_PATTERN", "test-index-*" ) ES_TEST_HOSTS = os.environ.get( - "RALPH_BACKENDS__DATABASE__ES__TEST_HOSTS", "http://localhost:9200" + "RALPH_BACKENDS__DATA__ES__TEST_HOSTS", "http://localhost:9200" ).split(",") # Mongo backend defaults MONGO_TEST_COLLECTION = os.environ.get( - "RALPH_BACKENDS__DATABASE__MONGO__TEST_COLLECTION", "marsha" + "RALPH_BACKENDS__DATA__MONGO__TEST_COLLECTION", "marsha" ) MONGO_TEST_FORWARDING_COLLECTION = os.environ.get( - "RALPH_BACKENDS__DATABASE__MONGO__TEST_FORWARDING_COLLECTION", "marsha-2" + "RALPH_BACKENDS__DATA__MONGO__TEST_FORWARDING_COLLECTION", "marsha-2" ) MONGO_TEST_DATABASE = os.environ.get( - "RALPH_BACKENDS__DATABASE__MONGO__TEST_DATABASE", "statements" + "RALPH_BACKENDS__DATA__MONGO__TEST_DATABASE", "statements" ) MONGO_TEST_CONNECTION_URI = os.environ.get( - "RALPH_BACKENDS__DATABASE__MONGO__TEST_CONNECTION_URI", "mongodb://localhost:27017/" + "RALPH_BACKENDS__DATA__MONGO__TEST_CONNECTION_URI", "mongodb://localhost:27017/" ) RUNSERVER_TEST_HOST = os.environ.get("RALPH_RUNSERVER_TEST_HOST", "0.0.0.0") @@ -98,8 +94,9 @@ @lru_cache() def get_clickhouse_test_backend(): - """Return a ClickHouseDatabase backend instance using test defaults.""" - return ClickHouseDatabase( + """Return a ClickHouseLRSBackend backend instance using test defaults.""" + + return ClickHouseLRSBackend( host=CLICKHOUSE_TEST_HOST, port=CLICKHOUSE_TEST_PORT, database=CLICKHOUSE_TEST_DATABASE, @@ -109,18 +106,19 @@ def get_clickhouse_test_backend(): @lru_cache def get_es_test_backend(): - """Return a ESDatabase backend instance using test defaults.""" - return ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) + """Return a ESLRSBackend backend instance using test defaults.""" + return ESLRSBackend(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) @lru_cache def get_mongo_test_backend(): """Returns a MongoDatabase backend instance using test defaults.""" - return MongoDatabase( + settings = MongoLRSBackend.settings_class( connection_uri=MONGO_TEST_CONNECTION_URI, database=MONGO_TEST_DATABASE, collection=MONGO_TEST_COLLECTION, ) + return MongoLRSBackend(settings) def get_es_fixture(host=ES_TEST_HOSTS, index=ES_TEST_INDEX): @@ -331,7 +329,7 @@ def get_clickhouse_fixture( """Create / delete a ClickHouse test database + table and yields an instantiated client. """ - client_options = ClickhouseClientOptions( + client_options = ClickHouseClientOptions( date_time_input_format="best_effort", # Allows RFC dates allow_experimental_object_type=1, # Allows JSON data type ).dict() @@ -600,24 +598,6 @@ def get_es_lrs_backend(index: str = ES_TEST_INDEX): return get_es_lrs_backend -@pytest.fixture -def swift(): - """Return get_swift_storage function.""" - - def get_swift_storage(): - """Returns an instance of SwiftStorage.""" - return SwiftStorage( - os_tenant_id="os_tenant_id", - os_tenant_name="os_tenant_name", - os_username="os_username", - os_password="os_password", - os_region_name="os_region_name", - os_storage_url="os_storage_url/ralph_logs_container", - ) - - return get_swift_storage - - @pytest.fixture def swift_backend(): """Return get_swift_data_backend function.""" @@ -653,26 +633,6 @@ def moto_fs(fs): fs.add_real_directory(module_dir, lazy_read=False) -@pytest.fixture -def s3(): - """Return get_s3_storage function.""" - # pylint:disable=invalid-name - - def get_s3_storage(): - """Returns an instance of S3Storage.""" - - return S3Storage( - access_key_id="access_key_id", - secret_access_key="secret_access_key", - session_token="session_token", - default_region="default-region", - bucket_name="bucket_name", - endpoint_url=None, - ) - - return get_s3_storage - - @pytest.fixture def s3_backend(): """Return the `get_s3_data_backend` function.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 4cc2f1af9..5fec60bee 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,8 +11,9 @@ from hypothesis import settings as hypothesis_settings from pydantic import ValidationError -from ralph.backends.storage.fs import FSStorage -from ralph.backends.storage.ldp import LDPStorage +from ralph.backends.conf import backends_settings +from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings +from ralph.backends.data.ldp import LDPDataBackend from ralph.cli import ( CommaSeparatedKeyValueParamType, CommaSeparatedTupleParamType, @@ -533,13 +534,13 @@ def test_cli_read_command_with_ldp_backend(monkeypatch): """Test the read command using the LDP backend.""" archive_content = {"foo": "bar"} - def mock_read(this, name, chunk_size=500): + def mock_read(*_, **__): """Always return the same archive.""" # pylint: disable=unused-argument yield bytes(json.dumps(archive_content), encoding="utf-8") - monkeypatch.setattr(LDPStorage, "read", mock_read) + monkeypatch.setattr(LDPDataBackend, "read", mock_read) runner = CliRunner() command = "read -b ldp --ldp-endpoint ovh-eu a547d9b3-6f2f-4913-a872-cf4efe699a66" @@ -555,13 +556,12 @@ def test_cli_read_command_with_fs_backend(fs, monkeypatch): """Test the read command using the FS backend.""" archive_content = {"foo": "bar"} - def mock_read(this, name, chunk_size): + def mock_read(*_, **__): """Always return the same archive.""" - # pylint: disable=unused-argument yield bytes(json.dumps(archive_content), encoding="utf-8") - monkeypatch.setattr(FSStorage, "read", mock_read) + monkeypatch.setattr(FSDataBackend, "read", mock_read) runner = CliRunner() result = runner.invoke(cli, "read -b fs foo".split()) @@ -589,7 +589,8 @@ def test_cli_read_command_with_es_backend(es): runner = CliRunner() es_hosts = ",".join(ES_TEST_HOSTS) es_client_options = "verify_certs=True" - command = f"""-v ERROR read -b es --es-hosts {es_hosts} --es-index {ES_TEST_INDEX} + command = f"""-v ERROR read -b es --es-hosts {es_hosts} + --es-default-index {ES_TEST_INDEX} --es-client-options {es_client_options}""" result = runner.invoke(cli, command.split()) assert result.exit_code == 0 @@ -648,7 +649,7 @@ def test_cli_read_command_with_es_backend_query(es): runner = CliRunner() es_hosts = ",".join(ES_TEST_HOSTS) - query = {"query": {"query": {"term": {"modulo": 0}}}} + query = {"query": {"term": {"modulo": 0}}} query_str = json.dumps(query, separators=(",", ":")) command = ( @@ -656,7 +657,7 @@ def test_cli_read_command_with_es_backend_query(es): "read " "-b es " f"--es-hosts {es_hosts} " - f"--es-index {ES_TEST_INDEX} " + f"--es-default-index {ES_TEST_INDEX} " f"--query {query_str}" ) result = runner.invoke(cli, command.split()) @@ -721,7 +722,7 @@ def test_cli_list_command_with_ldp_backend(monkeypatch): }, ] - def mock_list(this, details=False, new=False): + def mock_list(this, target=None, details=False, new=False): """Mock LDP backend list method.""" # pylint: disable=unused-argument @@ -732,16 +733,16 @@ def mock_list(this, details=False, new=False): response = response[1:] return response - monkeypatch.setattr(LDPStorage, "list", mock_list) + monkeypatch.setattr(LDPDataBackend, "list", mock_list) runner = CliRunner() - # List archives with default options + # List documents with default options result = runner.invoke(cli, ["list", "-b", "ldp", "--ldp-endpoint", "ovh-eu"]) assert result.exit_code == 0 assert "\n".join(archive_list) in result.output - # List archives with detailed output + # List documents with detailed output result = runner.invoke(cli, ["list", "-b", "ldp", "--ldp-endpoint", "ovh-eu", "-D"]) assert result.exit_code == 0 assert ( @@ -749,17 +750,17 @@ def mock_list(this, details=False, new=False): in result.output ) - # List new archives only + # List new documents only result = runner.invoke(cli, ["list", "-b", "ldp", "--ldp-endpoint", "ovh-eu", "-n"]) assert result.exit_code == 0 assert "997db3eb-b9ca-485d-810f-b530a6cef7c6" in result.output assert "5d5c4c93-04a4-42c5-9860-f51fa4044aa1" not in result.output - # Edge case: stream contains no archive - monkeypatch.setattr(LDPStorage, "list", lambda this, details, new: ()) + # Edge case: stream contains no document + monkeypatch.setattr(LDPDataBackend, "list", lambda this, target, details, new: ()) result = runner.invoke(cli, ["list", "-b", "ldp", "--ldp-endpoint", "ovh-eu"]) assert result.exit_code == 0 - assert "Configured ldp backend contains no archive" in result.output + assert "Configured ldp backend contains no document" in result.output # pylint: disable=invalid-name @@ -783,7 +784,7 @@ def test_cli_list_command_with_fs_backend(fs, monkeypatch): }, ] - def mock_list(this, details=False, new=False): + def mock_list(this, target=None, details=False, new=False): """Mock LDP backend list method.""" # pylint: disable=unused-argument @@ -794,16 +795,16 @@ def mock_list(this, details=False, new=False): response = response[1:] return response - monkeypatch.setattr(FSStorage, "list", mock_list) + monkeypatch.setattr(FSDataBackend, "list", mock_list) runner = CliRunner() - # List archives with default options + # List documents with default options result = runner.invoke(cli, ["list", "-b", "fs"]) assert result.exit_code == 0 assert "\n".join(archive_list) in result.output - # List archives with detailed output + # List documents with detailed output result = runner.invoke(cli, ["list", "-b", "fs", "-D"]) assert result.exit_code == 0 assert ( @@ -811,17 +812,17 @@ def mock_list(this, details=False, new=False): in result.output ) - # List new archives only + # List new documents only result = runner.invoke(cli, ["list", "-b", "fs", "-n"]) assert result.exit_code == 0 assert "file2" in result.output assert "file1" not in result.output - # Edge case: stream contains no archive - monkeypatch.setattr(FSStorage, "list", lambda this, details, new: ()) + # Edge case: stream contains no document + monkeypatch.setattr(FSDataBackend, "list", lambda this, target, details, new: ()) result = runner.invoke(cli, ["list", "-b", "fs"]) assert result.exit_code == 0 - assert "Configured fs backend contains no archive" in result.output + assert "Configured fs backend contains no document" in result.output # pylint: disable=invalid-name @@ -830,35 +831,37 @@ def test_cli_write_command_with_fs_backend(fs): fs.create_dir(str(settings.APP_DIR)) filename = Path("file1") - file_path = Path(settings.BACKENDS.STORAGE.FS.PATH) / filename + file_path = Path(FSDataBackendSettings().DEFAULT_DIRECTORY_PATH) / filename # Create a file runner = CliRunner() - result = runner.invoke(cli, "write -b fs file1".split(), input="test content") + result = runner.invoke(cli, "write -b fs -t file1".split(), input=b"test content") assert result.exit_code == 0 - with file_path.open("r", encoding=settings.LOCALE_ENCODING) as test_file: + with file_path.open("rb") as test_file: content = test_file.read() - assert "test content" in content + assert b"test content" in content # Trying to create the same file without -f should raise an error runner = CliRunner() - result = runner.invoke(cli, "write -b fs file1".split(), input="other content") + result = runner.invoke(cli, "write -b fs -t file1".split(), input=b"other content") assert result.exit_code == 1 assert "file1 already exists and overwrite is not allowed" in result.output # Try to create the same file with -f runner = CliRunner() - result = runner.invoke(cli, "write -b fs -f file1".split(), input="other content") + result = runner.invoke( + cli, "write -b fs -t file1 -f".split(), input=b"other content" + ) assert result.exit_code == 0 - with file_path.open("r", encoding=settings.LOCALE_ENCODING) as test_file: + with file_path.open("rb") as test_file: content = test_file.read() - assert "other content" in content + assert b"other content" in content def test_cli_write_command_with_es_backend(es): @@ -872,7 +875,7 @@ def test_cli_write_command_with_es_backend(es): es_hosts = ",".join(ES_TEST_HOSTS) result = runner.invoke( cli, - f"write -b es --es-hosts {es_hosts} --es-index {ES_TEST_INDEX}".split(), + f"write -b es --es-hosts {es_hosts} --es-default-index {ES_TEST_INDEX}".split(), input="\n".join(json.dumps(record) for record in records), ) assert result.exit_code == 0 @@ -911,33 +914,53 @@ def mock_uvicorn_run(_, env_file=None, **kwargs): with open(env_file, mode="r", encoding=settings.LOCALE_ENCODING) as file: env_lines = [ f"RALPH_RUNSERVER_BACKEND={settings.RUNSERVER_BACKEND}\n", - "RALPH_BACKENDS__DATABASE__ES__INDEX=foo\n", - "RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__verify_certs=True\n", - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__EVENT_TABLE_NAME=" - f"{settings.BACKENDS.DATABASE.CLICKHOUSE.EVENT_TABLE_NAME}\n", - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__DATABASE=" - f"{settings.BACKENDS.DATABASE.CLICKHOUSE.DATABASE}\n", - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__PORT=" - f"{settings.BACKENDS.DATABASE.CLICKHOUSE.PORT}\n", - "RALPH_BACKENDS__DATABASE__CLICKHOUSE__HOST=" - f"{settings.BACKENDS.DATABASE.CLICKHOUSE.HOST}\n", - "RALPH_BACKENDS__DATABASE__MONGO__COLLECTION=" - f"{settings.BACKENDS.DATABASE.MONGO.COLLECTION}\n", - "RALPH_BACKENDS__DATABASE__MONGO__DATABASE=" - f"{settings.BACKENDS.DATABASE.MONGO.DATABASE}\n", - "RALPH_BACKENDS__DATABASE__MONGO__CONNECTION_URI=" - f"{settings.BACKENDS.DATABASE.MONGO.CONNECTION_URI}\n", - "RALPH_BACKENDS__DATABASE__ES__OP_TYPE=" - f"{settings.BACKENDS.DATABASE.ES.OP_TYPE}\n", - "RALPH_BACKENDS__DATABASE__ES__HOSTS=" - f"{','.join(settings.BACKENDS.DATABASE.ES.HOSTS)}\n", + "RALPH_BACKENDS__LRS__ES__DEFAULT_INDEX=foo\n", + "RALPH_BACKENDS__LRS__ES__CLIENT_OPTIONS__verify_certs=True\n", + "RALPH_BACKENDS__LRS__MONGO__DEFAULT_CHUNK_SIZE=" + f"{backends_settings.BACKENDS.LRS.MONGO.DEFAULT_CHUNK_SIZE}\n", + "RALPH_BACKENDS__LRS__MONGO__DEFAULT_COLLECTION=" + f"{backends_settings.BACKENDS.LRS.MONGO.DEFAULT_COLLECTION}\n", + "RALPH_BACKENDS__LRS__MONGO__DEFAULT_DATABASE=" + f"{backends_settings.BACKENDS.LRS.MONGO.DEFAULT_DATABASE}\n", + "RALPH_BACKENDS__LRS__MONGO__CONNECTION_URI=" + f"{backends_settings.BACKENDS.LRS.MONGO.CONNECTION_URI}\n", + "RALPH_BACKENDS__LRS__FS__DEFAULT_LRS_FILE=" + f"{backends_settings.BACKENDS.LRS.FS.DEFAULT_LRS_FILE}\n", + "RALPH_BACKENDS__LRS__FS__DEFAULT_QUERY_STRING=" + f"{backends_settings.BACKENDS.LRS.FS.DEFAULT_QUERY_STRING}\n", + "RALPH_BACKENDS__LRS__FS__DEFAULT_DIRECTORY_PATH=" + f"{backends_settings.BACKENDS.LRS.FS.DEFAULT_DIRECTORY_PATH}\n", + "RALPH_BACKENDS__LRS__FS__DEFAULT_CHUNK_SIZE=" + f"{backends_settings.BACKENDS.LRS.FS.DEFAULT_CHUNK_SIZE}\n", + "RALPH_BACKENDS__LRS__ES__POINT_IN_TIME_KEEP_ALIVE=" + f"{backends_settings.BACKENDS.LRS.ES.POINT_IN_TIME_KEEP_ALIVE}\n", + "RALPH_BACKENDS__LRS__ES__HOSTS=" + f"{','.join(backends_settings.BACKENDS.LRS.ES.HOSTS)}\n", + "RALPH_BACKENDS__LRS__ES__DEFAULT_CHUNK_SIZE=" + f"{backends_settings.BACKENDS.LRS.ES.DEFAULT_CHUNK_SIZE}\n", + "RALPH_BACKENDS__LRS__ES__ALLOW_YELLOW_STATUS=" + f"{backends_settings.BACKENDS.LRS.ES.ALLOW_YELLOW_STATUS}\n", + "RALPH_BACKENDS__LRS__CLICKHOUSE__IDS_CHUNK_SIZE=" + f"{backends_settings.BACKENDS.LRS.CLICKHOUSE.IDS_CHUNK_SIZE}\n", + "RALPH_BACKENDS__LRS__CLICKHOUSE__DEFAULT_CHUNK_SIZE=" + f"{backends_settings.BACKENDS.LRS.CLICKHOUSE.DEFAULT_CHUNK_SIZE}\n", + "RALPH_BACKENDS__LRS__CLICKHOUSE__EVENT_TABLE_NAME=" + f"{backends_settings.BACKENDS.LRS.CLICKHOUSE.EVENT_TABLE_NAME}\n", + "RALPH_BACKENDS__LRS__CLICKHOUSE__DATABASE=" + f"{backends_settings.BACKENDS.LRS.CLICKHOUSE.DATABASE}\n", + "RALPH_BACKENDS__LRS__CLICKHOUSE__PORT=" + f"{backends_settings.BACKENDS.LRS.CLICKHOUSE.PORT}\n", + "RALPH_BACKENDS__LRS__CLICKHOUSE__HOST=" + f"{backends_settings.BACKENDS.LRS.CLICKHOUSE.HOST}\n", ] - assert file.readlines() == env_lines + env_lines_created = file.readlines() + assert all(line in env_lines_created for line in env_lines) monkeypatch.setattr("ralph.cli.uvicorn.run", mock_uvicorn_run) runner = CliRunner() result = runner.invoke( cli, - "runserver -b es --es-index foo --es-client-options verify_certs=True".split(), + "runserver -b es --es-default-index foo " + "--es-client-options verify_certs=True".split(), ) assert result.exit_code == 0 diff --git a/tests/test_cli_usage.py b/tests/test_cli_usage.py index baa4dc330..6101e90f3 100644 --- a/tests/test_cli_usage.py +++ b/tests/test_cli_usage.py @@ -108,46 +108,75 @@ def test_cli_read_command_usage(): assert result.exit_code == 0 assert ( + "Usage: ralph read [OPTIONS] [ARCHIVE]\n\n" + " Read an archive or records from a configured backend.\n\n" "Options:\n" - " -b, --backend [es|mongo|clickhouse|lrs|ldp|fs|swift|s3|ws]\n" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|swift|s3|lrs|" + "ws]\n" " Backend [required]\n" " ws backend: \n" " --ws-uri TEXT\n" + " lrs backend: \n" + " --lrs-statements-endpoint TEXT\n" + " --lrs-status-endpoint TEXT\n" + " --lrs-headers KEY=VALUE,KEY=VALUE\n" + " --lrs-password TEXT\n" + " --lrs-username TEXT\n" + " --lrs-base-url TEXT\n" " s3 backend: \n" - " --s3-endpoint-url TEXT\n" - " --s3-bucket-name TEXT\n" + " --s3-locale-encoding TEXT\n" + " --s3-default-chunk-size INTEGER\n" + " --s3-default-bucket-name TEXT\n" " --s3-default-region TEXT\n" + " --s3-endpoint-url TEXT\n" " --s3-session-token TEXT\n" " --s3-secret-access-key TEXT\n" " --s3-access-key-id TEXT\n" " swift backend: \n" - " --swift-os-identity-api-version TEXT\n" - " --swift-os-auth-url TEXT\n" - " --swift-os-project-domain-name TEXT\n" - " --swift-os-user-domain-name TEXT\n" - " --swift-os-storage-url TEXT\n" - " --swift-os-region-name TEXT\n" - " --swift-os-password TEXT\n" - " --swift-os-username TEXT\n" - " --swift-os-tenant-name TEXT\n" - " --swift-os-tenant-id TEXT\n" - " fs backend: \n" - " --fs-path TEXT\n" + " --swift-locale-encoding TEXT\n" + " --swift-default-container TEXT\n" + " --swift-user-domain-name TEXT\n" + " --swift-object-storage-url TEXT\n" + " --swift-region-name TEXT\n" + " --swift-project-domain-name TEXT\n" + " --swift-tenant-name TEXT\n" + " --swift-tenant-id TEXT\n" + " --swift-identity-api-version TEXT\n" + " --swift-password TEXT\n" + " --swift-username TEXT\n" + " --swift-auth-url TEXT\n" + " mongo backend: \n" + " --mongo-locale-encoding TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-connection-uri TEXT\n" " ldp backend: \n" - " --ldp-stream-id TEXT\n" " --ldp-service-name TEXT\n" + " --ldp-request-timeout TEXT\n" + " --ldp-endpoint TEXT\n" + " --ldp-default-stream-id TEXT\n" " --ldp-consumer-key TEXT\n" " --ldp-application-secret TEXT\n" " --ldp-application-key TEXT\n" - " --ldp-endpoint TEXT\n" - " lrs backend: \n" - " --lrs-statements-endpoint TEXT\n" - " --lrs-status-endpoint TEXT\n" - " --lrs-headers KEY=VALUE,KEY=VALUE\n" - " --lrs-password TEXT\n" - " --lrs-username TEXT\n" - " --lrs-base-url TEXT\n" + " fs backend: \n" + " --fs-locale-encoding TEXT\n" + " --fs-default-query-string TEXT\n" + " --fs-default-directory-path PATH\n" + " --fs-default-chunk-size INTEGER\n" + " es backend: \n" + " --es-refresh-after-write TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-locale-encoding TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-default-index TEXT\n" + " --es-default-chunk-size INTEGER\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" " clickhouse backend: \n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" " --clickhouse-password TEXT\n" " --clickhouse-username TEXT\n" @@ -155,31 +184,41 @@ def test_cli_read_command_usage(): " --clickhouse-database TEXT\n" " --clickhouse-port INTEGER\n" " --clickhouse-host TEXT\n" - " mongo backend: \n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-collection TEXT\n" - " --mongo-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " es backend: \n" - " --es-op-type TEXT\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-index TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" + " async_mongo backend: \n" + " --async-mongo-locale-encoding TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-connection-uri TEXT\n" + " async_es backend: \n" + " --async-es-refresh-after-write TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-default-index TEXT\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" " -c, --chunk-size INTEGER Get events by chunks of size #\n" " -t, --target TEXT Endpoint from which to read events (e.g.\n" " `/statements`)\n" ' -q, --query \'{"KEY": "VALUE", "KEY": "VALUE"}\'\n' - " Query object as a JSON string (database " - "and\n" + " Query object as a JSON string (database and" + "\n" " HTTP backends ONLY)\n" + " -i, --ignore_errors BOOLEAN Ignore errors during the encoding operation." + "\n" + " [default: False]\n" + " --help Show this message and exit." ) in result.output logging.warning(result.output) result = runner.invoke(cli, ["read"]) assert result.exit_code > 0 assert ( "Error: Missing option '-b' / '--backend'. " - "Choose from:\n\tes,\n\tmongo,\n\tclickhouse,\n\tlrs,\n\tldp,\n\tfs,\n\tswift," - "\n\ts3,\n\tws\n" + "Choose from:\n\tasync_es,\n\tasync_mongo,\n\tclickhouse,\n\tes,\n\tfs,\n\tldp," + "\n\tmongo,\n\tswift,\n\ts3,\n\tlrs,\n\tws\n" ) in result.output @@ -190,45 +229,100 @@ def test_cli_list_command_usage(): assert result.exit_code == 0 assert ( + "Usage: ralph list [OPTIONS]\n\n" + " List available documents from a configured data backend.\n\n" "Options:\n" - " -b, --backend [ldp|fs|swift|s3]\n" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|swift|s3]\n" " Backend [required]\n" " s3 backend: \n" - " --s3-endpoint-url TEXT\n" - " --s3-bucket-name TEXT\n" + " --s3-locale-encoding TEXT\n" + " --s3-default-chunk-size INTEGER\n" + " --s3-default-bucket-name TEXT\n" " --s3-default-region TEXT\n" + " --s3-endpoint-url TEXT\n" " --s3-session-token TEXT\n" " --s3-secret-access-key TEXT\n" " --s3-access-key-id TEXT\n" " swift backend: \n" - " --swift-os-identity-api-version TEXT\n" - " --swift-os-auth-url TEXT\n" - " --swift-os-project-domain-name TEXT\n" - " --swift-os-user-domain-name TEXT\n" - " --swift-os-storage-url TEXT\n" - " --swift-os-region-name TEXT\n" - " --swift-os-password TEXT\n" - " --swift-os-username TEXT\n" - " --swift-os-tenant-name TEXT\n" - " --swift-os-tenant-id TEXT\n" - " fs backend: \n" - " --fs-path TEXT\n" + " --swift-locale-encoding TEXT\n" + " --swift-default-container TEXT\n" + " --swift-user-domain-name TEXT\n" + " --swift-object-storage-url TEXT\n" + " --swift-region-name TEXT\n" + " --swift-project-domain-name TEXT\n" + " --swift-tenant-name TEXT\n" + " --swift-tenant-id TEXT\n" + " --swift-identity-api-version TEXT\n" + " --swift-password TEXT\n" + " --swift-username TEXT\n" + " --swift-auth-url TEXT\n" + " mongo backend: \n" + " --mongo-locale-encoding TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-connection-uri TEXT\n" " ldp backend: \n" - " --ldp-stream-id TEXT\n" " --ldp-service-name TEXT\n" + " --ldp-request-timeout TEXT\n" + " --ldp-endpoint TEXT\n" + " --ldp-default-stream-id TEXT\n" " --ldp-consumer-key TEXT\n" " --ldp-application-secret TEXT\n" " --ldp-application-key TEXT\n" - " --ldp-endpoint TEXT\n" - " -n, --new / -a, --all List not fetched (or all) archives\n" - " -D, --details / -I, --ids Get archives detailed output (JSON)\n" + " fs backend: \n" + " --fs-locale-encoding TEXT\n" + " --fs-default-query-string TEXT\n" + " --fs-default-directory-path PATH\n" + " --fs-default-chunk-size INTEGER\n" + " es backend: \n" + " --es-refresh-after-write TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-locale-encoding TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-default-index TEXT\n" + " --es-default-chunk-size INTEGER\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" + " clickhouse backend: \n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" + " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" + " --clickhouse-password TEXT\n" + " --clickhouse-username TEXT\n" + " --clickhouse-event-table-name TEXT\n" + " --clickhouse-database TEXT\n" + " --clickhouse-port INTEGER\n" + " --clickhouse-host TEXT\n" + " async_mongo backend: \n" + " --async-mongo-locale-encoding TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-connection-uri TEXT\n" + " async_es backend: \n" + " --async-es-refresh-after-write TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-default-index TEXT\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " -t, --target TEXT Container to list events from\n" + " -n, --new / -a, --all List not fetched (or all) documents\n" + " -D, --details / -I, --ids Get documents detailed output (JSON)\n" + " --help Show this message and exit.\n" ) in result.output result = runner.invoke(cli, ["list"]) assert result.exit_code > 0 assert ( - "Error: Missing option '-b' / '--backend'. Choose from:\n\tldp,\n\tfs,\n\t" - "swift,\n\ts3\n" + "Error: Missing option '-b' / '--backend'. Choose from:\n\tasync_es,\n\t" + "async_mongo,\n\tclickhouse,\n\tes,\n\tfs,\n\tldp,\n\tmongo,\n\tswift," + "\n\ts3\n" ) in result.output @@ -240,12 +334,11 @@ def test_cli_write_command_usage(): assert result.exit_code == 0 expected_output = ( - "Usage: ralph write [OPTIONS] [ARCHIVE]\n" - "\n" - " Write an archive to a configured backend.\n" - "\n" + "Usage: ralph write [OPTIONS]\n\n" + " Write an archive to a configured backend.\n\n" "Options:\n" - " -b, --backend [es|mongo|clickhouse|ldp|fs|swift|s3|lrs]\n" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|swift|s3|lrs]" + "\n" " Backend [required]\n" " lrs backend: \n" " --lrs-statements-endpoint TEXT\n" @@ -255,33 +348,59 @@ def test_cli_write_command_usage(): " --lrs-username TEXT\n" " --lrs-base-url TEXT\n" " s3 backend: \n" - " --s3-endpoint-url TEXT\n" - " --s3-bucket-name TEXT\n" + " --s3-locale-encoding TEXT\n" + " --s3-default-chunk-size INTEGER\n" + " --s3-default-bucket-name TEXT\n" " --s3-default-region TEXT\n" + " --s3-endpoint-url TEXT\n" " --s3-session-token TEXT\n" " --s3-secret-access-key TEXT\n" " --s3-access-key-id TEXT\n" " swift backend: \n" - " --swift-os-identity-api-version TEXT\n" - " --swift-os-auth-url TEXT\n" - " --swift-os-project-domain-name TEXT\n" - " --swift-os-user-domain-name TEXT\n" - " --swift-os-storage-url TEXT\n" - " --swift-os-region-name TEXT\n" - " --swift-os-password TEXT\n" - " --swift-os-username TEXT\n" - " --swift-os-tenant-name TEXT\n" - " --swift-os-tenant-id TEXT\n" - " fs backend: \n" - " --fs-path TEXT\n" + " --swift-locale-encoding TEXT\n" + " --swift-default-container TEXT\n" + " --swift-user-domain-name TEXT\n" + " --swift-object-storage-url TEXT\n" + " --swift-region-name TEXT\n" + " --swift-project-domain-name TEXT\n" + " --swift-tenant-name TEXT\n" + " --swift-tenant-id TEXT\n" + " --swift-identity-api-version TEXT\n" + " --swift-password TEXT\n" + " --swift-username TEXT\n" + " --swift-auth-url TEXT\n" + " mongo backend: \n" + " --mongo-locale-encoding TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-connection-uri TEXT\n" " ldp backend: \n" - " --ldp-stream-id TEXT\n" " --ldp-service-name TEXT\n" + " --ldp-request-timeout TEXT\n" + " --ldp-endpoint TEXT\n" + " --ldp-default-stream-id TEXT\n" " --ldp-consumer-key TEXT\n" " --ldp-application-secret TEXT\n" " --ldp-application-key TEXT\n" - " --ldp-endpoint TEXT\n" + " fs backend: \n" + " --fs-locale-encoding TEXT\n" + " --fs-default-query-string TEXT\n" + " --fs-default-directory-path PATH\n" + " --fs-default-chunk-size INTEGER\n" + " es backend: \n" + " --es-refresh-after-write TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-locale-encoding TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-default-index TEXT\n" + " --es-default-chunk-size INTEGER\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" " clickhouse backend: \n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" " --clickhouse-password TEXT\n" " --clickhouse-username TEXT\n" @@ -289,30 +408,35 @@ def test_cli_write_command_usage(): " --clickhouse-database TEXT\n" " --clickhouse-port INTEGER\n" " --clickhouse-host TEXT\n" - " mongo backend: \n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-collection TEXT\n" - " --mongo-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " es backend: \n" - " --es-op-type TEXT\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-index TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" + " async_mongo backend: \n" + " --async-mongo-locale-encoding TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-connection-uri TEXT\n" + " async_es backend: \n" + " --async-es-refresh-after-write TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-default-index TEXT\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" " -c, --chunk-size INTEGER Get events by chunks of size #\n" " -f, --force Overwrite existing archives or records\n" - " -I, --ignore-errors Continue writing regardless of raised " - "errors\n" + " -I, --ignore-errors Continue writing regardless of raised errors" + "\n" " -s, --simultaneous With HTTP backend, POST all chunks\n" " simultaneously (instead of sequentially)\n" " -m, --max-num-simultaneous INTEGER\n" - " The maximum number of chunks to send at " - "once,\n" - " when using `--simultaneous`. Use `-1` to " - "not\n" + " The maximum number of chunks to send at once" + ",\n" + " when using `--simultaneous`. Use `-1` to not" + "\n" " set a limit.\n" - " -t, --target TEXT Endpoint in which to write events (e.g.\n" - " `statements`)\n" + " -t, --target TEXT The target container to write into\n" " --help Show this message and exit.\n" ) assert expected_output in result.output @@ -320,8 +444,8 @@ def test_cli_write_command_usage(): result = runner.invoke(cli, ["write"]) assert result.exit_code > 0 assert ( - "Missing option '-b' / '--backend'. Choose from:\n\tes,\n\tmongo," - "\n\tclickhouse,\n\tldp,\n\tfs,\n\tswift,\n\ts3,\n\tlrs\n" + "Missing option '-b' / '--backend'. Choose from:\n\tasync_es,\n\tasync_mongo,\n" + "\tclickhouse,\n\tes,\n\tfs,\n\tldp,\n\tmongo,\n\tswift,\n\ts3,\n\tlrs\n" ) in result.output @@ -331,10 +455,38 @@ def test_cli_runserver_command_usage(): result = runner.invoke(cli, ["runserver", "--help"]) expected_output = ( + "Usage: ralph runserver [OPTIONS]\n\n" + " Run the API server for the development environment.\n\n" + " Starts uvicorn programmatically for convenience and documentation.\n\n" "Options:\n" - " -b, --backend [es|mongo|clickhouse]\n" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|mongo]\n" " Backend [required]\n" + " mongo backend: \n" + " --mongo-locale-encoding TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-connection-uri TEXT\n" + " fs backend: \n" + " --fs-default-lrs-file TEXT\n" + " --fs-locale-encoding TEXT\n" + " --fs-default-query-string TEXT\n" + " --fs-default-directory-path PATH\n" + " --fs-default-chunk-size INTEGER\n" + " es backend: \n" + " --es-refresh-after-write TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-locale-encoding TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-default-index TEXT\n" + " --es-default-chunk-size INTEGER\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" " clickhouse backend: \n" + " --clickhouse-ids-chunk-size INTEGER\n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" " --clickhouse-password TEXT\n" " --clickhouse-username TEXT\n" @@ -342,19 +494,25 @@ def test_cli_runserver_command_usage(): " --clickhouse-database TEXT\n" " --clickhouse-port INTEGER\n" " --clickhouse-host TEXT\n" - " mongo backend: \n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-collection TEXT\n" - " --mongo-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " es backend: \n" - " --es-op-type TEXT\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-index TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" + " async_mongo backend: \n" + " --async-mongo-locale-encoding TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-connection-uri TEXT\n" + " async_es backend: \n" + " --async-es-refresh-after-write TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-default-index TEXT\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" " -h, --host TEXT LRS server host name\n" " -p, --port INTEGER LRS server port\n" + " --help Show this message and exit.\n" ) - assert result.exit_code == 0 assert expected_output in result.output diff --git a/tests/test_conf.py b/tests/test_conf.py index 346fb1564..670bf5ba6 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1,15 +1,12 @@ """Tests for Ralph's configuration loading.""" from importlib import reload -from inspect import signature -from pathlib import PosixPath import pytest -from pydantic import ValidationError from ralph import conf +from ralph.backends.conf import BackendSettings from ralph.conf import CommaSeparatedTuple, Settings, settings -from ralph.utils import import_string def test_conf_settings_field_value_priority(fs, monkeypatch): @@ -50,8 +47,8 @@ def test_conf_settings_field_value_priority(fs, monkeypatch): def test_conf_comma_separated_list_with_valid_values(value, expected, monkeypatch): """Test the CommaSeparatedTuple pydantic data type with valid values.""" assert next(CommaSeparatedTuple.__get_validators__())(value) == expected - monkeypatch.setenv("RALPH_BACKENDS__DATABASE__ES__HOSTS", "".join(value)) - assert Settings().BACKENDS.DATABASE.ES.HOSTS == expected + monkeypatch.setenv("RALPH_BACKENDS__DATA__ES__HOSTS", "".join(value)) + assert BackendSettings().BACKENDS.DATA.ES.HOSTS == expected @pytest.mark.parametrize("value", [{}, None]) @@ -61,116 +58,6 @@ def test_conf_comma_separated_list_with_invalid_values(value): next(CommaSeparatedTuple.__get_validators__())(value) -@pytest.mark.parametrize( - "ca_certs,verify_certs,expected", - [ - ("/path", "True", {"ca_certs": PosixPath("/path"), "verify_certs": True}), - ("/path2", "f", {"ca_certs": PosixPath("/path2"), "verify_certs": False}), - (None, None, {"ca_certs": None, "verify_certs": None}), - ], -) -def test_conf_es_client_options_with_valid_values( - ca_certs, verify_certs, expected, monkeypatch -): - """Test the ESClientOptions pydantic data type with valid values.""" - # Using None here as in "not set by user" - if ca_certs is not None: - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__ca_certs", f"{ca_certs}" - ) - # Using None here as in "not set by user" - if verify_certs is not None: - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__verify_certs", - f"{verify_certs}", - ) - assert Settings().BACKENDS.DATABASE.ES.CLIENT_OPTIONS.dict() == expected - - -@pytest.mark.parametrize( - "ca_certs,verify_certs", - [ - ("/path", 3), - ("/path", None), - ], -) -def test_conf_es_client_options_with_invalid_values( - ca_certs, verify_certs, monkeypatch -): - """Test the ESClientOptions pydantic data type with invalid values.""" - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__ca_certs", f"{ca_certs}" - ) - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__verify_certs", - f"{verify_certs}", - ) - with pytest.raises(ValidationError, match="1 validation error for"): - Settings().BACKENDS.DATABASE.ES.CLIENT_OPTIONS.dict() - - -@pytest.mark.parametrize( - "document_class,tz_aware,expected", - [ - ("dict", "True", {"document_class": "dict", "tz_aware": True}), - ("str", "f", {"document_class": "str", "tz_aware": False}), - (None, None, {"document_class": None, "tz_aware": None}), - ], -) -def test_conf_mongo_client_options_with_valid_values( - document_class, tz_aware, expected, monkeypatch -): - """Test the MongoClientOptions pydantic data type with valid values.""" - # Using None here as in "not set by user" - if document_class is not None: - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__MONGO__CLIENT_OPTIONS__document_class", - f"{document_class}", - ) - # Using None here as in "not set by user" - if tz_aware is not None: - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__MONGO__CLIENT_OPTIONS__tz_aware", - f"{tz_aware}", - ) - assert Settings().BACKENDS.DATABASE.MONGO.CLIENT_OPTIONS.dict() == expected - - -@pytest.mark.parametrize( - "document_class,tz_aware", - [ - ("dict", 3), - ("str", None), - ], -) -def test_conf_mongo_client_options_with_invalid_values( - document_class, tz_aware, monkeypatch -): - """Test the MongoClientOptions pydantic data type with invalid values.""" - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__MONGO__CLIENT_OPTIONS__document_class", - f"{document_class}", - ) - monkeypatch.setenv( - "RALPH_BACKENDS__DATABASE__MONGO__CLIENT_OPTIONS__tz_aware", - f"{tz_aware}", - ) - with pytest.raises(ValidationError, match="1 validation error for"): - Settings().BACKENDS.DATABASE.MONGO.CLIENT_OPTIONS.dict() - - -def test_conf_settings_should_define_all_backends_options(): - """Test that Settings model defines all backends options.""" - for _, backends in settings.BACKENDS: - for _, backend in backends: - # pylint: disable=protected-access - backend_class = import_string(backend._class_path) - for parameter in signature(backend_class.__init__).parameters.values(): - if parameter.name == "self": - continue - assert hasattr(backend, parameter.name.upper()) - - def test_conf_core_settings_should_impact_settings_defaults(monkeypatch): """Test that core settings update application settings values.""" monkeypatch.setenv("RALPH_APP_DIR", "/foo") @@ -186,4 +73,3 @@ def test_conf_core_settings_should_impact_settings_defaults(monkeypatch): # Defaults. assert str(conf.settings.AUTH_FILE) == "/foo/auth.json" - assert conf.settings.BACKENDS.STORAGE.FS.PATH == "/foo/archives" diff --git a/tests/test_logger.py b/tests/test_logger.py index ac29d5d9a..17625e6df 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -42,12 +42,12 @@ def test_logger_exists(fs, monkeypatch): runner = CliRunner() result = runner.invoke( cli, - ["write", "-b", "fs", "test_file"], + ["write", "-b", "fs", "-t", "test_file"], input="test input", ) assert result.exit_code == 0 - assert "Writing archive test_file to the configured fs backend" in result.output + assert "Writing to target test_file for the configured fs backend" in result.output assert "Backend parameters:" in result.output diff --git a/tests/test_utils.py b/tests/test_utils.py index 79e279075..654eba3c1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,14 @@ """Tests for Ralph utils.""" +from abc import ABC +from types import ModuleType + import pytest from pydantic import BaseModel from ralph import utils as ralph_utils -from ralph.conf import InstantiableSettingsItem, settings +from ralph.backends.conf import backends_settings +from ralph.conf import InstantiableSettingsItem def test_utils_import_string(): @@ -23,19 +27,20 @@ def test_utils_import_string(): def test_utils_get_backend_type(): """Test get_backend_type utility.""" + backend_types = [backend_type[1] for backend_type in backends_settings.BACKENDS] assert ( - ralph_utils.get_backend_type(settings.BACKENDS, "es") - == settings.BACKENDS.DATABASE + ralph_utils.get_backend_type(backend_types, "es") + == backends_settings.BACKENDS.DATA ) assert ( - ralph_utils.get_backend_type(settings.BACKENDS, "ldp") - == settings.BACKENDS.STORAGE + ralph_utils.get_backend_type(backend_types, "lrs") + == backends_settings.BACKENDS.HTTP ) assert ( - ralph_utils.get_backend_type(settings.BACKENDS, "ws") - == settings.BACKENDS.STREAM + ralph_utils.get_backend_type(backend_types, "ws") + == backends_settings.BACKENDS.STREAM ) - assert ralph_utils.get_backend_type(settings.BACKENDS, "foo") is None + assert ralph_utils.get_backend_type(backend_types, "foo") is None @pytest.mark.parametrize( @@ -45,32 +50,54 @@ def test_utils_get_backend_type(): ({}, {}), # Options not matching the backend name are ignored. ({"foo": "bar", "not_dummy_foo": "baz"}, {}), - # One option matches the backend name and overrides the default. - ({"dummy_foo": "bar", "not_dummy_foo": "baz"}, {"foo": "bar"}), ], ) -def test_utils_get_backend_instance(options, expected): +def test_utils_get_backend_instance(monkeypatch, options, expected): """Test get_backend_instance utility should return the expected result.""" - class DummyBackendSettings(InstantiableSettingsItem): + class DummyTestBackendSettings(InstantiableSettingsItem): """Represents a dummy backend setting.""" - foo: str = "foo" # pylint: disable=disallowed-name + FOO: str = "FOO" # pylint: disable=disallowed-name def get_instance(self, **init_parameters): # pylint: disable=no-self-use """Returns the init_parameters.""" return init_parameters - class TestBackendType(BaseModel): - """A backend type including the DummyBackendSettings.""" + class DummyTestBackend(ABC): + """Represents a dummy backend instance.""" + + type = "test" + name = "dummy" + + def __init__(self, *args, **kargs): # pylint: disable=unused-argument + return + + def __call__(self, *args, **kwargs): # pylint: disable=unused-argument + return {} + + def mock_import_module(*args, **kwargs): # pylint: disable=unused-argument + """""" + test_module = ModuleType(name="ralph.backends.test.dummy") + + test_module.DummyTestBackendSettings = DummyTestBackendSettings + test_module.DummyTestBackend = DummyTestBackend + + return test_module + + class TestBackendSettings(BaseModel): # DATA-backend-type + """A backend type including the DummyTestBackendSettings.""" - DUMMY: DummyBackendSettings = DummyBackendSettings() + DUMMY: DummyTestBackendSettings = ( + DummyTestBackendSettings() + ) # Es-Backend-settings + monkeypatch.setattr(ralph_utils, "import_module", mock_import_module) backend_instance = ralph_utils.get_backend_instance( - TestBackendType(), "dummy", options + TestBackendSettings(), "dummy", options ) - assert isinstance(backend_instance, dict) - assert backend_instance == expected + assert isinstance(backend_instance, DummyTestBackend) + assert backend_instance() == expected @pytest.mark.parametrize("path,value", [(["foo", "bar"], "bar_value")]) From 61422a18dafc4ee92f165fbac98740253b159321 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Wed, 6 Sep 2023 14:10:28 +0200 Subject: [PATCH 26/65] =?UTF-8?q?=E2=9C=A8(cli)=20make=20async=20backends?= =?UTF-8?q?=20usable=20in=20cli?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With the addition of new asynchronous backends, it could be useful to be able to use them in the CLI. --- src/ralph/backends/data/async_es.py | 4 ++-- src/ralph/backends/data/async_mongo.py | 8 +++++-- src/ralph/backends/data/base.py | 16 ++++++++++++- src/ralph/cli.py | 21 ++++++++++++----- src/ralph/utils.py | 31 ++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 11 deletions(-) diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 3b39d7fcd..f94d2a64c 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -12,7 +12,7 @@ BaseAsyncDataBackend, BaseOperationType, DataBackendStatus, - enforce_query_checks, + async_enforce_query_checks, ) from ralph.backends.data.es import ESDataBackend, ESDataBackendSettings, ESQuery from ralph.exceptions import BackendException, BackendParameterException @@ -109,7 +109,7 @@ async def list( for index in indices: yield index - @enforce_query_checks + @async_enforce_query_checks async def read( self, *, diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 7b332f311..8d2d99907 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -20,7 +20,11 @@ from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import parse_bytes_to_dict -from ..data.base import BaseAsyncDataBackend, DataBackendStatus, enforce_query_checks +from ..data.base import ( + BaseAsyncDataBackend, + DataBackendStatus, + async_enforce_query_checks, +) logger = logging.getLogger(__name__) @@ -110,7 +114,7 @@ async def list( logger.error(msg, error) raise BackendException(msg % error) from error - @enforce_query_checks + @async_enforce_query_checks async def read( self, *, diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index da20328f3..641e08b53 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -236,6 +236,20 @@ def close(self) -> None: """ +def async_enforce_query_checks(method): + """Enforce query argument type checking for methods using it.""" + + @functools.wraps(method) + async def wrapper(*args, **kwargs): + """Wrap method execution.""" + query = kwargs.pop("query", None) + self_ = args[0] + async for result in method(*args, query=self_.validate_query(query), **kwargs): + yield result + + return wrapper + + class BaseAsyncDataBackend(ABC): """Base async data backend interface.""" @@ -313,7 +327,7 @@ async def list( """ @abstractmethod - @enforce_query_checks + @async_enforce_query_checks async def read( self, *, diff --git a/src/ralph/cli.py b/src/ralph/cli.py index b3272a021..1281330be 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -4,7 +4,7 @@ import logging import re import sys -from inspect import isclass, isasyncgen +from inspect import isasyncgen, isclass, iscoroutinefunction from pathlib import Path from tempfile import NamedTemporaryFile from typing import List @@ -38,11 +38,12 @@ from ralph.models.selector import ModelSelector from ralph.models.validator import Validator from ralph.utils import ( + execute_async, get_backend_instance, get_backend_type, get_root_logger, import_string, - iter_over_async + iter_over_async, ) # cli module logger @@ -625,15 +626,18 @@ def read( backend = get_backend_instance(backend_type, backend, options) if backend_type == backends_settings.BACKENDS.DATA: - for statement in backend.read( + statements = backend.read( query=query, target=target, chunk_size=chunk_size, raw_output=True, ignore_errors=ignore_errors, - ): + ) + statements = ( + iter_over_async(statements) if isasyncgen(statements) else statements + ) + for statement in statements: click.echo(statement) - elif backend_type == backends_settings.BACKENDS.STREAM: backend.stream(sys.stdout.buffer) elif backend_type == backends_settings.BACKENDS.HTTP: @@ -729,7 +733,12 @@ def write( backend = get_backend_instance(backend_type, backend, options) if backend_type == backends_settings.BACKENDS.DATA: - backend.write( + writer = ( + execute_async(backend.write) + if iscoroutinefunction(backend.write) + else backend.write + ) + writer( data=sys.stdin.buffer, target=target, chunk_size=chunk_size, diff --git a/src/ralph/utils.py b/src/ralph/utils.py index e0f3f9eec..07ee27a6b 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -235,3 +235,34 @@ def read_raw( continue logger_class.error(msg, error) raise BackendException(msg % error) from error + + +def iter_over_async(agenerator) -> Iterable: + """Iterate synchronously over an asynchronous generator.""" + loop = asyncio.get_event_loop() + aiterator = aiter(agenerator) + + async def get_next(): + """Get the next element from the async iterator.""" + try: + obj = await anext(aiterator) + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = loop.run_until_complete(get_next()) + if done: + break + yield obj + + +def execute_async(method): + """Run asynchronous method in a synchronous context.""" + + def wrapper(*args, **kwargs): + """Wrap method execution.""" + loop = asyncio.get_event_loop() + loop.run_until_complete(method(*args, **kwargs)) + + return wrapper From c48a2059e7916083d2d2114c8b925994b0b36526 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Wed, 6 Sep 2023 15:11:47 +0200 Subject: [PATCH 27/65] =?UTF-8?q?=F0=9F=90=9B(backends)=20fix=20clickhouse?= =?UTF-8?q?=20client=20options=20int=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding a default value for ClickHouse client option `allow_experimental_object_type` highlights a pydantic validation error with type `Literal[0,1]`. Switching to `coint`. --- src/ralph/backends/data/clickhouse.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 5b5c09a4c..66c659947 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -5,22 +5,12 @@ from datetime import datetime from io import IOBase from itertools import chain -from typing import ( - Any, - Dict, - Generator, - Iterable, - Iterator, - List, - Literal, - NamedTuple, - Union, -) +from typing import Any, Dict, Generator, Iterable, Iterator, List, NamedTuple, Union from uuid import UUID, uuid4 import clickhouse_connect from clickhouse_connect.driver.exceptions import ClickHouseError -from pydantic import BaseModel, Json, ValidationError +from pydantic import BaseModel, Json, ValidationError, conint from ralph.backends.data.base import ( BaseDataBackend, @@ -47,7 +37,7 @@ class ClickHouseClientOptions(ClientOptions): """Pydantic model for `clickhouse` client options.""" date_time_input_format: str = "best_effort" - allow_experimental_object_type: Literal[0, 1] = 1 + allow_experimental_object_type: conint(ge=0, le=1) = 1 class InsertTuple(NamedTuple): From 4797c353fef971483dcb8400f0d9f011be5312a5 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 7 Sep 2023 11:44:43 +0200 Subject: [PATCH 28/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20integra?= =?UTF-8?q?te=20unified=20backends=20in=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With addition of unified backends and changes to the conf files, API router needs some changes to be able to get the backends instance. --- src/ralph/api/routers/health.py | 14 +-- src/ralph/api/routers/statements.py | 71 +++++++++------ src/ralph/backends/data/base.py | 2 + src/ralph/backends/data/es.py | 5 +- src/ralph/backends/http/base.py | 1 + src/ralph/backends/lrs/async_es.py | 4 +- src/ralph/backends/lrs/async_mongo.py | 4 +- src/ralph/backends/lrs/base.py | 36 +++----- src/ralph/backends/lrs/clickhouse.py | 39 +++++---- src/ralph/backends/lrs/es.py | 34 +++++--- src/ralph/backends/lrs/fs.py | 41 ++++++--- src/ralph/backends/lrs/mongo.py | 39 +++++---- src/ralph/backends/stream/base.py | 1 + src/ralph/utils.py | 2 +- tests/api/test_health.py | 12 +-- tests/api/test_statements.py | 16 ++-- tests/api/test_statements_get.py | 39 ++++++--- tests/api/test_statements_post.py | 60 ++++++------- tests/api/test_statements_put.py | 60 ++++++------- tests/backends/lrs/test_async_es.py | 39 ++++----- tests/backends/lrs/test_async_mongo.py | 34 ++++---- tests/backends/lrs/test_clickhouse.py | 114 ++++++++++++++++++------- tests/backends/lrs/test_es.py | 39 ++++----- tests/backends/lrs/test_fs.py | 4 +- tests/backends/lrs/test_mongo.py | 44 +++++----- tests/fixtures/backends.py | 24 +++--- 26 files changed, 443 insertions(+), 335 deletions(-) diff --git a/src/ralph/api/routers/health.py b/src/ralph/api/routers/health.py index a7f1823cd..bb51551c8 100644 --- a/src/ralph/api/routers/health.py +++ b/src/ralph/api/routers/health.py @@ -1,20 +1,24 @@ """API routes related to application health checking.""" import logging +from typing import Union from fastapi import APIRouter, status from fastapi.responses import JSONResponse -from ralph.backends.database.base import BaseDatabase +from ralph.backends.conf import backends_settings +from ralph.backends.lrs.base import BaseAsyncLRSBackend, BaseLRSBackend from ralph.conf import settings +from ralph.utils import get_backend_instance logger = logging.getLogger(__name__) router = APIRouter() -DATABASE_CLIENT: BaseDatabase = getattr( - settings.BACKENDS.DATABASE, settings.RUNSERVER_BACKEND.upper() -).get_instance() +BACKEND_CLIENT: Union[BaseLRSBackend, BaseAsyncLRSBackend] = get_backend_instance( + backend_type=backends_settings.BACKENDS.LRS, + backend_name=settings.RUNSERVER_BACKEND, +) @router.get("/__lbheartbeat__") @@ -32,7 +36,7 @@ async def heartbeat(): Returns a 200 if all checks are successful. """ - content = {"database": DATABASE_CLIENT.status().value} + content = {"database": (await await_if_coroutine(BACKEND_CLIENT.status())).value} status_code = ( status.HTTP_200_OK if all(v == "ok" for v in content.values()) diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index 9c1d31ad9..2fc82e3e4 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -26,9 +26,10 @@ from ralph.api.auth.user import AuthenticatedUser from ralph.api.forwarding import forward_xapi_statements, get_active_xapi_forwardings from ralph.api.models import ErrorDetail, LaxStatement -from ralph.backends.database.base import ( +from ralph.backends.conf import backends_settings +from ralph.backends.lrs.base import ( AgentParameters, - BaseDatabase, + BaseLRSBackend, RalphStatementsQuery, ) from ralph.conf import settings @@ -41,7 +42,7 @@ BaseXapiAgentWithOpenId, ) from ralph.models.xapi.base.common import IRI -from ralph.utils import now, statements_are_equivalent +from ralph.utils import get_backend_instance, now, statements_are_equivalent logger = logging.getLogger(__name__) @@ -51,9 +52,10 @@ ) -DATABASE_CLIENT: BaseDatabase = getattr( - settings.BACKENDS.DATABASE, settings.RUNSERVER_BACKEND.upper() -).get_instance() +BACKEND_CLIENT: BaseLRSBackend = get_backend_instance( + backend_type=backends_settings.BACKENDS.LRS, + backend_name=settings.RUNSERVER_BACKEND, +) POST_PUT_RESPONSES = { 400: { @@ -343,8 +345,10 @@ async def get( # Query Database try: - query_result = DATABASE_CLIENT.query_statements( - RalphStatementsQuery.construct(**{**query_params, "limit": limit}) + query_result = await await_if_coroutine( + BACKEND_CLIENT.query_statements( + RalphStatementsQuery.construct(**{**query_params, "limit": limit}) + ) ) except BackendException as error: raise HTTPException( @@ -388,7 +392,7 @@ async def get( @router.put("/", responses=POST_PUT_RESPONSES, status_code=status.HTTP_204_NO_CONTENT) @router.put("", responses=POST_PUT_RESPONSES, status_code=status.HTTP_204_NO_CONTENT) -# pylint: disable=unused-argument +# pylint: disable=unused-argument, too-many-branches async def put( current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], statement: LaxStatement, @@ -424,19 +428,26 @@ async def put( _enrich_statement_with_authority(statement_as_dict, current_user) try: - existing_statement = DATABASE_CLIENT.query_statements_by_ids([statement_id]) + if isinstance(BACKEND_CLIENT, BaseLRSBackend): + existing_statements = list( + BACKEND_CLIENT.query_statements_by_ids([statement_id]) + ) + else: + existing_statements = [ + x async for x in BACKEND_CLIENT.query_statements_by_ids([statement_id]) + ] except BackendException as error: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="xAPI statements query failed", ) from error - if existing_statement: + if existing_statements: # The LRS specification calls for deep comparison of duplicate statement ids. # In the case that the current statement is not equivalent to one found # in the database we return a 409, otherwise the usual 204. - for existing in existing_statement: - if not statements_are_equivalent(statement_as_dict, existing["_source"]): + for existing in existing_statements: + if not statements_are_equivalent(statement_as_dict, existing): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="A different statement already exists with the same ID", @@ -445,7 +456,9 @@ async def put( # For valid requests, perform the bulk indexing of all incoming statements try: - success_count = DATABASE_CLIENT.put([statement_as_dict], ignore_errors=False) + success_count = BACKEND_CLIENT.write( + data=[statement_as_dict], ignore_errors=False + ) except (BackendException, BadFormatException) as exc: logger.error("Failed to index submitted statement") raise HTTPException( @@ -458,6 +471,7 @@ async def put( @router.post("/", responses=POST_PUT_RESPONSES) @router.post("", responses=POST_PUT_RESPONSES) +# pylint: disable = too-many-branches async def post( current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], statements: Union[LaxStatement, List[LaxStatement]], @@ -498,9 +512,17 @@ async def post( ) try: - existing_statements = DATABASE_CLIENT.query_statements_by_ids( - list(statements_dict) - ) + if isinstance(BACKEND_CLIENT, BaseLRSBackend): + existing_statements = list( + BACKEND_CLIENT.query_statements_by_ids(list(statements_dict)) + ) + else: + existing_statements = [ + x + async for x in BACKEND_CLIENT.query_statements_by_ids( + list(statements_dict) + ) + ] except BackendException as error: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -515,16 +537,15 @@ async def post( if existing_statements: existing_ids = set() for existing in existing_statements: - existing_ids.add(existing["_id"]) + existing_ids.add(existing["id"]) + # The LRS specification calls for deep comparison of duplicates. This # is done here. If they are not exactly the same, we raise an error. - if not statements_are_equivalent( - statements_dict[existing["_id"]], existing["_source"] - ): + if not statements_are_equivalent(statements_dict[existing["id"]], existing): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Differing statements already exist with the same ID: " - f"{existing['_id']}", + f"{existing['id']}", ) # Filter existing statements from the incoming statements @@ -533,16 +554,14 @@ async def post( for key, value in statements_dict.items() if key not in existing_ids } - - # Return if all incoming statements already exist if not statements_dict: response.status_code = status.HTTP_204_NO_CONTENT return # For valid requests, perform the bulk indexing of all incoming statements try: - success_count = DATABASE_CLIENT.put( - statements_dict.values(), ignore_errors=False + success_count = BACKEND_CLIENT.write( + data=statements_dict.values(), ignore_errors=False ) except (BackendException, BadFormatException) as exc: logger.error("Failed to index submitted statements") diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 641e08b53..00abbc5be 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -82,6 +82,7 @@ def wrapper(*args, **kwargs): class BaseDataBackend(ABC): """Base data backend interface.""" + type = "data" name = "base" query_model = BaseQuery default_operation_type = BaseOperationType.INDEX @@ -253,6 +254,7 @@ async def wrapper(*args, **kwargs): class BaseAsyncDataBackend(ABC): """Base async data backend interface.""" + type = "data" name = "base" query_model = BaseQuery default_operation_type = BaseOperationType.INDEX diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index db9ea11d7..ece828bc5 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -248,7 +248,10 @@ def read( kwargs["q"] = query.query_string count = chunk_size - while limit or chunk_size == count: + # The first condition is set to comprise either limit as None + # (when the backend query does not have `size` parameter), + # or limit with a positive value. + while limit != 0 and chunk_size == count: kwargs["size"] = limit if limit and limit < chunk_size else chunk_size try: documents = self.client.search(**kwargs)["hits"]["hits"] diff --git a/src/ralph/backends/http/base.py b/src/ralph/backends/http/base.py index 8418fd7c4..cc0379b6d 100644 --- a/src/ralph/backends/http/base.py +++ b/src/ralph/backends/http/base.py @@ -81,6 +81,7 @@ class Config: class BaseHTTPBackend(ABC): """Base HTTP backend interface.""" + type = "http" name = "base" query = BaseQuery diff --git a/src/ralph/backends/lrs/async_es.py b/src/ralph/backends/lrs/async_es.py index c9dae7da5..df8cf2e98 100644 --- a/src/ralph/backends/lrs/async_es.py +++ b/src/ralph/backends/lrs/async_es.py @@ -6,7 +6,7 @@ from ralph.backends.data.async_es import AsyncESDataBackend from ralph.backends.lrs.base import ( BaseAsyncLRSBackend, - StatementParameters, + RalphStatementsQuery, StatementQueryResult, ) from ralph.backends.lrs.es import ESLRSBackend @@ -21,7 +21,7 @@ class AsyncESLRSBackend(BaseAsyncLRSBackend, AsyncESDataBackend): settings_class = AsyncESDataBackend.settings_class async def query_statements( - self, params: StatementParameters + self, params: RalphStatementsQuery ) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" query = ESLRSBackend.get_query(params=params) diff --git a/src/ralph/backends/lrs/async_mongo.py b/src/ralph/backends/lrs/async_mongo.py index 3b26c0f78..aed815d44 100644 --- a/src/ralph/backends/lrs/async_mongo.py +++ b/src/ralph/backends/lrs/async_mongo.py @@ -7,7 +7,7 @@ from ralph.backends.data.async_mongo import AsyncMongoDataBackend from ralph.backends.lrs.base import ( BaseAsyncLRSBackend, - StatementParameters, + RalphStatementsQuery, StatementQueryResult, ) from ralph.backends.lrs.mongo import MongoLRSBackend @@ -22,7 +22,7 @@ class AsyncMongoLRSBackend(BaseAsyncLRSBackend, AsyncMongoDataBackend): settings_class = AsyncMongoDataBackend.settings_class async def query_statements( - self, params: StatementParameters + self, params: RalphStatementsQuery ) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" query = MongoLRSBackend.get_query(params) diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 008d60dfe..0cf552f25 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -2,9 +2,7 @@ from abc import abstractmethod from dataclasses import dataclass -from datetime import datetime -from typing import Iterator, List, Literal, Optional -from uuid import UUID +from typing import Iterator, List, Optional from pydantic import BaseModel @@ -13,6 +11,7 @@ BaseDataBackend, BaseDataBackendSettings, ) +from ralph.backends.http.async_lrs import LRSStatementsQuery class BaseLRSBackendSettings(BaseDataBackendSettings): @@ -41,38 +40,24 @@ class AgentParameters(BaseModel): account__home_page: Optional[str] -class StatementParameters(BaseModel): - """LRS statements query parameters.""" - - # pylint: disable=too-many-instance-attributes - - statementId: Optional[str] # pylint: disable=invalid-name - voidedStatementId: Optional[str] # pylint: disable=invalid-name - agent: Optional[AgentParameters] - verb: Optional[str] - activity: Optional[str] - registration: Optional[UUID] - related_activities: Optional[bool] - related_agents: Optional[bool] - since: Optional[datetime] - until: Optional[datetime] - limit: Optional[int] - format: Optional[Literal["ids", "exact", "canonical"]] = "exact" - attachments: Optional[bool] - ascending: Optional[bool] +class RalphStatementsQuery(LRSStatementsQuery): + """Represents a dictionary of possible LRS query parameters.""" + + agent: Optional[AgentParameters] = AgentParameters.construct() search_after: Optional[str] pit_id: Optional[str] - authority: Optional[AgentParameters] + authority: Optional[AgentParameters] = AgentParameters.construct() ignore_order: Optional[bool] class BaseLRSBackend(BaseDataBackend): """Base LRS backend interface.""" + type = "lrs" settings_class = BaseLRSBackendSettings @abstractmethod - def query_statements(self, params: StatementParameters) -> StatementQueryResult: + def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" @abstractmethod @@ -83,11 +68,12 @@ def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: class BaseAsyncLRSBackend(BaseAsyncDataBackend): """Base async LRS backend interface.""" + type = "lrs" settings_class = BaseLRSBackendSettings @abstractmethod async def query_statements( - self, params: StatementParameters + self, params: RalphStatementsQuery ) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" diff --git a/src/ralph/backends/lrs/clickhouse.py b/src/ralph/backends/lrs/clickhouse.py index 7c97ecd54..1721879b7 100644 --- a/src/ralph/backends/lrs/clickhouse.py +++ b/src/ralph/backends/lrs/clickhouse.py @@ -11,7 +11,7 @@ AgentParameters, BaseLRSBackend, BaseLRSBackendSettings, - StatementParameters, + RalphStatementsQuery, StatementQueryResult, ) from ralph.exceptions import BackendException, BackendParameterException @@ -36,12 +36,12 @@ class ClickHouseLRSBackend(BaseLRSBackend, ClickHouseDataBackend): settings_class = ClickHouseLRSBackendSettings - def query_statements(self, params: StatementParameters) -> StatementQueryResult: + def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" ch_params = params.dict(exclude_none=True) where = [] - if params.statementId: + if params.statement_id: where.append("event_id = {statementId:UUID}") self._add_agent_filters(ch_params, where, params.agent, "actor") @@ -131,11 +131,12 @@ def chunk_id_list(chunk_size=self.settings.IDS_CHUNK_SIZE): try: for chunk_ids in chunk_id_list(): query["parameters"]["ids"] = chunk_ids - yield from self.read( + ch_response = self.read( query=query, target=self.event_table_name, ignore_errors=True, ) + yield from (document["event"] for document in ch_response) except (BackendException, BackendParameterException) as error: msg = "Failed to read from ClickHouse" logger.error(msg) @@ -151,27 +152,33 @@ def _add_agent_filters( """Add filters relative to agents to `where`.""" if not agent_params: return - if agent_params.mbox: - ch_params[f"{target_field}__mbox"] = agent_params.mbox + if not isinstance(agent_params, dict): + agent_params = agent_params.dict() + if agent_params.get("mbox"): + ch_params[f"{target_field}__mbox"] = agent_params.get("mbox") where.append(f"event.{target_field}.mbox = {{{target_field}__mbox:String}}") - elif agent_params.mbox_sha1sum: - ch_params[f"{target_field}__mbox_sha1sum"] = agent_params.mbox_sha1sum + elif agent_params.get("mbox_sha1sum"): + ch_params[f"{target_field}__mbox_sha1sum"] = agent_params.get( + "mbox_sha1sum" + ) where.append( f"event.{target_field}.mbox_sha1sum = {{{target_field}__mbox_sha1sum:String}}" # noqa: E501 # pylint: disable=line-too-long ) - elif agent_params.openid: - ch_params[f"{target_field}__openid"] = agent_params.openid + elif agent_params.get("openid"): + ch_params[f"{target_field}__openid"] = agent_params.get("openid") where.append( f"event.{target_field}.openid = {{{target_field}__openid:String}}" ) - elif agent_params.account__name: - ch_params[f"{target_field}__account__name"] = agent_params.account__name + elif agent_params.get("account__name"): + ch_params[f"{target_field}__account__name"] = agent_params.get( + "account__name" + ) where.append( f"event.{target_field}.account.name = {{{target_field}__account__name:String}}" # noqa: E501 # pylint: disable=line-too-long ) - ch_params[ - f"{target_field}__account_home_page" - ] = agent_params.account__home_page + ch_params[f"{target_field}__account__home_page"] = agent_params.get( + "account__home_page" + ) where.append( - f"event.{target_field}.account.homePage = {{{target_field}__account_home_page:String}}" # noqa: E501 # pylint: disable=line-too-long + f"event.{target_field}.account.homePage = {{{target_field}__account__home_page:String}}" # noqa: E501 # pylint: disable=line-too-long ) diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index 5bf7d749e..6c498c772 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -7,7 +7,7 @@ from ralph.backends.lrs.base import ( AgentParameters, BaseLRSBackend, - StatementParameters, + RalphStatementsQuery, StatementQueryResult, ) from ralph.exceptions import BackendException, BackendParameterException @@ -20,7 +20,7 @@ class ESLRSBackend(BaseLRSBackend, ESDataBackend): settings_class = ESDataBackend.settings_class - def query_statements(self, params: StatementParameters) -> StatementQueryResult: + def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" query = self.get_query(params=params) try: @@ -46,12 +46,12 @@ def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: raise error @staticmethod - def get_query(params: StatementParameters) -> ESQuery: + def get_query(params: RalphStatementsQuery) -> ESQuery: """Construct query from statement parameters.""" es_query_filters = [] - if params.statementId: - es_query_filters += [{"term": {"_id": params.statementId}}] + if params.statement_id: + es_query_filters += [{"term": {"_id": params.statement_id}}] ESLRSBackend._add_agent_filters(es_query_filters, params.agent, "actor") ESLRSBackend._add_agent_filters(es_query_filters, params.authority, "authority") @@ -95,17 +95,23 @@ def _add_agent_filters( """Add filters relative to agents to `es_query_filters`.""" if not agent_params: return - if agent_params.mbox: + + if not isinstance(agent_params, dict): + agent_params = agent_params.dict() + + if agent_params.get("mbox"): field = f"{target_field}.mbox.keyword" - es_query_filters += [{"term": {field: agent_params.mbox}}] - elif agent_params.mbox_sha1sum: + es_query_filters += [{"term": {field: agent_params.get("mbox")}}] + elif agent_params.get("mbox_sha1sum"): field = f"{target_field}.mbox_sha1sum.keyword" - es_query_filters += [{"term": {field: agent_params.mbox_sha1sum}}] - elif agent_params.openid: + es_query_filters += [{"term": {field: agent_params.get("mbox_sha1sum")}}] + elif agent_params.get("openid"): field = f"{target_field}.openid.keyword" - es_query_filters += [{"term": {field: agent_params.openid}}] - elif agent_params.account__name: + es_query_filters += [{"term": {field: agent_params.get("openid")}}] + elif agent_params.get("account__name"): field = f"{target_field}.account.name.keyword" - es_query_filters += [{"term": {field: agent_params.account__name}}] + es_query_filters += [{"term": {field: agent_params.get("account__name")}}] field = f"{target_field}.account.homePage.keyword" - es_query_filters += [{"term": {field: agent_params.account__home_page}}] + es_query_filters += [ + {"term": {field: agent_params.get("account__home_page")}} + ] diff --git a/src/ralph/backends/lrs/fs.py b/src/ralph/backends/lrs/fs.py index 648719150..5407f075b 100644 --- a/src/ralph/backends/lrs/fs.py +++ b/src/ralph/backends/lrs/fs.py @@ -12,7 +12,7 @@ AgentParameters, BaseLRSBackend, BaseLRSBackendSettings, - StatementParameters, + RalphStatementsQuery, StatementQueryResult, ) @@ -49,10 +49,10 @@ def write( # pylint: disable=too-many-arguments target = target if target else self.settings.DEFAULT_LRS_FILE return super().write(data, target, chunk_size, ignore_errors, operation_type) - def query_statements(self, params: StatementParameters) -> StatementQueryResult: + def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" filters = [] - self._add_filter_by_id(filters, params.statementId) + self._add_filter_by_id(filters, params.statement_id) self._add_filter_by_agent(filters, params.agent, params.related_agents) self._add_filter_by_authority(filters, params.authority) self._add_filter_by_verb(filters, params.verb) @@ -105,11 +105,18 @@ def _add_filter_by_agent( if not agent: return - FSLRSBackend._add_filter_by_mbox(filters, agent.mbox, related) - FSLRSBackend._add_filter_by_sha1sum(filters, agent.mbox_sha1sum, related) - FSLRSBackend._add_filter_by_openid(filters, agent.openid, related) + if not isinstance(agent, dict): + agent = agent.dict() + FSLRSBackend._add_filter_by_mbox(filters, agent.get("mbox", None), related) + FSLRSBackend._add_filter_by_sha1sum( + filters, agent.get("mbox_sha1sum", None), related + ) + FSLRSBackend._add_filter_by_openid(filters, agent.get("openid", None), related) FSLRSBackend._add_filter_by_account( - filters, agent.account__name, agent.account__home_page, related + filters, + agent.get("account__name", None), + agent.get("account__home_page", None), + related, ) @staticmethod @@ -121,15 +128,21 @@ def _add_filter_by_authority( if not authority: return - FSLRSBackend._add_filter_by_mbox(filters, authority.mbox, field="authority") + if not isinstance(authority, dict): + authority = authority.dict() + FSLRSBackend._add_filter_by_mbox( + filters, authority.get("mbox", None), field="authority" + ) FSLRSBackend._add_filter_by_sha1sum( - filters, authority.mbox_sha1sum, field="authority" + filters, authority.get("mbox_sha1sum", None), field="authority" + ) + FSLRSBackend._add_filter_by_openid( + filters, authority.get("openid", None), field="authority" ) - FSLRSBackend._add_filter_by_openid(filters, authority.openid, field="authority") FSLRSBackend._add_filter_by_account( filters, - authority.account__name, - authority.account__home_page, + authority.get("account__name", None), + authority.get("account__home_page", None), field="authority", ) @@ -312,6 +325,8 @@ def _add_filter_by_timestamp_since( filters: list, timestamp: Union[datetime, None] ) -> None: """Add the `match_since` filter if `timestamp` is set.""" + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp) def match_since(statement: dict) -> bool: """Return `True` if the statement was created after `timestamp`.""" @@ -331,6 +346,8 @@ def _add_filter_by_timestamp_until( filters: list, timestamp: Union[datetime, None] ) -> None: """Add the `match_until` function if `timestamp` is set.""" + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp) def match_until(statement: dict) -> bool: """Return `True` if the statement was created before `timestamp`.""" diff --git a/src/ralph/backends/lrs/mongo.py b/src/ralph/backends/lrs/mongo.py index fdbe83315..2d2a0a64e 100644 --- a/src/ralph/backends/lrs/mongo.py +++ b/src/ralph/backends/lrs/mongo.py @@ -10,7 +10,7 @@ from ralph.backends.lrs.base import ( AgentParameters, BaseLRSBackend, - StatementParameters, + RalphStatementsQuery, StatementQueryResult, ) from ralph.exceptions import BackendException, BackendParameterException @@ -23,7 +23,7 @@ class MongoLRSBackend(BaseLRSBackend, MongoDataBackend): settings_class = MongoDataBackend.settings_class - def query_statements(self, params: StatementParameters) -> StatementQueryResult: + def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the results of a statements query using xAPI parameters.""" query = self.get_query(params) try: @@ -52,12 +52,12 @@ def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: raise error @staticmethod - def get_query(params: StatementParameters) -> MongoQuery: + def get_query(params: RalphStatementsQuery) -> MongoQuery: """Construct query from statement parameters.""" mongo_query_filters = {} - if params.statementId: - mongo_query_filters.update({"_source.id": params.statementId}) + if params.statement_id: + mongo_query_filters.update({"_source.id": params.statement_id}) MongoLRSBackend._add_agent_filters(mongo_query_filters, params.agent, "actor") MongoLRSBackend._add_agent_filters( @@ -76,16 +76,12 @@ def get_query(params: StatementParameters) -> MongoQuery: ) if params.since: - mongo_query_filters.update( - {"_source.timestamp": {"$gt": params.since.isoformat()}} - ) + mongo_query_filters.update({"_source.timestamp": {"$gt": params.since}}) if params.until: if not params.since: mongo_query_filters["_source.timestamp"] = {} - mongo_query_filters["_source.timestamp"].update( - {"$lte": params.until.isoformat()} - ) + mongo_query_filters["_source.timestamp"].update({"$lte": params.until}) if params.search_after: search_order = "$gt" if params.ascending else "$lt" @@ -118,20 +114,23 @@ def _add_agent_filters( if not agent_params: return - if agent_params.mbox: + if not isinstance(agent_params, dict): + agent_params = agent_params.dict() + + if agent_params.get("mbox"): key = f"_source.{target_field}.mbox" - mongo_query_filters.update({key: agent_params.mbox}) + mongo_query_filters.update({key: agent_params.get("mbox")}) - if agent_params.mbox_sha1sum: + if agent_params.get("mbox_sha1sum"): key = f"_source.{target_field}.mbox_sha1sum" - mongo_query_filters.update({key: agent_params.mbox_sha1sum}) + mongo_query_filters.update({key: agent_params.get("mbox_sha1sum")}) - if agent_params.openid: + if agent_params.get("openid"): key = f"_source.{target_field}.openid" - mongo_query_filters.update({key: agent_params.openid}) + mongo_query_filters.update({key: agent_params.get("openid")}) - if agent_params.account__name: + if agent_params.get("account__name"): key = f"_source.{target_field}.account.name" - mongo_query_filters.update({key: agent_params.account__name}) + mongo_query_filters.update({key: agent_params.get("account__name")}) key = f"_source.{target_field}.account.homePage" - mongo_query_filters.update({key: agent_params.account__home_page}) + mongo_query_filters.update({key: agent_params.get("account__home_page")}) diff --git a/src/ralph/backends/stream/base.py b/src/ralph/backends/stream/base.py index 1f2b1f11e..5a6861203 100644 --- a/src/ralph/backends/stream/base.py +++ b/src/ralph/backends/stream/base.py @@ -22,6 +22,7 @@ class Config(BaseSettingsConfig): class BaseStreamBackend(ABC): """Base stream backend interface.""" + type = "stream" name = "base" settings_class = BaseStreamBackendSettings diff --git a/src/ralph/utils.py b/src/ralph/utils.py index 07ee27a6b..40bb33dfb 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -8,7 +8,7 @@ from functools import reduce from importlib import import_module from inspect import getmembers, isclass -from typing import Any, Dict, Iterable, Iterator, List, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Union from pydantic import BaseModel diff --git a/tests/api/test_health.py b/tests/api/test_health.py index d415cf6db..9fdfddfee 100644 --- a/tests/api/test_health.py +++ b/tests/api/test_health.py @@ -6,7 +6,7 @@ from ralph.api import app from ralph.api.routers import health -from ralph.backends.database.base import DatabaseStatus +from ralph.backends.data.base import DataBackendStatus from tests.fixtures.backends import ( get_clickhouse_test_backend, @@ -23,7 +23,7 @@ ) def test_api_health_lbheartbeat(backend, monkeypatch): """Test the load balancer heartbeat healthcheck.""" - monkeypatch.setattr(health, "DATABASE_CLIENT", backend()) + monkeypatch.setattr(health, "BACKEND_CLIENT", backend()) response = client.get("/__lbheartbeat__") assert response.status_code == 200 @@ -37,19 +37,21 @@ def test_api_health_lbheartbeat(backend, monkeypatch): # pylint: disable=unused-argument def test_api_health_heartbeat(backend, monkeypatch, clickhouse): """Test the heartbeat healthcheck.""" - monkeypatch.setattr(health, "DATABASE_CLIENT", backend()) + monkeypatch.setattr(health, "BACKEND_CLIENT", backend()) response = client.get("/__heartbeat__") logging.warning(response.read()) assert response.status_code == 200 assert response.json() == {"database": "ok"} - monkeypatch.setattr(health.DATABASE_CLIENT, "status", lambda: DatabaseStatus.AWAY) + monkeypatch.setattr(health.BACKEND_CLIENT, "status", lambda: DataBackendStatus.AWAY) response = client.get("/__heartbeat__") assert response.json() == {"database": "away"} assert response.status_code == 500 - monkeypatch.setattr(health.DATABASE_CLIENT, "status", lambda: DatabaseStatus.ERROR) + monkeypatch.setattr( + health.BACKEND_CLIENT, "status", lambda: DataBackendStatus.ERROR + ) response = client.get("/__heartbeat__") assert response.json() == {"database": "error"} assert response.status_code == 500 diff --git a/tests/api/test_statements.py b/tests/api/test_statements.py index ffbca72fe..0d629356e 100644 --- a/tests/api/test_statements.py +++ b/tests/api/test_statements.py @@ -4,29 +4,29 @@ from ralph import conf from ralph.api.routers import statements -from ralph.backends.database.clickhouse import ClickHouseDatabase -from ralph.backends.database.es import ESDatabase -from ralph.backends.database.mongo import MongoDatabase +from ralph.backends.data.clickhouse import ClickHouseDataBackend +from ralph.backends.data.es import ESDataBackend +from ralph.backends.data.mongo import MongoDataBackend def test_api_statements_backend_instance_with_runserver_backend_env(monkeypatch): """Tests that given the RALPH_RUNSERVER_BACKEND environment variable, the backend - instance `DATABASE_CLIENT` should be updated accordingly. + instance `BACKEND_CLIENT` should be updated accordingly. """ # Default backend - assert isinstance(statements.DATABASE_CLIENT, ESDatabase) + assert isinstance(statements.BACKEND_CLIENT, ESDataBackend) # Mongo backend monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "mongo") reload(conf) - assert isinstance(reload(statements).DATABASE_CLIENT, MongoDatabase) + assert isinstance(reload(statements).BACKEND_CLIENT, MongoDataBackend) # Elasticsearch backend monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "es") reload(conf) - assert isinstance(reload(statements).DATABASE_CLIENT, ESDatabase) + assert isinstance(reload(statements).BACKEND_CLIENT, ESDataBackend) # ClickHouse backend monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "clickhouse") reload(conf) - assert isinstance(reload(statements).DATABASE_CLIENT, ClickHouseDatabase) + assert isinstance(reload(statements).BACKEND_CLIENT, ClickHouseDataBackend) diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 163b4abde..9aa0ff742 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -10,8 +10,9 @@ from ralph.api import app from ralph.api.auth.basic import get_authenticated_user -from ralph.backends.database.clickhouse import ClickHouseDatabase -from ralph.backends.database.mongo import MongoDatabase +from ralph.backends.data.base import BaseOperationType +from ralph.backends.data.clickhouse import ClickHouseDataBackend +from ralph.backends.data.mongo import MongoDataBackend from ralph.exceptions import BackendException from tests.fixtures.backends import ( @@ -54,18 +55,28 @@ def insert_mongo_statements(mongo_client, statements): """Insert a bunch of example statements into MongoDB for testing.""" database = getattr(mongo_client, MONGO_TEST_DATABASE) collection = getattr(database, MONGO_TEST_COLLECTION) - collection.insert_many(list(MongoDatabase.to_documents(statements))) + collection.insert_many( + list( + MongoDataBackend.to_documents( + data=statements, + ignore_errors=True, + operation_type=BaseOperationType.CREATE, + logger_class=None, + ) + ) + ) def insert_clickhouse_statements(statements): - """Insert a bunch of example statements into ClickHouse for testing.""" - backend = ClickHouseDatabase( - host=CLICKHOUSE_TEST_HOST, - port=CLICKHOUSE_TEST_PORT, - database=CLICKHOUSE_TEST_DATABASE, - event_table_name=CLICKHOUSE_TEST_TABLE_NAME, + """Inserts a bunch of example statements into ClickHouse for testing.""" + settings = ClickHouseDataBackend.settings_class( + HOST=CLICKHOUSE_TEST_HOST, + PORT=CLICKHOUSE_TEST_PORT, + DATABASE=CLICKHOUSE_TEST_DATABASE, + EVENT_TABLE_NAME=CLICKHOUSE_TEST_TABLE_NAME, ) - success = backend.put(statements) + backend = ClickHouseDataBackend(settings=settings) + success = backend.write(statements) assert success == len(statements) @@ -78,8 +89,8 @@ def insert_statements_and_monkeypatch_backend( # pylint: disable=invalid-name def _insert_statements_and_monkeypatch_backend(statements): - """Insert statements once into each backend.""" - database_client_class_path = "ralph.api.routers.statements.DATABASE_CLIENT" + """Inserts statements once into each backend.""" + database_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" if request.param == "mongo": insert_mongo_statements(mongo, statements) monkeypatch.setattr(database_client_class_path, get_mongo_test_backend()) @@ -666,11 +677,11 @@ def test_api_statements_get_statements_with_database_query_failure( # pylint: disable=redefined-outer-name def mock_query_statements(*_): - """Mock the DATABASE_CLIENT.query_statements method.""" + """Mocks the BACKEND_CLIENT.query_statements method.""" raise BackendException() monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT.query_statements", + "ralph.api.routers.statements.BACKEND_CLIENT.query_statements", mock_query_statements, ) diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index 5c743ec37..5d9377979 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -8,8 +8,8 @@ from httpx import AsyncClient from ralph.api import app -from ralph.backends.database.es import ESDatabase -from ralph.backends.database.mongo import MongoDatabase +from ralph.backends.lrs.es import ESLRSBackend +from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.conf import XapiForwardingConfigurationSettings from ralph.exceptions import BackendException @@ -75,7 +75,7 @@ def test_api_statements_post_single_statement_directly( """Test the post statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -118,7 +118,7 @@ def test_api_statements_post_enriching_without_existing_values( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", get_es_test_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) statement = { "actor": { @@ -182,7 +182,7 @@ def test_api_statements_post_enriching_with_existing_values( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", get_es_test_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) statement = { "actor": { @@ -235,7 +235,7 @@ def test_api_statements_post_single_statement_no_trailing_slash( """Test that the statements endpoint also works without the trailing slash.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -271,7 +271,7 @@ def test_api_statements_post_statements_list_of_one( """Test the post statements API route with one statement in a list.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -316,7 +316,7 @@ def test_api_statements_post_statements_list( """Test the post statements API route with two statements in a list.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statements = [ { "actor": { @@ -387,7 +387,7 @@ def test_api_statements_post_statements_list_with_duplicates( """Test the post statements API route with duplicate statement IDs should fail.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -434,7 +434,7 @@ def test_api_statements_post_statements_list_with_duplicate_of_existing_statemen """ # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement_uuid = str(uuid4()) statement = { @@ -505,15 +505,13 @@ def test_api_statements_post_statements_with_a_failure_during_storage( """Test the post statements API route with a failure happening during storage.""" # pylint: disable=invalid-name,unused-argument, too-many-arguments - def put_mock(*args, **kwargs): - """Raise an exception. Mock the database.put method.""" + def write_mock(*args, **kwargs): + """Raises an exception. Mocks the database.write method.""" raise BackendException() backend_instance = backend() - monkeypatch.setattr(backend_instance, "put", put_mock) - monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", backend_instance - ) + monkeypatch.setattr(backend_instance, "write", write_mock) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = { "actor": { "account": { @@ -556,9 +554,7 @@ def query_statements_by_ids_mock(*args, **kwargs): monkeypatch.setattr( backend_instance, "query_statements_by_ids", query_statements_by_ids_mock ) - monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", backend_instance - ) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = { "actor": { "account": { @@ -588,7 +584,7 @@ def query_statements_by_ids_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_post_statements_list_without_statement_forwarding( +def test_api_statements_post_statements_list_without_statement_forwarding( backend, auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the post statements API route, given an empty forwarding configuration, @@ -609,7 +605,7 @@ def spy_mock_forward_xapi_statements(_): monkeypatch.setattr( "ralph.api.routers.statements.get_active_xapi_forwardings", lambda: [] ) - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { @@ -642,15 +638,21 @@ def spy_mock_forward_xapi_statements(_): @pytest.mark.parametrize( "forwarding_backend", [ - lambda: ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_FORWARDING_INDEX), - lambda: MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_FORWARDING_COLLECTION, + lambda: ESLRSBackend( + settings=ESLRSBackend.settings_class( + HOSTS=ES_TEST_HOSTS, DEFAULT_INDEX=ES_TEST_FORWARDING_INDEX + ) + ), + lambda: MongoLRSBackend( + settings=MongoLRSBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_FORWARDING_COLLECTION, + ) ), ], ) -async def test_post_statements_list_with_statement_forwarding( +async def test_api_statements_post_statements_list_with_statement_forwarding( receiving_backend, forwarding_backend, monkeypatch, @@ -690,7 +692,7 @@ async def test_post_statements_list_with_statement_forwarding( ) # Receiving client should use the receiving Elasticsearch client for storage receiving_patch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", receiving_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", receiving_backend() ) lrs_context = lrs(app) # Start receiving LRS client @@ -720,7 +722,7 @@ async def test_post_statements_list_with_statement_forwarding( # Forwarding client should use the forwarding Elasticsearch client for storage forwarding_patch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", forwarding_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", forwarding_backend() ) # Start forwarding LRS client async with AsyncClient( diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index 4700c38ab..a4f43e29c 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -7,8 +7,8 @@ from httpx import AsyncClient from ralph.api import app -from ralph.backends.database.es import ESDatabase -from ralph.backends.database.mongo import MongoDatabase +from ralph.backends.lrs.es import ESLRSBackend +from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.conf import XapiForwardingConfigurationSettings from ralph.exceptions import BackendException @@ -69,7 +69,7 @@ def test_api_statements_put_single_statement_directly( """Test the put statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -111,7 +111,7 @@ def test_api_statements_put_enriching_without_existing_values( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", get_es_test_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) statement = { "actor": { @@ -174,7 +174,7 @@ def test_api_statements_put_enriching_with_existing_values( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", get_es_test_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) statement = { "actor": { @@ -227,7 +227,7 @@ def test_api_statements_put_single_statement_no_trailing_slash( """Test that the statements endpoint also works without the trailing slash.""" # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -261,7 +261,7 @@ def test_api_statements_put_statement_id_mismatch( ): # pylint: disable=invalid-name,unused-argument """Test the put statements API route when the statementId doesn't match.""" - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -299,7 +299,7 @@ def test_api_statements_put_statements_list_of_one( ): # pylint: disable=invalid-name,unused-argument """Test that we fail on PUTs with a list, even if it's one statement.""" - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -336,7 +336,7 @@ def test_api_statements_put_statement_duplicate_of_existing_statement( """ # pylint: disable=invalid-name,unused-argument - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { "account": { @@ -387,21 +387,19 @@ def test_api_statements_put_statement_duplicate_of_existing_statement( "backend", [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) -def test_api_statement_put_statements_with_a_failure_during_storage( +def test_api_statements_put_statements_with_a_failure_during_storage( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure happening during storage.""" # pylint: disable=invalid-name,unused-argument, too-many-arguments - def put_mock(*args, **kwargs): - """Raise an exception. Mock the database.put method.""" + def write_mock(*args, **kwargs): + """Raises an exception. Mocks the database.write method.""" raise BackendException() backend_instance = backend() - monkeypatch.setattr(backend_instance, "put", put_mock) - monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", backend_instance - ) + monkeypatch.setattr(backend_instance, "write", write_mock) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = { "actor": { "account": { @@ -444,9 +442,7 @@ def query_statements_by_ids_mock(*args, **kwargs): monkeypatch.setattr( backend_instance, "query_statements_by_ids", query_statements_by_ids_mock ) - monkeypatch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", backend_instance - ) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = { "actor": { "account": { @@ -476,7 +472,7 @@ def query_statements_by_ids_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_put_statement_without_statement_forwarding( +def test_api_statements_put_statement_without_statement_forwarding( backend, auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the put statements API route, given an empty forwarding configuration, @@ -497,7 +493,7 @@ def spy_mock_forward_xapi_statements(_): monkeypatch.setattr( "ralph.api.routers.statements.get_active_xapi_forwardings", lambda: [] ) - monkeypatch.setattr("ralph.api.routers.statements.DATABASE_CLIENT", backend()) + monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = { "actor": { @@ -529,15 +525,21 @@ def spy_mock_forward_xapi_statements(_): @pytest.mark.parametrize( "forwarding_backend", [ - lambda: ESDatabase(hosts=ES_TEST_HOSTS, index=ES_TEST_FORWARDING_INDEX), - lambda: MongoDatabase( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_FORWARDING_COLLECTION, + lambda: ESLRSBackend( + settings=ESLRSBackend.settings_class( + HOSTS=ES_TEST_HOSTS, DEFAULT_INDEX=ES_TEST_FORWARDING_INDEX + ) + ), + lambda: MongoLRSBackend( + settings=MongoLRSBackend.settings_class( + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_FORWARDING_COLLECTION, + ) ), ], ) -async def test_put_statement_with_statement_forwarding( +async def test_api_statements_put_statement_with_statement_forwarding( receiving_backend, forwarding_backend, monkeypatch, @@ -577,7 +579,7 @@ async def test_put_statement_with_statement_forwarding( ) # Receiving client should use the receiving Elasticsearch client for storage receiving_patch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", receiving_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", receiving_backend() ) lrs_context = lrs(app) # Start receiving LRS client @@ -610,7 +612,7 @@ async def test_put_statement_with_statement_forwarding( # Forwarding client should use the forwarding Elasticsearch client for storage forwarding_patch.setattr( - "ralph.api.routers.statements.DATABASE_CLIENT", forwarding_backend() + "ralph.api.routers.statements.BACKEND_CLIENT", forwarding_backend() ) # Start forwarding LRS client async with AsyncClient( diff --git a/tests/backends/lrs/test_async_es.py b/tests/backends/lrs/test_async_es.py index 9dd9e7466..4d034922c 100644 --- a/tests/backends/lrs/test_async_es.py +++ b/tests/backends/lrs/test_async_es.py @@ -2,14 +2,13 @@ import logging import re -from datetime import datetime import pytest from elastic_transport import ApiResponseMeta from elasticsearch import ApiError from elasticsearch.helpers import bulk -from ralph.backends.lrs.base import StatementParameters +from ralph.backends.lrs.base import RalphStatementsQuery from ralph.exceptions import BackendException from tests.fixtures.backends import ES_TEST_FORWARDING_INDEX, ES_TEST_INDEX @@ -26,7 +25,7 @@ "query": {"match_all": {}}, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -39,7 +38,7 @@ "query": {"bool": {"filter": [{"term": {"_id": "statementId"}}]}}, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -59,7 +58,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -88,7 +87,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -117,7 +116,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -150,7 +149,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -186,7 +185,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -205,18 +204,14 @@ { "range": { "timestamp": { - "gt": datetime.fromisoformat( - "2021-06-24T00:00:20.194929+00:00" - ) + "gt": "2021-06-24T00:00:20.194929+00:00" } } }, { "range": { "timestamp": { - "lte": datetime.fromisoformat( - "2023-06-24T00:00:20.194929+00:00" - ) + "lte": "2023-06-24T00:00:20.194929+00:00" } } }, @@ -225,7 +220,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -238,7 +233,7 @@ "query": {"match_all": {}}, "query_string": None, "search_after": ["1686557542970", "0"], - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -251,7 +246,7 @@ "query": {"match_all": {}}, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": "_shard_doc", "track_total_hits": False, }, @@ -276,7 +271,7 @@ async def mock_read(query, chunk_size): backend = async_es_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = await backend.query_statements(StatementParameters(**params)) + result = await backend.query_statements(RalphStatementsQuery.construct(**params)) assert result.statements == [{}] assert result.pit_id == "foo_pit_id" assert result.search_after == "bar_search_after|baz_search_after" @@ -299,7 +294,7 @@ async def test_backends_lrs_async_es_lrs_backend_query_statements( assert await backend.write(documents) == 1 # Check the expected search query results. - result = await backend.query_statements(StatementParameters(limit=10)) + result = await backend.query_statements(RalphStatementsQuery.construct(limit=10)) assert result.statements == documents assert re.match(r"[0-9]+\|0", result.search_after) @@ -326,7 +321,7 @@ async def mock_read(**_): msg = "Query error" with pytest.raises(BackendException, match=msg): with caplog.at_level(logging.ERROR): - await backend.query_statements(StatementParameters()) + await backend.query_statements(RalphStatementsQuery.construct()) await backend.close() @@ -359,7 +354,7 @@ def mock_search(**_): _ = [ statement async for statement in backend.query_statements_by_ids( - StatementParameters() + RalphStatementsQuery.construct() ) ] diff --git a/tests/backends/lrs/test_async_mongo.py b/tests/backends/lrs/test_async_mongo.py index 75ee8cd56..b3f3a9108 100644 --- a/tests/backends/lrs/test_async_mongo.py +++ b/tests/backends/lrs/test_async_mongo.py @@ -6,7 +6,7 @@ from bson.objectid import ObjectId from pymongo import ASCENDING, DESCENDING -from ralph.backends.lrs.base import StatementParameters +from ralph.backends.lrs.base import RalphStatementsQuery from ralph.exceptions import BackendException from tests.fixtures.backends import MONGO_TEST_FORWARDING_COLLECTION @@ -20,7 +20,7 @@ {}, { "filter": {}, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -34,7 +34,7 @@ {"statementId": "statementId"}, { "filter": {"_source.id": "statementId"}, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -51,7 +51,7 @@ "_source.id": "statementId", "_source.actor.mbox": "mailto:foo@bar.baz", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -73,7 +73,7 @@ "a7a5b7462b862c8c8767d43d43e865ffff754a64" ), }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -93,7 +93,7 @@ "_source.id": "statementId", "_source.actor.openid": "http://toby.openid.example.org/", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -117,7 +117,7 @@ "_source.actor.account.name": "13936749", "_source.actor.account.homePage": "http://www.example.com", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -138,7 +138,7 @@ "_source.object.id": "http://www.example.com/meetings/34534", "_source.object.objectType": "Activity", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -160,7 +160,7 @@ "$lte": "2023-06-24T00:00:20.194929+00:00", }, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -180,7 +180,7 @@ "$lte": "2023-06-24T00:00:20.194929+00:00", }, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -196,7 +196,7 @@ "filter": { "_id": {"$lt": ObjectId("666f6f2d6261722d71757578")}, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -212,7 +212,7 @@ "filter": { "_id": {"$gt": ObjectId("666f6f2d6261722d71757578")}, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", ASCENDING), @@ -239,7 +239,7 @@ async def mock_read(query, chunk_size): backend = async_mongo_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = await backend.query_statements(StatementParameters(**params)) + result = await backend.query_statements(RalphStatementsQuery.construct(**params)) assert result.statements == [{}] assert not result.pit_id assert result.search_after == "search_after_id" @@ -270,8 +270,8 @@ async def test_backends_lrs_async_mongo_lrs_backend_query_statements_with_succes ] assert await backend.write(documents) == 2 - statement_parameters = StatementParameters( - statementId="62b9ce922c26b46b68ffc68f", + statement_parameters = RalphStatementsQuery.construct( + statement_id="62b9ce922c26b46b68ffc68f", agent={ "account__name": "test_name", "account__home_page": "http://example.com", @@ -312,7 +312,7 @@ async def mock_read(**_): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - await backend.query_statements(StatementParameters()) + await backend.query_statements(RalphStatementsQuery.construct()) assert ( "ralph.backends.lrs.async_mongo", @@ -345,7 +345,7 @@ async def mock_read(**_): _ = [ statement async for statement in backend.query_statements_by_ids( - StatementParameters() + RalphStatementsQuery.construct() ) ] diff --git a/tests/backends/lrs/test_clickhouse.py b/tests/backends/lrs/test_clickhouse.py index e08dfa9df..c7d4fb8a4 100644 --- a/tests/backends/lrs/test_clickhouse.py +++ b/tests/backends/lrs/test_clickhouse.py @@ -7,7 +7,7 @@ import pytest from clickhouse_connect.driver.exceptions import ClickHouseError -from ralph.backends.lrs.base import StatementParameters +from ralph.backends.lrs.base import RalphStatementsQuery from ralph.exceptions import BackendException @@ -19,8 +19,15 @@ {}, { "where": [], - "params": {"format": "exact"}, - "limit": None, + "params": { + "ascending": False, + "attachments": False, + "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, + }, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), @@ -29,8 +36,17 @@ {"statementId": "test_id"}, { "where": ["event_id = {statementId:UUID}"], - "params": {"statementId": "test_id", "format": "exact"}, - "limit": None, + "params": { + "ascending": False, + "attachments": False, + "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, + "statementId": "test_id", + "statement_id": "test_id", + }, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), @@ -43,11 +59,17 @@ "event.actor.mbox = {actor__mbox:String}", ], "params": { - "statementId": "test_id", "actor__mbox": "mailto:foo@bar.baz", + "ascending": False, + "attachments": False, "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, + "statementId": "test_id", + "statement_id": "test_id", }, - "limit": None, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), @@ -63,11 +85,17 @@ "event.actor.mbox_sha1sum = {actor__mbox_sha1sum:String}", ], "params": { - "statementId": "test_id", "actor__mbox_sha1sum": "a7a5b7462b862c8c8767d43d43e865ffff754a64", + "ascending": False, + "attachments": False, "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, + "statementId": "test_id", + "statement_id": "test_id", }, - "limit": None, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), @@ -83,11 +111,17 @@ "event.actor.openid = {actor__openid:String}", ], "params": { - "statementId": "test_id", "actor__openid": "http://toby.openid.example.org/", + "ascending": False, + "attachments": False, "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, + "statementId": "test_id", + "statement_id": "test_id", }, - "limit": None, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), @@ -105,16 +139,21 @@ "where": [ "event_id = {statementId:UUID}", "event.actor.account.name = {actor__account__name:String}", - "event.actor.account.homePage = {actor__account_home_page:String}", + "event.actor.account.homePage = {actor__account__home_page:String}", ], "params": { - "statementId": "test_id", "actor__account__name": "13936749", - "actor__account_home_page": "http://www.example.com", + "actor__account__home_page": "http://www.example.com", "ascending": True, + "attachments": False, "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, + "statementId": "test_id", + "statement_id": "test_id", }, - "limit": None, + "limit": 0, "sort": "emission_time ASCENDING, event_id ASCENDING", }, ), @@ -132,10 +171,14 @@ "event.object.id = {activity:String}", ], "params": { - "verb": "http://adlnet.gov/expapi/verbs/attended", + "ascending": False, "activity": "http://www.example.com/meetings/34534", - "limit": 100, + "attachments": False, "format": "exact", + "limit": 100, + "related_activities": False, + "related_agents": False, + "verb": "http://adlnet.gov/expapi/verbs/attended", }, "limit": 100, "sort": "emission_time DESCENDING, event_id DESCENDING", @@ -153,15 +196,20 @@ "emission_time <= {until:DateTime64(6)}", ], "params": { + "ascending": False, + "attachments": False, + "format": "exact", + "limit": 0, + "related_activities": False, + "related_agents": False, "since": datetime( 2021, 6, 24, 0, 0, 20, 194929, tzinfo=timezone.utc - ), + ).isoformat(), "until": datetime( 2023, 6, 24, 0, 0, 20, 194929, tzinfo=timezone.utc - ), - "format": "exact", + ).isoformat(), }, - "limit": None, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), @@ -179,17 +227,22 @@ ), ], "params": { - "search_after": "1686557542970|0", - "pit_id": "46ToAwMDaWR5BXV1a", + "ascending": False, + "attachments": False, "format": "exact", + "limit": 0, + "pit_id": "46ToAwMDaWR5BXV1a", + "related_activities": False, + "related_agents": False, + "search_after": "1686557542970|0", }, - "limit": None, + "limit": 0, "sort": "emission_time DESCENDING, event_id DESCENDING", }, ), ], ) -def test_backends_database_clickhouse_query_statements( +def test_backends_database_clickhouse_query_statements_query( params, expected_params, monkeypatch, @@ -216,8 +269,7 @@ def mock_read(query, target, ignore_errors): backend = clickhouse_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - - backend.query_statements(StatementParameters(**params)) + backend.query_statements(RalphStatementsQuery.construct(**params)) backend.close() @@ -249,7 +301,7 @@ def test_backends_lrs_clickhouse_lrs_backend_query_statements( # Check the expected search query results. result = backend.query_statements( - StatementParameters(statementId=test_id, limit=10) + RalphStatementsQuery.construct(statementId=test_id, limit=10) ) assert result.statements == statements backend.close() @@ -279,7 +331,7 @@ def test_backends_lrs_clickhouse_lrs_backend__find(clickhouse, clickhouse_lrs_ba assert success == 1 # Check the expected search query results. - result = backend.query_statements(StatementParameters()) + result = backend.query_statements(RalphStatementsQuery.construct()) assert result.statements == statements backend.close() @@ -312,7 +364,7 @@ def test_backends_lrs_clickhouse_lrs_backend_query_statements_by_ids( # Check the expected search query results. result = list(backend.query_statements_by_ids([test_id])) - assert result[0]["event"] == statements[0] + assert result[0] == statements[0] backend.close() @@ -335,7 +387,7 @@ def mock_query(*args, **kwargs): msg = "Failed to read documents: Query error" with pytest.raises(BackendException, match=msg): - next(backend.query_statements(StatementParameters())) + next(backend.query_statements(RalphStatementsQuery.construct())) assert ( "ralph.backends.lrs.clickhouse", diff --git a/tests/backends/lrs/test_es.py b/tests/backends/lrs/test_es.py index 91bb56f31..151ae3af3 100644 --- a/tests/backends/lrs/test_es.py +++ b/tests/backends/lrs/test_es.py @@ -2,14 +2,13 @@ import logging import re -from datetime import datetime import pytest from elastic_transport import ApiResponseMeta from elasticsearch import ApiError from elasticsearch.helpers import bulk -from ralph.backends.lrs.base import StatementParameters +from ralph.backends.lrs.base import RalphStatementsQuery from ralph.exceptions import BackendException from tests.fixtures.backends import ES_TEST_FORWARDING_INDEX, ES_TEST_INDEX @@ -26,7 +25,7 @@ "query": {"match_all": {}}, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -39,7 +38,7 @@ "query": {"bool": {"filter": [{"term": {"_id": "statementId"}}]}}, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -59,7 +58,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -88,7 +87,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -117,7 +116,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -150,7 +149,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -186,7 +185,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -205,18 +204,14 @@ { "range": { "timestamp": { - "gt": datetime.fromisoformat( - "2021-06-24T00:00:20.194929+00:00" - ) + "gt": "2021-06-24T00:00:20.194929+00:00" } } }, { "range": { "timestamp": { - "lte": datetime.fromisoformat( - "2023-06-24T00:00:20.194929+00:00" - ) + "lte": "2023-06-24T00:00:20.194929+00:00" } } }, @@ -225,7 +220,7 @@ }, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -238,7 +233,7 @@ "query": {"match_all": {}}, "query_string": None, "search_after": ["1686557542970", "0"], - "size": None, + "size": 0, "sort": [{"timestamp": {"order": "desc"}}], "track_total_hits": False, }, @@ -251,7 +246,7 @@ "query": {"match_all": {}}, "query_string": None, "search_after": None, - "size": None, + "size": 0, "sort": "_shard_doc", "track_total_hits": False, }, @@ -275,7 +270,7 @@ def mock_read(query, chunk_size): backend = es_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = backend.query_statements(StatementParameters(**params)) + result = backend.query_statements(RalphStatementsQuery.construct(**params)) assert not result.statements assert result.pit_id == "foo_pit_id" assert result.search_after == "bar_search_after|baz_search_after" @@ -295,7 +290,7 @@ def test_backends_lrs_es_lrs_backend_query_statements(es, es_lrs_backend): assert backend.write(documents) == 1 # Check the expected search query results. - result = backend.query_statements(StatementParameters(limit=10)) + result = backend.query_statements(RalphStatementsQuery.construct(limit=10)) assert result.statements == documents assert re.match(r"[0-9]+\|0", result.search_after) @@ -320,7 +315,7 @@ def mock_read(**_): msg = "Query error" with pytest.raises(BackendException, match=msg): with caplog.at_level(logging.ERROR): - backend.query_statements(StatementParameters()) + backend.query_statements(RalphStatementsQuery.construct()) assert ( "ralph.backends.lrs.es", @@ -349,7 +344,7 @@ def mock_search(**_): msg = r"Failed to execute Elasticsearch query: ApiError\(None, 'Query error'\)" with pytest.raises(BackendException, match=msg): with caplog.at_level(logging.ERROR): - list(backend.query_statements_by_ids(StatementParameters())) + list(backend.query_statements_by_ids(RalphStatementsQuery.construct())) assert ( "ralph.backends.lrs.es", diff --git a/tests/backends/lrs/test_fs.py b/tests/backends/lrs/test_fs.py index 2a3968719..b64bd518f 100644 --- a/tests/backends/lrs/test_fs.py +++ b/tests/backends/lrs/test_fs.py @@ -2,7 +2,7 @@ import pytest -from ralph.backends.lrs.base import StatementParameters +from ralph.backends.lrs.base import RalphStatementsQuery @pytest.mark.parametrize( @@ -260,7 +260,7 @@ def test_backends_lrs_fs_lrs_backend_query_statements_query( ] backend = fs_lrs_backend() backend.write(statements) - result = backend.query_statements(StatementParameters(**params)) + result = backend.query_statements(RalphStatementsQuery.construct(**params)) ids = [statement.get("id") for statement in result.statements] assert ids == expected_statement_ids diff --git a/tests/backends/lrs/test_mongo.py b/tests/backends/lrs/test_mongo.py index 85edc3f0d..612b9c0a7 100644 --- a/tests/backends/lrs/test_mongo.py +++ b/tests/backends/lrs/test_mongo.py @@ -6,7 +6,7 @@ from bson.objectid import ObjectId from pymongo import ASCENDING, DESCENDING -from ralph.backends.lrs.base import StatementParameters +from ralph.backends.lrs.base import AgentParameters, RalphStatementsQuery from ralph.exceptions import BackendException from tests.fixtures.backends import MONGO_TEST_FORWARDING_COLLECTION @@ -20,7 +20,7 @@ {}, { "filter": {}, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -34,7 +34,7 @@ {"statementId": "statementId"}, { "filter": {"_source.id": "statementId"}, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -51,7 +51,7 @@ "_source.id": "statementId", "_source.actor.mbox": "mailto:foo@bar.baz", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -73,7 +73,7 @@ "a7a5b7462b862c8c8767d43d43e865ffff754a64" ), }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -93,7 +93,7 @@ "_source.id": "statementId", "_source.actor.openid": "http://toby.openid.example.org/", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -117,7 +117,7 @@ "_source.actor.account.name": "13936749", "_source.actor.account.homePage": "http://www.example.com", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -138,7 +138,7 @@ "_source.object.id": "http://www.example.com/meetings/34534", "_source.object.objectType": "Activity", }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -160,7 +160,7 @@ "$lte": "2023-06-24T00:00:20.194929+00:00", }, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -180,7 +180,7 @@ "$lte": "2023-06-24T00:00:20.194929+00:00", }, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -196,7 +196,7 @@ "filter": { "_id": {"$lt": ObjectId("666f6f2d6261722d71757578")}, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", DESCENDING), @@ -212,7 +212,7 @@ "filter": { "_id": {"$gt": ObjectId("666f6f2d6261722d71757578")}, }, - "limit": None, + "limit": 0, "projection": None, "sort": [ ("_source.timestamp", ASCENDING), @@ -238,7 +238,7 @@ def mock_read(query, chunk_size): backend = mongo_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = backend.query_statements(StatementParameters(**params)) + result = backend.query_statements(RalphStatementsQuery.construct(**params)) assert result.statements == [{}] assert not result.pit_id assert result.search_after == "search_after_id" @@ -258,7 +258,7 @@ def test_backends_lrs_mongo_lrs_backend_query_statements_with_success( timestamp = {"timestamp": "2022-06-27T15:36:50"} meta = { "actor": {"account": {"name": "test_name", "homePage": "http://example.com"}}, - "verb": {"id": "verb_id"}, + "verb": {"id": "https://xapi-example.com/verb-id"}, "object": {"id": "http://example.com", "objectType": "Activity"}, } documents = [ @@ -267,13 +267,13 @@ def test_backends_lrs_mongo_lrs_backend_query_statements_with_success( ] assert backend.write(documents) == 2 - statement_parameters = StatementParameters( + statement_parameters = RalphStatementsQuery.construct( statementId="62b9ce922c26b46b68ffc68f", - agent={ - "account__name": "test_name", - "account__home_page": "http://example.com", - }, - verb="verb_id", + agent=AgentParameters.construct( + account__name="test_name", + account__home_page="http://example.com", + ), + verb="https://xapi-example.com/verb-id", activity="http://example.com", since="2020-01-01T00:00:00.000000+00:00", until="2022-12-01T15:36:50", @@ -309,7 +309,7 @@ def mock_read(**_): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - backend.query_statements(StatementParameters()) + backend.query_statements(RalphStatementsQuery.construct()) assert ( "ralph.backends.lrs.mongo", @@ -339,7 +339,7 @@ def mock_read(**_): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - list(backend.query_statements_by_ids(StatementParameters())) + list(backend.query_statements_by_ids(RalphStatementsQuery.construct())) assert ( "ralph.backends.lrs.mongo", diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 1acf1bf61..b8f0bf9d1 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -96,27 +96,31 @@ def get_clickhouse_test_backend(): """Return a ClickHouseLRSBackend backend instance using test defaults.""" - return ClickHouseLRSBackend( - host=CLICKHOUSE_TEST_HOST, - port=CLICKHOUSE_TEST_PORT, - database=CLICKHOUSE_TEST_DATABASE, - event_table_name=CLICKHOUSE_TEST_TABLE_NAME, + settings = ClickHouseLRSBackend.settings_class( + HOST=CLICKHOUSE_TEST_HOST, + PORT=CLICKHOUSE_TEST_PORT, + DATABASE=CLICKHOUSE_TEST_DATABASE, + EVENT_TABLE_NAME=CLICKHOUSE_TEST_TABLE_NAME, ) + return ClickHouseLRSBackend(settings) @lru_cache def get_es_test_backend(): - """Return a ESLRSBackend backend instance using test defaults.""" - return ESLRSBackend(hosts=ES_TEST_HOSTS, index=ES_TEST_INDEX) + """Returns a ESLRSBackend backend instance using test defaults.""" + settings = ESLRSBackend.settings_class( + HOSTS=ES_TEST_HOSTS, DEFAULT_INDEX=ES_TEST_INDEX + ) + return ESLRSBackend(settings) @lru_cache def get_mongo_test_backend(): """Returns a MongoDatabase backend instance using test defaults.""" settings = MongoLRSBackend.settings_class( - connection_uri=MONGO_TEST_CONNECTION_URI, - database=MONGO_TEST_DATABASE, - collection=MONGO_TEST_COLLECTION, + CONNECTION_URI=MONGO_TEST_CONNECTION_URI, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=MONGO_TEST_COLLECTION, ) return MongoLRSBackend(settings) From 7205ea39a84b2ac7f4f91a3317fb1bbb61433a0c Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Sun, 10 Sep 2023 21:26:50 +0200 Subject: [PATCH 29/65] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F(backends)=20integra?= =?UTF-8?q?te=20async=20backends=20in=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With addition of unified backends, API router needs some changes to be able to use asynchronous backends. --- src/ralph/api/routers/health.py | 2 +- src/ralph/api/routers/statements.py | 15 ++++++++++----- src/ralph/utils.py | 11 +++++++++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/ralph/api/routers/health.py b/src/ralph/api/routers/health.py index bb51551c8..5f30b2847 100644 --- a/src/ralph/api/routers/health.py +++ b/src/ralph/api/routers/health.py @@ -9,7 +9,7 @@ from ralph.backends.conf import backends_settings from ralph.backends.lrs.base import BaseAsyncLRSBackend, BaseLRSBackend from ralph.conf import settings -from ralph.utils import get_backend_instance +from ralph.utils import await_if_coroutine, get_backend_instance logger = logging.getLogger(__name__) diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index 2fc82e3e4..49a3433cb 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -42,7 +42,12 @@ BaseXapiAgentWithOpenId, ) from ralph.models.xapi.base.common import IRI -from ralph.utils import get_backend_instance, now, statements_are_equivalent +from ralph.utils import ( + await_if_coroutine, + get_backend_instance, + now, + statements_are_equivalent, +) logger = logging.getLogger(__name__) @@ -456,8 +461,8 @@ async def put( # For valid requests, perform the bulk indexing of all incoming statements try: - success_count = BACKEND_CLIENT.write( - data=[statement_as_dict], ignore_errors=False + success_count = await await_if_coroutine( + BACKEND_CLIENT.write(data=[statement_as_dict], ignore_errors=False) ) except (BackendException, BadFormatException) as exc: logger.error("Failed to index submitted statement") @@ -560,8 +565,8 @@ async def post( # For valid requests, perform the bulk indexing of all incoming statements try: - success_count = BACKEND_CLIENT.write( - data=statements_dict.values(), ignore_errors=False + success_count = await await_if_coroutine( + BACKEND_CLIENT.write(data=statements_dict.values(), ignore_errors=False) ) except (BackendException, BadFormatException) as exc: logger.error("Failed to index submitted statements") diff --git a/src/ralph/utils.py b/src/ralph/utils.py index 40bb33dfb..3a2e476c9 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -7,8 +7,8 @@ import operator from functools import reduce from importlib import import_module -from inspect import getmembers, isclass -from typing import Any, Dict, Iterable, Iterator, List, Optional, Union +from inspect import getmembers, isclass, iscoroutine +from typing import Any, Dict, Iterable, Iterator, List, Union from pydantic import BaseModel @@ -266,3 +266,10 @@ def wrapper(*args, **kwargs): loop.run_until_complete(method(*args, **kwargs)) return wrapper + + +async def await_if_coroutine(value): + """Await the value if it is a coroutine, else return synchronously.""" + if iscoroutine(value): + return await value + return value From 72db48163f0253d0d96bab7c82dd07e3e3c34f91 Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Wed, 11 Oct 2023 16:11:58 +0200 Subject: [PATCH 30/65] =?UTF-8?q?=F0=9F=90=9B(test)=20fix=20pyfakefs=20fai?= =?UTF-8?q?lure=20of=20file=20creation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests using filesystem failed with pyfakefs in the CI as pyfakefs does not succeed on creating requesting files in the default directory path. The latter is then defined specifically for these tests and forced to be used in the ralph command. --- tests/test_cli.py | 26 ++++++++++++++++++-------- tests/test_logger.py | 6 ++++-- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5fec60bee..6a303884c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,7 +12,7 @@ from pydantic import ValidationError from ralph.backends.conf import backends_settings -from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings +from ralph.backends.data.fs import FSDataBackend from ralph.backends.data.ldp import LDPDataBackend from ralph.cli import ( CommaSeparatedKeyValueParamType, @@ -829,36 +829,46 @@ def mock_list(this, target=None, details=False, new=False): def test_cli_write_command_with_fs_backend(fs): """Test the write command using the FS backend.""" fs.create_dir(str(settings.APP_DIR)) + fs.create_dir("foo") - filename = Path("file1") - file_path = Path(FSDataBackendSettings().DEFAULT_DIRECTORY_PATH) / filename + filename = Path("foo/file1") # Create a file runner = CliRunner() - result = runner.invoke(cli, "write -b fs -t file1".split(), input=b"test content") + result = runner.invoke( + cli, + "write -b fs -t file1 --fs-default-directory-path foo".split(), + input=b"test content", + ) assert result.exit_code == 0 - with file_path.open("rb") as test_file: + with filename.open("rb") as test_file: content = test_file.read() assert b"test content" in content # Trying to create the same file without -f should raise an error runner = CliRunner() - result = runner.invoke(cli, "write -b fs -t file1".split(), input=b"other content") + result = runner.invoke( + cli, + "write -b fs -t file1 --fs-default-directory-path foo".split(), + input=b"other content", + ) assert result.exit_code == 1 assert "file1 already exists and overwrite is not allowed" in result.output # Try to create the same file with -f runner = CliRunner() result = runner.invoke( - cli, "write -b fs -t file1 -f".split(), input=b"other content" + cli, + "write -b fs -t file1 -f --fs-default-directory-path foo".split(), + input=b"other content", ) assert result.exit_code == 0 - with file_path.open("rb") as test_file: + with filename.open("rb") as test_file: content = test_file.read() assert b"other content" in content diff --git a/tests/test_logger.py b/tests/test_logger.py index 17625e6df..1112fd6d4 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -5,6 +5,7 @@ import ralph.logger from ralph.cli import cli +from ralph.conf import settings from ralph.exceptions import ConfigurationException @@ -35,14 +36,15 @@ def test_logger_exists(fs, monkeypatch): }, } - fs.create_dir("/dev") + fs.create_dir(str(settings.APP_DIR)) + fs.create_dir("foo") monkeypatch.setattr(ralph.logger.settings, "LOGGING", mock_default_config) runner = CliRunner() result = runner.invoke( cli, - ["write", "-b", "fs", "-t", "test_file"], + ["write", "-b", "fs", "-t", "test_file", "--fs-default-directory-path", "foo"], input="test input", ) From 95e856a7fa1ff321b4351e88bd4d1674c9869f87 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 12 Oct 2023 09:07:02 +0200 Subject: [PATCH 31/65] =?UTF-8?q?=F0=9F=90=9B(tray)=20fix=20elasticsearch?= =?UTF-8?q?=20env=20variables?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Environment variables `RALPH_BACKENDS__DATABASE__ES__*` have been renamed to `RALPH_BACKENDS__DATA__ES__*`. Changing them in the `tray`. --- Makefile | 2 +- src/helm/ralph/templates/cronjob.yaml | 2 +- src/helm/ralph/vault.yaml | 4 ++-- src/tray/templates/services/app/deploy.yml.j2 | 2 +- src/tray/vars/vault/main.yml.j2 | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 1802b28fa..7db49b6a3 100644 --- a/Makefile +++ b/Makefile @@ -79,7 +79,7 @@ arnold-bootstrap: \ $(ARNOLD) -d -c $(ARNOLD_CUSTOMER) -e $(ARNOLD_ENVIRONMENT) -a $(ARNOLD_APP) create_app_vaults && \ $(ARNOLD) -d -c $(ARNOLD_CUSTOMER) -e $(ARNOLD_ENVIRONMENT) -a elasticsearch create_app_vaults && \ $(ARNOLD) -d -c $(ARNOLD_CUSTOMER) -e $(ARNOLD_ENVIRONMENT) -- vault -a $(ARNOLD_APP) decrypt - sed -i 's/^# RALPH_BACKENDS__DATABASE__ES/RALPH_BACKENDS__DATABASE__ES/g' group_vars/customer/$(ARNOLD_CUSTOMER)/$(ARNOLD_ENVIRONMENT)/secrets/$(ARNOLD_APP).vault.yml + sed -i 's/^# RALPH_BACKENDS__DATA__ES/RALPH_BACKENDS__DATA__ES/g' group_vars/customer/$(ARNOLD_CUSTOMER)/$(ARNOLD_ENVIRONMENT)/secrets/$(ARNOLD_APP).vault.yml source .k3d-cluster.env.sh && \ $(ARNOLD) -d -c $(ARNOLD_CUSTOMER) -e $(ARNOLD_ENVIRONMENT) -- vault -a $(ARNOLD_APP) encrypt echo "skip_verification: True" > $(ARNOLD_APP_VARS) diff --git a/src/helm/ralph/templates/cronjob.yaml b/src/helm/ralph/templates/cronjob.yaml index a608d0328..811549138 100644 --- a/src/helm/ralph/templates/cronjob.yaml +++ b/src/helm/ralph/templates/cronjob.yaml @@ -52,7 +52,7 @@ spec: - name: RALPH_SENTRY_IGNORE_HEALTH_CHECKS value: "{{ .Values.sentryIgnoreHealthChecks }}" {{- if and .Values.elastic.enabled .Values.elastic.mountCACert }} - - name: RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__ca_certs + - name: RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__ca_certs value: "/usr/local/share/ca-certificates/ca.crt" {{- end }} envFrom: diff --git a/src/helm/ralph/vault.yaml b/src/helm/ralph/vault.yaml index e682464bd..ed122679e 100644 --- a/src/helm/ralph/vault.yaml +++ b/src/helm/ralph/vault.yaml @@ -1,5 +1,5 @@ -RALPH_BACKENDS__DATABASE__ES__HOSTS: http://elasticsearch:9200 -RALPH_BACKENDS__DATABASE__ES__INDEX: statements +RALPH_BACKENDS__DATA__ES__HOSTS: http://elasticsearch:9200 +RALPH_BACKENDS__DATA__ES__INDEX: statements RALPH_SENTRY_DSN: https://fake@key.ingest.sentry.io/1234567 RALPH_EXECUTION_ENVIRONMENT: production diff --git a/src/tray/templates/services/app/deploy.yml.j2 b/src/tray/templates/services/app/deploy.yml.j2 index 22e94dd2a..3e9cbd8f2 100644 --- a/src/tray/templates/services/app/deploy.yml.j2 +++ b/src/tray/templates/services/app/deploy.yml.j2 @@ -77,7 +77,7 @@ spec: - name: RALPH_SENTRY_IGNORE_HEALTH_CHECKS value: "{{ ralph_sentry_ignore_health_checks }}" {% if ralph_mount_es_ca_secret %} - - name: RALPH_BACKENDS__DATABASE__ES__CLIENT_OPTIONS__ca_certs + - name: RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__ca_certs value: "/usr/local/share/ca-certificates/es-cluster.pem" {% endif %} envFrom: diff --git a/src/tray/vars/vault/main.yml.j2 b/src/tray/vars/vault/main.yml.j2 index 85a61ae32..f0c28c23a 100644 --- a/src/tray/vars/vault/main.yml.j2 +++ b/src/tray/vars/vault/main.yml.j2 @@ -2,8 +2,8 @@ # env_type: {{ env_type }} # ES database backend -# RALPH_BACKENDS__DATABASE__ES__HOSTS: http://elasticsearch:9200 -# RALPH_BACKENDS__DATABASE__ES__INDEX: statements +# RALPH_BACKENDS__DATA__ES__HOSTS: http://elasticsearch:9200 +# RALPH_BACKENDS__DATA__ES__INDEX: statements # If you have self-generated a CA certificate for your ES cluster nodes, you may # also need this CA certificate to check certificates while requesting the From 9ff6dcd89ce41d40acec69ee2a0aac05bf867060 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 12 Oct 2023 10:07:13 +0200 Subject: [PATCH 32/65] =?UTF-8?q?=F0=9F=93=9D(project)=20update=20CHANGELO?= =?UTF-8?q?G.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed - Refactor `database` and `storage` backends under the unified `data` backend interface [BC] - Refactor LRS `query_statements` and `query_statements_by_ids` backends methods under the unified `lrs` backend interface [BC] --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a75a1a41..149ad5fe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ and this project adheres to ### Changed +- Refactor `database` and `storage` backends under the unified `data` backend +interface [BC] +- Refactor LRS `query_statements` and `query_statements_by_ids` backends +methods under the unified `lrs` backend interface [BC] - Refactor LRS Statements resource query parameters defined for `ralph` API - Helm chart: improve chart modularity - User credentials must now include an "agent" field which can be created From 7cda5606105845d0240131e60296a91bd5d50759 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 12 Oct 2023 10:07:45 +0200 Subject: [PATCH 33/65] =?UTF-8?q?=F0=9F=9A=A8(backends)=20fix=20`too-many-?= =?UTF-8?q?arguments`=20pylint=20warning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Following `pylint` upgrade to version > 3.0 , a false negative was corrected resulting in many warnings about methods having too many arguments. Escaping these warnings. --- src/ralph/backends/data/async_es.py | 1 + src/ralph/backends/data/async_mongo.py | 1 + src/ralph/backends/data/base.py | 2 ++ src/ralph/backends/data/clickhouse.py | 1 + src/ralph/backends/data/es.py | 1 + src/ralph/backends/data/fs.py | 1 + src/ralph/backends/data/ldp.py | 1 + src/ralph/backends/data/mongo.py | 1 + src/ralph/backends/data/s3.py | 1 + src/ralph/backends/data/swift.py | 1 + 10 files changed, 11 insertions(+) diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index f94d2a64c..f187717ac 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -119,6 +119,7 @@ async def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read documents matching the query in the target index and yield them. Args: diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 8d2d99907..76d8954c4 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -124,6 +124,7 @@ async def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read documents matching the `query` from `target` collection and yield them. Args: diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 00abbc5be..c5277d6cc 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -167,6 +167,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read records matching the `query` in the `target` container and yield them. Args: @@ -339,6 +340,7 @@ async def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read records matching the `query` in the `target` container and yield them. Args: diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 66c659947..795e0c798 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -197,6 +197,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read documents matching the query in the target table and yield them. Args: diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index ece828bc5..2d1bbb02e 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -205,6 +205,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read documents matching the query in the target index and yield them. Args: diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 8bba06374..85ce0fbf0 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -151,6 +151,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read files matching the query in the target folder and yield them. Args: diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index cfa8cf18d..e3cd499fb 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -156,6 +156,7 @@ def read( raw_output: bool = True, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read an archive matching the query in the target stream_id and yield it. Args: diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index 432952678..05dd83789 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -180,6 +180,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read documents matching the `query` from `target` collection and yield them. Args: diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index c20521d80..22ce05573 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -163,6 +163,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read an object matching the `query` in the `target` bucket and yields it. Args: diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index 18516d570..e4846f848 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -168,6 +168,7 @@ def read( raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: + # pylint: disable=too-many-arguments """Read objects matching the `query` in the `target` container and yields them. Args: From ea0edc04e1778b1de6b3adc0537d0263400b7a45 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 00:52:02 +0000 Subject: [PATCH 34/65] =?UTF-8?q?=E2=AC=86=EF=B8=8F(project)=20upgrade=20p?= =?UTF-8?q?ython=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit | datasource | package | from | to | | ---------- | --------------- | ------ | ------ | | pypi | hypothesis | 6.87.3 | 6.88.0 | | pypi | mkdocs-material | 9.4.4 | 9.4.6 | | pypi | moto | 4.2.5 | 4.2.6 | | pypi | pyfakefs | 5.2.4 | 5.3.0 | | pypi | sentry_sdk | 1.31.0 | 1.32.0 | --- setup.cfg | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7885a82e4..d719bd36b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,16 +75,16 @@ dev = cryptography==41.0.4 factory-boy==3.3.0 flake8==6.1.0 - hypothesis==6.87.3 + hypothesis==6.88.0 isort==5.12.0 logging-gelf==0.0.31 mkdocs==1.5.3 mkdocs-click==0.8.1 - mkdocs-material==9.4.4 + mkdocs-material==9.4.6 mkdocstrings[python-legacy]==0.23.0 - moto==4.2.5 + moto==4.2.6 pydocstyle==6.3.0 - pyfakefs==5.2.4 + pyfakefs==5.3.0 pylint==3.0.1 pytest==7.4.2 pytest-asyncio==0.21.1 @@ -104,7 +104,7 @@ lrs = ; See: https://github.com/encode/httpx/issues/2244 h11>=0.11.0 httpx<0.25.0 # pin as Python 3.7 is no longer supported from release 0.25.0 - sentry_sdk==1.31.0 + sentry_sdk==1.32.0 python-jose==3.3.0 uvicorn[standard]==0.23.2 From a07e6d512b1d3bfdaf12a70d33182f3d9c195e42 Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Mon, 16 Oct 2023 10:31:45 +0200 Subject: [PATCH 35/65] =?UTF-8?q?=F0=9F=93=9D(project)=20update=20CHANGELO?= =?UTF-8?q?G.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed: - Upgrade `sentry_sdk` to `1.32.0` --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 149ad5fe7..97118b8f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ have an authority field matching that of the user - CLI: change `push` to `write` and `fetch` to `read` [BC] - Upgrade `fastapi` to `0.103.2` - Upgrade `more-itertools` to `10.1.0` -- Upgrade `sentry_sdk` to `1.31.0` +- Upgrade `sentry_sdk` to `1.32.0` - Upgrade `uvicorn` to `0.23.2` - API: Invalid parameters now return 400 status code - API: Forwarding PUT now uses PUT (instead of POST) From a00b43949b58350fee1f40645c2f7b82e9f5a0cf Mon Sep 17 00:00:00 2001 From: SergioSim Date: Fri, 13 Oct 2023 12:32:29 +0200 Subject: [PATCH 36/65] =?UTF-8?q?=F0=9F=90=9B(backends)=20fix=20default=20?= =?UTF-8?q?RALPH=5FBACKENDS=5F=5FHTTP=5F=5FLRS=5F=5FHEADERS=20variable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It seems that pydantic ignores field aliases when building the Settings objects from environment variables. Thus we choose to use the field names for the RALPH_BACKENDS__HTTP__LRS__HEADERS variable. --- .env.dist | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.env.dist b/.env.dist index 55a38fd4f..73df46ee6 100644 --- a/.env.dist +++ b/.env.dist @@ -99,7 +99,8 @@ RALPH_BACKENDS__DATA__CLICKHOUSE__TEST_TABLE_NAME=test_xapi_events_all RALPH_BACKENDS__HTTP__LRS__BASE_URL=http://ralph:secret@0.0.0.0:8100/ RALPH_BACKENDS__HTTP__LRS__USERNAME=ralph RALPH_BACKENDS__HTTP__LRS__PASSWORD=secret -RALPH_BACKENDS__HTTP__LRS__HEADERS={"X-Experience-API-Version": "1.0.3", "content-type": "application/json"} +RALPH_BACKENDS__HTTP__LRS__HEADERS__X_EXPERIENCE_API_VERSION=1.0.3 +RALPH_BACKENDS__HTTP__LRS__HEADERS__CONTENT_TYPE=application/json RALPH_BACKENDS__HTTP__LRS__STATUS_ENDPOINT=/__heartbeat__ RALPH_BACKENDS__HTTP__LRS__STATEMENTS_ENDPOINT=/xAPI/statements From 95c6edeaa3b57fa6ba5c4a86da837709eb8a6ef0 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Thu, 12 Oct 2023 12:09:04 +0200 Subject: [PATCH 37/65] =?UTF-8?q?=F0=9F=8E=A8(tests)=20clean=20tests=20by?= =?UTF-8?q?=20factoring=20statements=20and=20renaming?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mocking statements tests now has its own function, to avoid duplicating large dict objects. Also, tests now have shorter names for clarity. --- tests/api/test_statements_get.py | 78 +++++----- tests/api/test_statements_post.py | 206 ++++---------------------- tests/api/test_statements_put.py | 187 ++++------------------- tests/backends/http/test_async_lrs.py | 61 +++----- tests/fixtures/auth.py | 4 +- tests/helpers.py | 93 ++++++++++-- tests/test_helpers.py | 61 ++++++++ 7 files changed, 259 insertions(+), 431 deletions(-) diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 9aa0ff742..56cdd2baf 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -28,8 +28,8 @@ get_mongo_test_backend, ) -from ..fixtures.auth import create_user -from ..helpers import create_mock_activity, create_mock_agent +from ..fixtures.auth import mock_basic_auth_user +from ..helpers import mock_activity, mock_agent client = TestClient(app) @@ -117,7 +117,7 @@ def _insert_statements_and_monkeypatch_backend(statements): "account_different_home_page", ], ) -def test_api_statements_get_statements_mine( +def test_api_statements_get_mine( monkeypatch, fs, insert_statements_and_monkeypatch_backend, ifi ): """(Security) Test that the get statements API route, given a "mine=True" @@ -128,27 +128,29 @@ def test_api_statements_get_statements_mine( # Create two distinct agents if ifi == "account_same_home_page": - agent_1 = create_mock_agent("account", 1, home_page_id=1) - agent_1_bis = create_mock_agent( + agent_1 = mock_agent("account", 1, home_page_id=1) + agent_1_bis = mock_agent( "account", 1, home_page_id=1, name="name", use_object_type=False ) - agent_2 = create_mock_agent("account", 2, home_page_id=1) + agent_2 = mock_agent("account", 2, home_page_id=1) elif ifi == "account_different_home_page": - agent_1 = create_mock_agent("account", 1, home_page_id=1) - agent_1_bis = create_mock_agent( + agent_1 = mock_agent("account", 1, home_page_id=1) + agent_1_bis = mock_agent( "account", 1, home_page_id=1, name="name", use_object_type=False ) - agent_2 = create_mock_agent("account", 1, home_page_id=2) + agent_2 = mock_agent("account", 1, home_page_id=2) else: - agent_1 = create_mock_agent(ifi, 1) - agent_1_bis = create_mock_agent(ifi, 1, name="name", use_object_type=False) - agent_2 = create_mock_agent(ifi, 2) + agent_1 = mock_agent(ifi, 1) + agent_1_bis = mock_agent(ifi, 1, name="name", use_object_type=False) + agent_2 = mock_agent(ifi, 2) username_1 = "jane" password_1 = "janepwd" scopes = [] - credentials_1_bis = create_user(fs, username_1, password_1, scopes, agent_1_bis) + credentials_1_bis = mock_basic_auth_user( + fs, username_1, password_1, scopes, agent_1_bis + ) # Clear cache before each test iteration get_authenticated_user.cache_clear() @@ -230,7 +232,7 @@ def test_api_statements_get_statements_mine( assert response.status_code == 422 -def test_api_statements_get_statements( +def test_api_statements_get( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route without any filters set up.""" @@ -258,7 +260,7 @@ def test_api_statements_get_statements( assert response.json() == {"statements": [statements[1], statements[0]]} -def test_api_statements_get_statements_ascending( +def test_api_statements_get_ascending( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given an "ascending" query parameter, should @@ -287,7 +289,7 @@ def test_api_statements_get_statements_ascending( assert response.json() == {"statements": [statements[0], statements[1]]} -def test_api_statements_get_statements_by_statement_id( +def test_api_statements_get_by_statement_id( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given a "statementId" query parameter, should @@ -326,7 +328,7 @@ def test_api_statements_get_statements_by_statement_id( "account_different_home_page", ], ) -def test_api_statements_get_statements_by_agent( +def test_api_statements_get_by_agent( ifi, insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given an "agent" query parameter, should @@ -336,14 +338,14 @@ def test_api_statements_get_statements_by_agent( # Create two distinct agents if ifi == "account_same_home_page": - agent_1 = create_mock_agent("account", 1, home_page_id=1) - agent_2 = create_mock_agent("account", 2, home_page_id=1) + agent_1 = mock_agent("account", 1, home_page_id=1) + agent_2 = mock_agent("account", 2, home_page_id=1) elif ifi == "account_different_home_page": - agent_1 = create_mock_agent("account", 1, home_page_id=1) - agent_2 = create_mock_agent("account", 1, home_page_id=2) + agent_1 = mock_agent("account", 1, home_page_id=1) + agent_2 = mock_agent("account", 1, home_page_id=2) else: - agent_1 = create_mock_agent(ifi, 1) - agent_2 = create_mock_agent(ifi, 2) + agent_1 = mock_agent(ifi, 1) + agent_2 = mock_agent(ifi, 2) statements = [ { @@ -370,7 +372,7 @@ def test_api_statements_get_statements_by_agent( assert response.json() == {"statements": [statements[0]]} -def test_api_statements_get_statements_by_verb( +def test_api_statements_get_by_verb( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given a "verb" query parameter, should @@ -401,7 +403,7 @@ def test_api_statements_get_statements_by_verb( assert response.json() == {"statements": [statements[1]]} -def test_api_statements_get_statements_by_activity( +def test_api_statements_get_by_activity( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given an "activity" query parameter, should @@ -409,8 +411,8 @@ def test_api_statements_get_statements_by_activity( """ # pylint: disable=redefined-outer-name - activity_0 = create_mock_activity(0) - activity_1 = create_mock_activity(1) + activity_0 = mock_activity(0) + activity_1 = mock_activity(1) statements = [ { @@ -444,7 +446,7 @@ def test_api_statements_get_statements_by_activity( assert response.json()["detail"][0]["msg"] == "'INVALID_IRI' is not a valid 'IRI'." -def test_api_statements_get_statements_since_timestamp( +def test_api_statements_get_since_timestamp( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given a "since" query parameter, should @@ -474,7 +476,7 @@ def test_api_statements_get_statements_since_timestamp( assert response.json() == {"statements": [statements[1]]} -def test_api_statements_get_statements_until_timestamp( +def test_api_statements_get_until_timestamp( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given an "until" query parameter, @@ -504,7 +506,7 @@ def test_api_statements_get_statements_until_timestamp( assert response.json() == {"statements": [statements[0]]} -def test_api_statements_get_statements_with_pagination( +def test_api_statements_get_with_pagination( monkeypatch, insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given a request leading to more results than @@ -574,7 +576,7 @@ def test_api_statements_get_statements_with_pagination( assert third_response.json() == {"statements": [statements[0]]} -def test_api_statements_get_statements_with_pagination_and_query( +def test_api_statements_get_with_pagination_and_query( monkeypatch, insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given a request with a query parameter @@ -639,7 +641,7 @@ def test_api_statements_get_statements_with_pagination_and_query( assert second_response.json() == {"statements": [statements[0]]} -def test_api_statements_get_statements_with_no_matching_statement( +def test_api_statements_get_with_no_matching_statement( insert_statements_and_monkeypatch_backend, auth_credentials ): """Test the get statements API route, given a query yielding no matching statement, @@ -668,9 +670,7 @@ def test_api_statements_get_statements_with_no_matching_statement( assert response.json() == {"statements": []} -def test_api_statements_get_statements_with_database_query_failure( - auth_credentials, monkeypatch -): +def test_api_statements_get_with_database_query_failure(auth_credentials, monkeypatch): """Test the get statements API route, given a query raising a BackendException, should return an error response with HTTP code 500. """ @@ -694,9 +694,7 @@ def mock_query_statements(*_): @pytest.mark.parametrize("id_param", ["statementId", "voidedStatementId"]) -def test_api_statements_get_statements_invalid_query_parameters( - auth_credentials, id_param -): +def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param): """Test error response for invalid query parameters""" id_1 = "be67b160-d958-4f51-b8b8-1892002dbac6" @@ -721,8 +719,8 @@ def test_api_statements_get_statements_invalid_query_parameters( # Check for 400 status code when invalid parameters are provided with a statementId for invalid_param, value in [ - ("activity", create_mock_activity()["id"]), - ("agent", json.dumps(create_mock_agent("mbox", 1))), + ("activity", mock_activity()["id"]), + ("agent", json.dumps(mock_agent("mbox", 1))), ("verb", "verb_1"), ]: response = client.get( diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index 5d9377979..350e4a11c 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -28,6 +28,7 @@ from ..helpers import ( assert_statement_get_responses_are_equivalent, + mock_statement, string_is_date, string_is_uuid, ) @@ -38,19 +39,7 @@ def test_api_statements_post_invalid_parameters(auth_credentials): """Test that using invalid parameters returns the proper status code.""" - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() # Check for 400 status code when unknown parameters are provided response = client.post( @@ -76,19 +65,7 @@ def test_api_statements_post_single_statement_directly( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements/", @@ -184,17 +161,8 @@ def test_api_statements_post_enriching_with_existing_values( monkeypatch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "object": {"id": "https://example.com/object-id/1/"}, - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() + # Add the field to be tested statement[field] = value @@ -236,19 +204,7 @@ def test_api_statements_post_single_statement_no_trailing_slash( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements", @@ -265,26 +221,14 @@ def test_api_statements_post_single_statement_no_trailing_slash( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_post_statements_list_of_one( +def test_api_statements_post_list_of_one( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement in a list.""" # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements/", @@ -310,41 +254,21 @@ def test_api_statements_post_statements_list_of_one( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_post_statements_list( +def test_api_statements_post_list( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with two statements in a list.""" # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statements = [ - { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:52Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - }, - { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - # Note the second statement has no preexisting ID - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - }, - ] + + statement_1 = mock_statement(timestamp="2022-03-15T14:07:52Z") + + # Note the second statement has no preexisting ID + statement_2 = mock_statement(timestamp="2022-03-15T14:07:51Z") + statement_2.pop("id") + + statements = [statement_1, statement_2] response = client.post( "/xAPI/statements/", @@ -381,26 +305,14 @@ def test_api_statements_post_statements_list( ], ) # pylint: disable=too-many-arguments -def test_api_statements_post_statements_list_with_duplicates( +def test_api_statements_post_list_with_duplicates( backend, monkeypatch, auth_credentials, es_data_stream, mongo, clickhouse ): """Test the post statements API route with duplicate statement IDs should fail.""" # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements/", @@ -426,7 +338,7 @@ def test_api_statements_post_statements_list_with_duplicates( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_post_statements_list_with_duplicate_of_existing_statement( +def test_api_statements_post_list_with_duplicate_of_existing_statement( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the post statements API route, given a statement that already exist in the @@ -437,19 +349,7 @@ def test_api_statements_post_statements_list_with_duplicate_of_existing_statemen monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement_uuid = str(uuid4()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": statement_uuid, - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement(id_=statement_uuid) # Post the statement once. response = client.post( @@ -499,7 +399,7 @@ def test_api_statements_post_statements_list_with_duplicate_of_existing_statemen "backend", [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) -def test_api_statements_post_statements_with_a_failure_during_storage( +def test_api_statements_post_with_failure_during_storage( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure happening during storage.""" @@ -512,19 +412,7 @@ def write_mock(*args, **kwargs): backend_instance = backend() monkeypatch.setattr(backend_instance, "write", write_mock) monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements/", @@ -540,7 +428,7 @@ def write_mock(*args, **kwargs): "backend", [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) -def test_api_statements_post_statements_with_a_failure_during_id_query( +def test_api_statements_post_with_failure_during_id_query( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure during query execution.""" @@ -555,19 +443,7 @@ def query_statements_by_ids_mock(*args, **kwargs): backend_instance, "query_statements_by_ids", query_statements_by_ids_mock ) monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements/", @@ -584,7 +460,7 @@ def query_statements_by_ids_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_post_statements_list_without_statement_forwarding( +def test_api_statements_post_list_without_forwarding( backend, auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the post statements API route, given an empty forwarding configuration, @@ -607,19 +483,7 @@ def spy_mock_forward_xapi_statements(_): ) monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.post( "/xAPI/statements/", @@ -652,7 +516,7 @@ def spy_mock_forward_xapi_statements(_): ), ], ) -async def test_api_statements_post_statements_list_with_statement_forwarding( +async def test_api_statements_post_list_with_forwarding( receiving_backend, forwarding_backend, monkeypatch, @@ -670,19 +534,7 @@ async def test_api_statements_post_statements_list_with_statement_forwarding( """ # pylint: disable=invalid-name,unused-argument,too-many-arguments,too-many-locals - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() # Set-up receiving LRS client with monkeypatch.context() as receiving_patch: diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index a4f43e29c..330bccd0f 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -25,26 +25,18 @@ get_mongo_test_backend, ) -from ..helpers import assert_statement_get_responses_are_equivalent, string_is_date +from ..helpers import ( + assert_statement_get_responses_are_equivalent, + mock_statement, + string_is_date, +) client = TestClient(app) def test_api_statements_put_invalid_parameters(auth_credentials): """Test that using invalid parameters returns the proper status code.""" - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() # Check for 400 status code when unknown parameters are provided response = client.put( @@ -70,19 +62,7 @@ def test_api_statements_put_single_statement_directly( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.put( f"/xAPI/statements/?statementId={statement['id']}", @@ -113,18 +93,7 @@ def test_api_statements_put_enriching_without_existing_values( monkeypatch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "object": {"id": "https://example.com/object-id/1/"}, - "verb": {"id": "https://example.com/verb-id/1/"}, - "id": str(uuid4()), - } + statement = mock_statement() response = client.put( f"/xAPI/statements/?statementId={statement['id']}", @@ -176,18 +145,8 @@ def test_api_statements_put_enriching_with_existing_values( monkeypatch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() ) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "object": {"id": "https://example.com/object-id/1/"}, - "verb": {"id": "https://example.com/verb-id/1/"}, - "id": str(uuid4()), - } + statement = mock_statement() + # Add the field to be tested statement[field] = value @@ -228,19 +187,7 @@ def test_api_statements_put_single_statement_no_trailing_slash( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.put( f"/xAPI/statements?statementId={statement['id']}", @@ -256,25 +203,13 @@ def test_api_statements_put_single_statement_no_trailing_slash( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_put_statement_id_mismatch( +def test_api_statements_put_id_mismatch( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): # pylint: disable=invalid-name,unused-argument """Test the put statements API route when the statementId doesn't match.""" monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement(id_=str(uuid4())) different_statement_id = str(uuid4()) response = client.put( @@ -294,25 +229,13 @@ def test_api_statements_put_statement_id_mismatch( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_put_statements_list_of_one( +def test_api_statements_put_list_of_one( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): # pylint: disable=invalid-name,unused-argument """Test that we fail on PUTs with a list, even if it's one statement.""" monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.put( f"/xAPI/statements/?statementId={statement['id']}", @@ -328,7 +251,7 @@ def test_api_statements_put_statements_list_of_one( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_put_statement_duplicate_of_existing_statement( +def test_api_statements_put_duplicate_of_existing_statement( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the put statements API route, given a statement that already exist in the @@ -337,19 +260,7 @@ def test_api_statements_put_statement_duplicate_of_existing_statement( # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() # Put the statement once. response = client.put( @@ -387,7 +298,7 @@ def test_api_statements_put_statement_duplicate_of_existing_statement( "backend", [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) -def test_api_statements_put_statements_with_a_failure_during_storage( +def test_api_statements_put_with_failure_during_storage( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure happening during storage.""" @@ -400,19 +311,7 @@ def write_mock(*args, **kwargs): backend_instance = backend() monkeypatch.setattr(backend_instance, "write", write_mock) monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.put( f"/xAPI/statements/?statementId={statement['id']}", @@ -428,7 +327,7 @@ def write_mock(*args, **kwargs): "backend", [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) -def test_api_statements_put_statement_with_a_failure_during_id_query( +def test_api_statements_put_with_a_failure_during_id_query( backend, monkeypatch, auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure during query execution.""" @@ -443,19 +342,7 @@ def query_statements_by_ids_mock(*args, **kwargs): backend_instance, "query_statements_by_ids", query_statements_by_ids_mock ) monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.put( f"/xAPI/statements/?statementId={statement['id']}", @@ -472,7 +359,7 @@ def query_statements_by_ids_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) # pylint: disable=too-many-arguments -def test_api_statements_put_statement_without_statement_forwarding( +def test_api_statements_put_without_forwarding( backend, auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the put statements API route, given an empty forwarding configuration, @@ -495,19 +382,7 @@ def spy_mock_forward_xapi_statements(_): ) monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-06-22T08:31:38Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() response = client.put( f"/xAPI/statements/?statementId={statement['id']}", @@ -539,7 +414,7 @@ def spy_mock_forward_xapi_statements(_): ), ], ) -async def test_api_statements_put_statement_with_statement_forwarding( +async def test_api_statements_put_with_forwarding( receiving_backend, forwarding_backend, monkeypatch, @@ -557,19 +432,7 @@ async def test_api_statements_put_statement_with_statement_forwarding( """ # pylint: disable=invalid-name,unused-argument,too-many-arguments,too-many-locals - statement = { - "actor": { - "account": { - "homePage": "https://example.com/homepage/", - "name": str(uuid4()), - }, - "objectType": "Agent", - }, - "id": str(uuid4()), - "object": {"id": "https://example.com/object-id/1/"}, - "timestamp": "2022-03-15T14:07:51Z", - "verb": {"id": "https://example.com/verb-id/1/"}, - } + statement = mock_statement() # Set-up receiving LRS client with monkeypatch.context() as receiving_patch: diff --git a/tests/backends/http/test_async_lrs.py b/tests/backends/http/test_async_lrs.py index 44d4a1ff0..f89f21aec 100644 --- a/tests/backends/http/test_async_lrs.py +++ b/tests/backends/http/test_async_lrs.py @@ -3,12 +3,9 @@ import asyncio import json import logging -import random import time -from datetime import datetime from functools import partial from urllib.parse import ParseResult, parse_qsl, urlencode, urljoin, urlparse -from uuid import uuid4 import httpx import pytest @@ -26,6 +23,8 @@ from ralph.backends.http.base import HTTPBackendStatus from ralph.exceptions import BackendException, BackendParameterException +from ...helpers import mock_statement + # pylint: disable=too-many-lines @@ -37,26 +36,6 @@ async def _unpack_async_generator(async_gen): return result -def _gen_statement(id_=None, verb=None, timestamp=None): - """Generate fake statements with random or provided parameters.""" - if id_ is None: - id_ = str(uuid4()) - if verb is None: - verb = {"id": f"https://w3id.org/xapi/video/verbs/{random.random()}"} - elif isinstance(verb, int): - verb = {"id": f"https://w3id.org/xapi/video/verbs/{verb}"} - if timestamp is None: - timestamp = datetime.strftime( - datetime.fromtimestamp(time.time() - random.random()), - "%Y-%m-%dT%H:%M:%S", - ) - elif isinstance(timestamp, int): - timestamp = datetime.strftime( - datetime.fromtimestamp((time.time() - timestamp), "%Y-%m-%dT%H:%M:%S") - ) - return {"id": id_, "verb": verb, "timestamp": timestamp} - - def test_backend_http_lrs_default_instantiation( monkeypatch, fs ): # pylint:disable = invalid-name @@ -250,11 +229,11 @@ async def test_backends_http_lrs_read_max_statements( chunk_size = 3 statements = { - "statements": [_gen_statement() for _ in range(chunk_size)], + "statements": [mock_statement() for _ in range(chunk_size)], "more": more_target, } more_statements = { - "statements": [_gen_statement() for _ in range(chunk_size)], + "statements": [mock_statement() for _ in range(chunk_size)], } # Mock GET response of HTTPX for target and "more" target without query parameter @@ -328,7 +307,7 @@ async def test_backends_http_lrs_read_without_target( ) backend = AsyncLRSHTTPBackend(settings) - statements = {"statements": [_gen_statement() for _ in range(3)]} + statements = {"statements": [mock_statement() for _ in range(3)]} # Mock HTTPX GET default_params = LRSStatementsQuery(limit=500).dict( @@ -419,9 +398,9 @@ async def test_backends_http_lrs_read_without_pagination( statements = { "statements": [ - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/paused"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/paused"}), ] } @@ -517,19 +496,19 @@ async def test_backends_http_lrs_read_with_pagination(httpx_mock: HTTPXMock): more_target = "/xAPI/statements/?pit_id=fake-pit-id" statements = { "statements": [ - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), - _gen_statement( + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), + mock_statement( verb={"id": "https://w3id.org/xapi/video/verbs/initialized"} ), - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/paused"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/paused"}), ], "more": more_target, } more_statements = { "statements": [ - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/seeked"}), - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), - _gen_statement(verb={"id": "https://w3id.org/xapi/video/verbs/paused"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/seeked"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/played"}), + mock_statement(verb={"id": "https://w3id.org/xapi/video/verbs/paused"}), ] } @@ -672,7 +651,7 @@ async def test_backends_http_lrs_write_without_operation( base_url = "http://fake-lrs.com" target = "/xAPI/statements/" - data = [_gen_statement() for _ in range(6)] + data = [mock_statement() for _ in range(6)] settings = LRSHTTPBackendSettings( BASE_URL=base_url, @@ -831,7 +810,7 @@ async def test_backends_http_lrs_write_without_target(httpx_mock: HTTPXMock, cap ) backend = AsyncLRSHTTPBackend(settings) - data = [_gen_statement() for _ in range(3)] + data = [mock_statement() for _ in range(3)] # Mock HTTPX POST httpx_mock.add_response( @@ -871,7 +850,7 @@ async def test_backends_http_lrs_write_with_create_or_index_operation( ) backend = AsyncLRSHTTPBackend(settings) - data = [_gen_statement() for _ in range(3)] + data = [mock_statement() for _ in range(3)] # Mock HTTPX POST httpx_mock.add_response(url=urljoin(base_url, target), method="POST", json=data) @@ -905,7 +884,7 @@ async def test_backends_http_lrs_write_backend_exception( ) backend = AsyncLRSHTTPBackend(settings) - data = [_gen_statement()] + data = [mock_statement()] # Mock HTTPX POST httpx_mock.add_response( @@ -968,7 +947,7 @@ async def _simulate_slow_processing(): all_statements = {} for index in range(num_pages): all_statements[index] = { - "statements": [_gen_statement() for _ in range(chunk_size)] + "statements": [mock_statement() for _ in range(chunk_size)] } if index < num_pages - 1: all_statements[index]["more"] = targets[index + 1] @@ -1032,7 +1011,7 @@ async def test_backends_http_lrs_write_concurrency( base_url = "http://fake-lrs.com" - data = [_gen_statement() for _ in range(6)] + data = [mock_statement() for _ in range(6)] # Changing data length might break tests assert len(data) == 6 diff --git a/tests/fixtures/auth.py b/tests/fixtures/auth.py index 23173ed85..da4c83868 100644 --- a/tests/fixtures/auth.py +++ b/tests/fixtures/auth.py @@ -22,7 +22,7 @@ PUBLIC_KEY_ID = "example-key-id" -def create_user( +def mock_basic_auth_user( fs_, username: str, password: str, @@ -91,7 +91,7 @@ def auth_credentials(fs, user_scopes=None, agent=None): if agent is None: agent = {"mbox": "mailto:test_ralph@example.com"} - credentials = create_user(fs, username, password, user_scopes, agent) + credentials = mock_basic_auth_user(fs, username, password, user_scopes, agent) return credentials diff --git a/tests/helpers.py b/tests/helpers.py index 6d3fdb223..3aceccad1 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,8 +1,11 @@ """Utilities for testing Ralph.""" -import datetime import hashlib +import random +import time import uuid -from typing import Optional +from datetime import datetime +from typing import Optional, Union +from uuid import UUID from ralph.utils import statements_are_equivalent @@ -10,7 +13,7 @@ def string_is_date(string: str): """Check if string can be parsed as a date.""" try: - datetime.datetime.fromisoformat(string) + datetime.fromisoformat(string) return True except ValueError: return False @@ -52,7 +55,7 @@ def _all_but_statements(response): ), "Statements in get responses are not equivalent, or not in the same order." -def create_mock_activity(id_: int = 0): +def mock_activity(id_: int = 0): """Create distinct activites with valid IRIs. Args: @@ -65,9 +68,9 @@ def create_mock_activity(id_: int = 0): } -def create_mock_agent( - ifi: str, - id_: int, +def mock_agent( + ifi: str = "mbox", + id_: int = 1, home_page_id: Optional[int] = None, name: Optional[str] = None, use_object_type: bool = True, @@ -111,7 +114,7 @@ def create_mock_agent( if ifi == "account": if home_page_id is None: raise ValueError( - "home_page_id must be defined if using create_mock_agent if " + "home_page_id must be defined if using mock_agent if " "using ifi=='account'" ) agent["account"] = { @@ -120,4 +123,76 @@ def create_mock_agent( } return agent - raise ValueError("No valid ifi was provided to create_mock_agent") + raise ValueError("No valid ifi was provided to mock_agent") + + +def mock_statement( + id_: Optional[Union[UUID, int]] = None, + actor: Optional[Union[dict, int]] = None, + verb: Optional[Union[dict, int]] = None, + object: Optional[Union[dict, int]] = None, + timestamp: Optional[Union[str, int]] = None, +): + """Generate fake statements with random or provided parameters. + Fields `actor`, `verb`, `object`, `timestamp` accept integer values which + can be used to create distinct values identifiable by this integer. For each + variable, using `None` will assign a default value. `timestamp` may be ommited + by using value `""` + Args: + id_: id of the statement + actor: actor of the statement + verb: verb of the statement + object: object of the statement + timestamp: timestamp of the statement. Use `""` to omit timestamp + """ + # pylint: disable=redefined-builtin + + # Id + if id_ is None: + id_ = str(uuid.uuid4()) + + # Actor + if actor is None: + actor = mock_agent() + elif isinstance(actor, int): + actor = mock_agent(id_=actor) + + # Verb + if verb is None: + verb = {"id": f"https://w3id.org/xapi/video/verbs/{random.random()}"} + elif isinstance(verb, int): + verb = {"id": f"https://w3id.org/xapi/video/verbs/{verb}"} + + # Object + if object is None: + object = { + "id": f"http://example.adlnet.gov/xapi/example/activity_{random.random()}" + } + elif isinstance(object, int): + object = {"id": f"http://example.adlnet.gov/xapi/example/activity_{object}"} + + # Timestamp + if timestamp is None: + timestamp = datetime.strftime( + datetime.fromtimestamp(time.time() - random.random()), + "%Y-%m-%dT%H:%M:%S+00:00", + ) + elif isinstance(timestamp, int): + timestamp = datetime.strftime( + datetime.fromtimestamp(1696236665 + timestamp), "%Y-%m-%dT%H:%M:%S+00:00" + ) + elif timestamp == "": + return { + "id": id_, + "actor": actor, + "verb": verb, + "object": object, + } + + return { + "id": id_, + "actor": actor, + "verb": verb, + "object": object, + "timestamp": timestamp, + } diff --git a/tests/test_helpers.py b/tests/test_helpers.py index fc5177665..e40b8dff9 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,6 +6,7 @@ from .helpers import ( assert_statement_get_responses_are_equivalent, + mock_statement, string_is_date, string_is_uuid, ) @@ -121,3 +122,63 @@ def test_helpers_assert_statement_get_responses_are_equivalent_length_error(): assert_statement_get_responses_are_equivalent(get_response_1, get_response_2) with pytest.raises(AssertionError): assert_statement_get_responses_are_equivalent(get_response_2, get_response_1) + + +def test_helpers_mock_statement_no_input(): + """Test that mocked statement have the expected fields.""" + + statement = mock_statement() + + assert "id" in statement + assert "actor" in statement + assert "verb" in statement + assert "object" in statement + assert "timestamp" in statement + + statement = mock_statement(timestamp="") + assert "timestamp" not in statement + + +def test_helpers_mock_statement_value_input(): + """Test that mocked statement have the expected fields with value input.""" + + reference_statement = { + "id": str(uuid4()), + "actor": { + "account": { + "homePage": "https://example.com/homepage/", + "name": str(uuid4()), + }, + "objectType": "Agent", + }, + # Note the second statement has no preexisting ID + "object": {"id": "https://example.com/object-id/1/"}, + "timestamp": "2022-03-15T14:07:51Z", + "verb": {"id": "https://example.com/verb-id/1/"}, + } + + statement = mock_statement( + id_=reference_statement["id"], + actor=reference_statement["actor"], + verb=reference_statement["verb"], + object=reference_statement["object"], + timestamp=reference_statement["timestamp"], + ) + + assert statement == reference_statement + + +@pytest.mark.parametrize("field", ["actor", "verb", "object", "timestamp"]) +@pytest.mark.parametrize("integer", [0, 1, 5]) +def test_helpers_mock_statement_integer_input(field, integer): + """Test that mocked statement fields behave properly with integer input.""" + + # Test that fields have same values for same integer input + statement_1 = mock_statement(**{field: integer}) + statement_2 = mock_statement(**{field: integer}) + assert statement_1[field] == statement_2[field] + + # Test that fields have different values for different integer input + statement_1 = mock_statement(**{field: integer}) + statement_2 = mock_statement(**{field: integer + 1}) + assert statement_1[field] != statement_2[field] From d7afb5bbbf059788321af5f92adb40d425dfbce7 Mon Sep 17 00:00:00 2001 From: SergioSim Date: Tue, 17 Oct 2023 14:57:27 +0200 Subject: [PATCH 38/65] =?UTF-8?q?=E2=9C=A8(cli)=20sort=20backend=20argumen?= =?UTF-8?q?ts=20alphabetically?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reading the ralph cli usage strings for the read/write/list/runserver commands is easier when backends and their arguments are listed in alphabetical order. --- CHANGELOG.md | 1 + src/ralph/cli.py | 71 +++--- tests/test_cli_usage.py | 521 ++++++++++++++++++++-------------------- 3 files changed, 302 insertions(+), 291 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97118b8f6..cef870b93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ have an authority field matching that of the user with camelCase alias, in `LRSStatementsQuery` - API: Add `RALPH_LRS_RESTRICT_BY_AUTHORITY` option making `?mine=True` implicit +- CLI: list cli usage strings in alphabetical order ### Fixed diff --git a/src/ralph/cli.py b/src/ralph/cli.py index 1281330be..fd2fd3fdd 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -205,44 +205,47 @@ def backends_options(name=None, backend_types: List[BaseModel] = None): def wrapper(command): backend_names = [] - for backend_type in backend_types: - for backend_name, backend in backend_type: - backend_name = backend_name.lower() - backend_names.append(backend_name) - for field_name, field in backend: - field_type = backend.__fields__[field_name].type_ - field_name = f"{backend_name}-{field_name.lower()}".replace( - "_", "-" - ) - option = f"--{field_name}" - option_kwargs = {} - # If the field is a boolean, convert it to a flag option - if field_type is bool: - option = f"{option}/--no-{field_name}" - option_kwargs["is_flag"] = True - elif field_type is dict: - option_kwargs["type"] = CommaSeparatedKeyValueParamType() - elif field_type is CommaSeparatedTuple: - option_kwargs["type"] = CommaSeparatedTupleParamType() - elif isclass(field_type) and issubclass(field_type, ClientOptions): - option_kwargs["type"] = ClientOptionsParamType(field_type) - elif isclass(field_type) and issubclass( - field_type, HeadersParameters - ): - option_kwargs["type"] = HeadersParametersParamType(field_type) - elif field_type is Path: - option_kwargs["type"] = click.Path() - - command = optgroup.option( - option.lower(), default=field, **option_kwargs - )(command) - - command = (optgroup.group(f"{backend_name} backend"))(command) + for backend_name, backend in sorted( + [ + name_backend + for backend_type in backend_types + for name_backend in backend_type + ], + key=lambda x: x[0], + reverse=True, + ): + backend_name = backend_name.lower() + backend_names.append(backend_name) + for field_name, field in sorted(backend, key=lambda x: x[0], reverse=True): + field_type = backend.__fields__[field_name].type_ + field_name = f"{backend_name}-{field_name.lower()}".replace("_", "-") + option = f"--{field_name}" + option_kwargs = {} + # If the field is a boolean, convert it to a flag option + if field_type is bool: + option = f"{option}/--no-{field_name}" + option_kwargs["is_flag"] = True + elif field_type is dict: + option_kwargs["type"] = CommaSeparatedKeyValueParamType() + elif field_type is CommaSeparatedTuple: + option_kwargs["type"] = CommaSeparatedTupleParamType() + elif isclass(field_type) and issubclass(field_type, ClientOptions): + option_kwargs["type"] = ClientOptionsParamType(field_type) + elif isclass(field_type) and issubclass(field_type, HeadersParameters): + option_kwargs["type"] = HeadersParametersParamType(field_type) + elif field_type is Path: + option_kwargs["type"] = click.Path() + + command = optgroup.option( + option.lower(), default=field, **option_kwargs + )(command) + + command = (optgroup.group(f"{backend_name} backend"))(command) command = click.option( "-b", "--backend", - type=click.Choice(backend_names), + type=click.Choice(sorted(backend_names)), required=True, help="Backend", )(command) diff --git a/tests/test_cli_usage.py b/tests/test_cli_usage.py index 6101e90f3..859eb0141 100644 --- a/tests/test_cli_usage.py +++ b/tests/test_cli_usage.py @@ -111,95 +111,95 @@ def test_cli_read_command_usage(): "Usage: ralph read [OPTIONS] [ARCHIVE]\n\n" " Read an archive or records from a configured backend.\n\n" "Options:\n" - " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|swift|s3|lrs|" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|lrs|mongo|s3|swift|" "ws]\n" " Backend [required]\n" - " ws backend: \n" - " --ws-uri TEXT\n" + " async_es backend: \n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-default-index TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-refresh-after-write TEXT\n" + " async_mongo backend: \n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-connection-uri TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-locale-encoding TEXT\n" + " clickhouse backend: \n" + " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" + " --clickhouse-database TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" + " --clickhouse-event-table-name TEXT\n" + " --clickhouse-host TEXT\n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-password TEXT\n" + " --clickhouse-port INTEGER\n" + " --clickhouse-username TEXT\n" + " es backend: \n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-default-chunk-size INTEGER\n" + " --es-default-index TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-locale-encoding TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-refresh-after-write TEXT\n" + " fs backend: \n" + " --fs-default-chunk-size INTEGER\n" + " --fs-default-directory-path PATH\n" + " --fs-default-query-string TEXT\n" + " --fs-locale-encoding TEXT\n" + " ldp backend: \n" + " --ldp-application-key TEXT\n" + " --ldp-application-secret TEXT\n" + " --ldp-consumer-key TEXT\n" + " --ldp-default-stream-id TEXT\n" + " --ldp-endpoint TEXT\n" + " --ldp-request-timeout TEXT\n" + " --ldp-service-name TEXT\n" " lrs backend: \n" - " --lrs-statements-endpoint TEXT\n" - " --lrs-status-endpoint TEXT\n" + " --lrs-base-url TEXT\n" " --lrs-headers KEY=VALUE,KEY=VALUE\n" " --lrs-password TEXT\n" + " --lrs-statements-endpoint TEXT\n" + " --lrs-status-endpoint TEXT\n" " --lrs-username TEXT\n" - " --lrs-base-url TEXT\n" + " mongo backend: \n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-connection-uri TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-locale-encoding TEXT\n" " s3 backend: \n" - " --s3-locale-encoding TEXT\n" - " --s3-default-chunk-size INTEGER\n" + " --s3-access-key-id TEXT\n" " --s3-default-bucket-name TEXT\n" + " --s3-default-chunk-size INTEGER\n" " --s3-default-region TEXT\n" " --s3-endpoint-url TEXT\n" - " --s3-session-token TEXT\n" + " --s3-locale-encoding TEXT\n" " --s3-secret-access-key TEXT\n" - " --s3-access-key-id TEXT\n" + " --s3-session-token TEXT\n" " swift backend: \n" - " --swift-locale-encoding TEXT\n" + " --swift-auth-url TEXT\n" " --swift-default-container TEXT\n" - " --swift-user-domain-name TEXT\n" + " --swift-identity-api-version TEXT\n" + " --swift-locale-encoding TEXT\n" " --swift-object-storage-url TEXT\n" - " --swift-region-name TEXT\n" + " --swift-password TEXT\n" " --swift-project-domain-name TEXT\n" - " --swift-tenant-name TEXT\n" + " --swift-region-name TEXT\n" " --swift-tenant-id TEXT\n" - " --swift-identity-api-version TEXT\n" - " --swift-password TEXT\n" + " --swift-tenant-name TEXT\n" " --swift-username TEXT\n" - " --swift-auth-url TEXT\n" - " mongo backend: \n" - " --mongo-locale-encoding TEXT\n" - " --mongo-default-chunk-size INTEGER\n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-default-collection TEXT\n" - " --mongo-default-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " ldp backend: \n" - " --ldp-service-name TEXT\n" - " --ldp-request-timeout TEXT\n" - " --ldp-endpoint TEXT\n" - " --ldp-default-stream-id TEXT\n" - " --ldp-consumer-key TEXT\n" - " --ldp-application-secret TEXT\n" - " --ldp-application-key TEXT\n" - " fs backend: \n" - " --fs-locale-encoding TEXT\n" - " --fs-default-query-string TEXT\n" - " --fs-default-directory-path PATH\n" - " --fs-default-chunk-size INTEGER\n" - " es backend: \n" - " --es-refresh-after-write TEXT\n" - " --es-point-in-time-keep-alive TEXT\n" - " --es-locale-encoding TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" - " --es-default-index TEXT\n" - " --es-default-chunk-size INTEGER\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-allow-yellow-status / --no-es-allow-yellow-status\n" - " clickhouse backend: \n" - " --clickhouse-locale-encoding TEXT\n" - " --clickhouse-default-chunk-size INTEGER\n" - " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" - " --clickhouse-password TEXT\n" - " --clickhouse-username TEXT\n" - " --clickhouse-event-table-name TEXT\n" - " --clickhouse-database TEXT\n" - " --clickhouse-port INTEGER\n" - " --clickhouse-host TEXT\n" - " async_mongo backend: \n" - " --async-mongo-locale-encoding TEXT\n" - " --async-mongo-default-chunk-size INTEGER\n" - " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --async-mongo-default-collection TEXT\n" - " --async-mongo-default-database TEXT\n" - " --async-mongo-connection-uri TEXT\n" - " async_es backend: \n" - " --async-es-refresh-after-write TEXT\n" - " --async-es-point-in-time-keep-alive TEXT\n" - " --async-es-locale-encoding TEXT\n" - " --async-es-hosts VALUE1,VALUE2,VALUE3\n" - " --async-es-default-index TEXT\n" - " --async-es-default-chunk-size INTEGER\n" - " --async-es-client-options KEY=VALUE,KEY=VALUE\n" - " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --swift-user-domain-name TEXT\n" + " ws backend: \n" + " --ws-uri TEXT\n" " -c, --chunk-size INTEGER Get events by chunks of size #\n" " -t, --target TEXT Endpoint from which to read events (e.g.\n" " `/statements`)\n" @@ -218,7 +218,7 @@ def test_cli_read_command_usage(): assert ( "Error: Missing option '-b' / '--backend'. " "Choose from:\n\tasync_es,\n\tasync_mongo,\n\tclickhouse,\n\tes,\n\tfs,\n\tldp," - "\n\tmongo,\n\tswift,\n\ts3,\n\tlrs,\n\tws\n" + "\n\tlrs,\n\tmongo,\n\ts3,\n\tswift,\n\tws\n" ) in result.output @@ -232,85 +232,85 @@ def test_cli_list_command_usage(): "Usage: ralph list [OPTIONS]\n\n" " List available documents from a configured data backend.\n\n" "Options:\n" - " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|swift|s3]\n" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|s3|swift]\n" " Backend [required]\n" + " async_es backend: \n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-default-index TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-refresh-after-write TEXT\n" + " async_mongo backend: \n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-connection-uri TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-locale-encoding TEXT\n" + " clickhouse backend: \n" + " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" + " --clickhouse-database TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" + " --clickhouse-event-table-name TEXT\n" + " --clickhouse-host TEXT\n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-password TEXT\n" + " --clickhouse-port INTEGER\n" + " --clickhouse-username TEXT\n" + " es backend: \n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-default-chunk-size INTEGER\n" + " --es-default-index TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-locale-encoding TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-refresh-after-write TEXT\n" + " fs backend: \n" + " --fs-default-chunk-size INTEGER\n" + " --fs-default-directory-path PATH\n" + " --fs-default-query-string TEXT\n" + " --fs-locale-encoding TEXT\n" + " ldp backend: \n" + " --ldp-application-key TEXT\n" + " --ldp-application-secret TEXT\n" + " --ldp-consumer-key TEXT\n" + " --ldp-default-stream-id TEXT\n" + " --ldp-endpoint TEXT\n" + " --ldp-request-timeout TEXT\n" + " --ldp-service-name TEXT\n" + " mongo backend: \n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-connection-uri TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-locale-encoding TEXT\n" " s3 backend: \n" - " --s3-locale-encoding TEXT\n" - " --s3-default-chunk-size INTEGER\n" + " --s3-access-key-id TEXT\n" " --s3-default-bucket-name TEXT\n" + " --s3-default-chunk-size INTEGER\n" " --s3-default-region TEXT\n" " --s3-endpoint-url TEXT\n" - " --s3-session-token TEXT\n" + " --s3-locale-encoding TEXT\n" " --s3-secret-access-key TEXT\n" - " --s3-access-key-id TEXT\n" + " --s3-session-token TEXT\n" " swift backend: \n" - " --swift-locale-encoding TEXT\n" + " --swift-auth-url TEXT\n" " --swift-default-container TEXT\n" - " --swift-user-domain-name TEXT\n" + " --swift-identity-api-version TEXT\n" + " --swift-locale-encoding TEXT\n" " --swift-object-storage-url TEXT\n" - " --swift-region-name TEXT\n" + " --swift-password TEXT\n" " --swift-project-domain-name TEXT\n" - " --swift-tenant-name TEXT\n" + " --swift-region-name TEXT\n" " --swift-tenant-id TEXT\n" - " --swift-identity-api-version TEXT\n" - " --swift-password TEXT\n" + " --swift-tenant-name TEXT\n" " --swift-username TEXT\n" - " --swift-auth-url TEXT\n" - " mongo backend: \n" - " --mongo-locale-encoding TEXT\n" - " --mongo-default-chunk-size INTEGER\n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-default-collection TEXT\n" - " --mongo-default-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " ldp backend: \n" - " --ldp-service-name TEXT\n" - " --ldp-request-timeout TEXT\n" - " --ldp-endpoint TEXT\n" - " --ldp-default-stream-id TEXT\n" - " --ldp-consumer-key TEXT\n" - " --ldp-application-secret TEXT\n" - " --ldp-application-key TEXT\n" - " fs backend: \n" - " --fs-locale-encoding TEXT\n" - " --fs-default-query-string TEXT\n" - " --fs-default-directory-path PATH\n" - " --fs-default-chunk-size INTEGER\n" - " es backend: \n" - " --es-refresh-after-write TEXT\n" - " --es-point-in-time-keep-alive TEXT\n" - " --es-locale-encoding TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" - " --es-default-index TEXT\n" - " --es-default-chunk-size INTEGER\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-allow-yellow-status / --no-es-allow-yellow-status\n" - " clickhouse backend: \n" - " --clickhouse-locale-encoding TEXT\n" - " --clickhouse-default-chunk-size INTEGER\n" - " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" - " --clickhouse-password TEXT\n" - " --clickhouse-username TEXT\n" - " --clickhouse-event-table-name TEXT\n" - " --clickhouse-database TEXT\n" - " --clickhouse-port INTEGER\n" - " --clickhouse-host TEXT\n" - " async_mongo backend: \n" - " --async-mongo-locale-encoding TEXT\n" - " --async-mongo-default-chunk-size INTEGER\n" - " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --async-mongo-default-collection TEXT\n" - " --async-mongo-default-database TEXT\n" - " --async-mongo-connection-uri TEXT\n" - " async_es backend: \n" - " --async-es-refresh-after-write TEXT\n" - " --async-es-point-in-time-keep-alive TEXT\n" - " --async-es-locale-encoding TEXT\n" - " --async-es-hosts VALUE1,VALUE2,VALUE3\n" - " --async-es-default-index TEXT\n" - " --async-es-default-chunk-size INTEGER\n" - " --async-es-client-options KEY=VALUE,KEY=VALUE\n" - " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --swift-user-domain-name TEXT\n" " -t, --target TEXT Container to list events from\n" " -n, --new / -a, --all List not fetched (or all) documents\n" " -D, --details / -I, --ids Get documents detailed output (JSON)\n" @@ -321,8 +321,8 @@ def test_cli_list_command_usage(): assert result.exit_code > 0 assert ( "Error: Missing option '-b' / '--backend'. Choose from:\n\tasync_es,\n\t" - "async_mongo,\n\tclickhouse,\n\tes,\n\tfs,\n\tldp,\n\tmongo,\n\tswift," - "\n\ts3\n" + "async_mongo,\n\tclickhouse,\n\tes,\n\tfs,\n\tldp,\n\tmongo,\n\ts3," + "\n\tswift\n" ) in result.output @@ -337,93 +337,93 @@ def test_cli_write_command_usage(): "Usage: ralph write [OPTIONS]\n\n" " Write an archive to a configured backend.\n\n" "Options:\n" - " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|mongo|swift|s3|lrs]" + " -b, --backend [async_es|async_mongo|clickhouse|es|fs|ldp|lrs|mongo|s3|swift]" "\n" " Backend [required]\n" + " async_es backend: \n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-default-index TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-refresh-after-write TEXT\n" + " async_mongo backend: \n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-connection-uri TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-locale-encoding TEXT\n" + " clickhouse backend: \n" + " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" + " --clickhouse-database TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" + " --clickhouse-event-table-name TEXT\n" + " --clickhouse-host TEXT\n" + " --clickhouse-locale-encoding TEXT\n" + " --clickhouse-password TEXT\n" + " --clickhouse-port INTEGER\n" + " --clickhouse-username TEXT\n" + " es backend: \n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-default-chunk-size INTEGER\n" + " --es-default-index TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-locale-encoding TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-refresh-after-write TEXT\n" + " fs backend: \n" + " --fs-default-chunk-size INTEGER\n" + " --fs-default-directory-path PATH\n" + " --fs-default-query-string TEXT\n" + " --fs-locale-encoding TEXT\n" + " ldp backend: \n" + " --ldp-application-key TEXT\n" + " --ldp-application-secret TEXT\n" + " --ldp-consumer-key TEXT\n" + " --ldp-default-stream-id TEXT\n" + " --ldp-endpoint TEXT\n" + " --ldp-request-timeout TEXT\n" + " --ldp-service-name TEXT\n" " lrs backend: \n" - " --lrs-statements-endpoint TEXT\n" - " --lrs-status-endpoint TEXT\n" + " --lrs-base-url TEXT\n" " --lrs-headers KEY=VALUE,KEY=VALUE\n" " --lrs-password TEXT\n" + " --lrs-statements-endpoint TEXT\n" + " --lrs-status-endpoint TEXT\n" " --lrs-username TEXT\n" - " --lrs-base-url TEXT\n" + " mongo backend: \n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-connection-uri TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-locale-encoding TEXT\n" " s3 backend: \n" - " --s3-locale-encoding TEXT\n" - " --s3-default-chunk-size INTEGER\n" + " --s3-access-key-id TEXT\n" " --s3-default-bucket-name TEXT\n" + " --s3-default-chunk-size INTEGER\n" " --s3-default-region TEXT\n" " --s3-endpoint-url TEXT\n" - " --s3-session-token TEXT\n" + " --s3-locale-encoding TEXT\n" " --s3-secret-access-key TEXT\n" - " --s3-access-key-id TEXT\n" + " --s3-session-token TEXT\n" " swift backend: \n" - " --swift-locale-encoding TEXT\n" + " --swift-auth-url TEXT\n" " --swift-default-container TEXT\n" - " --swift-user-domain-name TEXT\n" + " --swift-identity-api-version TEXT\n" + " --swift-locale-encoding TEXT\n" " --swift-object-storage-url TEXT\n" - " --swift-region-name TEXT\n" + " --swift-password TEXT\n" " --swift-project-domain-name TEXT\n" - " --swift-tenant-name TEXT\n" + " --swift-region-name TEXT\n" " --swift-tenant-id TEXT\n" - " --swift-identity-api-version TEXT\n" - " --swift-password TEXT\n" + " --swift-tenant-name TEXT\n" " --swift-username TEXT\n" - " --swift-auth-url TEXT\n" - " mongo backend: \n" - " --mongo-locale-encoding TEXT\n" - " --mongo-default-chunk-size INTEGER\n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-default-collection TEXT\n" - " --mongo-default-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " ldp backend: \n" - " --ldp-service-name TEXT\n" - " --ldp-request-timeout TEXT\n" - " --ldp-endpoint TEXT\n" - " --ldp-default-stream-id TEXT\n" - " --ldp-consumer-key TEXT\n" - " --ldp-application-secret TEXT\n" - " --ldp-application-key TEXT\n" - " fs backend: \n" - " --fs-locale-encoding TEXT\n" - " --fs-default-query-string TEXT\n" - " --fs-default-directory-path PATH\n" - " --fs-default-chunk-size INTEGER\n" - " es backend: \n" - " --es-refresh-after-write TEXT\n" - " --es-point-in-time-keep-alive TEXT\n" - " --es-locale-encoding TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" - " --es-default-index TEXT\n" - " --es-default-chunk-size INTEGER\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-allow-yellow-status / --no-es-allow-yellow-status\n" - " clickhouse backend: \n" - " --clickhouse-locale-encoding TEXT\n" - " --clickhouse-default-chunk-size INTEGER\n" - " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" - " --clickhouse-password TEXT\n" - " --clickhouse-username TEXT\n" - " --clickhouse-event-table-name TEXT\n" - " --clickhouse-database TEXT\n" - " --clickhouse-port INTEGER\n" - " --clickhouse-host TEXT\n" - " async_mongo backend: \n" - " --async-mongo-locale-encoding TEXT\n" - " --async-mongo-default-chunk-size INTEGER\n" - " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --async-mongo-default-collection TEXT\n" - " --async-mongo-default-database TEXT\n" - " --async-mongo-connection-uri TEXT\n" - " async_es backend: \n" - " --async-es-refresh-after-write TEXT\n" - " --async-es-point-in-time-keep-alive TEXT\n" - " --async-es-locale-encoding TEXT\n" - " --async-es-hosts VALUE1,VALUE2,VALUE3\n" - " --async-es-default-index TEXT\n" - " --async-es-default-chunk-size INTEGER\n" - " --async-es-client-options KEY=VALUE,KEY=VALUE\n" - " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --swift-user-domain-name TEXT\n" " -c, --chunk-size INTEGER Get events by chunks of size #\n" " -f, --force Overwrite existing archives or records\n" " -I, --ignore-errors Continue writing regardless of raised errors" @@ -445,7 +445,7 @@ def test_cli_write_command_usage(): assert result.exit_code > 0 assert ( "Missing option '-b' / '--backend'. Choose from:\n\tasync_es,\n\tasync_mongo,\n" - "\tclickhouse,\n\tes,\n\tfs,\n\tldp,\n\tmongo,\n\tswift,\n\ts3,\n\tlrs\n" + "\tclickhouse,\n\tes,\n\tfs,\n\tldp,\n\tlrs,\n\tmongo,\n\ts3,\n\tswift\n" ) in result.output @@ -461,58 +461,65 @@ def test_cli_runserver_command_usage(): "Options:\n" " -b, --backend [async_es|async_mongo|clickhouse|es|fs|mongo]\n" " Backend [required]\n" - " mongo backend: \n" - " --mongo-locale-encoding TEXT\n" - " --mongo-default-chunk-size INTEGER\n" - " --mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --mongo-default-collection TEXT\n" - " --mongo-default-database TEXT\n" - " --mongo-connection-uri TEXT\n" - " fs backend: \n" - " --fs-default-lrs-file TEXT\n" - " --fs-locale-encoding TEXT\n" - " --fs-default-query-string TEXT\n" - " --fs-default-directory-path PATH\n" - " --fs-default-chunk-size INTEGER\n" - " es backend: \n" - " --es-refresh-after-write TEXT\n" - " --es-point-in-time-keep-alive TEXT\n" - " --es-locale-encoding TEXT\n" - " --es-hosts VALUE1,VALUE2,VALUE3\n" - " --es-default-index TEXT\n" - " --es-default-chunk-size INTEGER\n" - " --es-client-options KEY=VALUE,KEY=VALUE\n" - " --es-allow-yellow-status / --no-es-allow-yellow-status\n" + " async_es backend: \n" + " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --async-es-client-options KEY=VALUE,KEY=VALUE\n" + " --async-es-default-chunk-size INTEGER\n" + " --async-es-default-index TEXT\n" + " --async-es-hosts VALUE1,VALUE2,VALUE3\n" + " --async-es-locale-encoding TEXT\n" + " --async-es-point-in-time-keep-alive TEXT\n" + " --async-es-refresh-after-write TEXT\n" + " async_mongo backend: \n" + " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --async-mongo-connection-uri TEXT\n" + " --async-mongo-default-chunk-size INTEGER\n" + " --async-mongo-default-collection TEXT\n" + " --async-mongo-default-database TEXT\n" + " --async-mongo-locale-encoding TEXT\n" " clickhouse backend: \n" + " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" + " --clickhouse-database TEXT\n" + " --clickhouse-default-chunk-size INTEGER\n" + " --clickhouse-event-table-name TEXT\n" + " --clickhouse-host TEXT\n" " --clickhouse-ids-chunk-size INTEGER\n" " --clickhouse-locale-encoding TEXT\n" - " --clickhouse-default-chunk-size INTEGER\n" - " --clickhouse-client-options KEY=VALUE,KEY=VALUE\n" " --clickhouse-password TEXT\n" - " --clickhouse-username TEXT\n" - " --clickhouse-event-table-name TEXT\n" - " --clickhouse-database TEXT\n" " --clickhouse-port INTEGER\n" - " --clickhouse-host TEXT\n" - " async_mongo backend: \n" - " --async-mongo-locale-encoding TEXT\n" - " --async-mongo-default-chunk-size INTEGER\n" - " --async-mongo-client-options KEY=VALUE,KEY=VALUE\n" - " --async-mongo-default-collection TEXT\n" - " --async-mongo-default-database TEXT\n" - " --async-mongo-connection-uri TEXT\n" - " async_es backend: \n" - " --async-es-refresh-after-write TEXT\n" - " --async-es-point-in-time-keep-alive TEXT\n" - " --async-es-locale-encoding TEXT\n" - " --async-es-hosts VALUE1,VALUE2,VALUE3\n" - " --async-es-default-index TEXT\n" - " --async-es-default-chunk-size INTEGER\n" - " --async-es-client-options KEY=VALUE,KEY=VALUE\n" - " --async-es-allow-yellow-status / --no-async-es-allow-yellow-status\n" + " --clickhouse-username TEXT\n" + " es backend: \n" + " --es-allow-yellow-status / --no-es-allow-yellow-status\n" + " --es-client-options KEY=VALUE,KEY=VALUE\n" + " --es-default-chunk-size INTEGER\n" + " --es-default-index TEXT\n" + " --es-hosts VALUE1,VALUE2,VALUE3\n" + " --es-locale-encoding TEXT\n" + " --es-point-in-time-keep-alive TEXT\n" + " --es-refresh-after-write TEXT\n" + " fs backend: \n" + " --fs-default-chunk-size INTEGER\n" + " --fs-default-directory-path PATH\n" + " --fs-default-lrs-file TEXT\n" + " --fs-default-query-string TEXT\n" + " --fs-locale-encoding TEXT\n" + " mongo backend: \n" + " --mongo-client-options KEY=VALUE,KEY=VALUE\n" + " --mongo-connection-uri TEXT\n" + " --mongo-default-chunk-size INTEGER\n" + " --mongo-default-collection TEXT\n" + " --mongo-default-database TEXT\n" + " --mongo-locale-encoding TEXT\n" " -h, --host TEXT LRS server host name\n" " -p, --port INTEGER LRS server port\n" " --help Show this message and exit.\n" ) assert result.exit_code == 0 assert expected_output in result.output + + result = runner.invoke(cli, ["runserver"]) + assert result.exit_code > 0 + assert ( + "Missing option '-b' / '--backend'. Choose from:\n\tasync_es,\n\tasync_mongo,\n" + "\tclickhouse,\n\tes,\n\tfs,\n\tmongo\n" + ) in result.output From cf9977fb0b75f4bd9d2f0abda78c816f32645f39 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Thu, 19 Oct 2023 11:41:31 +0200 Subject: [PATCH 39/65] =?UTF-8?q?=F0=9F=93=9D(docs)=20add=20useful=20comma?= =?UTF-8?q?nds=20to=20"contributing"=20in=20README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As a new developper on Ralph, it is not obvious what commands are useful and where to find them. This commit adds info on testing, linting, and updating dependancies to README in the contributing section. --- README.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/README.md b/README.md index 62ea70414..fc18c976b 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,39 @@ We try to raise our code quality standards and expect contributors to follow the recommendations from our [handbook](https://handbook.openfun.fr). +### Useful commands + +Bootstrap the project: + +``` +$ make bootstrap +``` + +Run tests: + +``` +$ make test +``` + +Run all linters: + +``` +$ make lint +``` + +If you add new dependencies to the project, you will have to rebuild the Docker +image (and the development environment): + +``` +$ make down && make bootstrap +``` + +You can explore all available rules using: + +``` +$ make help +``` + ## License This work is released under the MIT License (see [LICENSE](./LICENSE.md)). From f8b75c88323104cad063054befb06e55bf95b6d3 Mon Sep 17 00:00:00 2001 From: Claude Dioudonnat Date: Fri, 20 Oct 2023 17:03:41 +0200 Subject: [PATCH 40/65] =?UTF-8?q?=F0=9F=90=9B(helm)=20fix=20clickhouse=20H?= =?UTF-8?q?elm's=20version=20(#477)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The version 23.x for clickhouse is the docker image version. The last major version for the clickhouse's helm chart is 4. --- CHANGELOG.md | 1 + src/helm/ralph/Chart.yaml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cef870b93..4b6548f45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ have an authority field matching that of the user - API: Add `RALPH_LRS_RESTRICT_BY_AUTHORITY` option making `?mine=True` implicit - CLI: list cli usage strings in alphabetical order +- Helm: Fix clickhouse version ### Fixed diff --git a/src/helm/ralph/Chart.yaml b/src/helm/ralph/Chart.yaml index 9638b02a3..bc4130102 100644 --- a/src/helm/ralph/Chart.yaml +++ b/src/helm/ralph/Chart.yaml @@ -12,6 +12,6 @@ dependencies: repository: oci://registry-1.docker.io/bitnamicharts condition: mongodb.enabled - name: clickhouse - version: 23.x.x + version: 4.x.x repository: oci://registry-1.docker.io/bitnamicharts condition: clickhouse.enabled From 8493ac62174bf814d1bfba80eda9b92017c1d740 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:51:44 +0000 Subject: [PATCH 41/65] =?UTF-8?q?=E2=AC=86=EF=B8=8F(project)=20upgrade=20p?= =?UTF-8?q?ython=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit | datasource | package | from | to | | ---------- | ---------- | ------- | ------- | | pypi | black | 23.9.1 | 23.10.0 | | pypi | fastapi | 0.103.2 | 0.104.0 | | pypi | hypothesis | 6.88.0 | 6.88.1 | | pypi | pylint | 3.0.1 | 3.0.2 | --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index d719bd36b..b232ec5c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,11 +71,11 @@ cli = dev = anyio<4.0.1 # unpin until fastapi supports new major version of anyio bandit==1.7.5 - black==23.9.1 + black==23.10.0 cryptography==41.0.4 factory-boy==3.3.0 flake8==6.1.0 - hypothesis==6.88.0 + hypothesis==6.88.1 isort==5.12.0 logging-gelf==0.0.31 mkdocs==1.5.3 @@ -85,7 +85,7 @@ dev = moto==4.2.6 pydocstyle==6.3.0 pyfakefs==5.3.0 - pylint==3.0.1 + pylint==3.0.2 pytest==7.4.2 pytest-asyncio==0.21.1 pytest-cov==4.1.0 @@ -96,7 +96,7 @@ ci = twine==4.0.2 lrs = bcrypt==4.0.1 - fastapi==0.103.2 + fastapi==0.104.0 cachetools==5.3.1 ; We temporary pin `h11` to avoid pip downloading the latest version to solve a ; dependency conflict caused by `httpx` which requires httpcore>=0.15.0,<0.16.0 and From 8495825901782c799cca09bc0f02ea1477717f9b Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Mon, 23 Oct 2023 10:27:14 +0200 Subject: [PATCH 42/65] =?UTF-8?q?=F0=9F=93=9D(project)=20update=20CHANGELO?= =?UTF-8?q?G.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed: - Upgrade `fastapi` to `0.104.0` --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b6548f45..aa5d4f9dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ methods under the unified `lrs` backend interface [BC] - `GET /statements` now has "mine" option which matches statements that have an authority field matching that of the user - CLI: change `push` to `write` and `fetch` to `read` [BC] -- Upgrade `fastapi` to `0.103.2` +- Upgrade `fastapi` to `0.104.0` - Upgrade `more-itertools` to `10.1.0` - Upgrade `sentry_sdk` to `1.32.0` - Upgrade `uvicorn` to `0.23.2` From 85be17eae92f01d6f3b4bc932455c01ad7ea0793 Mon Sep 17 00:00:00 2001 From: Rodolphe Prin Date: Fri, 15 Sep 2023 10:16:04 +0200 Subject: [PATCH 43/65] =?UTF-8?q?=F0=9F=94=A7(helm)=20improve=20volumes=20?= =?UTF-8?q?and=20ingress=20configurations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We need more flexibility to better integrate Ralph deployment using the official Helm Chart. --- CHANGELOG.md | 1 + src/helm/ralph/templates/deployment.yaml | 2 +- src/helm/ralph/templates/ingress.yaml | 12 +++++++++--- src/helm/ralph/templates/pvc.yml | 2 ++ src/helm/ralph/values.yaml | 21 ++++++++++----------- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa5d4f9dd..79aaa12f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ have an authority field matching that of the user implicit - CLI: list cli usage strings in alphabetical order - Helm: Fix clickhouse version +- Helm: improve volumes and ingress configurations ### Fixed diff --git a/src/helm/ralph/templates/deployment.yaml b/src/helm/ralph/templates/deployment.yaml index c0279f9db..ade4a3096 100644 --- a/src/helm/ralph/templates/deployment.yaml +++ b/src/helm/ralph/templates/deployment.yaml @@ -72,7 +72,7 @@ spec: - name: {{ .Values.volumes.history.name }} {{- if .Values.volumes.history.enabled }} persistentVolumeClaim: - claimName: {{ .Values.volumes.history.claimName }} + claimName: {{ if .Values.volumes.history.existingClaim }}{{ .Values.volumes.history.existingClaim }}{{- else }}{{ .Values.volumes.history.claimName }}{{- end }} {{- else }} emptyDir: {} {{- end }} diff --git a/src/helm/ralph/templates/ingress.yaml b/src/helm/ralph/templates/ingress.yaml index 0bc660216..e42c545a5 100644 --- a/src/helm/ralph/templates/ingress.yaml +++ b/src/helm/ralph/templates/ingress.yaml @@ -11,11 +11,17 @@ metadata: {{- toYaml . | nindent 4 }} {{- end }} spec: - ingressClassName: {{ .Values.ingress.className | quote }} + ingressClassName: {{ .Values.ingress.ingressClassName | quote }} + {{- if .Values.ingress.tls }} tls: + {{- range .Values.ingress.tls }} - hosts: - - {{ .Values.ingress.hostname | quote }} - secretName: {{ printf "%s-tls" .Values.ingress.hostname }} + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} + {{- end }} rules: - host: {{ .Values.ingress.hostname | quote }} http: diff --git a/src/helm/ralph/templates/pvc.yml b/src/helm/ralph/templates/pvc.yml index 61b04ac4f..fde813187 100644 --- a/src/helm/ralph/templates/pvc.yml +++ b/src/helm/ralph/templates/pvc.yml @@ -1,4 +1,5 @@ {{- if .Values.volumes.history.enabled }} +{{- if not .Values.volumes.history.existingClaim -}} apiVersion: v1 kind: PersistentVolumeClaim metadata: @@ -14,3 +15,4 @@ spec: storage: {{ .Values.volumes.history.size }} storageClassName: {{ .Values.volumes.history.storageClass }} {{- end }} +{{- end }} diff --git a/src/helm/ralph/values.yaml b/src/helm/ralph/values.yaml index 9abfa4433..41b917546 100644 --- a/src/helm/ralph/values.yaml +++ b/src/helm/ralph/values.yaml @@ -27,16 +27,12 @@ service: ingress: enabled: false ingressClassName: "" - hostname: "" + hostname: "ralph.example.com" annotations: {} - -persistence: - enabled: true - storageClass: "local-storage" - accessModes: - - ReadWriteMany - size: 2Gi - existingClaim: "" + tls: + - hosts: + - "ralph.example.com" + secretName: "ralph-example-com-tls" affinity: podAntiAffinity: @@ -64,7 +60,7 @@ tolerations: [] resources: {} -envFromSecret: 'ralph-env' +envFromSecret: "ralph-env" envSecrets: {} existingSecret: false @@ -79,10 +75,13 @@ volumes: size: 2Gi accessModes: ReadWriteMany storageClass: "" + # Use an existing claim. If specified, the **history** + # PersistentVolumeClaim will **not** be created. + existingClaim: "" lrs: port: 8080 - authSecretName: 'ralph-lrs-auth' + authSecretName: "ralph-lrs-auth" # Authentication # # For each entry, we expect the following keys: From 247be51047fd284d2a40b870bc7d64943df9cdcd Mon Sep 17 00:00:00 2001 From: Claude Dioudonnat Date: Mon, 23 Oct 2023 10:14:17 +0200 Subject: [PATCH 44/65] =?UTF-8?q?=F0=9F=94=A7(helm)=20improve=20ingress's?= =?UTF-8?q?=20host=20management=20(#476)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For each hosts you can define a list a domain and a common TLS certificate. And fix the ingress to add route for all domain. Remove the ingress values tls and hostname. Co-authored-by: Julien Maupetit --- CHANGELOG.md | 2 ++ src/helm/ralph/templates/ingress.yaml | 34 +++++++++++++++++---------- src/helm/ralph/values.yaml | 10 ++++---- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79aaa12f6..cfe5fa78f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to - Implement Pydantic model for LRS Statements resource query parameters - Implement xAPI LMS Profile statements validation - `EdX` to `xAPI` converters for enrollment events +- Helm: Add variable ``ingress.hosts`` ### Changed @@ -57,6 +58,7 @@ have an authority field matching that of the user - `school`, `course`, `module` context extensions in Edx to xAPI base converter - `name` field in `VideoActivity` xAPI model mistakenly used in `video` profile +- Helm: remove variable ``ingress.hostname`` and ``ingress.tls`` ## [3.9.0] - 2023-07-21 diff --git a/src/helm/ralph/templates/ingress.yaml b/src/helm/ralph/templates/ingress.yaml index e42c545a5..2fe0d4c9b 100644 --- a/src/helm/ralph/templates/ingress.yaml +++ b/src/helm/ralph/templates/ingress.yaml @@ -12,25 +12,33 @@ metadata: {{- end }} spec: ingressClassName: {{ .Values.ingress.ingressClassName | quote }} - {{- if .Values.ingress.tls }} - tls: - {{- range .Values.ingress.tls }} - - hosts: - {{- range .hosts }} - - {{ . | quote }} - {{- end }} - secretName: {{ .secretName }} - {{- end }} - {{- end }} + + {{- $tls := (list) }} rules: - - host: {{ .Values.ingress.hostname | quote }} + {{- $outer := . }} + {{- range .Values.ingress.hosts }} + {{- if .tls }} + {{- $tls = (concat $tls (list .) ) }} + {{- end }} + {{- range .domains }} + - host: {{ . | quote}} http: paths: - path: / pathType: Prefix backend: service: - name: {{ include "ralph.fullname" . }} + name: {{ template "ralph.fullname" $outer }} port: - number: {{ .Values.service.port }} + number: {{ $outer.Values.service.port }} + {{- end }} + {{- end }} + tls: + {{- range $tls }} + - hosts: + {{- range .domains }} + - {{ .| quote }} + {{- end }} + secretName: {{ .tls.secretName }} + {{- end }} {{- end }} diff --git a/src/helm/ralph/values.yaml b/src/helm/ralph/values.yaml index 41b917546..383db4b92 100644 --- a/src/helm/ralph/values.yaml +++ b/src/helm/ralph/values.yaml @@ -27,12 +27,12 @@ service: ingress: enabled: false ingressClassName: "" - hostname: "ralph.example.com" + hosts: + - domains: + - ralph.example.com + tls: + secretName: "ralph-example-com-tls" annotations: {} - tls: - - hosts: - - "ralph.example.com" - secretName: "ralph-example-com-tls" affinity: podAntiAffinity: From 04c636f745822810a90621b149e37d17a1a0a328 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Tue, 24 Oct 2023 17:16:28 +0200 Subject: [PATCH 45/65] =?UTF-8?q?=E2=9C=A8(api)=20add=20option=20to=20enfo?= =?UTF-8?q?rce=20scopes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The current state of Ralph allows to restrict users by authority, but does not allow a/An admin user b/Finer access control (read, write). This PR aims to solve this issue by implementing `RESTRICT_BY_SCOPES` (`scopes` field is already present in user accounts) which restricts access when enabled. --- CHANGELOG.md | 2 + docs/api.md | 24 ++- setup.cfg | 2 +- src/ralph/api/auth/__init__.py | 13 +- src/ralph/api/auth/basic.py | 28 +++- src/ralph/api/auth/oidc.py | 30 +++- src/ralph/api/auth/user.py | 49 +++++- src/ralph/api/routers/statements.py | 39 +++-- src/ralph/conf.py | 18 +- tests/api/auth/test_basic.py | 44 ++--- tests/api/auth/test_oidc.py | 84 ++-------- tests/api/test_statements_get.py | 252 +++++++++++++++++++++++----- tests/api/test_statements_post.py | 156 +++++++++++++---- tests/api/test_statements_put.py | 144 ++++++++++++---- tests/conftest.py | 2 +- tests/fixtures/auth.py | 87 ++++++++-- tests/test_cli.py | 2 +- tests/test_conf.py | 18 ++ 18 files changed, 731 insertions(+), 263 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cfe5fa78f..0c4d4410a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ have an authority field matching that of the user - CLI: list cli usage strings in alphabetical order - Helm: Fix clickhouse version - Helm: improve volumes and ingress configurations +- API: Add `RALPH_LRS_RESTRICT_BY_SCOPE` option enabling endpoint access + control by user scopes ### Fixed diff --git a/docs/api.md b/docs/api.md index 652405e2e..3e62c6a77 100644 --- a/docs/api.md +++ b/docs/api.md @@ -178,7 +178,7 @@ By default, all authenticated users have full read and write access to the serve ### Filtering results by authority (multitenancy) -In Ralph, all incoming statements are assigned an `authority` (or ownership) derived from the user that makes the call. You may restrict read access to users "own" statements (thus enabling multitenancy) by setting the following environment variable: +In Ralph LRS, all incoming statements are assigned an `authority` (or ownership) derived from the user that makes the call. You may restrict read access to users "own" statements (thus enabling multitenancy) by setting the following environment variable: ``` RALPH_LRS_RESTRICT_BY_AUTHORITY = True # Default: False @@ -190,7 +190,27 @@ NB: If not using "scopes", or for users with limited "scopes", using this option #### Scopes -(Work In Progress) +In Ralph, users are assigned scopes which may be used to restrict endpoint access or +functionalities. You may enable this option by setting the following environment variable: + +``` +RALPH_LRS_RESTRICT_BY_SCOPES = True # Default: False +``` + +Valid scopes are a slight variation on those proposed by the +[xAPI specification](https://github.com/adlnet/xAPI-Spec/blob/master/xAPI-Communication.md#details-15): + + +- statements/write +- statements/read/mine +- statements/read +- state/write +- state/read +- define +- profile/write +- profile/read +- all/read +- all ## Forwarding statements diff --git a/setup.cfg b/setup.cfg index b232ec5c4..472414dc5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,7 +142,7 @@ match = ^(?!(setup)\.(py)$).*\.(py)$ [isort] known_ralph=ralph sections=FUTURE,STDLIB,THIRDPARTY,RALPH,FIRSTPARTY,LOCALFOLDER -skip_glob=venv +skip_glob=venv,*/.conda/* profile=black [tool:pytest] diff --git a/src/ralph/api/auth/__init__.py b/src/ralph/api/auth/__init__.py index f5e80b737..80aa52fff 100644 --- a/src/ralph/api/auth/__init__.py +++ b/src/ralph/api/auth/__init__.py @@ -1,12 +1,11 @@ """Main module for Ralph's LRS API authentication.""" -from ralph.api.auth.basic import get_authenticated_user as get_basic_user -from ralph.api.auth.oidc import get_authenticated_user as get_oidc_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.conf import settings # At startup, select the authentication mode that will be used -get_authenticated_user = ( - get_oidc_user - if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC - else get_basic_user -) +if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC: + get_authenticated_user = get_oidc_user +else: + get_authenticated_user = get_basic_auth_user diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index ddabb5add..04dfcce59 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -9,7 +9,7 @@ import bcrypt from cachetools import TTLCache, cached from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.security import HTTPBasic, HTTPBasicCredentials, SecurityScopes from pydantic import BaseModel, root_validator from starlette.authentication import AuthenticationError @@ -102,15 +102,17 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials: @cached( TTLCache(maxsize=settings.AUTH_CACHE_MAX_SIZE, ttl=settings.AUTH_CACHE_TTL), lock=Lock(), - key=lambda credentials: ( + key=lambda credentials, security_scopes: ( credentials.username, credentials.password, + security_scopes.scope_str, ) if credentials is not None else None, ) -def get_authenticated_user( +def get_basic_auth_user( credentials: Union[HTTPBasicCredentials, None] = Depends(security), + security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: """Checks valid auth parameters. @@ -119,13 +121,10 @@ def get_authenticated_user( Args: credentials (iterator): auth parameters from the Authorization header - - Return: - AuthenticatedUser (AuthenticatedUser) + security_scopes: scopes requested for access Raises: HTTPException - """ if not credentials: logger.error("The basic authentication mode requires a Basic Auth header") @@ -156,6 +155,7 @@ def get_authenticated_user( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc) ) from exc + # Check that a password was passed if not hashed_password: # We're doing a bogus password check anyway to avoid timing attacks on # usernames @@ -168,6 +168,7 @@ def get_authenticated_user( headers={"WWW-Authenticate": "Basic"}, ) + # Check password validity if not bcrypt.checkpw( credentials.password.encode(settings.LOCALE_ENCODING), hashed_password.encode(settings.LOCALE_ENCODING), @@ -182,4 +183,15 @@ def get_authenticated_user( headers={"WWW-Authenticate": "Basic"}, ) - return AuthenticatedUser(scopes=user.scopes, agent=user.agent) + user = AuthenticatedUser(scopes=user.scopes, agent=dict(user.agent)) + + # Restrict access by scopes + if settings.LRS_RESTRICT_BY_SCOPES: + for requested_scope in security_scopes.scopes: + if not user.scopes.is_authorized(requested_scope): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f'Access not authorized to scope: "{requested_scope}".', + headers={"WWW-Authenticate": "Basic"}, + ) + return user diff --git a/src/ralph/api/auth/oidc.py b/src/ralph/api/auth/oidc.py index 423cfbb5c..d4476f3f0 100644 --- a/src/ralph/api/auth/oidc.py +++ b/src/ralph/api/auth/oidc.py @@ -2,16 +2,17 @@ import logging from functools import lru_cache -from typing import Optional, Union +from typing import Optional import requests from fastapi import Depends, HTTPException, status -from fastapi.security import OpenIdConnect +from fastapi.security import OpenIdConnect, SecurityScopes from jose import ExpiredSignatureError, JWTError, jwt from jose.exceptions import JWTClaimsError from pydantic import AnyUrl, BaseModel, Extra +from typing_extensions import Annotated -from ralph.api.auth.user import AuthenticatedUser +from ralph.api.auth.user import AuthenticatedUser, UserScopes from ralph.conf import settings OPENID_CONFIGURATION_PATH = "/.well-known/openid-configuration" @@ -92,8 +93,9 @@ def get_public_keys(jwks_uri: AnyUrl) -> dict: ) from exc -def get_authenticated_user( - auth_header: Union[str, None] = Depends(oauth2_scheme) +def get_oidc_user( + auth_header: Annotated[Optional[str], Depends(oauth2_scheme)], + security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: """Decode and validate OpenId Connect ID token against issuer in config. @@ -143,7 +145,19 @@ def get_authenticated_user( id_token = IDToken.parse_obj(decoded_token) - return AuthenticatedUser( - agent={"openid": id_token.sub}, - scopes=id_token.scope.split(" ") if id_token.scope else [], + user = AuthenticatedUser( + agent={"openid": f"{id_token.iss}/{id_token.sub}"}, + scopes=UserScopes(id_token.scope.split(" ") if id_token.scope else []), ) + + # Restrict access by scopes + if settings.LRS_RESTRICT_BY_SCOPES: + for requested_scope in security_scopes.scopes: + if not user.scopes.is_authorized(requested_scope): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f'Access not authorized to scope: "{requested_scope}".', + headers={"WWW-Authenticate": "Basic"}, + ) + + return user diff --git a/src/ralph/api/auth/user.py b/src/ralph/api/auth/user.py index 6184a7611..9d61f0c4d 100644 --- a/src/ralph/api/auth/user.py +++ b/src/ralph/api/auth/user.py @@ -1,6 +1,7 @@ """Authenticated user for the Ralph API.""" -from typing import Dict, List, Literal +from functools import lru_cache +from typing import Dict, FrozenSet, Literal from pydantic import BaseModel @@ -18,6 +19,50 @@ ] +class UserScopes(FrozenSet[Scope]): + """Scopes available to users.""" + + @lru_cache(maxsize=1024) + def is_authorized(self, requested_scope: Scope): + """Check if the requested scope can be accessed based on user scopes.""" + expanded_scopes = { + "statements/read": {"statements/read/mine", "statements/read"}, + "all/read": { + "statements/read/mine", + "statements/read", + "state/read", + "profile/read", + "all/read", + }, + "all": { + "statements/write", + "statements/read/mine", + "statements/read", + "state/read", + "state/write", + "define", + "profile/read", + "profile/write", + "all/read", + "all", + }, + } + + expanded_user_scopes = set() + for scope in self: + expanded_user_scopes.update(expanded_scopes.get(scope, {scope})) + + return requested_scope in expanded_user_scopes + + @classmethod + def __get_validators__(cls): # noqa: D105 + def validate(value: FrozenSet[Scope]): + """Transform value to an instance of UserScopes.""" + return cls(value) + + yield validate + + class AuthenticatedUser(BaseModel): """Pydantic model for user authentication. @@ -27,4 +72,4 @@ class AuthenticatedUser(BaseModel): """ agent: Dict - scopes: List[Scope] + scopes: UserScopes diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index 49a3433cb..0bfb27212 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -15,6 +15,7 @@ Query, Request, Response, + Security, status, ) from fastapi.dependencies.models import Dependant @@ -101,6 +102,7 @@ def _enrich_statement_with_authority(statement: dict, current_user: Authenticate def _parse_agent_parameters(agent_obj: dict): """Parse a dict and return an AgentParameters object to use in queries.""" # Transform agent to `dict` as FastAPI cannot parse JSON (seen as string) + agent = parse_obj_as(BaseXapiAgent, agent_obj) agent_query_params = {} @@ -137,10 +139,12 @@ def strict_query_params(request: Request): @router.get("") @router.get("/") -# pylint: disable=too-many-arguments, too-many-locals async def get( request: Request, - current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], + current_user: Annotated[ + AuthenticatedUser, + Security(get_authenticated_user, scopes=["statements/read/mine"]), + ], ### # Query string parameters defined by the LRS specification ### @@ -170,7 +174,6 @@ async def get( "of the Statement is an Activity with the specified id" ), ), - # pylint: disable=unused-argument registration: Optional[UUID] = Query( None, description=( @@ -178,7 +181,6 @@ async def get( "Filter, only return Statements matching the specified registration id" ), ), - # pylint: disable=unused-argument related_activities: Optional[bool] = Query( False, description=( @@ -189,7 +191,6 @@ async def get( "instead of that parameter's normal behaviour" ), ), - # pylint: disable=unused-argument related_agents: Optional[bool] = Query( False, description=( @@ -221,7 +222,6 @@ async def get( "0 indicates return the maximum the server will allow" ), ), - # pylint: disable=unused-argument, redefined-builtin format: Optional[Literal["ids", "exact", "canonical"]] = Query( "exact", description=( @@ -240,7 +240,6 @@ async def get( 'as in "exact" mode.' ), ), - # pylint: disable=unused-argument attachments: Optional[bool] = Query( False, description=( @@ -286,6 +285,9 @@ async def get( LRS Specification: https://github.com/adlnet/xAPI-Spec/blob/1.0.3/xAPI-Communication.md#213-get-statements """ + # pylint: disable=unused-argument,redefined-builtin,too-many-arguments + # pylint: disable=too-many-locals + # Make sure the limit does not go above max from settings limit = min(limit, settings.RUNSERVER_MAX_SEARCH_HITS_COUNT) @@ -334,14 +336,15 @@ async def get( json.loads(query_params["agent"]) ) - if settings.LRS_RESTRICT_BY_AUTHORITY: - # If using scopes, only restrict results when appropriate - if settings.LRS_RESTRICT_BY_SCOPES: - raise NotImplementedError("Scopes are not yet implemented in Ralph.") - - # Otherwise, enforce mine for all users + # mine: If using scopes, only restrict users with limited scopes + if settings.LRS_RESTRICT_BY_SCOPES: + if not current_user.scopes.is_authorized("statements/read"): + mine = True + # mine: If using only authority, always restrict (otherwise, use the default value) + elif settings.LRS_RESTRICT_BY_AUTHORITY: mine = True + # Filter by authority if using `mine` if mine: query_params["authority"] = _parse_agent_parameters(current_user.agent) @@ -399,7 +402,10 @@ async def get( @router.put("", responses=POST_PUT_RESPONSES, status_code=status.HTTP_204_NO_CONTENT) # pylint: disable=unused-argument, too-many-branches async def put( - current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], + current_user: Annotated[ + AuthenticatedUser, + Security(get_authenticated_user, scopes=["statements/write"]), + ], statement: LaxStatement, background_tasks: BackgroundTasks, statement_id: UUID = Query(alias="statementId"), @@ -478,7 +484,10 @@ async def put( @router.post("", responses=POST_PUT_RESPONSES) # pylint: disable = too-many-branches async def post( - current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], + current_user: Annotated[ + AuthenticatedUser, + Security(get_authenticated_user, scopes=["statements/write"]), + ], statements: Union[LaxStatement, List[LaxStatement]], background_tasks: BackgroundTasks, response: Response, diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 0415a5ee3..ad91785ae 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -19,7 +19,10 @@ from unittest.mock import Mock get_app_dir = Mock(return_value=".") -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra + +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator + +from ralph.exceptions import ConfigurationException from .utils import import_string @@ -210,5 +213,18 @@ def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name """Returns Ralph's default locale encoding.""" return self._CORE.LOCALE_ENCODING + @root_validator(allow_reuse=True) + @classmethod + def check_restriction_compatibility(cls, values): + """Raise an error if scopes are being used without authority restriction.""" + if values.get("LRS_RESTRICT_BY_SCOPES") and not values.get( + "LRS_RESTRICT_BY_AUTHORITY" + ): + raise ConfigurationException( + "LRS_RESTRICT_BY_AUTHORITY must be set to True if using " + "LRS_RESTRICT_BY_SCOPES=True" + ) + return values + settings = Settings() diff --git a/tests/api/auth/test_basic.py b/tests/api/auth/test_basic.py index ebcbf3aea..211f5e411 100644 --- a/tests/api/auth/test_basic.py +++ b/tests/api/auth/test_basic.py @@ -6,15 +6,15 @@ import bcrypt import pytest from fastapi.exceptions import HTTPException -from fastapi.security import HTTPBasicCredentials +from fastapi.security import HTTPBasicCredentials, SecurityScopes from ralph.api.auth.basic import ( ServerUsersCredentials, UserCredentials, - get_authenticated_user, + get_basic_auth_user, get_stored_credentials, ) -from ralph.api.auth.user import AuthenticatedUser +from ralph.api.auth.user import AuthenticatedUser, UserScopes from ralph.conf import Settings, settings STORED_CREDENTIALS = json.dumps( @@ -97,18 +97,21 @@ def test_api_auth_basic_caching_credentials(fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() + get_stored_credentials.cache_clear() credentials = HTTPBasicCredentials(username="ralph", password="admin") # Call function as in a first request with these credentials - get_authenticated_user(credentials) + get_basic_auth_user( + security_scopes=SecurityScopes(["profile/read"]), credentials=credentials + ) - assert get_authenticated_user.cache.popitem() == ( - ("ralph", "admin"), + assert get_basic_auth_user.cache.popitem() == ( + ("ralph", "admin", "profile/read"), AuthenticatedUser( agent={"mbox": "mailto:ralph@example.com"}, - scopes=["statements/read/mine", "statements/write"], + scopes=UserScopes(["statements/read/mine", "statements/write"]), ), ) @@ -118,13 +121,13 @@ def test_api_auth_basic_with_wrong_password(fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() credentials = HTTPBasicCredentials(username="ralph", password="wrong_password") # Call function as in a first request with these credentials with pytest.raises(HTTPException): - get_authenticated_user(credentials) + get_basic_auth_user(credentials, SecurityScopes(["all"])) def test_api_auth_basic_no_credential_file_found(fs, monkeypatch): @@ -132,12 +135,12 @@ def test_api_auth_basic_no_credential_file_found(fs, monkeypatch): monkeypatch.setenv("RALPH_AUTH_FILE", "other_file") monkeypatch.setattr("ralph.api.auth.basic.settings", Settings()) - get_stored_credentials.cache_clear() + get_basic_auth_user.cache_clear() credentials = HTTPBasicCredentials(username="ralph", password="admin") with pytest.raises(HTTPException): - get_authenticated_user(credentials) + get_basic_auth_user(credentials, SecurityScopes(["all"])) def test_get_whoami_no_credentials(basic_auth_test_client): @@ -173,7 +176,7 @@ def test_get_whoami_username_not_found(basic_auth_test_client, fs): """Whoami route returns a 401 error when the username cannot be found.""" credential_bytes = base64.b64encode("john:admin".encode("utf-8")) credentials = str(credential_bytes, "utf-8") - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) @@ -195,7 +198,7 @@ def test_get_whoami_wrong_password(basic_auth_test_client, fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() response = basic_auth_test_client.get( "/whoami", headers={"Authorization": f"Basic {credentials}"} @@ -217,14 +220,17 @@ def test_get_whoami_correct_credentials(basic_auth_test_client, fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() response = basic_auth_test_client.get( "/whoami", headers={"Authorization": f"Basic {credentials}"} ) assert response.status_code == 200 - assert response.json() == { - "agent": {"mbox": "mailto:ralph@example.com"}, - "scopes": ["statements/read/mine", "statements/write"], - } + + assert len(response.json().keys()) == 2 + assert response.json()["agent"] == {"mbox": "mailto:ralph@example.com"} + assert sorted(response.json()["scopes"]) == [ + "statements/read/mine", + "statements/write", + ] diff --git a/tests/api/auth/test_oidc.py b/tests/api/auth/test_oidc.py index 0c044bfe6..a0b621f01 100644 --- a/tests/api/auth/test_oidc.py +++ b/tests/api/auth/test_oidc.py @@ -1,47 +1,29 @@ """Tests for the api.auth.oidc module.""" import responses +from pydantic import parse_obj_as from ralph.api.auth.oidc import discover_provider, get_public_keys +from ralph.models.xapi.base.agents import BaseXapiAgentWithOpenId -from tests.fixtures.auth import ISSUER_URI +from tests.fixtures.auth import ISSUER_URI, mock_oidc_user @responses.activate -def test_api_auth_oidc_valid( - oidc_auth_test_client, mock_discovery_response, mock_oidc_jwks, encoded_token -): +def test_api_auth_oidc_valid(oidc_auth_test_client): """Test a valid OpenId Connect authentication.""" - # Clear LRU cache - discover_provider.cache_clear() - get_public_keys.cache_clear() - - # Mock request to get provider configuration - responses.add( - responses.GET, - f"{ISSUER_URI}/.well-known/openid-configuration", - json=mock_discovery_response, - status=200, - ) - - # Mock request to get keys - responses.add( - responses.GET, - mock_discovery_response["jwks_uri"], - json=mock_oidc_jwks, - status=200, - ) + oidc_token = mock_oidc_user(scopes=["all", "profile/read"]) response = oidc_auth_test_client.get( "/whoami", - headers={"Authorization": f"Bearer {encoded_token}"}, + headers={"Authorization": f"Bearer {oidc_token}"}, ) assert response.status_code == 200 - assert response.json() == { - "scopes": ["all", "statements/read"], - "agent": {"openid": "123|oidc"}, - } + assert len(response.json().keys()) == 2 + assert response.json()["agent"] == {"openid": "https://iss.example.com/123|oidc"} + assert parse_obj_as(BaseXapiAgentWithOpenId, response.json()["agent"]) + assert sorted(response.json()["scopes"]) == ["all", "profile/read"] @responses.activate @@ -50,25 +32,7 @@ def test_api_auth_invalid_token( ): """Test API with an invalid audience.""" - # Clear LRU cache - discover_provider.cache_clear() - get_public_keys.cache_clear() - - # Mock request to get provider configuration - responses.add( - responses.GET, - f"{ISSUER_URI}/.well-known/openid-configuration", - json=mock_discovery_response, - status=200, - ) - - # Mock request to get keys - responses.add( - responses.GET, - mock_discovery_response["jwks_uri"], - json=mock_oidc_jwks, - status=200, - ) + mock_oidc_user() response = oidc_auth_test_client.get( "/whoami", @@ -143,34 +107,14 @@ def test_api_auth_invalid_keys( @responses.activate -def test_api_auth_invalid_header( - oidc_auth_test_client, mock_discovery_response, mock_oidc_jwks, encoded_token -): +def test_api_auth_invalid_header(oidc_auth_test_client): """Test API with an invalid request header.""" - # Clear LRU cache - discover_provider.cache_clear() - get_public_keys.cache_clear() - - # Mock request to get provider configuration - responses.add( - responses.GET, - f"{ISSUER_URI}/.well-known/openid-configuration", - json=mock_discovery_response, - status=200, - ) - - # Mock request to get keys - responses.add( - responses.GET, - mock_discovery_response["jwks_uri"], - json=mock_oidc_jwks, - status=200, - ) + oidc_token = mock_oidc_user() response = oidc_auth_test_client.get( "/whoami", - headers={"Authorization": f"Wrong header {encoded_token}"}, + headers={"Authorization": f"Wrong header {oidc_token}"}, ) assert response.status_code == 401 diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 56cdd2baf..ec8a24085 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -5,11 +5,14 @@ from urllib.parse import parse_qs, quote_plus, urlparse import pytest +import responses from elasticsearch.helpers import bulk from fastapi.testclient import TestClient from ralph.api import app -from ralph.api.auth.basic import get_authenticated_user +from ralph.api.auth import get_authenticated_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.backends.data.base import BaseOperationType from ralph.backends.data.clickhouse import ClickHouseDataBackend from ralph.backends.data.mongo import MongoDataBackend @@ -28,7 +31,7 @@ get_mongo_test_backend, ) -from ..fixtures.auth import mock_basic_auth_user +from ..fixtures.auth import mock_basic_auth_user, mock_oidc_user from ..helpers import mock_activity, mock_agent client = TestClient(app) @@ -81,28 +84,27 @@ def insert_clickhouse_statements(statements): @pytest.fixture(params=["es", "mongo", "clickhouse"]) -# pylint: disable=unused-argument def insert_statements_and_monkeypatch_backend( request, es, mongo, clickhouse, monkeypatch ): """(Security) Return a function that inserts statements into each backend.""" - # pylint: disable=invalid-name + # pylint: disable=invalid-name,unused-argument def _insert_statements_and_monkeypatch_backend(statements): """Inserts statements once into each backend.""" - database_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" if request.param == "mongo": insert_mongo_statements(mongo, statements) - monkeypatch.setattr(database_client_class_path, get_mongo_test_backend()) + monkeypatch.setattr(backend_client_class_path, get_mongo_test_backend()) return if request.param == "clickhouse": insert_clickhouse_statements(statements) monkeypatch.setattr( - database_client_class_path, get_clickhouse_test_backend() + backend_client_class_path, get_clickhouse_test_backend() ) return insert_es_statements(es, statements) - monkeypatch.setattr(database_client_class_path, get_es_test_backend()) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) return _insert_statements_and_monkeypatch_backend @@ -123,8 +125,7 @@ def test_api_statements_get_mine( """(Security) Test that the get statements API route, given a "mine=True" query parameter returns a list of statements filtered by authority. """ - # pylint: disable=redefined-outer-name - # pylint: disable=invalid-name + # pylint: disable=redefined-outer-name,invalid-name # Create two distinct agents if ifi == "account_same_home_page": @@ -153,7 +154,7 @@ def test_api_statements_get_mine( ) # Clear cache before each test iteration - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() statements = [ { @@ -233,7 +234,7 @@ def test_api_statements_get_mine( def test_api_statements_get( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route without any filters set up.""" # pylint: disable=redefined-outer-name @@ -253,7 +254,7 @@ def test_api_statements_get( # Confirm that calling this with and without the trailing slash both work for path in ("/xAPI/statements", "/xAPI/statements/"): response = client.get( - path, headers={"Authorization": f"Basic {auth_credentials}"} + path, headers={"Authorization": f"Basic {basic_auth_credentials}"} ) assert response.status_code == 200 @@ -261,7 +262,7 @@ def test_api_statements_get( def test_api_statements_get_ascending( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "ascending" query parameter, should return statements in ascending order by their timestamp. @@ -282,7 +283,7 @@ def test_api_statements_get_ascending( response = client.get( "/xAPI/statements/?ascending=true", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -290,7 +291,7 @@ def test_api_statements_get_ascending( def test_api_statements_get_by_statement_id( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "statementId" query parameter, should return a list of statements matching the given statementId. @@ -311,7 +312,7 @@ def test_api_statements_get_by_statement_id( response = client.get( f"/xAPI/statements/?statementId={statements[1]['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -329,7 +330,7 @@ def test_api_statements_get_by_statement_id( ], ) def test_api_statements_get_by_agent( - ifi, insert_statements_and_monkeypatch_backend, auth_credentials + ifi, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "agent" query parameter, should return a list of statements filtered by the given agent. @@ -365,7 +366,7 @@ def test_api_statements_get_by_agent( response = client.get( f"/xAPI/statements/?agent={quote_plus(json.dumps(agent_1))}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -373,7 +374,7 @@ def test_api_statements_get_by_agent( def test_api_statements_get_by_verb( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "verb" query parameter, should return a list of statements filtered by the given verb id. @@ -396,7 +397,7 @@ def test_api_statements_get_by_verb( response = client.get( "/xAPI/statements/?verb=" + quote_plus("http://adlnet.gov/expapi/verbs/played"), - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -404,7 +405,7 @@ def test_api_statements_get_by_verb( def test_api_statements_get_by_activity( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "activity" query parameter, should return a list of statements filtered by the given activity id. @@ -430,7 +431,7 @@ def test_api_statements_get_by_activity( response = client.get( f"/xAPI/statements/?activity={activity_1['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -439,7 +440,7 @@ def test_api_statements_get_by_activity( # Check that badly formated activity returns an error response = client.get( "/xAPI/statements/?activity=INVALID_IRI", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 422 @@ -447,7 +448,7 @@ def test_api_statements_get_by_activity( def test_api_statements_get_since_timestamp( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "since" query parameter, should return a list of statements filtered by the given timestamp. @@ -469,7 +470,7 @@ def test_api_statements_get_since_timestamp( since = (datetime.now() - timedelta(minutes=30)).isoformat() response = client.get( f"/xAPI/statements/?since={since}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -477,7 +478,7 @@ def test_api_statements_get_since_timestamp( def test_api_statements_get_until_timestamp( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "until" query parameter, should return a list of statements filtered by the given timestamp. @@ -499,7 +500,7 @@ def test_api_statements_get_until_timestamp( until = (datetime.now() - timedelta(minutes=30)).isoformat() response = client.get( f"/xAPI/statements/?until={until}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -507,7 +508,7 @@ def test_api_statements_get_until_timestamp( def test_api_statements_get_with_pagination( - monkeypatch, insert_statements_and_monkeypatch_backend, auth_credentials + monkeypatch, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a request leading to more results than can fit on the first page, should return a list of statements non-exceeding the page @@ -546,7 +547,8 @@ def test_api_statements_get_with_pagination( # First response gets the first two results, with a "more" entry as # we have more results to return on a later page. first_response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert first_response.status_code == 200 assert first_response.json()["statements"] == [statements[4], statements[3]] @@ -558,7 +560,7 @@ def test_api_statements_get_with_pagination( # Second response gets the missing result from the first response. second_response = client.get( first_response.json()["more"], - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert second_response.status_code == 200 assert second_response.json()["statements"] == [statements[2], statements[1]] @@ -570,14 +572,14 @@ def test_api_statements_get_with_pagination( # Third response gets the missing result from the first response third_response = client.get( second_response.json()["more"], - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert third_response.status_code == 200 assert third_response.json() == {"statements": [statements[0]]} def test_api_statements_get_with_pagination_and_query( - monkeypatch, insert_statements_and_monkeypatch_backend, auth_credentials + monkeypatch, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a request with a query parameter leading to more results than can fit on the first page, should return a list @@ -623,7 +625,7 @@ def test_api_statements_get_with_pagination_and_query( first_response = client.get( "/xAPI/statements/?verb=" + quote_plus("https://w3id.org/xapi/video/verbs/played"), - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert first_response.status_code == 200 assert first_response.json()["statements"] == [statements[2], statements[1]] @@ -635,14 +637,14 @@ def test_api_statements_get_with_pagination_and_query( # Second response gets the missing result from the first response. second_response = client.get( first_response.json()["more"], - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert second_response.status_code == 200 assert second_response.json() == {"statements": [statements[0]]} def test_api_statements_get_with_no_matching_statement( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a query yielding no matching statement, should return an empty list. @@ -663,14 +665,16 @@ def test_api_statements_get_with_no_matching_statement( response = client.get( "/xAPI/statements/?statementId=66c81e98-1763-4730-8cfc-f5ab34f1bad5", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert response.json() == {"statements": []} -def test_api_statements_get_with_database_query_failure(auth_credentials, monkeypatch): +def test_api_statements_get_with_database_query_failure( + basic_auth_credentials, monkeypatch +): """Test the get statements API route, given a query raising a BackendException, should return an error response with HTTP code 500. """ @@ -687,14 +691,14 @@ def mock_query_statements(*_): response = client.get( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 500 assert response.json() == {"detail": "xAPI statements query failed"} @pytest.mark.parametrize("id_param", ["statementId", "voidedStatementId"]) -def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param): +def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_param): """Test error response for invalid query parameters""" id_1 = "be67b160-d958-4f51-b8b8-1892002dbac6" @@ -703,7 +707,7 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) # Check for 400 status code when unknown parameters are provided response = client.get( "/xAPI/statements/?mamamia=herewegoagain", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 400 assert response.json() == { @@ -713,7 +717,7 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) # Check for 400 status code when both statementId and voidedStatementId are provided response = client.get( f"/xAPI/statements/?statementId={id_1}&voidedStatementId={id_2}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 400 @@ -725,7 +729,7 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) ]: response = client.get( f"/xAPI/statements/?{id_param}={id_1}&{invalid_param}={value}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 400 assert response.json() == { @@ -739,6 +743,166 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) for valid_param, value in [("format", "ids"), ("attachments", "true")]: response = client.get( f"/xAPI/statements/?{id_param}={id_1}&{valid_param}={value}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code != 400 + + +@responses.activate +@pytest.mark.parametrize("auth_method", ["basic", "oidc"]) +@pytest.mark.parametrize( + "scopes,is_authorized", + [ + (["all"], True), + (["all/read"], True), + (["statements/read/mine"], True), + (["statements/read"], True), + (["profile/write", "statements/read", "all/write"], True), + (["statements/write"], False), + (["profile/read"], False), + (["all/write"], False), + ([], False), + ], +) +def test_api_statements_get_scopes( + monkeypatch, fs, es, auth_method, scopes, is_authorized +): + """Test that getting statements behaves properly according to user scopes.""" + # pylint: disable=invalid-name,too-many-locals,too-many-arguments + + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + if auth_method == "basic": + agent = mock_agent("mbox", 1) + credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) + headers = {"Authorization": f"Basic {credentials}"} + + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user + get_basic_auth_user.cache_clear() + + elif auth_method == "oidc": + sub = "123|oidc" + iss = "https://iss.example.com" + agent = {"openid": f"{iss}/{sub}"} + oidc_token = mock_oidc_user(sub=sub, scopes=scopes) + headers = {"Authorization": f"Bearer {oidc_token}"} + + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + "http://providerHost:8080/auth/realms/real_name", + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + "http://clientHost:8100", + ) + + app.dependency_overrides[get_authenticated_user] = get_oidc_user + + statements = [ + { + "id": "be67b160-d958-4f51-b8b8-1892002dbac6", + "timestamp": (datetime.now() - timedelta(hours=1)).isoformat(), + "actor": agent, + "authority": agent, + }, + { + "id": "72c81e98-1763-4730-8cfc-f5ab34f1bad2", + "timestamp": datetime.now().isoformat(), + "actor": agent, + "authority": agent, + }, + ] + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + insert_es_statements(es, statements) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.get( + "/xAPI/statements/", + headers=headers, + ) + + if is_authorized: + assert response.status_code == 200 + assert response.json() == {"statements": [statements[1], statements[0]]} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": 'Access not authorized to scope: "statements/read/mine".' + } + + app.dependency_overrides.pop(get_authenticated_user, None) + + +@pytest.mark.parametrize( + "scopes,read_all_access", + [ + (["all"], True), + (["all/read", "statements/read/mine"], True), + (["statements/read"], True), + (["statements/read/mine"], False), + ], +) +def test_api_statements_get_scopes_with_authority( + monkeypatch, fs, es, scopes, read_all_access +): + """Test that restricting by scope and by authority behaves properly. + Getting statements should be restricted to mine for users which only have + `statements/read/mine` scope but should not be restricted when the user + has wider scopes. + """ + # pylint: disable=invalid-name + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_AUTHORITY", True + ) + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + agent = mock_agent("mbox", 1) + agent_2 = mock_agent("mbox", 2) + username = "jane" + password = "janepwd" + credentials = mock_basic_auth_user(fs, username, password, scopes, agent) + headers = {"Authorization": f"Basic {credentials}"} + + get_basic_auth_user.cache_clear() + + statements = [ + { + "id": "be67b160-d958-4f51-b8b8-1892002dbac6", + "timestamp": (datetime.now() - timedelta(hours=1)).isoformat(), + "actor": agent, + "authority": agent, + }, + { + "id": "72c81e98-1763-4730-8cfc-f5ab34f1bad2", + "timestamp": datetime.now().isoformat(), + "actor": agent, + "authority": agent_2, + }, + ] + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + insert_es_statements(es, statements) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.get( + "/xAPI/statements/", + headers=headers, + ) + + assert response.status_code == 200 + + if read_all_access: + assert response.json() == {"statements": [statements[1], statements[0]]} + else: + assert response.json() == {"statements": [statements[0]]} + + app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index 350e4a11c..fe3e63691 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -4,15 +4,20 @@ from uuid import uuid4 import pytest +import responses from fastapi.testclient import TestClient from httpx import AsyncClient from ralph.api import app +from ralph.api.auth import get_authenticated_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.conf import XapiForwardingConfigurationSettings from ralph.exceptions import BackendException +from tests.fixtures.auth import mock_basic_auth_user, mock_oidc_user from tests.fixtures.backends import ( ES_TEST_FORWARDING_INDEX, ES_TEST_HOSTS, @@ -28,6 +33,7 @@ from ..helpers import ( assert_statement_get_responses_are_equivalent, + mock_agent, mock_statement, string_is_date, string_is_uuid, @@ -36,7 +42,7 @@ client = TestClient(app) -def test_api_statements_post_invalid_parameters(auth_credentials): +def test_api_statements_post_invalid_parameters(basic_auth_credentials): """Test that using invalid parameters returns the proper status code.""" statement = mock_statement() @@ -44,7 +50,7 @@ def test_api_statements_post_invalid_parameters(auth_credentials): # Check for 400 status code when unknown parameters are provided response = client.post( "/xAPI/statements/?mamamia=herewegoagain", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 400 @@ -59,7 +65,7 @@ def test_api_statements_post_invalid_parameters(auth_credentials): ) # pylint: disable=too-many-arguments def test_api_statements_post_single_statement_directly( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument @@ -69,7 +75,7 @@ def test_api_statements_post_single_statement_directly( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -79,7 +85,8 @@ def test_api_statements_post_single_statement_directly( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -89,7 +96,7 @@ def test_api_statements_post_single_statement_directly( # pylint: disable=too-many-arguments def test_api_statements_post_enriching_without_existing_values( - monkeypatch, auth_credentials, es + monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when statement provides no values.""" # pylint: disable=invalid-name,unused-argument @@ -111,7 +118,7 @@ def test_api_statements_post_enriching_without_existing_values( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -120,7 +127,8 @@ def test_api_statements_post_enriching_without_existing_values( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -153,7 +161,7 @@ def test_api_statements_post_enriching_without_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_post_enriching_with_existing_values( - field, value, status, monkeypatch, auth_credentials, es + field, value, status, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when values are provided.""" # pylint: disable=invalid-name,unused-argument @@ -168,7 +176,7 @@ def test_api_statements_post_enriching_with_existing_values( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -178,7 +186,8 @@ def test_api_statements_post_enriching_with_existing_values( if status == 200: es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -198,7 +207,7 @@ def test_api_statements_post_enriching_with_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_post_single_statement_no_trailing_slash( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test that the statements endpoint also works without the trailing slash.""" # pylint: disable=invalid-name,unused-argument @@ -208,7 +217,7 @@ def test_api_statements_post_single_statement_no_trailing_slash( response = client.post( "/xAPI/statements", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -222,7 +231,7 @@ def test_api_statements_post_single_statement_no_trailing_slash( ) # pylint: disable=too-many-arguments def test_api_statements_post_list_of_one( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement in a list.""" # pylint: disable=invalid-name,unused-argument @@ -232,7 +241,7 @@ def test_api_statements_post_list_of_one( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement], ) @@ -241,7 +250,8 @@ def test_api_statements_post_list_of_one( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -255,7 +265,7 @@ def test_api_statements_post_list_of_one( ) # pylint: disable=too-many-arguments def test_api_statements_post_list( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with two statements in a list.""" # pylint: disable=invalid-name,unused-argument @@ -272,7 +282,7 @@ def test_api_statements_post_list( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statements, ) @@ -284,7 +294,8 @@ def test_api_statements_post_list( es.indices.refresh() get_response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert get_response.status_code == 200 @@ -306,7 +317,7 @@ def test_api_statements_post_list( ) # pylint: disable=too-many-arguments def test_api_statements_post_list_with_duplicates( - backend, monkeypatch, auth_credentials, es_data_stream, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es_data_stream, mongo, clickhouse ): """Test the post statements API route with duplicate statement IDs should fail.""" # pylint: disable=invalid-name,unused-argument @@ -316,7 +327,7 @@ def test_api_statements_post_list_with_duplicates( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement, statement], ) @@ -327,7 +338,8 @@ def test_api_statements_post_list_with_duplicates( # The failure should imply no statement insertion. es_data_stream.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert response.json() == {"statements": []} @@ -339,7 +351,7 @@ def test_api_statements_post_list_with_duplicates( ) # pylint: disable=too-many-arguments def test_api_statements_post_list_with_duplicate_of_existing_statement( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route, given a statement that already exist in the database (has the same ID), should fail. @@ -354,7 +366,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # Post the statement once. response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 200 @@ -366,7 +378,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # include the ID in the response as it wasn't inserted. response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 204 @@ -376,7 +388,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # Post the statement again, trying to change the timestamp which is not allowed. response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[dict(statement, **{"timestamp": "2023-03-15T14:07:51Z"})], ) @@ -387,7 +399,8 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( } response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -400,7 +413,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_post_with_failure_during_storage( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure happening during storage.""" # pylint: disable=invalid-name,unused-argument, too-many-arguments @@ -416,7 +429,7 @@ def write_mock(*args, **kwargs): response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -429,7 +442,7 @@ def write_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_post_with_failure_during_id_query( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure during query execution.""" # pylint: disable=invalid-name,unused-argument,too-many-arguments @@ -447,7 +460,7 @@ def query_statements_by_ids_mock(*args, **kwargs): response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -461,7 +474,7 @@ def query_statements_by_ids_mock(*args, **kwargs): ) # pylint: disable=too-many-arguments def test_api_statements_post_list_without_forwarding( - backend, auth_credentials, monkeypatch, es, mongo, clickhouse + backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the post statements API route, given an empty forwarding configuration, should not start the forwarding background task. @@ -487,7 +500,7 @@ def spy_mock_forward_xapi_statements(_): response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -520,7 +533,7 @@ async def test_api_statements_post_list_with_forwarding( receiving_backend, forwarding_backend, monkeypatch, - auth_credentials, + basic_auth_credentials, es, es_forwarding, mongo, @@ -594,7 +607,7 @@ async def test_api_statements_post_list_with_forwarding( # The statement should be stored on the forwarding client response = await forwarding_client.get( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -605,7 +618,7 @@ async def test_api_statements_post_list_with_forwarding( async with AsyncClient() as receiving_client: response = await receiving_client.get( f"http://{RUNSERVER_TEST_HOST}:{RUNSERVER_TEST_PORT}/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -614,3 +627,74 @@ async def test_api_statements_post_list_with_forwarding( # Stop receiving LRS client await lrs_context.__aexit__(None, None, None) + + +@responses.activate +@pytest.mark.parametrize("auth_method", ["basic", "oidc"]) +@pytest.mark.parametrize( + "scopes,is_authorized", + [ + (["all"], True), + (["profile/read", "statements/write"], True), + (["all/read"], False), + (["statements/read/mine"], False), + (["profile/write"], False), + ([], False), + ], +) +def test_api_statements_post_scopes( + monkeypatch, fs, es, auth_method, scopes, is_authorized +): + """Test that posting statements behaves properly according to user scopes.""" + # pylint: disable=invalid-name,unused-argument + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + if auth_method == "basic": + agent = mock_agent("mbox", 1) + credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) + headers = {"Authorization": f"Basic {credentials}"} + + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user + get_basic_auth_user.cache_clear() + + elif auth_method == "oidc": + sub = "123|oidc" + agent = {"openid": sub} + oidc_token = mock_oidc_user(sub=sub, scopes=scopes) + headers = {"Authorization": f"Bearer {oidc_token}"} + + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + "http://providerHost:8080/auth/realms/real_name", + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + "http://clientHost:8100", + ) + + app.dependency_overrides[get_authenticated_user] = get_oidc_user + + statement = mock_statement() + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.post( + "/xAPI/statements/", + headers=headers, + json=statement, + ) + + if is_authorized: + assert response.status_code == 200 + else: + assert response.status_code == 401 + assert response.json() == { + "detail": 'Access not authorized to scope: "statements/write".' + } + + app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index 330bccd0f..ae30b2b73 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -1,17 +1,23 @@ """Tests for the PUT statements endpoint of the Ralph API.""" - +from importlib import reload from uuid import uuid4 import pytest +import responses from fastapi.testclient import TestClient from httpx import AsyncClient +from ralph import api from ralph.api import app +from ralph.api.auth import get_authenticated_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.conf import XapiForwardingConfigurationSettings from ralph.exceptions import BackendException +from tests.fixtures.auth import mock_basic_auth_user, mock_oidc_user from tests.fixtures.backends import ( ES_TEST_FORWARDING_INDEX, ES_TEST_HOSTS, @@ -27,21 +33,23 @@ from ..helpers import ( assert_statement_get_responses_are_equivalent, + mock_agent, mock_statement, string_is_date, ) +reload(api) client = TestClient(app) -def test_api_statements_put_invalid_parameters(auth_credentials): +def test_api_statements_put_invalid_parameters(basic_auth_credentials): """Test that using invalid parameters returns the proper status code.""" statement = mock_statement() # Check for 400 status code when unknown parameters are provided response = client.put( "/xAPI/statements/?mamamia=herewegoagain", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 400 @@ -56,7 +64,7 @@ def test_api_statements_put_invalid_parameters(auth_credentials): ) # pylint: disable=too-many-arguments def test_api_statements_put_single_statement_directly( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument @@ -66,7 +74,7 @@ def test_api_statements_put_single_statement_directly( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -75,7 +83,8 @@ def test_api_statements_put_single_statement_directly( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -85,7 +94,7 @@ def test_api_statements_put_single_statement_directly( # pylint: disable=too-many-arguments def test_api_statements_put_enriching_without_existing_values( - monkeypatch, auth_credentials, es + monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when statement provides no values.""" # pylint: disable=invalid-name,unused-argument @@ -97,7 +106,7 @@ def test_api_statements_put_enriching_without_existing_values( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 204 @@ -105,7 +114,8 @@ def test_api_statements_put_enriching_without_existing_values( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -137,7 +147,7 @@ def test_api_statements_put_enriching_without_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_put_enriching_with_existing_values( - field, value, status, monkeypatch, auth_credentials, es + field, value, status, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when values are provided.""" # pylint: disable=invalid-name,unused-argument @@ -152,7 +162,7 @@ def test_api_statements_put_enriching_with_existing_values( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -162,7 +172,8 @@ def test_api_statements_put_enriching_with_existing_values( if status == 204: es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -181,7 +192,7 @@ def test_api_statements_put_enriching_with_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_put_single_statement_no_trailing_slash( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test that the statements endpoint also works without the trailing slash.""" # pylint: disable=invalid-name,unused-argument @@ -191,7 +202,7 @@ def test_api_statements_put_single_statement_no_trailing_slash( response = client.put( f"/xAPI/statements?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -204,7 +215,7 @@ def test_api_statements_put_single_statement_no_trailing_slash( ) # pylint: disable=too-many-arguments def test_api_statements_put_id_mismatch( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): # pylint: disable=invalid-name,unused-argument """Test the put statements API route when the statementId doesn't match.""" @@ -214,7 +225,7 @@ def test_api_statements_put_id_mismatch( different_statement_id = str(uuid4()) response = client.put( f"/xAPI/statements/?statementId={different_statement_id}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -230,7 +241,7 @@ def test_api_statements_put_id_mismatch( ) # pylint: disable=too-many-arguments def test_api_statements_put_list_of_one( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): # pylint: disable=invalid-name,unused-argument """Test that we fail on PUTs with a list, even if it's one statement.""" @@ -239,7 +250,7 @@ def test_api_statements_put_list_of_one( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement], ) @@ -252,7 +263,7 @@ def test_api_statements_put_list_of_one( ) # pylint: disable=too-many-arguments def test_api_statements_put_duplicate_of_existing_statement( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route, given a statement that already exist in the database (has the same ID), should fail. @@ -265,7 +276,7 @@ def test_api_statements_put_duplicate_of_existing_statement( # Put the statement once. response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 204 @@ -275,7 +286,7 @@ def test_api_statements_put_duplicate_of_existing_statement( # Put the statement twice, trying to change the timestamp, which is not allowed response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=dict(statement, **{"timestamp": "2023-03-15T14:07:51Z"}), ) @@ -286,7 +297,7 @@ def test_api_statements_put_duplicate_of_existing_statement( response = client.get( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -299,7 +310,7 @@ def test_api_statements_put_duplicate_of_existing_statement( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_put_with_failure_during_storage( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure happening during storage.""" # pylint: disable=invalid-name,unused-argument, too-many-arguments @@ -315,7 +326,7 @@ def write_mock(*args, **kwargs): response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -328,7 +339,7 @@ def write_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_put_with_a_failure_during_id_query( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure during query execution.""" # pylint: disable=invalid-name,unused-argument,too-many-arguments @@ -346,7 +357,7 @@ def query_statements_by_ids_mock(*args, **kwargs): response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -360,7 +371,7 @@ def query_statements_by_ids_mock(*args, **kwargs): ) # pylint: disable=too-many-arguments def test_api_statements_put_without_forwarding( - backend, auth_credentials, monkeypatch, es, mongo, clickhouse + backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the put statements API route, given an empty forwarding configuration, should not start the forwarding background task. @@ -386,7 +397,7 @@ def spy_mock_forward_xapi_statements(_): response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -418,7 +429,7 @@ async def test_api_statements_put_with_forwarding( receiving_backend, forwarding_backend, monkeypatch, - auth_credentials, + basic_auth_credentials, es, es_forwarding, mongo, @@ -495,7 +506,7 @@ async def test_api_statements_put_with_forwarding( # The statement should be stored on the forwarding client response = await forwarding_client.get( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -506,7 +517,7 @@ async def test_api_statements_put_with_forwarding( async with AsyncClient() as receiving_client: response = await receiving_client.get( f"http://{RUNSERVER_TEST_HOST}:{RUNSERVER_TEST_PORT}/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -515,3 +526,74 @@ async def test_api_statements_put_with_forwarding( # Stop receiving LRS client await lrs_context.__aexit__(None, None, None) + + +@responses.activate +@pytest.mark.parametrize("auth_method", ["basic", "oidc"]) +@pytest.mark.parametrize( + "scopes,is_authorized", + [ + (["all"], True), + (["profile/read", "statements/write"], True), + (["all/read"], False), + (["statements/read/mine"], False), + (["profile/write"], False), + ([], False), + ], +) +def test_api_statements_put_scopes( + monkeypatch, fs, es, auth_method, scopes, is_authorized +): + """Test that putting statements behaves properly according to user scopes.""" + # pylint: disable=invalid-name,unused-argument,duplicate-code + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + if auth_method == "basic": + agent = mock_agent("mbox", 1) + credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) + headers = {"Authorization": f"Basic {credentials}"} + + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user + get_basic_auth_user.cache_clear() + + elif auth_method == "oidc": + sub = "123|oidc" + agent = {"openid": sub} + oidc_token = mock_oidc_user(sub=sub, scopes=scopes) + headers = {"Authorization": f"Bearer {oidc_token}"} + + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + "http://providerHost:8080/auth/realms/real_name", + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + "http://clientHost:8100", + ) + + app.dependency_overrides[get_authenticated_user] = get_oidc_user + + statement = mock_statement() + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.put( + f"/xAPI/statements/?statementId={statement['id']}", + headers=headers, + json=statement, + ) + + if is_authorized: + assert response.status_code == 204 + else: + assert response.status_code == 401 + assert response.json() == { + "detail": 'Access not authorized to scope: "statements/write".' + } + + app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/conftest.py b/tests/conftest.py index 10b819ee3..8165d3458 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from .fixtures import hypothesis_configuration # noqa: F401 from .fixtures import hypothesis_strategies # noqa: F401 from .fixtures.auth import ( # noqa: F401 - auth_credentials, + basic_auth_credentials, basic_auth_test_client, encoded_token, mock_discovery_response, diff --git a/tests/fixtures/auth.py b/tests/fixtures/auth.py index da4c83868..7e44149b3 100644 --- a/tests/fixtures/auth.py +++ b/tests/fixtures/auth.py @@ -2,9 +2,11 @@ import base64 import json import os +from typing import Optional import bcrypt import pytest +import responses from cryptography.hazmat.primitives import serialization from fastapi.testclient import TestClient from jose import jwt @@ -12,6 +14,7 @@ from ralph.api import app, get_authenticated_user from ralph.api.auth.basic import get_stored_credentials +from ralph.api.auth.oidc import discover_provider, get_public_keys from ralph.conf import settings from . import private_key, public_key @@ -24,10 +27,10 @@ def mock_basic_auth_user( fs_, - username: str, - password: str, - scopes: list, - agent: dict, + username: str = "jane", + password: str = "pwd", + scopes: Optional[list] = None, + agent: Optional[dict] = None, ): """Create a user using Basic Auth in the (fake) file system. @@ -39,6 +42,12 @@ def mock_basic_auth_user( agent (dict): an agent that represents the user and may be used as authority """ + # Default values for `scopes` and `agent` + if scopes is None: + scopes = [] + if agent is None: + agent = {"mbox": "mailto:jane@ralphlrs.com"} + # Basic HTTP auth credential_bytes = base64.b64encode(f"{username}:{password}".encode("utf-8")) credentials = str(credential_bytes, "utf-8") @@ -71,7 +80,7 @@ def mock_basic_auth_user( # pylint: disable=invalid-name @pytest.fixture -def auth_credentials(fs, user_scopes=None, agent=None): +def basic_auth_credentials(fs, user_scopes=None, agent=None): """Set up the credentials file for request authentication. Args: @@ -92,7 +101,6 @@ def auth_credentials(fs, user_scopes=None, agent=None): agent = {"mbox": "mailto:test_ralph@example.com"} credentials = mock_basic_auth_user(fs, username, password, user_scopes, agent) - return credentials @@ -101,10 +109,10 @@ def basic_auth_test_client(): """Return a TestClient with HTTP basic authentication mode.""" # pylint:disable=import-outside-toplevel from ralph.api.auth.basic import ( - get_authenticated_user as get_basic, # pylint:disable=import-outside-toplevel + get_basic_auth_user, # pylint:disable=import-outside-toplevel ) - app.dependency_overrides[get_authenticated_user] = get_basic + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user with TestClient(app) as test_client: yield test_client @@ -122,15 +130,14 @@ def oidc_auth_test_client(monkeypatch): "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", AUDIENCE, ) - from ralph.api.auth.oidc import get_authenticated_user as get_oidc + from ralph.api.auth.oidc import get_oidc_user - app.dependency_overrides[get_authenticated_user] = get_oidc + app.dependency_overrides[get_authenticated_user] = get_oidc_user with TestClient(app) as test_client: yield test_client -@pytest.fixture -def mock_discovery_response(): +def _mock_discovery_response(): """Return an example discovery response.""" return { "issuer": "http://providerHost", @@ -219,6 +226,12 @@ def mock_discovery_response(): } +@pytest.fixture +def mock_discovery_response(): + """Return an example discovery response (fixture).""" + return _mock_discovery_response() + + def get_jwk(pub_key): """Return a JWK representation of the public key.""" public_numbers = pub_key.public_numbers() @@ -233,23 +246,27 @@ def get_jwk(pub_key): } -@pytest.fixture -def mock_oidc_jwks(): +def _mock_oidc_jwks(): """Mock OpenID Connect keys.""" return {"keys": [get_jwk(public_key)]} @pytest.fixture -def encoded_token(): +def mock_oidc_jwks(): + """Mock OpenID Connect keys (fixture).""" + return _mock_oidc_jwks() + + +def _create_oidc_token(sub, scopes): """Encode token with the private key.""" return jwt.encode( claims={ - "sub": "123|oidc", + "sub": sub, "iss": "https://iss.example.com", "aud": AUDIENCE, "iat": 0, # Issued the 1/1/1970 "exp": 9999999999, # Expiring in 11/20/2286 - "scope": "all statements/read", + "scope": " ".join(scopes), }, key=private_key.private_bytes( serialization.Encoding.PEM, @@ -261,3 +278,39 @@ def encoded_token(): "kid": PUBLIC_KEY_ID, }, ) + + +def mock_oidc_user(sub="123|oidc", scopes=None): + """Instantiate mock oidc user and return auth token.""" + # Default value for scope + if scopes is None: + scopes = ["all", "statements/read"] + + # Clear LRU cache + discover_provider.cache_clear() + get_public_keys.cache_clear() + + # Mock request to get provider configuration + responses.add( + responses.GET, + f"{ISSUER_URI}/.well-known/openid-configuration", + json=_mock_discovery_response(), + status=200, + ) + + # Mock request to get keys + responses.add( + responses.GET, + _mock_discovery_response()["jwks_uri"], + json=_mock_oidc_jwks(), + status=200, + ) + + oidc_token = _create_oidc_token(sub=sub, scopes=scopes) + return oidc_token + + +@pytest.fixture +def encoded_token(): + """Encode token with the private key (fixture).""" + return _create_oidc_token(sub="123|oidc", scopes=["all", "statements/read"]) diff --git a/tests/test_cli.py b/tests/test_cli.py index 6a303884c..890576027 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -212,7 +212,7 @@ def _assert_matching_basic_auth_credentials( assert "hash" in credentials if hash_: assert credentials["hash"] == hash_ - assert credentials["scopes"] == scopes + assert sorted(credentials["scopes"]) == sorted(scopes) assert "agent" in credentials if agent_name is not None: diff --git a/tests/test_conf.py b/tests/test_conf.py index 670bf5ba6..e9c681d79 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -7,6 +7,7 @@ from ralph import conf from ralph.backends.conf import BackendSettings from ralph.conf import CommaSeparatedTuple, Settings, settings +from ralph.exceptions import ConfigurationException def test_conf_settings_field_value_priority(fs, monkeypatch): @@ -73,3 +74,20 @@ def test_conf_core_settings_should_impact_settings_defaults(monkeypatch): # Defaults. assert str(conf.settings.AUTH_FILE) == "/foo/auth.json" + + +def test_conf_forbidden_scopes_without_authority(monkeypatch): + """Test that using RESTRICT_BY_SCOPES without RESTRICT_BY_AUTHORITY raises an + error.""" + + monkeypatch.setenv("RALPH_LRS_RESTRICT_BY_AUTHORITY", False) + monkeypatch.setenv("RALPH_LRS_RESTRICT_BY_SCOPES", True) + + with pytest.raises( + ConfigurationException, + match=( + "LRS_RESTRICT_BY_AUTHORITY must be set to True if using " + "LRS_RESTRICT_BY_SCOPES=True" + ), + ): + reload(conf) From 04c655e6146c4224807c72bad24560e7149239da Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 12 Oct 2023 18:49:24 +0200 Subject: [PATCH 46/65] =?UTF-8?q?=F0=9F=90=9B(backends)=20fix=20limit=20in?= =?UTF-8?q?=20`read`=20method=20of=20`async=5Fes`=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As found in the synchronous Elasticsearch backend, the `limit` parameter was not correclty taken into account in the `read` method of the asynchronous Elasticsearch backend. Fixing it by copying it from sync Elasticsearch. --- src/ralph/backends/data/async_es.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index f187717ac..20b906dd9 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -165,7 +165,10 @@ async def read( kwargs["q"] = query.query_string count = chunk_size - while limit or chunk_size == count: + # The first condition is set to comprise either limit as None + # (when the backend query does not have `size` parameter), + # or limit with a positive value. + while limit != 0 and chunk_size == count: kwargs["size"] = limit if limit and limit < chunk_size else chunk_size try: documents = (await self.client.search(**kwargs))["hits"]["hits"] From 71113355ee5d220879eb32d4ee7c3cbb2621c328 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Mon, 16 Oct 2023 16:51:31 +0200 Subject: [PATCH 47/65] =?UTF-8?q?=E2=9C=A8(api)=20add=20ability=20to=20use?= =?UTF-8?q?=20async=20backends=20for=20`runserver`=20command?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After new addition/integration of asynchronous backends such as `async_es` and `async_mongo`, add the ability to use them for the `runserver` command. --- src/ralph/conf.py | 4 +++- tests/api/test_statements.py | 26 +++++++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/ralph/conf.py b/src/ralph/conf.py index ad91785ae..1a149577d 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -190,7 +190,9 @@ class AuthBackends(Enum): RUNSERVER_AUTH_BACKEND: AuthBackends = AuthBackends.BASIC RUNSERVER_AUTH_OIDC_AUDIENCE: str = None RUNSERVER_AUTH_OIDC_ISSUER_URI: AnyHttpUrl = None - RUNSERVER_BACKEND: Literal["clickhouse", "es", "mongo"] = "es" + RUNSERVER_BACKEND: Literal[ + "async_es", "async_mongo", "clickhouse", "es", "mongo" + ] = "es" RUNSERVER_HOST: str = "0.0.0.0" # nosec RUNSERVER_MAX_SEARCH_HITS_COUNT: int = 100 RUNSERVER_POINT_IN_TIME_KEEP_ALIVE: str = "1m" diff --git a/tests/api/test_statements.py b/tests/api/test_statements.py index 0d629356e..1a468875c 100644 --- a/tests/api/test_statements.py +++ b/tests/api/test_statements.py @@ -4,9 +4,11 @@ from ralph import conf from ralph.api.routers import statements -from ralph.backends.data.clickhouse import ClickHouseDataBackend -from ralph.backends.data.es import ESDataBackend -from ralph.backends.data.mongo import MongoDataBackend +from ralph.backends.lrs.async_es import AsyncESLRSBackend +from ralph.backends.lrs.async_mongo import AsyncMongoLRSBackend +from ralph.backends.lrs.clickhouse import ClickHouseLRSBackend +from ralph.backends.lrs.es import ESLRSBackend +from ralph.backends.lrs.mongo import MongoLRSBackend def test_api_statements_backend_instance_with_runserver_backend_env(monkeypatch): @@ -14,19 +16,29 @@ def test_api_statements_backend_instance_with_runserver_backend_env(monkeypatch) instance `BACKEND_CLIENT` should be updated accordingly. """ # Default backend - assert isinstance(statements.BACKEND_CLIENT, ESDataBackend) + assert isinstance(statements.BACKEND_CLIENT, ESLRSBackend) # Mongo backend monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "mongo") reload(conf) - assert isinstance(reload(statements).BACKEND_CLIENT, MongoDataBackend) + assert isinstance(reload(statements).BACKEND_CLIENT, MongoLRSBackend) # Elasticsearch backend monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "es") reload(conf) - assert isinstance(reload(statements).BACKEND_CLIENT, ESDataBackend) + assert isinstance(reload(statements).BACKEND_CLIENT, ESLRSBackend) # ClickHouse backend monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "clickhouse") reload(conf) - assert isinstance(reload(statements).BACKEND_CLIENT, ClickHouseDataBackend) + assert isinstance(reload(statements).BACKEND_CLIENT, ClickHouseLRSBackend) + + # Async Elasticsearch backend + monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "async_es") + reload(conf) + assert isinstance(reload(statements).BACKEND_CLIENT, AsyncESLRSBackend) + + # Async Mongo backend + monkeypatch.setenv("RALPH_RUNSERVER_BACKEND", "async_mongo") + reload(conf) + assert isinstance(reload(statements).BACKEND_CLIENT, AsyncMongoLRSBackend) From 2b199752ecd5a958e8115c6358d12bc0d9f8d397 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Tue, 17 Oct 2023 15:19:49 +0200 Subject: [PATCH 48/65] =?UTF-8?q?=F0=9F=94=A7(projet)=20remove=20elasticse?= =?UTF-8?q?arch=20volume?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A volume was added to elasticsearch docker compose service as a good practice to speed up tests, but prevents the reproducibility between test runs and different machines. Removing this volume from elasticsearch service. --- docker-compose.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 0209de9d2..708179969 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -34,8 +34,6 @@ services: xpack.security.enabled: "false" ports: - "9200:9200" - volumes: - - esdata:/usr/share/elasticsearch/data mem_limit: 2g ulimits: memlock: @@ -74,7 +72,3 @@ services: # -- tools dockerize: image: jwilder/dockerize - -volumes: - esdata: - driver: local From ae2cec60c57d07273a1ad6dcc666e105fb6d2537 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Tue, 17 Oct 2023 15:29:09 +0200 Subject: [PATCH 49/65] =?UTF-8?q?=E2=9C=85(api)=20update=20tests=20to=20co?= =?UTF-8?q?ver=20async=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update API tests to cover newly added asynchronous backends. We no longer use `TestClient` for these tests, but the `httpx.AsyncClient` instead, as described by FastAPI documentation. --- src/ralph/backends/http/async_lrs.py | 2 - tests/api/test_forwarding.py | 9 +- tests/api/test_health.py | 38 +++-- tests/api/test_statements_get.py | 169 +++++++++++-------- tests/api/test_statements_post.py | 227 +++++++++++++++++--------- tests/api/test_statements_put.py | 198 +++++++++++++--------- tests/backends/http/test_async_lrs.py | 27 ++- tests/conftest.py | 1 + tests/fixtures/api.py | 15 ++ tests/fixtures/backends.py | 79 +++++---- 10 files changed, 474 insertions(+), 291 deletions(-) create mode 100644 tests/fixtures/api.py diff --git a/src/ralph/backends/http/async_lrs.py b/src/ralph/backends/http/async_lrs.py index 09c50ac3f..6f58859d6 100644 --- a/src/ralph/backends/http/async_lrs.py +++ b/src/ralph/backends/http/async_lrs.py @@ -335,7 +335,6 @@ async def _fetch_statements(self, target, raw_output, query_params: dict): while True: response = await client.get(target, params=query_params) response.raise_for_status() - statements_response = StatementResponse.parse_obj(response.json()) statements = statements_response.statements statements = ( @@ -370,7 +369,6 @@ async def fetch_all_statements(queue): target=target, raw_output=raw_output, query_params=query_params ): await queue.put(statement) - # Re-raising exceptions is necessary as create_task fails silently except Exception as exception: # None signals that the queue is done diff --git a/tests/api/test_forwarding.py b/tests/api/test_forwarding.py index ee7cc9e4e..2aeb019e1 100644 --- a/tests/api/test_forwarding.py +++ b/tests/api/test_forwarding.py @@ -1,6 +1,5 @@ """Tests for the xAPI statements forwarding background task.""" -import asyncio import json import logging @@ -139,7 +138,7 @@ def test_api_forwarding_get_active_xapi_forwardings_with_inactive_forwardings( is_active=st.just(True), ) ) -def test_api_forwarding_forward_xapi_statements_with_successful_request( +async def test_api_forwarding_forward_xapi_statements_with_successful_request( monkeypatch, caplog, statements, forwarding ): """Test the forward_xapi_statements function should log the forwarded statements @@ -164,7 +163,7 @@ async def post_success(*args, **kwargs): # pylint: disable=unused-argument caplog.clear() with caplog.at_level(logging.DEBUG): - asyncio.run(forward_xapi_statements(statements, method="post")) + await forward_xapi_statements(statements, method="post") assert [ f"Forwarded {len(statements)} statements to {forwarding.url} with success." @@ -185,7 +184,7 @@ async def post_success(*args, **kwargs): # pylint: disable=unused-argument is_active=st.just(True), ) ) -def test_api_forwarding_forward_xapi_statements_with_unsuccessful_request( +async def test_api_forwarding_forward_xapi_statements_with_unsuccessful_request( monkeypatch, caplog, statements, forwarding ): """Test the forward_xapi_statements function should log the error if the request @@ -211,7 +210,7 @@ async def post_fail(*args, **kwargs): # pylint: disable=unused-argument caplog.clear() with caplog.at_level(logging.ERROR): - asyncio.run(forward_xapi_statements(statements, method="post")) + await forward_xapi_statements(statements, method="post") assert ["Failed to forward xAPI statements. Failure during request."] == [ message diff --git a/tests/api/test_health.py b/tests/api/test_health.py index 9fdfddfee..0832c3bbe 100644 --- a/tests/api/test_health.py +++ b/tests/api/test_health.py @@ -2,56 +2,68 @@ import logging import pytest -from fastapi.testclient import TestClient -from ralph.api import app from ralph.api.routers import health from ralph.backends.data.base import DataBackendStatus from tests.fixtures.backends import ( + get_async_es_test_backend, + get_async_mongo_test_backend, get_clickhouse_test_backend, get_es_test_backend, get_mongo_test_backend, ) -client = TestClient(app) - +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_clickhouse_test_backend, get_es_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_clickhouse_test_backend, + get_es_test_backend, + get_mongo_test_backend, + ], ) -def test_api_health_lbheartbeat(backend, monkeypatch): +async def test_api_health_lbheartbeat(client, backend, monkeypatch): """Test the load balancer heartbeat healthcheck.""" monkeypatch.setattr(health, "BACKEND_CLIENT", backend()) - response = client.get("/__lbheartbeat__") + response = await client.get("/__lbheartbeat__") assert response.status_code == 200 assert response.json() is None +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_clickhouse_test_backend, get_es_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_clickhouse_test_backend, + get_es_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=unused-argument -def test_api_health_heartbeat(backend, monkeypatch, clickhouse): +async def test_api_health_heartbeat(client, backend, monkeypatch, clickhouse): + # pylint: disable=unused-argument """Test the heartbeat healthcheck.""" monkeypatch.setattr(health, "BACKEND_CLIENT", backend()) - response = client.get("/__heartbeat__") + response = await client.get("/__heartbeat__") logging.warning(response.read()) assert response.status_code == 200 assert response.json() == {"database": "ok"} monkeypatch.setattr(health.BACKEND_CLIENT, "status", lambda: DataBackendStatus.AWAY) - response = client.get("/__heartbeat__") + response = await client.get("/__heartbeat__") assert response.json() == {"database": "away"} assert response.status_code == 500 monkeypatch.setattr( health.BACKEND_CLIENT, "status", lambda: DataBackendStatus.ERROR ) - response = client.get("/__heartbeat__") + response = await client.get("/__heartbeat__") assert response.json() == {"database": "error"} assert response.status_code == 500 diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index ec8a24085..856bf3e6b 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -7,7 +7,6 @@ import pytest import responses from elasticsearch.helpers import bulk -from fastapi.testclient import TestClient from ralph.api import app from ralph.api.auth import get_authenticated_user @@ -26,6 +25,8 @@ ES_TEST_INDEX, MONGO_TEST_COLLECTION, MONGO_TEST_DATABASE, + get_async_es_test_backend, + get_async_mongo_test_backend, get_clickhouse_test_backend, get_es_test_backend, get_mongo_test_backend, @@ -34,8 +35,6 @@ from ..fixtures.auth import mock_basic_auth_user, mock_oidc_user from ..helpers import mock_activity, mock_agent -client = TestClient(app) - def insert_es_statements(es_client, statements): """Insert a bunch of example statements into Elasticsearch for testing.""" @@ -83,7 +82,7 @@ def insert_clickhouse_statements(statements): assert success == len(statements) -@pytest.fixture(params=["es", "mongo", "clickhouse"]) +@pytest.fixture(params=["async_es", "async_mongo", "es", "mongo", "clickhouse"]) def insert_statements_and_monkeypatch_backend( request, es, mongo, clickhouse, monkeypatch ): @@ -93,6 +92,20 @@ def insert_statements_and_monkeypatch_backend( def _insert_statements_and_monkeypatch_backend(statements): """Inserts statements once into each backend.""" backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + if request.param == "async_es": + insert_es_statements(es, statements) + monkeypatch.setattr(backend_client_class_path, get_async_es_test_backend()) + return + if request.param == "async_mongo": + insert_mongo_statements(mongo, statements) + monkeypatch.setattr( + backend_client_class_path, get_async_mongo_test_backend() + ) + return + if request.param == "es": + insert_es_statements(es, statements) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + return if request.param == "mongo": insert_mongo_statements(mongo, statements) monkeypatch.setattr(backend_client_class_path, get_mongo_test_backend()) @@ -103,12 +116,11 @@ def _insert_statements_and_monkeypatch_backend(statements): backend_client_class_path, get_clickhouse_test_backend() ) return - insert_es_statements(es, statements) - monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) return _insert_statements_and_monkeypatch_backend +@pytest.mark.anyio @pytest.mark.parametrize( "ifi", [ @@ -119,8 +131,8 @@ def _insert_statements_and_monkeypatch_backend(statements): "account_different_home_page", ], ) -def test_api_statements_get_mine( - monkeypatch, fs, insert_statements_and_monkeypatch_backend, ifi +async def test_api_statements_get_mine( + client, monkeypatch, fs, insert_statements_and_monkeypatch_backend, ifi ): """(Security) Test that the get statements API route, given a "mine=True" query parameter returns a list of statements filtered by authority. @@ -173,7 +185,7 @@ def test_api_statements_get_mine( insert_statements_and_monkeypatch_backend(statements) # No restriction on "mine" (implicit) : Return all statements - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) @@ -181,7 +193,7 @@ def test_api_statements_get_mine( assert response.json() == {"statements": [statements[1], statements[0]]} # No restriction on "mine" (explicit) : Return all statements - response = client.get( + response = await client.get( "/xAPI/statements/?mine=False", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) @@ -189,7 +201,7 @@ def test_api_statements_get_mine( assert response.json() == {"statements": [statements[1], statements[0]]} # Only fetch mine (explicit) : Return filtered statements - response = client.get( + response = await client.get( "/xAPI/statements/?mine=True", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) @@ -201,7 +213,7 @@ def test_api_statements_get_mine( monkeypatch.setattr( "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_AUTHORITY", True ) - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) @@ -210,7 +222,7 @@ def test_api_statements_get_mine( # Only fetch mine (implicit) with contradictory user request: Return filtered # statements - response = client.get( + response = await client.get( "/xAPI/statements/?mine=False", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) @@ -218,7 +230,7 @@ def test_api_statements_get_mine( assert response.json() == {"statements": [statements[0]]} # Fetch "mine" by id with a single forbidden statement : Return empty list - response = client.get( + response = await client.get( f"/xAPI/statements/?statementId={statements[1]['id']}&mine=True", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) @@ -226,15 +238,16 @@ def test_api_statements_get_mine( assert response.json() == {"statements": []} # Check that invalid parameters returns an error - response = client.get( + response = await client.get( "/xAPI/statements/?mine=BigBoat", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) assert response.status_code == 422 -def test_api_statements_get( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route without any filters set up.""" # pylint: disable=redefined-outer-name @@ -253,7 +266,7 @@ def test_api_statements_get( # Confirm that calling this with and without the trailing slash both work for path in ("/xAPI/statements", "/xAPI/statements/"): - response = client.get( + response = await client.get( path, headers={"Authorization": f"Basic {basic_auth_credentials}"} ) @@ -261,8 +274,9 @@ def test_api_statements_get( assert response.json() == {"statements": [statements[1], statements[0]]} -def test_api_statements_get_ascending( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_ascending( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "ascending" query parameter, should return statements in ascending order by their timestamp. @@ -281,7 +295,7 @@ def test_api_statements_get_ascending( ] insert_statements_and_monkeypatch_backend(statements) - response = client.get( + response = await client.get( "/xAPI/statements/?ascending=true", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -290,8 +304,9 @@ def test_api_statements_get_ascending( assert response.json() == {"statements": [statements[0], statements[1]]} -def test_api_statements_get_by_statement_id( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_by_statement_id( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "statementId" query parameter, should return a list of statements matching the given statementId. @@ -310,7 +325,7 @@ def test_api_statements_get_by_statement_id( ] insert_statements_and_monkeypatch_backend(statements) - response = client.get( + response = await client.get( f"/xAPI/statements/?statementId={statements[1]['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -319,6 +334,7 @@ def test_api_statements_get_by_statement_id( assert response.json() == {"statements": [statements[1]]} +@pytest.mark.anyio @pytest.mark.parametrize( "ifi", [ @@ -329,8 +345,8 @@ def test_api_statements_get_by_statement_id( "account_different_home_page", ], ) -def test_api_statements_get_by_agent( - ifi, insert_statements_and_monkeypatch_backend, basic_auth_credentials +async def test_api_statements_get_by_agent( + client, ifi, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "agent" query parameter, should return a list of statements filtered by the given agent. @@ -364,7 +380,7 @@ def test_api_statements_get_by_agent( ] insert_statements_and_monkeypatch_backend(statements) - response = client.get( + response = await client.get( f"/xAPI/statements/?agent={quote_plus(json.dumps(agent_1))}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -373,8 +389,9 @@ def test_api_statements_get_by_agent( assert response.json() == {"statements": [statements[0]]} -def test_api_statements_get_by_verb( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_by_verb( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "verb" query parameter, should return a list of statements filtered by the given verb id. @@ -395,7 +412,7 @@ def test_api_statements_get_by_verb( ] insert_statements_and_monkeypatch_backend(statements) - response = client.get( + response = await client.get( "/xAPI/statements/?verb=" + quote_plus("http://adlnet.gov/expapi/verbs/played"), headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -404,8 +421,9 @@ def test_api_statements_get_by_verb( assert response.json() == {"statements": [statements[1]]} -def test_api_statements_get_by_activity( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_by_activity( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "activity" query parameter, should return a list of statements filtered by the given activity id. @@ -429,7 +447,7 @@ def test_api_statements_get_by_activity( ] insert_statements_and_monkeypatch_backend(statements) - response = client.get( + response = await client.get( f"/xAPI/statements/?activity={activity_1['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -438,7 +456,7 @@ def test_api_statements_get_by_activity( assert response.json() == {"statements": [statements[1]]} # Check that badly formated activity returns an error - response = client.get( + response = await client.get( "/xAPI/statements/?activity=INVALID_IRI", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -447,8 +465,9 @@ def test_api_statements_get_by_activity( assert response.json()["detail"][0]["msg"] == "'INVALID_IRI' is not a valid 'IRI'." -def test_api_statements_get_since_timestamp( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_since_timestamp( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "since" query parameter, should return a list of statements filtered by the given timestamp. @@ -468,7 +487,7 @@ def test_api_statements_get_since_timestamp( insert_statements_and_monkeypatch_backend(statements) since = (datetime.now() - timedelta(minutes=30)).isoformat() - response = client.get( + response = await client.get( f"/xAPI/statements/?since={since}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -477,8 +496,9 @@ def test_api_statements_get_since_timestamp( assert response.json() == {"statements": [statements[1]]} -def test_api_statements_get_until_timestamp( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_until_timestamp( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "until" query parameter, should return a list of statements filtered by the given timestamp. @@ -498,7 +518,7 @@ def test_api_statements_get_until_timestamp( insert_statements_and_monkeypatch_backend(statements) until = (datetime.now() - timedelta(minutes=30)).isoformat() - response = client.get( + response = await client.get( f"/xAPI/statements/?until={until}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -507,8 +527,12 @@ def test_api_statements_get_until_timestamp( assert response.json() == {"statements": [statements[0]]} -def test_api_statements_get_with_pagination( - monkeypatch, insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_with_pagination( + client, + monkeypatch, + insert_statements_and_monkeypatch_backend, + basic_auth_credentials, ): """Test the get statements API route, given a request leading to more results than can fit on the first page, should return a list of statements non-exceeding the page @@ -546,7 +570,7 @@ def test_api_statements_get_with_pagination( # First response gets the first two results, with a "more" entry as # we have more results to return on a later page. - first_response = client.get( + first_response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -558,7 +582,7 @@ def test_api_statements_get_with_pagination( assert all(key in more_query_params for key in ("pit_id", "search_after")) # Second response gets the missing result from the first response. - second_response = client.get( + second_response = await client.get( first_response.json()["more"], headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -570,7 +594,7 @@ def test_api_statements_get_with_pagination( assert all(key in more_query_params for key in ("pit_id", "search_after")) # Third response gets the missing result from the first response - third_response = client.get( + third_response = await client.get( second_response.json()["more"], headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -578,8 +602,12 @@ def test_api_statements_get_with_pagination( assert third_response.json() == {"statements": [statements[0]]} -def test_api_statements_get_with_pagination_and_query( - monkeypatch, insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_with_pagination_and_query( + client, + monkeypatch, + insert_statements_and_monkeypatch_backend, + basic_auth_credentials, ): """Test the get statements API route, given a request with a query parameter leading to more results than can fit on the first page, should return a list @@ -622,7 +650,7 @@ def test_api_statements_get_with_pagination_and_query( # First response gets the first two results, with a "more" entry as # we have more results to return on a later page. - first_response = client.get( + first_response = await client.get( "/xAPI/statements/?verb=" + quote_plus("https://w3id.org/xapi/video/verbs/played"), headers={"Authorization": f"Basic {basic_auth_credentials}"}, @@ -635,7 +663,7 @@ def test_api_statements_get_with_pagination_and_query( assert all(key in more_query_params for key in ("verb", "pit_id", "search_after")) # Second response gets the missing result from the first response. - second_response = client.get( + second_response = await client.get( first_response.json()["more"], headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -643,8 +671,9 @@ def test_api_statements_get_with_pagination_and_query( assert second_response.json() == {"statements": [statements[0]]} -def test_api_statements_get_with_no_matching_statement( - insert_statements_and_monkeypatch_backend, basic_auth_credentials +@pytest.mark.anyio +async def test_api_statements_get_with_no_matching_statement( + client, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a query yielding no matching statement, should return an empty list. @@ -663,7 +692,7 @@ def test_api_statements_get_with_no_matching_statement( ] insert_statements_and_monkeypatch_backend(statements) - response = client.get( + response = await client.get( "/xAPI/statements/?statementId=66c81e98-1763-4730-8cfc-f5ab34f1bad5", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -672,8 +701,9 @@ def test_api_statements_get_with_no_matching_statement( assert response.json() == {"statements": []} -def test_api_statements_get_with_database_query_failure( - basic_auth_credentials, monkeypatch +@pytest.mark.anyio +async def test_api_statements_get_with_database_query_failure( + client, basic_auth_credentials, monkeypatch ): """Test the get statements API route, given a query raising a BackendException, should return an error response with HTTP code 500. @@ -689,7 +719,7 @@ def mock_query_statements(*_): mock_query_statements, ) - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -697,15 +727,18 @@ def mock_query_statements(*_): assert response.json() == {"detail": "xAPI statements query failed"} +@pytest.mark.anyio @pytest.mark.parametrize("id_param", ["statementId", "voidedStatementId"]) -def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_param): +async def test_api_statements_get_invalid_query_parameters( + client, basic_auth_credentials, id_param +): """Test error response for invalid query parameters""" id_1 = "be67b160-d958-4f51-b8b8-1892002dbac6" id_2 = "66c81e98-1763-4730-8cfc-f5ab34f1bad5" # Check for 400 status code when unknown parameters are provided - response = client.get( + response = await client.get( "/xAPI/statements/?mamamia=herewegoagain", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -715,7 +748,7 @@ def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_ } # Check for 400 status code when both statementId and voidedStatementId are provided - response = client.get( + response = await client.get( f"/xAPI/statements/?statementId={id_1}&voidedStatementId={id_2}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -727,7 +760,7 @@ def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_ ("agent", json.dumps(mock_agent("mbox", 1))), ("verb", "verb_1"), ]: - response = client.get( + response = await client.get( f"/xAPI/statements/?{id_param}={id_1}&{invalid_param}={value}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -741,13 +774,14 @@ def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_ # Check for NO 400 status code when statementId is passed with authorized parameters for valid_param, value in [("format", "ids"), ("attachments", "true")]: - response = client.get( + response = await client.get( f"/xAPI/statements/?{id_param}={id_1}&{valid_param}={value}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code != 400 +@pytest.mark.anyio @responses.activate @pytest.mark.parametrize("auth_method", ["basic", "oidc"]) @pytest.mark.parametrize( @@ -764,8 +798,8 @@ def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_ ([], False), ], ) -def test_api_statements_get_scopes( - monkeypatch, fs, es, auth_method, scopes, is_authorized +async def test_api_statements_get_scopes( + client, monkeypatch, fs, es, auth_method, scopes, is_authorized ): """Test that getting statements behaves properly according to user scopes.""" # pylint: disable=invalid-name,too-many-locals,too-many-arguments @@ -821,7 +855,7 @@ def test_api_statements_get_scopes( insert_es_statements(es, statements) monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) - response = client.get( + response = await client.get( "/xAPI/statements/", headers=headers, ) @@ -838,6 +872,7 @@ def test_api_statements_get_scopes( app.dependency_overrides.pop(get_authenticated_user, None) +@pytest.mark.anyio @pytest.mark.parametrize( "scopes,read_all_access", [ @@ -847,15 +882,15 @@ def test_api_statements_get_scopes( (["statements/read/mine"], False), ], ) -def test_api_statements_get_scopes_with_authority( - monkeypatch, fs, es, scopes, read_all_access +async def test_api_statements_get_scopes_with_authority( + client, monkeypatch, fs, es, scopes, read_all_access ): """Test that restricting by scope and by authority behaves properly. Getting statements should be restricted to mine for users which only have `statements/read/mine` scope but should not be restricted when the user has wider scopes. """ - # pylint: disable=invalid-name + # pylint: disable=invalid-name,too-many-arguments monkeypatch.setattr( "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_AUTHORITY", True ) @@ -893,7 +928,7 @@ def test_api_statements_get_scopes_with_authority( insert_es_statements(es, statements) monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) - response = client.get( + response = await client.get( "/xAPI/statements/", headers=headers, ) diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index fe3e63691..7b7858042 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -5,7 +5,6 @@ import pytest import responses -from fastapi.testclient import TestClient from httpx import AsyncClient from ralph.api import app @@ -26,6 +25,8 @@ MONGO_TEST_FORWARDING_COLLECTION, RUNSERVER_TEST_HOST, RUNSERVER_TEST_PORT, + get_async_es_test_backend, + get_async_mongo_test_backend, get_clickhouse_test_backend, get_es_test_backend, get_mongo_test_backend, @@ -39,16 +40,15 @@ string_is_uuid, ) -client = TestClient(app) - -def test_api_statements_post_invalid_parameters(basic_auth_credentials): +@pytest.mark.anyio +async def test_api_statements_post_invalid_parameters(client, basic_auth_credentials): """Test that using invalid parameters returns the proper status code.""" statement = mock_statement() # Check for 400 status code when unknown parameters are provided - response = client.post( + response = await client.post( "/xAPI/statements/?mamamia=herewegoagain", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -59,13 +59,20 @@ def test_api_statements_post_invalid_parameters(basic_auth_credentials): } +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) # pylint: disable=too-many-arguments -def test_api_statements_post_single_statement_directly( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_single_statement_directly( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument @@ -73,7 +80,7 @@ def test_api_statements_post_single_statement_directly( monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -84,7 +91,7 @@ def test_api_statements_post_single_statement_directly( es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -94,12 +101,12 @@ def test_api_statements_post_single_statement_directly( ) -# pylint: disable=too-many-arguments -def test_api_statements_post_enriching_without_existing_values( - monkeypatch, basic_auth_credentials, es +@pytest.mark.anyio +async def test_api_statements_post_enriching_without_existing_values( + client, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when statement provides no values.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name monkeypatch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() @@ -116,7 +123,7 @@ def test_api_statements_post_enriching_without_existing_values( "verb": {"id": "https://example.com/verb-id/1/"}, } - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -126,7 +133,7 @@ def test_api_statements_post_enriching_without_existing_values( es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -150,6 +157,7 @@ def test_api_statements_post_enriching_without_existing_values( assert statement["authority"] == {"mbox": "mailto:test_ralph@example.com"} +@pytest.mark.anyio @pytest.mark.parametrize( "field,value,status", [ @@ -159,12 +167,11 @@ def test_api_statements_post_enriching_without_existing_values( ("authority", {"mbox": "mailto:test_ralph@example.com"}, 200), ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_enriching_with_existing_values( - field, value, status, monkeypatch, basic_auth_credentials, es +async def test_api_statements_post_enriching_with_existing_values( + client, field, value, status, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when values are provided.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() @@ -174,7 +181,7 @@ def test_api_statements_post_enriching_with_existing_values( # Add the field to be tested statement[field] = value - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -185,7 +192,7 @@ def test_api_statements_post_enriching_with_existing_values( # Check that values match when they should if status == 200: es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -201,21 +208,27 @@ def test_api_statements_post_enriching_with_existing_values( assert statement[field] == value +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_single_statement_no_trailing_slash( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_single_statement_no_trailing_slash( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test that the statements endpoint also works without the trailing slash.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -225,21 +238,27 @@ def test_api_statements_post_single_statement_no_trailing_slash( assert response.json() == [statement["id"]] +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_list_of_one( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_list_of_one( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement in a list.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement], @@ -249,7 +268,7 @@ def test_api_statements_post_list_of_one( assert response.json() == [statement["id"]] es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -259,16 +278,22 @@ def test_api_statements_post_list_of_one( ) +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_list( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_list( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with two statements in a list.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) @@ -280,7 +305,7 @@ def test_api_statements_post_list( statements = [statement_1, statement_2] - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statements, @@ -293,7 +318,7 @@ def test_api_statements_post_list( assert regex.match(generated_id) es.indices.refresh() - get_response = client.get( + get_response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -307,25 +332,33 @@ def test_api_statements_post_list( ) +@pytest.mark.anyio @pytest.mark.parametrize( "backend", [ + get_async_es_test_backend, + get_async_mongo_test_backend, get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend, ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_list_with_duplicates( - backend, monkeypatch, basic_auth_credentials, es_data_stream, mongo, clickhouse +async def test_api_statements_post_list_with_duplicates( + client, + backend, + monkeypatch, + basic_auth_credentials, + es_data_stream, + mongo, + clickhouse, ): """Test the post statements API route with duplicate statement IDs should fail.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement, statement], @@ -337,7 +370,7 @@ def test_api_statements_post_list_with_duplicates( } # The failure should imply no statement insertion. es_data_stream.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -345,18 +378,24 @@ def test_api_statements_post_list_with_duplicates( assert response.json() == {"statements": []} +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_list_with_duplicate_of_existing_statement( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_list_with_duplicate_of_existing_statement( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route, given a statement that already exist in the database (has the same ID), should fail. """ - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) @@ -364,7 +403,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( statement = mock_statement(id_=statement_uuid) # Post the statement once. - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -376,7 +415,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # Post the statement twice, the data is identical so it should succeed but not # include the ID in the response as it wasn't inserted. - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -386,7 +425,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( es.indices.refresh() # Post the statement again, trying to change the timestamp which is not allowed. - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[dict(statement, **{"timestamp": "2023-03-15T14:07:51Z"})], @@ -398,7 +437,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( f"{statement_uuid}" } - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -408,17 +447,24 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( ) +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -def test_api_statements_post_with_failure_during_storage( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_with_failure_during_storage( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure happening during storage.""" - # pylint: disable=invalid-name,unused-argument, too-many-arguments + # pylint: disable=invalid-name,unused-argument,too-many-arguments - def write_mock(*args, **kwargs): + async def write_mock(*args, **kwargs): """Raises an exception. Mocks the database.write method.""" raise BackendException() @@ -427,7 +473,7 @@ def write_mock(*args, **kwargs): monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -437,12 +483,19 @@ def write_mock(*args, **kwargs): assert response.json() == {"detail": "Statements bulk indexation failed"} +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -def test_api_statements_post_with_failure_during_id_query( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_post_with_failure_during_id_query( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure during query execution.""" # pylint: disable=invalid-name,unused-argument,too-many-arguments @@ -458,7 +511,7 @@ def query_statements_by_ids_mock(*args, **kwargs): monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -468,22 +521,28 @@ def query_statements_by_ids_mock(*args, **kwargs): assert response.json() == {"detail": "xAPI statements query failed"} +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_post_list_without_forwarding( - backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse +async def test_api_statements_post_list_without_forwarding( + client, backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the post statements API route, given an empty forwarding configuration, should not start the forwarding background task. """ - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments spy = {} - def spy_mock_forward_xapi_statements(_): + async def spy_mock_forward_xapi_statements(_): """Mock the forward_xapi_statements; spies over whether it has been called.""" spy["error"] = "forward_xapi_statements should not have been called!" @@ -498,7 +557,7 @@ def spy_mock_forward_xapi_statements(_): statement = mock_statement() - response = client.post( + response = await client.post( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -508,9 +567,13 @@ def spy_mock_forward_xapi_statements(_): assert "error" not in spy -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize( - "receiving_backend", [get_es_test_backend, get_mongo_test_backend] + "receiving_backend", + [ + get_es_test_backend, + get_mongo_test_backend, + ], ) @pytest.mark.parametrize( "forwarding_backend", @@ -555,10 +618,11 @@ async def test_api_statements_post_list_with_forwarding( receiving_patch.setattr( "ralph.api.forwarding.get_active_xapi_forwardings", lambda: [] ) - # Receiving client should use the receiving Elasticsearch client for storage + # Receiving client should use the receiving backend for storage receiving_patch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", receiving_backend() ) + lrs_context = lrs(app) # Start receiving LRS client await lrs_context.__aenter__() # pylint: disable=unnecessary-dunder-call @@ -629,6 +693,7 @@ async def test_api_statements_post_list_with_forwarding( await lrs_context.__aexit__(None, None, None) +@pytest.mark.anyio @responses.activate @pytest.mark.parametrize("auth_method", ["basic", "oidc"]) @pytest.mark.parametrize( @@ -642,11 +707,11 @@ async def test_api_statements_post_list_with_forwarding( ([], False), ], ) -def test_api_statements_post_scopes( - monkeypatch, fs, es, auth_method, scopes, is_authorized +async def test_api_statements_post_scopes( + client, monkeypatch, fs, es, auth_method, scopes, is_authorized ): """Test that posting statements behaves properly according to user scopes.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr( "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True ) @@ -683,7 +748,7 @@ def test_api_statements_post_scopes( backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) - response = client.post( + response = await client.post( "/xAPI/statements/", headers=headers, json=statement, diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index ae30b2b73..5512987b6 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -1,13 +1,10 @@ """Tests for the PUT statements endpoint of the Ralph API.""" -from importlib import reload from uuid import uuid4 import pytest import responses -from fastapi.testclient import TestClient from httpx import AsyncClient -from ralph import api from ralph.api import app from ralph.api.auth import get_authenticated_user from ralph.api.auth.basic import get_basic_auth_user @@ -26,6 +23,8 @@ MONGO_TEST_FORWARDING_COLLECTION, RUNSERVER_TEST_HOST, RUNSERVER_TEST_PORT, + get_async_es_test_backend, + get_async_mongo_test_backend, get_clickhouse_test_backend, get_es_test_backend, get_mongo_test_backend, @@ -38,16 +37,14 @@ string_is_date, ) -reload(api) -client = TestClient(app) - -def test_api_statements_put_invalid_parameters(basic_auth_credentials): +@pytest.mark.anyio +async def test_api_statements_put_invalid_parameters(client, basic_auth_credentials): """Test that using invalid parameters returns the proper status code.""" statement = mock_statement() # Check for 400 status code when unknown parameters are provided - response = client.put( + response = await client.put( "/xAPI/statements/?mamamia=herewegoagain", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -60,19 +57,25 @@ def test_api_statements_put_invalid_parameters(basic_auth_credentials): @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_single_statement_directly( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +@pytest.mark.anyio +async def test_api_statements_put_single_statement_directly( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with one statement.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -82,7 +85,7 @@ def test_api_statements_put_single_statement_directly( es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -92,9 +95,9 @@ def test_api_statements_put_single_statement_directly( ) -# pylint: disable=too-many-arguments -def test_api_statements_put_enriching_without_existing_values( - monkeypatch, basic_auth_credentials, es +@pytest.mark.anyio +async def test_api_statements_put_enriching_without_existing_values( + client, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when statement provides no values.""" # pylint: disable=invalid-name,unused-argument @@ -104,7 +107,7 @@ def test_api_statements_put_enriching_without_existing_values( ) statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -113,7 +116,7 @@ def test_api_statements_put_enriching_without_existing_values( es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -137,6 +140,7 @@ def test_api_statements_put_enriching_without_existing_values( assert statement["authority"] == {"mbox": "mailto:test_ralph@example.com"} +@pytest.mark.anyio @pytest.mark.parametrize( "field,value,status", [ @@ -145,12 +149,11 @@ def test_api_statements_put_enriching_without_existing_values( ("authority", {"mbox": "mailto:test_ralph@example.com"}, 204), ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_enriching_with_existing_values( - field, value, status, monkeypatch, basic_auth_credentials, es +async def test_api_statements_put_enriching_with_existing_values( + client, field, value, status, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when values are provided.""" - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument, too-many-arguments monkeypatch.setattr( "ralph.api.routers.statements.BACKEND_CLIENT", get_es_test_backend() @@ -160,7 +163,7 @@ def test_api_statements_put_enriching_with_existing_values( # Add the field to be tested statement[field] = value - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -171,7 +174,7 @@ def test_api_statements_put_enriching_with_existing_values( # Check that values match when they should if status == 204: es.indices.refresh() - response = client.get( + response = await client.get( "/xAPI/statements/", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -186,21 +189,27 @@ def test_api_statements_put_enriching_with_existing_values( assert statement[field] == value +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_single_statement_no_trailing_slash( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_put_single_statement_no_trailing_slash( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): + # pylint: disable=invalid-name,unused-argument,too-many-arguments """Test that the statements endpoint also works without the trailing slash.""" - # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -209,21 +218,27 @@ def test_api_statements_put_single_statement_no_trailing_slash( assert response.status_code == 204 +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_id_mismatch( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_put_id_mismatch( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments """Test the put statements API route when the statementId doesn't match.""" monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement(id_=str(uuid4())) different_statement_id = str(uuid4()) - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={different_statement_id}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -235,20 +250,26 @@ def test_api_statements_put_id_mismatch( } +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_list_of_one( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_put_list_of_one( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): - # pylint: disable=invalid-name,unused-argument + # pylint: disable=invalid-name,unused-argument,too-many-arguments """Test that we fail on PUTs with a list, even if it's one statement.""" monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement], @@ -257,24 +278,30 @@ def test_api_statements_put_list_of_one( assert response.status_code == 422 +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_duplicate_of_existing_statement( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_put_duplicate_of_existing_statement( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): + # pylint: disable=invalid-name,unused-argument,too-many-arguments """Test the put statements API route, given a statement that already exist in the database (has the same ID), should fail. """ - # pylint: disable=invalid-name,unused-argument monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend()) statement = mock_statement() # Put the statement once. - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -284,7 +311,7 @@ def test_api_statements_put_duplicate_of_existing_statement( es.indices.refresh() # Put the statement twice, trying to change the timestamp, which is not allowed - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=dict(statement, **{"timestamp": "2023-03-15T14:07:51Z"}), @@ -295,7 +322,7 @@ def test_api_statements_put_duplicate_of_existing_statement( "detail": "A different statement already exists with the same ID" } - response = client.get( + response = await client.get( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) @@ -305,15 +332,22 @@ def test_api_statements_put_duplicate_of_existing_statement( ) +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -def test_api_statements_put_with_failure_during_storage( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_put_with_failure_during_storage( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure happening during storage.""" - # pylint: disable=invalid-name,unused-argument, too-many-arguments + # pylint: disable=invalid-name,unused-argument,too-many-arguments def write_mock(*args, **kwargs): """Raises an exception. Mocks the database.write method.""" @@ -324,7 +358,7 @@ def write_mock(*args, **kwargs): monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -334,12 +368,19 @@ def write_mock(*args, **kwargs): assert response.json() == {"detail": "Statement indexation failed"} +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -def test_api_statements_put_with_a_failure_during_id_query( - backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse +async def test_api_statements_put_with_a_failure_during_id_query( + client, backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure during query execution.""" # pylint: disable=invalid-name,unused-argument,too-many-arguments @@ -355,7 +396,7 @@ def query_statements_by_ids_mock(*args, **kwargs): monkeypatch.setattr("ralph.api.routers.statements.BACKEND_CLIENT", backend_instance) statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -365,18 +406,24 @@ def query_statements_by_ids_mock(*args, **kwargs): assert response.json() == {"detail": "xAPI statements query failed"} +@pytest.mark.anyio @pytest.mark.parametrize( "backend", - [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], + [ + get_async_es_test_backend, + get_async_mongo_test_backend, + get_es_test_backend, + get_clickhouse_test_backend, + get_mongo_test_backend, + ], ) -# pylint: disable=too-many-arguments -def test_api_statements_put_without_forwarding( - backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse +async def test_api_statements_put_without_forwarding( + client, backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse ): + # pylint: disable=invalid-name,unused-argument,too-many-arguments """Test the put statements API route, given an empty forwarding configuration, should not start the forwarding background task. """ - # pylint: disable=invalid-name,unused-argument spy = {} @@ -395,7 +442,7 @@ def spy_mock_forward_xapi_statements(_): statement = mock_statement() - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, @@ -404,9 +451,13 @@ def spy_mock_forward_xapi_statements(_): assert response.status_code == 204 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize( - "receiving_backend", [get_es_test_backend, get_mongo_test_backend] + "receiving_backend", + [ + get_es_test_backend, + get_mongo_test_backend, + ], ) @pytest.mark.parametrize( "forwarding_backend", @@ -528,6 +579,7 @@ async def test_api_statements_put_with_forwarding( await lrs_context.__aexit__(None, None, None) +@pytest.mark.anyio @responses.activate @pytest.mark.parametrize("auth_method", ["basic", "oidc"]) @pytest.mark.parametrize( @@ -541,11 +593,11 @@ async def test_api_statements_put_with_forwarding( ([], False), ], ) -def test_api_statements_put_scopes( - monkeypatch, fs, es, auth_method, scopes, is_authorized +async def test_api_statements_put_scopes( + client, monkeypatch, fs, es, auth_method, scopes, is_authorized ): """Test that putting statements behaves properly according to user scopes.""" - # pylint: disable=invalid-name,unused-argument,duplicate-code + # pylint: disable=invalid-name,unused-argument,duplicate-code,too-many-arguments monkeypatch.setattr( "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True ) @@ -582,7 +634,7 @@ def test_api_statements_put_scopes( backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) - response = client.put( + response = await client.put( f"/xAPI/statements/?statementId={statement['id']}", headers=headers, json=statement, diff --git a/tests/backends/http/test_async_lrs.py b/tests/backends/http/test_async_lrs.py index f89f21aec..9bddefa0c 100644 --- a/tests/backends/http/test_async_lrs.py +++ b/tests/backends/http/test_async_lrs.py @@ -253,20 +253,19 @@ async def test_backends_http_lrs_read_max_statements( json=statements, ) - if (max_statements is None) or (max_statements > chunk_size): - default_params.update(dict(parse_qsl(urlparse(more_target).query))) - httpx_mock.add_response( - url=ParseResult( - scheme=urlparse(base_url).scheme, - netloc=urlparse(base_url).netloc, - path=urlparse(more_target).path, - query=urlencode(default_params).lower(), - params="", - fragment="", - ).geturl(), - method="GET", - json=more_statements, - ) + default_params.update(dict(parse_qsl(urlparse(more_target).query))) + httpx_mock.add_response( + url=ParseResult( + scheme=urlparse(base_url).scheme, + netloc=urlparse(base_url).netloc, + path=urlparse(more_target).path, + query=urlencode(default_params).lower(), + params="", + fragment="", + ).geturl(), + method="GET", + json=more_statements, + ) settings = AsyncLRSHTTPBackend.settings_class( BASE_URL=base_url, USERNAME="user", PASSWORD="pass" diff --git a/tests/conftest.py b/tests/conftest.py index 8165d3458..281917e9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from .fixtures import hypothesis_configuration # noqa: F401 from .fixtures import hypothesis_strategies # noqa: F401 +from .fixtures.api import client # noqa: F401 from .fixtures.auth import ( # noqa: F401 basic_auth_credentials, basic_auth_test_client, diff --git a/tests/fixtures/api.py b/tests/fixtures/api.py new file mode 100644 index 000000000..3d969b86d --- /dev/null +++ b/tests/fixtures/api.py @@ -0,0 +1,15 @@ +"""Test fixtures related to the API.""" + +import pytest +from httpx import AsyncClient + +from ralph.api import app + + +@pytest.mark.anyio +@pytest.fixture(scope="session") +async def client(): + """Return an AsyncClient for the FastAPI app.""" + + async with AsyncClient(app=app, base_url="http://test") as async_client: + yield async_client diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index b8f0bf9d1..46df6812c 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -107,16 +107,32 @@ def get_clickhouse_test_backend(): @lru_cache def get_es_test_backend(): - """Returns a ESLRSBackend backend instance using test defaults.""" + """Return a ESLRSBackend backend instance using test defaults.""" settings = ESLRSBackend.settings_class( HOSTS=ES_TEST_HOSTS, DEFAULT_INDEX=ES_TEST_INDEX ) return ESLRSBackend(settings) +@lru_cache +def get_async_es_test_backend(index: str = ES_TEST_INDEX): + """Return an AsyncESLRSBackend backend instance using test defaults.""" + settings = AsyncESLRSBackend.settings_class( + ALLOW_YELLOW_STATUS=False, + CLIENT_OPTIONS={"ca_certs": None, "verify_certs": None}, + DEFAULT_CHUNK_SIZE=500, + DEFAULT_INDEX=index, + HOSTS=ES_TEST_HOSTS, + LOCALE_ENCODING="utf8", + POINT_IN_TIME_KEEP_ALIVE="1m", + REFRESH_AFTER_WRITE=True, + ) + return AsyncESLRSBackend(settings) + + @lru_cache def get_mongo_test_backend(): - """Returns a MongoDatabase backend instance using test defaults.""" + """Return a MongoDatabase backend instance using test defaults.""" settings = MongoLRSBackend.settings_class( CONNECTION_URI=MONGO_TEST_CONNECTION_URI, DEFAULT_DATABASE=MONGO_TEST_DATABASE, @@ -125,6 +141,24 @@ def get_mongo_test_backend(): return MongoLRSBackend(settings) +@lru_cache +def get_async_mongo_test_backend( + connection_uri: str = MONGO_TEST_CONNECTION_URI, + default_collection: str = MONGO_TEST_COLLECTION, + client_options: dict = None, +): + """Return an AsyncMongoDatabase backend instance using test defaults.""" + settings = AsyncMongoLRSBackend.settings_class( + CONNECTION_URI=connection_uri, + DEFAULT_DATABASE=MONGO_TEST_DATABASE, + DEFAULT_COLLECTION=default_collection, + CLIENT_OPTIONS=client_options if client_options else {}, + DEFAULT_CHUNK_SIZE=500, + LOCALE_ENCODING="utf8", + ) + return AsyncMongoLRSBackend(settings) + + def get_es_fixture(host=ES_TEST_HOSTS, index=ES_TEST_INDEX): """Create / delete an Elasticsearch test index and yield an instantiated client.""" client = Elasticsearch(host) @@ -191,7 +225,7 @@ def get_fs_lrs_backend(path: str = "foo"): return get_fs_lrs_backend -@pytest.fixture +@pytest.fixture(scope="session") def anyio_backend(): """Select asyncio backend for pytest anyio.""" return "asyncio" @@ -222,25 +256,11 @@ def get_mongo_data_backend( @pytest.fixture def async_mongo_lrs_backend(): - """Return the `async_get_mongo_lrs_backend` function.""" + """Return the `get_async_mongo_test_backend` function.""" - def async_get_mongo_lrs_backend( - connection_uri: str = MONGO_TEST_CONNECTION_URI, - default_collection: str = MONGO_TEST_COLLECTION, - client_options: dict = None, - ): - """Return an instance of AsyncMongoLRSBackend.""" - settings = AsyncMongoLRSBackend.settings_class( - CONNECTION_URI=connection_uri, - DEFAULT_DATABASE=MONGO_TEST_DATABASE, - DEFAULT_COLLECTION=default_collection, - CLIENT_OPTIONS=client_options if client_options else {}, - DEFAULT_CHUNK_SIZE=500, - LOCALE_ENCODING="utf8", - ) - return AsyncMongoLRSBackend(settings) + get_async_mongo_test_backend.cache_clear() - return async_get_mongo_lrs_backend + return get_async_mongo_test_backend def get_mongo_fixture( @@ -488,24 +508,11 @@ def get_async_es_data_backend(): @pytest.fixture def async_es_lrs_backend(): - """Return the `get_async_es_lrs_backend` function.""" - # pylint: disable=invalid-name,redefined-outer-name,unused-argument + """Return the `get_async_es_test_backend` function.""" - def get_async_es_lrs_backend(index: str = ES_TEST_INDEX): - """Return an instance of AsyncESLRSBackend.""" - settings = AsyncESLRSBackend.settings_class( - ALLOW_YELLOW_STATUS=False, - CLIENT_OPTIONS={"ca_certs": None, "verify_certs": None}, - DEFAULT_CHUNK_SIZE=500, - DEFAULT_INDEX=index, - HOSTS=ES_TEST_HOSTS, - LOCALE_ENCODING="utf8", - POINT_IN_TIME_KEEP_ALIVE="1m", - REFRESH_AFTER_WRITE=True, - ) - return AsyncESLRSBackend(settings) + get_async_es_test_backend.cache_clear() - return get_async_es_lrs_backend + return get_async_es_test_backend @pytest.fixture From 2e37c9da955850df7fa8cb6635215fe93ec16e3a Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Tue, 9 May 2023 12:15:36 +0200 Subject: [PATCH 50/65] =?UTF-8?q?=E2=9C=A8(project)=20add=20mypy=20configu?= =?UTF-8?q?ration=20for=20type=20checking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It has been decided that we use type checking in ralph project. Configurations for local tooling for development and in the CI have been added. --- Makefile | 5 +++++ setup.cfg | 48 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 7db49b6a3..3a47f5eaf 100644 --- a/Makefile +++ b/Makefile @@ -222,6 +222,11 @@ lint-pydocstyle: ## lint Python docstrings with pydocstyle @$(COMPOSE_TEST_RUN_APP) pydocstyle .PHONY: lint-pydocstyle +lint-mypy: ## lint back-end python sources with mypy + @echo 'lint:mypy started…' + @$(COMPOSE_TEST_RUN_APP) mypy +.PHONY: lint-mypy + logs: ## display app logs (follow mode) @$(COMPOSE) logs -f app .PHONY: logs diff --git a/setup.cfg b/setup.cfg index 472414dc5..0c7250166 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,6 +83,7 @@ dev = mkdocs-material==9.4.6 mkdocstrings[python-legacy]==0.23.0 moto==4.2.6 + mypy==1.2.0 pydocstyle==6.3.0 pyfakefs==5.3.0 pylint==3.0.2 @@ -92,6 +93,10 @@ dev = pytest-httpx<0.23.0 # pin as Python 3.7 and 3.8 is no longer supported from release 0.23.0 requests-mock==1.11.0 responses<0.23.2 # pin until boto3 supports urllib3>=2 + types-python-dateutil == 2.8.19.14 + types-python-jose == 3.3.4.8 + types-requests<2.31.0.7 + types-cachetools == 5.3.0.6 ci = twine==4.0.2 lrs = @@ -134,17 +139,50 @@ exclude = node_modules, */migrations/* -[pydocstyle] -convention = google -match_dir = ^(?!tests|venv|build|scripts).* -match = ^(?!(setup)\.(py)$).*\.(py)$ - [isort] known_ralph=ralph sections=FUTURE,STDLIB,THIRDPARTY,RALPH,FIRSTPARTY,LOCALFOLDER skip_glob=venv,*/.conda/* profile=black +[pydocstyle] +convention = google +match_dir = ^(?!tests|venv|build|scripts).* +match = ^(?!(setup)\.(py)$).*\.(py)$ + +[mypy] +warn_return_any = True +warn_unused_configs = True +disallow_untyped_defs = True +files=src/ralph/**/*.py +plugins = pydantic.mypy + +[mypy-rfc3987.*] +ignore_missing_imports = True + +[mypy-requests_toolbelt.*] +ignore_missing_imports = True + +[mypy-botocore.*] +ignore_missing_imports = True + +[mypy-boto3.*] +ignore_missing_imports = True + +[mypy-clickhouse_connect.*] +ignore_missing_imports = True + +[mypy-ovh.*] +ignore_missing_imports = True + +[mypy-swiftclient.service.*] +ignore_missing_imports = True + +[pydantic-mypy] +init_forbid_extra = True +init_typed = True +warn_required_dynamic_aliases = True + [tool:pytest] addopts = -v --cov-report term-missing --cov-config=.coveragerc --cov=ralph python_files = From 10c53e6471a226f24c026c2a0d85670728383315 Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Wed, 24 May 2023 19:45:40 +0200 Subject: [PATCH 51/65] =?UTF-8?q?=F0=9F=9A=A8(project)=20fix=20mypy=20lint?= =?UTF-8?q?er=20warnings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `mypy` being now configured, we need to fix all detected errors of type checkings in ralph source code. --- src/ralph/api/__init__.py | 9 ++-- src/ralph/api/auth/basic.py | 14 ++--- src/ralph/api/auth/oidc.py | 6 +-- src/ralph/api/forwarding.py | 4 +- src/ralph/api/routers/health.py | 4 +- src/ralph/api/routers/statements.py | 22 ++++---- src/ralph/backends/data/async_es.py | 10 ++-- src/ralph/backends/data/async_mongo.py | 18 +++---- src/ralph/backends/data/base.py | 34 ++++++------ src/ralph/backends/data/clickhouse.py | 33 ++++++++---- src/ralph/backends/data/es.py | 22 ++++---- src/ralph/backends/data/fs.py | 20 +++---- src/ralph/backends/data/ldp.py | 30 +++++------ src/ralph/backends/data/mongo.py | 28 +++++----- src/ralph/backends/data/s3.py | 18 +++---- src/ralph/backends/data/swift.py | 18 +++---- src/ralph/backends/http/async_lrs.py | 10 ++-- src/ralph/backends/http/base.py | 10 ++-- src/ralph/backends/http/lrs.py | 10 ++-- src/ralph/backends/lrs/clickhouse.py | 6 +-- src/ralph/backends/lrs/es.py | 2 +- src/ralph/backends/lrs/fs.py | 40 +++++++------- src/ralph/backends/lrs/mongo.py | 2 +- src/ralph/backends/stream/base.py | 2 +- src/ralph/backends/stream/ws.py | 10 ++-- src/ralph/cli.py | 8 +-- src/ralph/conf.py | 21 ++++---- src/ralph/filters.py | 4 +- src/ralph/logger.py | 2 +- src/ralph/models/converter.py | 53 +++++++++++++------ src/ralph/models/edx/base.py | 9 ++-- src/ralph/models/edx/browser.py | 11 ++-- src/ralph/models/edx/converters/xapi/base.py | 3 +- .../models/edx/converters/xapi/enrollment.py | 3 +- .../edx/converters/xapi/navigational.py | 3 +- .../models/edx/converters/xapi/server.py | 4 +- src/ralph/models/edx/converters/xapi/video.py | 13 ++--- .../models/edx/enrollment/fields/contexts.py | 9 ++-- .../models/edx/enrollment/fields/events.py | 9 ++-- src/ralph/models/edx/enrollment/statements.py | 11 ++-- .../models/edx/navigational/statements.py | 19 ++++--- .../open_response_assessment/fields/events.py | 12 ++--- .../open_response_assessment/statements.py | 12 ++--- .../models/edx/peer_instruction/statements.py | 11 ++-- .../edx/problem_interaction/fields/events.py | 11 ++-- .../edx/problem_interaction/statements.py | 11 ++-- src/ralph/models/edx/server.py | 11 ++-- .../edx/textbook_interaction/fields/events.py | 11 ++-- .../edx/textbook_interaction/statements.py | 11 ++-- src/ralph/models/edx/video/fields/events.py | 10 ++-- src/ralph/models/edx/video/statements.py | 11 ++-- src/ralph/models/selector.py | 14 ++--- src/ralph/models/validator.py | 16 +++--- src/ralph/models/xapi/base/agents.py | 11 ++-- src/ralph/models/xapi/base/common.py | 14 ++--- src/ralph/models/xapi/base/groups.py | 11 ++-- src/ralph/models/xapi/base/objects.py | 11 ++-- src/ralph/models/xapi/base/results.py | 4 +- src/ralph/models/xapi/base/statements.py | 6 +-- .../models/xapi/base/unnested_objects.py | 16 +++--- .../activity_types/acrossx_profile.py | 11 ++-- .../activity_streams_vocabulary.py | 14 +++-- .../xapi/concepts/activity_types/audio.py | 9 ++-- .../concepts/activity_types/scorm_profile.py | 18 +++++-- .../activity_types/tincan_vocabulary.py | 12 +++-- .../xapi/concepts/activity_types/video.py | 9 ++-- .../activity_types/virtual_classroom.py | 10 ++-- .../xapi/concepts/verbs/acrossx_profile.py | 11 ++-- .../verbs/activity_streams_vocabulary.py | 11 ++-- .../xapi/concepts/verbs/adl_vocabulary.py | 11 ++-- .../verbs/navy_common_reference_profile.py | 11 ++-- .../xapi/concepts/verbs/scorm_profile.py | 11 ++-- .../xapi/concepts/verbs/tincan_vocabulary.py | 12 ++--- src/ralph/models/xapi/concepts/verbs/video.py | 11 ++-- .../xapi/concepts/verbs/virtual_classroom.py | 12 ++--- src/ralph/models/xapi/lms/contexts.py | 13 ++--- src/ralph/models/xapi/lms/objects.py | 15 ++++-- src/ralph/models/xapi/video/contexts.py | 16 +++--- src/ralph/models/xapi/video/results.py | 11 ++-- .../models/xapi/virtual_classroom/contexts.py | 17 +++--- src/ralph/parsers.py | 11 ++-- src/ralph/utils.py | 33 ++++++------ 82 files changed, 599 insertions(+), 468 deletions(-) diff --git a/src/ralph/api/__init__.py b/src/ralph/api/__init__.py index 23c3e16f4..2a33df53a 100644 --- a/src/ralph/api/__init__.py +++ b/src/ralph/api/__init__.py @@ -1,5 +1,6 @@ """Main module for Ralph's LRS API.""" from functools import lru_cache +from typing import Any, Dict, List, Union from urllib.parse import urlparse import sentry_sdk @@ -14,12 +15,14 @@ @lru_cache(maxsize=None) -def get_health_check_routes(): +def get_health_check_routes() -> List: """Return the health check routes.""" return [route.path for route in health.router.routes] -def filter_transactions(event, hint): # pylint: disable=unused-argument +def filter_transactions( + event: Dict, hint # pylint: disable=unused-argument +) -> Union[Dict, None]: """Filter transactions for Sentry.""" url = urlparse(event["request"]["url"]) @@ -47,6 +50,6 @@ def filter_transactions(event, hint): # pylint: disable=unused-argument @app.get("/whoami") async def whoami( user: AuthenticatedUser = Depends(get_authenticated_user), -): +) -> Dict[str, Any]: """Return the current user's username along with their scopes.""" return {"agent": user.agent, "scopes": user.scopes} diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index 04dfcce59..027552595 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -4,7 +4,7 @@ from functools import lru_cache from pathlib import Path from threading import Lock -from typing import List, Union +from typing import Any, Iterator, List, Optional import bcrypt from cachetools import TTLCache, cached @@ -53,21 +53,21 @@ class ServerUsersCredentials(BaseModel): __root__: List[UserCredentials] - def __add__(self, other): # noqa: D105 + def __add__(self, other) -> Any: # noqa: D105 return ServerUsersCredentials.parse_obj(self.__root__ + other.__root__) - def __getitem__(self, item: int): # noqa: D105 + def __getitem__(self, item: int) -> UserCredentials: # noqa: D105 return self.__root__[item] - def __len__(self): # noqa: D105 + def __len__(self) -> int: # noqa: D105 return len(self.__root__) - def __iter__(self): # noqa: D105 + def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105 return iter(self.__root__) @root_validator @classmethod - def ensure_unique_username(cls, values): + def ensure_unique_username(cls, values: Any) -> Any: """Every username should be unique among registered users.""" usernames = [entry.username for entry in values.get("__root__")] if len(usernames) != len(set(usernames)): @@ -111,7 +111,7 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials: else None, ) def get_basic_auth_user( - credentials: Union[HTTPBasicCredentials, None] = Depends(security), + credentials: Optional[HTTPBasicCredentials] = Depends(security), security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: """Checks valid auth parameters. diff --git a/src/ralph/api/auth/oidc.py b/src/ralph/api/auth/oidc.py index d4476f3f0..2a2d107b0 100644 --- a/src/ralph/api/auth/oidc.py +++ b/src/ralph/api/auth/oidc.py @@ -2,7 +2,7 @@ import logging from functools import lru_cache -from typing import Optional +from typing import Dict, Optional import requests from fastapi import Depends, HTTPException, status @@ -54,7 +54,7 @@ class Config: # pylint: disable=missing-class-docstring # noqa: D106 @lru_cache() -def discover_provider(base_url: AnyUrl) -> dict: +def discover_provider(base_url: AnyUrl) -> Dict: """Discover the authentication server (or OpenId Provider) configuration.""" try: response = requests.get(f"{base_url}{OPENID_CONFIGURATION_PATH}", timeout=5) @@ -72,7 +72,7 @@ def discover_provider(base_url: AnyUrl) -> dict: @lru_cache() -def get_public_keys(jwks_uri: AnyUrl) -> dict: +def get_public_keys(jwks_uri: AnyUrl) -> Dict: """Retrieve the public keys used by the provider server for signing.""" try: response = requests.get(jwks_uri, timeout=5) diff --git a/src/ralph/api/forwarding.py b/src/ralph/api/forwarding.py index d685f3e88..6c85cc8b6 100644 --- a/src/ralph/api/forwarding.py +++ b/src/ralph/api/forwarding.py @@ -14,7 +14,7 @@ @lru_cache def get_active_xapi_forwardings() -> List[XapiForwardingConfigurationSettings]: """Return a list of active xAPI forwarding configuration settings.""" - active_forwardings = [] + active_forwardings: List = [] if not settings.XAPI_FORWARDINGS: logger.info("No xAPI forwarding configured; forwarding is disabled.") return active_forwardings @@ -34,7 +34,7 @@ def get_active_xapi_forwardings() -> List[XapiForwardingConfigurationSettings]: async def forward_xapi_statements( statements: Union[dict, List[dict]], method: Literal["post", "put"] -): +) -> None: """Forward xAPI statements.""" for forwarding in get_active_xapi_forwardings(): transport = AsyncHTTPTransport(retries=forwarding.max_retries) diff --git a/src/ralph/api/routers/health.py b/src/ralph/api/routers/health.py index 5f30b2847..1bafb3467 100644 --- a/src/ralph/api/routers/health.py +++ b/src/ralph/api/routers/health.py @@ -22,7 +22,7 @@ @router.get("/__lbheartbeat__") -async def lbheartbeat(): +async def lbheartbeat() -> None: """Load balancer heartbeat. Returns a 200 when the server is running. @@ -31,7 +31,7 @@ async def lbheartbeat(): @router.get("/__heartbeat__") -async def heartbeat(): +async def heartbeat() -> JSONResponse: """Application heartbeat. Returns a 200 if all checks are successful. diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index 0bfb27212..f163e2f99 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -3,7 +3,7 @@ import json import logging from datetime import datetime -from typing import List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union from urllib.parse import ParseResult, urlencode from uuid import UUID, uuid4 @@ -75,31 +75,33 @@ } -def _enrich_statement_with_id(statement: dict): +def _enrich_statement_with_id(statement: dict) -> None: # id: Statement UUID identifier. # https://github.com/adlnet/xAPI-Spec/blob/master/xAPI-Data.md#24-statement-properties statement["id"] = str(statement.get("id", uuid4())) -def _enrich_statement_with_stored(statement: dict): +def _enrich_statement_with_stored(statement: dict) -> None: # stored: The time at which a Statement is stored by the LRS. # https://github.com/adlnet/xAPI-Spec/blob/1.0.3/xAPI-Data.md#248-stored statement["stored"] = now() -def _enrich_statement_with_timestamp(statement: dict): +def _enrich_statement_with_timestamp(statement: dict) -> None: # timestamp: Time of the action. If not provided, it takes the same value as stored. # https://github.com/adlnet/xAPI-Spec/blob/master/xAPI-Data.md#247-timestamp statement["timestamp"] = statement.get("timestamp", statement["stored"]) -def _enrich_statement_with_authority(statement: dict, current_user: AuthenticatedUser): +def _enrich_statement_with_authority( + statement: dict, current_user: AuthenticatedUser +) -> None: # authority: Information about whom or what has asserted the statement is true. # https://github.com/adlnet/xAPI-Spec/blob/master/xAPI-Data.md#249-authority statement["authority"] = current_user.agent -def _parse_agent_parameters(agent_obj: dict): +def _parse_agent_parameters(agent_obj: dict) -> AgentParameters: """Parse a dict and return an AgentParameters object to use in queries.""" # Transform agent to `dict` as FastAPI cannot parse JSON (seen as string) @@ -120,7 +122,7 @@ def _parse_agent_parameters(agent_obj: dict): return AgentParameters.construct(**agent_query_params) -def strict_query_params(request: Request): +def strict_query_params(request: Request) -> None: """Raise a 400 error when using extra query parameters.""" dependant: Dependant = request.scope["route"].dependant allowed_params = [ @@ -279,7 +281,7 @@ async def get( ), ), _=Depends(strict_query_params), -): +) -> Dict: """Read a single xAPI Statement or multiple xAPI Statements. LRS Specification: @@ -410,7 +412,7 @@ async def put( background_tasks: BackgroundTasks, statement_id: UUID = Query(alias="statementId"), _=Depends(strict_query_params), -): +) -> None: """Store a single statement as a single member of a set. LRS Specification: @@ -492,7 +494,7 @@ async def post( background_tasks: BackgroundTasks, response: Response, _=Depends(strict_query_params), -): +) -> Union[List, None]: """Store a set of statements (or a single statement as a single member of a set). NB: at this time, using POST to make a GET request, is not supported. diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 20b906dd9..7ec7dba20 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -3,7 +3,7 @@ import logging from io import IOBase from itertools import chain -from typing import Iterable, Iterator, Union +from typing import Iterable, Iterator, Optional, Union from elasticsearch import ApiError, AsyncElasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, async_streaming_bulk @@ -30,7 +30,7 @@ class AsyncESDataBackend(BaseAsyncDataBackend): query_model = ESQuery settings_class = ESDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[ESDataBackendSettings] = None): """Instantiate the asynchronous Elasticsearch client. Args: @@ -70,7 +70,7 @@ async def status(self) -> DataBackendStatus: return DataBackendStatus.ERROR async def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List available Elasticsearch indices, data streams and aliases. @@ -113,8 +113,8 @@ async def list( async def read( self, *, - query: Union[str, ESQuery] = None, - target: str = None, + query: Optional[Union[str, ESQuery]] = None, + target: Optional[str] = None, chunk_size: Union[None, int] = None, raw_output: bool = False, ignore_errors: bool = False, diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 76d8954c4..8230a11be 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -4,7 +4,7 @@ import logging from io import IOBase from itertools import chain -from typing import Any, Dict, Iterable, Iterator, Union +from typing import Any, Dict, Iterable, Iterator, Optional, Union from bson.errors import BSONError from motor.motor_asyncio import AsyncIOMotorClient @@ -36,7 +36,7 @@ class AsyncMongoDataBackend(BaseAsyncDataBackend): query_model = MongoQuery settings_class = MongoDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[MongoDataBackendSettings] = None): """Instantiate the asynchronous MongoDB client. Args: @@ -74,7 +74,7 @@ async def status(self) -> DataBackendStatus: return DataBackendStatus.OK async def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List collections in the target database. @@ -118,9 +118,9 @@ async def list( async def read( self, *, - query: Union[str, MongoQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, MongoQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -179,10 +179,10 @@ async def read( async def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write data documents to the target collection and return their count. diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index c5277d6cc..4b9454819 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum, unique from io import IOBase -from typing import Iterable, Iterator, Union +from typing import Iterable, Iterator, Optional, Union from pydantic import BaseModel, BaseSettings, ValidationError @@ -89,7 +89,7 @@ class BaseDataBackend(ABC): settings_class = BaseDataBackendSettings @abstractmethod - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[BaseDataBackendSettings] = None): """Instantiate the data backend. Args: @@ -137,7 +137,7 @@ def status(self) -> DataBackendStatus: @abstractmethod def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List containers in the data backend. E.g., collections, files, indexes. @@ -161,9 +161,9 @@ def list( def read( self, *, - query: Union[str, BaseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -199,10 +199,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write `data` records to the `target` container and return their count. @@ -262,7 +262,7 @@ class BaseAsyncDataBackend(ABC): settings_class = BaseDataBackendSettings @abstractmethod - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[BaseDataBackendSettings] = None): """Instantiate the data backend. Args: @@ -310,7 +310,7 @@ async def status(self) -> DataBackendStatus: @abstractmethod async def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List containers in the data backend. E.g., collections, files, indexes. @@ -334,9 +334,9 @@ async def list( async def read( self, *, - query: Union[str, BaseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -372,10 +372,10 @@ async def read( async def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write `data` records to the `target` container and return their count. diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 795e0c798..81a978760 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -5,7 +5,17 @@ from datetime import datetime from io import IOBase from itertools import chain -from typing import Any, Dict, Generator, Iterable, Iterator, List, NamedTuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Union, +) from uuid import UUID, uuid4 import clickhouse_connect @@ -107,7 +117,7 @@ class ClickHouseDataBackend(BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = ClickHouseDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[ClickHouseDataBackendSettings] = None): """Instantiate the ClickHouse configuration. Args: @@ -156,7 +166,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List tables for a given database. @@ -191,9 +201,9 @@ def list( def read( self, *, - query: Union[str, ClickHouseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, ClickHouseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -285,10 +295,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write `data` documents to the `target` table and return their count. @@ -420,7 +430,10 @@ def _to_insert_tuples( yield insert_tuple def _bulk_import( - self, batch: list, ignore_errors: bool = False, event_table_name: str = None + self, + batch: list, + ignore_errors: bool = False, + event_table_name: Optional[str] = None, ): """Insert a batch of documents into the selected database table.""" try: diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 2d1bbb02e..05b6e1296 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -4,7 +4,7 @@ from io import IOBase from itertools import chain from pathlib import Path -from typing import Iterable, Iterator, List, Literal, Union +from typing import Iterable, Iterator, List, Literal, Optional, Union from elasticsearch import ApiError, Elasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, streaming_bulk @@ -28,8 +28,8 @@ class ESClientOptions(ClientOptions): """Elasticsearch additional client options.""" - ca_certs: Path = None - verify_certs: bool = None + ca_certs: Optional[Path] = None + verify_certs: Optional[bool] = None class ESDataBackendSettings(BaseDataBackendSettings): @@ -116,7 +116,7 @@ class ESDataBackend(BaseDataBackend): query_model = ESQuery settings_class = ESDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[ESDataBackendSettings] = None): """Instantiate the Elasticsearch data backend. Args: @@ -156,7 +156,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.ERROR def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List available Elasticsearch indices, data streams and aliases. @@ -199,9 +199,9 @@ def list( def read( self, *, - query: Union[str, ESQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, ESQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -277,10 +277,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write data documents to the target index and return their count. diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 85ce0fbf0..1eb024cea 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -7,7 +7,7 @@ from io import IOBase from itertools import chain from pathlib import Path -from typing import IO, Iterable, Iterator, Union +from typing import IO, Iterable, Iterator, Optional, Union from uuid import uuid4 from ralph.backends.data.base import ( @@ -56,7 +56,7 @@ class FSDataBackend(HistoryMixin, BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = FSDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[FSDataBackendSettings] = None): """Create the default target directory if it does not exist. Args: @@ -90,7 +90,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List files and directories in the target directory. @@ -110,7 +110,7 @@ def list( Raises: BackendParameterException: If the `target` argument is not a directory path. """ - target = Path(target) if target else self.default_directory + target: Path = Path(target) if target else self.default_directory if not target.is_absolute() and target != self.default_directory: target = self.default_directory / target try: @@ -145,9 +145,9 @@ def list( def read( self, *, - query: Union[str, BaseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -218,10 +218,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write data records to the target file and return their count. diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index e3cd499fb..83881c4ad 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -1,7 +1,7 @@ """OVH's LDP data backend for Ralph.""" import logging -from typing import Iterable, Iterator, Literal, Union +from typing import Iterable, Iterator, Literal, Optional, Union import ovh import requests @@ -40,10 +40,10 @@ class Config(BaseSettingsConfig): env_prefix = "RALPH_BACKENDS__DATA__LDP__" - APPLICATION_KEY: str = None - APPLICATION_SECRET: str = None - CONSUMER_KEY: str = None - DEFAULT_STREAM_ID: str = None + APPLICATION_KEY: Optional[str] = None + APPLICATION_SECRET: Optional[str] = None + CONSUMER_KEY: Optional[str] = None + DEFAULT_STREAM_ID: Optional[str] = None ENDPOINT: Literal[ "ovh-eu", "ovh-us", @@ -53,8 +53,8 @@ class Config(BaseSettingsConfig): "soyoustart-eu", "soyoustart-ca", ] = "ovh-eu" - REQUEST_TIMEOUT: int = None - SERVICE_NAME: str = None + REQUEST_TIMEOUT: Optional[int] = None + SERVICE_NAME: Optional[str] = None class LDPDataBackend(HistoryMixin, BaseDataBackend): @@ -63,7 +63,7 @@ class LDPDataBackend(HistoryMixin, BaseDataBackend): name = "ldp" settings_class = LDPDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[LDPDataBackendSettings] = None): """Instantiate the OVH LDP client. Args: @@ -101,7 +101,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List archives for a given target stream_id. @@ -150,9 +150,9 @@ def list( def read( self, *, - query: Union[str, BaseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = 4096, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = 4096, raw_output: bool = True, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -217,10 +217,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Iterable[Union[bytes, dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """LDP data backend is read-only, calling this method will raise an error.""" msg = "LDP data backend is read-only, cannot write to %s" diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index 05dd83789..506beb46f 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -7,7 +7,7 @@ import struct from io import IOBase from itertools import chain -from typing import Generator, Iterable, Iterator, List, Tuple, Union +from typing import Generator, Iterable, Iterator, List, Optional, Tuple, Union from uuid import uuid4 from bson.errors import BSONError @@ -42,8 +42,8 @@ class MongoClientOptions(ClientOptions): """MongoDB additional client options.""" - document_class: str = None - tz_aware: bool = None + document_class: Optional[str] = None + tz_aware: Optional[bool] = None class MongoDataBackendSettings(BaseDataBackendSettings): @@ -96,7 +96,7 @@ class MongoDataBackend(BaseDataBackend): query_model = MongoQuery settings_class = MongoDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[MongoDataBackendSettings] = None): """Instantiate the MongoDB client. Args: @@ -135,7 +135,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List collections in the `target` database. @@ -174,9 +174,9 @@ def list( def read( self, *, - query: Union[str, MongoQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, MongoQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -231,10 +231,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write `data` documents to the `target` collection and return their count. @@ -397,7 +397,7 @@ def to_documents( yield document @staticmethod - def _bulk_import(batch: list, ignore_errors: bool, collection: Collection): + def _bulk_import(batch: List, ignore_errors: bool, collection: Collection) -> int: """Insert a `batch` of documents into the MongoDB `collection`.""" try: new_documents = collection.insert_many(batch) @@ -414,7 +414,7 @@ def _bulk_import(batch: list, ignore_errors: bool, collection: Collection): return inserted_count @staticmethod - def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): + def _bulk_delete(batch: List, ignore_errors: bool, collection: Collection) -> int: """Delete a `batch` of documents from the MongoDB `collection`.""" try: deleted_documents = collection.delete_many({"_source.id": {"$in": batch}}) @@ -431,7 +431,7 @@ def _bulk_delete(batch: list, ignore_errors: bool, collection: Collection): return deleted_count @staticmethod - def _bulk_update(batch: list, ignore_errors: bool, collection: Collection): + def _bulk_update(batch: List, ignore_errors: bool, collection: Collection) -> int: """Update a `batch` of documents into the MongoDB `collection`.""" try: updated_documents = collection.bulk_write(batch) diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index 22ce05573..d4c30af5c 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -4,7 +4,7 @@ import logging from io import IOBase from itertools import chain -from typing import Iterable, Iterator, Union +from typing import Iterable, Iterator, Optional, Union from uuid import uuid4 import boto3 @@ -72,7 +72,7 @@ class S3DataBackend(HistoryMixin, BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = S3DataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[S3DataBackendSettings] = None): """Instantiate the AWS S3 client.""" self.settings = settings if settings else self.settings_class() @@ -109,7 +109,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List objects for the target bucket. @@ -157,9 +157,9 @@ def list( def read( self, *, - query: Union[str, BaseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -230,10 +230,10 @@ def read( def write( # pylint: disable=too-many-arguments self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write `data` records to the `target` bucket and return their count. diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index e4846f848..b0b75c53d 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -4,7 +4,7 @@ import logging from functools import cached_property from io import IOBase -from typing import Iterable, Iterator, Union +from typing import Iterable, Iterator, Optional, Union from uuid import uuid4 from swiftclient.service import ClientException, Connection @@ -71,7 +71,7 @@ class SwiftDataBackend(HistoryMixin, BaseDataBackend): default_operation_type = BaseOperationType.CREATE settings_class = SwiftDataBackendSettings - def __init__(self, settings: Union[settings_class, None] = None): + def __init__(self, settings: Optional[SwiftDataBackendSettings] = None): """Prepares the options for the SwiftService.""" self.settings = settings if settings else self.settings_class() @@ -120,7 +120,7 @@ def status(self) -> DataBackendStatus: return DataBackendStatus.OK def list( - self, target: Union[str, None] = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List files for the target container. @@ -162,9 +162,9 @@ def list( def read( self, *, - query: Union[str, BaseQuery] = None, - target: Union[str, None] = None, - chunk_size: Union[int, None] = 500, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = 500, raw_output: bool = False, ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: @@ -241,10 +241,10 @@ def read( def write( # pylint: disable=too-many-arguments, disable=too-many-branches self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], - target: Union[str, None] = None, - chunk_size: Union[int, None] = None, + target: Optional[str] = None, + chunk_size: Optional[int] = None, ignore_errors: bool = False, - operation_type: Union[BaseOperationType, None] = None, + operation_type: Optional[BaseOperationType] = None, ) -> int: """Write `data` records to the `target` container and returns their count. diff --git a/src/ralph/backends/http/async_lrs.py b/src/ralph/backends/http/async_lrs.py index 6f58859d6..8309397a0 100644 --- a/src/ralph/backends/http/async_lrs.py +++ b/src/ralph/backends/http/async_lrs.py @@ -106,7 +106,7 @@ class AsyncLRSHTTPBackend(BaseHTTPBackend): settings_class = LRSHTTPBackendSettings def __init__( # pylint: disable=too-many-arguments - self, settings: settings_class = None + self, settings: Optional[LRSHTTPBackendSettings] = None ): """Instantiate the LRS HTTP (basic auth) backend client. @@ -119,7 +119,7 @@ def __init__( # pylint: disable=too-many-arguments self.base_url = parse_obj_as(AnyHttpUrl, self.settings.BASE_URL) self.auth = (self.settings.USERNAME, self.settings.PASSWORD) - async def status(self): + async def status(self) -> HTTPBackendStatus: """HTTP backend check for server status.""" status_url = urljoin(self.base_url, self.settings.STATUS_ENDPOINT) @@ -139,7 +139,7 @@ async def status(self): return HTTPBackendStatus.OK async def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """Raise error for unsupported `list` method.""" msg = "LRS HTTP backend does not support `list` method, cannot list from %s" @@ -150,8 +150,8 @@ async def list( @enforce_query_checks async def read( # pylint: disable=too-many-arguments self, - query: Union[str, LRSStatementsQuery] = None, - target: str = None, + query: Optional[Union[str, LRSStatementsQuery]] = None, + target: Optional[str] = None, chunk_size: Optional[PositiveInt] = 500, raw_output: bool = False, ignore_errors: bool = False, diff --git a/src/ralph/backends/http/base.py b/src/ralph/backends/http/base.py index cc0379b6d..ae5003b35 100644 --- a/src/ralph/backends/http/base.py +++ b/src/ralph/backends/http/base.py @@ -85,7 +85,9 @@ class BaseHTTPBackend(ABC): name = "base" query = BaseQuery - def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery: + def validate_query( + self, query: Optional[Union[str, dict, BaseQuery]] = None + ) -> BaseQuery: """Validate and transforms the query.""" if query is None: query = self.query() @@ -114,7 +116,7 @@ def validate_query(self, query: Union[str, dict, BaseQuery] = None) -> BaseQuery @abstractmethod async def list( - self, target: str = None, details: bool = False, new: bool = False + self, target: Optional[str] = None, details: bool = False, new: bool = False ) -> Iterator[Union[str, dict]]: """List containers in the data backend. E.g., collections, files, indexes.""" @@ -126,8 +128,8 @@ async def status(self) -> HTTPBackendStatus: @enforce_query_checks async def read( # pylint: disable=too-many-arguments self, - query: Union[str, BaseQuery] = None, - target: str = None, + query: Optional[Union[str, BaseQuery]] = None, + target: Optional[str] = None, chunk_size: Optional[PositiveInt] = 500, raw_output: bool = False, ignore_errors: bool = False, diff --git a/src/ralph/backends/http/lrs.py b/src/ralph/backends/http/lrs.py index 3daf87c07..da6d43cf3 100644 --- a/src/ralph/backends/http/lrs.py +++ b/src/ralph/backends/http/lrs.py @@ -1,7 +1,9 @@ """LRS HTTP backend for Ralph.""" import asyncio +from typing import Iterator, Union from ralph.backends.http.async_lrs import AsyncLRSHTTPBackend +from ralph.backends.http.base import HTTPBackendStatus def _ensure_running_loop_uniqueness(func): @@ -33,21 +35,21 @@ class LRSHTTPBackend(AsyncLRSHTTPBackend): name = "lrs" @_ensure_running_loop_uniqueness - def status(self, *args, **kwargs): + def status(self, *args, **kwargs) -> HTTPBackendStatus: """HTTP backend check for server status.""" return asyncio.get_event_loop().run_until_complete( super().status(*args, **kwargs) ) @_ensure_running_loop_uniqueness - def list(self, *args, **kwargs): + def list(self, *args, **kwargs) -> Iterator[Union[str, dict]]: """Raise error for unsupported `list` method.""" return asyncio.get_event_loop().run_until_complete( super().list(*args, **kwargs) ) @_ensure_running_loop_uniqueness - def read(self, *args, **kwargs): + def read(self, *args, **kwargs) -> Iterator[Union[bytes, dict]]: """Get statements from LRS `target` endpoint. See AsyncLRSHTTP.read for more information. @@ -61,7 +63,7 @@ def read(self, *args, **kwargs): pass @_ensure_running_loop_uniqueness - def write(self, *args, **kwargs): + def write(self, *args, **kwargs) -> int: """Write `data` records to the `target` endpoint and return their count. See AsyncLRSHTTP.write for more information. diff --git a/src/ralph/backends/lrs/clickhouse.py b/src/ralph/backends/lrs/clickhouse.py index 1721879b7..b04c4f545 100644 --- a/src/ralph/backends/lrs/clickhouse.py +++ b/src/ralph/backends/lrs/clickhouse.py @@ -1,7 +1,7 @@ """ClickHouse LRS backend for Ralph.""" import logging -from typing import Iterator, List +from typing import Generator, Iterator, List from ralph.backends.data.clickhouse import ( ClickHouseDataBackend, @@ -118,7 +118,7 @@ def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: """Yield statements with matching ids from the backend.""" - def chunk_id_list(chunk_size=self.settings.IDS_CHUNK_SIZE): + def chunk_id_list(chunk_size: int = self.settings.IDS_CHUNK_SIZE) -> Generator: for i in range(0, len(ids), chunk_size): yield ids[i : i + chunk_size] @@ -148,7 +148,7 @@ def _add_agent_filters( where: list, agent_params: AgentParameters, target_field: str, - ): + ) -> None: """Add filters relative to agents to `where`.""" if not agent_params: return diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index 6c498c772..3b57511a2 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -91,7 +91,7 @@ def get_query(params: RalphStatementsQuery) -> ESQuery: @staticmethod def _add_agent_filters( es_query_filters: list, agent_params: AgentParameters, target_field: str - ): + ) -> None: """Add filters relative to agents to `es_query_filters`.""" if not agent_params: return diff --git a/src/ralph/backends/lrs/fs.py b/src/ralph/backends/lrs/fs.py index 5407f075b..4423bfdc4 100644 --- a/src/ralph/backends/lrs/fs.py +++ b/src/ralph/backends/lrs/fs.py @@ -3,7 +3,7 @@ import logging from datetime import datetime from io import IOBase -from typing import Iterable, List, Literal, Union +from typing import Iterable, List, Literal, Optional, Union from uuid import UUID from ralph.backends.data.base import BaseOperationType @@ -99,7 +99,7 @@ def query_statements_by_ids(self, ids: List[str]) -> List: @staticmethod def _add_filter_by_agent( - filters: list, agent: Union[AgentParameters, None], related: Union[bool, None] + filters: list, agent: Optional[AgentParameters], related: Optional[bool] ) -> None: """Add agent filters to `filters` if `agent` is set.""" if not agent: @@ -122,7 +122,7 @@ def _add_filter_by_agent( @staticmethod def _add_filter_by_authority( filters: list, - authority: Union[AgentParameters, None], + authority: Optional[AgentParameters], ) -> None: """Add authority filters to `filters` if `authority` is set.""" if not authority: @@ -147,7 +147,7 @@ def _add_filter_by_authority( ) @staticmethod - def _add_filter_by_id(filters: list, statement_id: Union[str, None]) -> None: + def _add_filter_by_id(filters: list, statement_id: Optional[str]) -> None: """Add the `match_statement_id` filter if `statement_id` is set.""" def match_statement_id(statement: dict) -> bool: @@ -169,8 +169,8 @@ def _get_related_agents(statement: dict) -> Iterable[dict]: @staticmethod def _add_filter_by_mbox( filters: list, - mbox: Union[str, None], - related: Union[bool, None] = False, + mbox: Optional[str], + related: Optional[bool] = False, field: Literal["actor", "authority"] = "actor", ) -> None: """Add the `match_mbox` filter if `mbox` is set.""" @@ -196,8 +196,8 @@ def match_related_mbox(statement: dict) -> bool: @staticmethod def _add_filter_by_sha1sum( filters: list, - sha1sum: Union[str, None], - related: Union[bool, None] = False, + sha1sum: Optional[str], + related: Optional[bool] = False, field: Literal["actor", "authority"] = "actor", ) -> None: """Add the `match_sha1sum` filter if `sha1sum` is set.""" @@ -223,8 +223,8 @@ def match_related_sha1sum(statement: dict) -> bool: @staticmethod def _add_filter_by_openid( filters: list, - openid: Union[str, None], - related: Union[bool, None] = False, + openid: Optional[str], + related: Optional[bool] = False, field: Literal["actor", "authority"] = "actor", ) -> None: """Add the `match_openid` filter if `openid` is set.""" @@ -250,9 +250,9 @@ def match_related_openid(statement: dict) -> bool: @staticmethod def _add_filter_by_account( filters: list, - name: Union[str, None], - home_page: Union[str, None], - related: Union[bool, None] = False, + name: Optional[str], + home_page: Optional[str], + related: Optional[bool] = False, field: Literal["actor", "authority"] = "actor", ) -> None: """Add the `match_account` filter if `name` or `home_page` is set.""" @@ -278,7 +278,7 @@ def match_related_account(statement: dict) -> bool: filters.append(match_related_account if related else match_account) @staticmethod - def _add_filter_by_verb(filters: list, verb_id: Union[str, None]) -> None: + def _add_filter_by_verb(filters: list, verb_id: Optional[str]) -> None: """Add the `match_verb_id` filter if `verb_id` is set.""" def match_verb_id(statement: dict) -> bool: @@ -290,7 +290,7 @@ def match_verb_id(statement: dict) -> bool: @staticmethod def _add_filter_by_activity( - filters: list, object_id: Union[str, None], related: Union[bool, None] + filters: list, object_id: Optional[str], related: Optional[bool] ) -> None: """Add the `match_object_id` filter if `object_id` is set.""" @@ -322,7 +322,7 @@ def match_related_object_id(statement: dict) -> bool: @staticmethod def _add_filter_by_timestamp_since( - filters: list, timestamp: Union[datetime, None] + filters: list, timestamp: Optional[datetime] ) -> None: """Add the `match_since` filter if `timestamp` is set.""" if isinstance(timestamp, str): @@ -343,7 +343,7 @@ def match_since(statement: dict) -> bool: @staticmethod def _add_filter_by_timestamp_until( - filters: list, timestamp: Union[datetime, None] + filters: list, timestamp: Optional[datetime] ) -> None: """Add the `match_until` function if `timestamp` is set.""" if isinstance(timestamp, str): @@ -363,9 +363,7 @@ def match_until(statement: dict) -> bool: filters.append(match_until) @staticmethod - def _add_filter_by_search_after( - filters: list, search_after: Union[str, None] - ) -> None: + def _add_filter_by_search_after(filters: list, search_after: Optional[str]) -> None: """Add the `match_search_after` filter if `search_after` is set.""" search_after_state = {"state": False} @@ -382,7 +380,7 @@ def match_search_after(statement: dict) -> bool: @staticmethod def _add_filter_by_registration( - filters: list, registration: Union[UUID, None] + filters: list, registration: Optional[UUID] ) -> None: """Add the `match_registration` filter if `registration` is set.""" registration_str = str(registration) diff --git a/src/ralph/backends/lrs/mongo.py b/src/ralph/backends/lrs/mongo.py index 2d2a0a64e..3fd1e2bae 100644 --- a/src/ralph/backends/lrs/mongo.py +++ b/src/ralph/backends/lrs/mongo.py @@ -103,7 +103,7 @@ def get_query(params: RalphStatementsQuery) -> MongoQuery: @staticmethod def _add_agent_filters( mongo_query_filters: dict, agent_params: AgentParameters, target_field: str - ): + ) -> None: """Add filters relative to agents to mongo_query_filters. Args: diff --git a/src/ralph/backends/stream/base.py b/src/ralph/backends/stream/base.py index 5a6861203..008e68d81 100644 --- a/src/ralph/backends/stream/base.py +++ b/src/ralph/backends/stream/base.py @@ -27,5 +27,5 @@ class BaseStreamBackend(ABC): settings_class = BaseStreamBackendSettings @abstractmethod - def stream(self, target: BinaryIO): + def stream(self, target: BinaryIO) -> None: """Read records and stream them to target.""" diff --git a/src/ralph/backends/stream/ws.py b/src/ralph/backends/stream/ws.py index 0cad5b029..2f70651cf 100644 --- a/src/ralph/backends/stream/ws.py +++ b/src/ralph/backends/stream/ws.py @@ -2,7 +2,7 @@ import asyncio import logging -from typing import BinaryIO +from typing import BinaryIO, Optional import websockets @@ -25,7 +25,7 @@ class Config(BaseSettingsConfig): env_prefix = "RALPH_BACKENDS__STREAM__WS__" - URI: str = None + URI: Optional[str] = None class WSStreamBackend(BaseStreamBackend): @@ -34,7 +34,7 @@ class WSStreamBackend(BaseStreamBackend): name = "ws" settings_class = WSStreamBackendSettings - def __init__(self, settings: settings_class = None): + def __init__(self, settings: Optional[WSStreamBackendSettings] = None): """Instantiate the websocket client. Args: @@ -43,13 +43,13 @@ def __init__(self, settings: settings_class = None): """ self.settings = settings if settings else self.settings_class() - def stream(self, target: BinaryIO): + def stream(self, target: BinaryIO) -> None: """Stream websocket content to target.""" # pylint: disable=no-member logger.debug("Streaming from websocket uri: %s", self.settings.URI) - async def _stream(): + async def _stream() -> None: async with websockets.connect(self.settings.URI) as websocket: while event := await websocket.recv(): target.write(bytes(f"{event}" + "\n", encoding="utf-8")) diff --git a/src/ralph/cli.py b/src/ralph/cli.py index fd2fd3fdd..932957bb5 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -7,7 +7,7 @@ from inspect import isasyncgen, isclass, iscoroutinefunction from pathlib import Path from tempfile import NamedTemporaryFile -from typing import List +from typing import Any, Optional, Sequence import bcrypt @@ -115,7 +115,7 @@ def convert(self, value, param, ctx): class ClientOptionsParamType(CommaSeparatedKeyValueParamType): """Comma separated key=value parameter type for client options.""" - def __init__(self, client_options_type): + def __init__(self, client_options_type: Any) -> None: """Instantiates ClientOptionsParamType for a client_options_type. Args: @@ -137,7 +137,7 @@ def convert(self, value, param, ctx): class HeadersParametersParamType(CommaSeparatedKeyValueParamType): """Comma separated key=value parameter type for headers parameters.""" - def __init__(self, headers_parameters_type): + def __init__(self, headers_parameters_type: Any) -> None: """Instantiates HeadersParametersParamType for a headers_parameters_type. Args: @@ -200,7 +200,7 @@ def cli(verbosity=None): handler.setLevel(level) -def backends_options(name=None, backend_types: List[BaseModel] = None): +def backends_options(name=None, backend_types: Optional[Sequence[BaseModel]] = None): """Backend-related options decorator for Ralph commands.""" def wrapper(command): diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 1a149577d..00affd056 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -1,13 +1,20 @@ """Configurations for Ralph.""" import io +import sys from enum import Enum from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Sequence, Union -try: +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator + +from ralph.exceptions import ConfigurationException + +from .utils import import_string + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal try: @@ -20,12 +27,6 @@ get_app_dir = Mock(return_value=".") -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator - -from ralph.exceptions import ConfigurationException - -from .utils import import_string - MODEL_PATH_SEPARATOR = "__" @@ -56,7 +57,7 @@ class CommaSeparatedTuple(str): @classmethod def __get_validators__(cls): # noqa: D105 - def validate(value: Union[str, Tuple[str], List[str]]) -> Tuple[str]: + def validate(value: Union[str, Sequence[str]]) -> Sequence[str]: """Check whether the value is a comma separated string or a list/tuple.""" if isinstance(value, (tuple, list)): return tuple(value) diff --git a/src/ralph/filters.py b/src/ralph/filters.py index 526d98af5..327c4d44e 100644 --- a/src/ralph/filters.py +++ b/src/ralph/filters.py @@ -1,9 +1,11 @@ """Ralph tracking logs filters.""" +from typing import Any, Union + from .exceptions import EventKeyError -def anonymous(event): +def anonymous(event: dict) -> Union[dict, Any]: """Remove anonymous events. Args: diff --git a/src/ralph/logger.py b/src/ralph/logger.py index b17807294..63945e33a 100644 --- a/src/ralph/logger.py +++ b/src/ralph/logger.py @@ -6,7 +6,7 @@ from ralph.exceptions import ConfigurationException -def configure_logging(): +def configure_logging() -> None: """Set up Ralph logging configuration.""" try: dictConfig(settings.LOGGING) diff --git a/src/ralph/models/converter.py b/src/ralph/models/converter.py index 85509af3d..7ddbe96e0 100644 --- a/src/ralph/models/converter.py +++ b/src/ralph/models/converter.py @@ -7,7 +7,18 @@ from importlib import import_module from inspect import getmembers, isclass from types import ModuleType -from typing import Any, Callable, Set, TextIO, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + Optional, + Set, + TextIO, + Tuple, + Union, +) from pydantic import BaseModel, ValidationError @@ -34,7 +45,13 @@ class ConversionItem: transformers: Tuple[Callable[[Any], Any]] raw_input: bool - def __init__(self, dest: str, src=None, transformers=lambda _: _, raw_input=False): + def __init__( + self, + dest: str, + src: Optional[str] = None, + transformers=lambda _: _, + raw_input: bool = False, + ) -> None: """Initialize ConversionItem. Args: @@ -55,7 +72,7 @@ def __init__(self, dest: str, src=None, transformers=lambda _: _, raw_input=Fals object.__setattr__(self, "transformers", transformers) object.__setattr__(self, "raw_input", raw_input) - def get_value(self, data: Union[dict, str]): + def get_value(self, data: Union[Dict, str]) -> Union[Dict, str]: """Return fetched source value after having applied all transformers to it. Args: @@ -84,7 +101,7 @@ class BaseConversionSet(ABC): __src__: BaseModel __dest__: BaseModel - def __init__(self): + def __init__(self) -> None: """Initializes BaseConversionSet.""" self._conversion_items = self._get_conversion_items() @@ -92,13 +109,13 @@ def __init__(self): def _get_conversion_items(self) -> Set[ConversionItem]: """Returns a set of ConversionItems used for conversion.""" - def __iter__(self): # noqa: D105 + def __iter__(self) -> Iterator[ConversionItem]: # noqa: D105 return iter(self._conversion_items) def convert_dict_event( event: dict, event_str: str, conversion_set: BaseConversionSet -) -> BaseModel: +) -> Any: """Convert the event dictionary with a conversion_set. Args: @@ -151,10 +168,10 @@ class Converter: def __init__( self, - model_selector=ModelSelector(), - module="ralph.models.edx.converters.xapi", - **conversion_set_kwargs, - ): + model_selector: ModelSelector = ModelSelector(), + module: str = "ralph.models.edx.converters.xapi", + **conversion_set_kwargs: Any, + ) -> None: """Initializes the Converter.""" self.model_selector = model_selector self.src_conversion_set = self.get_src_conversion_set( @@ -162,7 +179,9 @@ def __init__( ) @staticmethod - def get_src_conversion_set(module: ModuleType, **conversion_set_kwargs): + def get_src_conversion_set( + module: ModuleType, **conversion_set_kwargs: Any + ) -> dict: """Return a dictionary of initialized conversion_sets defined in the module.""" src_conversion_set = {} for _, class_ in getmembers(module, isclass): @@ -170,7 +189,9 @@ def get_src_conversion_set(module: ModuleType, **conversion_set_kwargs): src_conversion_set[class_.__src__] = class_(**conversion_set_kwargs) return src_conversion_set - def convert(self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: bool): + def convert( + self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: bool + ) -> Generator: """Convert JSON event strings line by line.""" total = 0 success = 0 @@ -201,7 +222,7 @@ def convert(self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: bool raise err logger.info("Total events: %d, Invalid events: %d", total, total - success) - def _convert_event(self, event_str: str): + def _convert_event(self, event_str: str) -> Any: """Convert a single JSON string event. Args: @@ -219,7 +240,7 @@ def _convert_event(self, event_str: str): ConversionException: When a field transformation fails. ValidationError: When the final converted event is invalid. """ - error = None + error: Optional[BaseException] = None event = json.loads(event_str) for model in self.model_selector.get_models(event): conversion_set = self.src_conversion_set.get(model, None) @@ -236,6 +257,8 @@ def _convert_event(self, event_str: str): raise error @staticmethod - def _log_error(message, event_str, error=None): + def _log_error( + message: object, event_str: str, error: Optional[BaseException] = None + ) -> None: logger.error(message) logger.debug("Raised error: %s, for event : %s", error, event_str) diff --git a/src/ralph/models/edx/base.py b/src/ralph/models/edx/base.py index 89af90028..7a38e2487 100644 --- a/src/ralph/models/edx/base.py +++ b/src/ralph/models/edx/base.py @@ -1,17 +1,18 @@ """Base event model definitions.""" +import sys from datetime import datetime from ipaddress import IPv4Address from pathlib import Path from typing import Dict, Optional, Union -try: +from pydantic import AnyHttpUrl, BaseModel, constr + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from pydantic import AnyHttpUrl, BaseModel, constr - class BaseModelWithConfig(BaseModel): """Pydantic model for base configuration shared among all models.""" diff --git a/src/ralph/models/edx/browser.py b/src/ralph/models/edx/browser.py index 39c45d8fa..fdf230473 100644 --- a/src/ralph/models/edx/browser.py +++ b/src/ralph/models/edx/browser.py @@ -1,16 +1,17 @@ """Browser event model definitions.""" +import sys from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import AnyUrl, constr from .base import BaseEdxModel +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class BaseBrowserModel(BaseEdxModel): """Pydantic model for core browser statements. diff --git a/src/ralph/models/edx/converters/xapi/base.py b/src/ralph/models/edx/converters/xapi/base.py index da7a413e8..2c3496cb3 100644 --- a/src/ralph/models/edx/converters/xapi/base.py +++ b/src/ralph/models/edx/converters/xapi/base.py @@ -1,5 +1,6 @@ """Base xAPI Converter.""" +from typing import Set from uuid import UUID, uuid5 from ralph.exceptions import ConfigurationException @@ -27,7 +28,7 @@ def __init__(self, uuid_namespace: str, platform_url: str): raise ConfigurationException("Invalid UUID namespace") from err super().__init__() - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" return { ConversionItem( diff --git a/src/ralph/models/edx/converters/xapi/enrollment.py b/src/ralph/models/edx/converters/xapi/enrollment.py index a82e30fd5..7f1feb145 100644 --- a/src/ralph/models/edx/converters/xapi/enrollment.py +++ b/src/ralph/models/edx/converters/xapi/enrollment.py @@ -1,4 +1,5 @@ """Enrollment event xAPI Converter.""" +from typing import Set from ralph.models.converter import ConversionItem from ralph.models.edx.enrollment.statements import ( @@ -13,7 +14,7 @@ class LMSBaseXapiConverter(BaseXapiConverter): """Base LMS xAPI Converter.""" - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( diff --git a/src/ralph/models/edx/converters/xapi/navigational.py b/src/ralph/models/edx/converters/xapi/navigational.py index c9f2f9e83..5d4935e4a 100644 --- a/src/ralph/models/edx/converters/xapi/navigational.py +++ b/src/ralph/models/edx/converters/xapi/navigational.py @@ -1,4 +1,5 @@ """Navigational event xAPI Converter.""" +from typing import Set from ralph.models.converter import ConversionItem from ralph.models.edx.navigational.statements import UIPageClose @@ -19,7 +20,7 @@ class UIPageCloseToPageTerminated(BaseXapiConverter): __src__ = UIPageClose __dest__ = PageTerminated - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union({ConversionItem("object__id", "page")}) diff --git a/src/ralph/models/edx/converters/xapi/server.py b/src/ralph/models/edx/converters/xapi/server.py index a9c59596c..6fb94f4c4 100644 --- a/src/ralph/models/edx/converters/xapi/server.py +++ b/src/ralph/models/edx/converters/xapi/server.py @@ -1,5 +1,7 @@ """Server event xAPI Converter.""" +from typing import Set + from ralph.models.converter import ConversionItem from ralph.models.edx.server import Server from ralph.models.xapi.navigation.statements import PageViewed @@ -16,7 +18,7 @@ class ServerEventToPageViewed(BaseXapiConverter): __src__ = Server __dest__ = PageViewed - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( diff --git a/src/ralph/models/edx/converters/xapi/video.py b/src/ralph/models/edx/converters/xapi/video.py index cb876886d..0abccadeb 100644 --- a/src/ralph/models/edx/converters/xapi/video.py +++ b/src/ralph/models/edx/converters/xapi/video.py @@ -1,4 +1,5 @@ """Video event xAPI Converter.""" +from typing import Set from ralph.models.converter import ConversionItem from ralph.models.edx.video.statements import ( @@ -32,7 +33,7 @@ class VideoBaseXapiConverter(BaseXapiConverter): """Base Video xAPI Converter.""" - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( @@ -70,7 +71,7 @@ class UILoadVideoToVideoInitialized(VideoBaseXapiConverter): __src__ = UILoadVideo __dest__ = VideoInitialized - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( @@ -100,7 +101,7 @@ class UIPlayVideoToVideoPlayed(VideoBaseXapiConverter): __src__ = UIPlayVideo __dest__ = VideoPlayed - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( @@ -123,7 +124,7 @@ class UIPauseVideoToVideoPaused(VideoBaseXapiConverter): __src__ = UIPauseVideo __dest__ = VideoPaused - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( @@ -154,7 +155,7 @@ class UIStopVideoToVideoTerminated(VideoBaseXapiConverter): __src__ = UIStopVideo __dest__ = VideoTerminated - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( @@ -193,7 +194,7 @@ class UISeekVideoToVideoSeeked(VideoBaseXapiConverter): __src__ = UISeekVideo __dest__ = VideoSeeked - def _get_conversion_items(self): + def _get_conversion_items(self) -> Set[ConversionItem]: """Return a set of ConversionItems used for conversion.""" conversion_items = super()._get_conversion_items() return conversion_items.union( diff --git a/src/ralph/models/edx/enrollment/fields/contexts.py b/src/ralph/models/edx/enrollment/fields/contexts.py index b2a3622fb..478086935 100644 --- a/src/ralph/models/edx/enrollment/fields/contexts.py +++ b/src/ralph/models/edx/enrollment/fields/contexts.py @@ -1,14 +1,15 @@ """Enrollment event models context fields definitions.""" +import sys from typing import Union -try: +from ...base import BaseContextField + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from ...base import BaseContextField - class EdxCourseEnrollmentUpgradeClickedContextField(BaseContextField): """Pydantic model for `edx.course.enrollment.upgrade_clicked`.`context` field. diff --git a/src/ralph/models/edx/enrollment/fields/events.py b/src/ralph/models/edx/enrollment/fields/events.py index 9cd198e10..60ef60c62 100644 --- a/src/ralph/models/edx/enrollment/fields/events.py +++ b/src/ralph/models/edx/enrollment/fields/events.py @@ -1,14 +1,15 @@ """Enrollment models event field definition.""" +import sys from typing import Union -try: +from ...base import AbstractBaseEventField + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from ...base import AbstractBaseEventField - class EnrollmentEventField(AbstractBaseEventField): """Pydantic model for enrollment `event` field. diff --git a/src/ralph/models/edx/enrollment/statements.py b/src/ralph/models/edx/enrollment/statements.py index 2ee34a3b3..1d342fcdd 100644 --- a/src/ralph/models/edx/enrollment/statements.py +++ b/src/ralph/models/edx/enrollment/statements.py @@ -1,12 +1,8 @@ """Enrollment event model definitions.""" +import sys from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json from ralph.models.selector import selector @@ -19,6 +15,11 @@ ) from .fields.events import EnrollmentEventField +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class EdxCourseEnrollmentActivated(BaseServerModel): """Pydantic model for `edx.course.enrollment.activated` statement. diff --git a/src/ralph/models/edx/navigational/statements.py b/src/ralph/models/edx/navigational/statements.py index 9d9b4ade8..65376244e 100644 --- a/src/ralph/models/edx/navigational/statements.py +++ b/src/ralph/models/edx/navigational/statements.py @@ -1,12 +1,8 @@ """Navigational event model definitions.""" +import sys from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json, validator from ralph.models.selector import selector @@ -14,6 +10,11 @@ from ..browser import BaseBrowserModel from .fields.events import NavigationalEventField +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class UIPageClose(BaseBrowserModel): """Pydantic model for `page_close` statement. @@ -76,7 +77,9 @@ class UISeqNext(BaseBrowserModel): @validator("event") @classmethod - def validate_next_jump_event_field(cls, value): + def validate_next_jump_event_field( + cls, value: Union[Json[NavigationalEventField], NavigationalEventField] + ) -> Union[Json[NavigationalEventField], NavigationalEventField]: """Check that event.new is equal to event.old + 1.""" if value.new != value.old + 1: raise ValueError("event.new - event.old should be equal to 1") @@ -106,7 +109,9 @@ class UISeqPrev(BaseBrowserModel): @validator("event") @classmethod - def validate_prev_jump_event_field(cls, value): + def validate_prev_jump_event_field( + cls, value: Union[Json[NavigationalEventField], NavigationalEventField] + ) -> Union[Json[NavigationalEventField], NavigationalEventField]: """Check that event.new is equal to event.old - 1.""" if value.new != value.old - 1: raise ValueError("event.old - event.new should be equal to 1") diff --git a/src/ralph/models/edx/open_response_assessment/fields/events.py b/src/ralph/models/edx/open_response_assessment/fields/events.py index 304a096f2..e5b658dde 100644 --- a/src/ralph/models/edx/open_response_assessment/fields/events.py +++ b/src/ralph/models/edx/open_response_assessment/fields/events.py @@ -1,19 +1,19 @@ """Open Response Assessment events model event fields definitions.""" +import sys from datetime import datetime from typing import Dict, List, Optional, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from uuid import UUID from pydantic import constr from ralph.models.edx.base import AbstractBaseEventField, BaseModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class ORAGetPeerSubmissionEventField(AbstractBaseEventField): """Pydantic model for `openassessmentblock.get_peer_submission`.`event` field. diff --git a/src/ralph/models/edx/open_response_assessment/statements.py b/src/ralph/models/edx/open_response_assessment/statements.py index 8d170fbf8..5aa5b8a33 100644 --- a/src/ralph/models/edx/open_response_assessment/statements.py +++ b/src/ralph/models/edx/open_response_assessment/statements.py @@ -1,13 +1,8 @@ """Open Response Assessment events model definitions.""" +import sys from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - - from pydantic import Json from ralph.models.edx.browser import BaseBrowserModel @@ -26,6 +21,11 @@ ORAUploadFileEventField, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class ORAGetPeerSubmission(BaseServerModel): """Pydantic model for `openassessmentblock.get_peer_submission` statement. diff --git a/src/ralph/models/edx/peer_instruction/statements.py b/src/ralph/models/edx/peer_instruction/statements.py index 833fd5d06..721e21461 100644 --- a/src/ralph/models/edx/peer_instruction/statements.py +++ b/src/ralph/models/edx/peer_instruction/statements.py @@ -1,12 +1,8 @@ """Peer instruction events model definitions.""" +import sys from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json from ralph.models.selector import selector @@ -14,6 +10,11 @@ from ..server import BaseServerModel from .fields.events import PeerInstructionEventField +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class PeerInstructionAccessed(BaseServerModel): """Pydantic model for `ubc.peer_instruction.accessed` statement. diff --git a/src/ralph/models/edx/problem_interaction/fields/events.py b/src/ralph/models/edx/problem_interaction/fields/events.py index b919e3888..5f0b868a4 100644 --- a/src/ralph/models/edx/problem_interaction/fields/events.py +++ b/src/ralph/models/edx/problem_interaction/fields/events.py @@ -1,17 +1,18 @@ """Problem interaction events model event fields definitions.""" +import sys from datetime import datetime from typing import Dict, List, Optional, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import constr from ...base import AbstractBaseEventField, BaseModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class QueueState(BaseModelWithConfig): """Pydantic model for problem interaction `event`.`correct_map`.`queuestate` field. diff --git a/src/ralph/models/edx/problem_interaction/statements.py b/src/ralph/models/edx/problem_interaction/statements.py index 88702b406..da44acf87 100644 --- a/src/ralph/models/edx/problem_interaction/statements.py +++ b/src/ralph/models/edx/problem_interaction/statements.py @@ -1,12 +1,8 @@ """Problem interaction events model definitions.""" +import sys from typing import List, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json from ralph.models.selector import selector @@ -29,6 +25,11 @@ UIProblemShowEventField, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class EdxProblemHintDemandhintDisplayed(BaseServerModel): """Pydantic model for `edx.problem.hint.demandhint_displayed` statement. diff --git a/src/ralph/models/edx/server.py b/src/ralph/models/edx/server.py index 9f8bc6af1..943368f81 100644 --- a/src/ralph/models/edx/server.py +++ b/src/ralph/models/edx/server.py @@ -1,19 +1,20 @@ """Server event model definitions.""" +import sys from pathlib import Path from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json from ralph.models.selector import LazyModelField, selector from .base import AbstractBaseEventField, BaseEdxModel +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class BaseServerModel(BaseEdxModel): """Pydantic model for core server statement.""" diff --git a/src/ralph/models/edx/textbook_interaction/fields/events.py b/src/ralph/models/edx/textbook_interaction/fields/events.py index a399bc1c5..6cabf6f07 100644 --- a/src/ralph/models/edx/textbook_interaction/fields/events.py +++ b/src/ralph/models/edx/textbook_interaction/fields/events.py @@ -1,16 +1,17 @@ """Textbook interaction event fields definitions.""" +import sys from typing import Optional, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Field, constr from ...base import AbstractBaseEventField +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + # pylint: disable=line-too-long class TextbookInteractionBaseEventField(AbstractBaseEventField): diff --git a/src/ralph/models/edx/textbook_interaction/statements.py b/src/ralph/models/edx/textbook_interaction/statements.py index 5f571fc27..9b707e55c 100644 --- a/src/ralph/models/edx/textbook_interaction/statements.py +++ b/src/ralph/models/edx/textbook_interaction/statements.py @@ -1,12 +1,8 @@ """Textbook interaction event model definitions.""" +import sys from typing import Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json from ralph.models.selector import selector @@ -29,6 +25,11 @@ TextbookPdfZoomMenuChangedEventField, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class UIBook(BaseBrowserModel): """Pydantic model for `book` statement. diff --git a/src/ralph/models/edx/video/fields/events.py b/src/ralph/models/edx/video/fields/events.py index 328c1f594..e2533e167 100644 --- a/src/ralph/models/edx/video/fields/events.py +++ b/src/ralph/models/edx/video/fields/events.py @@ -1,12 +1,14 @@ """Video event fields definitions.""" -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal +import sys from ...base import AbstractBaseEventField +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class VideoBaseEventField(AbstractBaseEventField): """Pydantic model for video core `event` field. diff --git a/src/ralph/models/edx/video/statements.py b/src/ralph/models/edx/video/statements.py index e468dd1c9..19ed0cdc2 100644 --- a/src/ralph/models/edx/video/statements.py +++ b/src/ralph/models/edx/video/statements.py @@ -1,12 +1,8 @@ """Video event model definitions.""" +import sys from typing import Optional, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Json from ralph.models.edx.video.fields.events import ( @@ -23,6 +19,11 @@ from ..browser import BaseBrowserModel +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class UILoadVideo(BaseBrowserModel): """Pydantic model for `load_video` statement. diff --git a/src/ralph/models/selector.py b/src/ralph/models/selector.py index 65ee026a0..a39980d87 100644 --- a/src/ralph/models/selector.py +++ b/src/ralph/models/selector.py @@ -6,7 +6,7 @@ from inspect import getmembers, isclass from itertools import chain from types import ModuleType -from typing import Any, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from pydantic import BaseModel @@ -21,7 +21,7 @@ class LazyModelField: path: Tuple[str] - def __init__(self, path: str): + def __init__(self, path: str) -> None: """Initialize Lazy Model Field.""" object.__setattr__(self, "path", tuple(path.split(MODEL_PATH_SEPARATOR))) @@ -33,7 +33,7 @@ class Rule: field: LazyModelField value: Union[LazyModelField, Any] # pylint: disable=unsubscriptable-object - def check(self, event): + def check(self, event: Dict) -> bool: """Check if event matches the rule. Args: @@ -46,7 +46,7 @@ def check(self, event): return event_value == expected_value -def selector(**filters): +def selector(**filters: Any) -> List[Rule]: """Return a list of rules that should match in order to select an event. Args: @@ -66,13 +66,13 @@ class ModelSelector: decision_tree (dict): Stores the rule checking order for model selection. """ - def __init__(self, module="ralph.models.edx"): + def __init__(self, module: str = "ralph.models.edx") -> None: """Instantiates ModelSelector.""" self.model_rules = ModelSelector.build_model_rules(import_module(module)) self.decision_tree = self.get_decision_tree(self.model_rules) @staticmethod - def build_model_rules(module: ModuleType): + def build_model_rules(module: ModuleType) -> Dict: """Build the model_rules dictionary. Using BaseModel classes defined in the module. @@ -83,7 +83,7 @@ def build_model_rules(module: ModuleType): model_rules[class_] = class_.__selector__ return model_rules - def get_first_model(self, event: dict): + def get_first_model(self, event: Dict) -> Any: """Return the first matching model for the event. See `self.get_models`.""" return self.get_models(event)[0] diff --git a/src/ralph/models/validator.py b/src/ralph/models/validator.py index 72b6eec04..fe970e46e 100644 --- a/src/ralph/models/validator.py +++ b/src/ralph/models/validator.py @@ -3,7 +3,7 @@ import json import logging -from typing import TextIO +from typing import Any, Generator, Optional, TextIO from pydantic import ValidationError @@ -20,7 +20,9 @@ def __init__(self, model_selector: ModelSelector): """Initializes Validator.""" self.model_selector = model_selector - def validate(self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: bool): + def validate( + self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: bool + ) -> Generator: """Validates JSON event strings line by line.""" total = 0 success = 0 @@ -45,14 +47,14 @@ def validate(self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: boo raise BadFormatException(message) from err logger.info("Total events: %d, Invalid events: %d", total, total - success) - def get_first_valid_model(self, event: dict): + def get_first_valid_model(self, event: dict) -> Any: """Returns the first successfully instantiated model for the event. Raises: UnknownEventException: When the event does not match any model. ValidationError: When the last validated event is invalid. """ - error = None + error: Optional[BaseException] = None for model in self.model_selector.get_models(event): try: return model(**event) @@ -61,7 +63,7 @@ def get_first_valid_model(self, event: dict): raise error - def _validate_event(self, event_str: str): + def _validate_event(self, event_str: str) -> Any: """Validate a single JSON string event. Raises: @@ -77,6 +79,8 @@ def _validate_event(self, event_str: str): return self.get_first_valid_model(event).json() @staticmethod - def _log_error(message, event_str, error=None): + def _log_error( + message: object, event_str: str, error: Optional[BaseException] = None + ) -> None: logger.error(message) logger.debug("Raised error: %s, for event : %s", error, event_str) diff --git a/src/ralph/models/xapi/base/agents.py b/src/ralph/models/xapi/base/agents.py index 66ed91c24..9f6ce53f5 100644 --- a/src/ralph/models/xapi/base/agents.py +++ b/src/ralph/models/xapi/base/agents.py @@ -1,13 +1,9 @@ """Base xAPI `Agent` definitions.""" +import sys from abc import ABC from typing import Optional, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import StrictStr from ..config import BaseModelWithConfig @@ -19,6 +15,11 @@ BaseXapiOpenIdIFI, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class BaseXapiAgentAccount(BaseModelWithConfig): """Pydantic model for `Agent` type `account` property. diff --git a/src/ralph/models/xapi/base/common.py b/src/ralph/models/xapi/base/common.py index d4e50ddc0..14c27d1e7 100644 --- a/src/ralph/models/xapi/base/common.py +++ b/src/ralph/models/xapi/base/common.py @@ -1,6 +1,6 @@ """Common for xAPI base definitions.""" -from typing import Dict +from typing import Dict, Generator, Type from langcodes import tag_is_valid from pydantic import StrictStr, validate_email @@ -11,8 +11,8 @@ class IRI(str): """Pydantic custom data type validating RFC 3987 IRIs.""" @classmethod - def __get_validators__(cls): # noqa: D105 - def validate(iri: str): + def __get_validators__(cls) -> Generator: # noqa: D105 + def validate(iri: str) -> Type["IRI"]: """Check whether the provided IRI is a valid RFC 3987 IRI.""" parse(iri, rule="IRI") return cls(iri) @@ -24,8 +24,8 @@ class LanguageTag(str): """Pydantic custom data type validating RFC 5646 Language tags.""" @classmethod - def __get_validators__(cls): # noqa: D105 - def validate(tag: str): + def __get_validators__(cls) -> Generator: # noqa: D105 + def validate(tag: str) -> Type["LanguageTag"]: """Check whether the provided tag is a valid RFC 5646 Language tag.""" if not tag_is_valid(tag): raise TypeError("Invalid RFC 5646 Language tag") @@ -41,8 +41,8 @@ class MailtoEmail(str): """Pydantic custom data type validating `mailto:email` format.""" @classmethod - def __get_validators__(cls): # noqa: D105 - def validate(mailto: str): + def __get_validators__(cls) -> Generator: # noqa: D105 + def validate(mailto: str) -> Type["MailtoEmail"]: """Check whether the provided value follows the `mailto:email` format.""" if not mailto.startswith("mailto:"): raise TypeError("Invalid `mailto:email` value") diff --git a/src/ralph/models/xapi/base/groups.py b/src/ralph/models/xapi/base/groups.py index d4f034a24..73c320058 100644 --- a/src/ralph/models/xapi/base/groups.py +++ b/src/ralph/models/xapi/base/groups.py @@ -1,13 +1,9 @@ """Base xAPI `Group` definitions.""" +import sys from abc import ABC from typing import List, Optional, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import StrictStr from ..config import BaseModelWithConfig @@ -19,6 +15,11 @@ BaseXapiOpenIdIFI, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class BaseXapiGroupCommonProperties(BaseModelWithConfig, ABC): """Pydantic model for core `Group` type property. diff --git a/src/ralph/models/xapi/base/objects.py b/src/ralph/models/xapi/base/objects.py index ef76ee635..74180040f 100644 --- a/src/ralph/models/xapi/base/objects.py +++ b/src/ralph/models/xapi/base/objects.py @@ -3,14 +3,10 @@ # Nota bene: we split object definitions into `objects.py` and `unnested_objects.py` # because of the circular dependency : objects -> context -> objects. +import sys from datetime import datetime from typing import List, Optional, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ..config import BaseModelWithConfig from .agents import BaseXapiAgent from .attachments import BaseXapiAttachment @@ -20,6 +16,11 @@ from .unnested_objects import BaseXapiUnnestedObject from .verbs import BaseXapiVerb +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class BaseXapiSubStatement(BaseModelWithConfig): """Pydantic model for `SubStatement` type property. diff --git a/src/ralph/models/xapi/base/results.py b/src/ralph/models/xapi/base/results.py index 3eee3ec01..bd3d49ec9 100644 --- a/src/ralph/models/xapi/base/results.py +++ b/src/ralph/models/xapi/base/results.py @@ -2,7 +2,7 @@ from datetime import timedelta from decimal import Decimal -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from pydantic import StrictBool, StrictStr, conint, root_validator @@ -27,7 +27,7 @@ class BaseXapiResultScore(BaseModelWithConfig): @root_validator @classmethod - def check_raw_min_max_relation(cls, values): + def check_raw_min_max_relation(cls, values: Any) -> Any: """Check the relationship `min < raw < max`.""" raw_value = values.get("raw", None) min_value = values.get("min", None) diff --git a/src/ralph/models/xapi/base/statements.py b/src/ralph/models/xapi/base/statements.py index d4f57a227..5282b51aa 100644 --- a/src/ralph/models/xapi/base/statements.py +++ b/src/ralph/models/xapi/base/statements.py @@ -1,7 +1,7 @@ """Base xAPI `Statement` definitions.""" from datetime import datetime -from typing import List, Optional, Union +from typing import Any, List, Optional, Union from uuid import UUID from pydantic import constr, root_validator @@ -47,7 +47,7 @@ class BaseXapiStatement(BaseModelWithConfig): @root_validator(pre=True) @classmethod - def check_abscence_of_empty_and_invalid_values(cls, values): + def check_absence_of_empty_and_invalid_values(cls, values: Any) -> Any: """Check the model for empty and invalid values. Check that the `context` field contains `platform` and `revision` fields @@ -57,7 +57,7 @@ def check_abscence_of_empty_and_invalid_values(cls, values): if value in [None, "", {}]: raise ValueError(f"{field}: invalid empty value") if isinstance(value, dict) and field != "extensions": - cls.check_abscence_of_empty_and_invalid_values(value) + cls.check_absence_of_empty_and_invalid_values(value) context = dict(values.get("context", {})) if context: diff --git a/src/ralph/models/xapi/base/unnested_objects.py b/src/ralph/models/xapi/base/unnested_objects.py index fa2129677..d0f2a0bd4 100644 --- a/src/ralph/models/xapi/base/unnested_objects.py +++ b/src/ralph/models/xapi/base/unnested_objects.py @@ -1,12 +1,7 @@ """Base xAPI `Object` definitions (1).""" -from typing import Dict, List, Optional, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - +import sys +from typing import Any, Dict, List, Optional, Union from uuid import UUID from pydantic import AnyUrl, StrictStr, constr, validator @@ -14,6 +9,11 @@ from ..config import BaseModelWithConfig from .common import IRI, LanguageMap +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class BaseXapiActivityDefinition(BaseModelWithConfig): """Pydantic model for `Activity` type `definition` property. @@ -81,7 +81,7 @@ class BaseXapiActivityInteractionDefinition(BaseXapiActivityDefinition): @validator("choices", "scale", "source", "target", "steps") @classmethod - def check_unique_ids(cls, value): + def check_unique_ids(cls, value: Any) -> None: """Check the uniqueness of interaction components IDs.""" if len(value) != len({x.id for x in value}): raise ValueError("Duplicate InteractionComponents are not valid") diff --git a/src/ralph/models/xapi/concepts/activity_types/acrossx_profile.py b/src/ralph/models/xapi/concepts/activity_types/acrossx_profile.py index 645bb5a77..be747865f 100644 --- a/src/ralph/models/xapi/concepts/activity_types/acrossx_profile.py +++ b/src/ralph/models/xapi/concepts/activity_types/acrossx_profile.py @@ -1,15 +1,18 @@ """`AcrossX Profile` activity types definitions.""" -try: +import sys + +from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition +# Message -# Message class MessageActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for message `Activity` type `definition` property. diff --git a/src/ralph/models/xapi/concepts/activity_types/activity_streams_vocabulary.py b/src/ralph/models/xapi/concepts/activity_types/activity_streams_vocabulary.py index 4b2c5da74..10b800d8b 100644 --- a/src/ralph/models/xapi/concepts/activity_types/activity_streams_vocabulary.py +++ b/src/ralph/models/xapi/concepts/activity_types/activity_streams_vocabulary.py @@ -1,14 +1,18 @@ """`Activity streams vocabulary` activity types definitions.""" -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal +import sys from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + # Page + + class PageActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for page `Activity` type `definition` property. @@ -32,6 +36,8 @@ class PageActivity(BaseXapiActivity): # File + + class FileActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for file `Activity` type `definition` property. diff --git a/src/ralph/models/xapi/concepts/activity_types/audio.py b/src/ralph/models/xapi/concepts/activity_types/audio.py index e14357855..c8ad1dd89 100644 --- a/src/ralph/models/xapi/concepts/activity_types/audio.py +++ b/src/ralph/models/xapi/concepts/activity_types/audio.py @@ -1,11 +1,14 @@ """`Audio` activity types definitions.""" -try: +import sys + +from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition # Audio diff --git a/src/ralph/models/xapi/concepts/activity_types/scorm_profile.py b/src/ralph/models/xapi/concepts/activity_types/scorm_profile.py index 3be231060..5f0ee0abf 100644 --- a/src/ralph/models/xapi/concepts/activity_types/scorm_profile.py +++ b/src/ralph/models/xapi/concepts/activity_types/scorm_profile.py @@ -1,14 +1,18 @@ """`Scorm Profile` activity types definitions.""" -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal +import sys from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + # CMI Interaction + + class CMIInteractionActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for CMI Interaction `Activity` type `definition` property. @@ -33,6 +37,8 @@ class CMIInteractionActivity(BaseXapiActivity): # Profile + + class ProfileActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for profile `Activity` type `definition` property. @@ -57,6 +63,8 @@ class ProfileActivity(BaseXapiActivity): # Course + + class CourseActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for course `Activity` type `definition` property. @@ -81,6 +89,8 @@ class CourseActivity(BaseXapiActivity): # Module + + class ModuleActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for module `Activity` type `definition` property. diff --git a/src/ralph/models/xapi/concepts/activity_types/tincan_vocabulary.py b/src/ralph/models/xapi/concepts/activity_types/tincan_vocabulary.py index 1cf8aad4e..b9b10d29d 100644 --- a/src/ralph/models/xapi/concepts/activity_types/tincan_vocabulary.py +++ b/src/ralph/models/xapi/concepts/activity_types/tincan_vocabulary.py @@ -1,14 +1,18 @@ """`Scorm Profile` activity types definitions.""" -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal +import sys from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + # Document + + class DocumentActivityDefinition(BaseXapiActivityDefinition): """Pydantic model for document `Activity` type `definition` property. diff --git a/src/ralph/models/xapi/concepts/activity_types/video.py b/src/ralph/models/xapi/concepts/activity_types/video.py index 17aaaa09b..e7616fd5a 100644 --- a/src/ralph/models/xapi/concepts/activity_types/video.py +++ b/src/ralph/models/xapi/concepts/activity_types/video.py @@ -1,11 +1,14 @@ """`Video` activity types definitions.""" -try: +import sys + +from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition # Video diff --git a/src/ralph/models/xapi/concepts/activity_types/virtual_classroom.py b/src/ralph/models/xapi/concepts/activity_types/virtual_classroom.py index a0eced420..cf1b842ca 100644 --- a/src/ralph/models/xapi/concepts/activity_types/virtual_classroom.py +++ b/src/ralph/models/xapi/concepts/activity_types/virtual_classroom.py @@ -1,13 +1,15 @@ """`Virtual classroom` activity types definitions.""" -try: +import sys + +from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition + +if sys.version_info >= (3, 8): from typing import Literal -except ImportError: +else: from typing_extensions import Literal -from ...base.unnested_objects import BaseXapiActivity, BaseXapiActivityDefinition - # Virtual classroom diff --git a/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py b/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py index 5ebcbe498..f0d4d5e5b 100644 --- a/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py +++ b/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py @@ -1,15 +1,16 @@ """`AcrossX Profile` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class PostedVerb(BaseXapiVerb): """Pydantic model for posted `verb`. diff --git a/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py index 52c302623..10b6cef1c 100644 --- a/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py @@ -1,15 +1,16 @@ """`Activity streams vocabulary` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class JoinVerb(BaseXapiVerb): """Pydantic model for join verb. diff --git a/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py index 7b8505aec..da1b6804c 100644 --- a/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py @@ -1,15 +1,16 @@ """`ADL Vocabulary` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class AskedVerb(BaseXapiVerb): """Pydantic model for asked `verb`. diff --git a/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py b/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py index a7296970d..53f027ceb 100644 --- a/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py +++ b/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py @@ -1,15 +1,16 @@ """`Navy Common Reference Profile` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class AccessedVerb(BaseXapiVerb): """Pydantic model for accessed `verb`. diff --git a/src/ralph/models/xapi/concepts/verbs/scorm_profile.py b/src/ralph/models/xapi/concepts/verbs/scorm_profile.py index 066edcac1..12dbd1b11 100644 --- a/src/ralph/models/xapi/concepts/verbs/scorm_profile.py +++ b/src/ralph/models/xapi/concepts/verbs/scorm_profile.py @@ -1,15 +1,16 @@ """`Scorm Profile` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class CompletedVerb(BaseXapiVerb): """Pydantic model for completed `verb`. diff --git a/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py index 39ed9c1b1..d32c25a81 100644 --- a/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py @@ -1,16 +1,16 @@ """`TinCan Vocabulary` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class ViewedVerb(BaseXapiVerb): """Pydantic model for viewed `verb`. diff --git a/src/ralph/models/xapi/concepts/verbs/video.py b/src/ralph/models/xapi/concepts/verbs/video.py index be875cf4c..d2e83d0b4 100644 --- a/src/ralph/models/xapi/concepts/verbs/video.py +++ b/src/ralph/models/xapi/concepts/verbs/video.py @@ -1,15 +1,16 @@ """`Video` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class PlayedVerb(BaseXapiVerb): """Pydantic model for played `verb`. diff --git a/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py b/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py index 54d6953d1..fcd2320a6 100644 --- a/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py +++ b/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py @@ -1,16 +1,16 @@ """`Virtual classroom` verbs definitions.""" +import sys from typing import Dict, Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - - from ...base.verbs import BaseXapiVerb from ...constants import LANG_EN_US_DISPLAY +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class MutedVerb(BaseXapiVerb): """Pydantic model for muted `verb`. diff --git a/src/ralph/models/xapi/lms/contexts.py b/src/ralph/models/xapi/lms/contexts.py index 0d49420f1..3de7a3c2b 100644 --- a/src/ralph/models/xapi/lms/contexts.py +++ b/src/ralph/models/xapi/lms/contexts.py @@ -1,14 +1,10 @@ """LMS xAPI events context fields definitions.""" +import sys from datetime import datetime from typing import List, Optional, Union from uuid import UUID -try: - from typing import Literal # pylint: disable = ungrouped-imports -except ImportError: - from typing_extensions import Literal - from pydantic import Field, NonNegativeFloat, PositiveInt, condecimal, validator from ..base.contexts import BaseXapiContext, BaseXapiContextContextActivities @@ -26,6 +22,11 @@ ) from ..config import BaseExtensionModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class LMSProfileActivity(ProfileActivity): """Pydantic model for LMS `profile` activity type. @@ -55,7 +56,7 @@ def check_presence_of_profile_activity_category( value: Union[ LMSProfileActivity, List[Union[LMSProfileActivity, BaseXapiActivity]] ], - ): + ) -> Union[LMSProfileActivity, List[Union[LMSProfileActivity, BaseXapiActivity]]]: """Check that the category list contains a `LMSProfileActivity`.""" if isinstance(value, LMSProfileActivity): return value diff --git a/src/ralph/models/xapi/lms/objects.py b/src/ralph/models/xapi/lms/objects.py index 4a09f0bdc..7bc701bff 100644 --- a/src/ralph/models/xapi/lms/objects.py +++ b/src/ralph/models/xapi/lms/objects.py @@ -1,12 +1,8 @@ """LMS xAPI events object fields definitions.""" +import sys from typing import Optional -try: - from typing import Literal # pylint: disable = ungrouped-imports -except ImportError: - from typing_extensions import Literal - from pydantic import Field from ..concepts.activity_types.acrossx_profile import ( @@ -20,8 +16,15 @@ from ..concepts.constants.acrossx_profile import ACTIVITY_EXTENSIONS_TYPE from ..config import BaseExtensionModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + # Page + + class LMSPageObjectDefinitionExtensions(BaseExtensionModelWithConfig): """Pydantic model for LMS page `object`.`definition`.`extensions` property. @@ -56,6 +59,8 @@ class LMSPageObject(WebpageActivity): # File + + class LMSFileObjectDefinitionExtensions(BaseExtensionModelWithConfig): """Pydantic model for LMS file `object`.`definition`.`extensions` property. diff --git a/src/ralph/models/xapi/video/contexts.py b/src/ralph/models/xapi/video/contexts.py index f3adfa667..4bdda8ff3 100644 --- a/src/ralph/models/xapi/video/contexts.py +++ b/src/ralph/models/xapi/video/contexts.py @@ -1,12 +1,7 @@ """Video xAPI events context fields definitions.""" +import sys from typing import List, Optional, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from uuid import UUID from pydantic import Field, NonNegativeFloat, validator @@ -29,6 +24,11 @@ ) from ..config import BaseExtensionModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class VideoProfileActivity(ProfileActivity): """Pydantic model for video profile `Activity` type. @@ -58,7 +58,9 @@ def check_presence_of_profile_activity_category( value: Union[ VideoProfileActivity, List[Union[VideoProfileActivity, BaseXapiActivity]] ], - ): + ) -> Union[ + VideoProfileActivity, List[Union[VideoProfileActivity, BaseXapiActivity]] + ]: """Check that the category list contains a `VideoProfileActivity`.""" if isinstance(value, VideoProfileActivity): return value diff --git a/src/ralph/models/xapi/video/results.py b/src/ralph/models/xapi/video/results.py index 5db4fd85f..7c515ad53 100644 --- a/src/ralph/models/xapi/video/results.py +++ b/src/ralph/models/xapi/video/results.py @@ -1,13 +1,9 @@ """Video xAPI events result fields definitions.""" +import sys from datetime import timedelta from typing import Optional -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from pydantic import Field, NonNegativeFloat from ..base.results import BaseXapiResult @@ -21,6 +17,11 @@ ) from ..config import BaseExtensionModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class VideoResultExtensions(BaseExtensionModelWithConfig): """Pydantic model for video `result`.`extensions` property. diff --git a/src/ralph/models/xapi/virtual_classroom/contexts.py b/src/ralph/models/xapi/virtual_classroom/contexts.py index 03256ae74..57784faf2 100644 --- a/src/ralph/models/xapi/virtual_classroom/contexts.py +++ b/src/ralph/models/xapi/virtual_classroom/contexts.py @@ -1,13 +1,8 @@ """Virtual classroom xAPI events context fields definitions.""" +import sys from datetime import datetime from typing import List, Optional, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from uuid import UUID from pydantic import Field, validator @@ -20,6 +15,11 @@ from ..concepts.constants.tincan_vocabulary import CONTEXT_EXTENSION_PLANNED_DURATION from ..config import BaseExtensionModelWithConfig +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + class VirtualClassroomProfileActivity(ProfileActivity): """Pydantic model for virtual classroom profile `Activity` type. @@ -53,7 +53,10 @@ def check_presence_of_profile_activity_category( VirtualClassroomProfileActivity, List[Union[VirtualClassroomProfileActivity, BaseXapiActivity]], ], - ): + ) -> Union[ + VirtualClassroomProfileActivity, + List[Union[VirtualClassroomProfileActivity, BaseXapiActivity]], + ]: """Check that the category list contains a `VirtualClassroomProfileActivity`.""" if isinstance(value, VirtualClassroomProfileActivity): return value diff --git a/src/ralph/parsers.py b/src/ralph/parsers.py index 6487fce29..545b4edfc 100644 --- a/src/ralph/parsers.py +++ b/src/ralph/parsers.py @@ -3,6 +3,7 @@ import json import logging from abc import ABC, abstractmethod +from typing import BinaryIO, Generator, TextIO, Union logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ class BaseParser(ABC): name = "base" @abstractmethod - def parse(self, input_file): + def parse(self, input_file: Union[TextIO, BinaryIO]) -> Generator: """Parse GELF formatted logs (one JSON string event per row). Args: @@ -33,8 +34,8 @@ class GELFParser(BaseParser): name = "gelf" - def parse(self, input_file): - """Parse GELF formatted logs (one JSON string event per row). + def parse(self, input_file: Union[TextIO, BinaryIO]) -> Generator: + """Parses GELF formatted logs (one JSON string event per row). Args: input_file (file-like): The log file to parse. @@ -65,8 +66,8 @@ class ElasticSearchParser(BaseParser): name = "es" - def parse(self, input_file): - """Parse Elasticsearch JSON documents. + def parse(self, input_file: Union[TextIO, BinaryIO]) -> Generator: + """Parses Elasticsearch JSON documents. Args: input_file (file-like): The file containing Elasticsearch JSON documents. diff --git a/src/ralph/utils.py b/src/ralph/utils.py index 3a2e476c9..1e715f206 100644 --- a/src/ralph/utils.py +++ b/src/ralph/utils.py @@ -8,14 +8,15 @@ from functools import reduce from importlib import import_module from inspect import getmembers, isclass, iscoroutine -from typing import Any, Dict, Iterable, Iterator, List, Union +from logging import Logger, getLogger +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Union from pydantic import BaseModel from ralph.exceptions import BackendException -def import_subclass(dotted_path, parent_class): +def import_subclass(dotted_path: str, parent_class: Any) -> Any: """Import a dotted module path. Return the class that is a subclass of `parent_class` inside this module. @@ -34,7 +35,7 @@ def import_subclass(dotted_path, parent_class): # Taken from Django utilities # https://docs.djangoproject.com/en/3.1/_modules/django/utils/module_loading/#import_string -def import_string(dotted_path): +def import_string(dotted_path: str) -> Any: """Import a dotted module path. Return the attribute/class designated by the last name in the path. @@ -55,7 +56,9 @@ def import_string(dotted_path): ) from err -def get_backend_type(backend_types: List[BaseModel], backend_name: str): +def get_backend_type( + backend_types: List[BaseModel], backend_name: str +) -> Union[BaseModel, None]: """Return the backend type from a backend name.""" backend_name = backend_name.upper() for backend_type in backend_types: @@ -64,7 +67,7 @@ def get_backend_type(backend_types: List[BaseModel], backend_name: str): return None -def get_backend_class(backend_type: BaseModel, backend_name: str): +def get_backend_class(backend_type: BaseModel, backend_name: str) -> Any: """Return the backend class given the backend type and backend name.""" # Get type name from backend_type class name backend_type_name = backend_type.__class__.__name__[ @@ -93,8 +96,8 @@ def get_backend_class(backend_type: BaseModel, backend_name: str): def get_backend_instance( backend_type: BaseModel, backend_name: str, - options: Union[dict, None] = None, -): + options: Optional[Dict] = None, +) -> Any: """Return the instantiated backend given the backend type, name and options.""" backend_class = get_backend_class(backend_type, backend_name) backend_settings = getattr(backend_type, backend_name.upper()) @@ -111,20 +114,20 @@ def get_backend_instance( return backend_class(backend_settings.__class__(**options)) -def get_root_logger(): +def get_root_logger() -> Logger: """Get main Ralph logger.""" - ralph_logger = logging.getLogger("ralph") + ralph_logger = getLogger("ralph") ralph_logger.propagate = True return ralph_logger -def now(): +def now() -> str: """Return the current UTC time in ISO format.""" return datetime.datetime.now(tz=datetime.timezone.utc).isoformat() -def get_dict_value_from_path(dict_: dict, path: List[str]): +def get_dict_value_from_path(dict_: Dict, path: Sequence[str]) -> Union[Dict, None]: """Get a nested dictionary value. Args: @@ -140,7 +143,7 @@ def get_dict_value_from_path(dict_: dict, path: List[str]): return None -def set_dict_value_from_path(dict_: dict, path: List[str], value: any): +def set_dict_value_from_path(dict_: Dict, path: List[str], value: Any) -> None: """Set a nested dictionary value. Args: @@ -153,7 +156,7 @@ def set_dict_value_from_path(dict_: dict, path: List[str], value: any): dict_[path[-1]] = value -async def gather_with_limited_concurrency(num_tasks: Union[None, int], *tasks): +async def gather_with_limited_concurrency(num_tasks: Optional[int], *tasks: Any) -> Any: """Gather no more than `num_tasks` tasks at time. Args: @@ -164,7 +167,7 @@ async def gather_with_limited_concurrency(num_tasks: Union[None, int], *tasks): if num_tasks is not None: semaphore = asyncio.Semaphore(num_tasks) - async def sem_task(task): + async def sem_task(task: Any) -> Any: async with semaphore: return await task @@ -180,7 +183,7 @@ async def sem_task(task): raise exception -def statements_are_equivalent(statement_1: dict, statement_2: dict): +def statements_are_equivalent(statement_1: dict, statement_2: dict) -> bool: """Check if statements are equivalent. To be equivalent, they must be identical on all fields not modified on input by the From 55189e5307a6e2d7ff890870f76f9f0f4321c1ba Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 26 Oct 2023 09:40:15 +0200 Subject: [PATCH 52/65] =?UTF-8?q?=F0=9F=94=A7(docs)=20add=20mike=20tools?= =?UTF-8?q?=20for=20docs=20versioning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Python package `mike` that builds and deploys docs for a specified version to the `gh-pages` branch. Also add a Makefile rule to be able to deploy manually. --- .circleci/config.yml | 4 ++-- Dockerfile | 6 ++++++ Makefile | 18 ++++++++++++++---- mkdocs.yml | 7 +++++++ setup.cfg | 1 + 5 files changed, 30 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 436e9e2eb..7dc027e04 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -670,8 +670,8 @@ workflows: only: /^v.*/ # Publish the documentation website to GitHub Pages. - # Only do it for master as tagged releases are supposed to tag their own version of the - # documentation in the release commit on master before they go out. + # Only do it for master and for tagged releases with a tag starting with + # the letter v. - deploy-docs: requires: - tray diff --git a/Dockerfile b/Dockerfile index 044edc092..f28098c65 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,6 +59,12 @@ RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; \ rm -rf /var/lib/apt/lists/*; \ fi; +# Install git for documentation deployment +RUN apt-get update && \ + apt-get install -y \ + git && \ + rm -rf /var/lib/apt/lists/*; + # Uninstall ralph and re-install it in editable mode along with development # dependencies RUN pip uninstall -y ralph-malph diff --git a/Makefile b/Makefile index 3a47f5eaf..40b670cde 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,14 @@ COMPOSE = DOCKER_USER=$(DOCKER_USER) docker compose COMPOSE_RUN = $(COMPOSE) run --rm COMPOSE_TEST_RUN = $(COMPOSE_RUN) COMPOSE_TEST_RUN_APP = $(COMPOSE_TEST_RUN) app -MKDOCS = $(COMPOSE_RUN) --no-deps --publish "8000:8000" app mkdocs +COMPOSE_RUN_DOCS = $(COMPOSE_RUN) --no-deps --publish "8000:8000" app + + +# -- Documentation +DOCS_COMMITTER_NAME = "FUN MOOC Bot" +DOCS_COMMITTER_EMAIL = funmoocbot@users.noreply.github.com +MKDOCS = $(COMPOSE_RUN_DOCS) mkdocs +MIKE = GIT_COMMITTER_NAME=$(DOCS_COMMITTER_NAME) GIT_COMMITTER_EMAIL=$(DOCS_COMMITTER_EMAIL) $(COMPOSE_RUN_DOCS) mike # -- Elasticsearch ES_PROTOCOL = http @@ -140,11 +147,14 @@ docs-build: ## build documentation site .PHONY: docs-build docs-deploy: ## deploy documentation site - @$(MKDOCS) gh-deploy +# Using env variables GIT_COMMITTER_NAME and GIT_COMMITTER_EMAIL will work with mike 2.0 +# Until that you need to set local git config user.name and user.email manually + @echo "Deploying docs with version dev" + @${MIKE} deploy dev .PHONY: docs-deploy -docs-serve: ## run mkdocs live server - @$(MKDOCS) serve --dev-addr 0.0.0.0:8000 +docs-serve: ## run mike live server + @$(MIKE) serve --dev-addr 0.0.0.0:8000 .PHONY: docs-serve down: ## stop and remove backend containers diff --git a/mkdocs.yml b/mkdocs.yml index 5bb6ae812..788a2252e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -43,3 +43,10 @@ nav: plugins: - search - mkdocstrings + - mike: + canonical_version: latest + version_selector: true + +extra: + version: + provider: mike diff --git a/setup.cfg b/setup.cfg index 0c7250166..4876289a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,6 +78,7 @@ dev = hypothesis==6.88.1 isort==5.12.0 logging-gelf==0.0.31 + mike==1.1.2 mkdocs==1.5.3 mkdocs-click==0.8.1 mkdocs-material==9.4.6 From 06d20388317a0e91d83d751e11ae37565ecafda6 Mon Sep 17 00:00:00 2001 From: Wilfried BARADAT Date: Thu, 26 Oct 2023 09:41:00 +0200 Subject: [PATCH 53/65] =?UTF-8?q?=F0=9F=92=9A(docs)=20add=20docs=20version?= =?UTF-8?q?ing=20in=20CI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With the help of `mike`, change the `deploy-docs` job to deploy docs with versioning. Job is executed for tags and for the master branch. Example: - For tags, docs will be tagged with the version 1.1 (for git tag v1.1.2) and alias "latest" - For master branch, docs will be tagged with version `dev` without any alias. --- .circleci/config.yml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 7dc027e04..abc31673b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -495,8 +495,17 @@ jobs: command: | git config --global user.email "funmoocbot@users.noreply.github.com" git config --global user.name "FUN MOOC Bot" - ~/.local/bin/mkdocs gh-deploy - + # Deploy docs with either: + # - DOCS_VERSION: 1.1 (for git tag v1.1.2) + # - DOCS_ALIAS: latest + # or + # - DOCS_VERSION: dev (for master branch) + # - No DOCS_ALIAS + DOCS_VERSION=$([[ -z "$CIRCLE_TAG" ]] && echo $CIRCLE_BRANCH || echo ${CIRCLE_TAG} | sed 's/^v\([0-9]\.[0-9]*\)\..*/\1/') + DOCS_ALIAS=$([[ -z "$CIRCLE_TAG" ]] && echo "" || echo "latest") + echo "DOCS_VERSION: ${DOCS_VERSION}" + echo "DOCS_ALIAS: ${DOCS_ALIAS}" + ~/.local/bin/mike deploy --push --update-aliases ${DOCS_VERSION} ${DOCS_ALIAS} # Make a new github release release: docker: @@ -680,7 +689,7 @@ workflows: branches: only: master tags: - only: /.*/ + only: /^v.*/ # Release - release: From 7aa8f4568a1da6382a85441a243100f963ee215a Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Fri, 27 Oct 2023 15:17:10 +0200 Subject: [PATCH 54/65] =?UTF-8?q?=F0=9F=93=9D(project)=20unify=20infinitiv?= =?UTF-8?q?e=20use=20in=20docstrings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In ralph, it has been established that all docstrings shall begin with an infinitive verb. Some typos remaining are now unified. --- gitlint/gitlint_emoji.py | 2 +- src/ralph/api/auth/basic.py | 2 +- src/ralph/api/routers/health.py | 4 +-- src/ralph/api/routers/statements.py | 2 +- src/ralph/backends/data/s3.py | 2 +- src/ralph/backends/data/swift.py | 2 +- src/ralph/backends/mixins.py | 4 +-- src/ralph/cli.py | 18 +++++----- src/ralph/conf.py | 6 ++-- src/ralph/models/converter.py | 6 ++-- src/ralph/models/selector.py | 2 +- src/ralph/models/validator.py | 6 ++-- src/ralph/parsers.py | 4 +-- tests/api/auth/test_basic.py | 2 +- tests/api/test_forwarding.py | 2 +- tests/api/test_statements.py | 2 +- tests/api/test_statements_get.py | 4 +-- tests/api/test_statements_post.py | 2 +- tests/api/test_statements_put.py | 2 +- tests/backends/http/test_async_lrs.py | 5 +-- tests/backends/lrs/test_async_mongo.py | 2 +- tests/backends/lrs/test_mongo.py | 2 +- tests/backends/stream/test_base.py | 2 +- tests/fixtures/backends.py | 33 ++++++++++++++----- tests/fixtures/hypothesis_strategies.py | 2 +- tests/helpers.py | 1 + tests/models/edx/converters/xapi/test_base.py | 6 ++-- .../models/edx/converters/xapi/test_video.py | 10 +++--- .../open_response_assessment/test_events.py | 6 ++-- .../edx/peer_instruction/test_statements.py | 9 +++-- tests/models/test_converter.py | 8 ++--- tests/models/xapi/base/test_statements.py | 16 ++++----- tests/models/xapi/test_video.py | 3 +- tests/models/xapi/test_virtual_classroom.py | 6 ++-- tests/test_cli.py | 28 ++++++++-------- tests/test_helpers.py | 2 +- tests/test_utils.py | 8 ++--- 37 files changed, 124 insertions(+), 99 deletions(-) diff --git a/gitlint/gitlint_emoji.py b/gitlint/gitlint_emoji.py index eb9040432..efad682bb 100644 --- a/gitlint/gitlint_emoji.py +++ b/gitlint/gitlint_emoji.py @@ -23,7 +23,7 @@ class GitmojiTitle(LineRule): target = CommitMessageTitle def validate(self, title, _commit): - """Validates Gitmoji title rule. + """Validate Gitmoji title rule. Downloads the list possible gitmojis from the project's GitHub repository and check that title contains one of them. diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index 027552595..f309ba371 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -114,7 +114,7 @@ def get_basic_auth_user( credentials: Optional[HTTPBasicCredentials] = Depends(security), security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: - """Checks valid auth parameters. + """Check valid auth parameters. Get the basic auth parameters from the Authorization header, and checks them against our own list of hashed credentials. diff --git a/src/ralph/api/routers/health.py b/src/ralph/api/routers/health.py index 1bafb3467..c8ca015f1 100644 --- a/src/ralph/api/routers/health.py +++ b/src/ralph/api/routers/health.py @@ -25,7 +25,7 @@ async def lbheartbeat() -> None: """Load balancer heartbeat. - Returns a 200 when the server is running. + Return a 200 when the server is running. """ return @@ -34,7 +34,7 @@ async def lbheartbeat() -> None: async def heartbeat() -> JSONResponse: """Application heartbeat. - Returns a 200 if all checks are successful. + Return a 200 if all checks are successful. """ content = {"database": (await await_if_coroutine(BACKEND_CLIENT.status())).value} status_code = ( diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index f163e2f99..436f82c52 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -588,5 +588,5 @@ async def post( logger.info("Indexed %d statements with success", success_count) - # Returns the list of IDs in the same order they were stored + # Return the list of IDs in the same order they were stored return list(statements_dict) diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index d4c30af5c..4198b9b5f 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -164,7 +164,7 @@ def read( ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: # pylint: disable=too-many-arguments - """Read an object matching the `query` in the `target` bucket and yields it. + """Read an object matching the `query` in the `target` bucket and yield it. Args: query: (str or BaseQuery): The ID of the object to read. diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index b0b75c53d..ac3b0ef9b 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -169,7 +169,7 @@ def read( ignore_errors: bool = False, ) -> Iterator[Union[bytes, dict]]: # pylint: disable=too-many-arguments - """Read objects matching the `query` in the `target` container and yields them. + """Read objects matching the `query` in the `target` container and yield them. Args: query: (str or BaseQuery): The query to select objects to read. diff --git a/src/ralph/backends/mixins.py b/src/ralph/backends/mixins.py index ae6e53ac4..6304abe08 100644 --- a/src/ralph/backends/mixins.py +++ b/src/ralph/backends/mixins.py @@ -59,10 +59,10 @@ def append_to_history(self, event): self.write_history(self.history + [event]) def get_command_history(self, backend_name, command): - """Extracts entry ids from the history for a given command and backend_name.""" + """Extract entry ids from the history for a given command and backend_name.""" def filter_by_name_and_command(entry): - """Checks whether the history entry matches the backend_name and command.""" + """Check whether the history entry matches the backend_name and command.""" return entry.get("backend") == backend_name and ( command in [entry.get("command"), entry.get("action")] ) diff --git a/src/ralph/cli.py b/src/ralph/cli.py index 932957bb5..782944c34 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -56,7 +56,7 @@ class CommaSeparatedTupleParamType(click.ParamType): name = "value1,value2,value3" def convert(self, value, param, ctx): - """Splits the value by comma to return a tuple of values.""" + """Split the value by comma to return a tuple of values.""" if isinstance(value, str): return tuple(value.split(",")) @@ -76,9 +76,9 @@ class CommaSeparatedKeyValueParamType(click.ParamType): name = "key=value,key=value" def convert(self, value, param, ctx): - """Splits the values by comma and equal sign. + """Split the values by comma and equal sign. - Returns a dictionary build with key/value pairs. + Return a dictionary build with key/value pairs. """ if isinstance(value, dict): return value @@ -116,7 +116,7 @@ class ClientOptionsParamType(CommaSeparatedKeyValueParamType): """Comma separated key=value parameter type for client options.""" def __init__(self, client_options_type: Any) -> None: - """Instantiates ClientOptionsParamType for a client_options_type. + """Instantiate ClientOptionsParamType for a client_options_type. Args: client_options_type (any): Pydantic model used for client options. @@ -124,9 +124,9 @@ def __init__(self, client_options_type: Any) -> None: self.client_options_type = client_options_type def convert(self, value, param, ctx): - """Splits the values by comma and equal sign. + """Split the values by comma and equal sign. - Returns an instance of client_options_type build with key/value pairs. + Return an instance of client_options_type build with key/value pairs. """ if isinstance(value, self.client_options_type): return value @@ -138,7 +138,7 @@ class HeadersParametersParamType(CommaSeparatedKeyValueParamType): """Comma separated key=value parameter type for headers parameters.""" def __init__(self, headers_parameters_type: Any) -> None: - """Instantiates HeadersParametersParamType for a headers_parameters_type. + """Instantiate HeadersParametersParamType for a headers_parameters_type. Args: headers_parameters_type (any): Pydantic model used for headers parameters. @@ -146,9 +146,9 @@ def __init__(self, headers_parameters_type: Any) -> None: self.headers_parameters_type = headers_parameters_type def convert(self, value, param, ctx): - """Splits the values by comma and equal sign. + """Split the values by comma and equal sign. - Returns an instance of headers_parameters_type build with key/value pairs. + Return an instance of headers_parameters_type build with key/value pairs. """ if isinstance(value, self.headers_parameters_type): return value diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 00affd056..5b415fa7f 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -79,7 +79,7 @@ class Config: # pylint: disable=missing-class-docstring # noqa: D106 _class_path: str = None def get_instance(self, **init_parameters): - """Returns an instance of the settings item class using its `_class_path`.""" + """Return an instance of the settings item class using its `_class_path`.""" return import_string(self._class_path)(**init_parameters) @@ -208,12 +208,12 @@ class AuthBackends(Enum): @property def APP_DIR(self) -> Path: # pylint: disable=invalid-name - """Returns the path to Ralph's configuration directory.""" + """Return the path to Ralph's configuration directory.""" return self._CORE.APP_DIR @property def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name - """Returns Ralph's default locale encoding.""" + """Return Ralph's default locale encoding.""" return self._CORE.LOCALE_ENCODING @root_validator(allow_reuse=True) diff --git a/src/ralph/models/converter.py b/src/ralph/models/converter.py index 7ddbe96e0..49b412365 100644 --- a/src/ralph/models/converter.py +++ b/src/ralph/models/converter.py @@ -102,12 +102,12 @@ class BaseConversionSet(ABC): __dest__: BaseModel def __init__(self) -> None: - """Initializes BaseConversionSet.""" + """Initialize BaseConversionSet.""" self._conversion_items = self._get_conversion_items() @abstractmethod def _get_conversion_items(self) -> Set[ConversionItem]: - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" def __iter__(self) -> Iterator[ConversionItem]: # noqa: D105 return iter(self._conversion_items) @@ -172,7 +172,7 @@ def __init__( module: str = "ralph.models.edx.converters.xapi", **conversion_set_kwargs: Any, ) -> None: - """Initializes the Converter.""" + """Initialize the Converter.""" self.model_selector = model_selector self.src_conversion_set = self.get_src_conversion_set( import_module(module), **conversion_set_kwargs diff --git a/src/ralph/models/selector.py b/src/ralph/models/selector.py index a39980d87..9375a3cfe 100644 --- a/src/ralph/models/selector.py +++ b/src/ralph/models/selector.py @@ -67,7 +67,7 @@ class ModelSelector: """ def __init__(self, module: str = "ralph.models.edx") -> None: - """Instantiates ModelSelector.""" + """Instantiate ModelSelector.""" self.model_rules = ModelSelector.build_model_rules(import_module(module)) self.decision_tree = self.get_decision_tree(self.model_rules) diff --git a/src/ralph/models/validator.py b/src/ralph/models/validator.py index fe970e46e..78bebe7d5 100644 --- a/src/ralph/models/validator.py +++ b/src/ralph/models/validator.py @@ -17,13 +17,13 @@ class Validator: """Events validator using pydantic models.""" def __init__(self, model_selector: ModelSelector): - """Initializes Validator.""" + """Initialize Validator.""" self.model_selector = model_selector def validate( self, input_file: TextIO, ignore_errors: bool, fail_on_unknown: bool ) -> Generator: - """Validates JSON event strings line by line.""" + """Validate JSON event strings line by line.""" total = 0 success = 0 for event_str in input_file: @@ -48,7 +48,7 @@ def validate( logger.info("Total events: %d, Invalid events: %d", total, total - success) def get_first_valid_model(self, event: dict) -> Any: - """Returns the first successfully instantiated model for the event. + """Return the first successfully instantiated model for the event. Raises: UnknownEventException: When the event does not match any model. diff --git a/src/ralph/parsers.py b/src/ralph/parsers.py index 545b4edfc..8552afc32 100644 --- a/src/ralph/parsers.py +++ b/src/ralph/parsers.py @@ -35,7 +35,7 @@ class GELFParser(BaseParser): name = "gelf" def parse(self, input_file: Union[TextIO, BinaryIO]) -> Generator: - """Parses GELF formatted logs (one JSON string event per row). + """Parse GELF formatted logs (one JSON string event per row). Args: input_file (file-like): The log file to parse. @@ -67,7 +67,7 @@ class ElasticSearchParser(BaseParser): name = "es" def parse(self, input_file: Union[TextIO, BinaryIO]) -> Generator: - """Parses Elasticsearch JSON documents. + """Parse Elasticsearch JSON documents. Args: input_file (file-like): The file containing Elasticsearch JSON documents. diff --git a/tests/api/auth/test_basic.py b/tests/api/auth/test_basic.py index 211f5e411..d692fb619 100644 --- a/tests/api/auth/test_basic.py +++ b/tests/api/auth/test_basic.py @@ -213,7 +213,7 @@ def test_get_whoami_wrong_password(basic_auth_test_client, fs): def test_get_whoami_correct_credentials(basic_auth_test_client, fs): """Whoami returns a 200 response when the credentials are correct. - Returns the username and associated scopes. + Return the username and associated scopes. """ credential_bytes = base64.b64encode("ralph:admin".encode("utf-8")) credentials = str(credential_bytes, "utf-8") diff --git a/tests/api/test_forwarding.py b/tests/api/test_forwarding.py index 2aeb019e1..5ed686195 100644 --- a/tests/api/test_forwarding.py +++ b/tests/api/test_forwarding.py @@ -200,7 +200,7 @@ def raise_for_status(): raise RequestError("Failure during request.") async def post_fail(*args, **kwargs): # pylint: disable=unused-argument - """Returns a MockUnsuccessfulResponse instance.""" + """Return a MockUnsuccessfulResponse instance.""" return MockUnsuccessfulResponse() monkeypatch.setattr("ralph.api.forwarding.AsyncClient.post", post_fail) diff --git a/tests/api/test_statements.py b/tests/api/test_statements.py index 1a468875c..d05f0d3a9 100644 --- a/tests/api/test_statements.py +++ b/tests/api/test_statements.py @@ -12,7 +12,7 @@ def test_api_statements_backend_instance_with_runserver_backend_env(monkeypatch): - """Tests that given the RALPH_RUNSERVER_BACKEND environment variable, the backend + """Test that given the RALPH_RUNSERVER_BACKEND environment variable, the backend instance `BACKEND_CLIENT` should be updated accordingly. """ # Default backend diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 856bf3e6b..8d675fb90 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -70,7 +70,7 @@ def insert_mongo_statements(mongo_client, statements): def insert_clickhouse_statements(statements): - """Inserts a bunch of example statements into ClickHouse for testing.""" + """Insert a bunch of example statements into ClickHouse for testing.""" settings = ClickHouseDataBackend.settings_class( HOST=CLICKHOUSE_TEST_HOST, PORT=CLICKHOUSE_TEST_PORT, @@ -90,7 +90,7 @@ def insert_statements_and_monkeypatch_backend( # pylint: disable=invalid-name,unused-argument def _insert_statements_and_monkeypatch_backend(statements): - """Inserts statements once into each backend.""" + """Insert statements once into each backend.""" backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" if request.param == "async_es": insert_es_statements(es, statements) diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index 7b7858042..1a040a0df 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -465,7 +465,7 @@ async def test_api_statements_post_with_failure_during_storage( # pylint: disable=invalid-name,unused-argument,too-many-arguments async def write_mock(*args, **kwargs): - """Raises an exception. Mocks the database.write method.""" + """Raise an exception. Mocks the database.write method.""" raise BackendException() backend_instance = backend() diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index 5512987b6..f65966f90 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -350,7 +350,7 @@ async def test_api_statements_put_with_failure_during_storage( # pylint: disable=invalid-name,unused-argument,too-many-arguments def write_mock(*args, **kwargs): - """Raises an exception. Mocks the database.write method.""" + """Raise an exception. Mocks the database.write method.""" raise BackendException() backend_instance = backend() diff --git a/tests/backends/http/test_async_lrs.py b/tests/backends/http/test_async_lrs.py index 9bddefa0c..371f0c48f 100644 --- a/tests/backends/http/test_async_lrs.py +++ b/tests/backends/http/test_async_lrs.py @@ -798,7 +798,8 @@ async def test_backends_http_lrs_write_with_invalid_parameters( @pytest.mark.anyio async def test_backends_http_lrs_write_without_target(httpx_mock: HTTPXMock, caplog): """Test the LRS backend `write` method without target parameter value writes - statements to '/xAPI/statements' default endpoint.""" + statements to '/xAPI/statements' default endpoint. + """ base_url = "http://fake-lrs.com" @@ -872,7 +873,7 @@ async def test_backends_http_lrs_write_backend_exception( httpx_mock: HTTPXMock, caplog, ): - """Test the `LRSHTTP.write` method with HTTP error""" + """Test the `LRSHTTP.write` method with HTTP error.""" base_url = "http://fake-lrs.com" target = "/xAPI/statements" diff --git a/tests/backends/lrs/test_async_mongo.py b/tests/backends/lrs/test_async_mongo.py index b3f3a9108..b0ed2d09a 100644 --- a/tests/backends/lrs/test_async_mongo.py +++ b/tests/backends/lrs/test_async_mongo.py @@ -360,7 +360,7 @@ async def mock_read(**_): async def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_two_collections( mongo, mongo_forwarding, async_mongo_lrs_backend ): - """Tests the `AsyncMongoLRSBackend.query_statements_by_ids` method, given a valid + """Test the `AsyncMongoLRSBackend.query_statements_by_ids` method, given a valid search query, should execute the query only on the specified collection and return the expected results. """ diff --git a/tests/backends/lrs/test_mongo.py b/tests/backends/lrs/test_mongo.py index 612b9c0a7..aa643c6ff 100644 --- a/tests/backends/lrs/test_mongo.py +++ b/tests/backends/lrs/test_mongo.py @@ -352,7 +352,7 @@ def mock_read(**_): def test_backends_lrs_mongo_lrs_backend_query_statements_by_ids_with_two_collections( mongo, mongo_forwarding, mongo_lrs_backend ): - """Tests the `MongoLRSBackend.query_statements_by_ids` method, given a valid search + """Test the `MongoLRSBackend.query_statements_by_ids` method, given a valid search query, should execute the query only on the specified collection and return the expected results. """ diff --git a/tests/backends/stream/test_base.py b/tests/backends/stream/test_base.py index d6c4a3cb9..2e4282f5d 100644 --- a/tests/backends/stream/test_base.py +++ b/tests/backends/stream/test_base.py @@ -12,7 +12,7 @@ class GoodStream(BaseStreamBackend): name = "good" def stream(self, target): - """Fakes the stream method.""" + """Fake the stream method.""" GoodStream() diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index 46df6812c..ab73d270e 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -174,7 +174,10 @@ def get_es_fixture(host=ES_TEST_HOSTS, index=ES_TEST_INDEX): @pytest.fixture def es(): - """Yield an Elasticsearch test client. See get_es_fixture above.""" + """Yield an Elasticsearch test client. + + See get_es_fixture above. + """ # pylint: disable=invalid-name for es_client in get_es_fixture(): yield es_client @@ -182,7 +185,10 @@ def es(): @pytest.fixture def es_forwarding(): - """Yield a second Elasticsearch test client. See get_es_fixture above.""" + """Yield a second Elasticsearch test client. + + See get_es_fixture above. + """ for es_client in get_es_fixture(index=ES_TEST_FORWARDING_INDEX): yield es_client @@ -268,7 +274,7 @@ def get_mongo_fixture( database=MONGO_TEST_DATABASE, collection=MONGO_TEST_COLLECTION, ): - """Create / delete a Mongo test database + collection and yields an + """Create / delete a Mongo test database + collection and yield an instantiated client. """ client = MongoClient(connection_uri) @@ -286,7 +292,10 @@ def get_mongo_fixture( @pytest.fixture def mongo(): - """Yield a Mongo test client. See get_mongo_fixture above.""" + """Yield a Mongo test client. + + See get_mongo_fixture above. + """ for mongo_client in get_mongo_fixture(): yield mongo_client @@ -339,7 +348,10 @@ def get_mongo_lrs_backend( @pytest.fixture def mongo_forwarding(): - """Yield a second Mongo test client. See get_mongo_fixture above.""" + """Yield a second Mongo test client. + + See get_mongo_fixture above. + """ for mongo_client in get_mongo_fixture(collection=MONGO_TEST_FORWARDING_COLLECTION): yield mongo_client @@ -350,7 +362,7 @@ def get_clickhouse_fixture( database=CLICKHOUSE_TEST_DATABASE, event_table_name=CLICKHOUSE_TEST_TABLE_NAME, ): - """Create / delete a ClickHouse test database + table and yields an + """Create / delete a ClickHouse test database + table and yield an instantiated client. """ client_options = ClickHouseClientOptions( @@ -396,7 +408,10 @@ def get_clickhouse_fixture( @pytest.fixture def clickhouse(): - """Yield a ClickHouse test client. See get_clickhouse_fixture above.""" + """Yield a ClickHouse test client. + + See get_clickhouse_fixture above. + """ for clickhouse_client in get_clickhouse_fixture(): yield clickhouse_client @@ -455,7 +470,7 @@ def es_data_stream(): @pytest.fixture def settings_fs(fs, monkeypatch): - """Force Path instantiation with fake FS in Ralph's Settings.""" + """Force Path instantiation with fake FS in ralph settings.""" # pylint:disable=invalid-name,unused-argument monkeypatch.setattr( @@ -636,7 +651,7 @@ def get_swift_data_backend(): @pytest.fixture() def moto_fs(fs): - """Fix the incompatibility between moto and pyfakefs""" + """Fix the incompatibility between moto and pyfakefs.""" # pylint:disable=invalid-name for module in [boto3, botocore]: diff --git a/tests/fixtures/hypothesis_strategies.py b/tests/fixtures/hypothesis_strategies.py index b874cf6a4..fb5e9f30e 100644 --- a/tests/fixtures/hypothesis_strategies.py +++ b/tests/fixtures/hypothesis_strategies.py @@ -96,7 +96,7 @@ def custom_builds( def custom_given(*args: Union[st.SearchStrategy, BaseModel], **kwargs): - """Wrap the Hypothesis `given` function. Replaces st.builds with custom_builds.""" + """Wrap the Hypothesis `given` function. Replace st.builds with custom_builds.""" strategies = [] for arg in args: strategies.append(custom_builds(arg) if is_base_model(arg) else arg) diff --git a/tests/helpers.py b/tests/helpers.py index 3aceccad1..c361f1dfd 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -134,6 +134,7 @@ def mock_statement( timestamp: Optional[Union[str, int]] = None, ): """Generate fake statements with random or provided parameters. + Fields `actor`, `verb`, `object`, `timestamp` accept integer values which can be used to create distinct values identifiable by this integer. For each variable, using `None` will assign a default value. `timestamp` may be ommited diff --git a/tests/models/edx/converters/xapi/test_base.py b/tests/models/edx/converters/xapi/test_base.py index f7babb58f..ddf4975e8 100644 --- a/tests/models/edx/converters/xapi/test_base.py +++ b/tests/models/edx/converters/xapi/test_base.py @@ -18,7 +18,7 @@ class DummyBaseXapiConverter(BaseXapiConverter): """Dummy implementation of abstract BaseXapiConverter.""" def _get_conversion_items(self): # pylint: disable=no-self-use - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" return set() converter = DummyBaseXapiConverter(uuid_namespace, "https://fun-mooc.fr") @@ -27,13 +27,13 @@ def _get_conversion_items(self): # pylint: disable=no-self-use def test_models_edx_converters_xapi_base_xapi_converter_unsuccessful_initialization(): - """Tests BaseXapiConverter failed initialization.""" + """Test BaseXapiConverter failed initialization.""" class DummyBaseXapiConverter(BaseXapiConverter): """Dummy implementation of abstract BaseXapiConverter.""" def _get_conversion_items(self): # pylint: disable=no-self-use - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" return set() with pytest.raises(ConfigurationException, match="Invalid UUID namespace"): diff --git a/tests/models/edx/converters/xapi/test_video.py b/tests/models/edx/converters/xapi/test_video.py index ebafc0674..914c05a46 100644 --- a/tests/models/edx/converters/xapi/test_video.py +++ b/tests/models/edx/converters/xapi/test_video.py @@ -30,7 +30,7 @@ def test_models_edx_converters_xapi_video_ui_load_video_to_video_initialized( uuid_namespace, event, platform_url ): - """Tests that converting with `UILoadVideoToVideoInitialized` returns the + """Test that converting with `UILoadVideoToVideoInitialized` returns the expected xAPI statement. """ event.context.user_id = "1" @@ -85,7 +85,7 @@ def test_models_edx_converters_xapi_video_ui_load_video_to_video_initialized( def test_models_edx_converters_xapi_video_ui_play_video_to_video_played( uuid_namespace, event, platform_url ): - """Tests that converting with `UIPlayVideoToVideoPlayed` returns the expected + """Test that converting with `UIPlayVideoToVideoPlayed` returns the expected xAPI statement. """ event.context.user_id = "1" @@ -144,7 +144,7 @@ def test_models_edx_converters_xapi_video_ui_play_video_to_video_played( def test_models_edx_converters_xapi_video_ui_pause_video_to_video_paused( uuid_namespace, event, platform_url ): - """Tests that converting with `UIPauseVideoToVideoPaused` returns the expected xAPI + """Test that converting with `UIPauseVideoToVideoPaused` returns the expected xAPI statement. """ event.context.user_id = "1" @@ -204,7 +204,7 @@ def test_models_edx_converters_xapi_video_ui_pause_video_to_video_paused( def test_models_edx_converters_xapi_video_ui_stop_video_to_video_terminated( uuid_namespace, event, platform_url ): - """Tests that converting with `UIStopVideoToVideoTerminated` returns the expected + """Test that converting with `UIStopVideoToVideoTerminated` returns the expected xAPI statement. """ event.context.user_id = "1" @@ -265,7 +265,7 @@ def test_models_edx_converters_xapi_video_ui_stop_video_to_video_terminated( def test_models_edx_converters_xapi_video_ui_seek_video_to_video_seeked( uuid_namespace, event, platform_url ): - """Tests that converting with `UISeekVideoToVideoSeeked` returns the expected + """Test that converting with `UISeekVideoToVideoSeeked` returns the expected xAPI statement. """ event.context.user_id = "1" diff --git a/tests/models/edx/open_response_assessment/test_events.py b/tests/models/edx/open_response_assessment/test_events.py index 7af940c45..614f2bc98 100644 --- a/tests/models/edx/open_response_assessment/test_events.py +++ b/tests/models/edx/open_response_assessment/test_events.py @@ -19,7 +19,8 @@ @custom_given(ORAGetPeerSubmissionEventField) def test_models_edx_ora_get_peer_submission_event_field_with_valid_values(field): """Test that a valid `ORAGetPeerSubmissionEventField` does not raise a - `ValidationError`.""" + `ValidationError`. + """ assert re.match( r"^block-v1:.+\+.+\+.+type@openassessment+block@[a-f0-9]{32}$", field.item_id @@ -31,7 +32,8 @@ def test_models_edx_ora_get_submission_for_staff_grading_event_field_with_valid_ field, ): """Test that a valid `ORAGetSubmissionForStaffGradingEventField` does not raise a - `ValidationError`.""" + `ValidationError`. + """ assert re.match( r"^block-v1:.+\+.+\+.+type@openassessment+block@[a-f0-9]{32}$", field.item_id diff --git a/tests/models/edx/peer_instruction/test_statements.py b/tests/models/edx/peer_instruction/test_statements.py index 2aa760135..3c2841573 100644 --- a/tests/models/edx/peer_instruction/test_statements.py +++ b/tests/models/edx/peer_instruction/test_statements.py @@ -38,7 +38,8 @@ def test_models_edx_peer_instruction_accessed_with_valid_statement( statement, ): """Test that a `ubc.peer_instruction.accessed` statement has the expected - `event_type`.""" + `event_type`. + """ assert statement.event_type == "ubc.peer_instruction.accessed" assert statement.name == "ubc.peer_instruction.accessed" @@ -48,7 +49,8 @@ def test_models_edx_peer_instruction_original_submitted_with_valid_statement( statement, ): """Test that a `ubc.peer_instruction.original_submitted` statement has the - expected `event_type`.""" + expected `event_type`. + """ assert statement.event_type == "ubc.peer_instruction.original_submitted" assert statement.name == "ubc.peer_instruction.original_submitted" @@ -58,6 +60,7 @@ def test_models_edx_peer_instruction_revised_submitted_with_valid_statement( statement, ): """Test that a `ubc.peer_instruction.revised_submitted` statement has the - expected `event_type`.""" + expected `event_type`. + """ assert statement.event_type == "ubc.peer_instruction.revised_submitted" assert statement.name == "ubc.peer_instruction.revised_submitted" diff --git a/tests/models/test_converter.py b/tests/models/test_converter.py index 74d678c31..e9b8b4ef3 100644 --- a/tests/models/test_converter.py +++ b/tests/models/test_converter.py @@ -113,7 +113,7 @@ class DummyBaseConversionSet(BaseConversionSet): __dest__ = BaseModel def _get_conversion_items(self): # pylint: disable=no-self-use - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" return set() assert not convert_dict_event(event, "", DummyBaseConversionSet()).dict() @@ -153,7 +153,7 @@ class DummyBaseConversionSet(BaseConversionSet): __dest__ = DummyBaseModel def _get_conversion_items(self): # pylint: disable=no-self-use - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" return {ConversionItem("converted", source, transformer)} converted = convert_dict_event(event, "", DummyBaseConversionSet()) @@ -172,7 +172,7 @@ class DummyBaseConversionSet(BaseConversionSet): __dest__ = BaseModel def _get_conversion_items(self): # pylint: disable=no-self-use - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" return {item} msg = "Failed to get the transformed value for field: None" @@ -190,7 +190,7 @@ class DummyBaseConversionSet(BaseConversionSet): __dest__ = BaseModel def _get_conversion_items(self): # pylint: disable=no-self-use - """Returns a set of ConversionItems used for conversion.""" + """Return a set of ConversionItems used for conversion.""" return set() msg = "Failed to parse the event, invalid JSON string" diff --git a/tests/models/xapi/base/test_statements.py b/tests/models/xapi/base/test_statements.py index 3fd1dcfc5..a88fca426 100644 --- a/tests/models/xapi/base/test_statements.py +++ b/tests/models/xapi/base/test_statements.py @@ -40,7 +40,7 @@ def test_models_xapi_base_statement_with_invalid_null_values(path, value, statem XAPI-00001 An LRS rejects with error code 400 Bad Request any Statement having a property whose value is set to "null", an empty object, or has no value, except in an "extensions" - property + property. """ statement = statement.dict(exclude_none=True) set_dict_value_from_path(statement, path.split("__"), value) @@ -64,7 +64,7 @@ def test_models_xapi_base_statement_with_valid_null_values(path, value, statemen XAPI-00001 An LRS rejects with error code 400 Bad Request any Statement having a property whose value is set to "null", an empty object, or has no value, except in an "extensions" - property + property. """ statement = statement.dict(exclude_none=True) set_dict_value_from_path(statement, path.split("__"), value) @@ -108,13 +108,13 @@ def test_models_xapi_base_statement_must_use_actor_verb_and_object(field, statem XAPI-00003 An LRS rejects with error code 400 Bad Request a Statement which does not contain an - "actor" property + "actor" property. XAPI-00004 An LRS rejects with error code 400 Bad Request a Statement which does not contain a - "verb" property + "verb" property. XAPI-00005 An LRS rejects with error code 400 Bad Request a Statement which does not contain an - "object" property + "object" property. """ statement = statement.dict(exclude_none=True) del statement[field] @@ -142,7 +142,7 @@ def test_models_xapi_base_statement_with_invalid_data_types(path, value, stateme XAPI-00006 An LRS rejects with error code 400 Bad Request a Statement which uses the wrong data - type + type. """ statement = statement.dict(exclude_none=True) set_dict_value_from_path(statement, path.split("__"), value) @@ -469,7 +469,7 @@ def test_models_xapi_base_statement_with_invalid_version(value, statement): """Test that the statement does not accept an invalid version field. An LRS MUST reject all Statements with a version specified that does not start with - 1.0.. + 1.0. """ statement = statement.dict(exclude_none=True) set_dict_value_from_path(statement, ["version"], value) @@ -482,7 +482,7 @@ def test_models_xapi_base_statement_with_valid_version(statement): """Test that the statement does accept a valid version field. Statements returned by an LRS MUST retain the version they are accepted with. - If they lack a version, the version MUST be set to 1.0.0 + If they lack a version, the version MUST be set to 1.0.0. """ statement = statement.dict(exclude_none=True) set_dict_value_from_path(statement, ["version"], "1.0.3") diff --git a/tests/models/xapi/test_video.py b/tests/models/xapi/test_video.py index 52f25f078..52d1b974f 100644 --- a/tests/models/xapi/test_video.py +++ b/tests/models/xapi/test_video.py @@ -115,7 +115,8 @@ def test_models_xapi_video_paused_with_valid_statement(statement): @custom_given(VideoSeeked) def test_models_xapi_video_seeked_with_valid_statement(statement): """Test that a video seeked statement has the expected `verb`.`id` and - `object`.`definition`.`type` property values.""" + `object`.`definition`.`type` property values. + """ assert statement.verb.id == "https://w3id.org/xapi/video/verbs/seeked" assert ( diff --git a/tests/models/xapi/test_virtual_classroom.py b/tests/models/xapi/test_virtual_classroom.py index 854db14d9..b3eeadeb4 100644 --- a/tests/models/xapi/test_virtual_classroom.py +++ b/tests/models/xapi/test_virtual_classroom.py @@ -78,7 +78,8 @@ def test_models_xapi_virtual_classroom_initialized_with_valid_statement(statemen @custom_given(VirtualClassroomJoined) def test_models_xapi_virtual_classroom_joined_with_valid_statement(statement): """Test that a virtual classroom joined statement has the expected - `verb`.`id` and `object`.`definition`.`type` property values.""" + `verb`.`id` and `object`.`definition`.`type` property values. + """ assert statement.verb.id == "http://activitystrea.ms/join" assert ( statement.object.definition.type @@ -89,7 +90,8 @@ def test_models_xapi_virtual_classroom_joined_with_valid_statement(statement): @custom_given(VirtualClassroomLeft) def test_models_xapi_virtual_classroom_left_with_valid_statement(statement): """Test that a virtual classroom left statement has the expected - `verb`.`id` and `object`.`definition`.`type` property values.""" + `verb`.`id` and `object`.`definition`.`type` property values. + """ assert statement.verb.id == "http://activitystrea.ms/leave" assert ( statement.object.definition.type diff --git a/tests/test_cli.py b/tests/test_cli.py index 890576027..827543378 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -198,7 +198,7 @@ def _assert_matching_basic_auth_credentials( ): """Assert that credentials match other arguments. - args: + Args: credentials: credentials to match against username: username that should match credentials scopes: scopes that should match credentials @@ -236,7 +236,7 @@ def _assert_matching_basic_auth_credentials( def _ifi_type_from_command(ifi_command): - """Return the ifi_type associated to the command being passed to cli""" + """Return the ifi_type associated to the command being passed to cli.""" if ifi_command not in ["-M", "-S", "-O", "-A"]: raise ValueError('The ifi_command must be one of: "-M", "-S", "-O", "-A"') @@ -246,7 +246,7 @@ def _ifi_type_from_command(ifi_command): def _ifi_value_from_command(ifi_value, ifi_type): - """Parse ifi_value returned by cli to generate dict when `ifi_type` is `account`""" + """Parse ifi_value returned by cli to generate dict when `ifi_type` is `account`.""" if ifi_type == "account": # Parse arguments from cli return {"name": ifi_value.split()[0], "homePage": ifi_value.split()[1]} @@ -436,7 +436,7 @@ def test_cli_auth_command_when_writing_auth_file_with_incorrect_auth_file(fs): def test_cli_extract_command_with_gelf_parser(gelf_logger): - """Test the extract command using the GELF parser.""" + """Test ralph extract command using the GELF parser.""" gelf_logger.info('{"username": "foo"}') runner = CliRunner() @@ -449,7 +449,7 @@ def test_cli_extract_command_with_gelf_parser(gelf_logger): def test_cli_extract_command_with_es_parser(): - """Test the extract command using the ElasticSearchParser.""" + """Test ralph extract command using the ElasticSearchParser.""" es_output = ( "\n".join( [ @@ -476,7 +476,7 @@ def test_cli_extract_command_with_es_parser(): @custom_given(UIPageClose) def test_cli_validate_command_with_edx_format(event): - """Test the validate command using the edx format.""" + """Test ralph validate command using the edx format.""" event_str = event.json() runner = CliRunner() result = runner.invoke(cli, ["validate", "-f", "edx"], input=event_str) @@ -487,7 +487,7 @@ def test_cli_validate_command_with_edx_format(event): @custom_given(UIPageClose) @pytest.mark.parametrize("valid_uuid", ["ee241f8b-174f-5bdb-bae9-c09de5fe017f"]) def test_cli_convert_command_from_edx_to_xapi_format(valid_uuid, event): - """Test the convert command from edx to xapi format.""" + """Test ralph convert command from edx to xapi format.""" event_str = event.json() runner = CliRunner() command = f"-v ERROR convert -f edx -t xapi -u {valid_uuid} -p https://fun-mooc.fr" @@ -501,8 +501,8 @@ def test_cli_convert_command_from_edx_to_xapi_format(valid_uuid, event): @pytest.mark.parametrize("invalid_uuid", ["", None, 1, {}]) def test_cli_convert_command_with_invalid_uuid(invalid_uuid): - """Test that the convert command raises an exception when the uuid namespace is - invalid. + """Test that the ralph convert command raises an exception when the uuid namespace + is invalid. """ runner = CliRunner() command = f"convert -f edx -t xapi -u '{invalid_uuid}' -p https://fun-mooc.fr" @@ -531,7 +531,7 @@ def test_cli_verbosity_option_should_impact_logging_behaviour(verbosity): def test_cli_read_command_with_ldp_backend(monkeypatch): - """Test the read command using the LDP backend.""" + """Test ralph read command using the LDP backend.""" archive_content = {"foo": "bar"} def mock_read(*_, **__): @@ -553,7 +553,7 @@ def mock_read(*_, **__): # pylint: disable=invalid-name # pylint: disable=unused-argument def test_cli_read_command_with_fs_backend(fs, monkeypatch): - """Test the read command using the FS backend.""" + """Test ralph read command using the FS backend.""" archive_content = {"foo": "bar"} def mock_read(*_, **__): @@ -694,7 +694,7 @@ def test_cli_read_command_with_ws_backend(events, ws): def test_cli_list_command_with_ldp_backend(monkeypatch): - """Test the list command using the LDP backend.""" + """Test ralph list command using the LDP backend.""" archive_list = [ "5d5c4c93-04a4-42c5-9860-f51fa4044aa1", "997db3eb-b9ca-485d-810f-b530a6cef7c6", @@ -766,7 +766,7 @@ def mock_list(this, target=None, details=False, new=False): # pylint: disable=invalid-name # pylint: disable=unused-argument def test_cli_list_command_with_fs_backend(fs, monkeypatch): - """Test the list command using the LDP backend.""" + """Test ralph list command using the LDP backend.""" archive_list = [ "file1", "file2", @@ -827,7 +827,7 @@ def mock_list(this, target=None, details=False, new=False): # pylint: disable=invalid-name def test_cli_write_command_with_fs_backend(fs): - """Test the write command using the FS backend.""" + """Test ralph write command using the FS backend.""" fs.create_dir(str(settings.APP_DIR)) fs.create_dir("foo") diff --git a/tests/test_helpers.py b/tests/test_helpers.py index e40b8dff9..d915ed6db 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -140,7 +140,7 @@ def test_helpers_mock_statement_no_input(): def test_helpers_mock_statement_value_input(): - """Test that mocked statement have the expected fields with value input.""" + """Test that mocked statement has the expected fields with value input.""" reference_statement = { "id": str(uuid4()), diff --git a/tests/test_utils.py b/tests/test_utils.py index 654eba3c1..39e9b1ba6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -56,16 +56,16 @@ def test_utils_get_backend_instance(monkeypatch, options, expected): """Test get_backend_instance utility should return the expected result.""" class DummyTestBackendSettings(InstantiableSettingsItem): - """Represents a dummy backend setting.""" + """Represent a dummy backend setting.""" FOO: str = "FOO" # pylint: disable=disallowed-name def get_instance(self, **init_parameters): # pylint: disable=no-self-use - """Returns the init_parameters.""" + """Return the init_parameters.""" return init_parameters class DummyTestBackend(ABC): - """Represents a dummy backend instance.""" + """Represent a dummy backend instance.""" type = "test" name = "dummy" @@ -77,7 +77,7 @@ def __call__(self, *args, **kwargs): # pylint: disable=unused-argument return {} def mock_import_module(*args, **kwargs): # pylint: disable=unused-argument - """""" + """Mock import_module.""" test_module = ModuleType(name="ralph.backends.test.dummy") test_module.DummyTestBackendSettings = DummyTestBackendSettings From 91c0523be7e108b813d36981e5c70f40f9cf3b90 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 30 Oct 2023 09:05:00 +0000 Subject: [PATCH 55/65] =?UTF-8?q?=E2=AC=86=EF=B8=8F(project)=20upgrade=20p?= =?UTF-8?q?ython=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit | datasource | package | from | to | | ---------- | ---------------- | -------- | --------- | | pypi | black | 23.10.0 | 23.10.1 | | pypi | cachetools | 5.3.1 | 5.3.2 | | pypi | cryptography | 41.0.4 | 41.0.5 | | pypi | mkdocs-material | 9.4.6 | 9.4.7 | | pypi | moto | 4.2.6 | 4.2.7 | | pypi | mypy | 1.2.0 | 1.6.1 | | pypi | pytest | 7.4.2 | 7.4.3 | | pypi | types-cachetools | 5.3.0.6 | 5.3.0.7 | | pypi | types-requests | 2.31.0.6 | 2.31.0.10 | --- setup.cfg | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 4876289a9..e5c1bf673 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,8 +71,8 @@ cli = dev = anyio<4.0.1 # unpin until fastapi supports new major version of anyio bandit==1.7.5 - black==23.10.0 - cryptography==41.0.4 + black==23.10.1 + cryptography==41.0.5 factory-boy==3.3.0 flake8==6.1.0 hypothesis==6.88.1 @@ -81,14 +81,14 @@ dev = mike==1.1.2 mkdocs==1.5.3 mkdocs-click==0.8.1 - mkdocs-material==9.4.6 + mkdocs-material==9.4.7 mkdocstrings[python-legacy]==0.23.0 - moto==4.2.6 - mypy==1.2.0 + moto==4.2.7 + mypy==1.6.1 pydocstyle==6.3.0 pyfakefs==5.3.0 pylint==3.0.2 - pytest==7.4.2 + pytest==7.4.3 pytest-asyncio==0.21.1 pytest-cov==4.1.0 pytest-httpx<0.23.0 # pin as Python 3.7 and 3.8 is no longer supported from release 0.23.0 @@ -96,14 +96,14 @@ dev = responses<0.23.2 # pin until boto3 supports urllib3>=2 types-python-dateutil == 2.8.19.14 types-python-jose == 3.3.4.8 - types-requests<2.31.0.7 - types-cachetools == 5.3.0.6 + types-requests<2.31.0.11 + types-cachetools ==5.3.0.7 ci = twine==4.0.2 lrs = bcrypt==4.0.1 fastapi==0.104.0 - cachetools==5.3.1 + cachetools==5.3.2 ; We temporary pin `h11` to avoid pip downloading the latest version to solve a ; dependency conflict caused by `httpx` which requires httpcore>=0.15.0,<0.16.0 and ; `httpcore` depends on h11>=0.11,<0.13. From ce0e618b71082cdf27946b591d34e64c30a9d8f9 Mon Sep 17 00:00:00 2001 From: Quitterie Lucas Date: Mon, 30 Oct 2023 10:06:41 +0100 Subject: [PATCH 56/65] =?UTF-8?q?=F0=9F=93=9D(project)=20update=20CHANGELO?= =?UTF-8?q?G.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed: - Upgrade `cachetools` to `5.3.2` --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c4d4410a..1ac6f7f09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to ### Changed +- Upgrade `cachetools` to `5.3.2` - Refactor `database` and `storage` backends under the unified `data` backend interface [BC] - Refactor LRS `query_statements` and `query_statements_by_ids` backends From a4802c9caef30ef4e215c6d90a97a57538b250e5 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Tue, 31 Oct 2023 16:06:26 +0100 Subject: [PATCH 57/65] allow multipe auth after unification --- .env.dist | 4 +- CHANGELOG.md | 2 + docs/api.md | 5 +- src/ralph/api/__init__.py | 1 + src/ralph/api/auth/__init__.py | 49 ++++++++++++++++--- src/ralph/api/auth/basic.py | 47 +++++------------- src/ralph/api/auth/oidc.py | 47 ++++++++---------- src/ralph/conf.py | 38 ++++++++++++--- tests/api/auth/test_basic.py | 81 ++++++++++++++++--------------- tests/api/auth/test_oidc.py | 63 +++++++++++++++--------- tests/api/test_statements_get.py | 38 ++++++++++----- tests/api/test_statements_post.py | 24 ++++----- tests/api/test_statements_put.py | 24 ++++----- tests/conftest.py | 2 - tests/fixtures/auth.py | 35 ------------- tests/helpers.py | 23 +++++++++ tests/test_conf.py | 5 ++ 17 files changed, 276 insertions(+), 212 deletions(-) diff --git a/.env.dist b/.env.dist index 73df46ee6..0779a7330 100644 --- a/.env.dist +++ b/.env.dist @@ -116,9 +116,9 @@ RALPH_BACKENDS__HTTP__LRS__STATEMENTS_ENDPOINT=/xAPI/statements # RALPH_CONVERTER_EDX_XAPI_UUID_NAMESPACE= -# LRS API +# LRS API -RALPH_RUNSERVER_AUTH_BACKEND=basic +RALPH_RUNSERVER_AUTH_BACKENDS=Basic RALPH_RUNSERVER_AUTH_OIDC_AUDIENCE=http://localhost:8100 RALPH_RUNSERVER_AUTH_OIDC_ISSUER_URI=http://learning-analytics-playground_keycloak_1:8080/auth/realms/fun-mooc RALPH_RUNSERVER_BACKEND=es diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ac6f7f09..1e668e162 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ have an authority field matching that of the user - Helm: improve volumes and ingress configurations - API: Add `RALPH_LRS_RESTRICT_BY_SCOPE` option enabling endpoint access control by user scopes +- API: Variable `RUNSERVER_AUTH_BACKEND` becomes `RUNSERVER_AUTH_BACKENDS`, and + multiple authentication methods are supported simultaneously ### Fixed diff --git a/docs/api.md b/docs/api.md index 3e62c6a77..d0aff35f7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -108,9 +108,10 @@ $ curl --user john.doe@example.com:PASSWORD http://localhost:8100/whoami Ralph LRS API server supports OpenID Connect (OIDC) on top of OAuth 2.0 for authentication and authorization. -To enable OIDC auth, you should set the `RALPH_RUNSERVER_AUTH_BACKEND` environment variable as follows: + +To enable OIDC auth, you should modify the `RALPH_RUNSERVER_AUTH_BACKENDS` environment variable by adding (or replacing) `oidc`: ```bash -RALPH_RUNSERVER_AUTH_BACKEND=oidc +RALPH_RUNSERVER_AUTH_BACKENDS=basic,oidc ``` and you should define the `RALPH_RUNSERVER_AUTH_OIDC_ISSUER_URI` environment variable with your identity provider's Issuer Identifier URI as follows: ```bash diff --git a/src/ralph/api/__init__.py b/src/ralph/api/__init__.py index 2a33df53a..1360e260c 100644 --- a/src/ralph/api/__init__.py +++ b/src/ralph/api/__init__.py @@ -43,6 +43,7 @@ def filter_transactions( ) app = FastAPI() + app.include_router(statements.router) app.include_router(health.router) diff --git a/src/ralph/api/auth/__init__.py b/src/ralph/api/auth/__init__.py index 80aa52fff..037d8163a 100644 --- a/src/ralph/api/auth/__init__.py +++ b/src/ralph/api/auth/__init__.py @@ -1,11 +1,48 @@ """Main module for Ralph's LRS API authentication.""" +from fastapi import Depends, HTTPException, status +from fastapi.security import SecurityScopes + from ralph.api.auth.basic import get_basic_auth_user from ralph.api.auth.oidc import get_oidc_user -from ralph.conf import settings +from ralph.conf import AuthBackend, settings + + +def get_authenticated_user( + security_scopes: SecurityScopes = SecurityScopes([]), + basic_auth_user=Depends(get_basic_auth_user), + oidc_auth_user=Depends(get_oidc_user), +): + """Authenticate user with any allowed method, using credentials in the header.""" + if AuthBackend.BASIC not in settings.RUNSERVER_AUTH_BACKENDS: + basic_auth_user = None + if AuthBackend.OIDC not in settings.RUNSERVER_AUTH_BACKENDS: + oidc_auth_user = None + + if basic_auth_user is not None: + user = basic_auth_user + auth_method = "Basic" + elif oidc_auth_user is not None: + user = oidc_auth_user + auth_method = "Bearer" + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={ + "WWW-Authenticate": ",".join( + [val.value for val in settings.RUNSERVER_AUTH_BACKENDS] + ) + }, + ) -# At startup, select the authentication mode that will be used -if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC: - get_authenticated_user = get_oidc_user -else: - get_authenticated_user = get_basic_auth_user + # Restrict access by scopes + if settings.LRS_RESTRICT_BY_SCOPES: + for requested_scope in security_scopes.scopes: + if not user.scopes.is_authorized(requested_scope): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f'Access not authorized to scope: "{requested_scope}".', + headers={"WWW-Authenticate": auth_method}, + ) + return user diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index f309ba371..2bc4a82ea 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -8,8 +8,8 @@ import bcrypt from cachetools import TTLCache, cached -from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials, SecurityScopes +from fastapi import Depends +from fastapi.security import HTTPBasic, HTTPBasicCredentials from pydantic import BaseModel, root_validator from starlette.authentication import AuthenticationError @@ -102,17 +102,15 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials: @cached( TTLCache(maxsize=settings.AUTH_CACHE_MAX_SIZE, ttl=settings.AUTH_CACHE_TTL), lock=Lock(), - key=lambda credentials, security_scopes: ( + key=lambda credentials: ( credentials.username, credentials.password, - security_scopes.scope_str, ) if credentials is not None else None, ) def get_basic_auth_user( credentials: Optional[HTTPBasicCredentials] = Depends(security), - security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: """Check valid auth parameters. @@ -121,18 +119,13 @@ def get_basic_auth_user( Args: credentials (iterator): auth parameters from the Authorization header - security_scopes: scopes requested for access Raises: HTTPException """ if not credentials: - logger.error("The basic authentication mode requires a Basic Auth header") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Basic"}, - ) + logger.info("No credentials were found for Basic auth") + return None try: user = next( @@ -145,15 +138,14 @@ def get_basic_auth_user( except StopIteration: # next() gets the first item in the enumerable; if there is none, it # raises a StopIteration error as it is out of bounds. - logger.warning( + logger.info( "User %s tried to authenticate but this account does not exists", credentials.username, ) hashed_password = None - except AuthenticationError as exc: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=str(exc) - ) from exc + except AuthenticationError: + logger.info("Error while authenticating using Basic auth") + return None # Check that a password was passed if not hashed_password: @@ -162,11 +154,7 @@ def get_basic_auth_user( bcrypt.checkpw( credentials.password.encode(settings.LOCALE_ENCODING), UNUSED_PASSWORD ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Basic"}, - ) + return None # Check password validity if not bcrypt.checkpw( @@ -177,21 +165,8 @@ def get_basic_auth_user( "Authentication failed for user %s", credentials.username, ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Basic"}, - ) + return None user = AuthenticatedUser(scopes=user.scopes, agent=dict(user.agent)) - # Restrict access by scopes - if settings.LRS_RESTRICT_BY_SCOPES: - for requested_scope in security_scopes.scopes: - if not user.scopes.is_authorized(requested_scope): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=f'Access not authorized to scope: "{requested_scope}".', - headers={"WWW-Authenticate": "Basic"}, - ) return user diff --git a/src/ralph/api/auth/oidc.py b/src/ralph/api/auth/oidc.py index 2a2d107b0..f11cef628 100644 --- a/src/ralph/api/auth/oidc.py +++ b/src/ralph/api/auth/oidc.py @@ -6,7 +6,7 @@ import requests from fastapi import Depends, HTTPException, status -from fastapi.security import OpenIdConnect, SecurityScopes +from fastapi.security import HTTPBearer, OpenIdConnect from jose import ExpiredSignatureError, JWTError, jwt from jose.exceptions import JWTClaimsError from pydantic import AnyUrl, BaseModel, Extra @@ -66,7 +66,7 @@ def discover_provider(base_url: AnyUrl) -> Dict: ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", + detail="Could not validate credentials ABU", headers={"WWW-Authenticate": "Bearer"}, ) from exc @@ -88,14 +88,13 @@ def get_public_keys(jwks_uri: AnyUrl) -> Dict: ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", + detail="Could not validate credentials ABA", headers={"WWW-Authenticate": "Bearer"}, ) from exc def get_oidc_user( - auth_header: Annotated[Optional[str], Depends(oauth2_scheme)], - security_scopes: SecurityScopes = SecurityScopes([]), + auth_header: Annotated[Optional[HTTPBearer], Depends(oauth2_scheme)], ) -> AuthenticatedUser: """Decode and validate OpenId Connect ID token against issuer in config. @@ -109,17 +108,25 @@ def get_oidc_user( Raises: HTTPException """ + if auth_header is None or "Bearer" not in auth_header: - logger.error("The OpenID Connect authentication mode requires a Bearer token") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, + logger.info( + "Not using OIDC auth. The OpenID Connect authentication mode requires a " + "Bearer token" ) + return None id_token = auth_header.split(" ")[-1] - provider_config = discover_provider(settings.RUNSERVER_AUTH_OIDC_ISSUER_URI) - key = get_public_keys(provider_config["jwks_uri"]) + try: + provider_config = discover_provider(settings.RUNSERVER_AUTH_OIDC_ISSUER_URI) + except HTTPException: + return None + + try: + key = get_public_keys(provider_config["jwks_uri"]) + except HTTPException: + return None + algorithms = provider_config["id_token_signing_alg_values_supported"] audience = settings.RUNSERVER_AUTH_OIDC_AUDIENCE options = { @@ -137,11 +144,7 @@ def get_oidc_user( ) except (ExpiredSignatureError, JWTError, JWTClaimsError) as exc: logger.error("Unable to decode the ID token: %s", exc) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) from exc + return None id_token = IDToken.parse_obj(decoded_token) @@ -150,14 +153,4 @@ def get_oidc_user( scopes=UserScopes(id_token.scope.split(" ") if id_token.scope else []), ) - # Restrict access by scopes - if settings.LRS_RESTRICT_BY_SCOPES: - for requested_scope in security_scopes.scopes: - if not user.scopes.is_authorized(requested_scope): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=f'Access not authorized to scope: "{requested_scope}".', - headers={"WWW-Authenticate": "Basic"}, - ) - return user diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 5b415fa7f..923407e93 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -133,6 +133,36 @@ class Config: # pylint: disable=missing-class-docstring # noqa: D106 timeout: float +class AuthBackend(Enum): + """Model for valid authentication methods.""" + + BASIC = "Basic" + OIDC = "OIDC" + + +class AuthBackends(str): + """Model representing a list of authentication backends.""" + + @classmethod + def __get_validators__(cls): # noqa: D105 + """Checks whether the value is a comma separated string or a tuple representing + an AuthBackend.""" + + def validate( + value: Union[AuthBackend, Tuple[AuthBackend], List[AuthBackend]] + ) -> Tuple[AuthBackend]: + """Check whether the value is a comma separated string or a list/tuple.""" + if isinstance(value, (tuple, list)): + return tuple(AuthBackend(value)) + + if isinstance(value, str): + return tuple(AuthBackend(val) for val in value.split(",")) + + raise TypeError("Invalid comma separated list") + + yield validate + + class Settings(BaseSettings): """Pydantic model for Ralph's global environment & configuration settings.""" @@ -142,12 +172,6 @@ class Config(BaseSettingsConfig): env_file = ".env" env_file_encoding = core_settings.LOCALE_ENCODING - class AuthBackends(Enum): - """Enum of the authentication backends.""" - - BASIC = "basic" - OIDC = "oidc" - _CORE: CoreSettings = core_settings AUTH_FILE: Path = _CORE.APP_DIR / "auth.json" AUTH_CACHE_MAX_SIZE = 100 @@ -188,7 +212,7 @@ class AuthBackends(Enum): }, } PARSERS: ParserSettings = ParserSettings() - RUNSERVER_AUTH_BACKEND: AuthBackends = AuthBackends.BASIC + RUNSERVER_AUTH_BACKENDS: AuthBackends = AuthBackends([AuthBackend.BASIC]) RUNSERVER_AUTH_OIDC_AUDIENCE: str = None RUNSERVER_AUTH_OIDC_ISSUER_URI: AnyHttpUrl = None RUNSERVER_BACKEND: Literal[ diff --git a/tests/api/auth/test_basic.py b/tests/api/auth/test_basic.py index d692fb619..dda4d8bce 100644 --- a/tests/api/auth/test_basic.py +++ b/tests/api/auth/test_basic.py @@ -5,9 +5,10 @@ import bcrypt import pytest -from fastapi.exceptions import HTTPException -from fastapi.security import HTTPBasicCredentials, SecurityScopes +from fastapi.security import HTTPBasicCredentials +from fastapi.testclient import TestClient +from ralph.api import app from ralph.api.auth.basic import ( ServerUsersCredentials, UserCredentials, @@ -15,7 +16,9 @@ get_stored_credentials, ) from ralph.api.auth.user import AuthenticatedUser, UserScopes -from ralph.conf import Settings, settings +from ralph.conf import AuthBackend, Settings, settings + +from tests.helpers import configure_env_for_mock_oidc_auth STORED_CREDENTIALS = json.dumps( [ @@ -29,6 +32,9 @@ ) +client = TestClient(app) + + def test_api_auth_basic_model_serveruserscredentials(): """Test api.auth ServerUsersCredentials model.""" @@ -103,12 +109,10 @@ def test_api_auth_basic_caching_credentials(fs): credentials = HTTPBasicCredentials(username="ralph", password="admin") # Call function as in a first request with these credentials - get_basic_auth_user( - security_scopes=SecurityScopes(["profile/read"]), credentials=credentials - ) + get_basic_auth_user(credentials=credentials) assert get_basic_auth_user.cache.popitem() == ( - ("ralph", "admin", "profile/read"), + ("ralph", "admin"), AuthenticatedUser( agent={"mbox": "mailto:ralph@example.com"}, scopes=UserScopes(["statements/read/mine", "statements/write"]), @@ -126,8 +130,7 @@ def test_api_auth_basic_with_wrong_password(fs): credentials = HTTPBasicCredentials(username="ralph", password="wrong_password") # Call function as in a first request with these credentials - with pytest.raises(HTTPException): - get_basic_auth_user(credentials, SecurityScopes(["all"])) + assert get_basic_auth_user(credentials) is None def test_api_auth_basic_no_credential_file_found(fs, monkeypatch): @@ -139,40 +142,39 @@ def test_api_auth_basic_no_credential_file_found(fs, monkeypatch): credentials = HTTPBasicCredentials(username="ralph", password="admin") - with pytest.raises(HTTPException): - get_basic_auth_user(credentials, SecurityScopes(["all"])) + assert get_basic_auth_user(credentials) is None -def test_get_whoami_no_credentials(basic_auth_test_client): +def test_get_whoami_no_credentials(): """Whoami route returns a 401 error when no credentials are sent.""" - response = basic_auth_test_client.get("/whoami") + response = client.get("/whoami") assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Basic" - assert response.json() == {"detail": "Could not validate credentials"} + assert response.headers["www-authenticate"] == ",".join( + [val.value for val in settings.RUNSERVER_AUTH_BACKENDS] + ) + assert response.json() == {"detail": "Invalid authentication credentials"} -def test_get_whoami_credentials_wrong_scheme(basic_auth_test_client): +def test_get_whoami_credentials_wrong_scheme(): """Whoami route returns a 401 error when wrong scheme is used for authorization.""" - response = basic_auth_test_client.get( - "/whoami", headers={"Authorization": "Bearer sometoken"} - ) + response = client.get("/whoami", headers={"Authorization": "Bearer sometoken"}) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Basic" - assert response.json() == {"detail": "Could not validate credentials"} + assert response.headers["www-authenticate"] == ",".join( + [val.value for val in settings.RUNSERVER_AUTH_BACKENDS] + ) + assert response.json() == {"detail": "Invalid authentication credentials"} -def test_get_whoami_credentials_encoding_error(basic_auth_test_client): +def test_get_whoami_credentials_encoding_error(): """Whoami route returns a 401 error when the credentials encoding is broken.""" - response = basic_auth_test_client.get( - "/whoami", headers={"Authorization": "Basic not-base64"} - ) + response = client.get("/whoami", headers={"Authorization": "Basic not-base64"}) assert response.status_code == 401 assert response.headers["www-authenticate"] == "Basic" assert response.json() == {"detail": "Invalid authentication credentials"} # pylint: disable=invalid-name -def test_get_whoami_username_not_found(basic_auth_test_client, fs): +def test_get_whoami_username_not_found(fs): """Whoami route returns a 401 error when the username cannot be found.""" credential_bytes = base64.b64encode("john:admin".encode("utf-8")) credentials = str(credential_bytes, "utf-8") @@ -181,17 +183,17 @@ def test_get_whoami_username_not_found(basic_auth_test_client, fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - response = basic_auth_test_client.get( - "/whoami", headers={"Authorization": f"Basic {credentials}"} - ) + response = client.get("/whoami", headers={"Authorization": f"Basic {credentials}"}) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Basic" + assert response.headers["www-authenticate"] == ",".join( + [val.value for val in settings.RUNSERVER_AUTH_BACKENDS] + ) assert response.json() == {"detail": "Invalid authentication credentials"} # pylint: disable=invalid-name -def test_get_whoami_wrong_password(basic_auth_test_client, fs): +def test_get_whoami_wrong_password(fs): """Whoami route returns a 401 error when the password is wrong.""" credential_bytes = base64.b64encode("john:not-admin".encode("utf-8")) credentials = str(credential_bytes, "utf-8") @@ -200,21 +202,24 @@ def test_get_whoami_wrong_password(basic_auth_test_client, fs): fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) get_basic_auth_user.cache_clear() - response = basic_auth_test_client.get( - "/whoami", headers={"Authorization": f"Basic {credentials}"} - ) + response = client.get("/whoami", headers={"Authorization": f"Basic {credentials}"}) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Basic" assert response.json() == {"detail": "Invalid authentication credentials"} # pylint: disable=invalid-name -def test_get_whoami_correct_credentials(basic_auth_test_client, fs): +@pytest.mark.parametrize( + "runserver_auth_backends", + [[AuthBackend.BASIC, AuthBackend.OIDC], [AuthBackend.BASIC]], +) +def test_get_whoami_correct_credentials(fs, monkeypatch, runserver_auth_backends): """Whoami returns a 200 response when the credentials are correct. Return the username and associated scopes. """ + configure_env_for_mock_oidc_auth(monkeypatch, runserver_auth_backends) + credential_bytes = base64.b64encode("ralph:admin".encode("utf-8")) credentials = str(credential_bytes, "utf-8") @@ -222,9 +227,7 @@ def test_get_whoami_correct_credentials(basic_auth_test_client, fs): fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) get_basic_auth_user.cache_clear() - response = basic_auth_test_client.get( - "/whoami", headers={"Authorization": f"Basic {credentials}"} - ) + response = client.get("/whoami", headers={"Authorization": f"Basic {credentials}"}) assert response.status_code == 200 diff --git a/tests/api/auth/test_oidc.py b/tests/api/auth/test_oidc.py index a0b621f01..553737c94 100644 --- a/tests/api/auth/test_oidc.py +++ b/tests/api/auth/test_oidc.py @@ -1,23 +1,36 @@ """Tests for the api.auth.oidc module.""" - +import pytest import responses +from fastapi.testclient import TestClient from pydantic import parse_obj_as +from ralph.api import app from ralph.api.auth.oidc import discover_provider, get_public_keys +from ralph.conf import AuthBackend from ralph.models.xapi.base.agents import BaseXapiAgentWithOpenId from tests.fixtures.auth import ISSUER_URI, mock_oidc_user +from tests.helpers import configure_env_for_mock_oidc_auth + +client = TestClient(app) +@pytest.mark.parametrize( + "runserver_auth_backends", + [[AuthBackend.BASIC, AuthBackend.OIDC], [AuthBackend.OIDC]], +) @responses.activate -def test_api_auth_oidc_valid(oidc_auth_test_client): +def test_api_auth_oidc_valid(monkeypatch, runserver_auth_backends): """Test a valid OpenId Connect authentication.""" + configure_env_for_mock_oidc_auth(monkeypatch, runserver_auth_backends) + oidc_token = mock_oidc_user(scopes=["all", "profile/read"]) - response = oidc_auth_test_client.get( + headers = {"Authorization": f"Bearer {oidc_token}"} + response = client.get( "/whoami", - headers={"Authorization": f"Bearer {oidc_token}"}, + headers=headers, ) assert response.status_code == 200 assert len(response.json().keys()) == 2 @@ -27,27 +40,29 @@ def test_api_auth_oidc_valid(oidc_auth_test_client): @responses.activate -def test_api_auth_invalid_token( - oidc_auth_test_client, mock_discovery_response, mock_oidc_jwks -): +def test_api_auth_invalid_token(monkeypatch, mock_discovery_response, mock_oidc_jwks): """Test API with an invalid audience.""" + configure_env_for_mock_oidc_auth(monkeypatch) + mock_oidc_user() - response = oidc_auth_test_client.get( + response = client.get( "/whoami", headers={"Authorization": "Bearer wrong_token"}, ) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Bearer" - assert response.json() == {"detail": "Could not validate credentials"} + # assert response.headers["www-authenticate"] == "Bearer" + assert response.json() == {"detail": "Invalid authentication credentials"} @responses.activate -def test_api_auth_invalid_discovery(oidc_auth_test_client, encoded_token): +def test_api_auth_invalid_discovery(monkeypatch, encoded_token): """Test API with an invalid provider discovery.""" + configure_env_for_mock_oidc_auth(monkeypatch) + # Clear LRU cache discover_provider.cache_clear() get_public_keys.cache_clear() @@ -60,22 +75,24 @@ def test_api_auth_invalid_discovery(oidc_auth_test_client, encoded_token): status=500, ) - response = oidc_auth_test_client.get( + response = client.get( "/whoami", headers={"Authorization": f"Bearer {encoded_token}"}, ) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Bearer" - assert response.json() == {"detail": "Could not validate credentials"} + # assert response.headers["www-authenticate"] == "Bearer" + assert response.json() == {"detail": "Invalid authentication credentials"} @responses.activate def test_api_auth_invalid_keys( - oidc_auth_test_client, mock_discovery_response, mock_oidc_jwks, encoded_token + monkeypatch, mock_discovery_response, mock_oidc_jwks, encoded_token ): """Test API with an invalid request for keys.""" + configure_env_for_mock_oidc_auth(monkeypatch) + # Clear LRU cache discover_provider.cache_clear() get_public_keys.cache_clear() @@ -96,27 +113,29 @@ def test_api_auth_invalid_keys( status=500, ) - response = oidc_auth_test_client.get( + response = client.get( "/whoami", headers={"Authorization": f"Bearer {encoded_token}"}, ) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Bearer" - assert response.json() == {"detail": "Could not validate credentials"} + # assert response.headers["www-authenticate"] == "Bearer" + assert response.json() == {"detail": "Invalid authentication credentials"} @responses.activate -def test_api_auth_invalid_header(oidc_auth_test_client): +def test_api_auth_invalid_header(monkeypatch): """Test API with an invalid request header.""" + configure_env_for_mock_oidc_auth(monkeypatch) + oidc_token = mock_oidc_user() - response = oidc_auth_test_client.get( + response = client.get( "/whoami", headers={"Authorization": f"Wrong header {oidc_token}"}, ) assert response.status_code == 401 - assert response.headers["www-authenticate"] == "Bearer" - assert response.json() == {"detail": "Could not validate credentials"} + # assert response.headers["www-authenticate"] == "Bearer" + assert response.json() == {"detail": "Invalid authentication credentials"} diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 8d675fb90..938543c46 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -9,12 +9,11 @@ from elasticsearch.helpers import bulk from ralph.api import app -from ralph.api.auth import get_authenticated_user from ralph.api.auth.basic import get_basic_auth_user -from ralph.api.auth.oidc import get_oidc_user from ralph.backends.data.base import BaseOperationType from ralph.backends.data.clickhouse import ClickHouseDataBackend from ralph.backends.data.mongo import MongoDataBackend +from ralph.conf import AuthBackend from ralph.exceptions import BackendException from tests.fixtures.backends import ( @@ -32,7 +31,7 @@ get_mongo_test_backend, ) -from ..fixtures.auth import mock_basic_auth_user, mock_oidc_user +from ..fixtures.auth import AUDIENCE, ISSUER_URI, mock_basic_auth_user, mock_oidc_user from ..helpers import mock_activity, mock_agent @@ -807,17 +806,38 @@ async def test_api_statements_get_scopes( monkeypatch.setattr( "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True ) - monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + monkeypatch.setattr( + f"ralph.api.auth.{auth_method}.settings.LRS_RESTRICT_BY_SCOPES", True + ) + + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_AUTHORITY", True + ) + monkeypatch.setattr( + f"ralph.api.auth.{auth_method}.settings.LRS_RESTRICT_BY_AUTHORITY", True + ) if auth_method == "basic": agent = mock_agent("mbox", 1) credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) headers = {"Authorization": f"Basic {credentials}"} - app.dependency_overrides[get_authenticated_user] = get_basic_auth_user get_basic_auth_user.cache_clear() elif auth_method == "oidc": + monkeypatch.setenv("RUNSERVER_AUTH_BACKENDS", [AuthBackend.OIDC]) + monkeypatch.setattr( + "ralph.api.auth.settings.RUNSERVER_AUTH_BACKENDS", [AuthBackend.OIDC] + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + ISSUER_URI, + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + AUDIENCE, + ) + sub = "123|oidc" iss = "https://iss.example.com" agent = {"openid": f"{iss}/{sub}"} @@ -833,8 +853,6 @@ async def test_api_statements_get_scopes( "http://clientHost:8100", ) - app.dependency_overrides[get_authenticated_user] = get_oidc_user - statements = [ { "id": "be67b160-d958-4f51-b8b8-1892002dbac6", @@ -859,7 +877,6 @@ async def test_api_statements_get_scopes( "/xAPI/statements/", headers=headers, ) - if is_authorized: assert response.status_code == 200 assert response.json() == {"statements": [statements[1], statements[0]]} @@ -869,8 +886,6 @@ async def test_api_statements_get_scopes( "detail": 'Access not authorized to scope: "statements/read/mine".' } - app.dependency_overrides.pop(get_authenticated_user, None) - @pytest.mark.anyio @pytest.mark.parametrize( @@ -898,6 +913,7 @@ async def test_api_statements_get_scopes_with_authority( "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True ) monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + monkeypatch.setattr("ralph.api.auth.oidc.settings.LRS_RESTRICT_BY_SCOPES", True) agent = mock_agent("mbox", 1) agent_2 = mock_agent("mbox", 2) @@ -939,5 +955,3 @@ async def test_api_statements_get_scopes_with_authority( assert response.json() == {"statements": [statements[1], statements[0]]} else: assert response.json() == {"statements": [statements[0]]} - - app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index 1a040a0df..58fc2f79a 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -8,15 +8,18 @@ from httpx import AsyncClient from ralph.api import app -from ralph.api.auth import get_authenticated_user from ralph.api.auth.basic import get_basic_auth_user -from ralph.api.auth.oidc import get_oidc_user from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend -from ralph.conf import XapiForwardingConfigurationSettings +from ralph.conf import AuthBackend, XapiForwardingConfigurationSettings from ralph.exceptions import BackendException -from tests.fixtures.auth import mock_basic_auth_user, mock_oidc_user +from tests.fixtures.auth import ( + AUDIENCE, + ISSUER_URI, + mock_basic_auth_user, + mock_oidc_user, +) from tests.fixtures.backends import ( ES_TEST_FORWARDING_INDEX, ES_TEST_HOSTS, @@ -722,7 +725,6 @@ async def test_api_statements_post_scopes( credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) headers = {"Authorization": f"Basic {credentials}"} - app.dependency_overrides[get_authenticated_user] = get_basic_auth_user get_basic_auth_user.cache_clear() elif auth_method == "oidc": @@ -731,17 +733,19 @@ async def test_api_statements_post_scopes( oidc_token = mock_oidc_user(sub=sub, scopes=scopes) headers = {"Authorization": f"Bearer {oidc_token}"} + monkeypatch.setenv("RUNSERVER_AUTH_BACKENDS", [AuthBackend.OIDC]) + monkeypatch.setattr( + "ralph.api.auth.settings.RUNSERVER_AUTH_BACKENDS", [AuthBackend.OIDC] + ) monkeypatch.setattr( "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", - "http://providerHost:8080/auth/realms/real_name", + ISSUER_URI, ) monkeypatch.setattr( "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", - "http://clientHost:8100", + AUDIENCE, ) - app.dependency_overrides[get_authenticated_user] = get_oidc_user - statement = mock_statement() # NB: scopes are not linked to statements and backends, we therefore test with ES @@ -761,5 +765,3 @@ async def test_api_statements_post_scopes( assert response.json() == { "detail": 'Access not authorized to scope: "statements/write".' } - - app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index f65966f90..418d011f0 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -6,15 +6,18 @@ from httpx import AsyncClient from ralph.api import app -from ralph.api.auth import get_authenticated_user from ralph.api.auth.basic import get_basic_auth_user -from ralph.api.auth.oidc import get_oidc_user from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend -from ralph.conf import XapiForwardingConfigurationSettings +from ralph.conf import AuthBackend, XapiForwardingConfigurationSettings from ralph.exceptions import BackendException -from tests.fixtures.auth import mock_basic_auth_user, mock_oidc_user +from tests.fixtures.auth import ( + AUDIENCE, + ISSUER_URI, + mock_basic_auth_user, + mock_oidc_user, +) from tests.fixtures.backends import ( ES_TEST_FORWARDING_INDEX, ES_TEST_HOSTS, @@ -608,7 +611,6 @@ async def test_api_statements_put_scopes( credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) headers = {"Authorization": f"Basic {credentials}"} - app.dependency_overrides[get_authenticated_user] = get_basic_auth_user get_basic_auth_user.cache_clear() elif auth_method == "oidc": @@ -617,17 +619,19 @@ async def test_api_statements_put_scopes( oidc_token = mock_oidc_user(sub=sub, scopes=scopes) headers = {"Authorization": f"Bearer {oidc_token}"} + monkeypatch.setenv("RUNSERVER_AUTH_BACKENDS", [AuthBackend.OIDC]) + monkeypatch.setattr( + "ralph.api.auth.settings.RUNSERVER_AUTH_BACKENDS", [AuthBackend.OIDC] + ) monkeypatch.setattr( "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", - "http://providerHost:8080/auth/realms/real_name", + ISSUER_URI, ) monkeypatch.setattr( "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", - "http://clientHost:8100", + AUDIENCE, ) - app.dependency_overrides[get_authenticated_user] = get_oidc_user - statement = mock_statement() # NB: scopes are not linked to statements and backends, we therefore test with ES @@ -647,5 +651,3 @@ async def test_api_statements_put_scopes( assert response.json() == { "detail": 'Access not authorized to scope: "statements/write".' } - - app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/conftest.py b/tests/conftest.py index 281917e9e..033644d8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,11 +7,9 @@ from .fixtures.api import client # noqa: F401 from .fixtures.auth import ( # noqa: F401 basic_auth_credentials, - basic_auth_test_client, encoded_token, mock_discovery_response, mock_oidc_jwks, - oidc_auth_test_client, ) from .fixtures.backends import ( # noqa: F401 anyio_backend, diff --git a/tests/fixtures/auth.py b/tests/fixtures/auth.py index 7e44149b3..2b0872842 100644 --- a/tests/fixtures/auth.py +++ b/tests/fixtures/auth.py @@ -8,11 +8,9 @@ import pytest import responses from cryptography.hazmat.primitives import serialization -from fastapi.testclient import TestClient from jose import jwt from jose.utils import long_to_base64 -from ralph.api import app, get_authenticated_user from ralph.api.auth.basic import get_stored_credentials from ralph.api.auth.oidc import discover_provider, get_public_keys from ralph.conf import settings @@ -104,39 +102,6 @@ def basic_auth_credentials(fs, user_scopes=None, agent=None): return credentials -@pytest.fixture -def basic_auth_test_client(): - """Return a TestClient with HTTP basic authentication mode.""" - # pylint:disable=import-outside-toplevel - from ralph.api.auth.basic import ( - get_basic_auth_user, # pylint:disable=import-outside-toplevel - ) - - app.dependency_overrides[get_authenticated_user] = get_basic_auth_user - - with TestClient(app) as test_client: - yield test_client - - -@pytest.fixture -def oidc_auth_test_client(monkeypatch): - """Return a TestClient with OpenId Connect authentication mode.""" - # pylint:disable=import-outside-toplevel - monkeypatch.setattr( - "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", - ISSUER_URI, - ) - monkeypatch.setattr( - "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", - AUDIENCE, - ) - from ralph.api.auth.oidc import get_oidc_user - - app.dependency_overrides[get_authenticated_user] = get_oidc_user - with TestClient(app) as test_client: - yield test_client - - def _mock_discovery_response(): """Return an example discovery response.""" return { diff --git a/tests/helpers.py b/tests/helpers.py index c361f1dfd..bf08db3c0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,8 +7,11 @@ from typing import Optional, Union from uuid import UUID +from ralph.api.auth import AuthBackend from ralph.utils import statements_are_equivalent +from tests.fixtures.auth import AUDIENCE, ISSUER_URI + def string_is_date(string: str): """Check if string can be parsed as a date.""" @@ -197,3 +200,23 @@ def mock_statement( "object": object, "timestamp": timestamp, } + + +def configure_env_for_mock_oidc_auth(monkeypatch, runserver_auth_backends=None): + """Configure environment variables to simulate OIDC use.""" + + if runserver_auth_backends is None: + runserver_auth_backends = [AuthBackend.OIDC] + + monkeypatch.setenv("RUNSERVER_AUTH_BACKENDS", runserver_auth_backends) + monkeypatch.setattr( + "ralph.api.auth.settings.RUNSERVER_AUTH_BACKENDS", runserver_auth_backends + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + ISSUER_URI, + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + AUDIENCE, + ) diff --git a/tests/test_conf.py b/tests/test_conf.py index e9c681d79..676029fe1 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -9,6 +9,11 @@ from ralph.conf import CommaSeparatedTuple, Settings, settings from ralph.exceptions import ConfigurationException +# import os +# def test_env_dist(fs, monkeypatch): +# fs.create_file(".env", contents=os.read("../.env.dist")) +# Settings() + def test_conf_settings_field_value_priority(fs, monkeypatch): """Test that the Settings object field values are defined in the following From d436ad1f04377f3a597d65c9c039df9de1179398 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Tue, 31 Oct 2023 16:13:25 +0100 Subject: [PATCH 58/65] add missing import --- src/ralph/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 923407e93..90788d6ab 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -4,7 +4,7 @@ import sys from enum import Enum from pathlib import Path -from typing import List, Sequence, Union +from typing import List, Sequence, Union, Tuple from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator From ceaf70834e07a0f9a1dc97d27adfa44af94e61e0 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Tue, 31 Oct 2023 18:45:08 +0100 Subject: [PATCH 59/65] start migrating to pydantic v2 --- setup.cfg | 3 +- src/ralph/api/auth/oidc.py | 10 +- src/ralph/api/auth/user.py | 2 + src/ralph/api/models.py | 13 +- src/ralph/backends/conf.py | 5 +- src/ralph/backends/data/base.py | 13 +- src/ralph/backends/data/clickhouse.py | 17 ++- src/ralph/backends/data/es.py | 10 +- src/ralph/backends/data/fs.py | 2 + src/ralph/backends/data/ldp.py | 2 + src/ralph/backends/data/mongo.py | 23 ++-- src/ralph/backends/data/s3.py | 2 + src/ralph/backends/data/swift.py | 2 + src/ralph/backends/http/async_lrs.py | 16 ++- src/ralph/backends/http/base.py | 13 +- src/ralph/backends/lrs/base.py | 16 +-- src/ralph/backends/stream/base.py | 5 +- src/ralph/backends/stream/ws.py | 2 + src/ralph/conf.py | 122 +++++++++++------- src/ralph/models/edx/base.py | 29 ++--- src/ralph/models/edx/browser.py | 5 +- .../models/edx/enrollment/fields/events.py | 2 +- .../models/edx/navigational/fields/events.py | 9 +- .../models/edx/navigational/statements.py | 8 +- .../open_response_assessment/fields/events.py | 35 ++--- .../edx/peer_instruction/fields/events.py | 5 +- .../edx/problem_interaction/fields/events.py | 81 ++++++------ src/ralph/models/edx/video/fields/events.py | 5 +- src/ralph/models/edx/video/statements.py | 14 +- src/ralph/models/xapi/base/agents.py | 4 +- src/ralph/models/xapi/base/attachments.py | 4 +- src/ralph/models/xapi/base/common.py | 6 + src/ralph/models/xapi/base/contexts.py | 26 ++-- src/ralph/models/xapi/base/groups.py | 2 +- src/ralph/models/xapi/base/ifi.py | 5 +- src/ralph/models/xapi/base/objects.py | 8 +- src/ralph/models/xapi/base/results.py | 23 ++-- src/ralph/models/xapi/base/statements.py | 22 ++-- src/ralph/models/xapi/base/verbs.py | 2 +- src/ralph/models/xapi/config.py | 12 +- src/ralph/models/xapi/lms/objects.py | 2 +- src/ralph/models/xapi/video/results.py | 8 +- 42 files changed, 325 insertions(+), 270 deletions(-) diff --git a/setup.cfg b/setup.cfg index e5c1bf673..2f0c437d5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,8 @@ install_requires = ; By default, we only consider core dependencies required to use Ralph as a ; library (mostly models). langcodes>=3.2.0 - pydantic[dotenv,email]>=1.10.0, <2.0 + pydantic[dotenv,email]>=2.0 + pydantic-settings>=2.0 rfc3987>=1.3.0 package_dir = =src diff --git a/src/ralph/api/auth/oidc.py b/src/ralph/api/auth/oidc.py index f11cef628..d41b5687f 100644 --- a/src/ralph/api/auth/oidc.py +++ b/src/ralph/api/auth/oidc.py @@ -9,7 +9,7 @@ from fastapi.security import HTTPBearer, OpenIdConnect from jose import ExpiredSignatureError, JWTError, jwt from jose.exceptions import JWTClaimsError -from pydantic import AnyUrl, BaseModel, Extra +from pydantic import ConfigDict, AnyUrl, BaseModel from typing_extensions import Annotated from ralph.api.auth.user import AuthenticatedUser, UserScopes @@ -44,13 +44,11 @@ class IDToken(BaseModel): iss: str sub: str - aud: Optional[str] + aud: Optional[str] = None exp: int iat: int - scope: Optional[str] - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = Extra.ignore + scope: Optional[str] = None + model_config = ConfigDict(extra="ignore") @lru_cache() diff --git a/src/ralph/api/auth/user.py b/src/ralph/api/auth/user.py index 9d61f0c4d..a5476a212 100644 --- a/src/ralph/api/auth/user.py +++ b/src/ralph/api/auth/user.py @@ -55,6 +55,8 @@ def is_authorized(self, requested_scope: Scope): return requested_scope in expanded_user_scopes @classmethod + # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. + # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. def __get_validators__(cls): # noqa: D105 def validate(value: FrozenSet[Scope]): """Transform value to an instance of UserScopes.""" diff --git a/src/ralph/api/models.py b/src/ralph/api/models.py index 94a8ee3d1..3bf051e2b 100644 --- a/src/ralph/api/models.py +++ b/src/ralph/api/models.py @@ -6,7 +6,7 @@ from typing import Optional, Union from uuid import UUID -from pydantic import AnyUrl, BaseModel, Extra +from pydantic import ConfigDict, AnyUrl, BaseModel from ..models.xapi.base.agents import BaseXapiAgent from ..models.xapi.base.groups import BaseXapiGroup @@ -28,14 +28,7 @@ class BaseModelWithLaxConfig(BaseModel): Common base lax model to perform light input validation as we receive statements through the API. """ - - class Config: - """Enable extra properties. - - Useful for not having to perform comprehensive validation. - """ - - extra = Extra.allow + model_config = ConfigDict(extra="allow") class LaxObjectField(BaseModelWithLaxConfig): @@ -64,6 +57,6 @@ class LaxStatement(BaseModelWithLaxConfig): """ actor: Union[BaseXapiAgent, BaseXapiGroup] - id: Optional[UUID] + id: Optional[UUID] = None object: LaxObjectField verb: LaxVerbField diff --git a/src/ralph/backends/conf.py b/src/ralph/backends/conf.py index acdb3de87..14f1fa0dd 100644 --- a/src/ralph/backends/conf.py +++ b/src/ralph/backends/conf.py @@ -1,6 +1,6 @@ """Configurations for Ralph backends.""" -from pydantic import BaseModel, BaseSettings +from pydantic import BaseModel from ralph.backends.data.clickhouse import ClickHouseDataBackendSettings from ralph.backends.data.es import ESDataBackendSettings @@ -14,6 +14,7 @@ from ralph.backends.lrs.fs import FSLRSBackendSettings from ralph.backends.stream.ws import WSStreamBackendSettings from ralph.conf import BaseSettingsConfig, core_settings +from pydantic_settings import BaseSettings # Active Data backend Settings. @@ -79,6 +80,8 @@ class Backends(BaseModel): class BackendSettings(BaseSettings): """Pydantic model for Ralph's backends environment & configuration settings.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 4b9454819..96f56f5fd 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -7,10 +7,11 @@ from io import IOBase from typing import Iterable, Iterator, Optional, Union -from pydantic import BaseModel, BaseSettings, ValidationError +from pydantic import BaseModel, ValidationError from ralph.conf import BaseSettingsConfig, core_settings from ralph.exceptions import BackendParameterException +from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) @@ -18,6 +19,8 @@ class BaseDataBackendSettings(BaseSettings): """Data backend default configuration.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -28,13 +31,9 @@ class Config(BaseSettingsConfig): class BaseQuery(BaseModel): """Base query model.""" + model_config = SettingsConfigDict(env_prefix="RALPH_BACKENDS__DATA__", env_file=".env", env_file_encoding=core_settings.LOCALE_ENCODING, extra="forbid") - class Config: - """Base query model configuration.""" - - extra = "forbid" - - query_string: Union[str, None] + query_string: Union[str, None] = None @unique diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 81a978760..c9ffafc1d 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -20,7 +20,7 @@ import clickhouse_connect from clickhouse_connect.driver.exceptions import ClickHouseError -from pydantic import BaseModel, Json, ValidationError, conint +from pydantic import Field, BaseModel, Json, ValidationError from ralph.backends.data.base import ( BaseDataBackend, @@ -32,6 +32,7 @@ ) from ralph.conf import BaseSettingsConfig, ClientOptions from ralph.exceptions import BackendException, BackendParameterException +from typing_extensions import Annotated logger = logging.getLogger(__name__) @@ -47,7 +48,7 @@ class ClickHouseClientOptions(ClientOptions): """Pydantic model for `clickhouse` client options.""" date_time_input_format: str = "best_effort" - allow_experimental_object_type: conint(ge=0, le=1) = 1 + allow_experimental_object_type: Annotated[int, Field(ge=0, le=1)] = 1 class InsertTuple(NamedTuple): @@ -75,6 +76,8 @@ class ClickHouseDataBackendSettings(BaseDataBackendSettings): LOCALE_ENCODING (str): The locale encoding to use when none is provided. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -95,10 +98,10 @@ class BaseClickHouseQuery(BaseQuery): """Base ClickHouse query model.""" select: Union[str, List[str]] = "event" - where: Union[str, List[str], None] - parameters: Union[Dict, None] - limit: Union[int, None] - sort: Union[str, None] + where: Union[str, List[str], None] = None + parameters: Union[Dict, None] = None + limit: Union[int, None] = None + sort: Union[str, None] = None column_oriented: Union[bool, None] = False @@ -106,7 +109,7 @@ class ClickHouseQuery(BaseClickHouseQuery): """ClickHouse query model.""" # pylint: disable=unsubscriptable-object - query_string: Union[Json[BaseClickHouseQuery], None] + query_string: Union[Json[BaseClickHouseQuery], None] = None class ClickHouseDataBackend(BaseDataBackend): diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 05b6e1296..9c4556d8a 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -52,6 +52,8 @@ class ESDataBackendSettings(BaseDataBackendSettings): refreshed after the write operation. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -76,8 +78,8 @@ class ESQueryPit(BaseModel): time alive. """ - id: Union[str, None] - keep_alive: Union[str, None] + id: Union[str, None] = None + keep_alive: Union[str, None] = None class ESQuery(BaseQuery): @@ -103,9 +105,9 @@ class ESQuery(BaseQuery): query: dict = {"match_all": {}} pit: ESQueryPit = ESQueryPit() - size: Union[int, None] + size: Union[int, None] = None sort: Union[str, List[dict]] = "_shard_doc" - search_after: Union[list, None] + search_after: Union[list, None] = None track_total_hits: Literal[False] = False diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 1eb024cea..951908136 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -38,6 +38,8 @@ class FSDataBackendSettings(BaseDataBackendSettings): LOCALE_ENCODING (str): The encoding used for writing dictionaries to files. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index 83881c4ad..6fc5b3fbd 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -35,6 +35,8 @@ class LDPDataBackendSettings(BaseDataBackendSettings): SERVICE_NAME (str): The default LDP account name. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index 506beb46f..f0fd75801 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -13,7 +13,7 @@ from bson.errors import BSONError from bson.objectid import ObjectId from dateutil.parser import isoparse -from pydantic import Json, MongoDsn, constr +from pydantic import StringConstraints, Json, MongoDsn from pymongo import MongoClient, ReplaceOne from pymongo.collection import Collection from pymongo.errors import ( @@ -35,6 +35,7 @@ from ralph.conf import BaseSettingsConfig, ClientOptions from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import parse_bytes_to_dict, read_raw +from typing_extensions import Annotated logger = logging.getLogger(__name__) @@ -58,16 +59,18 @@ class MongoDataBackendSettings(BaseDataBackendSettings): LOCALE_ENCODING (str): The locale encoding to use when none is provided. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" env_prefix = "RALPH_BACKENDS__DATA__MONGO__" CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/", scheme="mongodb") - DEFAULT_DATABASE: constr(regex=r"^[^\s.$/\\\"\x00]+$") = "statements" # noqa : F722 - DEFAULT_COLLECTION: constr( - regex=r"^(?!.*\.\.)[^.$\x00]+(?:\.[^.$\x00]+)*$" # noqa : F722 - ) = "marsha" + DEFAULT_DATABASE: Annotated[str, StringConstraints(pattern=r"^[^\s.$/\\\"\x00]+$")] = "statements" # noqa : F722 + DEFAULT_COLLECTION: Annotated[str, StringConstraints( + pattern=r"^(?!.*\.\.)[^.$\x00]+(?:\.[^.$\x00]+)*$" # noqa : F722 + )] = "marsha" CLIENT_OPTIONS: MongoClientOptions = MongoClientOptions() DEFAULT_CHUNK_SIZE: int = 500 LOCALE_ENCODING: str = "utf8" @@ -76,17 +79,17 @@ class Config(BaseSettingsConfig): class BaseMongoQuery(BaseQuery): """Base MongoDB query model.""" - filter: Union[dict, None] - limit: Union[int, None] - projection: Union[dict, None] - sort: Union[List[Tuple], None] + filter: Union[dict, None] = None + limit: Union[int, None] = None + projection: Union[dict, None] = None + sort: Union[List[Tuple], None] = None class MongoQuery(BaseMongoQuery): """MongoDB query model.""" # pylint: disable=unsubscriptable-object - query_string: Union[Json[BaseMongoQuery], None] + query_string: Union[Json[BaseMongoQuery], None] = None class MongoDataBackend(BaseDataBackend): diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index 4198b9b5f..d5598b4de 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -50,6 +50,8 @@ class S3DataBackendSettings(BaseDataBackendSettings): LOCALE_ENCODING (str): The encoding used for writing dictionaries to objects. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index ac3b0ef9b..c43dd3ae7 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -43,6 +43,8 @@ class SwiftDataBackendSettings(BaseDataBackendSettings): LOCALE_ENCODING (str): The encoding used for reading/writing documents. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/backends/http/async_lrs.py b/src/ralph/backends/http/async_lrs.py index 8309397a0..ada37de73 100644 --- a/src/ralph/backends/http/async_lrs.py +++ b/src/ralph/backends/http/async_lrs.py @@ -52,6 +52,8 @@ class LRSHTTPBackendSettings(BaseHTTPBackendSettings): STATEMENTS_ENDPOINT (str): Default endpoint for LRS statements resource. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -69,7 +71,7 @@ class StatementResponse(BaseModel): """Pydantic model for `get` statements response.""" statements: Union[List[dict], dict] - more: Optional[str] + more: Optional[str] = None class LRSStatementsQuery(BaseQuery): @@ -83,14 +85,14 @@ class LRSStatementsQuery(BaseQuery): statement_id: Optional[str] = Field(None, alias="statementId") voided_statement_id: Optional[str] = Field(None, alias="voidedStatementId") - agent: Optional[Union[BaseXapiAgent, BaseXapiGroup]] - verb: Optional[IRI] - activity: Optional[IRI] - registration: Optional[UUID] + agent: Optional[Union[BaseXapiAgent, BaseXapiGroup]] = None + verb: Optional[IRI] = None + activity: Optional[IRI] = None + registration: Optional[UUID] = None related_activities: Optional[bool] = False related_agents: Optional[bool] = False - since: Optional[datetime] - until: Optional[datetime] + since: Optional[datetime] = None + until: Optional[datetime] = None limit: Optional[NonNegativeInt] = 0 format: Optional[Literal["ids", "exact", "canonical"]] = "exact" attachments: Optional[bool] = False diff --git a/src/ralph/backends/http/base.py b/src/ralph/backends/http/base.py index ae5003b35..ff9bd22aa 100644 --- a/src/ralph/backends/http/base.py +++ b/src/ralph/backends/http/base.py @@ -6,11 +6,12 @@ from enum import Enum, unique from typing import Iterator, List, Optional, Union -from pydantic import BaseModel, BaseSettings, ValidationError +from pydantic import BaseModel, ValidationError from pydantic.types import PositiveInt from ralph.conf import BaseSettingsConfig, core_settings from ralph.exceptions import BackendParameterException +from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) @@ -18,6 +19,8 @@ class BaseHTTPBackendSettings(BaseSettings): """Data backend default configuration.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -69,13 +72,9 @@ def wrapper(*args, **kwargs): class BaseQuery(BaseModel): """Base query model.""" + model_config = SettingsConfigDict(env_prefix="RALPH_BACKENDS__HTTP__", env_file=".env", env_file_encoding=core_settings.LOCALE_ENCODING, extra="forbid") - class Config: - """Base query model configuration.""" - - extra = "forbid" - - query_string: Optional[str] + query_string: Optional[str] = None class BaseHTTPBackend(ABC): diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 0cf552f25..96d434afd 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -33,21 +33,21 @@ class AgentParameters(BaseModel): NB: Agent refers to the data structure, NOT to the LRS query parameter. """ - mbox: Optional[str] - mbox_sha1sum: Optional[str] - openid: Optional[str] - account__name: Optional[str] - account__home_page: Optional[str] + mbox: Optional[str] = None + mbox_sha1sum: Optional[str] = None + openid: Optional[str] = None + account__name: Optional[str] = None + account__home_page: Optional[str] = None class RalphStatementsQuery(LRSStatementsQuery): """Represents a dictionary of possible LRS query parameters.""" agent: Optional[AgentParameters] = AgentParameters.construct() - search_after: Optional[str] - pit_id: Optional[str] + search_after: Optional[str] = None + pit_id: Optional[str] = None authority: Optional[AgentParameters] = AgentParameters.construct() - ignore_order: Optional[bool] + ignore_order: Optional[bool] = None class BaseLRSBackend(BaseDataBackend): diff --git a/src/ralph/backends/stream/base.py b/src/ralph/backends/stream/base.py index 008e68d81..cc2845e2a 100644 --- a/src/ralph/backends/stream/base.py +++ b/src/ralph/backends/stream/base.py @@ -3,14 +3,15 @@ from abc import ABC, abstractmethod from typing import BinaryIO -from pydantic import BaseSettings - from ralph.conf import BaseSettingsConfig, core_settings +from pydantic_settings import BaseSettings class BaseStreamBackendSettings(BaseSettings): """Data backend default configuration.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/backends/stream/ws.py b/src/ralph/backends/stream/ws.py index 2f70651cf..c201dffec 100644 --- a/src/ralph/backends/stream/ws.py +++ b/src/ralph/backends/stream/ws.py @@ -20,6 +20,8 @@ class WSStreamBackendSettings(BaseStreamBackendSettings): URI (str): The URI to connect to. """ + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 90788d6ab..9414a39bf 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -4,13 +4,14 @@ import sys from enum import Enum from pathlib import Path -from typing import List, Sequence, Union, Tuple +from typing import Annotated, List, Optional, Sequence, Union, Tuple -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator +from pydantic import AfterValidator, model_validator, ConfigDict, AnyHttpUrl, AnyUrl, BaseModel from ralph.exceptions import ConfigurationException from .utils import import_string +from pydantic_settings import BaseSettings, SettingsConfigDict if sys.version_info >= (3, 8): from typing import Literal @@ -42,6 +43,8 @@ class BaseSettingsConfig: class CoreSettings(BaseSettings): """Pydantic model for Ralph's core settings.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -52,29 +55,44 @@ class Config(BaseSettingsConfig): core_settings = CoreSettings() -class CommaSeparatedTuple(str): - """Pydantic field type validating comma separated strings or lists/tuples.""" +# class CommaSeparatedTuple(str): +# """Pydantic field type validating comma separated strings or lists/tuples.""" - @classmethod - def __get_validators__(cls): # noqa: D105 - def validate(value: Union[str, Sequence[str]]) -> Sequence[str]: - """Check whether the value is a comma separated string or a list/tuple.""" - if isinstance(value, (tuple, list)): - return tuple(value) +# @classmethod +# # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. +# # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. +# def __get_validators__(cls): # noqa: D105 +# def validate(value: Union[str, Sequence[str]]) -> Sequence[str]: +# """Check whether the value is a comma separated string or a list/tuple.""" +# if isinstance(value, (tuple, list)): +# return tuple(value) + +# if isinstance(value, str): +# return tuple(value.split(",")) + +# raise TypeError("Invalid comma separated list") + +# yield validate - if isinstance(value, str): - return tuple(value.split(",")) +def validate_comma_separated_tuple(value: Union[str, Tuple[str, ...]]) -> Tuple[str]: + """Checks whether the value is a comma separated string or a tuple.""" - raise TypeError("Invalid comma separated list") + if isinstance(value, tuple): + return value - yield validate + if isinstance(value, str): + return tuple(value.split(",")) + + raise TypeError("Invalid comma separated list") + +CommaSeparatedTuple = Annotated[Union[str, Tuple[str, ...]], AfterValidator(validate_comma_separated_tuple)] class InstantiableSettingsItem(BaseModel): """Pydantic model for a settings configuration item that can be instantiated.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - underscore_attrs_are_private = True + # TODO[pydantic]: The following keys were removed: `underscore_attrs_are_private`. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. + model_config = SettingsConfigDict(underscore_attrs_are_private=True) _class_path: str = None @@ -85,16 +103,12 @@ def get_instance(self, **init_parameters): class ClientOptions(BaseModel): """Pydantic model for additional client options.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class HeadersParameters(BaseModel): """Pydantic model for headers parameters.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = Extra.allow + model_config = ConfigDict(extra="allow") # Active parser Settings. @@ -121,9 +135,7 @@ class ParserSettings(BaseModel): class XapiForwardingConfigurationSettings(BaseModel): """Pydantic model for xAPI forwarding configuration item.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - min_anystr_length = 1 + model_config = ConfigDict(str_min_length=1) url: AnyUrl is_active: bool @@ -140,32 +152,50 @@ class AuthBackend(Enum): OIDC = "OIDC" -class AuthBackends(str): - """Model representing a list of authentication backends.""" +# class AuthBackends(str): +# """Model representing a list of authentication backends.""" - @classmethod - def __get_validators__(cls): # noqa: D105 - """Checks whether the value is a comma separated string or a tuple representing - an AuthBackend.""" +# @classmethod +# # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. +# # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. +# def __get_validators__(cls): # noqa: D105 +# """Checks whether the value is a comma separated string or a tuple representing +# an AuthBackend.""" + +# def validate( +# value: Union[AuthBackend, Tuple[AuthBackend], List[AuthBackend]] +# ) -> Tuple[AuthBackend]: +# """Check whether the value is a comma separated string or a list/tuple.""" +# if isinstance(value, (tuple, list)): +# return tuple(AuthBackend(value)) + +# if isinstance(value, str): +# return tuple(AuthBackend(val) for val in value.split(",")) + +# raise TypeError("Invalid comma separated list") + +# yield validate - def validate( - value: Union[AuthBackend, Tuple[AuthBackend], List[AuthBackend]] - ) -> Tuple[AuthBackend]: - """Check whether the value is a comma separated string or a list/tuple.""" - if isinstance(value, (tuple, list)): - return tuple(AuthBackend(value)) +def validate_auth_backends( + value: Union[AuthBackend, Tuple[AuthBackend], List[AuthBackend]] +) -> Tuple[AuthBackend]: + """Check whether the value is a comma separated string or a list/tuple.""" + if isinstance(value, (tuple, list)): + return tuple(AuthBackend(value)) - if isinstance(value, str): - return tuple(AuthBackend(val) for val in value.split(",")) + if isinstance(value, str): + return tuple(AuthBackend(val) for val in value.split(",")) - raise TypeError("Invalid comma separated list") + raise TypeError("Invalid comma separated list") - yield validate +AuthBackends = Annotated[Union[str, Tuple[str, ...], List[str]], AfterValidator(validate_auth_backends)] class Settings(BaseSettings): """Pydantic model for Ralph's global environment & configuration settings.""" + # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. class Config(BaseSettingsConfig): """Pydantic Configuration.""" @@ -174,9 +204,9 @@ class Config(BaseSettingsConfig): _CORE: CoreSettings = core_settings AUTH_FILE: Path = _CORE.APP_DIR / "auth.json" - AUTH_CACHE_MAX_SIZE = 100 - AUTH_CACHE_TTL = 3600 - CONVERTER_EDX_XAPI_UUID_NAMESPACE: str = None + AUTH_CACHE_MAX_SIZE: int = 100 + AUTH_CACHE_TTL: int = 3600 + CONVERTER_EDX_XAPI_UUID_NAMESPACE: Optional[str] = None DEFAULT_BACKEND_CHUNK_SIZE: int = 500 EXECUTION_ENVIRONMENT: str = "development" HISTORY_FILE: Path = _CORE.APP_DIR / "history.json" @@ -240,7 +270,7 @@ def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name """Return Ralph's default locale encoding.""" return self._CORE.LOCALE_ENCODING - @root_validator(allow_reuse=True) + @model_validator(mode='after') @classmethod def check_restriction_compatibility(cls, values): """Raise an error if scopes are being used without authority restriction.""" diff --git a/src/ralph/models/edx/base.py b/src/ralph/models/edx/base.py index 7a38e2487..4c37c1fd4 100644 --- a/src/ralph/models/edx/base.py +++ b/src/ralph/models/edx/base.py @@ -6,7 +6,8 @@ from pathlib import Path from typing import Dict, Optional, Union -from pydantic import AnyHttpUrl, BaseModel, constr +from pydantic import StringConstraints, ConfigDict, AnyHttpUrl, BaseModel +from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -16,9 +17,7 @@ class BaseModelWithConfig(BaseModel): """Pydantic model for base configuration shared among all models.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = "forbid" + model_config = ConfigDict(extra="forbid") class ContextModuleField(BaseModelWithConfig): @@ -29,14 +28,14 @@ class ContextModuleField(BaseModelWithConfig): display_name (str): Consists of a short description or title of the component. """ - usage_key: constr(regex=r"^block-v1:.+\+.+\+.+type@.+@[a-f0-9]{32}$") # noqa:F722 + usage_key: Annotated[str, StringConstraints(pattern=r"^block-v1:.+\+.+\+.+type@.+@[a-f0-9]{32}$")] # noqa:F722 display_name: str original_usage_key: Optional[ - constr( - regex=r"^block-v1:.+\+.+\+.+type@problem\+block@[a-f0-9]{32}$" # noqa:F722 - ) - ] - original_usage_version: Optional[str] + Annotated[str, StringConstraints( + pattern=r"^block-v1:.+\+.+\+.+type@problem\+block@[a-f0-9]{32}$" # noqa:F722 + )] + ] = None + original_usage_version: Optional[str] = None class BaseContextField(BaseModelWithConfig): @@ -81,12 +80,12 @@ class BaseContextField(BaseModelWithConfig): `request.META['PATH_INFO']` """ - course_id: constr(regex=r"^$|^course-v1:.+\+.+\+.+$") # noqa:F722 - course_user_tags: Optional[Dict[str, str]] - module: Optional[ContextModuleField] + course_id: Annotated[str, StringConstraints(pattern=r"^$|^course-v1:.+\+.+\+.+$")] # noqa:F722 + course_user_tags: Optional[Dict[str, str]] = None + module: Optional[ContextModuleField] = None org_id: str path: Path - user_id: Union[int, Literal[""], None] + user_id: Union[int, Literal[""], None] = None class AbstractBaseEventField(BaseModelWithConfig): @@ -151,7 +150,7 @@ class BaseEdxModel(BaseModelWithConfig): In JSON the value is `null` instead of `None`. """ - username: Union[constr(min_length=2, max_length=30), Literal[""]] + username: Union[Annotated[str, StringConstraints(min_length=2, max_length=30)], Literal[""]] ip: Union[IPv4Address, Literal[""]] agent: str host: str diff --git a/src/ralph/models/edx/browser.py b/src/ralph/models/edx/browser.py index fdf230473..f40572540 100644 --- a/src/ralph/models/edx/browser.py +++ b/src/ralph/models/edx/browser.py @@ -3,9 +3,10 @@ import sys from typing import Union -from pydantic import AnyUrl, constr +from pydantic import StringConstraints, AnyUrl from .base import BaseEdxModel +from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -29,4 +30,4 @@ class BaseBrowserModel(BaseEdxModel): event_source: Literal["browser"] page: AnyUrl - session: Union[constr(regex=r"^[a-f0-9]{32}$"), Literal[""]] # noqa: F722 + session: Union[Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}$")], Literal[""]] # noqa: F722 diff --git a/src/ralph/models/edx/enrollment/fields/events.py b/src/ralph/models/edx/enrollment/fields/events.py index 60ef60c62..012e7be31 100644 --- a/src/ralph/models/edx/enrollment/fields/events.py +++ b/src/ralph/models/edx/enrollment/fields/events.py @@ -28,4 +28,4 @@ class EnrollmentEventField(AbstractBaseEventField): mode: Union[ Literal["audit"], Literal["honor"], Literal["professional"], Literal["verified"] ] - user_id: Union[int, Literal[""], None] + user_id: Union[int, Literal[""], None] = None diff --git a/src/ralph/models/edx/navigational/fields/events.py b/src/ralph/models/edx/navigational/fields/events.py index d13531978..fb345fc97 100644 --- a/src/ralph/models/edx/navigational/fields/events.py +++ b/src/ralph/models/edx/navigational/fields/events.py @@ -1,8 +1,9 @@ """Navigational event field definition.""" -from pydantic import constr +from pydantic import StringConstraints from ...base import AbstractBaseEventField +from typing_extensions import Annotated class NavigationalEventField(AbstractBaseEventField): @@ -20,11 +21,11 @@ class NavigationalEventField(AbstractBaseEventField): being navigated away from. """ - id: constr( - regex=( + id: Annotated[str, StringConstraints( + pattern=( r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type" # noqa : F722 r"@sequential\+block@[a-f0-9]{32}$" # noqa : F722 ) - ) + )] new: int old: int diff --git a/src/ralph/models/edx/navigational/statements.py b/src/ralph/models/edx/navigational/statements.py index 65376244e..ed41c1563 100644 --- a/src/ralph/models/edx/navigational/statements.py +++ b/src/ralph/models/edx/navigational/statements.py @@ -3,7 +3,7 @@ import sys from typing import Union -from pydantic import Json, validator +from pydantic import field_validator, Json from ralph.models.selector import selector @@ -75,7 +75,8 @@ class UISeqNext(BaseBrowserModel): event_type: Literal["seq_next"] name: Literal["seq_next"] - @validator("event") + @field_validator("event") + @classmethod @classmethod def validate_next_jump_event_field( cls, value: Union[Json[NavigationalEventField], NavigationalEventField] @@ -107,7 +108,8 @@ class UISeqPrev(BaseBrowserModel): event_type: Literal["seq_prev"] name: Literal["seq_prev"] - @validator("event") + @field_validator("event") + @classmethod @classmethod def validate_prev_jump_event_field( cls, value: Union[Json[NavigationalEventField], NavigationalEventField] diff --git a/src/ralph/models/edx/open_response_assessment/fields/events.py b/src/ralph/models/edx/open_response_assessment/fields/events.py index e5b658dde..cfde8da85 100644 --- a/src/ralph/models/edx/open_response_assessment/fields/events.py +++ b/src/ralph/models/edx/open_response_assessment/fields/events.py @@ -5,9 +5,10 @@ from typing import Dict, List, Optional, Union from uuid import UUID -from pydantic import constr +from pydantic import StringConstraints from ralph.models.edx.base import AbstractBaseEventField, BaseModelWithConfig +from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -29,15 +30,15 @@ class ORAGetPeerSubmissionEventField(AbstractBaseEventField): available. """ - course_id: constr(max_length=255) - item_id: constr( - regex=( + course_id: Annotated[str, StringConstraints(max_length=255)] + item_id: Annotated[str, StringConstraints( + pattern=( r"^block-v1:.+\+.+\+.+type@openassessment" # noqa : F722 r"+block@[a-f0-9]{32}$" # noqa : F722 ) - ) + )] requesting_student_id: str - submission_returned_uuid: Union[str, None] + submission_returned_uuid: Union[str, None] = None class ORAGetSubmissionForStaffGradingEventField(AbstractBaseEventField): @@ -57,13 +58,13 @@ class ORAGetSubmissionForStaffGradingEventField(AbstractBaseEventField): Currently, set to `full-grade`. """ - item_id: constr( - regex=( + item_id: Annotated[str, StringConstraints( + pattern=( r"^block-v1:.+\+.+\+.+type@openassessment" # noqa : F722 r"+block@[a-f0-9]{32}$" # noqa : F722 ) - ) - submission_returned_uuid: Union[str, None] + )] + submission_returned_uuid: Union[str, None] = None requesting_staff_id: str type: Literal["full-grade"] @@ -93,7 +94,7 @@ class ORAAssessEventPartsField(BaseModelWithConfig): option: str criterion: ORAAssessEventPartsCriterionField - feedback: Optional[str] + feedback: Optional[str] = None class ORAAssessEventRubricField(BaseModelWithConfig): @@ -109,7 +110,7 @@ class ORAAssessEventRubricField(BaseModelWithConfig): assess the response. """ - content_hash: constr(regex=r"^[a-f0-9]{1,40}$") # noqa: F722 + content_hash: Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{1,40}$")] # noqa: F722 class ORAAssessEventField(AbstractBaseEventField): @@ -138,7 +139,7 @@ class ORAAssessEventField(AbstractBaseEventField): parts: List[ORAAssessEventPartsField] rubric: ORAAssessEventRubricField scored_at: datetime - scorer_id: constr(max_length=40) + scorer_id: Annotated[str, StringConstraints(max_length=40)] score_type: Literal["PE", "SE", "ST"] submission_uuid: UUID @@ -187,8 +188,8 @@ class ORACreateSubmissionEventAnswerField(BaseModelWithConfig): """ parts: List[Dict[Literal["text"], str]] - file_keys: Optional[List[str]] - files_descriptions: Optional[List[str]] + file_keys: Optional[List[str]] = None + files_descriptions: Optional[List[str]] = None class ORACreateSubmissionEventField(AbstractBaseEventField): @@ -223,7 +224,7 @@ class ORASaveSubmissionEventSavedResponseField(BaseModelWithConfig): """ text: str - file_upload_key: Optional[str] + file_upload_key: Optional[str] = None class ORASaveSubmissionEventField(AbstractBaseEventField): @@ -270,6 +271,6 @@ class ORAUploadFileEventField(BaseModelWithConfig): fileType (str): Consists of the MIME type of the uploaded file. """ - fileName: constr(max_length=255) + fileName: Annotated[str, StringConstraints(max_length=255)] fileSize: int fileType: str diff --git a/src/ralph/models/edx/peer_instruction/fields/events.py b/src/ralph/models/edx/peer_instruction/fields/events.py index 83b8af10e..ad30f6294 100644 --- a/src/ralph/models/edx/peer_instruction/fields/events.py +++ b/src/ralph/models/edx/peer_instruction/fields/events.py @@ -1,8 +1,9 @@ """Peer instruction event field definition.""" -from pydantic import constr +from pydantic import StringConstraints from ...base import AbstractBaseEventField +from typing_extensions import Annotated class PeerInstructionEventField(AbstractBaseEventField): @@ -18,5 +19,5 @@ class PeerInstructionEventField(AbstractBaseEventField): """ answer: int - rationale: constr(max_length=12500) + rationale: Annotated[str, StringConstraints(max_length=12500)] truncated: bool diff --git a/src/ralph/models/edx/problem_interaction/fields/events.py b/src/ralph/models/edx/problem_interaction/fields/events.py index 5f0b868a4..b4cf4119f 100644 --- a/src/ralph/models/edx/problem_interaction/fields/events.py +++ b/src/ralph/models/edx/problem_interaction/fields/events.py @@ -4,9 +4,10 @@ from datetime import datetime from typing import Dict, List, Optional, Union -from pydantic import constr +from pydantic import StringConstraints from ...base import AbstractBaseEventField, BaseModelWithConfig +from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -41,13 +42,13 @@ class CorrectMap(BaseModelWithConfig): queuestate (json): see QueueStateField. """ - answervariable: Union[Literal[None], None, str] + answervariable: Union[Literal[None], None, str] = None correctness: Union[Literal["correct"], Literal["incorrect"]] - hint: Optional[str] - hintmode: Optional[Union[Literal["on_request"], Literal["always"]]] + hint: Optional[str] = None + hintmode: Optional[Union[Literal["on_request"], Literal["always"]]] = None msg: str - npoints: Optional[int] - queuestate: Optional[QueueState] + npoints: Optional[int] = None + queuestate: Optional[QueueState] = None class State(BaseModelWithConfig): @@ -62,10 +63,10 @@ class State(BaseModelWithConfig): """ correct_map: Dict[ - constr(regex=r"^[a-f0-9]{32}_[0-9]_[0-9]$"), # noqa : F722 + Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 CorrectMap, ] - done: Optional[bool] + done: Optional[bool] = None input_state: dict seed: int student_answers: dict @@ -135,7 +136,7 @@ class EdxProblemHintFeedbackDisplayedEventField(AbstractBaseEventField): `student_answer` response. Consists either of `single` or `compound` value. """ - choice_all: Optional[List[str]] + choice_all: Optional[List[str]] = None correctness: bool hint_label: str hints: List[dict] @@ -170,23 +171,23 @@ class ProblemCheckEventField(AbstractBaseEventField): """ answers: Dict[ - constr(regex=r"^[a-f0-9]{32}_[0-9]_[0-9]$"), # noqa : F722 + Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 Union[List[str], str], ] attempts: int correct_map: Dict[ - constr(regex=r"^[a-f0-9]{32}_[0-9]_[0-9]$"), # noqa : F722 + Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 CorrectMap, ] grade: int max_grade: int - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] state: State submission: Dict[ - constr(regex=r"^[a-f0-9]{32}_[0-9]_[0-9]$"), # noqa : F722 + Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 SubmissionAnswerField, ] success: Union[Literal["correct"], Literal["incorrect"]] @@ -204,14 +205,14 @@ class ProblemCheckFailEventField(AbstractBaseEventField): """ answers: Dict[ - constr(regex=r"^[a-f0-9]{32}_[0-9]_[0-9]$"), # noqa : F722 + Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 Union[List[str], str], ] failure: Union[Literal["closed"], Literal["unreset"]] - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] state: State @@ -235,10 +236,10 @@ class ProblemRescoreEventField(AbstractBaseEventField): new_total: int orig_score: int orig_total: int - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] state: State success: Union[Literal["correct"], Literal["incorrect"]] @@ -253,10 +254,10 @@ class ProblemRescoreFailEventField(AbstractBaseEventField): """ failure: Union[Literal["closed"], Literal["unreset"]] - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] state: State @@ -293,10 +294,10 @@ class ResetProblemEventField(AbstractBaseEventField): new_state: State old_state: State - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] class ResetProblemFailEventField(AbstractBaseEventField): @@ -310,10 +311,10 @@ class ResetProblemFailEventField(AbstractBaseEventField): failure: Union[Literal["closed"], Literal["not_done"]] old_state: State - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] class SaveProblemFailEventField(AbstractBaseEventField): @@ -329,10 +330,10 @@ class SaveProblemFailEventField(AbstractBaseEventField): answers: Dict[str, Union[int, str, list, dict]] failure: Union[Literal["closed"], Literal["done"]] - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] state: State @@ -347,10 +348,10 @@ class SaveProblemSuccessEventField(AbstractBaseEventField): """ answers: Dict[str, Union[int, str, list, dict]] - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] state: State @@ -361,7 +362,7 @@ class ShowAnswerEventField(AbstractBaseEventField): problem_id (str): Consists of the ID of the problem being shown. """ - problem_id: constr( - regex=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + problem_id: Annotated[str, StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - ) + )] diff --git a/src/ralph/models/edx/video/fields/events.py b/src/ralph/models/edx/video/fields/events.py index e2533e167..6a9c926c8 100644 --- a/src/ralph/models/edx/video/fields/events.py +++ b/src/ralph/models/edx/video/fields/events.py @@ -3,6 +3,7 @@ import sys from ...base import AbstractBaseEventField +from pydantic import ConfigDict if sys.version_info >= (3, 8): from typing import Literal @@ -19,9 +20,7 @@ class VideoBaseEventField(AbstractBaseEventField): id (str): Consists of the additional videos name if given by the course creators, or the system-generated hash code otherwise. """ - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = "allow" + model_config = ConfigDict(extra="allow") code: str id: str diff --git a/src/ralph/models/edx/video/statements.py b/src/ralph/models/edx/video/statements.py index 19ed0cdc2..f0be1e45d 100644 --- a/src/ralph/models/edx/video/statements.py +++ b/src/ralph/models/edx/video/statements.py @@ -66,7 +66,7 @@ class UIPlayVideo(BaseBrowserModel): PlayVideoEventField, ] event_type: Literal["play_video"] - name: Optional[Literal["play_video", "edx.video.played"]] + name: Optional[Literal["play_video", "edx.video.played"]] = None class UIPauseVideo(BaseBrowserModel): @@ -88,7 +88,7 @@ class UIPauseVideo(BaseBrowserModel): PauseVideoEventField, ] event_type: Literal["pause_video"] - name: Optional[Literal["pause_video", "edx.video.paused"]] + name: Optional[Literal["pause_video", "edx.video.paused"]] = None class UISeekVideo(BaseBrowserModel): @@ -111,7 +111,7 @@ class UISeekVideo(BaseBrowserModel): SeekVideoEventField, ] event_type: Literal["seek_video"] - name: Optional[Literal["seek_video", "edx.video.position.changed"]] + name: Optional[Literal["seek_video", "edx.video.position.changed"]] = None class UIStopVideo(BaseBrowserModel): @@ -133,7 +133,7 @@ class UIStopVideo(BaseBrowserModel): StopVideoEventField, ] event_type: Literal["stop_video"] - name: Optional[Literal["stop_video", "edx.video.stopped"]] + name: Optional[Literal["stop_video", "edx.video.stopped"]] = None class UIHideTranscript(BaseBrowserModel): @@ -200,7 +200,7 @@ class UISpeedChangeVideo(BaseBrowserModel): SpeedChangeVideoEventField, ] event_type: Literal["speed_change_video"] - name: Optional[Literal["speed_change_video"]] + name: Optional[Literal["speed_change_video"]] = None class UIVideoHideCCMenu(BaseBrowserModel): @@ -221,7 +221,7 @@ class UIVideoHideCCMenu(BaseBrowserModel): VideoBaseEventField, ] event_type: Literal["video_hide_cc_menu"] - name: Optional[Literal["video_hide_cc_menu"]] + name: Optional[Literal["video_hide_cc_menu"]] = None class UIVideoShowCCMenu(BaseBrowserModel): @@ -244,4 +244,4 @@ class UIVideoShowCCMenu(BaseBrowserModel): VideoBaseEventField, ] event_type: Literal["video_show_cc_menu"] - name: Optional[Literal["video_show_cc_menu"]] + name: Optional[Literal["video_show_cc_menu"]] = None diff --git a/src/ralph/models/xapi/base/agents.py b/src/ralph/models/xapi/base/agents.py index 9f6ce53f5..1c7761d2b 100644 --- a/src/ralph/models/xapi/base/agents.py +++ b/src/ralph/models/xapi/base/agents.py @@ -43,8 +43,8 @@ class BaseXapiAgentCommonProperties(BaseModelWithConfig, ABC): name (str): Consists of the full name of the Agent. """ - objectType: Optional[Literal["Agent"]] - name: Optional[StrictStr] + objectType: Optional[Literal["Agent"]] = None + name: Optional[StrictStr] = None class BaseXapiAgentWithMbox(BaseXapiAgentCommonProperties, BaseXapiMboxIFI): diff --git a/src/ralph/models/xapi/base/attachments.py b/src/ralph/models/xapi/base/attachments.py index 91ffdf93a..7ae7d37cb 100644 --- a/src/ralph/models/xapi/base/attachments.py +++ b/src/ralph/models/xapi/base/attachments.py @@ -23,8 +23,8 @@ class BaseXapiAttachment(BaseModelWithConfig): usageType: IRI display: LanguageMap - description: Optional[LanguageMap] + description: Optional[LanguageMap] = None contentType: str length: int sha2: str - fileUrl: Optional[AnyUrl] + fileUrl: Optional[AnyUrl] = None diff --git a/src/ralph/models/xapi/base/common.py b/src/ralph/models/xapi/base/common.py index 14c27d1e7..719a96162 100644 --- a/src/ralph/models/xapi/base/common.py +++ b/src/ralph/models/xapi/base/common.py @@ -11,6 +11,8 @@ class IRI(str): """Pydantic custom data type validating RFC 3987 IRIs.""" @classmethod + # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. + # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. def __get_validators__(cls) -> Generator: # noqa: D105 def validate(iri: str) -> Type["IRI"]: """Check whether the provided IRI is a valid RFC 3987 IRI.""" @@ -24,6 +26,8 @@ class LanguageTag(str): """Pydantic custom data type validating RFC 5646 Language tags.""" @classmethod + # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. + # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. def __get_validators__(cls) -> Generator: # noqa: D105 def validate(tag: str) -> Type["LanguageTag"]: """Check whether the provided tag is a valid RFC 5646 Language tag.""" @@ -41,6 +45,8 @@ class MailtoEmail(str): """Pydantic custom data type validating `mailto:email` format.""" @classmethod + # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. + # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. def __get_validators__(cls) -> Generator: # noqa: D105 def validate(mailto: str) -> Type["MailtoEmail"]: """Check whether the provided value follows the `mailto:email` format.""" diff --git a/src/ralph/models/xapi/base/contexts.py b/src/ralph/models/xapi/base/contexts.py index febd78754..4e6e8088f 100644 --- a/src/ralph/models/xapi/base/contexts.py +++ b/src/ralph/models/xapi/base/contexts.py @@ -25,10 +25,10 @@ class BaseXapiContextContextActivities(BaseModelWithConfig): properties. """ - parent: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] - grouping: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] - category: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] - other: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] + parent: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] = None + grouping: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] = None + category: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] = None + other: Optional[Union[BaseXapiActivity, List[BaseXapiActivity]]] = None class BaseXapiContext(BaseModelWithConfig): @@ -46,12 +46,12 @@ class BaseXapiContext(BaseModelWithConfig): extensions (dict): Consists of a dictionary of other properties as needed. """ - registration: Optional[UUID] - instructor: Optional[BaseXapiAgent] - team: Optional[BaseXapiGroup] - contextActivities: Optional[BaseXapiContextContextActivities] - revision: Optional[StrictStr] - platform: Optional[StrictStr] - language: Optional[LanguageTag] - statement: Optional[BaseXapiStatementRef] - extensions: Optional[Dict[IRI, Union[str, int, bool, list, dict, None]]] + registration: Optional[UUID] = None + instructor: Optional[BaseXapiAgent] = None + team: Optional[BaseXapiGroup] = None + contextActivities: Optional[BaseXapiContextContextActivities] = None + revision: Optional[StrictStr] = None + platform: Optional[StrictStr] = None + language: Optional[LanguageTag] = None + statement: Optional[BaseXapiStatementRef] = None + extensions: Optional[Dict[IRI, Union[str, int, bool, list, dict, None]]] = None diff --git a/src/ralph/models/xapi/base/groups.py b/src/ralph/models/xapi/base/groups.py index 73c320058..705036f66 100644 --- a/src/ralph/models/xapi/base/groups.py +++ b/src/ralph/models/xapi/base/groups.py @@ -32,7 +32,7 @@ class BaseXapiGroupCommonProperties(BaseModelWithConfig, ABC): """ objectType: Literal["Group"] - name: Optional[StrictStr] + name: Optional[StrictStr] = None class BaseXapiAnonymousGroup(BaseXapiGroupCommonProperties): diff --git a/src/ralph/models/xapi/base/ifi.py b/src/ralph/models/xapi/base/ifi.py index 149d157b8..e36eac372 100644 --- a/src/ralph/models/xapi/base/ifi.py +++ b/src/ralph/models/xapi/base/ifi.py @@ -1,9 +1,10 @@ """Base xAPI `Inverse Functional Identifier` definitions.""" -from pydantic import AnyUrl, StrictStr, constr +from pydantic import StringConstraints, AnyUrl, StrictStr from ..config import BaseModelWithConfig from .common import IRI, MailtoEmail +from typing_extensions import Annotated class BaseXapiAccount(BaseModelWithConfig): @@ -35,7 +36,7 @@ class BaseXapiMboxSha1SumIFI(BaseModelWithConfig): mbox_sha1sum (str): Consists of the SHA1 hash of the Agent's email address. """ - mbox_sha1sum: constr(regex=r"^[0-9a-f]{40}$") # noqa:F722 + mbox_sha1sum: Annotated[str, StringConstraints(pattern=r"^[0-9a-f]{40}$")] # noqa:F722 class BaseXapiOpenIdIFI(BaseModelWithConfig): diff --git a/src/ralph/models/xapi/base/objects.py b/src/ralph/models/xapi/base/objects.py index 74180040f..7b6a5cd89 100644 --- a/src/ralph/models/xapi/base/objects.py +++ b/src/ralph/models/xapi/base/objects.py @@ -36,10 +36,10 @@ class BaseXapiSubStatement(BaseModelWithConfig): verb: BaseXapiVerb object: BaseXapiUnnestedObject objectType: Literal["SubStatement"] - result: Optional[BaseXapiResult] - context: Optional[BaseXapiContext] - timestamp: Optional[datetime] - attachments: Optional[List[BaseXapiAttachment]] + result: Optional[BaseXapiResult] = None + context: Optional[BaseXapiContext] = None + timestamp: Optional[datetime] = None + attachments: Optional[List[BaseXapiAttachment]] = None BaseXapiObject = Union[ diff --git a/src/ralph/models/xapi/base/results.py b/src/ralph/models/xapi/base/results.py index bd3d49ec9..39e3d9d4f 100644 --- a/src/ralph/models/xapi/base/results.py +++ b/src/ralph/models/xapi/base/results.py @@ -4,10 +4,11 @@ from decimal import Decimal from typing import Any, Dict, Optional, Union -from pydantic import StrictBool, StrictStr, conint, root_validator +from pydantic import Field, StrictBool, StrictStr, root_validator from ..config import BaseModelWithConfig from .common import IRI +from typing_extensions import Annotated class BaseXapiResultScore(BaseModelWithConfig): @@ -20,10 +21,10 @@ class BaseXapiResultScore(BaseModelWithConfig): max (Decimal): Consists of the highest possible score. """ - scaled: Optional[conint(ge=-1, le=1)] - raw: Optional[Decimal] - min: Optional[Decimal] - max: Optional[Decimal] + scaled: Optional[Annotated[int, Field(ge=-1, le=1)]] = None + raw: Optional[Decimal] = None + min: Optional[Decimal] = None + max: Optional[Decimal] = None @root_validator @classmethod @@ -58,9 +59,9 @@ class BaseXapiResult(BaseModelWithConfig): extensions (dict): Consists of a dictionary of other properties as needed. """ - score: Optional[BaseXapiResultScore] - success: Optional[StrictBool] - completion: Optional[StrictBool] - response: Optional[StrictStr] - duration: Optional[timedelta] - extensions: Optional[Dict[IRI, Union[str, int, bool, list, dict, None]]] + score: Optional[BaseXapiResultScore] = None + success: Optional[StrictBool] = None + completion: Optional[StrictBool] = None + response: Optional[StrictStr] = None + duration: Optional[timedelta] = None + extensions: Optional[Dict[IRI, Union[str, int, bool, list, dict, None]]] = None diff --git a/src/ralph/models/xapi/base/statements.py b/src/ralph/models/xapi/base/statements.py index 5282b51aa..984ce0dd5 100644 --- a/src/ralph/models/xapi/base/statements.py +++ b/src/ralph/models/xapi/base/statements.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Union from uuid import UUID -from pydantic import constr, root_validator +from pydantic import model_validator, StringConstraints from ..config import BaseModelWithConfig from .agents import BaseXapiAgent @@ -14,6 +14,7 @@ from .objects import BaseXapiObject from .results import BaseXapiResult from .verbs import BaseXapiVerb +from typing_extensions import Annotated class BaseXapiStatement(BaseModelWithConfig): @@ -33,19 +34,20 @@ class BaseXapiStatement(BaseModelWithConfig): attachments (list): Consists of a list of attachments. """ - id: Optional[UUID] + id: Optional[UUID] = None actor: Union[BaseXapiAgent, BaseXapiGroup] verb: BaseXapiVerb object: BaseXapiObject - result: Optional[BaseXapiResult] - context: Optional[BaseXapiContext] - timestamp: Optional[datetime] - stored: Optional[datetime] - authority: Optional[Union[BaseXapiAgent, BaseXapiGroup]] - version: constr(regex=r"^1\.0\.[0-9]+$") = "1.0.0" # noqa:F722 - attachments: Optional[List[BaseXapiAttachment]] + result: Optional[BaseXapiResult] = None + context: Optional[BaseXapiContext] = None + timestamp: Optional[datetime] = None + stored: Optional[datetime] = None + authority: Optional[Union[BaseXapiAgent, BaseXapiGroup]] = None + version: Annotated[str, StringConstraints(pattern=r"^1\.0\.[0-9]+$")] = "1.0.0" # noqa:F722 + attachments: Optional[List[BaseXapiAttachment]] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod @classmethod def check_absence_of_empty_and_invalid_values(cls, values: Any) -> Any: """Check the model for empty and invalid values. diff --git a/src/ralph/models/xapi/base/verbs.py b/src/ralph/models/xapi/base/verbs.py index aa91a6bea..2b86a738d 100644 --- a/src/ralph/models/xapi/base/verbs.py +++ b/src/ralph/models/xapi/base/verbs.py @@ -15,4 +15,4 @@ class BaseXapiVerb(BaseModelWithConfig): """ id: IRI - display: Optional[LanguageMap] + display: Optional[LanguageMap] = None diff --git a/src/ralph/models/xapi/config.py b/src/ralph/models/xapi/config.py index ee0ba9438..bc8218bda 100644 --- a/src/ralph/models/xapi/config.py +++ b/src/ralph/models/xapi/config.py @@ -1,19 +1,13 @@ """Base xAPI model configuration.""" -from pydantic import BaseModel, Extra +from pydantic import ConfigDict, BaseModel class BaseModelWithConfig(BaseModel): """Pydantic model for base configuration shared among all models.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = Extra.forbid - min_anystr_length = 1 + model_config = ConfigDict(extra="forbid", str_min_length=1) class BaseExtensionModelWithConfig(BaseModel): """Pydantic model for extension configuration shared among all models.""" - - class Config: # pylint: disable=missing-class-docstring # noqa: D106 - extra = Extra.allow - min_anystr_length = 0 + model_config = ConfigDict(extra="allow", str_min_length=0) diff --git a/src/ralph/models/xapi/lms/objects.py b/src/ralph/models/xapi/lms/objects.py index 7bc701bff..4d8d947b6 100644 --- a/src/ralph/models/xapi/lms/objects.py +++ b/src/ralph/models/xapi/lms/objects.py @@ -34,7 +34,7 @@ class LMSPageObjectDefinitionExtensions(BaseExtensionModelWithConfig): """ type: Optional[Literal["course", "course_list", "user_space"]] = Field( - alias=ACTIVITY_EXTENSIONS_TYPE + None, alias=ACTIVITY_EXTENSIONS_TYPE ) diff --git a/src/ralph/models/xapi/video/results.py b/src/ralph/models/xapi/video/results.py index 7c515ad53..c4c065ecf 100644 --- a/src/ralph/models/xapi/video/results.py +++ b/src/ralph/models/xapi/video/results.py @@ -34,7 +34,7 @@ class VideoResultExtensions(BaseExtensionModelWithConfig): """ time: NonNegativeFloat = Field(alias=RESULT_EXTENSION_TIME) - playedSegments: Optional[str] = Field(alias=CONTEXT_EXTENSION_PLAYED_SEGMENTS) + playedSegments: Optional[str] = Field(None, alias=CONTEXT_EXTENSION_PLAYED_SEGMENTS) class VideoPausedResultExtensions(VideoResultExtensions): @@ -44,7 +44,7 @@ class VideoPausedResultExtensions(VideoResultExtensions): progress (float): Consists of the ratio of media consumed by the actor. """ - progress: Optional[NonNegativeFloat] = Field(alias=RESULT_EXTENSION_PROGRESS) + progress: Optional[NonNegativeFloat] = Field(None, alias=RESULT_EXTENSION_PROGRESS) class VideoSeekedResultExtensions(BaseExtensionModelWithConfig): @@ -132,8 +132,8 @@ class VideoCompletedResult(BaseXapiResult): """ extensions: VideoCompletedResultExtensions - completion: Optional[Literal[True]] - duration: Optional[timedelta] + completion: Optional[Literal[True]] = None + duration: Optional[timedelta] = None class VideoTerminatedResult(BaseXapiResult): From fef360f4d375867e4efbcfa5c76ea1d52da6f15d Mon Sep 17 00:00:00 2001 From: lleeoo Date: Thu, 2 Nov 2023 11:24:45 +0100 Subject: [PATCH 60/65] progress with pydantic v2 --- src/ralph/conf.py | 10 ++++----- .../edx/textbook_interaction/fields/events.py | 22 +++++++++---------- src/ralph/models/xapi/base/results.py | 4 ++-- .../models/xapi/base/unnested_objects.py | 16 +++++++------- .../xapi/concepts/verbs/acrossx_profile.py | 2 +- .../verbs/activity_streams_vocabulary.py | 4 ++-- .../xapi/concepts/verbs/adl_vocabulary.py | 6 ++--- .../verbs/navy_common_reference_profile.py | 4 ++-- .../xapi/concepts/verbs/scorm_profile.py | 8 +++---- .../xapi/concepts/verbs/tincan_vocabulary.py | 6 ++--- src/ralph/models/xapi/concepts/verbs/video.py | 6 ++--- .../xapi/concepts/verbs/virtual_classroom.py | 16 +++++++------- tests/fixtures/hypothesis_configuration.py | 20 +++++++++-------- tests/fixtures/hypothesis_strategies.py | 2 +- 14 files changed, 63 insertions(+), 63 deletions(-) diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 9414a39bf..5349743bb 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Annotated, List, Optional, Sequence, Union, Tuple -from pydantic import AfterValidator, model_validator, ConfigDict, AnyHttpUrl, AnyUrl, BaseModel +from pydantic import AfterValidator, model_validator, ConfigDict, AnyHttpUrl, AnyUrl, BaseModel, parse_obj_as from ralph.exceptions import ConfigurationException @@ -242,7 +242,7 @@ class Config(BaseSettingsConfig): }, } PARSERS: ParserSettings = ParserSettings() - RUNSERVER_AUTH_BACKENDS: AuthBackends = AuthBackends([AuthBackend.BASIC]) + RUNSERVER_AUTH_BACKENDS: AuthBackends = parse_obj_as(AuthBackends, 'Basic') RUNSERVER_AUTH_OIDC_AUDIENCE: str = None RUNSERVER_AUTH_OIDC_ISSUER_URI: AnyHttpUrl = None RUNSERVER_BACKEND: Literal[ @@ -255,7 +255,7 @@ class Config(BaseSettingsConfig): LRS_RESTRICT_BY_AUTHORITY: bool = False LRS_RESTRICT_BY_SCOPES: bool = False SENTRY_CLI_TRACES_SAMPLE_RATE: float = 1.0 - SENTRY_DSN: str = None + SENTRY_DSN: Optional[str] = None SENTRY_IGNORE_HEALTH_CHECKS: bool = False SENTRY_LRS_TRACES_SAMPLE_RATE: float = 1.0 XAPI_FORWARDINGS: List[XapiForwardingConfigurationSettings] = [] @@ -274,9 +274,7 @@ def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name @classmethod def check_restriction_compatibility(cls, values): """Raise an error if scopes are being used without authority restriction.""" - if values.get("LRS_RESTRICT_BY_SCOPES") and not values.get( - "LRS_RESTRICT_BY_AUTHORITY" - ): + if values.LRS_RESTRICT_BY_SCOPES and not values.LRS_RESTRICT_BY_AUTHORITY: raise ConfigurationException( "LRS_RESTRICT_BY_AUTHORITY must be set to True if using " "LRS_RESTRICT_BY_SCOPES=True" diff --git a/src/ralph/models/edx/textbook_interaction/fields/events.py b/src/ralph/models/edx/textbook_interaction/fields/events.py index 6cabf6f07..01a56fe46 100644 --- a/src/ralph/models/edx/textbook_interaction/fields/events.py +++ b/src/ralph/models/edx/textbook_interaction/fields/events.py @@ -1,9 +1,9 @@ """Textbook interaction event fields definitions.""" import sys -from typing import Optional, Union +from typing import Annotated, Optional, Union -from pydantic import Field, constr +from pydantic import Field, StringConstraints from ...base import AbstractBaseEventField @@ -24,11 +24,11 @@ class TextbookInteractionBaseEventField(AbstractBaseEventField): """ page: int - chapter: constr( - regex=( + chapter: Annotated[str, StringConstraints( + pattern=( r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa ) - ) + )] class TextbookPdfThumbnailsToggledEventField(TextbookInteractionBaseEventField): @@ -74,11 +74,11 @@ class TextbookPdfChapterNavigatedEventField(AbstractBaseEventField): """ name: Literal["textbook.pdf.chapter.navigated"] - chapter: constr( - regex=( + chapter: Annotated[str, StringConstraints( + pattern=( r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa ) - ) + )] chapter_title: str @@ -263,11 +263,11 @@ class BookEventField(AbstractBaseEventField): clicked or `nextpage` value when the previous page button is clicked. """ - chapter: constr( - regex=( + chapter: Annotated[str, StringConstraints( + pattern=( r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa ) - ) + )] name: Union[ Literal["textbook.pdf.page.loaded"], Literal["textbook.pdf.page.navigatednext"] ] diff --git a/src/ralph/models/xapi/base/results.py b/src/ralph/models/xapi/base/results.py index 39e3d9d4f..13c7824ec 100644 --- a/src/ralph/models/xapi/base/results.py +++ b/src/ralph/models/xapi/base/results.py @@ -4,7 +4,7 @@ from decimal import Decimal from typing import Any, Dict, Optional, Union -from pydantic import Field, StrictBool, StrictStr, root_validator +from pydantic import Field, StrictBool, StrictStr, model_validator from ..config import BaseModelWithConfig from .common import IRI @@ -26,7 +26,7 @@ class BaseXapiResultScore(BaseModelWithConfig): min: Optional[Decimal] = None max: Optional[Decimal] = None - @root_validator + @model_validator(mode='after') # TODO: needs review @classmethod def check_raw_min_max_relation(cls, values: Any) -> Any: """Check the relationship `min < raw < max`.""" diff --git a/src/ralph/models/xapi/base/unnested_objects.py b/src/ralph/models/xapi/base/unnested_objects.py index d0f2a0bd4..1ed0de2a9 100644 --- a/src/ralph/models/xapi/base/unnested_objects.py +++ b/src/ralph/models/xapi/base/unnested_objects.py @@ -1,10 +1,10 @@ """Base xAPI `Object` definitions (1).""" import sys -from typing import Any, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import AnyUrl, StrictStr, constr, validator +from pydantic import AnyUrl, StrictStr, StringConstraints, validator from ..config import BaseModelWithConfig from .common import IRI, LanguageMap @@ -26,11 +26,11 @@ class BaseXapiActivityDefinition(BaseModelWithConfig): extensions (dict): Consists of a dictionary of other properties as needed. """ - name: Optional[LanguageMap] - description: Optional[LanguageMap] - type: Optional[IRI] - moreInfo: Optional[AnyUrl] - extensions: Optional[Dict[IRI, Union[str, int, bool, list, dict, None]]] + name: Optional[LanguageMap] = None + description: Optional[LanguageMap] = None + type: Optional[IRI] = None + moreInfo: Optional[AnyUrl] = None + extensions: Optional[Dict[IRI, Union[str, int, bool, list, dict, None]]] = None class BaseXapiInteractionComponent(BaseModelWithConfig): @@ -41,7 +41,7 @@ class BaseXapiInteractionComponent(BaseModelWithConfig): description (LanguageMap): Consists of the description of the interaction. """ - id: constr(regex=r"^[^\s]+$") # noqa:F722 + id: Annotated[str, StringConstraints(pattern=r"^[^\s]+$")] # #noqa:F722 description: Optional[LanguageMap] diff --git a/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py b/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py index f0d4d5e5b..317aafd75 100644 --- a/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py +++ b/src/ralph/models/xapi/concepts/verbs/acrossx_profile.py @@ -23,4 +23,4 @@ class PostedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/acrossx/verbs/posted" ] = "https://w3id.org/xapi/acrossx/verbs/posted" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["posted"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["posted"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py index 10b6cef1c..79ab6fd8d 100644 --- a/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/activity_streams_vocabulary.py @@ -21,7 +21,7 @@ class JoinVerb(BaseXapiVerb): """ id: Literal["http://activitystrea.ms/join"] = "http://activitystrea.ms/join" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["joined"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["joined"]]] = None class LeaveVerb(BaseXapiVerb): @@ -33,4 +33,4 @@ class LeaveVerb(BaseXapiVerb): """ id: Literal["http://activitystrea.ms/leave"] = "http://activitystrea.ms/leave" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["left"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["left"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py index da1b6804c..3c3dcb9f7 100644 --- a/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/adl_vocabulary.py @@ -23,7 +23,7 @@ class AskedVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/asked" ] = "http://adlnet.gov/expapi/verbs/asked" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["asked"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["asked"]]] = None class AnsweredVerb(BaseXapiVerb): @@ -37,7 +37,7 @@ class AnsweredVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/answered" ] = "http://adlnet.gov/expapi/verbs/answered" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["answered"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["answered"]]] = None class RegisteredVerb(BaseXapiVerb): @@ -51,4 +51,4 @@ class RegisteredVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/registered" ] = "http://adlnet.gov/expapi/verbs/registered" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["registered"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["registered"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py b/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py index 53f027ceb..400f36dc1 100644 --- a/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py +++ b/src/ralph/models/xapi/concepts/verbs/navy_common_reference_profile.py @@ -23,7 +23,7 @@ class AccessedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/netc/verbs/accessed" ] = "https://w3id.org/xapi/netc/verbs/accessed" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["accessed"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["accessed"]]] = None class UploadedVerb(BaseXapiVerb): @@ -37,4 +37,4 @@ class UploadedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/netc/verbs/uploaded" ] = "https://w3id.org/xapi/netc/verbs/uploaded" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["uploaded"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["uploaded"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/scorm_profile.py b/src/ralph/models/xapi/concepts/verbs/scorm_profile.py index 12dbd1b11..0c377064a 100644 --- a/src/ralph/models/xapi/concepts/verbs/scorm_profile.py +++ b/src/ralph/models/xapi/concepts/verbs/scorm_profile.py @@ -23,7 +23,7 @@ class CompletedVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/completed" ] = "http://adlnet.gov/expapi/verbs/completed" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["completed"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["completed"]]] = None class InitializedVerb(BaseXapiVerb): @@ -37,7 +37,7 @@ class InitializedVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/initialized" ] = "http://adlnet.gov/expapi/verbs/initialized" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["initialized"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["initialized"]]] = None class InteractedVerb(BaseXapiVerb): @@ -51,7 +51,7 @@ class InteractedVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/interacted" ] = "http://adlnet.gov/expapi/verbs/interacted" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["interacted"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["interacted"]]] = None class TerminatedVerb(BaseXapiVerb): @@ -65,4 +65,4 @@ class TerminatedVerb(BaseXapiVerb): id: Literal[ "http://adlnet.gov/expapi/verbs/terminated" ] = "http://adlnet.gov/expapi/verbs/terminated" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["terminated"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["terminated"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py index d32c25a81..defcbc266 100644 --- a/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py @@ -23,7 +23,7 @@ class ViewedVerb(BaseXapiVerb): id: Literal[ "http://id.tincanapi.com/verb/viewed" ] = "http://id.tincanapi.com/verb/viewed" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["viewed"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["viewed"]]] = None class DownloadedVerb(BaseXapiVerb): @@ -37,7 +37,7 @@ class DownloadedVerb(BaseXapiVerb): id: Literal[ "http://id.tincanapi.com/verb/downloaded" ] = "http://id.tincanapi.com/verb/downloaded" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["downloaded"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["downloaded"]]] = None # TODO: remove literal for LANG_EN_US_DISPLAY ? class UnregisteredVerb(BaseXapiVerb): @@ -51,4 +51,4 @@ class UnregisteredVerb(BaseXapiVerb): id: Literal[ "http://id.tincanapi.com/verb/unregistered" ] = "http://id.tincanapi.com/verb/unregistered" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unregistered"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unregistered"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/video.py b/src/ralph/models/xapi/concepts/verbs/video.py index d2e83d0b4..4d8129dcd 100644 --- a/src/ralph/models/xapi/concepts/verbs/video.py +++ b/src/ralph/models/xapi/concepts/verbs/video.py @@ -23,7 +23,7 @@ class PlayedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/video/verbs/played" ] = "https://w3id.org/xapi/video/verbs/played" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["played"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["played"]]] = None class PausedVerb(BaseXapiVerb): @@ -37,7 +37,7 @@ class PausedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/video/verbs/paused" ] = "https://w3id.org/xapi/video/verbs/paused" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["paused"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["paused"]]] = None class SeekedVerb(BaseXapiVerb): @@ -51,4 +51,4 @@ class SeekedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/video/verbs/seeked" ] = "https://w3id.org/xapi/video/verbs/seeked" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["seeked"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["seeked"]]] = None diff --git a/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py b/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py index fcd2320a6..9aaded225 100644 --- a/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py +++ b/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py @@ -24,7 +24,7 @@ class MutedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/muted" ] = "https://w3id.org/xapi/virtual-classroom/verbs/muted" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["muted"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["muted"]]] = None class UnmutedVerb(BaseXapiVerb): @@ -39,7 +39,7 @@ class UnmutedVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/unmuted" ] = "https://w3id.org/xapi/virtual-classroom/verbs/unmuted" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unmuted"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unmuted"]]] = None class StartedCameraVerb(BaseXapiVerb): @@ -54,7 +54,7 @@ class StartedCameraVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/started-camera" ] = "https://w3id.org/xapi/virtual-classroom/verbs/started-camera" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["started camera"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["started camera"]]] = None class StoppedCameraVerb(BaseXapiVerb): @@ -69,7 +69,7 @@ class StoppedCameraVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/stopped-camera" ] = "https://w3id.org/xapi/virtual-classroom/verbs/stopped-camera" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["stopped camera"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["stopped camera"]]] = None class SharedScreenVerb(BaseXapiVerb): @@ -84,7 +84,7 @@ class SharedScreenVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/shared-screen" ] = "https://w3id.org/xapi/virtual-classroom/verbs/shared-screen" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["shared screen"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["shared screen"]]] = None class UnsharedScreenVerb(BaseXapiVerb): @@ -99,7 +99,7 @@ class UnsharedScreenVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/unshared-screen" ] = "https://w3id.org/xapi/virtual-classroom/verbs/unshared-screen" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unshared screen"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unshared screen"]]] = None class RaisedHandVerb(BaseXapiVerb): @@ -114,7 +114,7 @@ class RaisedHandVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/raised-hand" ] = "https://w3id.org/xapi/virtual-classroom/verbs/raised-hand" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["raised hand"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["raised hand"]]] = None class LoweredHandVerb(BaseXapiVerb): @@ -129,4 +129,4 @@ class LoweredHandVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/lowered-hand" ] = "https://w3id.org/xapi/virtual-classroom/verbs/lowered-hand" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["lowered hand"]]] + display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["lowered hand"]]] = None diff --git a/tests/fixtures/hypothesis_configuration.py b/tests/fixtures/hypothesis_configuration.py index f7c7844b0..7b295cac1 100644 --- a/tests/fixtures/hypothesis_configuration.py +++ b/tests/fixtures/hypothesis_configuration.py @@ -11,12 +11,14 @@ settings.register_profile("development", max_examples=1) settings.load_profile("development") -st.register_type_strategy(str, st.text(min_size=1)) -st.register_type_strategy(StrictStr, st.text(min_size=1)) -st.register_type_strategy(AnyUrl, provisional.urls()) -st.register_type_strategy(AnyHttpUrl, provisional.urls()) -st.register_type_strategy(IRI, provisional.urls()) -st.register_type_strategy( - MailtoEmail, st.builds(operator.add, st.just("mailto:"), st.emails()) -) -st.register_type_strategy(LanguageTag, st.just("en-US")) + +# TODO: uncomment and fix below +# st.register_type_strategy(str, st.text(min_size=1)) +# st.register_type_strategy(StrictStr, st.text(min_size=1)) +# st.register_type_strategy(AnyUrl, provisional.urls()) +# st.register_type_strategy(AnyHttpUrl, provisional.urls()) +# st.register_type_strategy(IRI, provisional.urls()) +# st.register_type_strategy( +# MailtoEmail, st.builds(operator.add, st.just("mailto:"), st.emails()) +# ) +# st.register_type_strategy(LanguageTag, st.just("en-US")) diff --git a/tests/fixtures/hypothesis_strategies.py b/tests/fixtures/hypothesis_strategies.py index fb5e9f30e..40646eb3d 100644 --- a/tests/fixtures/hypothesis_strategies.py +++ b/tests/fixtures/hypothesis_strategies.py @@ -83,7 +83,7 @@ def custom_builds( arg = kwargs.get(name, None) if arg is False: continue - is_required = field.required or (arg is not None and _overwrite_default) + is_required = field.is_required or (arg is not None and _overwrite_default) required_optional = required if is_required or arg is not None else optional field_strategy = get_strategy_from(field.outer_type_) if arg is None else arg required_optional[field.alias] = field_strategy From 739b5417c18129c264568fc2b42710cdb4e607dc Mon Sep 17 00:00:00 2001 From: lleeoo Date: Thu, 2 Nov 2023 18:03:36 +0100 Subject: [PATCH 61/65] wip --- src/ralph/api/auth/basic.py | 8 +- src/ralph/api/auth/oidc.py | 2 +- src/ralph/api/models.py | 3 +- src/ralph/backends/conf.py | 16 ++- src/ralph/backends/data/base.py | 28 ++-- src/ralph/backends/data/clickhouse.py | 20 +-- src/ralph/backends/data/es.py | 15 ++- src/ralph/backends/data/fs.py | 14 +- src/ralph/backends/data/ldp.py | 13 +- src/ralph/backends/data/mongo.py | 38 ++++-- src/ralph/backends/data/s3.py | 25 ++-- src/ralph/backends/data/swift.py | 27 ++-- src/ralph/backends/http/async_lrs.py | 13 +- src/ralph/backends/http/base.py | 28 ++-- src/ralph/backends/stream/base.py | 23 ++-- src/ralph/backends/stream/ws.py | 10 +- src/ralph/conf.py | 71 +++++++--- src/ralph/models/edx/base.py | 24 +++- src/ralph/models/edx/browser.py | 8 +- .../models/edx/navigational/fields/events.py | 17 ++- .../models/edx/navigational/statements.py | 2 +- .../open_response_assessment/fields/events.py | 36 ++++-- .../edx/peer_instruction/fields/events.py | 2 +- .../edx/problem_interaction/fields/events.py | 121 ++++++++++++------ .../edx/textbook_interaction/fields/events.py | 39 +++--- src/ralph/models/edx/video/fields/events.py | 4 +- src/ralph/models/xapi/base/ifi.py | 8 +- src/ralph/models/xapi/base/results.py | 4 +- src/ralph/models/xapi/base/statements.py | 8 +- .../xapi/concepts/verbs/tincan_vocabulary.py | 4 +- .../xapi/concepts/verbs/virtual_classroom.py | 16 ++- src/ralph/models/xapi/config.py | 4 +- tests/fixtures/hypothesis_strategies.py | 4 +- tests/test_cli.py | 8 +- 34 files changed, 432 insertions(+), 231 deletions(-) diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index 2bc4a82ea..0c503d25c 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -10,7 +10,7 @@ from cachetools import TTLCache, cached from fastapi import Depends from fastapi.security import HTTPBasic, HTTPBasicCredentials -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, RootModel, model_validator from starlette.authentication import AuthenticationError from ralph.api.auth.user import AuthenticatedUser @@ -40,7 +40,7 @@ class UserCredentials(AuthenticatedUser): username: str -class ServerUsersCredentials(BaseModel): +class ServerUsersCredentials(RootModel[List[UserCredentials]]): """Custom root pydantic model. Describe expected list of all server users credentials as stored in @@ -51,8 +51,6 @@ class ServerUsersCredentials(BaseModel): list of all server users credentials. """ - __root__: List[UserCredentials] - def __add__(self, other) -> Any: # noqa: D105 return ServerUsersCredentials.parse_obj(self.__root__ + other.__root__) @@ -65,7 +63,7 @@ def __len__(self) -> int: # noqa: D105 def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105 return iter(self.__root__) - @root_validator + @model_validator(mode="after") @classmethod def ensure_unique_username(cls, values: Any) -> Any: """Every username should be unique among registered users.""" diff --git a/src/ralph/api/auth/oidc.py b/src/ralph/api/auth/oidc.py index d41b5687f..79a8a59f3 100644 --- a/src/ralph/api/auth/oidc.py +++ b/src/ralph/api/auth/oidc.py @@ -9,7 +9,7 @@ from fastapi.security import HTTPBearer, OpenIdConnect from jose import ExpiredSignatureError, JWTError, jwt from jose.exceptions import JWTClaimsError -from pydantic import ConfigDict, AnyUrl, BaseModel +from pydantic import AnyUrl, BaseModel, ConfigDict from typing_extensions import Annotated from ralph.api.auth.user import AuthenticatedUser, UserScopes diff --git a/src/ralph/api/models.py b/src/ralph/api/models.py index 3bf051e2b..9e2a64504 100644 --- a/src/ralph/api/models.py +++ b/src/ralph/api/models.py @@ -6,7 +6,7 @@ from typing import Optional, Union from uuid import UUID -from pydantic import ConfigDict, AnyUrl, BaseModel +from pydantic import AnyUrl, BaseModel, ConfigDict from ..models.xapi.base.agents import BaseXapiAgent from ..models.xapi.base.groups import BaseXapiGroup @@ -28,6 +28,7 @@ class BaseModelWithLaxConfig(BaseModel): Common base lax model to perform light input validation as we receive statements through the API. """ + model_config = ConfigDict(extra="allow") diff --git a/src/ralph/backends/conf.py b/src/ralph/backends/conf.py index 14f1fa0dd..27d4eedcf 100644 --- a/src/ralph/backends/conf.py +++ b/src/ralph/backends/conf.py @@ -1,6 +1,7 @@ """Configurations for Ralph backends.""" from pydantic import BaseModel +from pydantic_settings import BaseSettings, SettingsConfigDict from ralph.backends.data.clickhouse import ClickHouseDataBackendSettings from ralph.backends.data.es import ESDataBackendSettings @@ -13,8 +14,7 @@ from ralph.backends.lrs.clickhouse import ClickHouseLRSBackendSettings from ralph.backends.lrs.fs import FSLRSBackendSettings from ralph.backends.stream.ws import WSStreamBackendSettings -from ralph.conf import BaseSettingsConfig, core_settings -from pydantic_settings import BaseSettings +from ralph.conf import BASE_SETTINGS_CONFIG, core_settings # Active Data backend Settings. @@ -82,11 +82,15 @@ class BackendSettings(BaseSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_file = ".env" - env_file_encoding = core_settings.LOCALE_ENCODING + # env_file = ".env" + # env_file_encoding = core_settings.LOCALE_ENCODING + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_file=".env", env_file_encoding=core_settings.LOCALE_ENCODING + ) BACKENDS: Backends = Backends() diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 96f56f5fd..c2cc881eb 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -8,10 +8,10 @@ from typing import Iterable, Iterator, Optional, Union from pydantic import BaseModel, ValidationError +from pydantic_settings import BaseSettings, SettingsConfigDict -from ralph.conf import BaseSettingsConfig, core_settings +from ralph.conf import BASE_SETTINGS_CONFIG, core_settings from ralph.exceptions import BackendParameterException -from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) @@ -21,17 +21,29 @@ class BaseDataBackendSettings(BaseSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__DATA__" - env_file = ".env" - env_file_encoding = core_settings.LOCALE_ENCODING + # env_prefix = "RALPH_BACKENDS__DATA__" + # env_file = ".env" + # env_file_encoding = core_settings.LOCALE_ENCODING + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__", + env_file=".env", + env_file_encoding=core_settings.LOCALE_ENCODING, + ) class BaseQuery(BaseModel): """Base query model.""" - model_config = SettingsConfigDict(env_prefix="RALPH_BACKENDS__DATA__", env_file=".env", env_file_encoding=core_settings.LOCALE_ENCODING, extra="forbid") + + model_config = SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__", + env_file=".env", + env_file_encoding=core_settings.LOCALE_ENCODING, + extra="forbid", + ) query_string: Union[str, None] = None diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index c9ffafc1d..9829106c1 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -20,7 +20,9 @@ import clickhouse_connect from clickhouse_connect.driver.exceptions import ClickHouseError -from pydantic import Field, BaseModel, Json, ValidationError +from pydantic import BaseModel, Field, Json, ValidationError +from pydantic_settings import SettingsConfigDict +from typing_extensions import Annotated from ralph.backends.data.base import ( BaseDataBackend, @@ -30,9 +32,8 @@ DataBackendStatus, enforce_query_checks, ) -from ralph.conf import BaseSettingsConfig, ClientOptions +from ralph.conf import BASE_SETTINGS_CONFIG, ClientOptions from ralph.exceptions import BackendException, BackendParameterException -from typing_extensions import Annotated logger = logging.getLogger(__name__) @@ -78,17 +79,20 @@ class ClickHouseDataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__DATA__CLICKHOUSE__" + # env_prefix = "RALPH_BACKENDS__DATA__CLICKHOUSE__" + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__CLICKHOUSE__" + ) HOST: str = "localhost" PORT: int = 8123 DATABASE: str = "xapi" EVENT_TABLE_NAME: str = "xapi_events_all" - USERNAME: str = None - PASSWORD: str = None + USERNAME: Optional[str] = None + PASSWORD: Optional[str] = None CLIENT_OPTIONS: ClickHouseClientOptions = ClickHouseClientOptions() DEFAULT_CHUNK_SIZE: int = 500 LOCALE_ENCODING: str = "utf8" diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 9c4556d8a..7bb92ff81 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -9,6 +9,7 @@ from elasticsearch import ApiError, Elasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, streaming_bulk from pydantic import BaseModel +from pydantic_settings import SettingsConfigDict from ralph.backends.data.base import ( BaseDataBackend, @@ -18,7 +19,7 @@ DataBackendStatus, enforce_query_checks, ) -from ralph.conf import BaseSettingsConfig, ClientOptions, CommaSeparatedTuple +from ralph.conf import BASE_SETTINGS_CONFIG, ClientOptions, CommaSeparatedTuple from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import parse_bytes_to_dict, read_raw @@ -54,10 +55,14 @@ class ESDataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__DATA__ES__" + # env_prefix = "RALPH_BACKENDS__DATA__ES__" + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__ES__" + ) ALLOW_YELLOW_STATUS: bool = False CLIENT_OPTIONS: ESClientOptions = ESClientOptions() @@ -66,7 +71,7 @@ class Config(BaseSettingsConfig): HOSTS: CommaSeparatedTuple = ("http://localhost:9200",) LOCALE_ENCODING: str = "utf8" POINT_IN_TIME_KEEP_ALIVE: str = "1m" - REFRESH_AFTER_WRITE: Union[Literal["false", "true", "wait_for"], bool, str, None] + REFRESH_AFTER_WRITE: Union[Literal["false", "true", "wait_for"], bool, str, None] = False # TODO: check that this is the good default class ESQueryPit(BaseModel): diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 951908136..f0cd8f841 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -10,6 +10,8 @@ from typing import IO, Iterable, Iterator, Optional, Union from uuid import uuid4 +from pydantic_settings import SettingsConfigDict + from ralph.backends.data.base import ( BaseDataBackend, BaseDataBackendSettings, @@ -19,7 +21,7 @@ enforce_query_checks, ) from ralph.backends.mixins import HistoryMixin -from ralph.conf import BaseSettingsConfig +from ralph.conf import BASE_SETTINGS_CONFIG from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import now @@ -40,10 +42,14 @@ class FSDataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" + + # env_prefix = "RALPH_BACKENDS__DATA__FS__" - env_prefix = "RALPH_BACKENDS__DATA__FS__" + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__FS__" + ) DEFAULT_CHUNK_SIZE: int = 4096 DEFAULT_DIRECTORY_PATH: Path = Path(".") diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index 6fc5b3fbd..76d966a5d 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -5,6 +5,7 @@ import ovh import requests +from pydantic_settings import SettingsConfigDict from ralph.backends.data.base import ( BaseDataBackend, @@ -15,7 +16,7 @@ enforce_query_checks, ) from ralph.backends.mixins import HistoryMixin -from ralph.conf import BaseSettingsConfig +from ralph.conf import BASE_SETTINGS_CONFIG from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import now @@ -37,10 +38,14 @@ class LDPDataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__DATA__LDP__" + # env_prefix = "RALPH_BACKENDS__DATA__LDP__" + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__LDP__" + ) APPLICATION_KEY: Optional[str] = None APPLICATION_SECRET: Optional[str] = None diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index f0fd75801..c0b4c50af 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -13,7 +13,8 @@ from bson.errors import BSONError from bson.objectid import ObjectId from dateutil.parser import isoparse -from pydantic import StringConstraints, Json, MongoDsn +from pydantic import Json, MongoDsn, StringConstraints +from pydantic_settings import SettingsConfigDict from pymongo import MongoClient, ReplaceOne from pymongo.collection import Collection from pymongo.errors import ( @@ -23,6 +24,7 @@ InvalidOperation, PyMongoError, ) +from typing_extensions import Annotated from ralph.backends.data.base import ( BaseDataBackend, @@ -32,10 +34,9 @@ DataBackendStatus, enforce_query_checks, ) -from ralph.conf import BaseSettingsConfig, ClientOptions +from ralph.conf import BASE_SETTINGS_CONFIG, ClientOptions from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import parse_bytes_to_dict, read_raw -from typing_extensions import Annotated logger = logging.getLogger(__name__) @@ -61,16 +62,27 @@ class MongoDataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" - - env_prefix = "RALPH_BACKENDS__DATA__MONGO__" - - CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/", scheme="mongodb") - DEFAULT_DATABASE: Annotated[str, StringConstraints(pattern=r"^[^\s.$/\\\"\x00]+$")] = "statements" # noqa : F722 - DEFAULT_COLLECTION: Annotated[str, StringConstraints( - pattern=r"^(?!.*\.\.)[^.$\x00]+(?:\.[^.$\x00]+)*$" # noqa : F722 - )] = "marsha" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" + + # env_prefix = "RALPH_BACKENDS__DATA__MONGO__" + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__MONGO__" + ) + + CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/") + #CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/", scheme="mongodb") # TODO: check why we remove scheme + DEFAULT_DATABASE: Annotated[ + str, StringConstraints(pattern=r"^[^\s.$/\\\"\x00]+$") + ] = "statements" # noqa : F722 + DEFAULT_COLLECTION: str = "marsha" + # DEFAULT_COLLECTION: Annotated[ # TODO: Uncomment after pydantic 2.5 https://github.com/pydantic/pydantic/issues/7058 + # str, + # StringConstraints( + # pattern=r"^(?!.*\.\.)[^.$\x00]+(?:\.[^.$\x00]+)*$" # noqa : F722 + # ), + # ] = "marsha" CLIENT_OPTIONS: MongoClientOptions = MongoClientOptions() DEFAULT_CHUNK_SIZE: int = 500 LOCALE_ENCODING: str = "utf8" diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index d5598b4de..dff8518cd 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -17,6 +17,7 @@ ResponseStreamingError, ) from botocore.response import StreamingBody +from pydantic_settings import SettingsConfigDict from requests_toolbelt import StreamingIterator from ralph.backends.data.base import ( @@ -28,7 +29,7 @@ enforce_query_checks, ) from ralph.backends.mixins import HistoryMixin -from ralph.conf import BaseSettingsConfig +from ralph.conf import BASE_SETTINGS_CONFIG from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import now @@ -52,17 +53,21 @@ class S3DataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__DATA__S3__" + # env_prefix = "RALPH_BACKENDS__DATA__S3__" - ACCESS_KEY_ID: str = None - SECRET_ACCESS_KEY: str = None - SESSION_TOKEN: str = None - ENDPOINT_URL: str = None - DEFAULT_REGION: str = None - DEFAULT_BUCKET_NAME: str = None + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__S3__" + ) + + ACCESS_KEY_ID: Optional[str] = None + SECRET_ACCESS_KEY: Optional[str] = None + SESSION_TOKEN: Optional[str] = None + ENDPOINT_URL: Optional[str] = None + DEFAULT_REGION: Optional[str] = None + DEFAULT_BUCKET_NAME: Optional[str] = None DEFAULT_CHUNK_SIZE: int = 4096 LOCALE_ENCODING: str = "utf8" diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index c43dd3ae7..100c31e24 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -7,6 +7,7 @@ from typing import Iterable, Iterator, Optional, Union from uuid import uuid4 +from pydantic_settings import SettingsConfigDict from swiftclient.service import ClientException, Connection from ralph.backends.data.base import ( @@ -18,7 +19,7 @@ enforce_query_checks, ) from ralph.backends.mixins import HistoryMixin -from ralph.conf import BaseSettingsConfig +from ralph.conf import BASE_SETTINGS_CONFIG from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import now @@ -45,22 +46,26 @@ class SwiftDataBackendSettings(BaseDataBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__DATA__SWIFT__" + # env_prefix = "RALPH_BACKENDS__DATA__SWIFT__" + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__DATA__SWIFT__" + ) AUTH_URL: str = "https://auth.cloud.ovh.net/" - USERNAME: str = None - PASSWORD: str = None + USERNAME: Optional[str] = None + PASSWORD: Optional[str] = None IDENTITY_API_VERSION: str = "3" - TENANT_ID: str = None - TENANT_NAME: str = None + TENANT_ID: Optional[str] = None + TENANT_NAME: Optional[str] = None PROJECT_DOMAIN_NAME: str = "Default" - REGION_NAME: str = None - OBJECT_STORAGE_URL: str = None + REGION_NAME: Optional[str] = None + OBJECT_STORAGE_URL: Optional[str] = None USER_DOMAIN_NAME: str = "Default" - DEFAULT_CONTAINER: str = None + DEFAULT_CONTAINER: Optional[str] = None LOCALE_ENCODING: str = "utf8" diff --git a/src/ralph/backends/http/async_lrs.py b/src/ralph/backends/http/async_lrs.py index ada37de73..dc59b7f90 100644 --- a/src/ralph/backends/http/async_lrs.py +++ b/src/ralph/backends/http/async_lrs.py @@ -13,8 +13,9 @@ from more_itertools import chunked from pydantic import AnyHttpUrl, BaseModel, Field, NonNegativeInt, parse_obj_as from pydantic.types import PositiveInt +from pydantic_settings import SettingsConfigDict -from ralph.conf import BaseSettingsConfig, HeadersParameters +from ralph.conf import BASE_SETTINGS_CONFIG, HeadersParameters from ralph.exceptions import BackendException, BackendParameterException from ralph.models.xapi.base.agents import BaseXapiAgent from ralph.models.xapi.base.common import IRI @@ -54,10 +55,14 @@ class LRSHTTPBackendSettings(BaseHTTPBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__HTTP__LRS__" + # env_prefix = "RALPH_BACKENDS__HTTP__LRS__" + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__HTTP__LRS__" + ) BASE_URL: AnyHttpUrl = Field("http://0.0.0.0:8100") USERNAME: str = "ralph" diff --git a/src/ralph/backends/http/base.py b/src/ralph/backends/http/base.py index ff9bd22aa..f75b68571 100644 --- a/src/ralph/backends/http/base.py +++ b/src/ralph/backends/http/base.py @@ -8,10 +8,10 @@ from pydantic import BaseModel, ValidationError from pydantic.types import PositiveInt +from pydantic_settings import BaseSettings, SettingsConfigDict -from ralph.conf import BaseSettingsConfig, core_settings +from ralph.conf import BASE_SETTINGS_CONFIG, core_settings from ralph.exceptions import BackendParameterException -from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) @@ -21,12 +21,18 @@ class BaseHTTPBackendSettings(BaseSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__HTTP__" - env_file = ".env" - env_file_encoding = core_settings.LOCALE_ENCODING + # env_prefix = "RALPH_BACKENDS__HTTP__" + # env_file = ".env" + # env_file_encoding = core_settings.LOCALE_ENCODING + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__HTTP__", + env_file=".env", + env_file_encoding=core_settings.LOCALE_ENCODING, + ) @unique @@ -72,7 +78,13 @@ def wrapper(*args, **kwargs): class BaseQuery(BaseModel): """Base query model.""" - model_config = SettingsConfigDict(env_prefix="RALPH_BACKENDS__HTTP__", env_file=".env", env_file_encoding=core_settings.LOCALE_ENCODING, extra="forbid") + + model_config = SettingsConfigDict( + env_prefix="RALPH_BACKENDS__HTTP__", + env_file=".env", + env_file_encoding=core_settings.LOCALE_ENCODING, + extra="forbid", + ) query_string: Optional[str] = None diff --git a/src/ralph/backends/stream/base.py b/src/ralph/backends/stream/base.py index cc2845e2a..fd5d8bf3c 100644 --- a/src/ralph/backends/stream/base.py +++ b/src/ralph/backends/stream/base.py @@ -3,8 +3,9 @@ from abc import ABC, abstractmethod from typing import BinaryIO -from ralph.conf import BaseSettingsConfig, core_settings -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict + +from ralph.conf import BASE_SETTINGS_CONFIG, core_settings class BaseStreamBackendSettings(BaseSettings): @@ -12,12 +13,18 @@ class BaseStreamBackendSettings(BaseSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" - - env_prefix = "RALPH_BACKENDS__STREAM__" - env_file = ".env" - env_file_encoding = core_settings.LOCALE_ENCODING + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" + + # env_prefix = "RALPH_BACKENDS__STREAM__" + # env_file = ".env" + # env_file_encoding = core_settings.LOCALE_ENCODING + + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_prefix="RALPH_BACKENDS__STREAM__", + env_file=".env", + env_file_encoding=core_settings.LOCALE_ENCODING, + ) class BaseStreamBackend(ABC): diff --git a/src/ralph/backends/stream/ws.py b/src/ralph/backends/stream/ws.py index c201dffec..512689554 100644 --- a/src/ralph/backends/stream/ws.py +++ b/src/ralph/backends/stream/ws.py @@ -5,8 +5,9 @@ from typing import BinaryIO, Optional import websockets +from pydantic_settings import SettingsConfigDict -from ralph.conf import BaseSettingsConfig +from ralph.conf import BASE_SETTINGS_CONFIG from .base import BaseStreamBackend, BaseStreamBackendSettings @@ -22,10 +23,11 @@ class WSStreamBackendSettings(BaseStreamBackendSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" - env_prefix = "RALPH_BACKENDS__STREAM__WS__" + # env_prefix = "RALPH_BACKENDS__STREAM__WS__" + model_config = BASE_SETTINGS_CONFIG URI: Optional[str] = None diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 5349743bb..55b0a8cc8 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -4,14 +4,22 @@ import sys from enum import Enum from pathlib import Path -from typing import Annotated, List, Optional, Sequence, Union, Tuple - -from pydantic import AfterValidator, model_validator, ConfigDict, AnyHttpUrl, AnyUrl, BaseModel, parse_obj_as +from typing import Annotated, List, Optional, Sequence, Tuple, Union + +from pydantic import ( + AfterValidator, + AnyHttpUrl, + AnyUrl, + BaseModel, + ConfigDict, + model_validator, + parse_obj_as, +) +from pydantic_settings import BaseSettings, SettingsConfigDict from ralph.exceptions import ConfigurationException from .utils import import_string -from pydantic_settings import BaseSettings, SettingsConfigDict if sys.version_info >= (3, 8): from typing import Literal @@ -31,13 +39,17 @@ MODEL_PATH_SEPARATOR = "__" -class BaseSettingsConfig: - """Pydantic model for BaseSettings Configuration.""" +# class BaseSettingsConfig: +# """Pydantic model for BaseSettings Configuration.""" + +# case_sensitive = True +# env_nested_delimiter = "__" +# env_prefix = "RALPH_" +# extra = "ignore" - case_sensitive = True - env_nested_delimiter = "__" - env_prefix = "RALPH_" - extra = "ignore" +BASE_SETTINGS_CONFIG = SettingsConfigDict( + case_sensitive=True, env_nested_delimiter="__", env_prefix="RALPH_", extra="ignore" +) class CoreSettings(BaseSettings): @@ -45,8 +57,9 @@ class CoreSettings(BaseSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" + model_config = BASE_SETTINGS_CONFIG APP_DIR: Path = get_app_dir("ralph") LOCALE_ENCODING: str = getattr(io, "LOCALE_ENCODING", "utf8") @@ -74,6 +87,7 @@ class Config(BaseSettingsConfig): # yield validate + def validate_comma_separated_tuple(value: Union[str, Tuple[str, ...]]) -> Tuple[str]: """Checks whether the value is a comma separated string or a tuple.""" @@ -85,16 +99,20 @@ def validate_comma_separated_tuple(value: Union[str, Tuple[str, ...]]) -> Tuple[ raise TypeError("Invalid comma separated list") -CommaSeparatedTuple = Annotated[Union[str, Tuple[str, ...]], AfterValidator(validate_comma_separated_tuple)] + +CommaSeparatedTuple = Annotated[ + Union[str, Tuple[str, ...]], AfterValidator(validate_comma_separated_tuple) +] class InstantiableSettingsItem(BaseModel): """Pydantic model for a settings configuration item that can be instantiated.""" + # TODO[pydantic]: The following keys were removed: `underscore_attrs_are_private`. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. model_config = SettingsConfigDict(underscore_attrs_are_private=True) - _class_path: str = None + _class_path: Optional[str] = None def get_instance(self, **init_parameters): """Return an instance of the settings item class using its `_class_path`.""" @@ -103,11 +121,13 @@ def get_instance(self, **init_parameters): class ClientOptions(BaseModel): """Pydantic model for additional client options.""" + model_config = ConfigDict(extra="forbid") class HeadersParameters(BaseModel): """Pydantic model for headers parameters.""" + model_config = ConfigDict(extra="allow") @@ -135,6 +155,7 @@ class ParserSettings(BaseModel): class XapiForwardingConfigurationSettings(BaseModel): """Pydantic model for xAPI forwarding configuration item.""" + model_config = ConfigDict(str_min_length=1) url: AnyUrl @@ -176,6 +197,7 @@ class AuthBackend(Enum): # yield validate + def validate_auth_backends( value: Union[AuthBackend, Tuple[AuthBackend], List[AuthBackend]] ) -> Tuple[AuthBackend]: @@ -188,7 +210,10 @@ def validate_auth_backends( raise TypeError("Invalid comma separated list") -AuthBackends = Annotated[Union[str, Tuple[str, ...], List[str]], AfterValidator(validate_auth_backends)] + +AuthBackends = Annotated[ + Union[str, Tuple[str, ...], List[str]], AfterValidator(validate_auth_backends) +] class Settings(BaseSettings): @@ -196,11 +221,15 @@ class Settings(BaseSettings): # TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - class Config(BaseSettingsConfig): - """Pydantic Configuration.""" + # class Config(BaseSettingsConfig): + # """Pydantic Configuration.""" + + # env_file = ".env" + # env_file_encoding = core_settings.LOCALE_ENCODING - env_file = ".env" - env_file_encoding = core_settings.LOCALE_ENCODING + model_config = BASE_SETTINGS_CONFIG | SettingsConfigDict( + env_file=".env", env_file_encoding=core_settings.LOCALE_ENCODING + ) _CORE: CoreSettings = core_settings AUTH_FILE: Path = _CORE.APP_DIR / "auth.json" @@ -242,7 +271,7 @@ class Config(BaseSettingsConfig): }, } PARSERS: ParserSettings = ParserSettings() - RUNSERVER_AUTH_BACKENDS: AuthBackends = parse_obj_as(AuthBackends, 'Basic') + RUNSERVER_AUTH_BACKENDS: AuthBackends = parse_obj_as(AuthBackends, "Basic") RUNSERVER_AUTH_OIDC_AUDIENCE: str = None RUNSERVER_AUTH_OIDC_ISSUER_URI: AnyHttpUrl = None RUNSERVER_BACKEND: Literal[ @@ -270,7 +299,7 @@ def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name """Return Ralph's default locale encoding.""" return self._CORE.LOCALE_ENCODING - @model_validator(mode='after') + @model_validator(mode="after") @classmethod def check_restriction_compatibility(cls, values): """Raise an error if scopes are being used without authority restriction.""" diff --git a/src/ralph/models/edx/base.py b/src/ralph/models/edx/base.py index 4c37c1fd4..7a8af1772 100644 --- a/src/ralph/models/edx/base.py +++ b/src/ralph/models/edx/base.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Dict, Optional, Union -from pydantic import StringConstraints, ConfigDict, AnyHttpUrl, BaseModel +from pydantic import AnyHttpUrl, BaseModel, ConfigDict, StringConstraints from typing_extensions import Annotated if sys.version_info >= (3, 8): @@ -17,6 +17,7 @@ class BaseModelWithConfig(BaseModel): """Pydantic model for base configuration shared among all models.""" + model_config = ConfigDict(extra="forbid") @@ -28,12 +29,17 @@ class ContextModuleField(BaseModelWithConfig): display_name (str): Consists of a short description or title of the component. """ - usage_key: Annotated[str, StringConstraints(pattern=r"^block-v1:.+\+.+\+.+type@.+@[a-f0-9]{32}$")] # noqa:F722 + usage_key: Annotated[ + str, StringConstraints(pattern=r"^block-v1:.+\+.+\+.+type@.+@[a-f0-9]{32}$") + ] # noqa:F722 display_name: str original_usage_key: Optional[ - Annotated[str, StringConstraints( - pattern=r"^block-v1:.+\+.+\+.+type@problem\+block@[a-f0-9]{32}$" # noqa:F722 - )] + Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:.+\+.+\+.+type@problem\+block@[a-f0-9]{32}$" # noqa:F722 + ), + ] ] = None original_usage_version: Optional[str] = None @@ -80,7 +86,9 @@ class BaseContextField(BaseModelWithConfig): `request.META['PATH_INFO']` """ - course_id: Annotated[str, StringConstraints(pattern=r"^$|^course-v1:.+\+.+\+.+$")] # noqa:F722 + course_id: Annotated[ + str, StringConstraints(pattern=r"^$|^course-v1:.+\+.+\+.+$") + ] # noqa:F722 course_user_tags: Optional[Dict[str, str]] = None module: Optional[ContextModuleField] = None org_id: str @@ -150,7 +158,9 @@ class BaseEdxModel(BaseModelWithConfig): In JSON the value is `null` instead of `None`. """ - username: Union[Annotated[str, StringConstraints(min_length=2, max_length=30)], Literal[""]] + username: Union[ + Annotated[str, StringConstraints(min_length=2, max_length=30)], Literal[""] + ] ip: Union[IPv4Address, Literal[""]] agent: str host: str diff --git a/src/ralph/models/edx/browser.py b/src/ralph/models/edx/browser.py index f40572540..ad4747f6b 100644 --- a/src/ralph/models/edx/browser.py +++ b/src/ralph/models/edx/browser.py @@ -3,10 +3,10 @@ import sys from typing import Union -from pydantic import StringConstraints, AnyUrl +from pydantic import AnyUrl, StringConstraints +from typing_extensions import Annotated from .base import BaseEdxModel -from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -30,4 +30,6 @@ class BaseBrowserModel(BaseEdxModel): event_source: Literal["browser"] page: AnyUrl - session: Union[Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}$")], Literal[""]] # noqa: F722 + session: Union[ + Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}$")], Literal[""] + ] # noqa: F722 diff --git a/src/ralph/models/edx/navigational/fields/events.py b/src/ralph/models/edx/navigational/fields/events.py index fb345fc97..d0990446d 100644 --- a/src/ralph/models/edx/navigational/fields/events.py +++ b/src/ralph/models/edx/navigational/fields/events.py @@ -1,9 +1,9 @@ """Navigational event field definition.""" from pydantic import StringConstraints +from typing_extensions import Annotated from ...base import AbstractBaseEventField -from typing_extensions import Annotated class NavigationalEventField(AbstractBaseEventField): @@ -21,11 +21,14 @@ class NavigationalEventField(AbstractBaseEventField): being navigated away from. """ - id: Annotated[str, StringConstraints( - pattern=( - r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type" # noqa : F722 - r"@sequential\+block@[a-f0-9]{32}$" # noqa : F722 - ) - )] + id: Annotated[ + str, + StringConstraints( + pattern=( + r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type" # noqa : F722 + r"@sequential\+block@[a-f0-9]{32}$" # noqa : F722 + ) + ), + ] new: int old: int diff --git a/src/ralph/models/edx/navigational/statements.py b/src/ralph/models/edx/navigational/statements.py index ed41c1563..9df3e9b30 100644 --- a/src/ralph/models/edx/navigational/statements.py +++ b/src/ralph/models/edx/navigational/statements.py @@ -3,7 +3,7 @@ import sys from typing import Union -from pydantic import field_validator, Json +from pydantic import Json, field_validator from ralph.models.selector import selector diff --git a/src/ralph/models/edx/open_response_assessment/fields/events.py b/src/ralph/models/edx/open_response_assessment/fields/events.py index cfde8da85..d7b576293 100644 --- a/src/ralph/models/edx/open_response_assessment/fields/events.py +++ b/src/ralph/models/edx/open_response_assessment/fields/events.py @@ -6,9 +6,9 @@ from uuid import UUID from pydantic import StringConstraints +from typing_extensions import Annotated from ralph.models.edx.base import AbstractBaseEventField, BaseModelWithConfig -from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -31,12 +31,15 @@ class ORAGetPeerSubmissionEventField(AbstractBaseEventField): """ course_id: Annotated[str, StringConstraints(max_length=255)] - item_id: Annotated[str, StringConstraints( - pattern=( - r"^block-v1:.+\+.+\+.+type@openassessment" # noqa : F722 - r"+block@[a-f0-9]{32}$" # noqa : F722 - ) - )] + item_id: Annotated[ + str, + StringConstraints( + pattern=( + r"^block-v1:.+\+.+\+.+type@openassessment" # noqa : F722 + r"+block@[a-f0-9]{32}$" # noqa : F722 + ) + ), + ] requesting_student_id: str submission_returned_uuid: Union[str, None] = None @@ -58,12 +61,15 @@ class ORAGetSubmissionForStaffGradingEventField(AbstractBaseEventField): Currently, set to `full-grade`. """ - item_id: Annotated[str, StringConstraints( - pattern=( - r"^block-v1:.+\+.+\+.+type@openassessment" # noqa : F722 - r"+block@[a-f0-9]{32}$" # noqa : F722 - ) - )] + item_id: Annotated[ + str, + StringConstraints( + pattern=( + r"^block-v1:.+\+.+\+.+type@openassessment" # noqa : F722 + r"+block@[a-f0-9]{32}$" # noqa : F722 + ) + ), + ] submission_returned_uuid: Union[str, None] = None requesting_staff_id: str type: Literal["full-grade"] @@ -110,7 +116,9 @@ class ORAAssessEventRubricField(BaseModelWithConfig): assess the response. """ - content_hash: Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{1,40}$")] # noqa: F722 + content_hash: Annotated[ + str, StringConstraints(pattern=r"^[a-f0-9]{1,40}$") + ] # noqa: F722 class ORAAssessEventField(AbstractBaseEventField): diff --git a/src/ralph/models/edx/peer_instruction/fields/events.py b/src/ralph/models/edx/peer_instruction/fields/events.py index ad30f6294..dc4e2ac44 100644 --- a/src/ralph/models/edx/peer_instruction/fields/events.py +++ b/src/ralph/models/edx/peer_instruction/fields/events.py @@ -1,9 +1,9 @@ """Peer instruction event field definition.""" from pydantic import StringConstraints +from typing_extensions import Annotated from ...base import AbstractBaseEventField -from typing_extensions import Annotated class PeerInstructionEventField(AbstractBaseEventField): diff --git a/src/ralph/models/edx/problem_interaction/fields/events.py b/src/ralph/models/edx/problem_interaction/fields/events.py index b4cf4119f..6f8b8edbb 100644 --- a/src/ralph/models/edx/problem_interaction/fields/events.py +++ b/src/ralph/models/edx/problem_interaction/fields/events.py @@ -5,9 +5,9 @@ from typing import Dict, List, Optional, Union from pydantic import StringConstraints +from typing_extensions import Annotated from ...base import AbstractBaseEventField, BaseModelWithConfig -from typing_extensions import Annotated if sys.version_info >= (3, 8): from typing import Literal @@ -63,7 +63,9 @@ class State(BaseModelWithConfig): """ correct_map: Dict[ - Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 + Annotated[ + str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$") + ], # noqa : F722 CorrectMap, ] done: Optional[bool] = None @@ -171,23 +173,32 @@ class ProblemCheckEventField(AbstractBaseEventField): """ answers: Dict[ - Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 + Annotated[ + str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$") + ], # noqa : F722 Union[List[str], str], ] attempts: int correct_map: Dict[ - Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 + Annotated[ + str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$") + ], # noqa : F722 CorrectMap, ] grade: int max_grade: int - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] state: State submission: Dict[ - Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 + Annotated[ + str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$") + ], # noqa : F722 SubmissionAnswerField, ] success: Union[Literal["correct"], Literal["incorrect"]] @@ -205,14 +216,19 @@ class ProblemCheckFailEventField(AbstractBaseEventField): """ answers: Dict[ - Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$")], # noqa : F722 + Annotated[ + str, StringConstraints(pattern=r"^[a-f0-9]{32}_[0-9]_[0-9]$") + ], # noqa : F722 Union[List[str], str], ] failure: Union[Literal["closed"], Literal["unreset"]] - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] state: State @@ -236,10 +252,13 @@ class ProblemRescoreEventField(AbstractBaseEventField): new_total: int orig_score: int orig_total: int - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] state: State success: Union[Literal["correct"], Literal["incorrect"]] @@ -254,10 +273,13 @@ class ProblemRescoreFailEventField(AbstractBaseEventField): """ failure: Union[Literal["closed"], Literal["unreset"]] - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] state: State @@ -294,10 +316,13 @@ class ResetProblemEventField(AbstractBaseEventField): new_state: State old_state: State - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] class ResetProblemFailEventField(AbstractBaseEventField): @@ -311,10 +336,13 @@ class ResetProblemFailEventField(AbstractBaseEventField): failure: Union[Literal["closed"], Literal["not_done"]] old_state: State - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] class SaveProblemFailEventField(AbstractBaseEventField): @@ -330,10 +358,13 @@ class SaveProblemFailEventField(AbstractBaseEventField): answers: Dict[str, Union[int, str, list, dict]] failure: Union[Literal["closed"], Literal["done"]] - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] state: State @@ -348,10 +379,13 @@ class SaveProblemSuccessEventField(AbstractBaseEventField): """ answers: Dict[str, Union[int, str, list, dict]] - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] state: State @@ -362,7 +396,10 @@ class ShowAnswerEventField(AbstractBaseEventField): problem_id (str): Consists of the ID of the problem being shown. """ - problem_id: Annotated[str, StringConstraints( - pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 - r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 - )] + problem_id: Annotated[ + str, + StringConstraints( + pattern=r"^block-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+" # noqa : F722 + r"type@problem\+block@[a-f0-9]{32}$" # noqa : F722 + ), + ] diff --git a/src/ralph/models/edx/textbook_interaction/fields/events.py b/src/ralph/models/edx/textbook_interaction/fields/events.py index 01a56fe46..7c58cba8c 100644 --- a/src/ralph/models/edx/textbook_interaction/fields/events.py +++ b/src/ralph/models/edx/textbook_interaction/fields/events.py @@ -24,11 +24,14 @@ class TextbookInteractionBaseEventField(AbstractBaseEventField): """ page: int - chapter: Annotated[str, StringConstraints( - pattern=( - r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa - ) - )] + chapter: Annotated[ + str, + StringConstraints( + pattern=( + r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa + ) + ), + ] class TextbookPdfThumbnailsToggledEventField(TextbookInteractionBaseEventField): @@ -74,11 +77,14 @@ class TextbookPdfChapterNavigatedEventField(AbstractBaseEventField): """ name: Literal["textbook.pdf.chapter.navigated"] - chapter: Annotated[str, StringConstraints( - pattern=( - r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa - ) - )] + chapter: Annotated[ + str, + StringConstraints( + pattern=( + r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa + ) + ), + ] chapter_title: str @@ -263,11 +269,14 @@ class BookEventField(AbstractBaseEventField): clicked or `nextpage` value when the previous page button is clicked. """ - chapter: Annotated[str, StringConstraints( - pattern=( - r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa - ) - )] + chapter: Annotated[ + str, + StringConstraints( + pattern=( + r"^\/asset-v1:[^\/+]+(\/|\+)[^\/+]+(\/|\+)[^\/?]+type@asset\+block.+$" # noqa + ) + ), + ] name: Union[ Literal["textbook.pdf.page.loaded"], Literal["textbook.pdf.page.navigatednext"] ] diff --git a/src/ralph/models/edx/video/fields/events.py b/src/ralph/models/edx/video/fields/events.py index 6a9c926c8..6249c64ee 100644 --- a/src/ralph/models/edx/video/fields/events.py +++ b/src/ralph/models/edx/video/fields/events.py @@ -2,9 +2,10 @@ import sys -from ...base import AbstractBaseEventField from pydantic import ConfigDict +from ...base import AbstractBaseEventField + if sys.version_info >= (3, 8): from typing import Literal else: @@ -20,6 +21,7 @@ class VideoBaseEventField(AbstractBaseEventField): id (str): Consists of the additional videos name if given by the course creators, or the system-generated hash code otherwise. """ + model_config = ConfigDict(extra="allow") code: str diff --git a/src/ralph/models/xapi/base/ifi.py b/src/ralph/models/xapi/base/ifi.py index e36eac372..cf9fa0b89 100644 --- a/src/ralph/models/xapi/base/ifi.py +++ b/src/ralph/models/xapi/base/ifi.py @@ -1,10 +1,10 @@ """Base xAPI `Inverse Functional Identifier` definitions.""" -from pydantic import StringConstraints, AnyUrl, StrictStr +from pydantic import AnyUrl, StrictStr, StringConstraints +from typing_extensions import Annotated from ..config import BaseModelWithConfig from .common import IRI, MailtoEmail -from typing_extensions import Annotated class BaseXapiAccount(BaseModelWithConfig): @@ -36,7 +36,9 @@ class BaseXapiMboxSha1SumIFI(BaseModelWithConfig): mbox_sha1sum (str): Consists of the SHA1 hash of the Agent's email address. """ - mbox_sha1sum: Annotated[str, StringConstraints(pattern=r"^[0-9a-f]{40}$")] # noqa:F722 + mbox_sha1sum: Annotated[ + str, StringConstraints(pattern=r"^[0-9a-f]{40}$") + ] # noqa:F722 class BaseXapiOpenIdIFI(BaseModelWithConfig): diff --git a/src/ralph/models/xapi/base/results.py b/src/ralph/models/xapi/base/results.py index 13c7824ec..0b60a9938 100644 --- a/src/ralph/models/xapi/base/results.py +++ b/src/ralph/models/xapi/base/results.py @@ -5,10 +5,10 @@ from typing import Any, Dict, Optional, Union from pydantic import Field, StrictBool, StrictStr, model_validator +from typing_extensions import Annotated from ..config import BaseModelWithConfig from .common import IRI -from typing_extensions import Annotated class BaseXapiResultScore(BaseModelWithConfig): @@ -26,7 +26,7 @@ class BaseXapiResultScore(BaseModelWithConfig): min: Optional[Decimal] = None max: Optional[Decimal] = None - @model_validator(mode='after') # TODO: needs review + @model_validator(mode="after") # TODO: needs review @classmethod def check_raw_min_max_relation(cls, values: Any) -> Any: """Check the relationship `min < raw < max`.""" diff --git a/src/ralph/models/xapi/base/statements.py b/src/ralph/models/xapi/base/statements.py index 984ce0dd5..3fc2a6a58 100644 --- a/src/ralph/models/xapi/base/statements.py +++ b/src/ralph/models/xapi/base/statements.py @@ -4,7 +4,8 @@ from typing import Any, List, Optional, Union from uuid import UUID -from pydantic import model_validator, StringConstraints +from pydantic import StringConstraints, model_validator +from typing_extensions import Annotated from ..config import BaseModelWithConfig from .agents import BaseXapiAgent @@ -14,7 +15,6 @@ from .objects import BaseXapiObject from .results import BaseXapiResult from .verbs import BaseXapiVerb -from typing_extensions import Annotated class BaseXapiStatement(BaseModelWithConfig): @@ -43,7 +43,9 @@ class BaseXapiStatement(BaseModelWithConfig): timestamp: Optional[datetime] = None stored: Optional[datetime] = None authority: Optional[Union[BaseXapiAgent, BaseXapiGroup]] = None - version: Annotated[str, StringConstraints(pattern=r"^1\.0\.[0-9]+$")] = "1.0.0" # noqa:F722 + version: Annotated[ + str, StringConstraints(pattern=r"^1\.0\.[0-9]+$") + ] = "1.0.0" # noqa:F722 attachments: Optional[List[BaseXapiAttachment]] = None @model_validator(mode="before") diff --git a/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py b/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py index defcbc266..8e6019120 100644 --- a/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py +++ b/src/ralph/models/xapi/concepts/verbs/tincan_vocabulary.py @@ -37,7 +37,9 @@ class DownloadedVerb(BaseXapiVerb): id: Literal[ "http://id.tincanapi.com/verb/downloaded" ] = "http://id.tincanapi.com/verb/downloaded" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["downloaded"]]] = None # TODO: remove literal for LANG_EN_US_DISPLAY ? + display: Optional[ + Dict[Literal[LANG_EN_US_DISPLAY], Literal["downloaded"]] + ] = None # TODO: remove literal for LANG_EN_US_DISPLAY ? class UnregisteredVerb(BaseXapiVerb): diff --git a/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py b/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py index 9aaded225..b9e618933 100644 --- a/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py +++ b/src/ralph/models/xapi/concepts/verbs/virtual_classroom.py @@ -54,7 +54,9 @@ class StartedCameraVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/started-camera" ] = "https://w3id.org/xapi/virtual-classroom/verbs/started-camera" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["started camera"]]] = None + display: Optional[ + Dict[Literal[LANG_EN_US_DISPLAY], Literal["started camera"]] + ] = None class StoppedCameraVerb(BaseXapiVerb): @@ -69,7 +71,9 @@ class StoppedCameraVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/stopped-camera" ] = "https://w3id.org/xapi/virtual-classroom/verbs/stopped-camera" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["stopped camera"]]] = None + display: Optional[ + Dict[Literal[LANG_EN_US_DISPLAY], Literal["stopped camera"]] + ] = None class SharedScreenVerb(BaseXapiVerb): @@ -84,7 +88,9 @@ class SharedScreenVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/shared-screen" ] = "https://w3id.org/xapi/virtual-classroom/verbs/shared-screen" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["shared screen"]]] = None + display: Optional[ + Dict[Literal[LANG_EN_US_DISPLAY], Literal["shared screen"]] + ] = None class UnsharedScreenVerb(BaseXapiVerb): @@ -99,7 +105,9 @@ class UnsharedScreenVerb(BaseXapiVerb): id: Literal[ "https://w3id.org/xapi/virtual-classroom/verbs/unshared-screen" ] = "https://w3id.org/xapi/virtual-classroom/verbs/unshared-screen" - display: Optional[Dict[Literal[LANG_EN_US_DISPLAY], Literal["unshared screen"]]] = None + display: Optional[ + Dict[Literal[LANG_EN_US_DISPLAY], Literal["unshared screen"]] + ] = None class RaisedHandVerb(BaseXapiVerb): diff --git a/src/ralph/models/xapi/config.py b/src/ralph/models/xapi/config.py index bc8218bda..1dce8bb0c 100644 --- a/src/ralph/models/xapi/config.py +++ b/src/ralph/models/xapi/config.py @@ -1,13 +1,15 @@ """Base xAPI model configuration.""" -from pydantic import ConfigDict, BaseModel +from pydantic import BaseModel, ConfigDict class BaseModelWithConfig(BaseModel): """Pydantic model for base configuration shared among all models.""" + model_config = ConfigDict(extra="forbid", str_min_length=1) class BaseExtensionModelWithConfig(BaseModel): """Pydantic model for extension configuration shared among all models.""" + model_config = ConfigDict(extra="allow", str_min_length=0) diff --git a/tests/fixtures/hypothesis_strategies.py b/tests/fixtures/hypothesis_strategies.py index 40646eb3d..73f351e00 100644 --- a/tests/fixtures/hypothesis_strategies.py +++ b/tests/fixtures/hypothesis_strategies.py @@ -85,7 +85,9 @@ def custom_builds( continue is_required = field.is_required or (arg is not None and _overwrite_default) required_optional = required if is_required or arg is not None else optional - field_strategy = get_strategy_from(field.outer_type_) if arg is None else arg + field_strategy = ( + get_strategy_from(field.annotation) if arg is None else arg + ) # TODO: validate this change is not failing silently required_optional[field.alias] = field_strategy if not required: # To avoid generating empty values diff --git a/tests/test_cli.py b/tests/test_cli.py index 827543378..2b6d5abeb 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,7 +2,7 @@ import json import logging from pathlib import Path -from typing import Union +from typing import Optional, Union import pytest from click.exceptions import BadParameter @@ -170,7 +170,7 @@ def _gen_cli_auth_args( scopes: list, ifi_command: str, ifi_value: Union[str, dict], - agent_name: str = None, + agent_name: Optional[str] = None, write: bool = False, ): """Generate arguments for cli to create user.""" @@ -193,8 +193,8 @@ def _assert_matching_basic_auth_credentials( scopes: list, ifi_type: str, ifi_value: Union[str, dict], - agent_name: str = None, - hash_: str = None, + agent_name: Optional[str] = None, + hash_: Optional[str] = None, ): """Assert that credentials match other arguments. From f2ff59d44c8fe977e7b98943f397bf41f3f05bf7 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Tue, 7 Nov 2023 18:04:05 +0100 Subject: [PATCH 62/65] wip --- src/ralph/api/auth/basic.py | 17 ++-- src/ralph/api/auth/user.py | 23 +++-- src/ralph/api/routers/statements.py | 4 +- src/ralph/backends/data/es.py | 4 +- src/ralph/backends/data/mongo.py | 2 +- src/ralph/backends/lrs/base.py | 4 +- src/ralph/backends/lrs/es.py | 4 +- src/ralph/backends/lrs/mongo.py | 2 +- src/ralph/cli.py | 35 +++++-- src/ralph/conf.py | 4 +- src/ralph/models/xapi/base/common.py | 95 +++++++++++-------- .../models/xapi/base/unnested_objects.py | 4 +- src/ralph/models/xapi/lms/contexts.py | 4 +- src/ralph/models/xapi/video/contexts.py | 4 +- .../models/xapi/virtual_classroom/contexts.py | 4 +- tests/backends/lrs/test_async_es.py | 8 +- tests/backends/lrs/test_async_mongo.py | 8 +- tests/backends/lrs/test_clickhouse.py | 8 +- tests/backends/lrs/test_es.py | 8 +- tests/backends/lrs/test_fs.py | 2 +- tests/backends/lrs/test_mongo.py | 10 +- tests/fixtures/hypothesis_strategies.py | 2 +- tests/models/test_converter.py | 27 +++--- tests/models/test_validator.py | 3 +- tests/test_cli.py | 23 +++++ 25 files changed, 190 insertions(+), 119 deletions(-) diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index 0c503d25c..400406e01 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -50,24 +50,25 @@ class ServerUsersCredentials(RootModel[List[UserCredentials]]): __root__ (List): Custom root consisting of the list of all server users credentials. """ - - def __add__(self, other) -> Any: # noqa: D105 - return ServerUsersCredentials.parse_obj(self.__root__ + other.__root__) + root: List[UserCredentials] + + def __add__(self, other): # noqa: D105 + return ServerUsersCredentials.parse_obj(self.root + other.root) def __getitem__(self, item: int) -> UserCredentials: # noqa: D105 - return self.__root__[item] + return self.root[item] def __len__(self) -> int: # noqa: D105 - return len(self.__root__) + return len(self.root) def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105 - return iter(self.__root__) + return iter(self.root) - @model_validator(mode="after") + @model_validator(mode="before") @classmethod def ensure_unique_username(cls, values: Any) -> Any: """Every username should be unique among registered users.""" - usernames = [entry.username for entry in values.get("__root__")] + usernames = [entry.username for entry in values] if len(usernames) != len(set(usernames)): raise ValueError( "You cannot create multiple credentials with the same username" diff --git a/src/ralph/api/auth/user.py b/src/ralph/api/auth/user.py index a5476a212..630b5e83b 100644 --- a/src/ralph/api/auth/user.py +++ b/src/ralph/api/auth/user.py @@ -19,7 +19,9 @@ ] -class UserScopes(FrozenSet[Scope]): +from pydantic import RootModel + +class UserScopes(RootModel[FrozenSet[Scope]]): """Scopes available to users.""" @lru_cache(maxsize=1024) @@ -54,16 +56,17 @@ def is_authorized(self, requested_scope: Scope): return requested_scope in expanded_user_scopes - @classmethod - # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. - # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. - def __get_validators__(cls): # noqa: D105 - def validate(value: FrozenSet[Scope]): - """Transform value to an instance of UserScopes.""" - return cls(value) + # @classmethod + # # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. + # # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. + # def __get_validators__(cls): # noqa: D105 + # def validate(value: FrozenSet[Scope]): + # """Transform value to an instance of UserScopes.""" + # return cls(value) - yield validate + # yield validate +from ralph.models.xapi.base.agents import BaseXapiAgent class AuthenticatedUser(BaseModel): """Pydantic model for user authentication. @@ -73,5 +76,5 @@ class AuthenticatedUser(BaseModel): scopes (list): The scopes the user has access to. """ - agent: Dict + agent: BaseXapiAgent scopes: UserScopes diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index 436f82c52..fcf37b8cb 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -119,7 +119,7 @@ def _parse_agent_parameters(agent_obj: dict) -> AgentParameters: agent_query_params["account__home_page"] = agent.account.homePage # Overwrite `agent` field - return AgentParameters.construct(**agent_query_params) + return AgentParameters.model_construct(**agent_query_params) def strict_query_params(request: Request) -> None: @@ -357,7 +357,7 @@ async def get( try: query_result = await await_if_coroutine( BACKEND_CLIENT.query_statements( - RalphStatementsQuery.construct(**{**query_params, "limit": limit}) + RalphStatementsQuery.model_construct(**{**query_params, "limit": limit}) ) ) except BackendException as error: diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 7bb92ff81..cee41b93f 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -71,7 +71,9 @@ class ESDataBackendSettings(BaseDataBackendSettings): HOSTS: CommaSeparatedTuple = ("http://localhost:9200",) LOCALE_ENCODING: str = "utf8" POINT_IN_TIME_KEEP_ALIVE: str = "1m" - REFRESH_AFTER_WRITE: Union[Literal["false", "true", "wait_for"], bool, str, None] = False # TODO: check that this is the good default + REFRESH_AFTER_WRITE: Union[ + Literal["false", "true", "wait_for"], bool, str, None + ] = False # TODO: check that this is the good default class ESQueryPit(BaseModel): diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index c0b4c50af..d615ef25d 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -72,7 +72,7 @@ class MongoDataBackendSettings(BaseDataBackendSettings): ) CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/") - #CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/", scheme="mongodb") # TODO: check why we remove scheme + # CONNECTION_URI: MongoDsn = MongoDsn("mongodb://localhost:27017/", scheme="mongodb") # TODO: check why we remove scheme DEFAULT_DATABASE: Annotated[ str, StringConstraints(pattern=r"^[^\s.$/\\\"\x00]+$") ] = "statements" # noqa : F722 diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 96d434afd..f9a1637c6 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -43,10 +43,10 @@ class AgentParameters(BaseModel): class RalphStatementsQuery(LRSStatementsQuery): """Represents a dictionary of possible LRS query parameters.""" - agent: Optional[AgentParameters] = AgentParameters.construct() + agent: Optional[AgentParameters] = AgentParameters.model_construct() search_after: Optional[str] = None pit_id: Optional[str] = None - authority: Optional[AgentParameters] = AgentParameters.construct() + authority: Optional[AgentParameters] = AgentParameters.model_construct() ignore_order: Optional[bool] = None diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index 3b57511a2..af476aacc 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -72,7 +72,7 @@ def get_query(params: RalphStatementsQuery) -> ESQuery: es_query_filters += [{"range": {"timestamp": {"lte": params.until}}}] es_query = { - "pit": ESQueryPit.construct(id=params.pit_id), + "pit": ESQueryPit.model_construct(id=params.pit_id), "size": params.limit, "sort": [{"timestamp": {"order": "asc" if params.ascending else "desc"}}], } @@ -86,7 +86,7 @@ def get_query(params: RalphStatementsQuery) -> ESQuery: es_query["search_after"] = params.search_after.split("|") # Note: `params` fields are validated thus we skip their validation in ESQuery. - return ESQuery.construct(**es_query) + return ESQuery.model_construct(**es_query) @staticmethod def _add_agent_filters( diff --git a/src/ralph/backends/lrs/mongo.py b/src/ralph/backends/lrs/mongo.py index 3fd1e2bae..f49b1344b 100644 --- a/src/ralph/backends/lrs/mongo.py +++ b/src/ralph/backends/lrs/mongo.py @@ -96,7 +96,7 @@ def get_query(params: RalphStatementsQuery) -> MongoQuery: ] # Note: `params` fields are validated thus we skip MongoQuery validation. - return MongoQuery.construct( + return MongoQuery.model_construct( filter=mongo_query_filters, limit=params.limit, sort=mongo_query_sort ) diff --git a/src/ralph/cli.py b/src/ralph/cli.py index 782944c34..1d9ee21e2 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -205,7 +205,7 @@ def backends_options(name=None, backend_types: Optional[Sequence[BaseModel]] = N def wrapper(command): backend_names = [] - for backend_name, backend in sorted( + for backend_name, backend in sorted( # e.g: "ASYNC_ES", ESDataBackendSettings() [ name_backend for backend_type in backend_types @@ -217,7 +217,7 @@ def wrapper(command): backend_name = backend_name.lower() backend_names.append(backend_name) for field_name, field in sorted(backend, key=lambda x: x[0], reverse=True): - field_type = backend.__fields__[field_name].type_ + field_type = type(backend.model_fields[field_name])#.annotation.__origin__ field_name = f"{backend_name}-{field_name.lower()}".replace("_", "-") option = f"--{field_name}" option_kwargs = {} @@ -365,9 +365,11 @@ def auth( # Import required Pydantic models dynamically so that we don't create a # direct dependency between the CLI and the LRS # pylint: disable=invalid-name + logger.warning('ok aaa') ServerUsersCredentials = import_string( "ralph.api.auth.basic.ServerUsersCredentials" ) + logger.warning('ok bbb') UserCredentialsBasicAuth = import_string("ralph.api.auth.basic.UserCredentials") # NB: renaming classes below for clarity @@ -381,12 +383,14 @@ def auth( "ralph.models.xapi.base.agents.BaseXapiAgentWithAccount" ) + logger.warning('ok ccc') if agent_ifi_mbox: if agent_ifi_mbox[:7] != "mailto:": raise click.UsageError( 'Mbox field must start with "mailto:" (e.g.: "mailto:foo@bar.com")' ) agent = AgentMbox(mbox=agent_ifi_mbox, name=agent_name, objectType="Agent") + logger.warning('ok ddd') if agent_ifi_mbox_sha1sum: agent = AgentMboxSha1sum( mbox_sha1sum=agent_ifi_mbox_sha1sum, name=agent_name, objectType="Agent" @@ -408,6 +412,7 @@ def auth( scopes=scope, agent=agent, ) + logger.warning('ok eee') if write_to_disk: logger.info("Will append new credentials to: %s", settings.AUTH_FILE) @@ -419,23 +424,39 @@ def auth( auth_file.parent.mkdir(parents=True, exist_ok=True) auth_file.touch() - users = ServerUsersCredentials.parse_obj([]) + logger.warning('ok fff') + users = ServerUsersCredentials.model_validate([]) + logger.warning('ok fffgloser') + + logger.warning(auth_file) # Parse credentials file if not empty if auth_file.stat().st_size: - users = ServerUsersCredentials.parse_file(auth_file) - users += ServerUsersCredentials.parse_obj( + with open(auth_file, encoding=settings.LOCALE_ENCODING) as f: + users = ServerUsersCredentials.model_validate_json(json.load(f)) + + logger.warning('ok fffa') + logger.warning(type(ServerUsersCredentials.model_validate( + [ + credentials, + ] + ))) + + users += ServerUsersCredentials.model_validate( [ credentials, ] ) - auth_file.write_text(users.json(indent=2), encoding=settings.LOCALE_ENCODING) + + logger.warning('ok fffb') + + auth_file.write_text(users.model_dump_json(indent=2), encoding=settings.LOCALE_ENCODING) logger.info("User %s has been added to: %s", username, settings.AUTH_FILE) else: click.echo( ( f"Copy/paste the following credentials to your LRS authentication " f"file located in: {settings.AUTH_FILE}\n" - f"{credentials.json(indent=2)}" + f"{credentials.model_dump_json(indent=2)}" ) ) diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 55b0a8cc8..33e4e3219 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -108,9 +108,7 @@ def validate_comma_separated_tuple(value: Union[str, Tuple[str, ...]]) -> Tuple[ class InstantiableSettingsItem(BaseModel): """Pydantic model for a settings configuration item that can be instantiated.""" - # TODO[pydantic]: The following keys were removed: `underscore_attrs_are_private`. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - model_config = SettingsConfigDict(underscore_attrs_are_private=True) + model_config = SettingsConfigDict() _class_path: Optional[str] = None diff --git a/src/ralph/models/xapi/base/common.py b/src/ralph/models/xapi/base/common.py index 719a96162..0bcd7d449 100644 --- a/src/ralph/models/xapi/base/common.py +++ b/src/ralph/models/xapi/base/common.py @@ -1,58 +1,77 @@ """Common for xAPI base definitions.""" -from typing import Dict, Generator, Type +from typing import Annotated, Dict, Generator, Type from langcodes import tag_is_valid -from pydantic import StrictStr, validate_email +from pydantic import AfterValidator, StrictStr, validate_email from rfc3987 import parse +def validate_iri(iri): + """Check whether the provided IRI is a valid RFC 3987 IRI.""" + parse(iri, rule="IRI") + return iri -class IRI(str): - """Pydantic custom data type validating RFC 3987 IRIs.""" +IRI = Annotated[str, AfterValidator(validate_iri)] - @classmethod - # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. - # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. - def __get_validators__(cls) -> Generator: # noqa: D105 - def validate(iri: str) -> Type["IRI"]: - """Check whether the provided IRI is a valid RFC 3987 IRI.""" - parse(iri, rule="IRI") - return cls(iri) +# class IRI(str): +# """Pydantic custom data type validating RFC 3987 IRIs.""" - yield validate +# @classmethod +# # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. +# # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. +# def __get_validators__(cls) -> Generator: # noqa: D105 +# def validate(iri: str) -> Type["IRI"]: -class LanguageTag(str): - """Pydantic custom data type validating RFC 5646 Language tags.""" +# yield validate - @classmethod - # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. - # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. - def __get_validators__(cls) -> Generator: # noqa: D105 - def validate(tag: str) -> Type["LanguageTag"]: - """Check whether the provided tag is a valid RFC 5646 Language tag.""" - if not tag_is_valid(tag): - raise TypeError("Invalid RFC 5646 Language tag") - return cls(tag) - yield validate +# class LanguageTag(str): +# """Pydantic custom data type validating RFC 5646 Language tags.""" +# @classmethod +# # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. +# # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. +# def __get_validators__(cls) -> Generator: # noqa: D105 +# def validate(tag: str) -> Type["LanguageTag"]: +# """Check whether the provided tag is a valid RFC 5646 Language tag.""" +# if not tag_is_valid(tag): +# raise TypeError("Invalid RFC 5646 Language tag") +# return cls(tag) +# yield validate + +def validate_language_tag(tag): + """Check whether the provided tag is a valid RFC 5646 Language tag.""" + if not tag_is_valid(tag): + raise TypeError("Invalid RFC 5646 Language tag") + return tag + +LanguageTag = Annotated[str, AfterValidator(validate_language_tag)] LanguageMap = Dict[LanguageTag, StrictStr] -class MailtoEmail(str): - """Pydantic custom data type validating `mailto:email` format.""" +def validate_mailto_email(mailto: str): + """Check whether the provided value follows the `mailto:email` format.""" + if not mailto.startswith("mailto:"): + raise TypeError("Invalid `mailto:email` value") + valid = validate_email(mailto[7:]) + return f"mailto:{valid[1]}" + +MailtoEmail = Annotated[str, AfterValidator(validate_mailto_email)] + +# class MailtoEmail(str): +# """Pydantic custom data type validating `mailto:email` format.""" - @classmethod - # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. - # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. - def __get_validators__(cls) -> Generator: # noqa: D105 - def validate(mailto: str) -> Type["MailtoEmail"]: - """Check whether the provided value follows the `mailto:email` format.""" - if not mailto.startswith("mailto:"): - raise TypeError("Invalid `mailto:email` value") - valid = validate_email(mailto[7:]) - return cls(f"mailto:{valid[1]}") +# @classmethod +# # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. +# # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. +# def __get_validators__(cls) -> Generator: # noqa: D105 +# def validate(mailto: str) -> Type["MailtoEmail"]: +# """Check whether the provided value follows the `mailto:email` format.""" +# if not mailto.startswith("mailto:"): +# raise TypeError("Invalid `mailto:email` value") +# valid = validate_email(mailto[7:]) +# return cls(f"mailto:{valid[1]}") - yield validate +# yield validate diff --git a/src/ralph/models/xapi/base/unnested_objects.py b/src/ralph/models/xapi/base/unnested_objects.py index 1ed0de2a9..133c66581 100644 --- a/src/ralph/models/xapi/base/unnested_objects.py +++ b/src/ralph/models/xapi/base/unnested_objects.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import AnyUrl, StrictStr, StringConstraints, validator +from pydantic import AnyUrl, StrictStr, StringConstraints, field_validator from ..config import BaseModelWithConfig from .common import IRI, LanguageMap @@ -79,7 +79,7 @@ class BaseXapiActivityInteractionDefinition(BaseXapiActivityDefinition): target: Optional[List[BaseXapiInteractionComponent]] steps: Optional[List[BaseXapiInteractionComponent]] - @validator("choices", "scale", "source", "target", "steps") + @field_validator("choices", "scale", "source", "target", "steps") @classmethod def check_unique_ids(cls, value: Any) -> None: """Check the uniqueness of interaction components IDs.""" diff --git a/src/ralph/models/xapi/lms/contexts.py b/src/ralph/models/xapi/lms/contexts.py index 3de7a3c2b..ebc76addc 100644 --- a/src/ralph/models/xapi/lms/contexts.py +++ b/src/ralph/models/xapi/lms/contexts.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union from uuid import UUID -from pydantic import Field, NonNegativeFloat, PositiveInt, condecimal, validator +from pydantic import Field, NonNegativeFloat, PositiveInt, condecimal, field_validator from ..base.contexts import BaseXapiContext, BaseXapiContextContextActivities from ..base.unnested_objects import BaseXapiActivity @@ -49,7 +49,7 @@ class LMSContextContextActivities(BaseXapiContextContextActivities): LMSProfileActivity, List[Union[LMSProfileActivity, BaseXapiActivity]] ] - @validator("category") + @field_validator("category") @classmethod def check_presence_of_profile_activity_category( cls, diff --git a/src/ralph/models/xapi/video/contexts.py b/src/ralph/models/xapi/video/contexts.py index 4bdda8ff3..aaf8a8611 100644 --- a/src/ralph/models/xapi/video/contexts.py +++ b/src/ralph/models/xapi/video/contexts.py @@ -4,7 +4,7 @@ from typing import List, Optional, Union from uuid import UUID -from pydantic import Field, NonNegativeFloat, validator +from pydantic import Field, NonNegativeFloat, field_validator from ..base.contexts import BaseXapiContext, BaseXapiContextContextActivities from ..base.unnested_objects import BaseXapiActivity @@ -51,7 +51,7 @@ class VideoContextContextActivities(BaseXapiContextContextActivities): VideoProfileActivity, List[Union[VideoProfileActivity, BaseXapiActivity]] ] - @validator("category") + @field_validator("category") @classmethod def check_presence_of_profile_activity_category( cls, diff --git a/src/ralph/models/xapi/virtual_classroom/contexts.py b/src/ralph/models/xapi/virtual_classroom/contexts.py index 57784faf2..e6feec946 100644 --- a/src/ralph/models/xapi/virtual_classroom/contexts.py +++ b/src/ralph/models/xapi/virtual_classroom/contexts.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union from uuid import UUID -from pydantic import Field, validator +from pydantic import Field, field_validator from ..base.contexts import BaseXapiContext, BaseXapiContextContextActivities from ..base.unnested_objects import BaseXapiActivity @@ -45,7 +45,7 @@ class VirtualClassroomContextContextActivities(BaseXapiContextContextActivities) List[Union[VirtualClassroomProfileActivity, BaseXapiActivity]], ] - @validator("category") + @field_validator("category") @classmethod def check_presence_of_profile_activity_category( cls, diff --git a/tests/backends/lrs/test_async_es.py b/tests/backends/lrs/test_async_es.py index 4d034922c..7ec90334a 100644 --- a/tests/backends/lrs/test_async_es.py +++ b/tests/backends/lrs/test_async_es.py @@ -271,7 +271,7 @@ async def mock_read(query, chunk_size): backend = async_es_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = await backend.query_statements(RalphStatementsQuery.construct(**params)) + result = await backend.query_statements(RalphStatementsQuery.model_construct(**params)) assert result.statements == [{}] assert result.pit_id == "foo_pit_id" assert result.search_after == "bar_search_after|baz_search_after" @@ -294,7 +294,7 @@ async def test_backends_lrs_async_es_lrs_backend_query_statements( assert await backend.write(documents) == 1 # Check the expected search query results. - result = await backend.query_statements(RalphStatementsQuery.construct(limit=10)) + result = await backend.query_statements(RalphStatementsQuery.model_construct(limit=10)) assert result.statements == documents assert re.match(r"[0-9]+\|0", result.search_after) @@ -321,7 +321,7 @@ async def mock_read(**_): msg = "Query error" with pytest.raises(BackendException, match=msg): with caplog.at_level(logging.ERROR): - await backend.query_statements(RalphStatementsQuery.construct()) + await backend.query_statements(RalphStatementsQuery.model_construct()) await backend.close() @@ -354,7 +354,7 @@ def mock_search(**_): _ = [ statement async for statement in backend.query_statements_by_ids( - RalphStatementsQuery.construct() + RalphStatementsQuery.model_construct() ) ] diff --git a/tests/backends/lrs/test_async_mongo.py b/tests/backends/lrs/test_async_mongo.py index b0ed2d09a..3b8ad2931 100644 --- a/tests/backends/lrs/test_async_mongo.py +++ b/tests/backends/lrs/test_async_mongo.py @@ -239,7 +239,7 @@ async def mock_read(query, chunk_size): backend = async_mongo_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = await backend.query_statements(RalphStatementsQuery.construct(**params)) + result = await backend.query_statements(RalphStatementsQuery.model_construct(**params)) assert result.statements == [{}] assert not result.pit_id assert result.search_after == "search_after_id" @@ -270,7 +270,7 @@ async def test_backends_lrs_async_mongo_lrs_backend_query_statements_with_succes ] assert await backend.write(documents) == 2 - statement_parameters = RalphStatementsQuery.construct( + statement_parameters = RalphStatementsQuery.model_construct( statement_id="62b9ce922c26b46b68ffc68f", agent={ "account__name": "test_name", @@ -312,7 +312,7 @@ async def mock_read(**_): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - await backend.query_statements(RalphStatementsQuery.construct()) + await backend.query_statements(RalphStatementsQuery.model_construct()) assert ( "ralph.backends.lrs.async_mongo", @@ -345,7 +345,7 @@ async def mock_read(**_): _ = [ statement async for statement in backend.query_statements_by_ids( - RalphStatementsQuery.construct() + RalphStatementsQuery.model_construct() ) ] diff --git a/tests/backends/lrs/test_clickhouse.py b/tests/backends/lrs/test_clickhouse.py index c7d4fb8a4..bca3ac020 100644 --- a/tests/backends/lrs/test_clickhouse.py +++ b/tests/backends/lrs/test_clickhouse.py @@ -269,7 +269,7 @@ def mock_read(query, target, ignore_errors): backend = clickhouse_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - backend.query_statements(RalphStatementsQuery.construct(**params)) + backend.query_statements(RalphStatementsQuery.model_construct(**params)) backend.close() @@ -301,7 +301,7 @@ def test_backends_lrs_clickhouse_lrs_backend_query_statements( # Check the expected search query results. result = backend.query_statements( - RalphStatementsQuery.construct(statementId=test_id, limit=10) + RalphStatementsQuery.model_construct(statementId=test_id, limit=10) ) assert result.statements == statements backend.close() @@ -331,7 +331,7 @@ def test_backends_lrs_clickhouse_lrs_backend__find(clickhouse, clickhouse_lrs_ba assert success == 1 # Check the expected search query results. - result = backend.query_statements(RalphStatementsQuery.construct()) + result = backend.query_statements(RalphStatementsQuery.model_construct()) assert result.statements == statements backend.close() @@ -387,7 +387,7 @@ def mock_query(*args, **kwargs): msg = "Failed to read documents: Query error" with pytest.raises(BackendException, match=msg): - next(backend.query_statements(RalphStatementsQuery.construct())) + next(backend.query_statements(RalphStatementsQuery.model_construct())) assert ( "ralph.backends.lrs.clickhouse", diff --git a/tests/backends/lrs/test_es.py b/tests/backends/lrs/test_es.py index 151ae3af3..e9504acbd 100644 --- a/tests/backends/lrs/test_es.py +++ b/tests/backends/lrs/test_es.py @@ -270,7 +270,7 @@ def mock_read(query, chunk_size): backend = es_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = backend.query_statements(RalphStatementsQuery.construct(**params)) + result = backend.query_statements(RalphStatementsQuery.model_construct(**params)) assert not result.statements assert result.pit_id == "foo_pit_id" assert result.search_after == "bar_search_after|baz_search_after" @@ -290,7 +290,7 @@ def test_backends_lrs_es_lrs_backend_query_statements(es, es_lrs_backend): assert backend.write(documents) == 1 # Check the expected search query results. - result = backend.query_statements(RalphStatementsQuery.construct(limit=10)) + result = backend.query_statements(RalphStatementsQuery.model_construct(limit=10)) assert result.statements == documents assert re.match(r"[0-9]+\|0", result.search_after) @@ -315,7 +315,7 @@ def mock_read(**_): msg = "Query error" with pytest.raises(BackendException, match=msg): with caplog.at_level(logging.ERROR): - backend.query_statements(RalphStatementsQuery.construct()) + backend.query_statements(RalphStatementsQuery.model_construct()) assert ( "ralph.backends.lrs.es", @@ -344,7 +344,7 @@ def mock_search(**_): msg = r"Failed to execute Elasticsearch query: ApiError\(None, 'Query error'\)" with pytest.raises(BackendException, match=msg): with caplog.at_level(logging.ERROR): - list(backend.query_statements_by_ids(RalphStatementsQuery.construct())) + list(backend.query_statements_by_ids(RalphStatementsQuery.model_construct())) assert ( "ralph.backends.lrs.es", diff --git a/tests/backends/lrs/test_fs.py b/tests/backends/lrs/test_fs.py index b64bd518f..9e1ed5d07 100644 --- a/tests/backends/lrs/test_fs.py +++ b/tests/backends/lrs/test_fs.py @@ -260,7 +260,7 @@ def test_backends_lrs_fs_lrs_backend_query_statements_query( ] backend = fs_lrs_backend() backend.write(statements) - result = backend.query_statements(RalphStatementsQuery.construct(**params)) + result = backend.query_statements(RalphStatementsQuery.model_construct(**params)) ids = [statement.get("id") for statement in result.statements] assert ids == expected_statement_ids diff --git a/tests/backends/lrs/test_mongo.py b/tests/backends/lrs/test_mongo.py index aa643c6ff..a40da1091 100644 --- a/tests/backends/lrs/test_mongo.py +++ b/tests/backends/lrs/test_mongo.py @@ -238,7 +238,7 @@ def mock_read(query, chunk_size): backend = mongo_lrs_backend() monkeypatch.setattr(backend, "read", mock_read) - result = backend.query_statements(RalphStatementsQuery.construct(**params)) + result = backend.query_statements(RalphStatementsQuery.model_construct(**params)) assert result.statements == [{}] assert not result.pit_id assert result.search_after == "search_after_id" @@ -267,9 +267,9 @@ def test_backends_lrs_mongo_lrs_backend_query_statements_with_success( ] assert backend.write(documents) == 2 - statement_parameters = RalphStatementsQuery.construct( + statement_parameters = RalphStatementsQuery.model_construct( statementId="62b9ce922c26b46b68ffc68f", - agent=AgentParameters.construct( + agent=AgentParameters.model_construct( account__name="test_name", account__home_page="http://example.com", ), @@ -309,7 +309,7 @@ def mock_read(**_): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - backend.query_statements(RalphStatementsQuery.construct()) + backend.query_statements(RalphStatementsQuery.model_construct()) assert ( "ralph.backends.lrs.mongo", @@ -339,7 +339,7 @@ def mock_read(**_): with caplog.at_level(logging.ERROR): with pytest.raises(BackendException, match=msg): - list(backend.query_statements_by_ids(RalphStatementsQuery.construct())) + list(backend.query_statements_by_ids(RalphStatementsQuery.model_construct())) assert ( "ralph.backends.lrs.mongo", diff --git a/tests/fixtures/hypothesis_strategies.py b/tests/fixtures/hypothesis_strategies.py index 73f351e00..b7c4dd52d 100644 --- a/tests/fixtures/hypothesis_strategies.py +++ b/tests/fixtures/hypothesis_strategies.py @@ -79,7 +79,7 @@ def custom_builds( break optional = {} required = {} - for name, field in klass.__fields__.items(): + for name, field in klass.model_fields.items(): arg = kwargs.get(name, None) if arg is False: continue diff --git a/tests/models/test_converter.py b/tests/models/test_converter.py index e9b8b4ef3..c4226b2fe 100644 --- a/tests/models/test_converter.py +++ b/tests/models/test_converter.py @@ -101,22 +101,25 @@ def test_converter_conversion_item_get_value_with_successful_transformers( assert conversion_item.get_value(event) == expected -@pytest.mark.parametrize("event", [{}, {"foo": "bar"}]) -def test_converter_convert_dict_event_with_empty_conversion_set(event): - """Test when the conversion_set is empty, convert_dict_event should return an empty - model. - """ +# TODO: take care of this +# @pytest.mark.parametrize("event", [{}, {"foo": "bar"}]) +# def test_converter_convert_dict_event_with_empty_conversion_set(event): +# """Test when the conversion_set is empty, convert_dict_event should return an empty +# model. +# """ +# class DummyModel(BaseModel): +# pass - class DummyBaseConversionSet(BaseConversionSet): - """Dummy implementation of abstract BaseConversionSet.""" +# class DummyBaseConversionSet(BaseConversionSet): +# """Dummy implementation of abstract BaseConversionSet.""" - __dest__ = BaseModel +# __dest__ = DummyModel - def _get_conversion_items(self): # pylint: disable=no-self-use - """Return a set of ConversionItems used for conversion.""" - return set() +# def _get_conversion_items(self): # pylint: disable=no-self-use +# """Return a set of ConversionItems used for conversion.""" +# return set() - assert not convert_dict_event(event, "", DummyBaseConversionSet()).dict() +# assert not convert_dict_event(event, "", DummyBaseConversionSet()).dict() @pytest.mark.parametrize("event", [{"foo": "foo_value", "bar": "bar_value"}]) diff --git a/tests/models/test_validator.py b/tests/models/test_validator.py index 4777dac47..33fa46cec 100644 --- a/tests/models/test_validator.py +++ b/tests/models/test_validator.py @@ -3,6 +3,7 @@ import copy import json import logging +from typing import Annotated import pytest from hypothesis import HealthCheck, settings @@ -205,7 +206,7 @@ def test_models_validator_validate_typing_cleanup(event): @pytest.mark.parametrize( "event, models, expected", - [({"foo": 1}, [Server, create_model("A", foo=1)], create_model("A", foo=1))], + [({"foo": 1}, [Server, create_model("A", foo=(int, 1))], create_model("A", foo=(int, 1)))], ) def test_models_validator_get_first_valid_model_with_match(event, models, expected): """Test that the `get_first_valid_model` method returns the expected model.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 2b6d5abeb..ca795515f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -349,19 +349,28 @@ def test_cli_auth_command_when_writing_auth_file( # The authentication file does not exist + print("yeah yo ok 111") + # Add a first user cli_args = _gen_cli_auth_args( username_1, password_1, scopes_1, ifi_command_1, ifi_value_1, write=True ) + print("yeah yo ok 2") + + print("cli args are", cli_args) assert Path(settings.AUTH_FILE).exists() is False result = runner.invoke(cli, cli_args) + print("yeah yo ok 2.5") assert result.exit_code == 0 assert Path(settings.AUTH_FILE).exists() is True with Path(settings.AUTH_FILE).open(encoding="utf-8") as auth_file: all_credentials = json.loads("\n".join(auth_file.readlines())) assert len(all_credentials) == 1 + + print("yeah yo ok 3") + # Check that the first user matches ifi_type_1 = _ifi_type_from_command(ifi_command=ifi_command_1) ifi_value_1 = _ifi_value_from_command(ifi_value_1, ifi_type_1) @@ -373,6 +382,9 @@ def test_cli_auth_command_when_writing_auth_file( ifi_value=ifi_value_1, ) + + print("yeah yo ok 4") + # Add a second user username_2 = "lol" password_2 = "baz" @@ -383,11 +395,17 @@ def test_cli_auth_command_when_writing_auth_file( ) result = runner.invoke(cli, cli_args) + + print("yeah yo ok 5") + assert result.exit_code == 0 with Path(settings.AUTH_FILE).open(encoding="utf-8") as auth_file: all_credentials = json.loads("\n".join(auth_file.readlines())) assert len(all_credentials) == 2 + + print("yeah yo ok 6") + # Check that the first user still matches _assert_matching_basic_auth_credentials( credentials=all_credentials[0], @@ -397,6 +415,9 @@ def test_cli_auth_command_when_writing_auth_file( ifi_value=ifi_value_1, ) + + print("yeah yo ok 7") + # Check that the second user matches ifi_type_2 = _ifi_type_from_command(ifi_command=ifi_command_2) ifi_value_2 = _ifi_value_from_command(ifi_value_2, ifi_type_2) @@ -409,6 +430,8 @@ def test_cli_auth_command_when_writing_auth_file( ) + print("yeah yo ok 8") + # pylint: disable=invalid-name def test_cli_auth_command_when_writing_auth_file_with_incorrect_auth_file(fs): """Test ralph auth command when credentials are written in the authentication From f60de9b38899f15f98eed885b2ee12f4a8d5b27c Mon Sep 17 00:00:00 2001 From: lleeoo Date: Wed, 8 Nov 2023 11:45:05 +0100 Subject: [PATCH 63/65] wip --- Makefile | 2 +- src/ralph/api/auth/basic.py | 4 +- src/ralph/cli.py | 2 +- tests/fixtures/hypothesis_strategies.py | 59 +++++++++++++++++++++++-- 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 40b670cde..a972cc0ee 100644 --- a/Makefile +++ b/Makefile @@ -73,7 +73,7 @@ bin/init-cluster: -u $(RALPH_LRS_AUTH_USER_NAME) \ -p $(RALPH_LRS_AUTH_USER_PWD) \ -s $(RALPH_LRS_AUTH_USER_SCOPE) \ - -M $(RALPH_LRS_AUTH_USER_AGENT_MBOX) + -M $(RALPH_LRS_AUTH_USER_AGENT_MBOX) -w diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index 400406e01..11e278b88 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -64,10 +64,12 @@ def __len__(self) -> int: # noqa: D105 def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105 return iter(self.root) - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def ensure_unique_username(cls, values: Any) -> Any: """Every username should be unique among registered users.""" + logger.warning("azerty") + logger.error(values) usernames = [entry.username for entry in values] if len(usernames) != len(set(usernames)): raise ValueError( diff --git a/src/ralph/cli.py b/src/ralph/cli.py index 1d9ee21e2..840ae1da0 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -432,7 +432,7 @@ def auth( # Parse credentials file if not empty if auth_file.stat().st_size: with open(auth_file, encoding=settings.LOCALE_ENCODING) as f: - users = ServerUsersCredentials.model_validate_json(json.load(f)) + users = ServerUsersCredentials.model_validate_json(f.read()) logger.warning('ok fffa') logger.warning(type(ServerUsersCredentials.model_validate( diff --git a/tests/fixtures/hypothesis_strategies.py b/tests/fixtures/hypothesis_strategies.py index b7c4dd52d..02154b61d 100644 --- a/tests/fixtures/hypothesis_strategies.py +++ b/tests/fixtures/hypothesis_strategies.py @@ -56,6 +56,46 @@ def get_strategy_from(annotation): return st.none() return st.from_type(annotation) +# def OLD_custom_builds( +# klass: BaseModel, _overwrite_default=True, **kwargs: Union[st.SearchStrategy, bool] +# ): +# """Return a fixed_dictionaries Hypothesis strategy for pydantic models. + +# Args: +# klass (BaseModel): The pydantic model for which to generate a strategy. +# _overwrite_default (bool): By default, fields overwritten by kwargs become +# required. If _overwrite_default is set to False, we keep the original field +# requirement (either required or optional). +# **kwargs (SearchStrategy or bool): If kwargs contain search strategies, they +# overwrite the default strategy for the given key. +# If kwargs contains booleans, they set whether the given key should be +# present (True) or omitted (False) in the generated model. +# """ + +# for special_class, special_kwargs in OVERWRITTEN_STRATEGIES.items(): +# if issubclass(klass, special_class): +# kwargs = dict(special_kwargs, **kwargs) +# break +# optional = {} +# required = {} +# for name, field in klass.model_fields.items(): +# arg = kwargs.get(name, None) +# if arg is False: +# continue +# is_required = field.is_required or (arg is not None and _overwrite_default) +# required_optional = required if is_required or arg is not None else optional +# #field_strategy = ( +# # get_strategy_from(field.annotation) if arg is None else arg +# #) # TODO: validate this change is not failing silently +# field_strategy = get_strategy_from(field.outer_type_) if arg is None else arg +# required_optional[field.alias] = field_strategy +# if not required: +# # To avoid generating empty values +# key, value = random.choice(list(optional.items())) +# required[key] = value +# del optional[key] +# return st.fixed_dictionaries(required, optional=optional).map(klass.parse_obj) + def custom_builds( klass: BaseModel, _overwrite_default=True, **kwargs: Union[st.SearchStrategy, bool] @@ -84,20 +124,33 @@ def custom_builds( if arg is False: continue is_required = field.is_required or (arg is not None and _overwrite_default) - required_optional = required if is_required or arg is not None else optional + field_strategy = ( get_strategy_from(field.annotation) if arg is None else arg ) # TODO: validate this change is not failing silently - required_optional[field.alias] = field_strategy + #field_strategy = get_strategy_from(field.outer_type_) if arg is None else arg + if is_required or arg is not None: + required[field.alias] = field_strategy + else: + optional[field.alias] = field_strategy if not required: # To avoid generating empty values key, value = random.choice(list(optional.items())) required[key] = value del optional[key] + print("Imblue dabedi") + print(required) + print(optional) return st.fixed_dictionaries(required, optional=optional).map(klass.parse_obj) +# def OLD_custom_given(*args: Union[st.SearchStrategy, BaseModel], **kwargs): +# """Wrap the Hypothesis `given` function. Replace st.builds with custom_builds.""" +# strategies = [] +# for arg in args: +# strategies.append(custom_builds(arg) if is_base_model(arg) else arg) +# return given(*strategies, **kwargs) -def custom_given(*args: Union[st.SearchStrategy, BaseModel], **kwargs): +def custom_given(*args: BaseModel, **kwargs): """Wrap the Hypothesis `given` function. Replace st.builds with custom_builds.""" strategies = [] for arg in args: From 606658d13e3f5a9cb2a0e11e38f537ca59289103 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Fri, 10 Nov 2023 10:20:19 +0100 Subject: [PATCH 64/65] wip --- src/ralph/api/auth/basic.py | 6 +++--- src/ralph/api/routers/statements.py | 6 +++--- src/ralph/backends/data/async_es.py | 11 +++++++++-- src/ralph/backends/data/async_mongo.py | 2 +- src/ralph/backends/data/clickhouse.py | 2 +- src/ralph/backends/data/es.py | 6 +++--- src/ralph/backends/data/mongo.py | 2 +- src/ralph/backends/lrs/clickhouse.py | 4 ++-- src/ralph/backends/lrs/es.py | 2 +- src/ralph/backends/lrs/fs.py | 4 ++-- src/ralph/backends/lrs/mongo.py | 2 +- src/ralph/cli.py | 21 ++------------------- src/ralph/models/xapi/base/ifi.py | 2 +- tests/api/test_statements_get.py | 3 ++- tests/backends/data/test_async_es.py | 13 ++++++------- tests/backends/lrs/test_async_es.py | 2 +- tests/backends/lrs/test_async_mongo.py | 2 +- tests/backends/lrs/test_es.py | 2 +- tests/backends/lrs/test_mongo.py | 2 +- tests/backends/test_conf.py | 8 ++++---- tests/fixtures/backends.py | 2 +- tests/fixtures/hypothesis_strategies.py | 4 +--- tests/models/test_converter.py | 6 +++--- tests/models/xapi/base/test_statements.py | 4 ++-- tests/test_cli.py | 22 ---------------------- 25 files changed, 53 insertions(+), 87 deletions(-) diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index 11e278b88..5bc55750e 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -68,8 +68,6 @@ def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105 @classmethod def ensure_unique_username(cls, values: Any) -> Any: """Every username should be unique among registered users.""" - logger.warning("azerty") - logger.error(values) usernames = [entry.username for entry in values] if len(usernames) != len(set(usernames)): raise ValueError( @@ -97,7 +95,9 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials: msg = "Credentials file <%s> not found." logger.warning(msg, auth_file) raise AuthenticationError(msg.format(auth_file)) - return ServerUsersCredentials.parse_file(auth_file) + + with open(auth_file, encoding=settings.LOCALE_ENCODING) as f: + return ServerUsersCredentials.model_validate_json(f.read()) @cached( diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index fcf37b8cb..53028da03 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -334,9 +334,9 @@ async def get( # Parse the "agent" parameter (JSON) into multiple string parameters if query_params.get("agent") is not None: # Overwrite `agent` field - query_params["agent"] = _parse_agent_parameters( + query_params["agent"] = json.loads(_parse_agent_parameters( json.loads(query_params["agent"]) - ) + ).model_dump_json()) # mine: If using scopes, only restrict users with limited scopes if settings.LRS_RESTRICT_BY_SCOPES: @@ -348,7 +348,7 @@ async def get( # Filter by authority if using `mine` if mine: - query_params["authority"] = _parse_agent_parameters(current_user.agent) + query_params["authority"] = json.loads(_parse_agent_parameters(current_user.agent).model_dump_json()) if "mine" in query_params: query_params.pop("mine") diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 7ec7dba20..dd5037cf1 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -45,7 +45,7 @@ def client(self): """Create an AsyncElasticsearch client if it doesn't exist.""" if not self._client: self._client = AsyncElasticsearch( - self.settings.HOSTS, **self.settings.CLIENT_OPTIONS.dict() + self.settings.HOSTS, **self.settings.CLIENT_OPTIONS.model_dump() ) return self._client @@ -160,10 +160,17 @@ async def read( raise BackendException(msg % error) from error limit = query.size - kwargs = query.dict(exclude={"query_string", "size"}) + + # TODO: fix this temporary workaround linked to Url(...) not being serialized + #kwargs = query.model_dump(exclude={"query_string", "size"}) + import json + kwargs = json.loads(query.model_dump_json(exclude={"query_string", "size"})) + if query.query_string: kwargs["q"] = query.query_string + # TODO: field "query" is `dict` and therefore model dump does not go recursively + count = chunk_size # The first condition is set to comprise either limit as None # (when the backend query does not have `size` parameter), diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 8230a11be..dba65b862 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -44,7 +44,7 @@ def __init__(self, settings: Optional[MongoDataBackendSettings] = None): """ self.settings = settings if settings else self.settings_class() self.client = AsyncIOMotorClient( - self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.dict() + self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.model_dump() ) self.database = self.client[self.settings.DEFAULT_DATABASE] self.collection = self.database[self.settings.DEFAULT_COLLECTION] diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 9829106c1..75c347bcb 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -155,7 +155,7 @@ def client(self): database=self.database, username=self.settings.USERNAME, password=self.settings.PASSWORD, - settings=self.settings.CLIENT_OPTIONS.dict(), + settings=self.settings.CLIENT_OPTIONS.model_dump(), ) return self._client diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index cee41b93f..81e678c52 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -88,7 +88,7 @@ class ESQueryPit(BaseModel): id: Union[str, None] = None keep_alive: Union[str, None] = None - +from typing import Any, Dict class ESQuery(BaseQuery): """Elasticsearch query model. @@ -110,7 +110,7 @@ class ESQuery(BaseQuery): Not used. Always set to `False`. """ # pylint: disable=line-too-long # noqa: E501 - query: dict = {"match_all": {}} + query: dict = {"match_all": {}} pit: ESQueryPit = ESQueryPit() size: Union[int, None] = None sort: Union[str, List[dict]] = "_shard_doc" @@ -140,7 +140,7 @@ def client(self): """Create an Elasticsearch client if it doesn't exist.""" if not self._client: self._client = Elasticsearch( - self.settings.HOSTS, **self.settings.CLIENT_OPTIONS.dict() + self.settings.HOSTS, **self.settings.CLIENT_OPTIONS.model_dump() ) return self._client diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index d615ef25d..c3ed6b79a 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -120,7 +120,7 @@ def __init__(self, settings: Optional[MongoDataBackendSettings] = None): """ self.settings = settings if settings else self.settings_class() self.client = MongoClient( - self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.dict() + self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.model_dump() ) self.database = self.client[self.settings.DEFAULT_DATABASE] self.collection = self.database[self.settings.DEFAULT_COLLECTION] diff --git a/src/ralph/backends/lrs/clickhouse.py b/src/ralph/backends/lrs/clickhouse.py index b04c4f545..81a6e4f0b 100644 --- a/src/ralph/backends/lrs/clickhouse.py +++ b/src/ralph/backends/lrs/clickhouse.py @@ -38,7 +38,7 @@ class ClickHouseLRSBackend(BaseLRSBackend, ClickHouseDataBackend): def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" - ch_params = params.dict(exclude_none=True) + ch_params = params.model_dump(exclude_none=True) where = [] if params.statement_id: @@ -153,7 +153,7 @@ def _add_agent_filters( if not agent_params: return if not isinstance(agent_params, dict): - agent_params = agent_params.dict() + agent_params = agent_params.model_dump() if agent_params.get("mbox"): ch_params[f"{target_field}__mbox"] = agent_params.get("mbox") where.append(f"event.{target_field}.mbox = {{{target_field}__mbox:String}}") diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index af476aacc..f9ed1d91f 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -97,7 +97,7 @@ def _add_agent_filters( return if not isinstance(agent_params, dict): - agent_params = agent_params.dict() + agent_params = agent_params.model_dump() if agent_params.get("mbox"): field = f"{target_field}.mbox.keyword" diff --git a/src/ralph/backends/lrs/fs.py b/src/ralph/backends/lrs/fs.py index 4423bfdc4..fc8bc5f3e 100644 --- a/src/ralph/backends/lrs/fs.py +++ b/src/ralph/backends/lrs/fs.py @@ -106,7 +106,7 @@ def _add_filter_by_agent( return if not isinstance(agent, dict): - agent = agent.dict() + agent = agent.model_dump() FSLRSBackend._add_filter_by_mbox(filters, agent.get("mbox", None), related) FSLRSBackend._add_filter_by_sha1sum( filters, agent.get("mbox_sha1sum", None), related @@ -129,7 +129,7 @@ def _add_filter_by_authority( return if not isinstance(authority, dict): - authority = authority.dict() + authority = authority.model_dump() FSLRSBackend._add_filter_by_mbox( filters, authority.get("mbox", None), field="authority" ) diff --git a/src/ralph/backends/lrs/mongo.py b/src/ralph/backends/lrs/mongo.py index f49b1344b..8dac624da 100644 --- a/src/ralph/backends/lrs/mongo.py +++ b/src/ralph/backends/lrs/mongo.py @@ -115,7 +115,7 @@ def _add_agent_filters( return if not isinstance(agent_params, dict): - agent_params = agent_params.dict() + agent_params = agent_params.model_dump() if agent_params.get("mbox"): key = f"_source.{target_field}.mbox" diff --git a/src/ralph/cli.py b/src/ralph/cli.py index 840ae1da0..e2943e511 100644 --- a/src/ralph/cli.py +++ b/src/ralph/cli.py @@ -365,11 +365,10 @@ def auth( # Import required Pydantic models dynamically so that we don't create a # direct dependency between the CLI and the LRS # pylint: disable=invalid-name - logger.warning('ok aaa') + ServerUsersCredentials = import_string( "ralph.api.auth.basic.ServerUsersCredentials" ) - logger.warning('ok bbb') UserCredentialsBasicAuth = import_string("ralph.api.auth.basic.UserCredentials") # NB: renaming classes below for clarity @@ -383,14 +382,12 @@ def auth( "ralph.models.xapi.base.agents.BaseXapiAgentWithAccount" ) - logger.warning('ok ccc') if agent_ifi_mbox: if agent_ifi_mbox[:7] != "mailto:": raise click.UsageError( 'Mbox field must start with "mailto:" (e.g.: "mailto:foo@bar.com")' ) agent = AgentMbox(mbox=agent_ifi_mbox, name=agent_name, objectType="Agent") - logger.warning('ok ddd') if agent_ifi_mbox_sha1sum: agent = AgentMboxSha1sum( mbox_sha1sum=agent_ifi_mbox_sha1sum, name=agent_name, objectType="Agent" @@ -412,7 +409,6 @@ def auth( scopes=scope, agent=agent, ) - logger.warning('ok eee') if write_to_disk: logger.info("Will append new credentials to: %s", settings.AUTH_FILE) @@ -424,30 +420,17 @@ def auth( auth_file.parent.mkdir(parents=True, exist_ok=True) auth_file.touch() - logger.warning('ok fff') users = ServerUsersCredentials.model_validate([]) - logger.warning('ok fffgloser') - - logger.warning(auth_file) # Parse credentials file if not empty if auth_file.stat().st_size: with open(auth_file, encoding=settings.LOCALE_ENCODING) as f: users = ServerUsersCredentials.model_validate_json(f.read()) - logger.warning('ok fffa') - logger.warning(type(ServerUsersCredentials.model_validate( - [ - credentials, - ] - ))) - users += ServerUsersCredentials.model_validate( [ credentials, ] ) - - logger.warning('ok fffb') auth_file.write_text(users.model_dump_json(indent=2), encoding=settings.LOCALE_ENCODING) logger.info("User %s has been added to: %s", username, settings.AUTH_FILE) @@ -877,7 +860,7 @@ def runserver(backend: str, host: str, port: int, **options): if isinstance(value, tuple): value = ",".join(value) if issubclass(type(value), ClientOptions): - for key_dict, value_dict in value.dict().items(): + for key_dict, value_dict in value.model_dump().items(): if value_dict is None: continue key_dict = f"{key}__{key_dict}" diff --git a/src/ralph/models/xapi/base/ifi.py b/src/ralph/models/xapi/base/ifi.py index cf9fa0b89..c8f62feab 100644 --- a/src/ralph/models/xapi/base/ifi.py +++ b/src/ralph/models/xapi/base/ifi.py @@ -48,7 +48,7 @@ class BaseXapiOpenIdIFI(BaseModelWithConfig): openid (URI): Consists of an openID that uniquely identifies the Agent. """ - openid: AnyUrl + openid: str # Changed due to https://github.com/pydantic/pydantic/issues/7186 class BaseXapiAccountIFI(BaseModelWithConfig): diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 938543c46..ce8de27d1 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -204,7 +204,8 @@ async def test_api_statements_get_mine( "/xAPI/statements/?mine=True", headers={"Authorization": f"Basic {credentials_1_bis}"}, ) - assert response.status_code == 200 + + assert response.status_code == 200 # TODO: bug here with openid and asynces assert response.json() == {"statements": [statements[0]]} # Only fetch mine (implicit with RALPH_LRS_RESTRICT_BY_AUTHORITY=True): Return diff --git a/tests/backends/data/test_async_es.py b/tests/backends/data/test_async_es.py index a18745897..eb4a270ed 100644 --- a/tests/backends/data/test_async_es.py +++ b/tests/backends/data/test_async_es.py @@ -465,13 +465,12 @@ async def test_backends_data_async_es_data_backend_read_method_with_query( async for statement in backend.read(query={"not_query": "foo"}) ] - assert ( - "ralph.backends.data.base", - logging.ERROR, - "The 'query' argument is expected to be a ESQuery instance. " - "[{'loc': ('not_query',), 'msg': 'extra fields not permitted', " - "'type': 'value_error.extra'}]", - ) in caplog.record_tuples + assert ('ralph.backends.data.base', + logging.ERROR, + "The 'query' argument is expected to be a ESQuery instance. " + "[{'type': 'extra_forbidden', 'loc': ('not_query',), 'msg': 'Extra" + " inputs are not permitted', 'input': 'foo', 'url': " + "'https://errors.pydantic.dev/2.4/v/extra_forbidden'}]") in caplog.record_tuples await backend.close() diff --git a/tests/backends/lrs/test_async_es.py b/tests/backends/lrs/test_async_es.py index 7ec90334a..4e11b208a 100644 --- a/tests/backends/lrs/test_async_es.py +++ b/tests/backends/lrs/test_async_es.py @@ -263,7 +263,7 @@ async def test_backends_lrs_async_es_lrs_backend_query_statements_query( async def mock_read(query, chunk_size): """Mock the `AsyncESLRSBackend.read` method.""" - assert query.dict() == expected_query + assert query.model_dump() == expected_query assert chunk_size == expected_query.get("size") query.pit.id = "foo_pit_id" query.search_after = ["bar_search_after", "baz_search_after"] diff --git a/tests/backends/lrs/test_async_mongo.py b/tests/backends/lrs/test_async_mongo.py index 3b8ad2931..ef739830e 100644 --- a/tests/backends/lrs/test_async_mongo.py +++ b/tests/backends/lrs/test_async_mongo.py @@ -233,7 +233,7 @@ async def test_backends_lrs_async_mongo_lrs_backend_query_statements_query( async def mock_read(query, chunk_size): """Mock the `AsyncMongoLRSBackend.read` method.""" - assert query.dict() == expected_query + assert query.model_dump() == expected_query assert chunk_size == expected_query.get("limit") yield {"_id": "search_after_id", "_source": {}} diff --git a/tests/backends/lrs/test_es.py b/tests/backends/lrs/test_es.py index e9504acbd..5bdb622fd 100644 --- a/tests/backends/lrs/test_es.py +++ b/tests/backends/lrs/test_es.py @@ -262,7 +262,7 @@ def test_backends_lrs_es_lrs_backend_query_statements_query( def mock_read(query, chunk_size): """Mock the `ESLRSBackend.read` method.""" - assert query.dict() == expected_query + assert query.model_dump() == expected_query assert chunk_size == expected_query.get("size") query.pit.id = "foo_pit_id" query.search_after = ["bar_search_after", "baz_search_after"] diff --git a/tests/backends/lrs/test_mongo.py b/tests/backends/lrs/test_mongo.py index a40da1091..311a49b01 100644 --- a/tests/backends/lrs/test_mongo.py +++ b/tests/backends/lrs/test_mongo.py @@ -232,7 +232,7 @@ def test_backends_lrs_mongo_lrs_backend_query_statements_query( def mock_read(query, chunk_size): """Mock the `MongoLRSBackend.read` method.""" - assert query.dict() == expected_query + assert query.model_dump() == expected_query assert chunk_size == expected_query.get("limit") return [{"_id": "search_after_id", "_source": {}}] diff --git a/tests/backends/test_conf.py b/tests/backends/test_conf.py index aa4a6dd09..6d615af44 100644 --- a/tests/backends/test_conf.py +++ b/tests/backends/test_conf.py @@ -69,7 +69,7 @@ def test_conf_es_client_options_with_valid_values( "RALPH_BACKENDS__DATA__ES__CLIENT_OPTIONS__verify_certs", f"{verify_certs}", ) - assert BackendSettings().BACKENDS.DATA.ES.CLIENT_OPTIONS.dict() == expected + assert BackendSettings().BACKENDS.DATA.ES.CLIENT_OPTIONS.model_dump() == expected @pytest.mark.parametrize( @@ -91,7 +91,7 @@ def test_conf_es_client_options_with_invalid_values( f"{verify_certs}", ) with pytest.raises(ValidationError, match="1 validation error for"): - BackendSettings().BACKENDS.DATA.ES.CLIENT_OPTIONS.dict() + BackendSettings().BACKENDS.DATA.ES.CLIENT_OPTIONS.model_dump() @pytest.mark.parametrize( @@ -118,7 +118,7 @@ def test_conf_mongo_client_options_with_valid_values( "RALPH_BACKENDS__DATA__MONGO__CLIENT_OPTIONS__tz_aware", f"{tz_aware}", ) - assert BackendSettings().BACKENDS.DATA.MONGO.CLIENT_OPTIONS.dict() == expected + assert BackendSettings().BACKENDS.DATA.MONGO.CLIENT_OPTIONS.model_dump() == expected @pytest.mark.parametrize( @@ -141,4 +141,4 @@ def test_conf_mongo_client_options_with_invalid_values( f"{tz_aware}", ) with pytest.raises(ValidationError, match="1 validation error for"): - BackendSettings().BACKENDS.DATA.MONGO.CLIENT_OPTIONS.dict() + BackendSettings().BACKENDS.DATA.MONGO.CLIENT_OPTIONS.model_dump() diff --git a/tests/fixtures/backends.py b/tests/fixtures/backends.py index ab73d270e..19b4a2462 100644 --- a/tests/fixtures/backends.py +++ b/tests/fixtures/backends.py @@ -368,7 +368,7 @@ def get_clickhouse_fixture( client_options = ClickHouseClientOptions( date_time_input_format="best_effort", # Allows RFC dates allow_experimental_object_type=1, # Allows JSON data type - ).dict() + ).model_dump() client = clickhouse_connect.get_client( host=host, diff --git a/tests/fixtures/hypothesis_strategies.py b/tests/fixtures/hypothesis_strategies.py index 02154b61d..9b8aa20fb 100644 --- a/tests/fixtures/hypothesis_strategies.py +++ b/tests/fixtures/hypothesis_strategies.py @@ -138,9 +138,7 @@ def custom_builds( key, value = random.choice(list(optional.items())) required[key] = value del optional[key] - print("Imblue dabedi") - print(required) - print(optional) + return st.fixed_dictionaries(required, optional=optional).map(klass.parse_obj) # def OLD_custom_given(*args: Union[st.SearchStrategy, BaseModel], **kwargs): diff --git a/tests/models/test_converter.py b/tests/models/test_converter.py index c4226b2fe..7ede2c20c 100644 --- a/tests/models/test_converter.py +++ b/tests/models/test_converter.py @@ -119,7 +119,7 @@ def test_converter_conversion_item_get_value_with_successful_transformers( # """Return a set of ConversionItems used for conversion.""" # return set() -# assert not convert_dict_event(event, "", DummyBaseConversionSet()).dict() +# assert not convert_dict_event(event, "", DummyBaseConversionSet()).model_dump() @pytest.mark.parametrize("event", [{"foo": "foo_value", "bar": "bar_value"}]) @@ -148,7 +148,7 @@ def test_converter_convert_dict_event_with_one_conversion_item( class DummyBaseModel(BaseModel): """Dummy base model with one field.""" - converted: Optional[Any] + converted: Optional[Any] = None class DummyBaseConversionSet(BaseConversionSet): """Dummy implementation of abstract BaseConversionSet.""" @@ -160,7 +160,7 @@ def _get_conversion_items(self): # pylint: disable=no-self-use return {ConversionItem("converted", source, transformer)} converted = convert_dict_event(event, "", DummyBaseConversionSet()) - assert converted.dict(exclude_none=True) == expected + assert converted.model_dump(exclude_none=True) == expected @pytest.mark.parametrize("item", [ConversionItem("foo", None, lambda x: x / 0)]) diff --git a/tests/models/xapi/base/test_statements.py b/tests/models/xapi/base/test_statements.py index a88fca426..1e4660f43 100644 --- a/tests/models/xapi/base/test_statements.py +++ b/tests/models/xapi/base/test_statements.py @@ -486,9 +486,9 @@ def test_models_xapi_base_statement_with_valid_version(statement): """ statement = statement.dict(exclude_none=True) set_dict_value_from_path(statement, ["version"], "1.0.3") - assert "1.0.3" == BaseXapiStatement(**statement).dict()["version"] + assert "1.0.3" == BaseXapiStatement(**statement).model_dump()["version"] del statement["version"] - assert "1.0.0" == BaseXapiStatement(**statement).dict()["version"] + assert "1.0.0" == BaseXapiStatement(**statement).model_dump()["version"] @settings(deadline=None) diff --git a/tests/test_cli.py b/tests/test_cli.py index ca795515f..fa4355598 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -349,28 +349,19 @@ def test_cli_auth_command_when_writing_auth_file( # The authentication file does not exist - print("yeah yo ok 111") - # Add a first user cli_args = _gen_cli_auth_args( username_1, password_1, scopes_1, ifi_command_1, ifi_value_1, write=True ) - print("yeah yo ok 2") - - print("cli args are", cli_args) assert Path(settings.AUTH_FILE).exists() is False result = runner.invoke(cli, cli_args) - print("yeah yo ok 2.5") assert result.exit_code == 0 assert Path(settings.AUTH_FILE).exists() is True with Path(settings.AUTH_FILE).open(encoding="utf-8") as auth_file: all_credentials = json.loads("\n".join(auth_file.readlines())) assert len(all_credentials) == 1 - - print("yeah yo ok 3") - # Check that the first user matches ifi_type_1 = _ifi_type_from_command(ifi_command=ifi_command_1) ifi_value_1 = _ifi_value_from_command(ifi_value_1, ifi_type_1) @@ -382,9 +373,6 @@ def test_cli_auth_command_when_writing_auth_file( ifi_value=ifi_value_1, ) - - print("yeah yo ok 4") - # Add a second user username_2 = "lol" password_2 = "baz" @@ -395,17 +383,12 @@ def test_cli_auth_command_when_writing_auth_file( ) result = runner.invoke(cli, cli_args) - - print("yeah yo ok 5") - assert result.exit_code == 0 with Path(settings.AUTH_FILE).open(encoding="utf-8") as auth_file: all_credentials = json.loads("\n".join(auth_file.readlines())) assert len(all_credentials) == 2 - print("yeah yo ok 6") - # Check that the first user still matches _assert_matching_basic_auth_credentials( credentials=all_credentials[0], @@ -415,9 +398,6 @@ def test_cli_auth_command_when_writing_auth_file( ifi_value=ifi_value_1, ) - - print("yeah yo ok 7") - # Check that the second user matches ifi_type_2 = _ifi_type_from_command(ifi_command=ifi_command_2) ifi_value_2 = _ifi_value_from_command(ifi_value_2, ifi_type_2) @@ -430,8 +410,6 @@ def test_cli_auth_command_when_writing_auth_file( ) - print("yeah yo ok 8") - # pylint: disable=invalid-name def test_cli_auth_command_when_writing_auth_file_with_incorrect_auth_file(fs): """Test ralph auth command when credentials are written in the authentication From a324201cec87af21efe7c090fc50753f12e90905 Mon Sep 17 00:00:00 2001 From: lleeoo Date: Fri, 10 Nov 2023 10:29:07 +0100 Subject: [PATCH 65/65] wip --- setup.cfg | 2 +- src/ralph/models/edx/navigational/statements.py | 2 -- src/ralph/models/edx/problem_interaction/fields/events.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 2f0c437d5..6b2ef7891 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ install_requires = ; By default, we only consider core dependencies required to use Ralph as a ; library (mostly models). langcodes>=3.2.0 - pydantic[dotenv,email]>=2.0 + pydantic[dotenv,email]>=2.4,<3.0 pydantic-settings>=2.0 rfc3987>=1.3.0 package_dir = diff --git a/src/ralph/models/edx/navigational/statements.py b/src/ralph/models/edx/navigational/statements.py index 9df3e9b30..f4304f39d 100644 --- a/src/ralph/models/edx/navigational/statements.py +++ b/src/ralph/models/edx/navigational/statements.py @@ -77,7 +77,6 @@ class UISeqNext(BaseBrowserModel): @field_validator("event") @classmethod - @classmethod def validate_next_jump_event_field( cls, value: Union[Json[NavigationalEventField], NavigationalEventField] ) -> Union[Json[NavigationalEventField], NavigationalEventField]: @@ -110,7 +109,6 @@ class UISeqPrev(BaseBrowserModel): @field_validator("event") @classmethod - @classmethod def validate_prev_jump_event_field( cls, value: Union[Json[NavigationalEventField], NavigationalEventField] ) -> Union[Json[NavigationalEventField], NavigationalEventField]: diff --git a/src/ralph/models/edx/problem_interaction/fields/events.py b/src/ralph/models/edx/problem_interaction/fields/events.py index 6f8b8edbb..4a12883c4 100644 --- a/src/ralph/models/edx/problem_interaction/fields/events.py +++ b/src/ralph/models/edx/problem_interaction/fields/events.py @@ -42,7 +42,7 @@ class CorrectMap(BaseModelWithConfig): queuestate (json): see QueueStateField. """ - answervariable: Union[Literal[None], None, str] = None + answervariable: Union[Literal[None], None, str] # = None correctness: Union[Literal["correct"], Literal["incorrect"]] hint: Optional[str] = None hintmode: Optional[Union[Literal["on_request"], Literal["always"]]] = None