diff --git a/.github/workflows/run-tests-on-push.yml b/.github/workflows/run-tests-on-push.yml index 6cb7d18e..f07da233 100644 --- a/.github/workflows/run-tests-on-push.yml +++ b/.github/workflows/run-tests-on-push.yml @@ -1,6 +1,7 @@ -name: Run Tests (On Push) +name: Run Tests on: push: + # Run all tests on main, fast tests on other branches env: LOG_CONFIG: test @@ -50,7 +51,12 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev - - run: poetry run pytest tests/ + - name: Run fast tests on non-main branches + if: github.event_name == 'push' && github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" + - name: Run full tests on main + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: poetry run pytest tests/ run-tests-3_11: runs-on: ubuntu-latest @@ -66,7 +72,12 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev --extras server - - run: poetry run pytest tests/ --show-capture=stdout --cov=src + - name: Run fast tests on non-main branches + if: github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" --show-capture=stdout + - name: Run all tests with coverage on main branch + if: github.ref == 'refs/heads/main' + run: poetry run pytest tests/ --show-capture=stdout --cov=src run-tests-3_12-core-dependencies: runs-on: ubuntu-latest @@ -80,7 +91,12 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev - - run: poetry run pytest tests/ + - name: Run fast tests on non-main branches + if: github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" + - name: Run all tests on main branch + if: github.ref == 'refs/heads/main' + run: poetry run pytest tests/ run-tests-3_12: runs-on: ubuntu-latest @@ -96,4 +112,9 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev --extras server - - run: poetry run pytest tests/ --show-capture=stdout --cov=src + - name: Run fast tests on non-main branches + if: github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" --show-capture=stdout + - name: Run all tests with coverage on main branch + if: github.ref == 'refs/heads/main' + run: poetry run pytest tests/ --show-capture=stdout --cov=src diff --git a/alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py b/alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py new file mode 100644 index 00000000..af7eb945 --- /dev/null +++ b/alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py @@ -0,0 +1,222 @@ +"""add pipeline and job tracking tables + +Revision ID: 8de33cc35cd7 +Revises: dcf8572d3a17 +Create Date: 2026-01-28 10:08:36.906494 + +""" + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8de33cc35cd7" +down_revision = "dcf8572d3a17" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "pipelines", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("urn", sa.String(length=255), nullable=True), + sa.Column("name", sa.String(length=500), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("status", sa.String(length=50), nullable=False), + sa.Column("correlation_id", sa.String(length=255), nullable=True), + sa.Column( + "metadata", + postgresql.JSONB(astext_type=sa.Text()), + server_default="{}", + nullable=False, + comment="Flexible metadata storage for pipeline-specific data", + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_by_user_id", sa.Integer(), nullable=True), + sa.Column("mavedb_version", sa.String(length=50), nullable=True), + sa.CheckConstraint( + "status IN ('created', 'running', 'succeeded', 'failed', 'cancelled', 'paused', 'partial')", + name="ck_pipelines_status_valid", + ), + sa.ForeignKeyConstraint(["created_by_user_id"], ["users.id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("urn"), + ) + op.create_index("ix_pipelines_correlation_id", "pipelines", ["correlation_id"], unique=False) + op.create_index("ix_pipelines_created_at", "pipelines", ["created_at"], unique=False) + op.create_index("ix_pipelines_created_by_user_id", "pipelines", ["created_by_user_id"], unique=False) + op.create_index("ix_pipelines_status", "pipelines", ["status"], unique=False) + op.create_table( + "job_runs", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("urn", sa.String(length=255), nullable=True), + sa.Column("job_type", sa.String(length=100), nullable=False), + sa.Column("job_function", sa.String(length=255), nullable=False), + sa.Column("job_params", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("status", sa.String(length=50), nullable=False), + sa.Column("pipeline_id", sa.Integer(), nullable=True), + sa.Column("priority", sa.Integer(), nullable=False), + sa.Column("max_retries", sa.Integer(), nullable=False), + sa.Column("retry_count", sa.Integer(), nullable=False), + sa.Column("retry_delay_seconds", sa.Integer(), nullable=True), + sa.Column("scheduled_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("error_traceback", sa.Text(), nullable=True), + sa.Column("failure_category", sa.String(length=100), nullable=True), + sa.Column("progress_current", sa.Integer(), nullable=True), + sa.Column("progress_total", sa.Integer(), nullable=True), + sa.Column("progress_message", sa.String(length=500), nullable=True), + sa.Column("correlation_id", sa.String(length=255), nullable=True), + sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), server_default="{}", nullable=False), + sa.Column("mavedb_version", sa.String(length=50), nullable=True), + sa.CheckConstraint( + "status IN ('pending', 'queued', 'running', 'succeeded', 'failed', 'cancelled', 'skipped')", + name="ck_job_runs_status_valid", + ), + sa.CheckConstraint("max_retries >= 0", name="ck_job_runs_max_retries_positive"), + sa.CheckConstraint("priority >= 0", name="ck_job_runs_priority_positive"), + sa.CheckConstraint("retry_count >= 0", name="ck_job_runs_retry_count_positive"), + sa.ForeignKeyConstraint(["pipeline_id"], ["pipelines.id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("urn"), + ) + op.create_index("ix_job_runs_correlation_id", "job_runs", ["correlation_id"], unique=False) + op.create_index("ix_job_runs_created_at", "job_runs", ["created_at"], unique=False) + op.create_index("ix_job_runs_job_type", "job_runs", ["job_type"], unique=False) + op.create_index("ix_job_runs_pipeline_id", "job_runs", ["pipeline_id"], unique=False) + op.create_index("ix_job_runs_scheduled_at", "job_runs", ["scheduled_at"], unique=False) + op.create_index("ix_job_runs_status", "job_runs", ["status"], unique=False) + op.create_index("ix_job_runs_status_scheduled", "job_runs", ["status", "scheduled_at"], unique=False) + op.create_table( + "job_dependencies", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("depends_on_job_id", sa.Integer(), nullable=False), + sa.Column("dependency_type", sa.String(length=50), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.CheckConstraint( + "dependency_type IS NULL OR dependency_type IN ('success_required', 'completion_required')", + name="ck_job_dependencies_type_valid", + ), + sa.ForeignKeyConstraint(["depends_on_job_id"], ["job_runs.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["id"], ["job_runs.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id", "depends_on_job_id"), + ) + op.create_index("ix_job_dependencies_created_at", "job_dependencies", ["created_at"], unique=False) + op.create_index("ix_job_dependencies_depends_on_job_id", "job_dependencies", ["depends_on_job_id"], unique=False) + op.create_table( + "variant_annotation_status", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("variant_id", sa.Integer(), nullable=False), + sa.Column( + "annotation_type", + sa.String(length=50), + nullable=False, + comment="Type of annotation: vrs, clinvar, gnomad, etc.", + ), + sa.Column( + "version", + sa.String(length=50), + nullable=True, + comment="Version of the annotation source used (if applicable)", + ), + sa.Column("status", sa.String(length=50), nullable=False, comment="success, failed, skipped, pending"), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("failure_category", sa.String(length=100), nullable=True), + sa.Column( + "success_data", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Annotation results when successful", + ), + sa.Column( + "current", + sa.Boolean(), + server_default="true", + nullable=False, + comment="Whether this is the current status for the variant and annotation type", + ), + sa.Column("job_run_id", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.CheckConstraint( + "annotation_type IN ('vrs_mapping', 'clingen_allele_id', 'mapped_hgvs', 'variant_translation', 'gnomad_allele_frequency', 'clinvar_control', 'vep_functional_consequence', 'ldh_submission')", + name="ck_variant_annotation_type_valid", + ), + sa.CheckConstraint("status IN ('success', 'failed', 'skipped')", name="ck_variant_annotation_status_valid"), + sa.ForeignKeyConstraint(["job_run_id"], ["job_runs.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint(["variant_id"], ["variants.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_variant_annotation_status_annotation_type", "variant_annotation_status", ["annotation_type"], unique=False + ) + op.create_index( + "ix_variant_annotation_status_created_at", "variant_annotation_status", ["created_at"], unique=False + ) + op.create_index("ix_variant_annotation_status_current", "variant_annotation_status", ["current"], unique=False) + op.create_index( + "ix_variant_annotation_status_job_run_id", "variant_annotation_status", ["job_run_id"], unique=False + ) + op.create_index("ix_variant_annotation_status_status", "variant_annotation_status", ["status"], unique=False) + op.create_index( + "ix_variant_annotation_status_variant_id", "variant_annotation_status", ["variant_id"], unique=False + ) + op.create_index( + "ix_variant_annotation_status_variant_type_version_current", + "variant_annotation_status", + ["variant_id", "annotation_type", "version", "current"], + unique=False, + ) + op.create_index("ix_variant_annotation_status_version", "variant_annotation_status", ["version"], unique=False) + op.create_index( + "ix_variant_annotation_type_status", "variant_annotation_status", ["annotation_type", "status"], unique=False + ) + op.create_index( + "ix_variant_annotation_variant_type_status", + "variant_annotation_status", + ["variant_id", "annotation_type", "status"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_variant_annotation_variant_type_status", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_type_status", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_version", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_variant_type_version_current", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_variant_id", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_status", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_job_run_id", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_current", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_created_at", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_annotation_type", table_name="variant_annotation_status") + op.drop_table("variant_annotation_status") + op.drop_index("ix_job_dependencies_depends_on_job_id", table_name="job_dependencies") + op.drop_index("ix_job_dependencies_created_at", table_name="job_dependencies") + op.drop_table("job_dependencies") + op.drop_index("ix_job_runs_status_scheduled", table_name="job_runs") + op.drop_index("ix_job_runs_status", table_name="job_runs") + op.drop_index("ix_job_runs_scheduled_at", table_name="job_runs") + op.drop_index("ix_job_runs_pipeline_id", table_name="job_runs") + op.drop_index("ix_job_runs_job_type", table_name="job_runs") + op.drop_index("ix_job_runs_created_at", table_name="job_runs") + op.drop_index("ix_job_runs_correlation_id", table_name="job_runs") + op.drop_table("job_runs") + op.drop_index("ix_pipelines_status", table_name="pipelines") + op.drop_index("ix_pipelines_created_by_user_id", table_name="pipelines") + op.drop_index("ix_pipelines_created_at", table_name="pipelines") + op.drop_index("ix_pipelines_correlation_id", table_name="pipelines") + op.drop_table("pipelines") + # ### end Alembic commands ### diff --git a/bin/localstack-init.sh b/bin/localstack-init.sh new file mode 100755 index 00000000..1a00cfcb --- /dev/null +++ b/bin/localstack-init.sh @@ -0,0 +1,4 @@ +#!/bin/sh +echo "Initializing local S3 bucket..." +awslocal s3 mb s3://score-set-csv-uploads-dev +echo "S3 bucket 'score-set-csv-uploads-dev' created." \ No newline at end of file diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index d9d430af..972eb410 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -95,6 +95,18 @@ services: volumes: - mavedb-redis-dev:/data + localstack: + image: localstack/localstack:latest + ports: + - "4566:4566" + env_file: + - settings/.env.dev + environment: + - SERVICES=s3:4566 # We only need S3 for MaveDB + volumes: + - mavedb-localstack-dev:/var/lib/localstack + - "./bin/localstack-init.sh:/etc/localstack/init/ready.d/localstack-init.sh" + seqrepo: image: biocommons/seqrepo:2024-12-20 volumes: @@ -104,3 +116,4 @@ volumes: mavedb-data-dev: mavedb-redis-dev: mavedb-seqrepo-dev: + mavedb-localstack-dev: diff --git a/poetry.lock b/poetry.lock index 18ecdd5e..2bd65bd7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -154,6 +154,21 @@ files = [ {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, ] +[[package]] +name = "asyncclick" +version = "8.3.0.7" +description = "Composable command line interface toolkit, async fork" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "asyncclick-8.3.0.7-py3-none-any.whl", hash = "sha256:7607046de39a3f315867cad818849f973e29d350c10d92f251db3ff7600c6c7d"}, + {file = "asyncclick-8.3.0.7.tar.gz", hash = "sha256:8a80d8ac613098ee6a9a8f0248f60c66c273e22402cf3f115ed7f071acfc71d3"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "attrs" version = "25.3.0" @@ -165,7 +180,6 @@ files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, ] -markers = {main = "extra == \"server\""} [package.extras] benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] @@ -263,7 +277,6 @@ description = "miscellaneous simple bioinformatics utilities and lookup tables" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "bioutils-0.6.1-py3-none-any.whl", hash = "sha256:9928297331b9fc0a4fd4235afdef9a80a0916d8b5c2811ab781bded0dad4b9b6"}, {file = "bioutils-0.6.1.tar.gz", hash = "sha256:6ad7a9b6da73beea798a935499339d8b60a434edc37dfc803474d2e93e0e64aa"}, @@ -301,411 +314,441 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "boto3-stubs" -version = "1.34.162" -description = "Type annotations for boto3 1.34.162 generated with mypy-boto3-builder 7.26.0" +version = "1.42.33" +description = "Type annotations for boto3 1.42.33 generated with mypy-boto3-builder 8.12.0" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev"] files = [ - {file = "boto3_stubs-1.34.162-py3-none-any.whl", hash = "sha256:47c651272782a2e894082087eeaeb87a7e809e7e282748560cf39c155031abef"}, - {file = "boto3_stubs-1.34.162.tar.gz", hash = "sha256:6d60b7b9652e1c99f3caba00779e1b94ba7062b0431147a00543af8b1f5252f4"}, + {file = "boto3_stubs-1.42.33-py3-none-any.whl", hash = "sha256:ea2887aaab8b29db446a8260a19069ad8ad614d7a9ffe34ae87b9a2396c7a57e"}, + {file = "boto3_stubs-1.42.33.tar.gz", hash = "sha256:c6b508b3541d48d63892a3eb2a7b36ec4d24435e8cf8233b6ae3f8f2122f0b61"}, ] [package.dependencies] botocore-stubs = "*" +mypy-boto3-s3 = {version = ">=1.42.0,<1.43.0", optional = true, markers = "extra == \"s3\""} types-s3transfer = "*" typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} [package.extras] -accessanalyzer = ["mypy-boto3-accessanalyzer (>=1.34.0,<1.35.0)"] -account = ["mypy-boto3-account (>=1.34.0,<1.35.0)"] -acm = ["mypy-boto3-acm (>=1.34.0,<1.35.0)"] -acm-pca = ["mypy-boto3-acm-pca (>=1.34.0,<1.35.0)"] -all = ["mypy-boto3-accessanalyzer (>=1.34.0,<1.35.0)", "mypy-boto3-account (>=1.34.0,<1.35.0)", "mypy-boto3-acm (>=1.34.0,<1.35.0)", "mypy-boto3-acm-pca (>=1.34.0,<1.35.0)", "mypy-boto3-amp (>=1.34.0,<1.35.0)", "mypy-boto3-amplify (>=1.34.0,<1.35.0)", "mypy-boto3-amplifybackend (>=1.34.0,<1.35.0)", "mypy-boto3-amplifyuibuilder (>=1.34.0,<1.35.0)", "mypy-boto3-apigateway (>=1.34.0,<1.35.0)", "mypy-boto3-apigatewaymanagementapi (>=1.34.0,<1.35.0)", "mypy-boto3-apigatewayv2 (>=1.34.0,<1.35.0)", "mypy-boto3-appconfig (>=1.34.0,<1.35.0)", "mypy-boto3-appconfigdata (>=1.34.0,<1.35.0)", "mypy-boto3-appfabric (>=1.34.0,<1.35.0)", "mypy-boto3-appflow (>=1.34.0,<1.35.0)", "mypy-boto3-appintegrations (>=1.34.0,<1.35.0)", "mypy-boto3-application-autoscaling (>=1.34.0,<1.35.0)", "mypy-boto3-application-insights (>=1.34.0,<1.35.0)", "mypy-boto3-application-signals (>=1.34.0,<1.35.0)", "mypy-boto3-applicationcostprofiler (>=1.34.0,<1.35.0)", "mypy-boto3-appmesh (>=1.34.0,<1.35.0)", "mypy-boto3-apprunner (>=1.34.0,<1.35.0)", "mypy-boto3-appstream (>=1.34.0,<1.35.0)", "mypy-boto3-appsync (>=1.34.0,<1.35.0)", "mypy-boto3-apptest (>=1.34.0,<1.35.0)", "mypy-boto3-arc-zonal-shift (>=1.34.0,<1.35.0)", "mypy-boto3-artifact (>=1.34.0,<1.35.0)", "mypy-boto3-athena (>=1.34.0,<1.35.0)", "mypy-boto3-auditmanager (>=1.34.0,<1.35.0)", "mypy-boto3-autoscaling (>=1.34.0,<1.35.0)", "mypy-boto3-autoscaling-plans (>=1.34.0,<1.35.0)", "mypy-boto3-b2bi (>=1.34.0,<1.35.0)", "mypy-boto3-backup (>=1.34.0,<1.35.0)", "mypy-boto3-backup-gateway (>=1.34.0,<1.35.0)", "mypy-boto3-batch (>=1.34.0,<1.35.0)", "mypy-boto3-bcm-data-exports (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock-agent (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock-agent-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-billingconductor (>=1.34.0,<1.35.0)", "mypy-boto3-braket (>=1.34.0,<1.35.0)", "mypy-boto3-budgets (>=1.34.0,<1.35.0)", "mypy-boto3-ce (>=1.34.0,<1.35.0)", "mypy-boto3-chatbot (>=1.34.0,<1.35.0)", "mypy-boto3-chime (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-identity (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-media-pipelines (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-meetings (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-messaging (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-voice (>=1.34.0,<1.35.0)", "mypy-boto3-cleanrooms (>=1.34.0,<1.35.0)", "mypy-boto3-cleanroomsml (>=1.34.0,<1.35.0)", "mypy-boto3-cloud9 (>=1.34.0,<1.35.0)", "mypy-boto3-cloudcontrol (>=1.34.0,<1.35.0)", "mypy-boto3-clouddirectory (>=1.34.0,<1.35.0)", "mypy-boto3-cloudformation (>=1.34.0,<1.35.0)", "mypy-boto3-cloudfront (>=1.34.0,<1.35.0)", "mypy-boto3-cloudfront-keyvaluestore (>=1.34.0,<1.35.0)", "mypy-boto3-cloudhsm (>=1.34.0,<1.35.0)", "mypy-boto3-cloudhsmv2 (>=1.34.0,<1.35.0)", "mypy-boto3-cloudsearch (>=1.34.0,<1.35.0)", "mypy-boto3-cloudsearchdomain (>=1.34.0,<1.35.0)", "mypy-boto3-cloudtrail (>=1.34.0,<1.35.0)", "mypy-boto3-cloudtrail-data (>=1.34.0,<1.35.0)", "mypy-boto3-cloudwatch (>=1.34.0,<1.35.0)", "mypy-boto3-codeartifact (>=1.34.0,<1.35.0)", "mypy-boto3-codebuild (>=1.34.0,<1.35.0)", "mypy-boto3-codecatalyst (>=1.34.0,<1.35.0)", "mypy-boto3-codecommit (>=1.34.0,<1.35.0)", "mypy-boto3-codeconnections (>=1.34.0,<1.35.0)", "mypy-boto3-codedeploy (>=1.34.0,<1.35.0)", "mypy-boto3-codeguru-reviewer (>=1.34.0,<1.35.0)", "mypy-boto3-codeguru-security (>=1.34.0,<1.35.0)", "mypy-boto3-codeguruprofiler (>=1.34.0,<1.35.0)", "mypy-boto3-codepipeline (>=1.34.0,<1.35.0)", "mypy-boto3-codestar (>=1.34.0,<1.35.0)", "mypy-boto3-codestar-connections (>=1.34.0,<1.35.0)", "mypy-boto3-codestar-notifications (>=1.34.0,<1.35.0)", "mypy-boto3-cognito-identity (>=1.34.0,<1.35.0)", "mypy-boto3-cognito-idp (>=1.34.0,<1.35.0)", "mypy-boto3-cognito-sync (>=1.34.0,<1.35.0)", "mypy-boto3-comprehend (>=1.34.0,<1.35.0)", "mypy-boto3-comprehendmedical (>=1.34.0,<1.35.0)", "mypy-boto3-compute-optimizer (>=1.34.0,<1.35.0)", "mypy-boto3-config (>=1.34.0,<1.35.0)", "mypy-boto3-connect (>=1.34.0,<1.35.0)", "mypy-boto3-connect-contact-lens (>=1.34.0,<1.35.0)", "mypy-boto3-connectcampaigns (>=1.34.0,<1.35.0)", "mypy-boto3-connectcases (>=1.34.0,<1.35.0)", "mypy-boto3-connectparticipant (>=1.34.0,<1.35.0)", "mypy-boto3-controlcatalog (>=1.34.0,<1.35.0)", "mypy-boto3-controltower (>=1.34.0,<1.35.0)", "mypy-boto3-cost-optimization-hub (>=1.34.0,<1.35.0)", "mypy-boto3-cur (>=1.34.0,<1.35.0)", "mypy-boto3-customer-profiles (>=1.34.0,<1.35.0)", "mypy-boto3-databrew (>=1.34.0,<1.35.0)", "mypy-boto3-dataexchange (>=1.34.0,<1.35.0)", "mypy-boto3-datapipeline (>=1.34.0,<1.35.0)", "mypy-boto3-datasync (>=1.34.0,<1.35.0)", "mypy-boto3-datazone (>=1.34.0,<1.35.0)", "mypy-boto3-dax (>=1.34.0,<1.35.0)", "mypy-boto3-deadline (>=1.34.0,<1.35.0)", "mypy-boto3-detective (>=1.34.0,<1.35.0)", "mypy-boto3-devicefarm (>=1.34.0,<1.35.0)", "mypy-boto3-devops-guru (>=1.34.0,<1.35.0)", "mypy-boto3-directconnect (>=1.34.0,<1.35.0)", "mypy-boto3-discovery (>=1.34.0,<1.35.0)", "mypy-boto3-dlm (>=1.34.0,<1.35.0)", "mypy-boto3-dms (>=1.34.0,<1.35.0)", "mypy-boto3-docdb (>=1.34.0,<1.35.0)", "mypy-boto3-docdb-elastic (>=1.34.0,<1.35.0)", "mypy-boto3-drs (>=1.34.0,<1.35.0)", "mypy-boto3-ds (>=1.34.0,<1.35.0)", "mypy-boto3-dynamodb (>=1.34.0,<1.35.0)", "mypy-boto3-dynamodbstreams (>=1.34.0,<1.35.0)", "mypy-boto3-ebs (>=1.34.0,<1.35.0)", "mypy-boto3-ec2 (>=1.34.0,<1.35.0)", "mypy-boto3-ec2-instance-connect (>=1.34.0,<1.35.0)", "mypy-boto3-ecr (>=1.34.0,<1.35.0)", "mypy-boto3-ecr-public (>=1.34.0,<1.35.0)", "mypy-boto3-ecs (>=1.34.0,<1.35.0)", "mypy-boto3-efs (>=1.34.0,<1.35.0)", "mypy-boto3-eks (>=1.34.0,<1.35.0)", "mypy-boto3-eks-auth (>=1.34.0,<1.35.0)", "mypy-boto3-elastic-inference (>=1.34.0,<1.35.0)", "mypy-boto3-elasticache (>=1.34.0,<1.35.0)", "mypy-boto3-elasticbeanstalk (>=1.34.0,<1.35.0)", "mypy-boto3-elastictranscoder (>=1.34.0,<1.35.0)", "mypy-boto3-elb (>=1.34.0,<1.35.0)", "mypy-boto3-elbv2 (>=1.34.0,<1.35.0)", "mypy-boto3-emr (>=1.34.0,<1.35.0)", "mypy-boto3-emr-containers (>=1.34.0,<1.35.0)", "mypy-boto3-emr-serverless (>=1.34.0,<1.35.0)", "mypy-boto3-entityresolution (>=1.34.0,<1.35.0)", "mypy-boto3-es (>=1.34.0,<1.35.0)", "mypy-boto3-events (>=1.34.0,<1.35.0)", "mypy-boto3-evidently (>=1.34.0,<1.35.0)", "mypy-boto3-finspace (>=1.34.0,<1.35.0)", "mypy-boto3-finspace-data (>=1.34.0,<1.35.0)", "mypy-boto3-firehose (>=1.34.0,<1.35.0)", "mypy-boto3-fis (>=1.34.0,<1.35.0)", "mypy-boto3-fms (>=1.34.0,<1.35.0)", "mypy-boto3-forecast (>=1.34.0,<1.35.0)", "mypy-boto3-forecastquery (>=1.34.0,<1.35.0)", "mypy-boto3-frauddetector (>=1.34.0,<1.35.0)", "mypy-boto3-freetier (>=1.34.0,<1.35.0)", "mypy-boto3-fsx (>=1.34.0,<1.35.0)", "mypy-boto3-gamelift (>=1.34.0,<1.35.0)", "mypy-boto3-glacier (>=1.34.0,<1.35.0)", "mypy-boto3-globalaccelerator (>=1.34.0,<1.35.0)", "mypy-boto3-glue (>=1.34.0,<1.35.0)", "mypy-boto3-grafana (>=1.34.0,<1.35.0)", "mypy-boto3-greengrass (>=1.34.0,<1.35.0)", "mypy-boto3-greengrassv2 (>=1.34.0,<1.35.0)", "mypy-boto3-groundstation (>=1.34.0,<1.35.0)", "mypy-boto3-guardduty (>=1.34.0,<1.35.0)", "mypy-boto3-health (>=1.34.0,<1.35.0)", "mypy-boto3-healthlake (>=1.34.0,<1.35.0)", "mypy-boto3-iam (>=1.34.0,<1.35.0)", "mypy-boto3-identitystore (>=1.34.0,<1.35.0)", "mypy-boto3-imagebuilder (>=1.34.0,<1.35.0)", "mypy-boto3-importexport (>=1.34.0,<1.35.0)", "mypy-boto3-inspector (>=1.34.0,<1.35.0)", "mypy-boto3-inspector-scan (>=1.34.0,<1.35.0)", "mypy-boto3-inspector2 (>=1.34.0,<1.35.0)", "mypy-boto3-internetmonitor (>=1.34.0,<1.35.0)", "mypy-boto3-iot (>=1.34.0,<1.35.0)", "mypy-boto3-iot-data (>=1.34.0,<1.35.0)", "mypy-boto3-iot-jobs-data (>=1.34.0,<1.35.0)", "mypy-boto3-iot1click-devices (>=1.34.0,<1.35.0)", "mypy-boto3-iot1click-projects (>=1.34.0,<1.35.0)", "mypy-boto3-iotanalytics (>=1.34.0,<1.35.0)", "mypy-boto3-iotdeviceadvisor (>=1.34.0,<1.35.0)", "mypy-boto3-iotevents (>=1.34.0,<1.35.0)", "mypy-boto3-iotevents-data (>=1.34.0,<1.35.0)", "mypy-boto3-iotfleethub (>=1.34.0,<1.35.0)", "mypy-boto3-iotfleetwise (>=1.34.0,<1.35.0)", "mypy-boto3-iotsecuretunneling (>=1.34.0,<1.35.0)", "mypy-boto3-iotsitewise (>=1.34.0,<1.35.0)", "mypy-boto3-iotthingsgraph (>=1.34.0,<1.35.0)", "mypy-boto3-iottwinmaker (>=1.34.0,<1.35.0)", "mypy-boto3-iotwireless (>=1.34.0,<1.35.0)", "mypy-boto3-ivs (>=1.34.0,<1.35.0)", "mypy-boto3-ivs-realtime (>=1.34.0,<1.35.0)", "mypy-boto3-ivschat (>=1.34.0,<1.35.0)", "mypy-boto3-kafka (>=1.34.0,<1.35.0)", "mypy-boto3-kafkaconnect (>=1.34.0,<1.35.0)", "mypy-boto3-kendra (>=1.34.0,<1.35.0)", "mypy-boto3-kendra-ranking (>=1.34.0,<1.35.0)", "mypy-boto3-keyspaces (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-archived-media (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-media (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-signaling (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-webrtc-storage (>=1.34.0,<1.35.0)", "mypy-boto3-kinesisanalytics (>=1.34.0,<1.35.0)", "mypy-boto3-kinesisanalyticsv2 (>=1.34.0,<1.35.0)", "mypy-boto3-kinesisvideo (>=1.34.0,<1.35.0)", "mypy-boto3-kms (>=1.34.0,<1.35.0)", "mypy-boto3-lakeformation (>=1.34.0,<1.35.0)", "mypy-boto3-lambda (>=1.34.0,<1.35.0)", "mypy-boto3-launch-wizard (>=1.34.0,<1.35.0)", "mypy-boto3-lex-models (>=1.34.0,<1.35.0)", "mypy-boto3-lex-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-lexv2-models (>=1.34.0,<1.35.0)", "mypy-boto3-lexv2-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-license-manager (>=1.34.0,<1.35.0)", "mypy-boto3-license-manager-linux-subscriptions (>=1.34.0,<1.35.0)", "mypy-boto3-license-manager-user-subscriptions (>=1.34.0,<1.35.0)", "mypy-boto3-lightsail (>=1.34.0,<1.35.0)", "mypy-boto3-location (>=1.34.0,<1.35.0)", "mypy-boto3-logs (>=1.34.0,<1.35.0)", "mypy-boto3-lookoutequipment (>=1.34.0,<1.35.0)", "mypy-boto3-lookoutmetrics (>=1.34.0,<1.35.0)", "mypy-boto3-lookoutvision (>=1.34.0,<1.35.0)", "mypy-boto3-m2 (>=1.34.0,<1.35.0)", "mypy-boto3-machinelearning (>=1.34.0,<1.35.0)", "mypy-boto3-macie2 (>=1.34.0,<1.35.0)", "mypy-boto3-mailmanager (>=1.34.0,<1.35.0)", "mypy-boto3-managedblockchain (>=1.34.0,<1.35.0)", "mypy-boto3-managedblockchain-query (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-agreement (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-catalog (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-deployment (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-entitlement (>=1.34.0,<1.35.0)", "mypy-boto3-marketplacecommerceanalytics (>=1.34.0,<1.35.0)", "mypy-boto3-mediaconnect (>=1.34.0,<1.35.0)", "mypy-boto3-mediaconvert (>=1.34.0,<1.35.0)", "mypy-boto3-medialive (>=1.34.0,<1.35.0)", "mypy-boto3-mediapackage (>=1.34.0,<1.35.0)", "mypy-boto3-mediapackage-vod (>=1.34.0,<1.35.0)", "mypy-boto3-mediapackagev2 (>=1.34.0,<1.35.0)", "mypy-boto3-mediastore (>=1.34.0,<1.35.0)", "mypy-boto3-mediastore-data (>=1.34.0,<1.35.0)", "mypy-boto3-mediatailor (>=1.34.0,<1.35.0)", "mypy-boto3-medical-imaging (>=1.34.0,<1.35.0)", "mypy-boto3-memorydb (>=1.34.0,<1.35.0)", "mypy-boto3-meteringmarketplace (>=1.34.0,<1.35.0)", "mypy-boto3-mgh (>=1.34.0,<1.35.0)", "mypy-boto3-mgn (>=1.34.0,<1.35.0)", "mypy-boto3-migration-hub-refactor-spaces (>=1.34.0,<1.35.0)", "mypy-boto3-migrationhub-config (>=1.34.0,<1.35.0)", "mypy-boto3-migrationhuborchestrator (>=1.34.0,<1.35.0)", "mypy-boto3-migrationhubstrategy (>=1.34.0,<1.35.0)", "mypy-boto3-mq (>=1.34.0,<1.35.0)", "mypy-boto3-mturk (>=1.34.0,<1.35.0)", "mypy-boto3-mwaa (>=1.34.0,<1.35.0)", "mypy-boto3-neptune (>=1.34.0,<1.35.0)", "mypy-boto3-neptune-graph (>=1.34.0,<1.35.0)", "mypy-boto3-neptunedata (>=1.34.0,<1.35.0)", "mypy-boto3-network-firewall (>=1.34.0,<1.35.0)", "mypy-boto3-networkmanager (>=1.34.0,<1.35.0)", "mypy-boto3-networkmonitor (>=1.34.0,<1.35.0)", "mypy-boto3-nimble (>=1.34.0,<1.35.0)", "mypy-boto3-oam (>=1.34.0,<1.35.0)", "mypy-boto3-omics (>=1.34.0,<1.35.0)", "mypy-boto3-opensearch (>=1.34.0,<1.35.0)", "mypy-boto3-opensearchserverless (>=1.34.0,<1.35.0)", "mypy-boto3-opsworks (>=1.34.0,<1.35.0)", "mypy-boto3-opsworkscm (>=1.34.0,<1.35.0)", "mypy-boto3-organizations (>=1.34.0,<1.35.0)", "mypy-boto3-osis (>=1.34.0,<1.35.0)", "mypy-boto3-outposts (>=1.34.0,<1.35.0)", "mypy-boto3-panorama (>=1.34.0,<1.35.0)", "mypy-boto3-payment-cryptography (>=1.34.0,<1.35.0)", "mypy-boto3-payment-cryptography-data (>=1.34.0,<1.35.0)", "mypy-boto3-pca-connector-ad (>=1.34.0,<1.35.0)", "mypy-boto3-pca-connector-scep (>=1.34.0,<1.35.0)", "mypy-boto3-personalize (>=1.34.0,<1.35.0)", "mypy-boto3-personalize-events (>=1.34.0,<1.35.0)", "mypy-boto3-personalize-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-pi (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint-email (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint-sms-voice (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint-sms-voice-v2 (>=1.34.0,<1.35.0)", "mypy-boto3-pipes (>=1.34.0,<1.35.0)", "mypy-boto3-polly (>=1.34.0,<1.35.0)", "mypy-boto3-pricing (>=1.34.0,<1.35.0)", "mypy-boto3-privatenetworks (>=1.34.0,<1.35.0)", "mypy-boto3-proton (>=1.34.0,<1.35.0)", "mypy-boto3-qapps (>=1.34.0,<1.35.0)", "mypy-boto3-qbusiness (>=1.34.0,<1.35.0)", "mypy-boto3-qconnect (>=1.34.0,<1.35.0)", "mypy-boto3-qldb (>=1.34.0,<1.35.0)", "mypy-boto3-qldb-session (>=1.34.0,<1.35.0)", "mypy-boto3-quicksight (>=1.34.0,<1.35.0)", "mypy-boto3-ram (>=1.34.0,<1.35.0)", "mypy-boto3-rbin (>=1.34.0,<1.35.0)", "mypy-boto3-rds (>=1.34.0,<1.35.0)", "mypy-boto3-rds-data (>=1.34.0,<1.35.0)", "mypy-boto3-redshift (>=1.34.0,<1.35.0)", "mypy-boto3-redshift-data (>=1.34.0,<1.35.0)", "mypy-boto3-redshift-serverless (>=1.34.0,<1.35.0)", "mypy-boto3-rekognition (>=1.34.0,<1.35.0)", "mypy-boto3-repostspace (>=1.34.0,<1.35.0)", "mypy-boto3-resiliencehub (>=1.34.0,<1.35.0)", "mypy-boto3-resource-explorer-2 (>=1.34.0,<1.35.0)", "mypy-boto3-resource-groups (>=1.34.0,<1.35.0)", "mypy-boto3-resourcegroupstaggingapi (>=1.34.0,<1.35.0)", "mypy-boto3-robomaker (>=1.34.0,<1.35.0)", "mypy-boto3-rolesanywhere (>=1.34.0,<1.35.0)", "mypy-boto3-route53 (>=1.34.0,<1.35.0)", "mypy-boto3-route53-recovery-cluster (>=1.34.0,<1.35.0)", "mypy-boto3-route53-recovery-control-config (>=1.34.0,<1.35.0)", "mypy-boto3-route53-recovery-readiness (>=1.34.0,<1.35.0)", "mypy-boto3-route53domains (>=1.34.0,<1.35.0)", "mypy-boto3-route53profiles (>=1.34.0,<1.35.0)", "mypy-boto3-route53resolver (>=1.34.0,<1.35.0)", "mypy-boto3-rum (>=1.34.0,<1.35.0)", "mypy-boto3-s3 (>=1.34.0,<1.35.0)", "mypy-boto3-s3control (>=1.34.0,<1.35.0)", "mypy-boto3-s3outposts (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-a2i-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-edge (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-featurestore-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-geospatial (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-metrics (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-savingsplans (>=1.34.0,<1.35.0)", "mypy-boto3-scheduler (>=1.34.0,<1.35.0)", "mypy-boto3-schemas (>=1.34.0,<1.35.0)", "mypy-boto3-sdb (>=1.34.0,<1.35.0)", "mypy-boto3-secretsmanager (>=1.34.0,<1.35.0)", "mypy-boto3-securityhub (>=1.34.0,<1.35.0)", "mypy-boto3-securitylake (>=1.34.0,<1.35.0)", "mypy-boto3-serverlessrepo (>=1.34.0,<1.35.0)", "mypy-boto3-service-quotas (>=1.34.0,<1.35.0)", "mypy-boto3-servicecatalog (>=1.34.0,<1.35.0)", "mypy-boto3-servicecatalog-appregistry (>=1.34.0,<1.35.0)", "mypy-boto3-servicediscovery (>=1.34.0,<1.35.0)", "mypy-boto3-ses (>=1.34.0,<1.35.0)", "mypy-boto3-sesv2 (>=1.34.0,<1.35.0)", "mypy-boto3-shield (>=1.34.0,<1.35.0)", "mypy-boto3-signer (>=1.34.0,<1.35.0)", "mypy-boto3-simspaceweaver (>=1.34.0,<1.35.0)", "mypy-boto3-sms (>=1.34.0,<1.35.0)", "mypy-boto3-sms-voice (>=1.34.0,<1.35.0)", "mypy-boto3-snow-device-management (>=1.34.0,<1.35.0)", "mypy-boto3-snowball (>=1.34.0,<1.35.0)", "mypy-boto3-sns (>=1.34.0,<1.35.0)", "mypy-boto3-sqs (>=1.34.0,<1.35.0)", "mypy-boto3-ssm (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-contacts (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-incidents (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-quicksetup (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-sap (>=1.34.0,<1.35.0)", "mypy-boto3-sso (>=1.34.0,<1.35.0)", "mypy-boto3-sso-admin (>=1.34.0,<1.35.0)", "mypy-boto3-sso-oidc (>=1.34.0,<1.35.0)", "mypy-boto3-stepfunctions (>=1.34.0,<1.35.0)", "mypy-boto3-storagegateway (>=1.34.0,<1.35.0)", "mypy-boto3-sts (>=1.34.0,<1.35.0)", "mypy-boto3-supplychain (>=1.34.0,<1.35.0)", "mypy-boto3-support (>=1.34.0,<1.35.0)", "mypy-boto3-support-app (>=1.34.0,<1.35.0)", "mypy-boto3-swf (>=1.34.0,<1.35.0)", "mypy-boto3-synthetics (>=1.34.0,<1.35.0)", "mypy-boto3-taxsettings (>=1.34.0,<1.35.0)", "mypy-boto3-textract (>=1.34.0,<1.35.0)", "mypy-boto3-timestream-influxdb (>=1.34.0,<1.35.0)", "mypy-boto3-timestream-query (>=1.34.0,<1.35.0)", "mypy-boto3-timestream-write (>=1.34.0,<1.35.0)", "mypy-boto3-tnb (>=1.34.0,<1.35.0)", "mypy-boto3-transcribe (>=1.34.0,<1.35.0)", "mypy-boto3-transfer (>=1.34.0,<1.35.0)", "mypy-boto3-translate (>=1.34.0,<1.35.0)", "mypy-boto3-trustedadvisor (>=1.34.0,<1.35.0)", "mypy-boto3-verifiedpermissions (>=1.34.0,<1.35.0)", "mypy-boto3-voice-id (>=1.34.0,<1.35.0)", "mypy-boto3-vpc-lattice (>=1.34.0,<1.35.0)", "mypy-boto3-waf (>=1.34.0,<1.35.0)", "mypy-boto3-waf-regional (>=1.34.0,<1.35.0)", "mypy-boto3-wafv2 (>=1.34.0,<1.35.0)", "mypy-boto3-wellarchitected (>=1.34.0,<1.35.0)", "mypy-boto3-wisdom (>=1.34.0,<1.35.0)", "mypy-boto3-workdocs (>=1.34.0,<1.35.0)", "mypy-boto3-worklink (>=1.34.0,<1.35.0)", "mypy-boto3-workmail (>=1.34.0,<1.35.0)", "mypy-boto3-workmailmessageflow (>=1.34.0,<1.35.0)", "mypy-boto3-workspaces (>=1.34.0,<1.35.0)", "mypy-boto3-workspaces-thin-client (>=1.34.0,<1.35.0)", "mypy-boto3-workspaces-web (>=1.34.0,<1.35.0)", "mypy-boto3-xray (>=1.34.0,<1.35.0)"] -amp = ["mypy-boto3-amp (>=1.34.0,<1.35.0)"] -amplify = ["mypy-boto3-amplify (>=1.34.0,<1.35.0)"] -amplifybackend = ["mypy-boto3-amplifybackend (>=1.34.0,<1.35.0)"] -amplifyuibuilder = ["mypy-boto3-amplifyuibuilder (>=1.34.0,<1.35.0)"] -apigateway = ["mypy-boto3-apigateway (>=1.34.0,<1.35.0)"] -apigatewaymanagementapi = ["mypy-boto3-apigatewaymanagementapi (>=1.34.0,<1.35.0)"] -apigatewayv2 = ["mypy-boto3-apigatewayv2 (>=1.34.0,<1.35.0)"] -appconfig = ["mypy-boto3-appconfig (>=1.34.0,<1.35.0)"] -appconfigdata = ["mypy-boto3-appconfigdata (>=1.34.0,<1.35.0)"] -appfabric = ["mypy-boto3-appfabric (>=1.34.0,<1.35.0)"] -appflow = ["mypy-boto3-appflow (>=1.34.0,<1.35.0)"] -appintegrations = ["mypy-boto3-appintegrations (>=1.34.0,<1.35.0)"] -application-autoscaling = ["mypy-boto3-application-autoscaling (>=1.34.0,<1.35.0)"] -application-insights = ["mypy-boto3-application-insights (>=1.34.0,<1.35.0)"] -application-signals = ["mypy-boto3-application-signals (>=1.34.0,<1.35.0)"] -applicationcostprofiler = ["mypy-boto3-applicationcostprofiler (>=1.34.0,<1.35.0)"] -appmesh = ["mypy-boto3-appmesh (>=1.34.0,<1.35.0)"] -apprunner = ["mypy-boto3-apprunner (>=1.34.0,<1.35.0)"] -appstream = ["mypy-boto3-appstream (>=1.34.0,<1.35.0)"] -appsync = ["mypy-boto3-appsync (>=1.34.0,<1.35.0)"] -apptest = ["mypy-boto3-apptest (>=1.34.0,<1.35.0)"] -arc-zonal-shift = ["mypy-boto3-arc-zonal-shift (>=1.34.0,<1.35.0)"] -artifact = ["mypy-boto3-artifact (>=1.34.0,<1.35.0)"] -athena = ["mypy-boto3-athena (>=1.34.0,<1.35.0)"] -auditmanager = ["mypy-boto3-auditmanager (>=1.34.0,<1.35.0)"] -autoscaling = ["mypy-boto3-autoscaling (>=1.34.0,<1.35.0)"] -autoscaling-plans = ["mypy-boto3-autoscaling-plans (>=1.34.0,<1.35.0)"] -b2bi = ["mypy-boto3-b2bi (>=1.34.0,<1.35.0)"] -backup = ["mypy-boto3-backup (>=1.34.0,<1.35.0)"] -backup-gateway = ["mypy-boto3-backup-gateway (>=1.34.0,<1.35.0)"] -batch = ["mypy-boto3-batch (>=1.34.0,<1.35.0)"] -bcm-data-exports = ["mypy-boto3-bcm-data-exports (>=1.34.0,<1.35.0)"] -bedrock = ["mypy-boto3-bedrock (>=1.34.0,<1.35.0)"] -bedrock-agent = ["mypy-boto3-bedrock-agent (>=1.34.0,<1.35.0)"] -bedrock-agent-runtime = ["mypy-boto3-bedrock-agent-runtime (>=1.34.0,<1.35.0)"] -bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.34.0,<1.35.0)"] -billingconductor = ["mypy-boto3-billingconductor (>=1.34.0,<1.35.0)"] -boto3 = ["boto3 (==1.34.162)", "botocore (==1.34.162)"] -braket = ["mypy-boto3-braket (>=1.34.0,<1.35.0)"] -budgets = ["mypy-boto3-budgets (>=1.34.0,<1.35.0)"] -ce = ["mypy-boto3-ce (>=1.34.0,<1.35.0)"] -chatbot = ["mypy-boto3-chatbot (>=1.34.0,<1.35.0)"] -chime = ["mypy-boto3-chime (>=1.34.0,<1.35.0)"] -chime-sdk-identity = ["mypy-boto3-chime-sdk-identity (>=1.34.0,<1.35.0)"] -chime-sdk-media-pipelines = ["mypy-boto3-chime-sdk-media-pipelines (>=1.34.0,<1.35.0)"] -chime-sdk-meetings = ["mypy-boto3-chime-sdk-meetings (>=1.34.0,<1.35.0)"] -chime-sdk-messaging = ["mypy-boto3-chime-sdk-messaging (>=1.34.0,<1.35.0)"] -chime-sdk-voice = ["mypy-boto3-chime-sdk-voice (>=1.34.0,<1.35.0)"] -cleanrooms = ["mypy-boto3-cleanrooms (>=1.34.0,<1.35.0)"] -cleanroomsml = ["mypy-boto3-cleanroomsml (>=1.34.0,<1.35.0)"] -cloud9 = ["mypy-boto3-cloud9 (>=1.34.0,<1.35.0)"] -cloudcontrol = ["mypy-boto3-cloudcontrol (>=1.34.0,<1.35.0)"] -clouddirectory = ["mypy-boto3-clouddirectory (>=1.34.0,<1.35.0)"] -cloudformation = ["mypy-boto3-cloudformation (>=1.34.0,<1.35.0)"] -cloudfront = ["mypy-boto3-cloudfront (>=1.34.0,<1.35.0)"] -cloudfront-keyvaluestore = ["mypy-boto3-cloudfront-keyvaluestore (>=1.34.0,<1.35.0)"] -cloudhsm = ["mypy-boto3-cloudhsm (>=1.34.0,<1.35.0)"] -cloudhsmv2 = ["mypy-boto3-cloudhsmv2 (>=1.34.0,<1.35.0)"] -cloudsearch = ["mypy-boto3-cloudsearch (>=1.34.0,<1.35.0)"] -cloudsearchdomain = ["mypy-boto3-cloudsearchdomain (>=1.34.0,<1.35.0)"] -cloudtrail = ["mypy-boto3-cloudtrail (>=1.34.0,<1.35.0)"] -cloudtrail-data = ["mypy-boto3-cloudtrail-data (>=1.34.0,<1.35.0)"] -cloudwatch = ["mypy-boto3-cloudwatch (>=1.34.0,<1.35.0)"] -codeartifact = ["mypy-boto3-codeartifact (>=1.34.0,<1.35.0)"] -codebuild = ["mypy-boto3-codebuild (>=1.34.0,<1.35.0)"] -codecatalyst = ["mypy-boto3-codecatalyst (>=1.34.0,<1.35.0)"] -codecommit = ["mypy-boto3-codecommit (>=1.34.0,<1.35.0)"] -codeconnections = ["mypy-boto3-codeconnections (>=1.34.0,<1.35.0)"] -codedeploy = ["mypy-boto3-codedeploy (>=1.34.0,<1.35.0)"] -codeguru-reviewer = ["mypy-boto3-codeguru-reviewer (>=1.34.0,<1.35.0)"] -codeguru-security = ["mypy-boto3-codeguru-security (>=1.34.0,<1.35.0)"] -codeguruprofiler = ["mypy-boto3-codeguruprofiler (>=1.34.0,<1.35.0)"] -codepipeline = ["mypy-boto3-codepipeline (>=1.34.0,<1.35.0)"] -codestar = ["mypy-boto3-codestar (>=1.34.0,<1.35.0)"] -codestar-connections = ["mypy-boto3-codestar-connections (>=1.34.0,<1.35.0)"] -codestar-notifications = ["mypy-boto3-codestar-notifications (>=1.34.0,<1.35.0)"] -cognito-identity = ["mypy-boto3-cognito-identity (>=1.34.0,<1.35.0)"] -cognito-idp = ["mypy-boto3-cognito-idp (>=1.34.0,<1.35.0)"] -cognito-sync = ["mypy-boto3-cognito-sync (>=1.34.0,<1.35.0)"] -comprehend = ["mypy-boto3-comprehend (>=1.34.0,<1.35.0)"] -comprehendmedical = ["mypy-boto3-comprehendmedical (>=1.34.0,<1.35.0)"] -compute-optimizer = ["mypy-boto3-compute-optimizer (>=1.34.0,<1.35.0)"] -config = ["mypy-boto3-config (>=1.34.0,<1.35.0)"] -connect = ["mypy-boto3-connect (>=1.34.0,<1.35.0)"] -connect-contact-lens = ["mypy-boto3-connect-contact-lens (>=1.34.0,<1.35.0)"] -connectcampaigns = ["mypy-boto3-connectcampaigns (>=1.34.0,<1.35.0)"] -connectcases = ["mypy-boto3-connectcases (>=1.34.0,<1.35.0)"] -connectparticipant = ["mypy-boto3-connectparticipant (>=1.34.0,<1.35.0)"] -controlcatalog = ["mypy-boto3-controlcatalog (>=1.34.0,<1.35.0)"] -controltower = ["mypy-boto3-controltower (>=1.34.0,<1.35.0)"] -cost-optimization-hub = ["mypy-boto3-cost-optimization-hub (>=1.34.0,<1.35.0)"] -cur = ["mypy-boto3-cur (>=1.34.0,<1.35.0)"] -customer-profiles = ["mypy-boto3-customer-profiles (>=1.34.0,<1.35.0)"] -databrew = ["mypy-boto3-databrew (>=1.34.0,<1.35.0)"] -dataexchange = ["mypy-boto3-dataexchange (>=1.34.0,<1.35.0)"] -datapipeline = ["mypy-boto3-datapipeline (>=1.34.0,<1.35.0)"] -datasync = ["mypy-boto3-datasync (>=1.34.0,<1.35.0)"] -datazone = ["mypy-boto3-datazone (>=1.34.0,<1.35.0)"] -dax = ["mypy-boto3-dax (>=1.34.0,<1.35.0)"] -deadline = ["mypy-boto3-deadline (>=1.34.0,<1.35.0)"] -detective = ["mypy-boto3-detective (>=1.34.0,<1.35.0)"] -devicefarm = ["mypy-boto3-devicefarm (>=1.34.0,<1.35.0)"] -devops-guru = ["mypy-boto3-devops-guru (>=1.34.0,<1.35.0)"] -directconnect = ["mypy-boto3-directconnect (>=1.34.0,<1.35.0)"] -discovery = ["mypy-boto3-discovery (>=1.34.0,<1.35.0)"] -dlm = ["mypy-boto3-dlm (>=1.34.0,<1.35.0)"] -dms = ["mypy-boto3-dms (>=1.34.0,<1.35.0)"] -docdb = ["mypy-boto3-docdb (>=1.34.0,<1.35.0)"] -docdb-elastic = ["mypy-boto3-docdb-elastic (>=1.34.0,<1.35.0)"] -drs = ["mypy-boto3-drs (>=1.34.0,<1.35.0)"] -ds = ["mypy-boto3-ds (>=1.34.0,<1.35.0)"] -dynamodb = ["mypy-boto3-dynamodb (>=1.34.0,<1.35.0)"] -dynamodbstreams = ["mypy-boto3-dynamodbstreams (>=1.34.0,<1.35.0)"] -ebs = ["mypy-boto3-ebs (>=1.34.0,<1.35.0)"] -ec2 = ["mypy-boto3-ec2 (>=1.34.0,<1.35.0)"] -ec2-instance-connect = ["mypy-boto3-ec2-instance-connect (>=1.34.0,<1.35.0)"] -ecr = ["mypy-boto3-ecr (>=1.34.0,<1.35.0)"] -ecr-public = ["mypy-boto3-ecr-public (>=1.34.0,<1.35.0)"] -ecs = ["mypy-boto3-ecs (>=1.34.0,<1.35.0)"] -efs = ["mypy-boto3-efs (>=1.34.0,<1.35.0)"] -eks = ["mypy-boto3-eks (>=1.34.0,<1.35.0)"] -eks-auth = ["mypy-boto3-eks-auth (>=1.34.0,<1.35.0)"] -elastic-inference = ["mypy-boto3-elastic-inference (>=1.34.0,<1.35.0)"] -elasticache = ["mypy-boto3-elasticache (>=1.34.0,<1.35.0)"] -elasticbeanstalk = ["mypy-boto3-elasticbeanstalk (>=1.34.0,<1.35.0)"] -elastictranscoder = ["mypy-boto3-elastictranscoder (>=1.34.0,<1.35.0)"] -elb = ["mypy-boto3-elb (>=1.34.0,<1.35.0)"] -elbv2 = ["mypy-boto3-elbv2 (>=1.34.0,<1.35.0)"] -emr = ["mypy-boto3-emr (>=1.34.0,<1.35.0)"] -emr-containers = ["mypy-boto3-emr-containers (>=1.34.0,<1.35.0)"] -emr-serverless = ["mypy-boto3-emr-serverless (>=1.34.0,<1.35.0)"] -entityresolution = ["mypy-boto3-entityresolution (>=1.34.0,<1.35.0)"] -es = ["mypy-boto3-es (>=1.34.0,<1.35.0)"] -essential = ["mypy-boto3-cloudformation (>=1.34.0,<1.35.0)", "mypy-boto3-dynamodb (>=1.34.0,<1.35.0)", "mypy-boto3-ec2 (>=1.34.0,<1.35.0)", "mypy-boto3-lambda (>=1.34.0,<1.35.0)", "mypy-boto3-rds (>=1.34.0,<1.35.0)", "mypy-boto3-s3 (>=1.34.0,<1.35.0)", "mypy-boto3-sqs (>=1.34.0,<1.35.0)"] -events = ["mypy-boto3-events (>=1.34.0,<1.35.0)"] -evidently = ["mypy-boto3-evidently (>=1.34.0,<1.35.0)"] -finspace = ["mypy-boto3-finspace (>=1.34.0,<1.35.0)"] -finspace-data = ["mypy-boto3-finspace-data (>=1.34.0,<1.35.0)"] -firehose = ["mypy-boto3-firehose (>=1.34.0,<1.35.0)"] -fis = ["mypy-boto3-fis (>=1.34.0,<1.35.0)"] -fms = ["mypy-boto3-fms (>=1.34.0,<1.35.0)"] -forecast = ["mypy-boto3-forecast (>=1.34.0,<1.35.0)"] -forecastquery = ["mypy-boto3-forecastquery (>=1.34.0,<1.35.0)"] -frauddetector = ["mypy-boto3-frauddetector (>=1.34.0,<1.35.0)"] -freetier = ["mypy-boto3-freetier (>=1.34.0,<1.35.0)"] -fsx = ["mypy-boto3-fsx (>=1.34.0,<1.35.0)"] -gamelift = ["mypy-boto3-gamelift (>=1.34.0,<1.35.0)"] -glacier = ["mypy-boto3-glacier (>=1.34.0,<1.35.0)"] -globalaccelerator = ["mypy-boto3-globalaccelerator (>=1.34.0,<1.35.0)"] -glue = ["mypy-boto3-glue (>=1.34.0,<1.35.0)"] -grafana = ["mypy-boto3-grafana (>=1.34.0,<1.35.0)"] -greengrass = ["mypy-boto3-greengrass (>=1.34.0,<1.35.0)"] -greengrassv2 = ["mypy-boto3-greengrassv2 (>=1.34.0,<1.35.0)"] -groundstation = ["mypy-boto3-groundstation (>=1.34.0,<1.35.0)"] -guardduty = ["mypy-boto3-guardduty (>=1.34.0,<1.35.0)"] -health = ["mypy-boto3-health (>=1.34.0,<1.35.0)"] -healthlake = ["mypy-boto3-healthlake (>=1.34.0,<1.35.0)"] -iam = ["mypy-boto3-iam (>=1.34.0,<1.35.0)"] -identitystore = ["mypy-boto3-identitystore (>=1.34.0,<1.35.0)"] -imagebuilder = ["mypy-boto3-imagebuilder (>=1.34.0,<1.35.0)"] -importexport = ["mypy-boto3-importexport (>=1.34.0,<1.35.0)"] -inspector = ["mypy-boto3-inspector (>=1.34.0,<1.35.0)"] -inspector-scan = ["mypy-boto3-inspector-scan (>=1.34.0,<1.35.0)"] -inspector2 = ["mypy-boto3-inspector2 (>=1.34.0,<1.35.0)"] -internetmonitor = ["mypy-boto3-internetmonitor (>=1.34.0,<1.35.0)"] -iot = ["mypy-boto3-iot (>=1.34.0,<1.35.0)"] -iot-data = ["mypy-boto3-iot-data (>=1.34.0,<1.35.0)"] -iot-jobs-data = ["mypy-boto3-iot-jobs-data (>=1.34.0,<1.35.0)"] -iot1click-devices = ["mypy-boto3-iot1click-devices (>=1.34.0,<1.35.0)"] -iot1click-projects = ["mypy-boto3-iot1click-projects (>=1.34.0,<1.35.0)"] -iotanalytics = ["mypy-boto3-iotanalytics (>=1.34.0,<1.35.0)"] -iotdeviceadvisor = ["mypy-boto3-iotdeviceadvisor (>=1.34.0,<1.35.0)"] -iotevents = ["mypy-boto3-iotevents (>=1.34.0,<1.35.0)"] -iotevents-data = ["mypy-boto3-iotevents-data (>=1.34.0,<1.35.0)"] -iotfleethub = ["mypy-boto3-iotfleethub (>=1.34.0,<1.35.0)"] -iotfleetwise = ["mypy-boto3-iotfleetwise (>=1.34.0,<1.35.0)"] -iotsecuretunneling = ["mypy-boto3-iotsecuretunneling (>=1.34.0,<1.35.0)"] -iotsitewise = ["mypy-boto3-iotsitewise (>=1.34.0,<1.35.0)"] -iotthingsgraph = ["mypy-boto3-iotthingsgraph (>=1.34.0,<1.35.0)"] -iottwinmaker = ["mypy-boto3-iottwinmaker (>=1.34.0,<1.35.0)"] -iotwireless = ["mypy-boto3-iotwireless (>=1.34.0,<1.35.0)"] -ivs = ["mypy-boto3-ivs (>=1.34.0,<1.35.0)"] -ivs-realtime = ["mypy-boto3-ivs-realtime (>=1.34.0,<1.35.0)"] -ivschat = ["mypy-boto3-ivschat (>=1.34.0,<1.35.0)"] -kafka = ["mypy-boto3-kafka (>=1.34.0,<1.35.0)"] -kafkaconnect = ["mypy-boto3-kafkaconnect (>=1.34.0,<1.35.0)"] -kendra = ["mypy-boto3-kendra (>=1.34.0,<1.35.0)"] -kendra-ranking = ["mypy-boto3-kendra-ranking (>=1.34.0,<1.35.0)"] -keyspaces = ["mypy-boto3-keyspaces (>=1.34.0,<1.35.0)"] -kinesis = ["mypy-boto3-kinesis (>=1.34.0,<1.35.0)"] -kinesis-video-archived-media = ["mypy-boto3-kinesis-video-archived-media (>=1.34.0,<1.35.0)"] -kinesis-video-media = ["mypy-boto3-kinesis-video-media (>=1.34.0,<1.35.0)"] -kinesis-video-signaling = ["mypy-boto3-kinesis-video-signaling (>=1.34.0,<1.35.0)"] -kinesis-video-webrtc-storage = ["mypy-boto3-kinesis-video-webrtc-storage (>=1.34.0,<1.35.0)"] -kinesisanalytics = ["mypy-boto3-kinesisanalytics (>=1.34.0,<1.35.0)"] -kinesisanalyticsv2 = ["mypy-boto3-kinesisanalyticsv2 (>=1.34.0,<1.35.0)"] -kinesisvideo = ["mypy-boto3-kinesisvideo (>=1.34.0,<1.35.0)"] -kms = ["mypy-boto3-kms (>=1.34.0,<1.35.0)"] -lakeformation = ["mypy-boto3-lakeformation (>=1.34.0,<1.35.0)"] -lambda = ["mypy-boto3-lambda (>=1.34.0,<1.35.0)"] -launch-wizard = ["mypy-boto3-launch-wizard (>=1.34.0,<1.35.0)"] -lex-models = ["mypy-boto3-lex-models (>=1.34.0,<1.35.0)"] -lex-runtime = ["mypy-boto3-lex-runtime (>=1.34.0,<1.35.0)"] -lexv2-models = ["mypy-boto3-lexv2-models (>=1.34.0,<1.35.0)"] -lexv2-runtime = ["mypy-boto3-lexv2-runtime (>=1.34.0,<1.35.0)"] -license-manager = ["mypy-boto3-license-manager (>=1.34.0,<1.35.0)"] -license-manager-linux-subscriptions = ["mypy-boto3-license-manager-linux-subscriptions (>=1.34.0,<1.35.0)"] -license-manager-user-subscriptions = ["mypy-boto3-license-manager-user-subscriptions (>=1.34.0,<1.35.0)"] -lightsail = ["mypy-boto3-lightsail (>=1.34.0,<1.35.0)"] -location = ["mypy-boto3-location (>=1.34.0,<1.35.0)"] -logs = ["mypy-boto3-logs (>=1.34.0,<1.35.0)"] -lookoutequipment = ["mypy-boto3-lookoutequipment (>=1.34.0,<1.35.0)"] -lookoutmetrics = ["mypy-boto3-lookoutmetrics (>=1.34.0,<1.35.0)"] -lookoutvision = ["mypy-boto3-lookoutvision (>=1.34.0,<1.35.0)"] -m2 = ["mypy-boto3-m2 (>=1.34.0,<1.35.0)"] -machinelearning = ["mypy-boto3-machinelearning (>=1.34.0,<1.35.0)"] -macie2 = ["mypy-boto3-macie2 (>=1.34.0,<1.35.0)"] -mailmanager = ["mypy-boto3-mailmanager (>=1.34.0,<1.35.0)"] -managedblockchain = ["mypy-boto3-managedblockchain (>=1.34.0,<1.35.0)"] -managedblockchain-query = ["mypy-boto3-managedblockchain-query (>=1.34.0,<1.35.0)"] -marketplace-agreement = ["mypy-boto3-marketplace-agreement (>=1.34.0,<1.35.0)"] -marketplace-catalog = ["mypy-boto3-marketplace-catalog (>=1.34.0,<1.35.0)"] -marketplace-deployment = ["mypy-boto3-marketplace-deployment (>=1.34.0,<1.35.0)"] -marketplace-entitlement = ["mypy-boto3-marketplace-entitlement (>=1.34.0,<1.35.0)"] -marketplacecommerceanalytics = ["mypy-boto3-marketplacecommerceanalytics (>=1.34.0,<1.35.0)"] -mediaconnect = ["mypy-boto3-mediaconnect (>=1.34.0,<1.35.0)"] -mediaconvert = ["mypy-boto3-mediaconvert (>=1.34.0,<1.35.0)"] -medialive = ["mypy-boto3-medialive (>=1.34.0,<1.35.0)"] -mediapackage = ["mypy-boto3-mediapackage (>=1.34.0,<1.35.0)"] -mediapackage-vod = ["mypy-boto3-mediapackage-vod (>=1.34.0,<1.35.0)"] -mediapackagev2 = ["mypy-boto3-mediapackagev2 (>=1.34.0,<1.35.0)"] -mediastore = ["mypy-boto3-mediastore (>=1.34.0,<1.35.0)"] -mediastore-data = ["mypy-boto3-mediastore-data (>=1.34.0,<1.35.0)"] -mediatailor = ["mypy-boto3-mediatailor (>=1.34.0,<1.35.0)"] -medical-imaging = ["mypy-boto3-medical-imaging (>=1.34.0,<1.35.0)"] -memorydb = ["mypy-boto3-memorydb (>=1.34.0,<1.35.0)"] -meteringmarketplace = ["mypy-boto3-meteringmarketplace (>=1.34.0,<1.35.0)"] -mgh = ["mypy-boto3-mgh (>=1.34.0,<1.35.0)"] -mgn = ["mypy-boto3-mgn (>=1.34.0,<1.35.0)"] -migration-hub-refactor-spaces = ["mypy-boto3-migration-hub-refactor-spaces (>=1.34.0,<1.35.0)"] -migrationhub-config = ["mypy-boto3-migrationhub-config (>=1.34.0,<1.35.0)"] -migrationhuborchestrator = ["mypy-boto3-migrationhuborchestrator (>=1.34.0,<1.35.0)"] -migrationhubstrategy = ["mypy-boto3-migrationhubstrategy (>=1.34.0,<1.35.0)"] -mq = ["mypy-boto3-mq (>=1.34.0,<1.35.0)"] -mturk = ["mypy-boto3-mturk (>=1.34.0,<1.35.0)"] -mwaa = ["mypy-boto3-mwaa (>=1.34.0,<1.35.0)"] -neptune = ["mypy-boto3-neptune (>=1.34.0,<1.35.0)"] -neptune-graph = ["mypy-boto3-neptune-graph (>=1.34.0,<1.35.0)"] -neptunedata = ["mypy-boto3-neptunedata (>=1.34.0,<1.35.0)"] -network-firewall = ["mypy-boto3-network-firewall (>=1.34.0,<1.35.0)"] -networkmanager = ["mypy-boto3-networkmanager (>=1.34.0,<1.35.0)"] -networkmonitor = ["mypy-boto3-networkmonitor (>=1.34.0,<1.35.0)"] -nimble = ["mypy-boto3-nimble (>=1.34.0,<1.35.0)"] -oam = ["mypy-boto3-oam (>=1.34.0,<1.35.0)"] -omics = ["mypy-boto3-omics (>=1.34.0,<1.35.0)"] -opensearch = ["mypy-boto3-opensearch (>=1.34.0,<1.35.0)"] -opensearchserverless = ["mypy-boto3-opensearchserverless (>=1.34.0,<1.35.0)"] -opsworks = ["mypy-boto3-opsworks (>=1.34.0,<1.35.0)"] -opsworkscm = ["mypy-boto3-opsworkscm (>=1.34.0,<1.35.0)"] -organizations = ["mypy-boto3-organizations (>=1.34.0,<1.35.0)"] -osis = ["mypy-boto3-osis (>=1.34.0,<1.35.0)"] -outposts = ["mypy-boto3-outposts (>=1.34.0,<1.35.0)"] -panorama = ["mypy-boto3-panorama (>=1.34.0,<1.35.0)"] -payment-cryptography = ["mypy-boto3-payment-cryptography (>=1.34.0,<1.35.0)"] -payment-cryptography-data = ["mypy-boto3-payment-cryptography-data (>=1.34.0,<1.35.0)"] -pca-connector-ad = ["mypy-boto3-pca-connector-ad (>=1.34.0,<1.35.0)"] -pca-connector-scep = ["mypy-boto3-pca-connector-scep (>=1.34.0,<1.35.0)"] -personalize = ["mypy-boto3-personalize (>=1.34.0,<1.35.0)"] -personalize-events = ["mypy-boto3-personalize-events (>=1.34.0,<1.35.0)"] -personalize-runtime = ["mypy-boto3-personalize-runtime (>=1.34.0,<1.35.0)"] -pi = ["mypy-boto3-pi (>=1.34.0,<1.35.0)"] -pinpoint = ["mypy-boto3-pinpoint (>=1.34.0,<1.35.0)"] -pinpoint-email = ["mypy-boto3-pinpoint-email (>=1.34.0,<1.35.0)"] -pinpoint-sms-voice = ["mypy-boto3-pinpoint-sms-voice (>=1.34.0,<1.35.0)"] -pinpoint-sms-voice-v2 = ["mypy-boto3-pinpoint-sms-voice-v2 (>=1.34.0,<1.35.0)"] -pipes = ["mypy-boto3-pipes (>=1.34.0,<1.35.0)"] -polly = ["mypy-boto3-polly (>=1.34.0,<1.35.0)"] -pricing = ["mypy-boto3-pricing (>=1.34.0,<1.35.0)"] -privatenetworks = ["mypy-boto3-privatenetworks (>=1.34.0,<1.35.0)"] -proton = ["mypy-boto3-proton (>=1.34.0,<1.35.0)"] -qapps = ["mypy-boto3-qapps (>=1.34.0,<1.35.0)"] -qbusiness = ["mypy-boto3-qbusiness (>=1.34.0,<1.35.0)"] -qconnect = ["mypy-boto3-qconnect (>=1.34.0,<1.35.0)"] -qldb = ["mypy-boto3-qldb (>=1.34.0,<1.35.0)"] -qldb-session = ["mypy-boto3-qldb-session (>=1.34.0,<1.35.0)"] -quicksight = ["mypy-boto3-quicksight (>=1.34.0,<1.35.0)"] -ram = ["mypy-boto3-ram (>=1.34.0,<1.35.0)"] -rbin = ["mypy-boto3-rbin (>=1.34.0,<1.35.0)"] -rds = ["mypy-boto3-rds (>=1.34.0,<1.35.0)"] -rds-data = ["mypy-boto3-rds-data (>=1.34.0,<1.35.0)"] -redshift = ["mypy-boto3-redshift (>=1.34.0,<1.35.0)"] -redshift-data = ["mypy-boto3-redshift-data (>=1.34.0,<1.35.0)"] -redshift-serverless = ["mypy-boto3-redshift-serverless (>=1.34.0,<1.35.0)"] -rekognition = ["mypy-boto3-rekognition (>=1.34.0,<1.35.0)"] -repostspace = ["mypy-boto3-repostspace (>=1.34.0,<1.35.0)"] -resiliencehub = ["mypy-boto3-resiliencehub (>=1.34.0,<1.35.0)"] -resource-explorer-2 = ["mypy-boto3-resource-explorer-2 (>=1.34.0,<1.35.0)"] -resource-groups = ["mypy-boto3-resource-groups (>=1.34.0,<1.35.0)"] -resourcegroupstaggingapi = ["mypy-boto3-resourcegroupstaggingapi (>=1.34.0,<1.35.0)"] -robomaker = ["mypy-boto3-robomaker (>=1.34.0,<1.35.0)"] -rolesanywhere = ["mypy-boto3-rolesanywhere (>=1.34.0,<1.35.0)"] -route53 = ["mypy-boto3-route53 (>=1.34.0,<1.35.0)"] -route53-recovery-cluster = ["mypy-boto3-route53-recovery-cluster (>=1.34.0,<1.35.0)"] -route53-recovery-control-config = ["mypy-boto3-route53-recovery-control-config (>=1.34.0,<1.35.0)"] -route53-recovery-readiness = ["mypy-boto3-route53-recovery-readiness (>=1.34.0,<1.35.0)"] -route53domains = ["mypy-boto3-route53domains (>=1.34.0,<1.35.0)"] -route53profiles = ["mypy-boto3-route53profiles (>=1.34.0,<1.35.0)"] -route53resolver = ["mypy-boto3-route53resolver (>=1.34.0,<1.35.0)"] -rum = ["mypy-boto3-rum (>=1.34.0,<1.35.0)"] -s3 = ["mypy-boto3-s3 (>=1.34.0,<1.35.0)"] -s3control = ["mypy-boto3-s3control (>=1.34.0,<1.35.0)"] -s3outposts = ["mypy-boto3-s3outposts (>=1.34.0,<1.35.0)"] -sagemaker = ["mypy-boto3-sagemaker (>=1.34.0,<1.35.0)"] -sagemaker-a2i-runtime = ["mypy-boto3-sagemaker-a2i-runtime (>=1.34.0,<1.35.0)"] -sagemaker-edge = ["mypy-boto3-sagemaker-edge (>=1.34.0,<1.35.0)"] -sagemaker-featurestore-runtime = ["mypy-boto3-sagemaker-featurestore-runtime (>=1.34.0,<1.35.0)"] -sagemaker-geospatial = ["mypy-boto3-sagemaker-geospatial (>=1.34.0,<1.35.0)"] -sagemaker-metrics = ["mypy-boto3-sagemaker-metrics (>=1.34.0,<1.35.0)"] -sagemaker-runtime = ["mypy-boto3-sagemaker-runtime (>=1.34.0,<1.35.0)"] -savingsplans = ["mypy-boto3-savingsplans (>=1.34.0,<1.35.0)"] -scheduler = ["mypy-boto3-scheduler (>=1.34.0,<1.35.0)"] -schemas = ["mypy-boto3-schemas (>=1.34.0,<1.35.0)"] -sdb = ["mypy-boto3-sdb (>=1.34.0,<1.35.0)"] -secretsmanager = ["mypy-boto3-secretsmanager (>=1.34.0,<1.35.0)"] -securityhub = ["mypy-boto3-securityhub (>=1.34.0,<1.35.0)"] -securitylake = ["mypy-boto3-securitylake (>=1.34.0,<1.35.0)"] -serverlessrepo = ["mypy-boto3-serverlessrepo (>=1.34.0,<1.35.0)"] -service-quotas = ["mypy-boto3-service-quotas (>=1.34.0,<1.35.0)"] -servicecatalog = ["mypy-boto3-servicecatalog (>=1.34.0,<1.35.0)"] -servicecatalog-appregistry = ["mypy-boto3-servicecatalog-appregistry (>=1.34.0,<1.35.0)"] -servicediscovery = ["mypy-boto3-servicediscovery (>=1.34.0,<1.35.0)"] -ses = ["mypy-boto3-ses (>=1.34.0,<1.35.0)"] -sesv2 = ["mypy-boto3-sesv2 (>=1.34.0,<1.35.0)"] -shield = ["mypy-boto3-shield (>=1.34.0,<1.35.0)"] -signer = ["mypy-boto3-signer (>=1.34.0,<1.35.0)"] -simspaceweaver = ["mypy-boto3-simspaceweaver (>=1.34.0,<1.35.0)"] -sms = ["mypy-boto3-sms (>=1.34.0,<1.35.0)"] -sms-voice = ["mypy-boto3-sms-voice (>=1.34.0,<1.35.0)"] -snow-device-management = ["mypy-boto3-snow-device-management (>=1.34.0,<1.35.0)"] -snowball = ["mypy-boto3-snowball (>=1.34.0,<1.35.0)"] -sns = ["mypy-boto3-sns (>=1.34.0,<1.35.0)"] -sqs = ["mypy-boto3-sqs (>=1.34.0,<1.35.0)"] -ssm = ["mypy-boto3-ssm (>=1.34.0,<1.35.0)"] -ssm-contacts = ["mypy-boto3-ssm-contacts (>=1.34.0,<1.35.0)"] -ssm-incidents = ["mypy-boto3-ssm-incidents (>=1.34.0,<1.35.0)"] -ssm-quicksetup = ["mypy-boto3-ssm-quicksetup (>=1.34.0,<1.35.0)"] -ssm-sap = ["mypy-boto3-ssm-sap (>=1.34.0,<1.35.0)"] -sso = ["mypy-boto3-sso (>=1.34.0,<1.35.0)"] -sso-admin = ["mypy-boto3-sso-admin (>=1.34.0,<1.35.0)"] -sso-oidc = ["mypy-boto3-sso-oidc (>=1.34.0,<1.35.0)"] -stepfunctions = ["mypy-boto3-stepfunctions (>=1.34.0,<1.35.0)"] -storagegateway = ["mypy-boto3-storagegateway (>=1.34.0,<1.35.0)"] -sts = ["mypy-boto3-sts (>=1.34.0,<1.35.0)"] -supplychain = ["mypy-boto3-supplychain (>=1.34.0,<1.35.0)"] -support = ["mypy-boto3-support (>=1.34.0,<1.35.0)"] -support-app = ["mypy-boto3-support-app (>=1.34.0,<1.35.0)"] -swf = ["mypy-boto3-swf (>=1.34.0,<1.35.0)"] -synthetics = ["mypy-boto3-synthetics (>=1.34.0,<1.35.0)"] -taxsettings = ["mypy-boto3-taxsettings (>=1.34.0,<1.35.0)"] -textract = ["mypy-boto3-textract (>=1.34.0,<1.35.0)"] -timestream-influxdb = ["mypy-boto3-timestream-influxdb (>=1.34.0,<1.35.0)"] -timestream-query = ["mypy-boto3-timestream-query (>=1.34.0,<1.35.0)"] -timestream-write = ["mypy-boto3-timestream-write (>=1.34.0,<1.35.0)"] -tnb = ["mypy-boto3-tnb (>=1.34.0,<1.35.0)"] -transcribe = ["mypy-boto3-transcribe (>=1.34.0,<1.35.0)"] -transfer = ["mypy-boto3-transfer (>=1.34.0,<1.35.0)"] -translate = ["mypy-boto3-translate (>=1.34.0,<1.35.0)"] -trustedadvisor = ["mypy-boto3-trustedadvisor (>=1.34.0,<1.35.0)"] -verifiedpermissions = ["mypy-boto3-verifiedpermissions (>=1.34.0,<1.35.0)"] -voice-id = ["mypy-boto3-voice-id (>=1.34.0,<1.35.0)"] -vpc-lattice = ["mypy-boto3-vpc-lattice (>=1.34.0,<1.35.0)"] -waf = ["mypy-boto3-waf (>=1.34.0,<1.35.0)"] -waf-regional = ["mypy-boto3-waf-regional (>=1.34.0,<1.35.0)"] -wafv2 = ["mypy-boto3-wafv2 (>=1.34.0,<1.35.0)"] -wellarchitected = ["mypy-boto3-wellarchitected (>=1.34.0,<1.35.0)"] -wisdom = ["mypy-boto3-wisdom (>=1.34.0,<1.35.0)"] -workdocs = ["mypy-boto3-workdocs (>=1.34.0,<1.35.0)"] -worklink = ["mypy-boto3-worklink (>=1.34.0,<1.35.0)"] -workmail = ["mypy-boto3-workmail (>=1.34.0,<1.35.0)"] -workmailmessageflow = ["mypy-boto3-workmailmessageflow (>=1.34.0,<1.35.0)"] -workspaces = ["mypy-boto3-workspaces (>=1.34.0,<1.35.0)"] -workspaces-thin-client = ["mypy-boto3-workspaces-thin-client (>=1.34.0,<1.35.0)"] -workspaces-web = ["mypy-boto3-workspaces-web (>=1.34.0,<1.35.0)"] -xray = ["mypy-boto3-xray (>=1.34.0,<1.35.0)"] +accessanalyzer = ["mypy-boto3-accessanalyzer (>=1.42.0,<1.43.0)"] +account = ["mypy-boto3-account (>=1.42.0,<1.43.0)"] +acm = ["mypy-boto3-acm (>=1.42.0,<1.43.0)"] +acm-pca = ["mypy-boto3-acm-pca (>=1.42.0,<1.43.0)"] +aiops = ["mypy-boto3-aiops (>=1.42.0,<1.43.0)"] +all = ["mypy-boto3-accessanalyzer (>=1.42.0,<1.43.0)", "mypy-boto3-account (>=1.42.0,<1.43.0)", "mypy-boto3-acm (>=1.42.0,<1.43.0)", "mypy-boto3-acm-pca (>=1.42.0,<1.43.0)", "mypy-boto3-aiops (>=1.42.0,<1.43.0)", "mypy-boto3-amp (>=1.42.0,<1.43.0)", "mypy-boto3-amplify (>=1.42.0,<1.43.0)", "mypy-boto3-amplifybackend (>=1.42.0,<1.43.0)", "mypy-boto3-amplifyuibuilder (>=1.42.0,<1.43.0)", "mypy-boto3-apigateway (>=1.42.0,<1.43.0)", "mypy-boto3-apigatewaymanagementapi (>=1.42.0,<1.43.0)", "mypy-boto3-apigatewayv2 (>=1.42.0,<1.43.0)", "mypy-boto3-appconfig (>=1.42.0,<1.43.0)", "mypy-boto3-appconfigdata (>=1.42.0,<1.43.0)", "mypy-boto3-appfabric (>=1.42.0,<1.43.0)", "mypy-boto3-appflow (>=1.42.0,<1.43.0)", "mypy-boto3-appintegrations (>=1.42.0,<1.43.0)", "mypy-boto3-application-autoscaling (>=1.42.0,<1.43.0)", "mypy-boto3-application-insights (>=1.42.0,<1.43.0)", "mypy-boto3-application-signals (>=1.42.0,<1.43.0)", "mypy-boto3-applicationcostprofiler (>=1.42.0,<1.43.0)", "mypy-boto3-appmesh (>=1.42.0,<1.43.0)", "mypy-boto3-apprunner (>=1.42.0,<1.43.0)", "mypy-boto3-appstream (>=1.42.0,<1.43.0)", "mypy-boto3-appsync (>=1.42.0,<1.43.0)", "mypy-boto3-arc-region-switch (>=1.42.0,<1.43.0)", "mypy-boto3-arc-zonal-shift (>=1.42.0,<1.43.0)", "mypy-boto3-artifact (>=1.42.0,<1.43.0)", "mypy-boto3-athena (>=1.42.0,<1.43.0)", "mypy-boto3-auditmanager (>=1.42.0,<1.43.0)", "mypy-boto3-autoscaling (>=1.42.0,<1.43.0)", "mypy-boto3-autoscaling-plans (>=1.42.0,<1.43.0)", "mypy-boto3-b2bi (>=1.42.0,<1.43.0)", "mypy-boto3-backup (>=1.42.0,<1.43.0)", "mypy-boto3-backup-gateway (>=1.42.0,<1.43.0)", "mypy-boto3-backupsearch (>=1.42.0,<1.43.0)", "mypy-boto3-batch (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-dashboards (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-data-exports (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-pricing-calculator (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-recommended-actions (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agent (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agent-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agentcore (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agentcore-control (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-data-automation (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-data-automation-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-billing (>=1.42.0,<1.43.0)", "mypy-boto3-billingconductor (>=1.42.0,<1.43.0)", "mypy-boto3-braket (>=1.42.0,<1.43.0)", "mypy-boto3-budgets (>=1.42.0,<1.43.0)", "mypy-boto3-ce (>=1.42.0,<1.43.0)", "mypy-boto3-chatbot (>=1.42.0,<1.43.0)", "mypy-boto3-chime (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-identity (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-media-pipelines (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-meetings (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-messaging (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-voice (>=1.42.0,<1.43.0)", "mypy-boto3-cleanrooms (>=1.42.0,<1.43.0)", "mypy-boto3-cleanroomsml (>=1.42.0,<1.43.0)", "mypy-boto3-cloud9 (>=1.42.0,<1.43.0)", "mypy-boto3-cloudcontrol (>=1.42.0,<1.43.0)", "mypy-boto3-clouddirectory (>=1.42.0,<1.43.0)", "mypy-boto3-cloudformation (>=1.42.0,<1.43.0)", "mypy-boto3-cloudfront (>=1.42.0,<1.43.0)", "mypy-boto3-cloudfront-keyvaluestore (>=1.42.0,<1.43.0)", "mypy-boto3-cloudhsm (>=1.42.0,<1.43.0)", "mypy-boto3-cloudhsmv2 (>=1.42.0,<1.43.0)", "mypy-boto3-cloudsearch (>=1.42.0,<1.43.0)", "mypy-boto3-cloudsearchdomain (>=1.42.0,<1.43.0)", "mypy-boto3-cloudtrail (>=1.42.0,<1.43.0)", "mypy-boto3-cloudtrail-data (>=1.42.0,<1.43.0)", "mypy-boto3-cloudwatch (>=1.42.0,<1.43.0)", "mypy-boto3-codeartifact (>=1.42.0,<1.43.0)", "mypy-boto3-codebuild (>=1.42.0,<1.43.0)", "mypy-boto3-codecatalyst (>=1.42.0,<1.43.0)", "mypy-boto3-codecommit (>=1.42.0,<1.43.0)", "mypy-boto3-codeconnections (>=1.42.0,<1.43.0)", "mypy-boto3-codedeploy (>=1.42.0,<1.43.0)", "mypy-boto3-codeguru-reviewer (>=1.42.0,<1.43.0)", "mypy-boto3-codeguru-security (>=1.42.0,<1.43.0)", "mypy-boto3-codeguruprofiler (>=1.42.0,<1.43.0)", "mypy-boto3-codepipeline (>=1.42.0,<1.43.0)", "mypy-boto3-codestar-connections (>=1.42.0,<1.43.0)", "mypy-boto3-codestar-notifications (>=1.42.0,<1.43.0)", "mypy-boto3-cognito-identity (>=1.42.0,<1.43.0)", "mypy-boto3-cognito-idp (>=1.42.0,<1.43.0)", "mypy-boto3-cognito-sync (>=1.42.0,<1.43.0)", "mypy-boto3-comprehend (>=1.42.0,<1.43.0)", "mypy-boto3-comprehendmedical (>=1.42.0,<1.43.0)", "mypy-boto3-compute-optimizer (>=1.42.0,<1.43.0)", "mypy-boto3-compute-optimizer-automation (>=1.42.0,<1.43.0)", "mypy-boto3-config (>=1.42.0,<1.43.0)", "mypy-boto3-connect (>=1.42.0,<1.43.0)", "mypy-boto3-connect-contact-lens (>=1.42.0,<1.43.0)", "mypy-boto3-connectcampaigns (>=1.42.0,<1.43.0)", "mypy-boto3-connectcampaignsv2 (>=1.42.0,<1.43.0)", "mypy-boto3-connectcases (>=1.42.0,<1.43.0)", "mypy-boto3-connectparticipant (>=1.42.0,<1.43.0)", "mypy-boto3-controlcatalog (>=1.42.0,<1.43.0)", "mypy-boto3-controltower (>=1.42.0,<1.43.0)", "mypy-boto3-cost-optimization-hub (>=1.42.0,<1.43.0)", "mypy-boto3-cur (>=1.42.0,<1.43.0)", "mypy-boto3-customer-profiles (>=1.42.0,<1.43.0)", "mypy-boto3-databrew (>=1.42.0,<1.43.0)", "mypy-boto3-dataexchange (>=1.42.0,<1.43.0)", "mypy-boto3-datapipeline (>=1.42.0,<1.43.0)", "mypy-boto3-datasync (>=1.42.0,<1.43.0)", "mypy-boto3-datazone (>=1.42.0,<1.43.0)", "mypy-boto3-dax (>=1.42.0,<1.43.0)", "mypy-boto3-deadline (>=1.42.0,<1.43.0)", "mypy-boto3-detective (>=1.42.0,<1.43.0)", "mypy-boto3-devicefarm (>=1.42.0,<1.43.0)", "mypy-boto3-devops-guru (>=1.42.0,<1.43.0)", "mypy-boto3-directconnect (>=1.42.0,<1.43.0)", "mypy-boto3-discovery (>=1.42.0,<1.43.0)", "mypy-boto3-dlm (>=1.42.0,<1.43.0)", "mypy-boto3-dms (>=1.42.0,<1.43.0)", "mypy-boto3-docdb (>=1.42.0,<1.43.0)", "mypy-boto3-docdb-elastic (>=1.42.0,<1.43.0)", "mypy-boto3-drs (>=1.42.0,<1.43.0)", "mypy-boto3-ds (>=1.42.0,<1.43.0)", "mypy-boto3-ds-data (>=1.42.0,<1.43.0)", "mypy-boto3-dsql (>=1.42.0,<1.43.0)", "mypy-boto3-dynamodb (>=1.42.0,<1.43.0)", "mypy-boto3-dynamodbstreams (>=1.42.0,<1.43.0)", "mypy-boto3-ebs (>=1.42.0,<1.43.0)", "mypy-boto3-ec2 (>=1.42.0,<1.43.0)", "mypy-boto3-ec2-instance-connect (>=1.42.0,<1.43.0)", "mypy-boto3-ecr (>=1.42.0,<1.43.0)", "mypy-boto3-ecr-public (>=1.42.0,<1.43.0)", "mypy-boto3-ecs (>=1.42.0,<1.43.0)", "mypy-boto3-efs (>=1.42.0,<1.43.0)", "mypy-boto3-eks (>=1.42.0,<1.43.0)", "mypy-boto3-eks-auth (>=1.42.0,<1.43.0)", "mypy-boto3-elasticache (>=1.42.0,<1.43.0)", "mypy-boto3-elasticbeanstalk (>=1.42.0,<1.43.0)", "mypy-boto3-elb (>=1.42.0,<1.43.0)", "mypy-boto3-elbv2 (>=1.42.0,<1.43.0)", "mypy-boto3-emr (>=1.42.0,<1.43.0)", "mypy-boto3-emr-containers (>=1.42.0,<1.43.0)", "mypy-boto3-emr-serverless (>=1.42.0,<1.43.0)", "mypy-boto3-entityresolution (>=1.42.0,<1.43.0)", "mypy-boto3-es (>=1.42.0,<1.43.0)", "mypy-boto3-events (>=1.42.0,<1.43.0)", "mypy-boto3-evidently (>=1.42.0,<1.43.0)", "mypy-boto3-evs (>=1.42.0,<1.43.0)", "mypy-boto3-finspace (>=1.42.0,<1.43.0)", "mypy-boto3-finspace-data (>=1.42.0,<1.43.0)", "mypy-boto3-firehose (>=1.42.0,<1.43.0)", "mypy-boto3-fis (>=1.42.0,<1.43.0)", "mypy-boto3-fms (>=1.42.0,<1.43.0)", "mypy-boto3-forecast (>=1.42.0,<1.43.0)", "mypy-boto3-forecastquery (>=1.42.0,<1.43.0)", "mypy-boto3-frauddetector (>=1.42.0,<1.43.0)", "mypy-boto3-freetier (>=1.42.0,<1.43.0)", "mypy-boto3-fsx (>=1.42.0,<1.43.0)", "mypy-boto3-gamelift (>=1.42.0,<1.43.0)", "mypy-boto3-gameliftstreams (>=1.42.0,<1.43.0)", "mypy-boto3-geo-maps (>=1.42.0,<1.43.0)", "mypy-boto3-geo-places (>=1.42.0,<1.43.0)", "mypy-boto3-geo-routes (>=1.42.0,<1.43.0)", "mypy-boto3-glacier (>=1.42.0,<1.43.0)", "mypy-boto3-globalaccelerator (>=1.42.0,<1.43.0)", "mypy-boto3-glue (>=1.42.0,<1.43.0)", "mypy-boto3-grafana (>=1.42.0,<1.43.0)", "mypy-boto3-greengrass (>=1.42.0,<1.43.0)", "mypy-boto3-greengrassv2 (>=1.42.0,<1.43.0)", "mypy-boto3-groundstation (>=1.42.0,<1.43.0)", "mypy-boto3-guardduty (>=1.42.0,<1.43.0)", "mypy-boto3-health (>=1.42.0,<1.43.0)", "mypy-boto3-healthlake (>=1.42.0,<1.43.0)", "mypy-boto3-iam (>=1.42.0,<1.43.0)", "mypy-boto3-identitystore (>=1.42.0,<1.43.0)", "mypy-boto3-imagebuilder (>=1.42.0,<1.43.0)", "mypy-boto3-importexport (>=1.42.0,<1.43.0)", "mypy-boto3-inspector (>=1.42.0,<1.43.0)", "mypy-boto3-inspector-scan (>=1.42.0,<1.43.0)", "mypy-boto3-inspector2 (>=1.42.0,<1.43.0)", "mypy-boto3-internetmonitor (>=1.42.0,<1.43.0)", "mypy-boto3-invoicing (>=1.42.0,<1.43.0)", "mypy-boto3-iot (>=1.42.0,<1.43.0)", "mypy-boto3-iot-data (>=1.42.0,<1.43.0)", "mypy-boto3-iot-jobs-data (>=1.42.0,<1.43.0)", "mypy-boto3-iot-managed-integrations (>=1.42.0,<1.43.0)", "mypy-boto3-iotanalytics (>=1.42.0,<1.43.0)", "mypy-boto3-iotdeviceadvisor (>=1.42.0,<1.43.0)", "mypy-boto3-iotevents (>=1.42.0,<1.43.0)", "mypy-boto3-iotevents-data (>=1.42.0,<1.43.0)", "mypy-boto3-iotfleetwise (>=1.42.0,<1.43.0)", "mypy-boto3-iotsecuretunneling (>=1.42.0,<1.43.0)", "mypy-boto3-iotsitewise (>=1.42.0,<1.43.0)", "mypy-boto3-iotthingsgraph (>=1.42.0,<1.43.0)", "mypy-boto3-iottwinmaker (>=1.42.0,<1.43.0)", "mypy-boto3-iotwireless (>=1.42.0,<1.43.0)", "mypy-boto3-ivs (>=1.42.0,<1.43.0)", "mypy-boto3-ivs-realtime (>=1.42.0,<1.43.0)", "mypy-boto3-ivschat (>=1.42.0,<1.43.0)", "mypy-boto3-kafka (>=1.42.0,<1.43.0)", "mypy-boto3-kafkaconnect (>=1.42.0,<1.43.0)", "mypy-boto3-kendra (>=1.42.0,<1.43.0)", "mypy-boto3-kendra-ranking (>=1.42.0,<1.43.0)", "mypy-boto3-keyspaces (>=1.42.0,<1.43.0)", "mypy-boto3-keyspacesstreams (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-archived-media (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-media (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-signaling (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-webrtc-storage (>=1.42.0,<1.43.0)", "mypy-boto3-kinesisanalytics (>=1.42.0,<1.43.0)", "mypy-boto3-kinesisanalyticsv2 (>=1.42.0,<1.43.0)", "mypy-boto3-kinesisvideo (>=1.42.0,<1.43.0)", "mypy-boto3-kms (>=1.42.0,<1.43.0)", "mypy-boto3-lakeformation (>=1.42.0,<1.43.0)", "mypy-boto3-lambda (>=1.42.0,<1.43.0)", "mypy-boto3-launch-wizard (>=1.42.0,<1.43.0)", "mypy-boto3-lex-models (>=1.42.0,<1.43.0)", "mypy-boto3-lex-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-lexv2-models (>=1.42.0,<1.43.0)", "mypy-boto3-lexv2-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-license-manager (>=1.42.0,<1.43.0)", "mypy-boto3-license-manager-linux-subscriptions (>=1.42.0,<1.43.0)", "mypy-boto3-license-manager-user-subscriptions (>=1.42.0,<1.43.0)", "mypy-boto3-lightsail (>=1.42.0,<1.43.0)", "mypy-boto3-location (>=1.42.0,<1.43.0)", "mypy-boto3-logs (>=1.42.0,<1.43.0)", "mypy-boto3-lookoutequipment (>=1.42.0,<1.43.0)", "mypy-boto3-m2 (>=1.42.0,<1.43.0)", "mypy-boto3-machinelearning (>=1.42.0,<1.43.0)", "mypy-boto3-macie2 (>=1.42.0,<1.43.0)", "mypy-boto3-mailmanager (>=1.42.0,<1.43.0)", "mypy-boto3-managedblockchain (>=1.42.0,<1.43.0)", "mypy-boto3-managedblockchain-query (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-agreement (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-catalog (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-deployment (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-entitlement (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-reporting (>=1.42.0,<1.43.0)", "mypy-boto3-marketplacecommerceanalytics (>=1.42.0,<1.43.0)", "mypy-boto3-mediaconnect (>=1.42.0,<1.43.0)", "mypy-boto3-mediaconvert (>=1.42.0,<1.43.0)", "mypy-boto3-medialive (>=1.42.0,<1.43.0)", "mypy-boto3-mediapackage (>=1.42.0,<1.43.0)", "mypy-boto3-mediapackage-vod (>=1.42.0,<1.43.0)", "mypy-boto3-mediapackagev2 (>=1.42.0,<1.43.0)", "mypy-boto3-mediastore (>=1.42.0,<1.43.0)", "mypy-boto3-mediastore-data (>=1.42.0,<1.43.0)", "mypy-boto3-mediatailor (>=1.42.0,<1.43.0)", "mypy-boto3-medical-imaging (>=1.42.0,<1.43.0)", "mypy-boto3-memorydb (>=1.42.0,<1.43.0)", "mypy-boto3-meteringmarketplace (>=1.42.0,<1.43.0)", "mypy-boto3-mgh (>=1.42.0,<1.43.0)", "mypy-boto3-mgn (>=1.42.0,<1.43.0)", "mypy-boto3-migration-hub-refactor-spaces (>=1.42.0,<1.43.0)", "mypy-boto3-migrationhub-config (>=1.42.0,<1.43.0)", "mypy-boto3-migrationhuborchestrator (>=1.42.0,<1.43.0)", "mypy-boto3-migrationhubstrategy (>=1.42.0,<1.43.0)", "mypy-boto3-mpa (>=1.42.0,<1.43.0)", "mypy-boto3-mq (>=1.42.0,<1.43.0)", "mypy-boto3-mturk (>=1.42.0,<1.43.0)", "mypy-boto3-mwaa (>=1.42.0,<1.43.0)", "mypy-boto3-mwaa-serverless (>=1.42.0,<1.43.0)", "mypy-boto3-neptune (>=1.42.0,<1.43.0)", "mypy-boto3-neptune-graph (>=1.42.0,<1.43.0)", "mypy-boto3-neptunedata (>=1.42.0,<1.43.0)", "mypy-boto3-network-firewall (>=1.42.0,<1.43.0)", "mypy-boto3-networkflowmonitor (>=1.42.0,<1.43.0)", "mypy-boto3-networkmanager (>=1.42.0,<1.43.0)", "mypy-boto3-networkmonitor (>=1.42.0,<1.43.0)", "mypy-boto3-notifications (>=1.42.0,<1.43.0)", "mypy-boto3-notificationscontacts (>=1.42.0,<1.43.0)", "mypy-boto3-nova-act (>=1.42.0,<1.43.0)", "mypy-boto3-oam (>=1.42.0,<1.43.0)", "mypy-boto3-observabilityadmin (>=1.42.0,<1.43.0)", "mypy-boto3-odb (>=1.42.0,<1.43.0)", "mypy-boto3-omics (>=1.42.0,<1.43.0)", "mypy-boto3-opensearch (>=1.42.0,<1.43.0)", "mypy-boto3-opensearchserverless (>=1.42.0,<1.43.0)", "mypy-boto3-organizations (>=1.42.0,<1.43.0)", "mypy-boto3-osis (>=1.42.0,<1.43.0)", "mypy-boto3-outposts (>=1.42.0,<1.43.0)", "mypy-boto3-panorama (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-account (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-benefits (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-channel (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-selling (>=1.42.0,<1.43.0)", "mypy-boto3-payment-cryptography (>=1.42.0,<1.43.0)", "mypy-boto3-payment-cryptography-data (>=1.42.0,<1.43.0)", "mypy-boto3-pca-connector-ad (>=1.42.0,<1.43.0)", "mypy-boto3-pca-connector-scep (>=1.42.0,<1.43.0)", "mypy-boto3-pcs (>=1.42.0,<1.43.0)", "mypy-boto3-personalize (>=1.42.0,<1.43.0)", "mypy-boto3-personalize-events (>=1.42.0,<1.43.0)", "mypy-boto3-personalize-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-pi (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint-email (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint-sms-voice (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint-sms-voice-v2 (>=1.42.0,<1.43.0)", "mypy-boto3-pipes (>=1.42.0,<1.43.0)", "mypy-boto3-polly (>=1.42.0,<1.43.0)", "mypy-boto3-pricing (>=1.42.0,<1.43.0)", "mypy-boto3-proton (>=1.42.0,<1.43.0)", "mypy-boto3-qapps (>=1.42.0,<1.43.0)", "mypy-boto3-qbusiness (>=1.42.0,<1.43.0)", "mypy-boto3-qconnect (>=1.42.0,<1.43.0)", "mypy-boto3-quicksight (>=1.42.0,<1.43.0)", "mypy-boto3-ram (>=1.42.0,<1.43.0)", "mypy-boto3-rbin (>=1.42.0,<1.43.0)", "mypy-boto3-rds (>=1.42.0,<1.43.0)", "mypy-boto3-rds-data (>=1.42.0,<1.43.0)", "mypy-boto3-redshift (>=1.42.0,<1.43.0)", "mypy-boto3-redshift-data (>=1.42.0,<1.43.0)", "mypy-boto3-redshift-serverless (>=1.42.0,<1.43.0)", "mypy-boto3-rekognition (>=1.42.0,<1.43.0)", "mypy-boto3-repostspace (>=1.42.0,<1.43.0)", "mypy-boto3-resiliencehub (>=1.42.0,<1.43.0)", "mypy-boto3-resource-explorer-2 (>=1.42.0,<1.43.0)", "mypy-boto3-resource-groups (>=1.42.0,<1.43.0)", "mypy-boto3-resourcegroupstaggingapi (>=1.42.0,<1.43.0)", "mypy-boto3-rolesanywhere (>=1.42.0,<1.43.0)", "mypy-boto3-route53 (>=1.42.0,<1.43.0)", "mypy-boto3-route53-recovery-cluster (>=1.42.0,<1.43.0)", "mypy-boto3-route53-recovery-control-config (>=1.42.0,<1.43.0)", "mypy-boto3-route53-recovery-readiness (>=1.42.0,<1.43.0)", "mypy-boto3-route53domains (>=1.42.0,<1.43.0)", "mypy-boto3-route53globalresolver (>=1.42.0,<1.43.0)", "mypy-boto3-route53profiles (>=1.42.0,<1.43.0)", "mypy-boto3-route53resolver (>=1.42.0,<1.43.0)", "mypy-boto3-rtbfabric (>=1.42.0,<1.43.0)", "mypy-boto3-rum (>=1.42.0,<1.43.0)", "mypy-boto3-s3 (>=1.42.0,<1.43.0)", "mypy-boto3-s3control (>=1.42.0,<1.43.0)", "mypy-boto3-s3outposts (>=1.42.0,<1.43.0)", "mypy-boto3-s3tables (>=1.42.0,<1.43.0)", "mypy-boto3-s3vectors (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-a2i-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-edge (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-featurestore-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-geospatial (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-metrics (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-savingsplans (>=1.42.0,<1.43.0)", "mypy-boto3-scheduler (>=1.42.0,<1.43.0)", "mypy-boto3-schemas (>=1.42.0,<1.43.0)", "mypy-boto3-sdb (>=1.42.0,<1.43.0)", "mypy-boto3-secretsmanager (>=1.42.0,<1.43.0)", "mypy-boto3-security-ir (>=1.42.0,<1.43.0)", "mypy-boto3-securityhub (>=1.42.0,<1.43.0)", "mypy-boto3-securitylake (>=1.42.0,<1.43.0)", "mypy-boto3-serverlessrepo (>=1.42.0,<1.43.0)", "mypy-boto3-service-quotas (>=1.42.0,<1.43.0)", "mypy-boto3-servicecatalog (>=1.42.0,<1.43.0)", "mypy-boto3-servicecatalog-appregistry (>=1.42.0,<1.43.0)", "mypy-boto3-servicediscovery (>=1.42.0,<1.43.0)", "mypy-boto3-ses (>=1.42.0,<1.43.0)", "mypy-boto3-sesv2 (>=1.42.0,<1.43.0)", "mypy-boto3-shield (>=1.42.0,<1.43.0)", "mypy-boto3-signer (>=1.42.0,<1.43.0)", "mypy-boto3-signin (>=1.42.0,<1.43.0)", "mypy-boto3-simspaceweaver (>=1.42.0,<1.43.0)", "mypy-boto3-snow-device-management (>=1.42.0,<1.43.0)", "mypy-boto3-snowball (>=1.42.0,<1.43.0)", "mypy-boto3-sns (>=1.42.0,<1.43.0)", "mypy-boto3-socialmessaging (>=1.42.0,<1.43.0)", "mypy-boto3-sqs (>=1.42.0,<1.43.0)", "mypy-boto3-ssm (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-contacts (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-guiconnect (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-incidents (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-quicksetup (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-sap (>=1.42.0,<1.43.0)", "mypy-boto3-sso (>=1.42.0,<1.43.0)", "mypy-boto3-sso-admin (>=1.42.0,<1.43.0)", "mypy-boto3-sso-oidc (>=1.42.0,<1.43.0)", "mypy-boto3-stepfunctions (>=1.42.0,<1.43.0)", "mypy-boto3-storagegateway (>=1.42.0,<1.43.0)", "mypy-boto3-sts (>=1.42.0,<1.43.0)", "mypy-boto3-supplychain (>=1.42.0,<1.43.0)", "mypy-boto3-support (>=1.42.0,<1.43.0)", "mypy-boto3-support-app (>=1.42.0,<1.43.0)", "mypy-boto3-swf (>=1.42.0,<1.43.0)", "mypy-boto3-synthetics (>=1.42.0,<1.43.0)", "mypy-boto3-taxsettings (>=1.42.0,<1.43.0)", "mypy-boto3-textract (>=1.42.0,<1.43.0)", "mypy-boto3-timestream-influxdb (>=1.42.0,<1.43.0)", "mypy-boto3-timestream-query (>=1.42.0,<1.43.0)", "mypy-boto3-timestream-write (>=1.42.0,<1.43.0)", "mypy-boto3-tnb (>=1.42.0,<1.43.0)", "mypy-boto3-transcribe (>=1.42.0,<1.43.0)", "mypy-boto3-transfer (>=1.42.0,<1.43.0)", "mypy-boto3-translate (>=1.42.0,<1.43.0)", "mypy-boto3-trustedadvisor (>=1.42.0,<1.43.0)", "mypy-boto3-verifiedpermissions (>=1.42.0,<1.43.0)", "mypy-boto3-voice-id (>=1.42.0,<1.43.0)", "mypy-boto3-vpc-lattice (>=1.42.0,<1.43.0)", "mypy-boto3-waf (>=1.42.0,<1.43.0)", "mypy-boto3-waf-regional (>=1.42.0,<1.43.0)", "mypy-boto3-wafv2 (>=1.42.0,<1.43.0)", "mypy-boto3-wellarchitected (>=1.42.0,<1.43.0)", "mypy-boto3-wickr (>=1.42.0,<1.43.0)", "mypy-boto3-wisdom (>=1.42.0,<1.43.0)", "mypy-boto3-workdocs (>=1.42.0,<1.43.0)", "mypy-boto3-workmail (>=1.42.0,<1.43.0)", "mypy-boto3-workmailmessageflow (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces-instances (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces-thin-client (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces-web (>=1.42.0,<1.43.0)", "mypy-boto3-xray (>=1.42.0,<1.43.0)"] +amp = ["mypy-boto3-amp (>=1.42.0,<1.43.0)"] +amplify = ["mypy-boto3-amplify (>=1.42.0,<1.43.0)"] +amplifybackend = ["mypy-boto3-amplifybackend (>=1.42.0,<1.43.0)"] +amplifyuibuilder = ["mypy-boto3-amplifyuibuilder (>=1.42.0,<1.43.0)"] +apigateway = ["mypy-boto3-apigateway (>=1.42.0,<1.43.0)"] +apigatewaymanagementapi = ["mypy-boto3-apigatewaymanagementapi (>=1.42.0,<1.43.0)"] +apigatewayv2 = ["mypy-boto3-apigatewayv2 (>=1.42.0,<1.43.0)"] +appconfig = ["mypy-boto3-appconfig (>=1.42.0,<1.43.0)"] +appconfigdata = ["mypy-boto3-appconfigdata (>=1.42.0,<1.43.0)"] +appfabric = ["mypy-boto3-appfabric (>=1.42.0,<1.43.0)"] +appflow = ["mypy-boto3-appflow (>=1.42.0,<1.43.0)"] +appintegrations = ["mypy-boto3-appintegrations (>=1.42.0,<1.43.0)"] +application-autoscaling = ["mypy-boto3-application-autoscaling (>=1.42.0,<1.43.0)"] +application-insights = ["mypy-boto3-application-insights (>=1.42.0,<1.43.0)"] +application-signals = ["mypy-boto3-application-signals (>=1.42.0,<1.43.0)"] +applicationcostprofiler = ["mypy-boto3-applicationcostprofiler (>=1.42.0,<1.43.0)"] +appmesh = ["mypy-boto3-appmesh (>=1.42.0,<1.43.0)"] +apprunner = ["mypy-boto3-apprunner (>=1.42.0,<1.43.0)"] +appstream = ["mypy-boto3-appstream (>=1.42.0,<1.43.0)"] +appsync = ["mypy-boto3-appsync (>=1.42.0,<1.43.0)"] +arc-region-switch = ["mypy-boto3-arc-region-switch (>=1.42.0,<1.43.0)"] +arc-zonal-shift = ["mypy-boto3-arc-zonal-shift (>=1.42.0,<1.43.0)"] +artifact = ["mypy-boto3-artifact (>=1.42.0,<1.43.0)"] +athena = ["mypy-boto3-athena (>=1.42.0,<1.43.0)"] +auditmanager = ["mypy-boto3-auditmanager (>=1.42.0,<1.43.0)"] +autoscaling = ["mypy-boto3-autoscaling (>=1.42.0,<1.43.0)"] +autoscaling-plans = ["mypy-boto3-autoscaling-plans (>=1.42.0,<1.43.0)"] +b2bi = ["mypy-boto3-b2bi (>=1.42.0,<1.43.0)"] +backup = ["mypy-boto3-backup (>=1.42.0,<1.43.0)"] +backup-gateway = ["mypy-boto3-backup-gateway (>=1.42.0,<1.43.0)"] +backupsearch = ["mypy-boto3-backupsearch (>=1.42.0,<1.43.0)"] +batch = ["mypy-boto3-batch (>=1.42.0,<1.43.0)"] +bcm-dashboards = ["mypy-boto3-bcm-dashboards (>=1.42.0,<1.43.0)"] +bcm-data-exports = ["mypy-boto3-bcm-data-exports (>=1.42.0,<1.43.0)"] +bcm-pricing-calculator = ["mypy-boto3-bcm-pricing-calculator (>=1.42.0,<1.43.0)"] +bcm-recommended-actions = ["mypy-boto3-bcm-recommended-actions (>=1.42.0,<1.43.0)"] +bedrock = ["mypy-boto3-bedrock (>=1.42.0,<1.43.0)"] +bedrock-agent = ["mypy-boto3-bedrock-agent (>=1.42.0,<1.43.0)"] +bedrock-agent-runtime = ["mypy-boto3-bedrock-agent-runtime (>=1.42.0,<1.43.0)"] +bedrock-agentcore = ["mypy-boto3-bedrock-agentcore (>=1.42.0,<1.43.0)"] +bedrock-agentcore-control = ["mypy-boto3-bedrock-agentcore-control (>=1.42.0,<1.43.0)"] +bedrock-data-automation = ["mypy-boto3-bedrock-data-automation (>=1.42.0,<1.43.0)"] +bedrock-data-automation-runtime = ["mypy-boto3-bedrock-data-automation-runtime (>=1.42.0,<1.43.0)"] +bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.42.0,<1.43.0)"] +billing = ["mypy-boto3-billing (>=1.42.0,<1.43.0)"] +billingconductor = ["mypy-boto3-billingconductor (>=1.42.0,<1.43.0)"] +boto3 = ["boto3 (==1.42.33)"] +braket = ["mypy-boto3-braket (>=1.42.0,<1.43.0)"] +budgets = ["mypy-boto3-budgets (>=1.42.0,<1.43.0)"] +ce = ["mypy-boto3-ce (>=1.42.0,<1.43.0)"] +chatbot = ["mypy-boto3-chatbot (>=1.42.0,<1.43.0)"] +chime = ["mypy-boto3-chime (>=1.42.0,<1.43.0)"] +chime-sdk-identity = ["mypy-boto3-chime-sdk-identity (>=1.42.0,<1.43.0)"] +chime-sdk-media-pipelines = ["mypy-boto3-chime-sdk-media-pipelines (>=1.42.0,<1.43.0)"] +chime-sdk-meetings = ["mypy-boto3-chime-sdk-meetings (>=1.42.0,<1.43.0)"] +chime-sdk-messaging = ["mypy-boto3-chime-sdk-messaging (>=1.42.0,<1.43.0)"] +chime-sdk-voice = ["mypy-boto3-chime-sdk-voice (>=1.42.0,<1.43.0)"] +cleanrooms = ["mypy-boto3-cleanrooms (>=1.42.0,<1.43.0)"] +cleanroomsml = ["mypy-boto3-cleanroomsml (>=1.42.0,<1.43.0)"] +cloud9 = ["mypy-boto3-cloud9 (>=1.42.0,<1.43.0)"] +cloudcontrol = ["mypy-boto3-cloudcontrol (>=1.42.0,<1.43.0)"] +clouddirectory = ["mypy-boto3-clouddirectory (>=1.42.0,<1.43.0)"] +cloudformation = ["mypy-boto3-cloudformation (>=1.42.0,<1.43.0)"] +cloudfront = ["mypy-boto3-cloudfront (>=1.42.0,<1.43.0)"] +cloudfront-keyvaluestore = ["mypy-boto3-cloudfront-keyvaluestore (>=1.42.0,<1.43.0)"] +cloudhsm = ["mypy-boto3-cloudhsm (>=1.42.0,<1.43.0)"] +cloudhsmv2 = ["mypy-boto3-cloudhsmv2 (>=1.42.0,<1.43.0)"] +cloudsearch = ["mypy-boto3-cloudsearch (>=1.42.0,<1.43.0)"] +cloudsearchdomain = ["mypy-boto3-cloudsearchdomain (>=1.42.0,<1.43.0)"] +cloudtrail = ["mypy-boto3-cloudtrail (>=1.42.0,<1.43.0)"] +cloudtrail-data = ["mypy-boto3-cloudtrail-data (>=1.42.0,<1.43.0)"] +cloudwatch = ["mypy-boto3-cloudwatch (>=1.42.0,<1.43.0)"] +codeartifact = ["mypy-boto3-codeartifact (>=1.42.0,<1.43.0)"] +codebuild = ["mypy-boto3-codebuild (>=1.42.0,<1.43.0)"] +codecatalyst = ["mypy-boto3-codecatalyst (>=1.42.0,<1.43.0)"] +codecommit = ["mypy-boto3-codecommit (>=1.42.0,<1.43.0)"] +codeconnections = ["mypy-boto3-codeconnections (>=1.42.0,<1.43.0)"] +codedeploy = ["mypy-boto3-codedeploy (>=1.42.0,<1.43.0)"] +codeguru-reviewer = ["mypy-boto3-codeguru-reviewer (>=1.42.0,<1.43.0)"] +codeguru-security = ["mypy-boto3-codeguru-security (>=1.42.0,<1.43.0)"] +codeguruprofiler = ["mypy-boto3-codeguruprofiler (>=1.42.0,<1.43.0)"] +codepipeline = ["mypy-boto3-codepipeline (>=1.42.0,<1.43.0)"] +codestar-connections = ["mypy-boto3-codestar-connections (>=1.42.0,<1.43.0)"] +codestar-notifications = ["mypy-boto3-codestar-notifications (>=1.42.0,<1.43.0)"] +cognito-identity = ["mypy-boto3-cognito-identity (>=1.42.0,<1.43.0)"] +cognito-idp = ["mypy-boto3-cognito-idp (>=1.42.0,<1.43.0)"] +cognito-sync = ["mypy-boto3-cognito-sync (>=1.42.0,<1.43.0)"] +comprehend = ["mypy-boto3-comprehend (>=1.42.0,<1.43.0)"] +comprehendmedical = ["mypy-boto3-comprehendmedical (>=1.42.0,<1.43.0)"] +compute-optimizer = ["mypy-boto3-compute-optimizer (>=1.42.0,<1.43.0)"] +compute-optimizer-automation = ["mypy-boto3-compute-optimizer-automation (>=1.42.0,<1.43.0)"] +config = ["mypy-boto3-config (>=1.42.0,<1.43.0)"] +connect = ["mypy-boto3-connect (>=1.42.0,<1.43.0)"] +connect-contact-lens = ["mypy-boto3-connect-contact-lens (>=1.42.0,<1.43.0)"] +connectcampaigns = ["mypy-boto3-connectcampaigns (>=1.42.0,<1.43.0)"] +connectcampaignsv2 = ["mypy-boto3-connectcampaignsv2 (>=1.42.0,<1.43.0)"] +connectcases = ["mypy-boto3-connectcases (>=1.42.0,<1.43.0)"] +connectparticipant = ["mypy-boto3-connectparticipant (>=1.42.0,<1.43.0)"] +controlcatalog = ["mypy-boto3-controlcatalog (>=1.42.0,<1.43.0)"] +controltower = ["mypy-boto3-controltower (>=1.42.0,<1.43.0)"] +cost-optimization-hub = ["mypy-boto3-cost-optimization-hub (>=1.42.0,<1.43.0)"] +cur = ["mypy-boto3-cur (>=1.42.0,<1.43.0)"] +customer-profiles = ["mypy-boto3-customer-profiles (>=1.42.0,<1.43.0)"] +databrew = ["mypy-boto3-databrew (>=1.42.0,<1.43.0)"] +dataexchange = ["mypy-boto3-dataexchange (>=1.42.0,<1.43.0)"] +datapipeline = ["mypy-boto3-datapipeline (>=1.42.0,<1.43.0)"] +datasync = ["mypy-boto3-datasync (>=1.42.0,<1.43.0)"] +datazone = ["mypy-boto3-datazone (>=1.42.0,<1.43.0)"] +dax = ["mypy-boto3-dax (>=1.42.0,<1.43.0)"] +deadline = ["mypy-boto3-deadline (>=1.42.0,<1.43.0)"] +detective = ["mypy-boto3-detective (>=1.42.0,<1.43.0)"] +devicefarm = ["mypy-boto3-devicefarm (>=1.42.0,<1.43.0)"] +devops-guru = ["mypy-boto3-devops-guru (>=1.42.0,<1.43.0)"] +directconnect = ["mypy-boto3-directconnect (>=1.42.0,<1.43.0)"] +discovery = ["mypy-boto3-discovery (>=1.42.0,<1.43.0)"] +dlm = ["mypy-boto3-dlm (>=1.42.0,<1.43.0)"] +dms = ["mypy-boto3-dms (>=1.42.0,<1.43.0)"] +docdb = ["mypy-boto3-docdb (>=1.42.0,<1.43.0)"] +docdb-elastic = ["mypy-boto3-docdb-elastic (>=1.42.0,<1.43.0)"] +drs = ["mypy-boto3-drs (>=1.42.0,<1.43.0)"] +ds = ["mypy-boto3-ds (>=1.42.0,<1.43.0)"] +ds-data = ["mypy-boto3-ds-data (>=1.42.0,<1.43.0)"] +dsql = ["mypy-boto3-dsql (>=1.42.0,<1.43.0)"] +dynamodb = ["mypy-boto3-dynamodb (>=1.42.0,<1.43.0)"] +dynamodbstreams = ["mypy-boto3-dynamodbstreams (>=1.42.0,<1.43.0)"] +ebs = ["mypy-boto3-ebs (>=1.42.0,<1.43.0)"] +ec2 = ["mypy-boto3-ec2 (>=1.42.0,<1.43.0)"] +ec2-instance-connect = ["mypy-boto3-ec2-instance-connect (>=1.42.0,<1.43.0)"] +ecr = ["mypy-boto3-ecr (>=1.42.0,<1.43.0)"] +ecr-public = ["mypy-boto3-ecr-public (>=1.42.0,<1.43.0)"] +ecs = ["mypy-boto3-ecs (>=1.42.0,<1.43.0)"] +efs = ["mypy-boto3-efs (>=1.42.0,<1.43.0)"] +eks = ["mypy-boto3-eks (>=1.42.0,<1.43.0)"] +eks-auth = ["mypy-boto3-eks-auth (>=1.42.0,<1.43.0)"] +elasticache = ["mypy-boto3-elasticache (>=1.42.0,<1.43.0)"] +elasticbeanstalk = ["mypy-boto3-elasticbeanstalk (>=1.42.0,<1.43.0)"] +elb = ["mypy-boto3-elb (>=1.42.0,<1.43.0)"] +elbv2 = ["mypy-boto3-elbv2 (>=1.42.0,<1.43.0)"] +emr = ["mypy-boto3-emr (>=1.42.0,<1.43.0)"] +emr-containers = ["mypy-boto3-emr-containers (>=1.42.0,<1.43.0)"] +emr-serverless = ["mypy-boto3-emr-serverless (>=1.42.0,<1.43.0)"] +entityresolution = ["mypy-boto3-entityresolution (>=1.42.0,<1.43.0)"] +es = ["mypy-boto3-es (>=1.42.0,<1.43.0)"] +essential = ["mypy-boto3-cloudformation (>=1.42.0,<1.43.0)", "mypy-boto3-dynamodb (>=1.42.0,<1.43.0)", "mypy-boto3-ec2 (>=1.42.0,<1.43.0)", "mypy-boto3-lambda (>=1.42.0,<1.43.0)", "mypy-boto3-rds (>=1.42.0,<1.43.0)", "mypy-boto3-s3 (>=1.42.0,<1.43.0)", "mypy-boto3-sqs (>=1.42.0,<1.43.0)"] +events = ["mypy-boto3-events (>=1.42.0,<1.43.0)"] +evidently = ["mypy-boto3-evidently (>=1.42.0,<1.43.0)"] +evs = ["mypy-boto3-evs (>=1.42.0,<1.43.0)"] +finspace = ["mypy-boto3-finspace (>=1.42.0,<1.43.0)"] +finspace-data = ["mypy-boto3-finspace-data (>=1.42.0,<1.43.0)"] +firehose = ["mypy-boto3-firehose (>=1.42.0,<1.43.0)"] +fis = ["mypy-boto3-fis (>=1.42.0,<1.43.0)"] +fms = ["mypy-boto3-fms (>=1.42.0,<1.43.0)"] +forecast = ["mypy-boto3-forecast (>=1.42.0,<1.43.0)"] +forecastquery = ["mypy-boto3-forecastquery (>=1.42.0,<1.43.0)"] +frauddetector = ["mypy-boto3-frauddetector (>=1.42.0,<1.43.0)"] +freetier = ["mypy-boto3-freetier (>=1.42.0,<1.43.0)"] +fsx = ["mypy-boto3-fsx (>=1.42.0,<1.43.0)"] +full = ["boto3-stubs-full (>=1.42.0,<1.43.0)"] +gamelift = ["mypy-boto3-gamelift (>=1.42.0,<1.43.0)"] +gameliftstreams = ["mypy-boto3-gameliftstreams (>=1.42.0,<1.43.0)"] +geo-maps = ["mypy-boto3-geo-maps (>=1.42.0,<1.43.0)"] +geo-places = ["mypy-boto3-geo-places (>=1.42.0,<1.43.0)"] +geo-routes = ["mypy-boto3-geo-routes (>=1.42.0,<1.43.0)"] +glacier = ["mypy-boto3-glacier (>=1.42.0,<1.43.0)"] +globalaccelerator = ["mypy-boto3-globalaccelerator (>=1.42.0,<1.43.0)"] +glue = ["mypy-boto3-glue (>=1.42.0,<1.43.0)"] +grafana = ["mypy-boto3-grafana (>=1.42.0,<1.43.0)"] +greengrass = ["mypy-boto3-greengrass (>=1.42.0,<1.43.0)"] +greengrassv2 = ["mypy-boto3-greengrassv2 (>=1.42.0,<1.43.0)"] +groundstation = ["mypy-boto3-groundstation (>=1.42.0,<1.43.0)"] +guardduty = ["mypy-boto3-guardduty (>=1.42.0,<1.43.0)"] +health = ["mypy-boto3-health (>=1.42.0,<1.43.0)"] +healthlake = ["mypy-boto3-healthlake (>=1.42.0,<1.43.0)"] +iam = ["mypy-boto3-iam (>=1.42.0,<1.43.0)"] +identitystore = ["mypy-boto3-identitystore (>=1.42.0,<1.43.0)"] +imagebuilder = ["mypy-boto3-imagebuilder (>=1.42.0,<1.43.0)"] +importexport = ["mypy-boto3-importexport (>=1.42.0,<1.43.0)"] +inspector = ["mypy-boto3-inspector (>=1.42.0,<1.43.0)"] +inspector-scan = ["mypy-boto3-inspector-scan (>=1.42.0,<1.43.0)"] +inspector2 = ["mypy-boto3-inspector2 (>=1.42.0,<1.43.0)"] +internetmonitor = ["mypy-boto3-internetmonitor (>=1.42.0,<1.43.0)"] +invoicing = ["mypy-boto3-invoicing (>=1.42.0,<1.43.0)"] +iot = ["mypy-boto3-iot (>=1.42.0,<1.43.0)"] +iot-data = ["mypy-boto3-iot-data (>=1.42.0,<1.43.0)"] +iot-jobs-data = ["mypy-boto3-iot-jobs-data (>=1.42.0,<1.43.0)"] +iot-managed-integrations = ["mypy-boto3-iot-managed-integrations (>=1.42.0,<1.43.0)"] +iotanalytics = ["mypy-boto3-iotanalytics (>=1.42.0,<1.43.0)"] +iotdeviceadvisor = ["mypy-boto3-iotdeviceadvisor (>=1.42.0,<1.43.0)"] +iotevents = ["mypy-boto3-iotevents (>=1.42.0,<1.43.0)"] +iotevents-data = ["mypy-boto3-iotevents-data (>=1.42.0,<1.43.0)"] +iotfleetwise = ["mypy-boto3-iotfleetwise (>=1.42.0,<1.43.0)"] +iotsecuretunneling = ["mypy-boto3-iotsecuretunneling (>=1.42.0,<1.43.0)"] +iotsitewise = ["mypy-boto3-iotsitewise (>=1.42.0,<1.43.0)"] +iotthingsgraph = ["mypy-boto3-iotthingsgraph (>=1.42.0,<1.43.0)"] +iottwinmaker = ["mypy-boto3-iottwinmaker (>=1.42.0,<1.43.0)"] +iotwireless = ["mypy-boto3-iotwireless (>=1.42.0,<1.43.0)"] +ivs = ["mypy-boto3-ivs (>=1.42.0,<1.43.0)"] +ivs-realtime = ["mypy-boto3-ivs-realtime (>=1.42.0,<1.43.0)"] +ivschat = ["mypy-boto3-ivschat (>=1.42.0,<1.43.0)"] +kafka = ["mypy-boto3-kafka (>=1.42.0,<1.43.0)"] +kafkaconnect = ["mypy-boto3-kafkaconnect (>=1.42.0,<1.43.0)"] +kendra = ["mypy-boto3-kendra (>=1.42.0,<1.43.0)"] +kendra-ranking = ["mypy-boto3-kendra-ranking (>=1.42.0,<1.43.0)"] +keyspaces = ["mypy-boto3-keyspaces (>=1.42.0,<1.43.0)"] +keyspacesstreams = ["mypy-boto3-keyspacesstreams (>=1.42.0,<1.43.0)"] +kinesis = ["mypy-boto3-kinesis (>=1.42.0,<1.43.0)"] +kinesis-video-archived-media = ["mypy-boto3-kinesis-video-archived-media (>=1.42.0,<1.43.0)"] +kinesis-video-media = ["mypy-boto3-kinesis-video-media (>=1.42.0,<1.43.0)"] +kinesis-video-signaling = ["mypy-boto3-kinesis-video-signaling (>=1.42.0,<1.43.0)"] +kinesis-video-webrtc-storage = ["mypy-boto3-kinesis-video-webrtc-storage (>=1.42.0,<1.43.0)"] +kinesisanalytics = ["mypy-boto3-kinesisanalytics (>=1.42.0,<1.43.0)"] +kinesisanalyticsv2 = ["mypy-boto3-kinesisanalyticsv2 (>=1.42.0,<1.43.0)"] +kinesisvideo = ["mypy-boto3-kinesisvideo (>=1.42.0,<1.43.0)"] +kms = ["mypy-boto3-kms (>=1.42.0,<1.43.0)"] +lakeformation = ["mypy-boto3-lakeformation (>=1.42.0,<1.43.0)"] +lambda = ["mypy-boto3-lambda (>=1.42.0,<1.43.0)"] +launch-wizard = ["mypy-boto3-launch-wizard (>=1.42.0,<1.43.0)"] +lex-models = ["mypy-boto3-lex-models (>=1.42.0,<1.43.0)"] +lex-runtime = ["mypy-boto3-lex-runtime (>=1.42.0,<1.43.0)"] +lexv2-models = ["mypy-boto3-lexv2-models (>=1.42.0,<1.43.0)"] +lexv2-runtime = ["mypy-boto3-lexv2-runtime (>=1.42.0,<1.43.0)"] +license-manager = ["mypy-boto3-license-manager (>=1.42.0,<1.43.0)"] +license-manager-linux-subscriptions = ["mypy-boto3-license-manager-linux-subscriptions (>=1.42.0,<1.43.0)"] +license-manager-user-subscriptions = ["mypy-boto3-license-manager-user-subscriptions (>=1.42.0,<1.43.0)"] +lightsail = ["mypy-boto3-lightsail (>=1.42.0,<1.43.0)"] +location = ["mypy-boto3-location (>=1.42.0,<1.43.0)"] +logs = ["mypy-boto3-logs (>=1.42.0,<1.43.0)"] +lookoutequipment = ["mypy-boto3-lookoutequipment (>=1.42.0,<1.43.0)"] +m2 = ["mypy-boto3-m2 (>=1.42.0,<1.43.0)"] +machinelearning = ["mypy-boto3-machinelearning (>=1.42.0,<1.43.0)"] +macie2 = ["mypy-boto3-macie2 (>=1.42.0,<1.43.0)"] +mailmanager = ["mypy-boto3-mailmanager (>=1.42.0,<1.43.0)"] +managedblockchain = ["mypy-boto3-managedblockchain (>=1.42.0,<1.43.0)"] +managedblockchain-query = ["mypy-boto3-managedblockchain-query (>=1.42.0,<1.43.0)"] +marketplace-agreement = ["mypy-boto3-marketplace-agreement (>=1.42.0,<1.43.0)"] +marketplace-catalog = ["mypy-boto3-marketplace-catalog (>=1.42.0,<1.43.0)"] +marketplace-deployment = ["mypy-boto3-marketplace-deployment (>=1.42.0,<1.43.0)"] +marketplace-entitlement = ["mypy-boto3-marketplace-entitlement (>=1.42.0,<1.43.0)"] +marketplace-reporting = ["mypy-boto3-marketplace-reporting (>=1.42.0,<1.43.0)"] +marketplacecommerceanalytics = ["mypy-boto3-marketplacecommerceanalytics (>=1.42.0,<1.43.0)"] +mediaconnect = ["mypy-boto3-mediaconnect (>=1.42.0,<1.43.0)"] +mediaconvert = ["mypy-boto3-mediaconvert (>=1.42.0,<1.43.0)"] +medialive = ["mypy-boto3-medialive (>=1.42.0,<1.43.0)"] +mediapackage = ["mypy-boto3-mediapackage (>=1.42.0,<1.43.0)"] +mediapackage-vod = ["mypy-boto3-mediapackage-vod (>=1.42.0,<1.43.0)"] +mediapackagev2 = ["mypy-boto3-mediapackagev2 (>=1.42.0,<1.43.0)"] +mediastore = ["mypy-boto3-mediastore (>=1.42.0,<1.43.0)"] +mediastore-data = ["mypy-boto3-mediastore-data (>=1.42.0,<1.43.0)"] +mediatailor = ["mypy-boto3-mediatailor (>=1.42.0,<1.43.0)"] +medical-imaging = ["mypy-boto3-medical-imaging (>=1.42.0,<1.43.0)"] +memorydb = ["mypy-boto3-memorydb (>=1.42.0,<1.43.0)"] +meteringmarketplace = ["mypy-boto3-meteringmarketplace (>=1.42.0,<1.43.0)"] +mgh = ["mypy-boto3-mgh (>=1.42.0,<1.43.0)"] +mgn = ["mypy-boto3-mgn (>=1.42.0,<1.43.0)"] +migration-hub-refactor-spaces = ["mypy-boto3-migration-hub-refactor-spaces (>=1.42.0,<1.43.0)"] +migrationhub-config = ["mypy-boto3-migrationhub-config (>=1.42.0,<1.43.0)"] +migrationhuborchestrator = ["mypy-boto3-migrationhuborchestrator (>=1.42.0,<1.43.0)"] +migrationhubstrategy = ["mypy-boto3-migrationhubstrategy (>=1.42.0,<1.43.0)"] +mpa = ["mypy-boto3-mpa (>=1.42.0,<1.43.0)"] +mq = ["mypy-boto3-mq (>=1.42.0,<1.43.0)"] +mturk = ["mypy-boto3-mturk (>=1.42.0,<1.43.0)"] +mwaa = ["mypy-boto3-mwaa (>=1.42.0,<1.43.0)"] +mwaa-serverless = ["mypy-boto3-mwaa-serverless (>=1.42.0,<1.43.0)"] +neptune = ["mypy-boto3-neptune (>=1.42.0,<1.43.0)"] +neptune-graph = ["mypy-boto3-neptune-graph (>=1.42.0,<1.43.0)"] +neptunedata = ["mypy-boto3-neptunedata (>=1.42.0,<1.43.0)"] +network-firewall = ["mypy-boto3-network-firewall (>=1.42.0,<1.43.0)"] +networkflowmonitor = ["mypy-boto3-networkflowmonitor (>=1.42.0,<1.43.0)"] +networkmanager = ["mypy-boto3-networkmanager (>=1.42.0,<1.43.0)"] +networkmonitor = ["mypy-boto3-networkmonitor (>=1.42.0,<1.43.0)"] +notifications = ["mypy-boto3-notifications (>=1.42.0,<1.43.0)"] +notificationscontacts = ["mypy-boto3-notificationscontacts (>=1.42.0,<1.43.0)"] +nova-act = ["mypy-boto3-nova-act (>=1.42.0,<1.43.0)"] +oam = ["mypy-boto3-oam (>=1.42.0,<1.43.0)"] +observabilityadmin = ["mypy-boto3-observabilityadmin (>=1.42.0,<1.43.0)"] +odb = ["mypy-boto3-odb (>=1.42.0,<1.43.0)"] +omics = ["mypy-boto3-omics (>=1.42.0,<1.43.0)"] +opensearch = ["mypy-boto3-opensearch (>=1.42.0,<1.43.0)"] +opensearchserverless = ["mypy-boto3-opensearchserverless (>=1.42.0,<1.43.0)"] +organizations = ["mypy-boto3-organizations (>=1.42.0,<1.43.0)"] +osis = ["mypy-boto3-osis (>=1.42.0,<1.43.0)"] +outposts = ["mypy-boto3-outposts (>=1.42.0,<1.43.0)"] +panorama = ["mypy-boto3-panorama (>=1.42.0,<1.43.0)"] +partnercentral-account = ["mypy-boto3-partnercentral-account (>=1.42.0,<1.43.0)"] +partnercentral-benefits = ["mypy-boto3-partnercentral-benefits (>=1.42.0,<1.43.0)"] +partnercentral-channel = ["mypy-boto3-partnercentral-channel (>=1.42.0,<1.43.0)"] +partnercentral-selling = ["mypy-boto3-partnercentral-selling (>=1.42.0,<1.43.0)"] +payment-cryptography = ["mypy-boto3-payment-cryptography (>=1.42.0,<1.43.0)"] +payment-cryptography-data = ["mypy-boto3-payment-cryptography-data (>=1.42.0,<1.43.0)"] +pca-connector-ad = ["mypy-boto3-pca-connector-ad (>=1.42.0,<1.43.0)"] +pca-connector-scep = ["mypy-boto3-pca-connector-scep (>=1.42.0,<1.43.0)"] +pcs = ["mypy-boto3-pcs (>=1.42.0,<1.43.0)"] +personalize = ["mypy-boto3-personalize (>=1.42.0,<1.43.0)"] +personalize-events = ["mypy-boto3-personalize-events (>=1.42.0,<1.43.0)"] +personalize-runtime = ["mypy-boto3-personalize-runtime (>=1.42.0,<1.43.0)"] +pi = ["mypy-boto3-pi (>=1.42.0,<1.43.0)"] +pinpoint = ["mypy-boto3-pinpoint (>=1.42.0,<1.43.0)"] +pinpoint-email = ["mypy-boto3-pinpoint-email (>=1.42.0,<1.43.0)"] +pinpoint-sms-voice = ["mypy-boto3-pinpoint-sms-voice (>=1.42.0,<1.43.0)"] +pinpoint-sms-voice-v2 = ["mypy-boto3-pinpoint-sms-voice-v2 (>=1.42.0,<1.43.0)"] +pipes = ["mypy-boto3-pipes (>=1.42.0,<1.43.0)"] +polly = ["mypy-boto3-polly (>=1.42.0,<1.43.0)"] +pricing = ["mypy-boto3-pricing (>=1.42.0,<1.43.0)"] +proton = ["mypy-boto3-proton (>=1.42.0,<1.43.0)"] +qapps = ["mypy-boto3-qapps (>=1.42.0,<1.43.0)"] +qbusiness = ["mypy-boto3-qbusiness (>=1.42.0,<1.43.0)"] +qconnect = ["mypy-boto3-qconnect (>=1.42.0,<1.43.0)"] +quicksight = ["mypy-boto3-quicksight (>=1.42.0,<1.43.0)"] +ram = ["mypy-boto3-ram (>=1.42.0,<1.43.0)"] +rbin = ["mypy-boto3-rbin (>=1.42.0,<1.43.0)"] +rds = ["mypy-boto3-rds (>=1.42.0,<1.43.0)"] +rds-data = ["mypy-boto3-rds-data (>=1.42.0,<1.43.0)"] +redshift = ["mypy-boto3-redshift (>=1.42.0,<1.43.0)"] +redshift-data = ["mypy-boto3-redshift-data (>=1.42.0,<1.43.0)"] +redshift-serverless = ["mypy-boto3-redshift-serverless (>=1.42.0,<1.43.0)"] +rekognition = ["mypy-boto3-rekognition (>=1.42.0,<1.43.0)"] +repostspace = ["mypy-boto3-repostspace (>=1.42.0,<1.43.0)"] +resiliencehub = ["mypy-boto3-resiliencehub (>=1.42.0,<1.43.0)"] +resource-explorer-2 = ["mypy-boto3-resource-explorer-2 (>=1.42.0,<1.43.0)"] +resource-groups = ["mypy-boto3-resource-groups (>=1.42.0,<1.43.0)"] +resourcegroupstaggingapi = ["mypy-boto3-resourcegroupstaggingapi (>=1.42.0,<1.43.0)"] +rolesanywhere = ["mypy-boto3-rolesanywhere (>=1.42.0,<1.43.0)"] +route53 = ["mypy-boto3-route53 (>=1.42.0,<1.43.0)"] +route53-recovery-cluster = ["mypy-boto3-route53-recovery-cluster (>=1.42.0,<1.43.0)"] +route53-recovery-control-config = ["mypy-boto3-route53-recovery-control-config (>=1.42.0,<1.43.0)"] +route53-recovery-readiness = ["mypy-boto3-route53-recovery-readiness (>=1.42.0,<1.43.0)"] +route53domains = ["mypy-boto3-route53domains (>=1.42.0,<1.43.0)"] +route53globalresolver = ["mypy-boto3-route53globalresolver (>=1.42.0,<1.43.0)"] +route53profiles = ["mypy-boto3-route53profiles (>=1.42.0,<1.43.0)"] +route53resolver = ["mypy-boto3-route53resolver (>=1.42.0,<1.43.0)"] +rtbfabric = ["mypy-boto3-rtbfabric (>=1.42.0,<1.43.0)"] +rum = ["mypy-boto3-rum (>=1.42.0,<1.43.0)"] +s3 = ["mypy-boto3-s3 (>=1.42.0,<1.43.0)"] +s3control = ["mypy-boto3-s3control (>=1.42.0,<1.43.0)"] +s3outposts = ["mypy-boto3-s3outposts (>=1.42.0,<1.43.0)"] +s3tables = ["mypy-boto3-s3tables (>=1.42.0,<1.43.0)"] +s3vectors = ["mypy-boto3-s3vectors (>=1.42.0,<1.43.0)"] +sagemaker = ["mypy-boto3-sagemaker (>=1.42.0,<1.43.0)"] +sagemaker-a2i-runtime = ["mypy-boto3-sagemaker-a2i-runtime (>=1.42.0,<1.43.0)"] +sagemaker-edge = ["mypy-boto3-sagemaker-edge (>=1.42.0,<1.43.0)"] +sagemaker-featurestore-runtime = ["mypy-boto3-sagemaker-featurestore-runtime (>=1.42.0,<1.43.0)"] +sagemaker-geospatial = ["mypy-boto3-sagemaker-geospatial (>=1.42.0,<1.43.0)"] +sagemaker-metrics = ["mypy-boto3-sagemaker-metrics (>=1.42.0,<1.43.0)"] +sagemaker-runtime = ["mypy-boto3-sagemaker-runtime (>=1.42.0,<1.43.0)"] +savingsplans = ["mypy-boto3-savingsplans (>=1.42.0,<1.43.0)"] +scheduler = ["mypy-boto3-scheduler (>=1.42.0,<1.43.0)"] +schemas = ["mypy-boto3-schemas (>=1.42.0,<1.43.0)"] +sdb = ["mypy-boto3-sdb (>=1.42.0,<1.43.0)"] +secretsmanager = ["mypy-boto3-secretsmanager (>=1.42.0,<1.43.0)"] +security-ir = ["mypy-boto3-security-ir (>=1.42.0,<1.43.0)"] +securityhub = ["mypy-boto3-securityhub (>=1.42.0,<1.43.0)"] +securitylake = ["mypy-boto3-securitylake (>=1.42.0,<1.43.0)"] +serverlessrepo = ["mypy-boto3-serverlessrepo (>=1.42.0,<1.43.0)"] +service-quotas = ["mypy-boto3-service-quotas (>=1.42.0,<1.43.0)"] +servicecatalog = ["mypy-boto3-servicecatalog (>=1.42.0,<1.43.0)"] +servicecatalog-appregistry = ["mypy-boto3-servicecatalog-appregistry (>=1.42.0,<1.43.0)"] +servicediscovery = ["mypy-boto3-servicediscovery (>=1.42.0,<1.43.0)"] +ses = ["mypy-boto3-ses (>=1.42.0,<1.43.0)"] +sesv2 = ["mypy-boto3-sesv2 (>=1.42.0,<1.43.0)"] +shield = ["mypy-boto3-shield (>=1.42.0,<1.43.0)"] +signer = ["mypy-boto3-signer (>=1.42.0,<1.43.0)"] +signin = ["mypy-boto3-signin (>=1.42.0,<1.43.0)"] +simspaceweaver = ["mypy-boto3-simspaceweaver (>=1.42.0,<1.43.0)"] +snow-device-management = ["mypy-boto3-snow-device-management (>=1.42.0,<1.43.0)"] +snowball = ["mypy-boto3-snowball (>=1.42.0,<1.43.0)"] +sns = ["mypy-boto3-sns (>=1.42.0,<1.43.0)"] +socialmessaging = ["mypy-boto3-socialmessaging (>=1.42.0,<1.43.0)"] +sqs = ["mypy-boto3-sqs (>=1.42.0,<1.43.0)"] +ssm = ["mypy-boto3-ssm (>=1.42.0,<1.43.0)"] +ssm-contacts = ["mypy-boto3-ssm-contacts (>=1.42.0,<1.43.0)"] +ssm-guiconnect = ["mypy-boto3-ssm-guiconnect (>=1.42.0,<1.43.0)"] +ssm-incidents = ["mypy-boto3-ssm-incidents (>=1.42.0,<1.43.0)"] +ssm-quicksetup = ["mypy-boto3-ssm-quicksetup (>=1.42.0,<1.43.0)"] +ssm-sap = ["mypy-boto3-ssm-sap (>=1.42.0,<1.43.0)"] +sso = ["mypy-boto3-sso (>=1.42.0,<1.43.0)"] +sso-admin = ["mypy-boto3-sso-admin (>=1.42.0,<1.43.0)"] +sso-oidc = ["mypy-boto3-sso-oidc (>=1.42.0,<1.43.0)"] +stepfunctions = ["mypy-boto3-stepfunctions (>=1.42.0,<1.43.0)"] +storagegateway = ["mypy-boto3-storagegateway (>=1.42.0,<1.43.0)"] +sts = ["mypy-boto3-sts (>=1.42.0,<1.43.0)"] +supplychain = ["mypy-boto3-supplychain (>=1.42.0,<1.43.0)"] +support = ["mypy-boto3-support (>=1.42.0,<1.43.0)"] +support-app = ["mypy-boto3-support-app (>=1.42.0,<1.43.0)"] +swf = ["mypy-boto3-swf (>=1.42.0,<1.43.0)"] +synthetics = ["mypy-boto3-synthetics (>=1.42.0,<1.43.0)"] +taxsettings = ["mypy-boto3-taxsettings (>=1.42.0,<1.43.0)"] +textract = ["mypy-boto3-textract (>=1.42.0,<1.43.0)"] +timestream-influxdb = ["mypy-boto3-timestream-influxdb (>=1.42.0,<1.43.0)"] +timestream-query = ["mypy-boto3-timestream-query (>=1.42.0,<1.43.0)"] +timestream-write = ["mypy-boto3-timestream-write (>=1.42.0,<1.43.0)"] +tnb = ["mypy-boto3-tnb (>=1.42.0,<1.43.0)"] +transcribe = ["mypy-boto3-transcribe (>=1.42.0,<1.43.0)"] +transfer = ["mypy-boto3-transfer (>=1.42.0,<1.43.0)"] +translate = ["mypy-boto3-translate (>=1.42.0,<1.43.0)"] +trustedadvisor = ["mypy-boto3-trustedadvisor (>=1.42.0,<1.43.0)"] +verifiedpermissions = ["mypy-boto3-verifiedpermissions (>=1.42.0,<1.43.0)"] +voice-id = ["mypy-boto3-voice-id (>=1.42.0,<1.43.0)"] +vpc-lattice = ["mypy-boto3-vpc-lattice (>=1.42.0,<1.43.0)"] +waf = ["mypy-boto3-waf (>=1.42.0,<1.43.0)"] +waf-regional = ["mypy-boto3-waf-regional (>=1.42.0,<1.43.0)"] +wafv2 = ["mypy-boto3-wafv2 (>=1.42.0,<1.43.0)"] +wellarchitected = ["mypy-boto3-wellarchitected (>=1.42.0,<1.43.0)"] +wickr = ["mypy-boto3-wickr (>=1.42.0,<1.43.0)"] +wisdom = ["mypy-boto3-wisdom (>=1.42.0,<1.43.0)"] +workdocs = ["mypy-boto3-workdocs (>=1.42.0,<1.43.0)"] +workmail = ["mypy-boto3-workmail (>=1.42.0,<1.43.0)"] +workmailmessageflow = ["mypy-boto3-workmailmessageflow (>=1.42.0,<1.43.0)"] +workspaces = ["mypy-boto3-workspaces (>=1.42.0,<1.43.0)"] +workspaces-instances = ["mypy-boto3-workspaces-instances (>=1.42.0,<1.43.0)"] +workspaces-thin-client = ["mypy-boto3-workspaces-thin-client (>=1.42.0,<1.43.0)"] +workspaces-web = ["mypy-boto3-workspaces-web (>=1.42.0,<1.43.0)"] +xray = ["mypy-boto3-xray (>=1.42.0,<1.43.0)"] [[package]] name = "botocore" @@ -769,7 +812,6 @@ description = "Canonical JSON" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "canonicaljson-2.0.0-py3-none-any.whl", hash = "sha256:c38a315de3b5a0532f1ec1f9153cd3d716abfc565a558d00a4835428a34fca5b"}, {file = "canonicaljson-2.0.0.tar.gz", hash = "sha256:e2fdaef1d7fadc5d9cb59bd3d0d41b064ddda697809ac4325dced721d12f113f"}, @@ -1015,7 +1057,7 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "extra == \"server\" and (platform_system == \"Windows\" or sys_platform == \"win32\")", dev = "sys_platform == \"win32\""} +markers = {main = "platform_system == \"Windows\" or extra == \"server\" and sys_platform == \"win32\"", dev = "sys_platform == \"win32\""} [[package]] name = "coloredlogs" @@ -1501,7 +1543,6 @@ description = "GA4GH Categorical Variation Representation (Cat-VRS) reference im optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_cat_vrs-0.7.1-py3-none-any.whl", hash = "sha256:549e726182d9fdc28d049b9adc6a8c65189bbade06b2ceed8cb20a35cbdefc45"}, {file = "ga4gh_cat_vrs-0.7.1.tar.gz", hash = "sha256:ac8d11ea5f474e8a9745107673d4e8b6949819ccdc9debe2ab8ad8e5f853f87c"}, @@ -1523,7 +1564,6 @@ description = "GA4GH Variant Annotation (VA) reference implementation" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_va_spec-0.4.2-py3-none-any.whl", hash = "sha256:c165a96dfa225845b5d63740d3ad40c9f2dcb26808cf759b73bc122a68a9a60e"}, {file = "ga4gh_va_spec-0.4.2.tar.gz", hash = "sha256:13eda6a8cfc7a2baa395e33d17e3296c2ec1c63ec85fe38085751c112cf1c902"}, @@ -1546,7 +1586,6 @@ description = "GA4GH Variation Representation Specification (VRS) reference impl optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_vrs-2.1.3-py3-none-any.whl", hash = "sha256:15b20363d9d4a4604be0930b41b14c9b4e6dc15a6e8be813544f0775b873bc5b"}, {file = "ga4gh_vrs-2.1.3.tar.gz", hash = "sha256:48af6de1eb40e00aa68ed5a935061917b4017468ef366e8e68bbbc17ffaa60f3"}, @@ -2496,6 +2535,21 @@ install-types = ["pip"] mypyc = ["setuptools (>=50)"] reports = ["lxml"] +[[package]] +name = "mypy-boto3-s3" +version = "1.42.21" +description = "Type annotations for boto3 S3 1.42.21 service generated with mypy-boto3-builder 8.12.0" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mypy_boto3_s3-1.42.21-py3-none-any.whl", hash = "sha256:f5b7d1ed718ba5b00f67e95a9a38c6a021159d3071ea235e6cf496e584115ded"}, + {file = "mypy_boto3_s3-1.42.21.tar.gz", hash = "sha256:cab71c918aac7d98c4d742544c722e37d8e7178acb8bc88a0aead7b1035026d2"}, +] + +[package.dependencies] +typing-extensions = {version = "*", markers = "python_version < \"3.12\""} + [[package]] name = "mypy-extensions" version = "1.1.0" @@ -3815,7 +3869,6 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"server\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -4794,9 +4847,9 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it type = ["pytest-mypy"] [extras] -server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "ga4gh-va-spec", "hgvs", "orcid", "psycopg2", "pyathena", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "pyathena", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "cb94d5f7faedc07aa0e3457fdb0735b6526b2f40f02c6d438cab46b733123fd6" +content-hash = "452148c0c5ee1b9cbb12087a27c8d6d3e650ad1eb4fed99b4470b4db16f041c6" diff --git a/pyproject.toml b/pyproject.toml index ca00ecf0..cc4e938c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,12 +62,13 @@ starlette-context = { version = "^0.3.6", optional = true } slack-sdk = { version = "~3.21.3", optional = true } uvicorn = { extras = ["standard"], version = "*", optional = true } watchtower = { version = "~3.2.0", optional = true } +asyncclick = "^8.3.0.7" [tool.poetry.group.dev] optional = true [tool.poetry.group.dev.dependencies] -boto3-stubs = "~1.34.97" +boto3-stubs = { extras = ["s3"], version = "~1.42.33" } mypy = "~1.10.0" pre-commit = "*" jsonschema = "*" @@ -88,7 +89,7 @@ SQLAlchemy = { extras = ["mypy"], version = "~2.0.0" } [tool.poetry.extras] -server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "ga4gh-va-spec", "orcid", "psycopg2", "python-jose", "python-multipart", "pyathena", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "pyathena", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] [tool.mypy] @@ -100,11 +101,17 @@ plugins = [ mypy_path = "mypy_stubs" [tool.pytest.ini_options] -addopts = "-v -rP --import-mode=importlib --disable-socket --allow-unix-socket --allow-hosts localhost,::1,127.0.0.1" +addopts = "-v -rP --import-mode=importlib" asyncio_mode = 'strict' testpaths = "tests/" pythonpath = "." norecursedirs = "tests/helpers/" +markers = """ + integration: mark a test as an integration test. + unit: mark a test as a unit test. + network: mark a test that requires network access. + slow: mark a test as slow-running. +""" # Uncomment the following lines to include application log output in Pytest logs. # log_cli = true # log_cli_level = "DEBUG" diff --git a/settings/.env.template b/settings/.env.template index fbb5b861..a11bbbbb 100644 --- a/settings/.env.template +++ b/settings/.env.template @@ -98,3 +98,12 @@ AWS_REGION_NAME=us-west-2 ATHENA_SCHEMA_NAME=default ATHENA_S3_STAGING_DIR=s3://your-bucket/path/to/staging/ GNOMAD_DATA_VERSION=v4.1 + +#################################################################################################### +# Environment variables for S3 connection +#################################################################################################### + +AWS_ACCESS_KEY_ID=test +AWS_SECRET_ACCESS_KEY=test +S3_ENDPOINT_URL=http://localstack:4566 +UPLOAD_S3_BUCKET_NAME=score-set-csv-uploads-dev \ No newline at end of file diff --git a/src/mavedb/data_providers/services.py b/src/mavedb/data_providers/services.py index eed9b01d..a94c16d6 100644 --- a/src/mavedb/data_providers/services.py +++ b/src/mavedb/data_providers/services.py @@ -1,10 +1,14 @@ import os -from typing import Optional +from typing import TYPE_CHECKING, Optional -from cdot.hgvs.dataproviders import SeqFetcher, ChainedSeqFetcher, FastaSeqFetcher, RESTDataProvider +import boto3 +from cdot.hgvs.dataproviders import ChainedSeqFetcher, FastaSeqFetcher, RESTDataProvider, SeqFetcher from mavedb.lib.mapping import VRSMap +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + GENOMIC_FASTA_FILES = [ "/data/GCF_000001405.39_GRCh38.p13_genomic.fna.gz", "/data/GCF_000001405.25_GRCh37.p13_genomic.fna.gz", @@ -12,6 +16,7 @@ DCD_MAP_URL = os.environ.get("DCD_MAPPING_URL", "http://dcd-mapping:8000") CDOT_URL = os.environ.get("CDOT_URL", "http://cdot-rest:8000") +CSV_UPLOAD_S3_BUCKET_NAME = os.getenv("UPLOAD_S3_BUCKET_NAME", "score-set-csv-uploads-dev") def seqfetcher() -> ChainedSeqFetcher: @@ -24,3 +29,13 @@ def cdot_rest() -> RESTDataProvider: def vrs_mapper(url: Optional[str] = None) -> VRSMap: return VRSMap(DCD_MAP_URL) if not url else VRSMap(url) + + +def s3_client() -> "S3Client": + return boto3.client( + "s3", + endpoint_url=os.getenv("S3_ENDPOINT_URL"), + aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_REGION_NAME", "us-west-2"), + ) diff --git a/src/mavedb/db/session.py b/src/mavedb/db/session.py index ab75604a..4fe2baa1 100644 --- a/src/mavedb/db/session.py +++ b/src/mavedb/db/session.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -15,8 +16,23 @@ engine = create_engine( # For PostgreSQL: - DB_URL + DB_URL, + pool_size=10, # For SQLite: # DB_URL, connect_args={"check_same_thread": False} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@contextmanager +def db_session(): + """Provide a transactional scope around a series of operations.""" + session = SessionLocal() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/mavedb/lib/annotation_status_manager.py b/src/mavedb/lib/annotation_status_manager.py new file mode 100644 index 00000000..29b17bc0 --- /dev/null +++ b/src/mavedb/lib/annotation_status_manager.py @@ -0,0 +1,146 @@ +"""Manage annotation statuses for variants. + +This module provides functionality to insert and retrieve annotation statuses +for genetic variants, ensuring that only one current status exists per +(variant, annotation type, version) combination. +""" + +import logging +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + +logger = logging.getLogger(__name__) + + +class AnnotationStatusManager: + """ + Manager for handling variant annotation statuses. + + Attributes: + session (Session): The SQLAlchemy session used for database operations. + + Methods: + add_annotation( + variant_id: int, + annotation_type: AnnotationType, + version: Optional[str], + annotation_data: dict, + current: bool = True + ) -> VariantAnnotationStatus: + Inserts a new annotation status and marks previous ones as not current. + + get_current_annotation( + variant_id: int, + annotation_type: AnnotationType, + version: Optional[str] = None + ) -> Optional[VariantAnnotationStatus]: + Retrieves the current annotation status for a given variant/type/version. + """ + + def __init__(self, session: Session): + self.session = session + + def add_annotation( + self, + variant_id: int, + annotation_type: AnnotationType, + status: AnnotationStatus, + version: Optional[str] = None, + annotation_data: dict = {}, + current: bool = True, + ) -> VariantAnnotationStatus: + """ + Insert a new annotation and mark previous ones as not current for the same (variant, type, version). + Callers should take care to ensure only one current annotation exists per (variant, type, version). Note + + Args: + variant_id (int): The ID of the variant being annotated. + annotation_type (AnnotationType): The type of annotation (e.g., 'vrs', 'clinvar'). + version (Optional[str]): The version of the annotation source. + annotation_data (dict): Additional data for the annotation status. + current (bool): Whether this annotation is the current one. + + Returns: + VariantAnnotationStatus: The newly created annotation status record. + + Side Effects: + - Updates existing records to set current=False for the same (variant, type, version). + - Adds a new VariantAnnotationStatus record to the database session. + + NOTE: + - This method does not commit the session and only flushes to the database. The caller + is responsible for persisting any changes (e.g., by calling session.commit()). + """ + logger.debug( + f"Adding annotation for variant_id={variant_id}, annotation_type={annotation_type}, version={version}" + ) + + # Find existing current annotations to be replaced + existing_current = ( + self.session.execute( + select(VariantAnnotationStatus).where( + VariantAnnotationStatus.variant_id == variant_id, + VariantAnnotationStatus.annotation_type == annotation_type, + VariantAnnotationStatus.version == version, + VariantAnnotationStatus.current.is_(True), + ) + ) + .scalars() + .all() + ) + for var_ann in existing_current: + logger.debug( + f"Replacing current annotation {var_ann.id} for variant_id={variant_id}, annotation_type={annotation_type}, version={version}" + ) + var_ann.current = False + + self.session.flush() + + new_status = VariantAnnotationStatus( + variant_id=variant_id, + annotation_type=annotation_type, + status=status, + version=version, + current=current, + **annotation_data, + ) # type: ignore[call-arg] + + self.session.add(new_status) + self.session.flush() + + logger.info( + f"Successfully added annotation for variant_id={variant_id}, annotation_type={annotation_type}, version={version}" + ) + return new_status + + def get_current_annotation( + self, variant_id: int, annotation_type: AnnotationType, version: Optional[str] = None + ) -> Optional[VariantAnnotationStatus]: + """ + Retrieve the current annotation for a given variant/type/version. + + Args: + variant_id (int): The ID of the variant. + annotation_type (AnnotationType): The type of annotation. + version (Optional[str]): The version of the annotation source. + + Returns: + Optional[VariantAnnotationStatus]: The current annotation status record, or None if not found. + """ + stmt = select(VariantAnnotationStatus).where( + VariantAnnotationStatus.variant_id == variant_id, + VariantAnnotationStatus.annotation_type == annotation_type, + VariantAnnotationStatus.current.is_(True), + ) + + if version is not None: + stmt = stmt.where(VariantAnnotationStatus.version == version) + + result = self.session.execute(stmt) + return result.scalar_one_or_none() diff --git a/src/mavedb/lib/clingen/allele_registry.py b/src/mavedb/lib/clingen/allele_registry.py index 5e025b14..a7951255 100644 --- a/src/mavedb/lib/clingen/allele_registry.py +++ b/src/mavedb/lib/clingen/allele_registry.py @@ -1,4 +1,5 @@ import logging + import requests logger = logging.getLogger(__name__) @@ -43,3 +44,18 @@ def get_matching_registered_ca_ids(clingen_pa_id: str) -> list[str]: ca_ids.extend([allele["@id"].split("/")[-1] for allele in allele["matchingRegisteredTranscripts"]]) return ca_ids + + +def get_associated_clinvar_allele_id(clingen_allele_id: str) -> str | None: + """Retrieve the associated ClinVar Allele ID for a given ClinGen Allele ID from the ClinGen API.""" + response = requests.get(f"{CLINGEN_API_URL}/{clingen_allele_id}") + if response.status_code != 200: + logger.error(f"Failed to query ClinGen API for {clingen_allele_id}: {response.status_code}") + return None + + data = response.json() + clinvar_allele_id = data.get("externalRecords", {}).get("ClinVarAlleles", [{}])[0].get("alleleId") + if clinvar_allele_id: + return str(clinvar_allele_id) + + return None diff --git a/src/mavedb/lib/clingen/constants.py b/src/mavedb/lib/clingen/constants.py index 2bc6979b..77a33a53 100644 --- a/src/mavedb/lib/clingen/constants.py +++ b/src/mavedb/lib/clingen/constants.py @@ -17,5 +17,3 @@ LDH_SUBMISSION_ENDPOINT = f"https://genboree.org/mq/brdg/pulsar/{CLIN_GEN_TENANT}/ldh/submissions/{LDH_ENTITY_ENDPOINT}" LDH_ACCESS_ENDPOINT = os.getenv("LDH_ACCESS_ENDPOINT", "https://genboree.org/ldh") LDH_MAVE_ACCESS_ENDPOINT = f"{LDH_ACCESS_ENDPOINT}/{LDH_ENTITY_NAME}/id" - -LINKED_DATA_RETRY_THRESHOLD = 0.95 diff --git a/src/mavedb/lib/clingen/services.py b/src/mavedb/lib/clingen/services.py index 1bcb7778..a9e41fcb 100644 --- a/src/mavedb/lib/clingen/services.py +++ b/src/mavedb/lib/clingen/services.py @@ -1,19 +1,16 @@ import hashlib import logging -import requests import os import time from datetime import datetime from typing import Optional -from urllib import parse - +import requests from jose import jwt -from mavedb.lib.logging.context import logging_context, save_to_logging_context, format_raised_exception_info_as_dict -from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD, LDH_MAVE_ACCESS_ENDPOINT - -from mavedb.lib.types.clingen import LdhSubmission, ClinGenAllele +from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD +from mavedb.lib.logging.context import format_raised_exception_info_as_dict, logging_context, save_to_logging_context +from mavedb.lib.types.clingen import ClinGenAllele, LdhSubmission from mavedb.lib.utils import batched logger = logging.getLogger(__name__) @@ -279,50 +276,6 @@ def _existing_jwt(self) -> Optional[str]: return None -def get_clingen_variation(urn: str) -> Optional[dict]: - """ - Fetches ClinGen variation data for a given URN (Uniform Resource Name) from the Linked Data Hub. - - Args: - urn (str): The URN of the variation to fetch. - - Returns: - Optional[dict]: A dictionary containing the variation data if the request is successful, - or None if the request fails. - """ - response = requests.get( - f"{LDH_MAVE_ACCESS_ENDPOINT}/{parse.quote_plus(urn)}", - headers={"Accept": "application/json"}, - ) - - if response.status_code == 200: - return response.json() - else: - logger.error(f"Failed to fetch data for URN {urn}: {response.status_code} - {response.text}") - return None - - -def clingen_allele_id_from_ldh_variation(variation: Optional[dict]) -> Optional[str]: - """ - Extracts the ClinGen allele ID from a given variation dictionary. - - Args: - variation (Optional[dict]): A dictionary containing variation data, otherwise None. - - Returns: - Optional[str]: The ClinGen allele ID if found, otherwise None. - """ - if not variation: - return None - - try: - return variation["data"]["ldFor"]["Variant"][0]["entId"] - except (KeyError, IndexError) as exc: - save_to_logging_context(format_raised_exception_info_as_dict(exc)) - logger.error("Failed to extract ClinGen allele ID from variation data.", extra=logging_context()) - return None - - def get_allele_registry_associations( content_submissions: list[str], submission_response: list[ClinGenAllele] ) -> dict[str, str]: diff --git a/src/mavedb/lib/clinvar/__init__.py b/src/mavedb/lib/clinvar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/lib/clinvar/constants.py b/src/mavedb/lib/clinvar/constants.py new file mode 100644 index 00000000..b0d5397f --- /dev/null +++ b/src/mavedb/lib/clinvar/constants.py @@ -0,0 +1 @@ +TSV_VARIANT_ARCHIVE_BASE_URL = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/archive" diff --git a/src/mavedb/lib/clinvar/utils.py b/src/mavedb/lib/clinvar/utils.py new file mode 100644 index 00000000..845dcec9 --- /dev/null +++ b/src/mavedb/lib/clinvar/utils.py @@ -0,0 +1,112 @@ +import csv +import gzip +import io +import sys +from datetime import datetime +from typing import Dict + +import requests + +from mavedb.lib.clinvar.constants import TSV_VARIANT_ARCHIVE_BASE_URL + + +def validate_clinvar_variant_summary_date(month: int, year: int) -> None: + """ + Validates the provided month and year for fetching ClinVar variant summary data. + + Ensures that: + - The year is not earlier than 2015 (ClinVar archived data is only available from 2015 onwards). + - The year is not in the future. + - If the year is the current year, the month is not in the future. + + Raises: + ValueError: If the provided year is before 2015, in the future, or if the month is in the future for the current year. + + Args: + month (int): The month to validate (1-12). + year (int): The year to validate. + """ + current_year = datetime.now().year + current_month = datetime.now().month + + if month < 1 or month > 12: + raise ValueError("Month must be an integer between 1 and 12.") + + if year < 2015 or (year == 2015 and month < 2): + raise ValueError("ClinVar archived data is only available from February 2015 onwards.") + elif year > current_year: + raise ValueError("Cannot fetch ClinVar data for future years.") + elif year == current_year and month > current_month: + raise ValueError("Cannot fetch ClinVar data for future months.") + + +def fetch_clinvar_variant_summary_tsv(month: int, year: int) -> bytes: + """ + Fetches the ClinVar variant summary TSV file for a specified month and year. + + This function attempts to download the variant summary file from the ClinVar FTP archive. + It first tries the top-level directory for recent files, and if not found, falls back to the year-based subdirectory. + The function validates the provided month and year before attempting the download. + + Args: + month (int): The month for which to fetch the variant summary (as an integer). + year (int): The year for which to fetch the variant summary. + + Returns: + bytes: The contents of the downloaded variant summary TSV file (gzipped). + + Raises: + requests.RequestException: If the file cannot be downloaded from either location. + ValueError: If the provided month or year is invalid. + """ + validate_clinvar_variant_summary_date(month, year) + + # Construct URLs for the variant summary TSV file. ClinVar stores recent files at the top level and older files in year-based subdirectories. + # The cadence at which files are moved is not documented, so we try both locations with a preference for the top-level URL. + url_top_level = f"{TSV_VARIANT_ARCHIVE_BASE_URL}/variant_summary_{year}-{month:02d}.txt.gz" + url_archive = f"{TSV_VARIANT_ARCHIVE_BASE_URL}/{year}/variant_summary_{year}-{month:02d}.txt.gz" + + try: + response = requests.get(url_top_level, stream=True) + response.raise_for_status() + return response.content + except requests.exceptions.HTTPError: + response = requests.get(url_archive, stream=True) + response.raise_for_status() + return response.content + + +def parse_clinvar_variant_summary(tsv_content: bytes) -> Dict[str, Dict[str, str]]: + """ + Parses a gzipped TSV file content and returns a dictionary mapping Allele IDs to row data. + + Args: + tsv_content (bytes): The gzipped TSV file content as bytes. + + Returns: + Dict[str, Dict[str, str]]: A dictionary where each key is a string Allele ID (from the '#AlleleID' column), + and each value is a dictionary representing the corresponding row with column names as keys. + + Raises: + KeyError: If the '#AlleleID' column is missing in any row. + ValueError: If the '#AlleleID' value cannot be converted to an integer. + csv.Error: If there is an error parsing the TSV content. + + Note: + The function temporarily increases the CSV field size limit to handle large fields in the TSV file. Some old ClinVar + variant summary files may have fields larger than the default limit. + """ + default_csv_field_size_limit = csv.field_size_limit() + + try: + csv.field_size_limit(sys.maxsize) + + with gzip.open(filename=io.BytesIO(tsv_content), mode="rt") as f: + # This readlines object will only be a list of bytes if the file is opened in "rb" mode. + reader = csv.DictReader(f.readlines(), delimiter="\t") # type: ignore + data = {str(row["#AlleleID"]): row for row in reader} + + finally: + csv.field_size_limit(default_csv_field_size_limit) + + return data diff --git a/src/mavedb/lib/exceptions.py b/src/mavedb/lib/exceptions.py index 8734becb..2dadeb95 100644 --- a/src/mavedb/lib/exceptions.py +++ b/src/mavedb/lib/exceptions.py @@ -168,6 +168,12 @@ class NonexistentMappingResultsError(ValueError): pass +class NonexistentMappingScoresError(ValueError): + """Raised when score set mapping results do not contain mapping scores""" + + pass + + class NonexistentMappingReferenceError(ValueError): """Raised when score set mapping results do not contain a valid reference sequence""" @@ -202,3 +208,39 @@ class UniProtPollingEnqueueError(ValueError): """Raised when a UniProt ID polling job fails to be enqueued despite appearing as if it should have been""" pass + + +class UniprotMappingResultNotFoundError(ValueError): + """Raised when no UniProt ID is found in the mapping results for a target gene.""" + + pass + + +class UniprotAmbiguousMappingResultError(ValueError): + """Raised when ambiguous UniProt IDs are found in the mapping results for a target gene.""" + + pass + + +class NonExistentTargetGeneError(ValueError): + """Raised when a target gene does not exist in the database.""" + + pass + + +class LDHSubmissionFailureError(Exception): + """Raised when submission to ClinGen Linked Data Hub (LDH) fails for all submissions.""" + + pass + + +class PipelineNotFoundError(Exception): + """Raised when a pipeline associated with a job is not found.""" + + pass + + +class NoMappedVariantsError(Exception): + """Raised when no variants were mapped during the variant mapping process.""" + + pass diff --git a/src/mavedb/lib/gnomad.py b/src/mavedb/lib/gnomad.py index 02a7da2d..ea76d613 100644 --- a/src/mavedb/lib/gnomad.py +++ b/src/mavedb/lib/gnomad.py @@ -1,19 +1,21 @@ +import logging import os import re -import logging from typing import Any, Sequence, Union -from sqlalchemy import text, select, Row +from sqlalchemy import Connection, Row, select, text from sqlalchemy.orm import Session +from mavedb.lib.annotation_status_manager import AnnotationStatusManager from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.utils import batched -from mavedb.db.athena import engine as athena_engine +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.mapped_variant import MappedVariant GNOMAD_DB_NAME = "gnomAD" -GNOMAD_DATA_VERSION = os.getenv("GNOMAD_DATA_VERSION") +GNOMAD_DATA_VERSION = os.getenv("GNOMAD_DATA_VERSION", "v4.1") # e.g., "v4.1" logger = logging.getLogger(__name__) @@ -66,7 +68,9 @@ def allele_list_from_list_like_string(alleles_string: str) -> list[str]: return alleles -def gnomad_variant_data_for_caids(caids: Sequence[str]) -> Sequence[Row[Any]]: # pragma: no cover +def gnomad_variant_data_for_caids( + athena_session: Connection, caids: Sequence[str] +) -> Sequence[Row[Any]]: # pragma: no cover """ Fetches variant rows from the gnomAD table for a list of CAIDs. Athena has a maximum character limit of 262144 in queries. CAIDs are about 12 characters long on average + 4 for two quotes, a comma and a space. Chunk our list @@ -94,36 +98,33 @@ def gnomad_variant_data_for_caids(caids: Sequence[str]) -> Sequence[Row[Any]]: caid_strs = [",".join(f"'{caid}'" for caid in chunk) for chunk in chunked_caids] save_to_logging_context({"num_caids": len(caids), "num_chunks": len(caid_strs)}) - with athena_engine.connect() as athena_connection: - logger.debug(msg="Connected to Athena", extra=logging_context()) - - result_rows: list[Row[Any]] = [] - for chunk_index, caid_str in enumerate(caid_strs): - athena_query = f""" - SELECT - "locus.contig", - "locus.position", - "alleles", - "caid", - "joint.freq.all.ac", - "joint.freq.all.an", - "joint.fafmax.faf95_max_gen_anc", - "joint.fafmax.faf95_max" - FROM - {gnomad_table_name()} - WHERE - caid IN ({caid_str}) - """ - logger.debug( - msg=f"Fetching gnomAD variants from Athena (batch {chunk_index}) with query:\n{athena_query}", - extra=logging_context(), - ) + result_rows: list[Row[Any]] = [] + for chunk_index, caid_str in enumerate(caid_strs): + athena_query = f""" + SELECT + "locus.contig", + "locus.position", + "alleles", + "caid", + "joint.freq.all.ac", + "joint.freq.all.an", + "joint.fafmax.faf95_max_gen_anc", + "joint.fafmax.faf95_max" + FROM + {gnomad_table_name()} + WHERE + caid IN ({caid_str}) + """ + logger.debug( + msg=f"Fetching gnomAD variants from Athena (batch {chunk_index}) with query:\n{athena_query}", + extra=logging_context(), + ) - result = athena_connection.execute(text(athena_query)) - rows = result.fetchall() - result_rows.extend(rows) + result = athena_session.execute(text(athena_query)) + rows = result.fetchall() + result_rows.extend(rows) - logger.debug(f"Fetched {len(rows)} gnomAD variants from Athena (batch {chunk_index}).") + logger.debug(f"Fetched {len(rows)} gnomAD variants from Athena (batch {chunk_index}).") save_to_logging_context({"num_gnomad_variant_rows_fetched": len(result_rows)}) logger.debug(msg="Done fetching gnomAD variants from Athena", extra=logging_context()) @@ -170,6 +171,7 @@ def link_gnomad_variants_to_mapped_variants( if faf95_max is not None: faf95_max = float(faf95_max) + annotation_manager = AnnotationStatusManager(db) for mapped_variant in mapped_variants_with_caids: # Remove any existing gnomAD variants for this mapped variant that match the current gnomAD data version to avoid data duplication. # There should only be one gnomAD variant per mapped variant per gnomAD data version, since each gnomAD variant can only match to one @@ -217,6 +219,18 @@ def link_gnomad_variants_to_mapped_variants( linked_gnomad_variants += 1 db.add(gnomad_variant) + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.GNOMAD_ALLELE_FREQUENCY, + version=GNOMAD_DATA_VERSION, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "success_data": { + "gnomad_db_identifier": gnomad_variant.db_identifier, + } + }, + current=True, + ) logger.debug( msg=f"Linked gnomAD variant {gnomad_variant.db_identifier} to mapped variant {mapped_variant.id} ({mapped_variant.clingen_allele_id})", diff --git a/src/mavedb/lib/logging/context.py b/src/mavedb/lib/logging/context.py index 6771f760..075efb58 100644 --- a/src/mavedb/lib/logging/context.py +++ b/src/mavedb/lib/logging/context.py @@ -55,15 +55,7 @@ def save_to_logging_context(ctx: dict) -> dict: return {} for k, v in ctx.items(): - # Don't overwrite existing context mappings but create a list if a duplicated key is added. - if k in context: - existing_ctx = context[k] - if isinstance(existing_ctx, list): - context[k].append(v) - else: - context[k] = [existing_ctx, v] - else: - context[k] = v + context[k] = v return context.data diff --git a/src/mavedb/lib/mapping.py b/src/mavedb/lib/mapping.py index d3915f53..0f601e85 100644 --- a/src/mavedb/lib/mapping.py +++ b/src/mavedb/lib/mapping.py @@ -9,6 +9,8 @@ "c": "cdna", } +EXCLUDED_PREMAPPED_ANNOTATION_KEYS = {"sequence"} + class VRSMap: url: str diff --git a/src/mavedb/lib/types/workflow.py b/src/mavedb/lib/types/workflow.py new file mode 100644 index 00000000..b0e6413e --- /dev/null +++ b/src/mavedb/lib/types/workflow.py @@ -0,0 +1,16 @@ +from typing import Any, TypedDict + +from mavedb.models.enums.job_pipeline import DependencyType + + +class JobDefinition(TypedDict): + key: str + type: str + function: str + params: dict[str, Any] + dependencies: list[tuple[str, DependencyType]] + + +class PipelineDefinition(TypedDict): + description: str + job_definitions: list[JobDefinition] diff --git a/src/mavedb/lib/urns.py b/src/mavedb/lib/urns.py index e3903ac8..55a59e70 100644 --- a/src/mavedb/lib/urns.py +++ b/src/mavedb/lib/urns.py @@ -153,3 +153,25 @@ def generate_calibration_urn(): :return: A new calibration URN """ return f"urn:mavedb:calibration-{uuid4()}" + + +def generate_pipeline_urn(): + """ + Generate a new URN for a pipeline. + + Pipeline URNs include a 16-digit UUID. + + :return: A new pipeline URN + """ + return f"urn:mavedb:pipeline-{uuid4()}" + + +def generate_job_run_urn(): + """ + Generate a new URN for a job run. + + Job run URNs include a 16-digit UUID. + + :return: A new job run URN + """ + return f"urn:mavedb:job-{uuid4()}" diff --git a/src/mavedb/lib/workflow/__init__.py b/src/mavedb/lib/workflow/__init__.py new file mode 100644 index 00000000..65be1386 --- /dev/null +++ b/src/mavedb/lib/workflow/__init__.py @@ -0,0 +1,9 @@ +from .definitions import PIPELINE_DEFINITIONS +from .job_factory import JobFactory +from .pipeline_factory import PipelineFactory + +__all__ = [ + "JobFactory", + "PipelineFactory", + "PIPELINE_DEFINITIONS", +] diff --git a/src/mavedb/lib/workflow/definitions.py b/src/mavedb/lib/workflow/definitions.py new file mode 100644 index 00000000..72c83e42 --- /dev/null +++ b/src/mavedb/lib/workflow/definitions.py @@ -0,0 +1,238 @@ +from mavedb.lib.types.workflow import JobDefinition, PipelineDefinition +from mavedb.models.enums.job_pipeline import DependencyType, JobType + +# As a general rule, job keys should match function names for clarity. In some cases of +# repeated jobs, a suffix may be added to the key for uniqueness. + + +def annotation_pipeline_job_definitions() -> list[JobDefinition]: + return [ + { + "key": "submit_score_set_mappings_to_car", + "function": "submit_score_set_mappings_to_car", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "link_gnomad_variants", + "function": "link_gnomad_variants", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "submit_uniprot_mapping_jobs_for_score_set", + "function": "submit_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "poll_uniprot_mapping_jobs_for_score_set", + "function": "poll_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "mapping_jobs": {}, # Required param to be filled in at runtime by previous job + }, + "dependencies": [("submit_uniprot_mapping_jobs_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + # TODO#650: Simplify or automate the generation of these repetitive job definitions + { + "key": "refresh_clinvar_controls_201502", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2015, + "month": 2, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201601", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2016, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201701", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2017, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201801", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2018, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201901", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2019, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202001", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2020, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202101", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2021, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202201", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2022, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202301", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2023, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202401", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2024, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202501", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2025, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202601", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2026, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + ] + + +PIPELINE_DEFINITIONS: dict[str, PipelineDefinition] = { + "validate_map_annotate_score_set": { + "description": "Pipeline to validate, map, and annotate variants for a score set.", + "job_definitions": [ + { + "key": "create_variants_for_score_set", + "function": "create_variants_for_score_set", + "type": JobType.VARIANT_CREATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + "scores_file_key": None, # Required param to be filled in at runtime + "counts_file_key": None, # Required param to be filled in at runtime + "score_columns_metadata": None, # Required param to be filled in at runtime + "count_columns_metadata": None, # Required param to be filled in at runtime + }, + "dependencies": [], + }, + { + "key": "map_variants_for_score_set", + "function": "map_variants_for_score_set", + "type": JobType.VARIANT_MAPPING, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("create_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + *annotation_pipeline_job_definitions(), + ], + }, + "annotate_score_set": { + "description": "Pipeline to annotate variants for a score set.", + "job_definitions": annotation_pipeline_job_definitions(), + }, + # Add more pipelines here +} diff --git a/src/mavedb/lib/workflow/job_factory.py b/src/mavedb/lib/workflow/job_factory.py new file mode 100644 index 00000000..556c9c09 --- /dev/null +++ b/src/mavedb/lib/workflow/job_factory.py @@ -0,0 +1,102 @@ +from copy import deepcopy +from typing import Optional + +from sqlalchemy.orm import Session + +from mavedb import __version__ as mavedb_version +from mavedb.lib.types.workflow import JobDefinition +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun + + +class JobFactory: + """ + JobFactory is responsible for creating and persisting JobRun instances based on + provided job definitions and pipeline parameters. + + Attributes: + session (Session): The SQLAlchemy session used for database operations. + + Methods: + create_job_run(job_def: JobDefinition, pipeline_id: Optional[int], user_id: int, correlation_id: str, pipeline_params: dict) -> JobRun:""" + + def __init__(self, session: Session): + self.session = session + + def create_job_run( + self, job_def: JobDefinition, correlation_id: str, pipeline_params: dict, pipeline_id: Optional[int] = None + ) -> JobRun: + """ + Creates and persists a new JobRun instance based on the provided job definition and pipeline parameters. + + Args: + job_def (JobDefinition): The job definition containing job type, function, and parameter template. + pipeline_id (Optional[int]): The ID of the pipeline this job run is associated with. + correlation_id (str): A unique identifier for correlating this job run with external systems or logs. + pipeline_params (dict): A dictionary of parameters to fill in required job parameters and allow for extensibility. + + Returns: + JobRun: The newly created JobRun instance (not yet committed to the database). + + Raises: + ValueError: If any required parameter defined in the job definition is missing from pipeline_params. + """ + job_params = deepcopy(job_def["params"]) + + # Fill in required params from pipeline_params + for key in job_params: + if job_params[key] is None: + if key not in pipeline_params: + raise ValueError(f"Missing required param: {key}") + job_params[key] = pipeline_params[key] + + job_run = JobRun( + job_type=job_def["type"], + job_function=job_def["function"], + job_params=job_params, + pipeline_id=pipeline_id, + mavedb_version=mavedb_version, + correlation_id=correlation_id, + ) # type: ignore[call-arg] + + self.session.add(job_run) + return job_run + + def create_job_dependency( + self, + parent_job_run_id: int, + child_job_run_id: int, + dependency_type: DependencyType = DependencyType.SUCCESS_REQUIRED, + ) -> JobDependency: + """ + Creates and persists a JobDependency instance linking a parent job run to a child job run. + + Args: + parent_job_run_id (int): The ID of the parent job run. + child_job_run_id (int): The ID of the child job run. + dependency_type (DependencyType): The type of dependency (default is SUCCESS_REQUIRED). + + Returns: + JobDependency: The newly created JobDependency instance (not yet committed to the database). + + Raises: + ValueError: If the parent or child job run IDs do not exist in the database. + """ + + # Validate that the parent and child job runs exist + parent_exists = self.session.query(JobRun.id).filter(JobRun.id == parent_job_run_id).first() is not None + child_exists = self.session.query(JobRun.id).filter(JobRun.id == child_job_run_id).first() is not None + if not parent_exists: + raise ValueError(f"Parent job run ID {parent_job_run_id} does not exist.") + if not child_exists: + raise ValueError(f"Child job run ID {child_job_run_id} does not exist.") + + job_dependency = JobDependency( + id=child_job_run_id, + depends_on_job_id=parent_job_run_id, + dependency_type=dependency_type, + ) # type: ignore[call-arg] + + self.session.add(job_dependency) + return job_dependency diff --git a/src/mavedb/lib/workflow/pipeline_factory.py b/src/mavedb/lib/workflow/pipeline_factory.py new file mode 100644 index 00000000..42ec1e00 --- /dev/null +++ b/src/mavedb/lib/workflow/pipeline_factory.py @@ -0,0 +1,116 @@ +from sqlalchemy.orm import Session + +from mavedb import __version__ as mavedb_version +from mavedb.lib.logging.context import correlation_id_for_context +from mavedb.lib.workflow.definitions import PIPELINE_DEFINITIONS +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.models.enums.job_pipeline import JobType +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.models.user import User + + +class PipelineFactory: + """ + PipelineFactory is responsible for creating Pipeline instances and their associated JobRun and JobDependency records in the database. + + Attributes: + session (Session): The SQLAlchemy session used for database operations. + + Methods: + __init__(session: Session): + Initializes the PipelineFactory with a database session. + + create_pipeline( + pipeline_name: str, + pipeline_description: Optional[str], + creating_user: User, + pipeline_params: dict + ) -> Pipeline: + Creates a new Pipeline along with its JobRun and JobDependency records, + commits them to the database, and returns the created Pipeline object. + """ + + def __init__(self, session: Session): + self.session = session + + def create_pipeline( + self, pipeline_name: str, creating_user: User, pipeline_params: dict + ) -> tuple[Pipeline, JobRun]: + """ + Creates a new Pipeline instance along with its associated JobRun and JobDependency records. + + Args: + pipeline_name (str): The name of the pipeline to create. + pipeline_description (Optional[str]): A description for the pipeline. + creating_user (User): The user object representing the user creating the pipeline. + pipeline_params (dict): Additional parameters for pipeline creation, such as correlation_id. + + Returns: + Pipeline: The created Pipeline object. + JobRun: The JobRun object representing the start of the pipeline. + + Raises: + KeyError: If the specified pipeline_name is not found in PIPELINE_DEFINITIONS. + Exception: If there is an error during database operations. + + Side Effects: + - Adds and commits new Pipeline, JobRun, and JobDependency records to the database session. + """ + pipeline_def = PIPELINE_DEFINITIONS[pipeline_name] + jobs = pipeline_def["job_definitions"] + job_runs: dict[str, JobRun] = {} + + correlation_id = pipeline_params.get("correlation_id", correlation_id_for_context()) + + pipeline = Pipeline( + name=pipeline_name, + description=pipeline_def["description"], + correlation_id=correlation_id, + created_by_user_id=creating_user.id, + mavedb_version=mavedb_version, + ) # type: ignore[call-arg] + self.session.add(pipeline) + self.session.flush() # To get pipeline.id + + start_pipeline_job = JobRun( + job_type=JobType.PIPELINE_MANAGEMENT, + job_function="start_pipeline", + job_params={}, + pipeline_id=pipeline.id, + mavedb_version=mavedb_version, + correlation_id=correlation_id, + ) # type: ignore[call-arg] + self.session.add(start_pipeline_job) + self.session.flush() # to get start_pipeline_job.id + + job_factory = JobFactory(self.session) + for job_def in jobs: + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=pipeline.id, + correlation_id=correlation_id, + pipeline_params=pipeline_params, + ) + job_runs[job_def["key"]] = job_run + + self.session.flush() # to get job_run IDs + + for job_def in jobs: + job_deps = job_def["dependencies"] + + job_run = job_runs[job_def["key"]] + for dep_key, dependency_type in job_deps: + dep_job_run = job_runs[dep_key] + + dep_job = JobDependency( + id=job_run.id, + depends_on_job_id=dep_job_run.id, + dependency_type=dependency_type, + ) # type: ignore[call-arg] + + self.session.add(dep_job) + + self.session.commit() + return pipeline, start_pipeline_job diff --git a/src/mavedb/lib/workflow/py.typed b/src/mavedb/lib/workflow/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/models/__init__.py b/src/mavedb/models/__init__.py index 684b3c98..191fdc51 100644 --- a/src/mavedb/models/__init__.py +++ b/src/mavedb/models/__init__.py @@ -10,9 +10,12 @@ "experiment_set", "genome_identifier", "gnomad_variant", + "job_dependency", + "job_run", "legacy_keyword", "license", "mapped_variant", + "pipeline", "publication_identifier", "published_variant", "raw_read_identifier", @@ -27,6 +30,7 @@ "uniprot_identifier", "uniprot_offset", "user", + "variant_annotation_status", "variant", "variant_translation", ] diff --git a/src/mavedb/models/enums/__init__.py b/src/mavedb/models/enums/__init__.py index e69de29b..80c3a7de 100644 --- a/src/mavedb/models/enums/__init__.py +++ b/src/mavedb/models/enums/__init__.py @@ -0,0 +1,25 @@ +""" +Enums used by MaveDB models. +""" + +from .contribution_role import ContributionRole +from .job_pipeline import AnnotationStatus, DependencyType, FailureCategory, JobStatus, PipelineStatus +from .mapping_state import MappingState +from .processing_state import ProcessingState +from .score_calibration_relation import ScoreCalibrationRelation +from .target_category import TargetCategory +from .user_role import UserRole + +__all__ = [ + "ContributionRole", + "JobStatus", + "PipelineStatus", + "DependencyType", + "FailureCategory", + "AnnotationStatus", + "MappingState", + "ProcessingState", + "ScoreCalibrationRelation", + "TargetCategory", + "UserRole", +] diff --git a/src/mavedb/models/enums/annotation_type.py b/src/mavedb/models/enums/annotation_type.py new file mode 100644 index 00000000..b1595347 --- /dev/null +++ b/src/mavedb/models/enums/annotation_type.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class AnnotationType(str, Enum): + VRS_MAPPING = "vrs_mapping" + CLINGEN_ALLELE_ID = "clingen_allele_id" + MAPPED_HGVS = "mapped_hgvs" + VARIANT_TRANSLATION = "variant_translation" + GNOMAD_ALLELE_FREQUENCY = "gnomad_allele_frequency" + CLINVAR_CONTROL = "clinvar_control" + VEP_FUNCTIONAL_CONSEQUENCE = "vep_functional_consequence" + LDH_SUBMISSION = "ldh_submission" diff --git a/src/mavedb/models/enums/job_pipeline.py b/src/mavedb/models/enums/job_pipeline.py new file mode 100644 index 00000000..8a70eb3f --- /dev/null +++ b/src/mavedb/models/enums/job_pipeline.py @@ -0,0 +1,93 @@ +""" +Job and pipeline related enums. +""" + +from enum import Enum + + +class JobStatus(str, Enum): + """Status of a job execution.""" + + SUCCEEDED = "succeeded" + FAILED = "failed" + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + CANCELLED = "cancelled" + SKIPPED = "skipped" + + +class PipelineStatus(str, Enum): + """Status of a pipeline execution.""" + + SUCCEEDED = "succeeded" + FAILED = "failed" + CREATED = "created" + RUNNING = "running" + PAUSED = "paused" + CANCELLED = "cancelled" + PARTIAL = "partial" # Pipeline completed with mixed results (some succeeded, some skipped/cancelled) + + +class DependencyType(str, Enum): + """Types of job dependencies.""" + + SUCCESS_REQUIRED = "success_required" # Job only runs if dependency succeeded + COMPLETION_REQUIRED = "completion_required" # Job runs if dependency completed (success OR failure) + + +class FailureCategory(str, Enum): + """Categories of job failures for better classification and handling.""" + + # System-level failures + SYSTEM_ERROR = "system_error" + TIMEOUT = "timeout" + RESOURCE_EXHAUSTION = "resource_exhaustion" + CONFIGURATION_ERROR = "configuration_error" + DEPENDENCY_FAILURE = "dependency_failure" + + # Queue and scheduling failures + ENQUEUE_ERROR = "enqueue_error" + SCHEDULING_ERROR = "scheduling_error" + CANCELLED = "cancelled" + + # Data and validation failures + VALIDATION_ERROR = "validation_error" + DATA_ERROR = "data_error" + + # External service failures + NETWORK_ERROR = "network_error" + API_RATE_LIMITED = "api_rate_limited" + SERVICE_UNAVAILABLE = "service_unavailable" + AUTHENTICATION_FAILED = "authentication_failed" + + # Permission and access failures + PERMISSION_ERROR = "permission_error" + QUOTA_EXCEEDED = "quota_exceeded" + + # Variant processing specific + INVALID_HGVS = "invalid_hgvs" + REFERENCE_MISMATCH = "reference_mismatch" + VRS_MAPPING_FAILED = "vrs_mapping_failed" + TRANSCRIPT_NOT_FOUND = "transcript_not_found" + + # Catch-all + UNKNOWN = "unknown" + + +class AnnotationStatus(str, Enum): + """Status of individual variant annotations.""" + + SUCCESS = "success" + FAILED = "failed" + SKIPPED = "skipped" + + +class JobType(str, Enum): + """Types of jobs in the pipeline.""" + + VARIANT_CREATION = "variant_creation" + VARIANT_MAPPING = "variant_mapping" + MAPPED_VARIANT_ANNOTATION = "mapped_variant_annotation" + PIPELINE_MANAGEMENT = "pipeline_management" + DATA_MANAGEMENT = "data_management" diff --git a/src/mavedb/models/job_dependency.py b/src/mavedb/models/job_dependency.py new file mode 100644 index 00000000..ac851c7d --- /dev/null +++ b/src/mavedb/models/job_dependency.py @@ -0,0 +1,65 @@ +""" +SQLAlchemy models for job dependencies. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.models.enums import DependencyType + +if TYPE_CHECKING: + from mavedb.models.job_run import JobRun + + +class JobDependency(Base): + """ + Defines dependencies between jobs within a pipeline. + + This table maps jobs to their pipeline and defines execution order. + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. + """ + + __tablename__ = "job_dependencies" + + # The job being defined (references job_runs.id). Composite primary key with the dependency we are defining. + id: Mapped[int] = mapped_column(Integer, ForeignKey("job_runs.id", ondelete="CASCADE"), primary_key=True) + depends_on_job_id: Mapped[int] = mapped_column( + Integer, ForeignKey("job_runs.id", ondelete="CASCADE"), nullable=False, primary_key=True + ) + + # Type of dependency + dependency_type: Mapped[Optional[DependencyType]] = mapped_column(String(50), nullable=False) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + # Flexible metadata + metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column( + "metadata", MutableDict.as_mutable(JSONB), nullable=True + ) + + # Relationships + job_run: Mapped["JobRun"] = relationship("JobRun", back_populates="job_dependencies", foreign_keys=[id]) + depends_on_job: Mapped["JobRun"] = relationship("JobRun", foreign_keys=[depends_on_job_id], remote_side="JobRun.id") + + # Indexes + __table_args__ = ( + Index("ix_job_dependencies_depends_on_job_id", "depends_on_job_id"), + Index("ix_job_dependencies_created_at", "created_at"), + CheckConstraint( + "dependency_type IS NULL OR dependency_type IN ('success_required', 'completion_required')", + name="ck_job_dependencies_type_valid", + ), + ) + + def __repr__(self) -> str: + return f"" diff --git a/src/mavedb/models/job_run.py b/src/mavedb/models/job_run.py new file mode 100644 index 00000000..9ec039cd --- /dev/null +++ b/src/mavedb/models/job_run.py @@ -0,0 +1,112 @@ +""" +SQLAlchemy models for job runs. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.lib.urns import generate_job_run_urn +from mavedb.models.enums import JobStatus + +if TYPE_CHECKING: + from mavedb.models.job_dependency import JobDependency + from mavedb.models.pipeline import Pipeline + + +class JobRun(Base): + """ + Represents a single execution of a job. + + Jobs can be retried, so there may be multiple JobRun records for the same logical job. + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. + """ + + __tablename__ = "job_runs" + + # Primary identification + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + urn: Mapped[str] = mapped_column(String(255), nullable=True, unique=True, default=generate_job_run_urn) + + # Job definition + job_type: Mapped[str] = mapped_column(String(100), nullable=False) + job_function: Mapped[str] = mapped_column(String(255), nullable=False) + job_params: Mapped[Optional[Dict[str, Any]]] = mapped_column(MutableDict.as_mutable(JSONB), nullable=True) + + # Execution tracking + status: Mapped[JobStatus] = mapped_column(String(50), nullable=False, default=JobStatus.PENDING) + + # Pipeline association + pipeline_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("pipelines.id", ondelete="SET NULL"), nullable=True + ) + + # Priority and scheduling + priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + max_retries: Mapped[int] = mapped_column(Integer, nullable=False, default=3) + retry_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + retry_delay_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + + # Timing + scheduled_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + # Error handling + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + error_traceback: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + failure_category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + # Progress tracking + progress_current: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + progress_total: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + progress_message: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + + # Correlation for tracing + correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # Flexible metadata + metadata_: Mapped[Dict[str, Any]] = mapped_column( + "metadata", MutableDict.as_mutable(JSONB), nullable=False, server_default="{}" + ) + + # Version tracking + mavedb_version: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + + # Relationships + job_dependencies: Mapped[list["JobDependency"]] = relationship( + "JobDependency", back_populates="job_run", uselist=True, foreign_keys="[JobDependency.id]" + ) + pipeline: Mapped[Optional["Pipeline"]] = relationship( + "Pipeline", back_populates="job_runs", foreign_keys="[JobRun.pipeline_id]" + ) + + # Indexes + __table_args__ = ( + Index("ix_job_runs_status", "status"), + Index("ix_job_runs_job_type", "job_type"), + Index("ix_job_runs_pipeline_id", "pipeline_id"), + Index("ix_job_runs_scheduled_at", "scheduled_at"), + Index("ix_job_runs_created_at", "created_at"), + Index("ix_job_runs_correlation_id", "correlation_id"), + Index("ix_job_runs_status_scheduled", "status", "scheduled_at"), + CheckConstraint( + "status IN ('pending', 'queued', 'running', 'succeeded', 'failed', 'cancelled', 'skipped')", + name="ck_job_runs_status_valid", + ), + CheckConstraint("priority >= 0", name="ck_job_runs_priority_positive"), + CheckConstraint("max_retries >= 0", name="ck_job_runs_max_retries_positive"), + CheckConstraint("retry_count >= 0", name="ck_job_runs_retry_count_positive"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/src/mavedb/models/pipeline.py b/src/mavedb/models/pipeline.py new file mode 100644 index 00000000..717ec24c --- /dev/null +++ b/src/mavedb/models/pipeline.py @@ -0,0 +1,89 @@ +""" +SQLAlchemy models for job pipelines. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.lib.urns import generate_pipeline_urn +from mavedb.models.enums import PipelineStatus +from mavedb.models.job_run import JobRun + +if TYPE_CHECKING: + from mavedb.models.user import User + + +class Pipeline(Base): + """ + Represents a high-level workflow that groups related jobs. + + Examples: + - Processing a score set upload + - Batch re-annotation of variants + - Database migration workflows + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. + """ + + __tablename__ = "pipelines" + + # Primary identification + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + urn: Mapped[str] = mapped_column(String(255), nullable=True, unique=True, default=generate_pipeline_urn) + name: Mapped[str] = mapped_column(String(500), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Status and lifecycle + status: Mapped[PipelineStatus] = mapped_column(String(50), nullable=False, default=PipelineStatus.CREATED) + + # Correlation for end-to-end tracing + correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # Flexible metadata storage + metadata_: Mapped[Dict[str, Any]] = mapped_column( + "metadata", + MutableDict.as_mutable(JSONB), + nullable=False, + comment="Flexible metadata storage for pipeline-specific data", + server_default="{}", + ) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # User tracking + created_by_user_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + + # Version tracking + mavedb_version: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + + # Relationships + job_runs: Mapped[List["JobRun"]] = relationship("JobRun", back_populates="pipeline", cascade="all, delete-orphan") + created_by_user: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by_user_id]) + + # Indexes + __table_args__ = ( + Index("ix_pipelines_status", "status"), + Index("ix_pipelines_created_at", "created_at"), + Index("ix_pipelines_correlation_id", "correlation_id"), + Index("ix_pipelines_created_by_user_id", "created_by_user_id"), + CheckConstraint( + "status IN ('created', 'running', 'succeeded', 'failed', 'cancelled', 'paused', 'partial')", + name="ck_pipelines_status_valid", + ), + ) + + def __repr__(self) -> str: + return f"" diff --git a/src/mavedb/models/variant_annotation_status.py b/src/mavedb/models/variant_annotation_status.py new file mode 100644 index 00000000..3051b4d3 --- /dev/null +++ b/src/mavedb/models/variant_annotation_status.py @@ -0,0 +1,113 @@ +""" +SQLAlchemy models for variant annotation status. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.models.enums.job_pipeline import AnnotationStatus + +if TYPE_CHECKING: + from mavedb.models.job_run import JobRun + from mavedb.models.variant import Variant + + +class VariantAnnotationStatus(Base): + """ + Tracks annotation status for individual variants. + + Allows us to see which variants failed annotation and why. + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. + """ + + __tablename__ = "variant_annotation_status" + + # Primary key + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + # Composite primary key + variant_id: Mapped[int] = mapped_column(Integer, ForeignKey("variants.id", ondelete="CASCADE"), primary_key=True) + annotation_type: Mapped[str] = mapped_column( + String(50), primary_key=True, comment="Type of annotation: vrs, clinvar, gnomad, etc." + ) + + # Source version + version: Mapped[Optional[str]] = mapped_column( + String(50), nullable=True, comment="Version of the annotation source used (if applicable)" + ) + + # Status tracking + status: Mapped[AnnotationStatus] = mapped_column(String(50), nullable=False, comment="success, failed, skipped") + + # Error information + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + failure_category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + # Success data (flexible JSONB for annotation results) + success_data: Mapped[Optional[Dict[str, Any]]] = mapped_column( + MutableDict.as_mutable(JSONB), nullable=True, comment="Annotation results when successful" + ) + + # Current flag + current: Mapped[bool] = mapped_column( + nullable=False, + server_default="true", + comment="Whether this is the current status for the variant and annotation type", + ) + + # Job tracking + job_run_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("job_runs.id", ondelete="SET NULL"), nullable=True + ) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + # Relationships + variant: Mapped["Variant"] = relationship("Variant") + job_run: Mapped[Optional["JobRun"]] = relationship("JobRun") + + # Indexes + __table_args__ = ( + Index("ix_variant_annotation_status_variant_id", "variant_id"), + Index("ix_variant_annotation_status_annotation_type", "annotation_type"), + Index("ix_variant_annotation_status_status", "status"), + Index("ix_variant_annotation_status_job_run_id", "job_run_id"), + Index("ix_variant_annotation_status_created_at", "created_at"), + # Composite index for common queries + Index("ix_variant_annotation_variant_type_status", "variant_id", "annotation_type", "status"), + Index("ix_variant_annotation_type_status", "annotation_type", "status"), + Index("ix_variant_annotation_status_current", "current"), + Index("ix_variant_annotation_status_version", "version"), + Index( + "ix_variant_annotation_status_variant_type_version_current", + "variant_id", + "annotation_type", + "version", + "current", + ), + CheckConstraint( + "annotation_type IN ('vrs_mapping', 'clingen_allele_id', 'mapped_hgvs', 'variant_translation', 'gnomad_allele_frequency', 'clinvar_control', 'vep_functional_consequence', 'ldh_submission')", + name="ck_variant_annotation_type_valid", + ), + CheckConstraint( + "status IN ('success', 'failed', 'skipped')", + name="ck_variant_annotation_status_valid", + ), + ## Although un-enforced at the DB level, we should ensure only one 'current' record per (variant_id, annotation_type, version) + ) + + def __repr__(self) -> str: + return f"" diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 959f9133..395baedf 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -1,3 +1,4 @@ +import io import json import logging import time @@ -20,6 +21,7 @@ from sqlalchemy.orm import Session, contains_eager from mavedb import deps +from mavedb.data_providers.services import CSV_UPLOAD_S3_BUCKET_NAME, s3_client from mavedb.lib.annotation.annotate import ( variant_functional_impact_statement, variant_pathogenicity_evidence, @@ -66,6 +68,7 @@ generate_experiment_urn, generate_score_set_urn, ) +from mavedb.lib.workflow.pipeline_factory import PipelineFactory from mavedb.models.clinical_control import ClinicalControl from mavedb.models.contributor import Contributor from mavedb.models.enums.processing_state import ProcessingState @@ -111,6 +114,7 @@ async def enqueue_variant_creation( new_score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, new_count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, worker: ArqRedis, + db: Session, ) -> None: assert item.dataset_columns is not None @@ -136,25 +140,67 @@ async def enqueue_variant_creation( variants_to_csv_rows(item.variants, columns=count_columns, namespaced=False) ).replace("NA", np.NaN) + scores_file_to_upload = existing_scores_df if new_scores_df is None else new_scores_df + counts_file_to_upload = existing_counts_df if new_counts_df is None else new_counts_df + + scores_file_key = None + counts_file_key = None + if scores_file_to_upload is not None or counts_file_to_upload is not None: + timestamp = date.today().isoformat() + unique_id = str(int(time.time() * 1000)) + user_id = user_data.user.id + score_set_id = item.id + + s3 = s3_client() + + if scores_file_to_upload is not None: + save_to_logging_context({"num_scores": len(scores_file_to_upload)}) + scores_file_key = f"{score_set_id}/{user_id}/{timestamp}-{unique_id}-scores.csv" + s3.upload_fileobj( + Fileobj=io.BytesIO(scores_file_to_upload.to_csv(index=False).encode("utf-8")), + Bucket=CSV_UPLOAD_S3_BUCKET_NAME, + Key=scores_file_key, + ) + + if counts_file_to_upload is not None: + save_to_logging_context({"num_counts": len(counts_file_to_upload)}) + counts_file_key = f"{score_set_id}/{user_id}/{timestamp}-{unique_id}-counts.csv" + s3.upload_fileobj( + Fileobj=io.BytesIO(counts_file_to_upload.to_csv(index=False).encode("utf-8")), + Bucket=CSV_UPLOAD_S3_BUCKET_NAME, + Key=counts_file_key, + ) + + pipeline_factory = PipelineFactory(session=db) + pipeline, pipeline_entrypoint = pipeline_factory.create_pipeline( + pipeline_name="validate_map_annotate_score_set", + creating_user=user_data.user, + pipeline_params={ + "correlation_id": correlation_id_for_context(), + "score_set_id": item.id, + "updater_id": user_data.user.id, + "scores_file_key": scores_file_key, + "counts_file_key": counts_file_key, + "score_columns_metadata": item.dataset_columns.get("score_columns_metadata") + if new_score_columns_metadata is None + else new_score_columns_metadata, + "count_columns_metadata": item.dataset_columns.get("count_columns_metadata") + if new_count_columns_metadata is None + else new_count_columns_metadata, + }, + ) + # Await the insertion of this job into the worker queue, not the job itself. # Uses provided score and counts dataframes and metadata files, or falls back to existing data on the score set if not provided. job = await worker.enqueue_job( - "create_variants_for_score_set", - correlation_id_for_context(), - item.id, - user_data.user.id, - existing_scores_df if new_scores_df is None else new_scores_df, - existing_counts_df if new_counts_df is None else new_counts_df, - item.dataset_columns.get("score_columns_metadata") - if new_score_columns_metadata is None - else new_score_columns_metadata, - item.dataset_columns.get("count_columns_metadata") - if new_count_columns_metadata is None - else new_count_columns_metadata, + pipeline_entrypoint.job_function, pipeline_entrypoint.id, _job_id=pipeline_entrypoint.urn ) if job is not None: save_to_logging_context({"worker_job_id": job.job_id}) - logger.info(msg="Enqueued variant creation job.", extra=logging_context()) + logger.info( + msg="Enqueued validate_map_annotate_score_set pipeline (job_id: {}).".format(job.job_id), + extra=logging_context(), + ) class ScoreSetUpdateResult(TypedDict): @@ -1747,6 +1793,7 @@ async def upload_score_set_variant_data( new_score_columns_metadata=dataset_column_metadata.get("score_columns_metadata", {}), new_count_columns_metadata=dataset_column_metadata.get("count_columns_metadata", {}), worker=worker, + db=db, ) db.add(item) @@ -1871,6 +1918,7 @@ async def update_score_set_with_variants( new_count_columns_metadata=dataset_column_metadata.get("count_columns_metadata") if did_count_columns_metadata_change else existing_count_columns_metadata, + db=db, ) db.add(updatedItem) @@ -1918,7 +1966,12 @@ async def update_score_set( updatedItem.processing_state = ProcessingState.processing logger.info(msg="Enqueuing variant creation job.", extra=logging_context()) - await enqueue_variant_creation(item=updatedItem, user_data=user_data, worker=worker) + await enqueue_variant_creation( + item=updatedItem, + user_data=user_data, + worker=worker, + db=db, + ) db.add(updatedItem) db.commit() diff --git a/src/mavedb/scripts/clingen_car_submission.py b/src/mavedb/scripts/clingen_car_submission.py index 29ea5fd8..492c6c3e 100644 --- a/src/mavedb/scripts/clingen_car_submission.py +++ b/src/mavedb/scripts/clingen_car_submission.py @@ -1,133 +1,72 @@ -import click +import datetime import logging from typing import Sequence + +import asyncclick as click from sqlalchemy import select -from sqlalchemy.orm import Session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant -from mavedb.scripts.environment import with_database_session -from mavedb.lib.clingen.services import ClinGenAlleleRegistryService, get_allele_registry_associations -from mavedb.lib.clingen.constants import CAR_SUBMISSION_ENDPOINT -from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.worker.jobs.external_services.clingen import submit_score_set_mappings_to_car +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) -def submit_urns_to_car(db: Session, urns: Sequence[str], debug: bool) -> list[str]: - if not CAR_SUBMISSION_ENDPOINT: - logger.error("`CAR_SUBMISSION_ENDPOINT` is not set. Please check your configuration.") - return [] - - car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) - submitted_entities = [] - - if debug: - logger.debug("Debug mode enabled. Submitting only one request to ClinGen CAR.") - urns = urns[:1] - - for idx, urn in enumerate(urns): - logger.info(f"Processing URN: {urn}. (Scoreset {idx + 1}/{len(urns)})") - try: - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == urn)).one_or_none() - if not score_set: - logger.warning(f"No score set found for URN: {urn}") - continue - - logger.info(f"Submitting mapped variants to CAR service for score set with URN: {urn}") - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant, MappedVariant.variant_id == Variant.id) - .join(ScoreSet) - .where(ScoreSet.urn == urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() - - if not variant_objects: - logger.warning(f"No mapped variants found for score set with URN: {urn}") - continue - - if debug: - logger.debug(f"Debug mode enabled. Submitting only one variant to ClinGen CAR for URN: {urn}") - variant_objects = variant_objects[:1] - - logger.debug(f"Preparing {len(variant_objects)} mapped variants for CAR submission") - hgvs_to_mapped_variant: dict[str, list[int]] = {} - for variant, mapped_variant in variant_objects: - hgvs = get_hgvs_from_post_mapped(mapped_variant.post_mapped) - if hgvs and hgvs not in hgvs_to_mapped_variant: - hgvs_to_mapped_variant[hgvs] = [mapped_variant.id] - elif hgvs and hgvs in hgvs_to_mapped_variant: - hgvs_to_mapped_variant[hgvs].append(mapped_variant.id) - else: - logger.warning(f"No HGVS string found for mapped variant {variant.urn}") - - if not hgvs_to_mapped_variant: - logger.warning(f"No HGVS strings to submit for URN: {urn}") - continue - - logger.info(f"Submitting {len(hgvs_to_mapped_variant)} HGVS strings to CAR service for URN: {urn}") - response = car_service.dispatch_submissions(list(hgvs_to_mapped_variant.keys())) - - if not response: - logger.error(f"CAR submission failed for URN: {urn}") - else: - logger.info(f"Successfully submitted to CAR for URN: {urn}") - # Associate CAIDs with mapped variants - associations = get_allele_registry_associations(list(hgvs_to_mapped_variant.keys()), response) - for hgvs, caid in associations.items(): - mapped_variant_ids = hgvs_to_mapped_variant.get(hgvs, []) - for mv_id in mapped_variant_ids: - mapped_variant = db.scalar(select(MappedVariant).where(MappedVariant.id == mv_id)) - if not mapped_variant: - logger.warning(f"Mapped variant with ID {mv_id} not found for HGVS {hgvs}.") - continue - - mapped_variant.clingen_allele_id = caid - db.add(mapped_variant) - - submitted_entities.extend([variant.urn for variant, _ in variant_objects]) - - except Exception as e: - logger.error(f"Error processing URN {urn}", exc_info=e) - - return submitted_entities - - @click.command() -@with_database_session @click.argument("urns", nargs=-1) @click.option("--all", help="Submit variants for every score set in MaveDB.", is_flag=True) -@click.option("--suppress-output", help="Suppress final print output to the console.", is_flag=True) -@click.option("--debug", help="Enable debug mode. This will send only one request at most to ClinGen CAR", is_flag=True) -def submit_car_urns_command( - db: Session, - urns: Sequence[str], - all: bool, - suppress_output: bool, - debug: bool, -) -> None: +async def main(urns: Sequence[str], all: bool) -> None: """ Submit data to ClinGen Allele Registry for mapped variant CAID generation for the given URNs. """ + db = SessionLocal() + if urns and all: logger.error("Cannot provide both URNs and --all option.") return if all: - urns = db.scalars(select(ScoreSet.urn)).all() # type: ignore - - if not urns: - logger.error("No URNs provided. Please provide at least one URN.") - return - - submitted_variant_urns = submit_urns_to_car(db, urns, debug) - - if not suppress_output: - print(", ".join(submitted_variant_urns)) + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info(f"Command invoked with --all. Routine will submit CAR data for {len(score_set_ids)} score sets.") + else: + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Submitting CAR data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for CAR submission + job_def = STANDALONE_JOB_DEFINITIONS[submit_score_set_mappings_to_car] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Submitting CAR data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await submit_score_set_mappings_to_car(ctx, job_run.id) # type: ignore if __name__ == "__main__": - submit_car_urns_command() + main() diff --git a/src/mavedb/scripts/clingen_ldh_submission.py b/src/mavedb/scripts/clingen_ldh_submission.py index 94f16520..17178287 100644 --- a/src/mavedb/scripts/clingen_ldh_submission.py +++ b/src/mavedb/scripts/clingen_ldh_submission.py @@ -1,19 +1,18 @@ -import click +import datetime import logging import re -from typing import Optional, Sequence +from typing import Sequence -from sqlalchemy import and_, select +import click +from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant -from mavedb.scripts.environment import with_database_session -from mavedb.lib.clingen.services import ClinGenLdhService -from mavedb.lib.clingen.constants import DEFAULT_LDH_SUBMISSION_BATCH_SIZE, LDH_SUBMISSION_ENDPOINT -from mavedb.lib.clingen.content_constructors import construct_ldh_submission -from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.worker.jobs.external_services.clingen import submit_score_set_mappings_to_ldh +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) @@ -21,177 +20,58 @@ variant_with_reference_regex = re.compile(r":") -def submit_urns_to_clingen( - db: Session, urns: Sequence[str], unlinked_only: bool, prefer_unmapped_hgvs: bool, debug: bool -) -> list[str]: - ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) - ldh_service.authenticate() - - submitted_entities = [] - - if debug: - logger.debug("Debug mode enabled. Submitting only one request to ClinGen.") - urns = urns[:1] - - for idx, urn in enumerate(urns): - logger.info(f"Processing URN: {urn}. (Scoreset {idx + 1}/{len(urns)})") - - try: - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == urn)).one_or_none() - if not score_set: - logger.warning(f"No score set found for URN: {urn}") - continue - - logger.info(f"Submitting mapped variants to LDH service for score set with URN: {urn}") - mapped_variant_join_clause = and_( - MappedVariant.variant_id == Variant.id, - MappedVariant.post_mapped.is_not(None), - MappedVariant.current.is_(True), - ) - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant, mapped_variant_join_clause, isouter=True) - .join(ScoreSet) - .where(ScoreSet.urn == urn) - ).all() - - if not variant_objects: - logger.warning(f"No mapped variants found for score set with URN: {urn}") - continue - - logger.debug(f"Preparing {len(variant_objects)} mapped variants for submission") - - variant_content: list[tuple[str, Variant, Optional[MappedVariant]]] = [] - for variant, mapped_variant in variant_objects: - if mapped_variant is None: - if variant.hgvs_nt is not None and intronic_variant_with_reference_regex.search(variant.hgvs_nt): - # Use the hgvs_nt string for unmapped intronic variants. This is because our mapper does not yet - # support mapping intronic variants. - variation = variant.hgvs_nt - if variation: - logger.info(f"Using hgvs_nt for unmapped intronic variant {variant.urn}: {variation}") - elif variant.hgvs_nt is not None and variant_with_reference_regex.search(variant.hgvs_nt): - # Use the hgvs_nt string for other unmapped NT variants in accession-based score sets. - variation = variant.hgvs_nt - if variation: - logger.info(f"Using hgvs_nt for unmapped non-intronic variant {variant.urn}: {variation}") - elif variant.hgvs_pro is not None and variant_with_reference_regex.search(variant.hgvs_pro): - # Use the hgvs_pro string for unmapped PRO variants in accession-based score sets. - variation = variant.hgvs_pro - if variation: - logger.info(f"Using hgvs_pro for unmapped non-intronic variant {variant.urn}: {variation}") - else: - logger.warning( - f"No variation found for unmapped variant {variant.urn} (nt: {variant.hgvs_nt}, aa: {variant.hgvs_pro}, splice: {variant.hgvs_splice})." - ) - continue - else: - if unlinked_only and mapped_variant.clingen_allele_id: - continue - # If the script was run with the --prefer-unmapped-hgvs flag, use the hgvs_nt string rather than the - # mapped variant, as long as the variant is accession-based. - if ( - prefer_unmapped_hgvs - and variant.hgvs_nt is not None - and variant_with_reference_regex.search(variant.hgvs_nt) - ): - variation = variant.hgvs_nt - if variation: - logger.info(f"Using hgvs_nt for mapped variant {variant.urn}: {variation}") - elif ( - prefer_unmapped_hgvs - and variant.hgvs_pro is not None - and variant_with_reference_regex.search(variant.hgvs_pro) - ): - variation = variant.hgvs_pro - if variation: - logger.info( - f"Using hgvs_pro for mapped variant {variant.urn}: {variation}" - ) # continue # TEMPORARY. Only submit unmapped variants. - else: - variation = get_hgvs_from_post_mapped(mapped_variant) - if variation: - logger.info(f"Using mapped variant for {variant.urn}: {variation}") - - if not variation: - logger.warning( - f"No variation found for mapped variant {variant.urn} (nt: {variant.hgvs_nt}, aa: {variant.hgvs_pro}, splice: {variant.hgvs_splice})." - ) - continue - - variant_content.append((variation, variant, mapped_variant)) - - if debug: - logger.debug("Debug mode enabled. Submitting only one request to ClinGen.") - variant_content = variant_content[:1] - - logger.debug(f"Constructing LDH submission for {len(variant_content)} variants") - submission_content = construct_ldh_submission(variant_content) - submission_successes, submission_failures = ldh_service.dispatch_submissions( - submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE - ) - - if submission_failures: - logger.error(f"Failed to submit some variants for URN: {urn}") - else: - logger.info(f"Successfully submitted all variants for URN: {urn}") - - submitted_entities.extend([variant.urn for _, variant, _ in variant_content]) - - except Exception as e: - logger.error(f"Error processing URN {urn}", exc_info=e) - - # TODO#372: non-nullable urns. - return submitted_entities # type: ignore - - @click.command() -@with_database_session @click.argument("urns", nargs=-1) @click.option("--all", help="Submit variants for every score set in MaveDB.", is_flag=True) -@click.option( - "--unlinked", - default=False, - help="Only submit variants that have not already been linked to ClinGen alleles.", - is_flag=True, -) -@click.option( - "--prefer-unmapped-hgvs", - default=False, - help="If the unmapped HGVS string is accession-based, use it in the submission instead of the mapped variant.", - is_flag=True, -) -@click.option("--suppress-output", help="Suppress final print output to the console.", is_flag=True) -@click.option("--debug", help="Enable debug mode. This will send only one request at most to ClinGen", is_flag=True) -def submit_clingen_urns_command( - db: Session, - urns: Sequence[str], - all: bool, - unlinked: bool, - prefer_unmapped_hgvs: bool, - suppress_output: bool, - debug: bool, -) -> None: +def main(db: Session, urns: Sequence[str], all: bool) -> None: """ - Submit data to ClinGen for mapped variant allele ID generation for the given URNs. + Submit data to ClinGen LDH for mapped variant allele ID generation for the given URNs. """ + db = SessionLocal() + if urns and all: logger.error("Cannot provide both URNs and --all option.") return if all: - # TODO#372: non-nullable urns. - urns = db.scalars(select(ScoreSet.urn)).all() # type: ignore - - if not urns: - logger.error("No URNs provided. Please provide at least one URN.") - return - - submitted_variant_urns = submit_urns_to_clingen(db, urns, unlinked, prefer_unmapped_hgvs, debug) - - if not suppress_output: - print(", ".join(submitted_variant_urns)) + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info(f"Command invoked with --all. Routine will submit LDH data for {len(score_set_ids)} score sets.") + else: + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Submitting LDH data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for ldh submission + job_def = STANDALONE_JOB_DEFINITIONS[submit_score_set_mappings_to_ldh] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Submitting LDH data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await submit_score_set_mappings_to_ldh(ctx, job_run.id) # type: ignore if __name__ == "__main__": - submit_clingen_urns_command() + main() diff --git a/src/mavedb/scripts/environment.py b/src/mavedb/scripts/environment.py index 66bdbb78..831da7a4 100644 --- a/src/mavedb/scripts/environment.py +++ b/src/mavedb/scripts/environment.py @@ -4,16 +4,14 @@ import enum import logging -import click from functools import wraps - +import asyncclick as click from sqlalchemy.orm import configure_mappers from mavedb import deps from mavedb.models import * # noqa: F403 - logger = logging.getLogger(__name__) diff --git a/src/mavedb/scripts/link_clingen_variants.py b/src/mavedb/scripts/link_clingen_variants.py deleted file mode 100644 index 2ca3c069..00000000 --- a/src/mavedb/scripts/link_clingen_variants.py +++ /dev/null @@ -1,75 +0,0 @@ -import click -import logging -from typing import Sequence - -from sqlalchemy import and_, select -from sqlalchemy.orm import Session - -from mavedb.lib.clingen.services import get_clingen_variation, clingen_allele_id_from_ldh_variation -from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant -from mavedb.scripts.environment import with_database_session - -logger = logging.getLogger(__name__) - - -@click.command() -@with_database_session -@click.argument("urns", nargs=-1) -@click.option("--score-sets/--variants", default=False) -@click.option("--unlinked", default=False, is_flag=True) -def link_clingen_variants(db: Session, urns: Sequence[str], score_sets: bool, unlinked: bool) -> None: - """ - Submit data to ClinGen for mapped variant allele ID generation for the given URNs. - """ - if not urns: - logger.error("No URNs provided. Please provide at least one URN.") - return - - # Convert score set URNs to variant URNs. - if score_sets: - query = ( - select(Variant.urn) - .join(MappedVariant) - .join(ScoreSet) - .where(MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None)) - ) - - if unlinked: - query = query.where(MappedVariant.clingen_allele_id.is_(None)) - - variants = [db.scalars(query.where(ScoreSet.urn == urn)).all() for urn in urns] - urns = [variant for sublist in variants for variant in sublist if variant is not None] - - failed_urns = [] - for urn in urns: - ldh_variation = get_clingen_variation(urn) - allele_id = clingen_allele_id_from_ldh_variation(ldh_variation) - - if not allele_id: - failed_urns.append(urn) - continue - - mapped_variant = db.scalar( - select(MappedVariant).join(Variant).where(and_(Variant.urn == urn, MappedVariant.current.is_(True))) - ) - - if not mapped_variant: - logger.warning(f"No mapped variant found for URN {urn}.") - failed_urns.append(urn) - continue - - mapped_variant.clingen_allele_id = allele_id - db.add(mapped_variant) - - logger.info(f"Successfully linked URN {urn} to ClinGen variation {allele_id}.") - - if failed_urns: - logger.warning(f"Failed to link the following {len(failed_urns)} URNs: {', '.join(failed_urns)}") - - logger.info(f"Linking process completed. Linked {len(urns) - len(failed_urns)}/{len(urns)} URNs successfully.") - - -if __name__ == "__main__": - link_clingen_variants() diff --git a/src/mavedb/scripts/link_gnomad_variants.py b/src/mavedb/scripts/link_gnomad_variants.py index e7f0fa49..af684683 100644 --- a/src/mavedb/scripts/link_gnomad_variants.py +++ b/src/mavedb/scripts/link_gnomad_variants.py @@ -1,80 +1,66 @@ +import datetime import logging -from typing import Sequence -import click -from sqlalchemy import select -from sqlalchemy.orm import Session +import asyncclick as click -from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.variant import Variant -from mavedb.scripts.environment import with_database_session - +from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) @click.command() -@with_database_session -@click.option( - "--score-set-urn", multiple=True, type=str, help="Score set URN(s) to process. Can be used multiple times." -) +@click.argument("urns", nargs=-1) @click.option("--all", "all_score_sets", is_flag=True, help="Process all score sets in the database.", default=False) -@click.option("--only-current", is_flag=True, help="Only process current mapped variants.", default=True) -def link_gnomad_variants(db: Session, score_set_urn: list[str], all_score_sets: bool, only_current: bool) -> None: +async def main(urns: list[str], all_score_sets: bool) -> None: """ Query AWS Athena for gnomAD variants matching mapped variant CAIDs for one or more score sets. """ - # 1. Collect all CAIDs for mapped variants in the selected score sets + db = SessionLocal() + if all_score_sets: - score_sets = db.query(ScoreSet.id).all() - score_set_ids = [s.id for s in score_sets] + logger.info("Processing all score sets in the database.") + score_sets = db.query(ScoreSet).all() else: - if not score_set_urn: - logger.error("No score set URNs specified.") - return - - score_sets = db.query(ScoreSet.id).filter(ScoreSet.urn.in_(score_set_urn)).all() - score_set_ids = [s.id for s in score_sets] - if len(score_set_ids) != len(score_set_urn): - logger.warning("Some provided URNs were not found in the database.") - - if not score_set_ids: - logger.error("No score sets found.") - return - - caid_query = ( - select(MappedVariant.clingen_allele_id) - .join(Variant) - .where(Variant.score_set_id.in_(score_set_ids), MappedVariant.clingen_allele_id.is_not(None)) - ) - - if only_current: - caid_query = caid_query.where(MappedVariant.current.is_(True)) - - # We filter out Nonetype CAIDs to avoid issues with Athena queries, so we can type this as Sequence[str] and ignore MyPy warnings - caids: Sequence[str] = db.scalars(caid_query.distinct()).all() # type: ignore - if not caids: - logger.error("No CAIDs found for the selected score sets.") - return - - logger.info(f"Found {len(caids)} CAIDs for the selected score sets to link to gnomAD variants.") - - # 2. Query Athena for gnomAD variants matching the CAIDs - gnomad_variant_data = gnomad_variant_data_for_caids(caids) - - if not gnomad_variant_data: - logger.error("No gnomAD records found for the provided CAIDs.") - return - - logger.info(f"Fetched {len(gnomad_variant_data)} gnomAD records from Athena.") - - # 3. Link gnomAD variants to mapped variants in the database - link_gnomad_variants_to_mapped_variants(db, gnomad_variant_data, only_current=only_current) - - logger.info("Done linking gnomAD variants.") + logger.info(f"Processing score sets with URNs: {urns}") + score_sets = db.query(ScoreSet).filter(ScoreSet.urn.in_(urns)).all() + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for gnomAD linking + job_def = STANDALONE_JOB_DEFINITIONS[link_gnomad_variants] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set in score_sets: + logger.info(f"Linking gnomAD variants for score set ID {score_set.id} (URN: {score_set.urn})...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set.id, + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set.id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await link_gnomad_variants(ctx, job_run.id) # type: ignore if __name__ == "__main__": - link_gnomad_variants() + main() diff --git a/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py b/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py index c681babc..1e37b103 100644 --- a/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py +++ b/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py @@ -1,126 +1,129 @@ -import click +import asyncio +import datetime import logging -from typing import Optional -from sqlalchemy.orm import Session +import asyncclick as click # using asyncclick to allow async commands -from mavedb.scripts.environment import with_database_session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.models.enums.job_pipeline import JobStatus from mavedb.models.score_set import ScoreSet -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession -from mavedb.lib.mapping import extract_ids_from_post_mapped_metadata - -VALID_UNIPROT_DBS = [ - "UniProtKB", - "UniProtKB_AC-ID", - "UniProtKB-Swiss-Prot", - "UniParc", - "UniRef50", - "UniRef90", - "UniRef100", -] +from mavedb.worker.jobs.external_services.uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) @click.command() -@with_database_session -@click.option("--score-set-urn", type=str, default=None, help="Score set URN to process. If not provided, process all.") +@click.argument("score_set_urn", type=str, required=True) @click.option("--polling-interval", type=int, default=30, help="Polling interval in seconds for checking job status.") @click.option("--polling-attempts", type=int, default=5, help="Number of tries to poll for job completion.") -@click.option("--to-db", type=str, default="UniProtKB", help="Target UniProt database for ID mapping.") -@click.option( - "--prefer-swiss-prot", is_flag=True, default=True, help="Prefer Swiss-Prot entries in the mapping results." -) @click.option( - "--refresh-mapped-identifier", + "--refresh", is_flag=True, default=False, help="Refresh the existing mapped identifier, if one exists.", ) -def main( - db: Session, - score_set_urn: Optional[str], +async def main( + score_set_urn: str, polling_interval: int, polling_attempts: int, - to_db: str, - prefer_swiss_prot: bool = True, - refresh_mapped_identifier: bool = False, + refresh: bool = False, ) -> None: - if to_db not in VALID_UNIPROT_DBS: - raise ValueError(f"Invalid target database: {to_db}. Must be one of {VALID_UNIPROT_DBS}.") + db = SessionLocal() + if score_set_urn: - score_sets = db.query(ScoreSet).filter(ScoreSet.urn == score_set_urn).all() - else: - score_sets = db.query(ScoreSet).all() - - api = UniProtIDMappingAPI(polling_interval=polling_interval, polling_tries=polling_attempts) - - logger.info(f"Processing {len(score_sets)} score sets.") - for score_set in score_sets: - logger.info(f"Processing score set: {score_set.urn}") - - if not score_set.target_genes: - logger.warning(f"No target gene for score set {score_set.urn}. Skipped mapping this score set.") - continue - - for target_gene in score_set.target_genes: - if target_gene.uniprot_id_from_mapped_metadata and not refresh_mapped_identifier: - logger.debug( - f"Target gene {target_gene.id} already has UniProt ID {target_gene.uniprot_id_from_mapped_metadata} and refresh_mapped_identifier is False. Skipped mapping this target." - ) - continue - - if not target_gene.post_mapped_metadata: - logger.warning( - f"No post-mapped metadata for target gene {target_gene.id}. Skipped mapping this target." - ) - continue - - ids = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not ids: - logger.warning( - f"No IDs found in post_mapped_metadata for target gene {target_gene.id}. Skipped mapping this target." - ) - continue - if len(ids) > 1: - logger.warning( - f"More than one accession ID found in post_mapped_metadata for target gene {target_gene.id}. Skipped mapping this target." - ) - continue - - id_to_map = ids[0] - from_db = infer_db_name_from_sequence_accession(id_to_map) - job_id = api.submit_id_mapping(from_db, to_db=to_db, ids=[id_to_map]) - - if not job_id: - logger.warning(f"Failed to submit job for target gene {target_gene.id}. Skipped mapping this target.") - continue - if not api.check_id_mapping_results_ready(job_id): - logger.warning(f"Job {job_id} not ready for target gene {target_gene.id}. Skipped mapping this target.") - continue - - results = api.get_id_mapping_results(job_id) - mapped_results = api.extract_uniprot_id_from_results(results, prefer_swiss_prot=prefer_swiss_prot) - - if not mapped_results: - logger.warning(f"No UniProt ID found for target gene {target_gene.id}. Skipped mapping this target.") - continue - if len(mapped_results) > 1: - logger.warning( - f"Could not unambiguously map target gene {target_gene.id}. Found multiple UniProt IDs ({len(mapped_results)})." - ) - continue - - uniprot_id = mapped_results[0][id_to_map]["uniprot_id"] - target_gene.uniprot_id_from_mapped_metadata = uniprot_id - db.add(target_gene) - - logger.info(f"Updated target gene {target_gene.id} with UniProt ID {uniprot_id}.") - - logger.info(f"Processed score set {score_set.urn} with {len(score_set.target_genes)} target genes.") - - logger.info(f"Done processing {len(score_sets)} score sets.") + score_set = db.query(ScoreSet).filter(ScoreSet.urn == score_set_urn).one() + + score_set_id = score_set.id + if not refresh and any(tg.uniprot_id_from_mapped_metadata for tg in score_set.target_genes): + logger.info(f"Score set {score_set_urn} already has mapped UniProt IDs. Use --refresh to re-map.") + return + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definitions + submission_def = STANDALONE_JOB_DEFINITIONS[submit_uniprot_mapping_jobs_for_score_set] + polling_def = STANDALONE_JOB_DEFINITIONS[poll_uniprot_mapping_jobs_for_score_set] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + submission_run = job_factory.create_job_run( + job_def=submission_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + }, + ) + db.add(submission_run) + db.flush() + + polling_run = job_factory.create_job_run( + job_def=polling_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + "mapping_jobs": {}, # Will be filled in by the submission job + }, + ) + db.add(polling_run) + db.flush() + + # Dependencies are still valid outside of pipeline contexts, but we must invoke + # dependent jobs manually. + polling_dependency = job_factory.create_job_dependency( + parent_job_run_id=submission_run.id, child_job_run_id=polling_run.id + ) + db.add(polling_dependency) + db.flush() + + logger.info( + f"Submitted UniProt ID mapping submission job run ID {submission_run.id} for score set URN {score_set_urn}." + ) + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await submit_uniprot_mapping_jobs_for_score_set(ctx, submission_run.id) # type: ignore[call-arg] + + job_manager = JobManager(db, None, submission_run.id) + for i in range(polling_attempts): + logger.info( + f"Submitted UniProt ID mapping polling job run ID {polling_run.id} for score set URN {score_set_urn}, attempt {i + 1}." + ) + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + polling_result: JobResultData = await poll_uniprot_mapping_jobs_for_score_set(ctx, polling_run.id) # type: ignore[call-arg] + db.refresh(polling_run) + + if polling_run.status == JobStatus.SUCCEEDED: + logger.info(f"Polling job for score set URN {score_set_urn} succeeded on attempt {i + 1}.") + break + + logger.info( + f"Polling job for score set URN {score_set_urn} failed on attempt {i + 1} with error: {polling_result.get('exception')}" + ) + db.refresh(polling_run) + job_manager.prepare_retry(f"Polling job failed. Attempting retry in {polling_interval} seconds.") + await asyncio.sleep(polling_interval) + + logger.info(f"Completed UniProt ID mapping for score set URN {score_set_urn}. Polling result : {polling_result}") if __name__ == "__main__": diff --git a/src/mavedb/scripts/populate_mapped_variants.py b/src/mavedb/scripts/populate_mapped_variants.py index de9eedbd..759026bf 100644 --- a/src/mavedb/scripts/populate_mapped_variants.py +++ b/src/mavedb/scripts/populate_mapped_variants.py @@ -1,178 +1,72 @@ +import datetime import logging -from datetime import date -from typing import Optional, Sequence, Union +from typing import Optional, Sequence -import click -from sqlalchemy import cast, select -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Session +import asyncclick as click # using asyncclick to allow async commands +from sqlalchemy import select -from mavedb.data_providers.services import vrs_mapper -from mavedb.lib.exceptions import NonexistentMappingReferenceError -from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.mapping import ANNOTATION_LAYERS -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.mapped_variant import MappedVariant +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.scripts.environment import script_environment, with_database_session +from mavedb.scripts.environment import script_environment +from mavedb.worker.jobs import STANDALONE_JOB_DEFINITIONS, map_variants_for_score_set +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -def variant_from_mapping(db: Session, mapping: dict, dcd_mapping_version: str) -> MappedVariant: - variant_urn = mapping.get("mavedb_id") - variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() - - return MappedVariant( - variant_id=variant.id, - pre_mapped=mapping.get("pre_mapped"), - post_mapped=mapping.get("post_mapped"), - modification_date=date.today(), - mapped_date=date.today(), # since this is a one-time script, assume mapping was done today - vrs_version=mapping.get("vrs_version"), - mapping_api_version=dcd_mapping_version, - error_message=mapping.get("error_message"), - current=True, - ) - - @script_environment.command() -@with_database_session @click.argument("urns", nargs=-1) @click.option("--all", help="Populate mapped variants for every score set in MaveDB.", is_flag=True) -def populate_mapped_variant_data(db: Session, urns: Sequence[Optional[str]], all: bool): +@click.option("--as-user-id", type=int, help="User ID to attribute as the updater of the mapped variants.") +async def populate_mapped_variant_data(urns: Sequence[Optional[str]], all: bool, as_user_id: Optional[int]): score_set_ids: Sequence[Optional[int]] + db = SessionLocal() + if all: score_set_ids = db.scalars(select(ScoreSet.id)).all() logger.info( - f"Command invoked with --all. Routine will populate mapped variant data for {len(urns)} score sets." + f"Command invoked with --all. Routine will populate mapped variant data for {len(score_set_ids)} score sets." ) else: score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() - logger.info(f"Populating mapped variant data for the provided score sets ({len(urns)}).") - - vrs = vrs_mapper() - - for idx, ss_id in enumerate(score_set_ids): - if not ss_id: - continue - - score_set = db.scalar(select(ScoreSet).where(ScoreSet.id == ss_id)) - if not score_set: - logger.warning(f"Could not fetch score set with id={ss_id}.") - continue - - try: - existing_mapped_variants = ( - db.query(MappedVariant) - .join(Variant) - .join(ScoreSet) - .filter(ScoreSet.id == ss_id, MappedVariant.current.is_(True)) - .all() - ) - - for variant in existing_mapped_variants: - variant.current = False - - assert score_set.urn - logger.info(f"Mapping score set {score_set.urn}.") - mapped_scoreset = vrs.map_score_set(score_set.urn) - logger.info(f"Done mapping score set {score_set.urn}.") - - dcd_mapping_version = mapped_scoreset["dcd_mapping_version"] - mapped_scores = mapped_scoreset.get("mapped_scores") - - if not mapped_scores: - # if there are no mapped scores, the score set failed to map. - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": mapped_scoreset.get("error_message")} - db.commit() - logger.info(f"No mapped variants available for {score_set.urn}.") - else: - reference_metadata = mapped_scoreset.get("reference_sequences") - if not reference_metadata: - raise NonexistentMappingReferenceError() - - for target_gene_identifier in reference_metadata: - target_gene = next( - ( - target_gene - for target_gene in score_set.target_genes - if target_gene.name == target_gene_identifier - ), - None, - ) - if not target_gene: - raise ValueError( - f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." - ) - # allow for multiple annotation layers - pre_mapped_metadata = {} - post_mapped_metadata: dict[str, Union[Optional[str], dict[str, dict[str, str | list[str]]]]] = {} - excluded_pre_mapped_keys = {"sequence"} - - gene_info = reference_metadata[target_gene_identifier].get("gene_info") - if gene_info: - target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") - post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") - - for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: - layer_premapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( - "computed_reference_sequence" - ) - if layer_premapped: - pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { - k: layer_premapped[k] - for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys - } - layer_postmapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( - "mapped_reference_sequence" - ) - if layer_postmapped: - post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped - target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) - target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) - - mapped_variants = [ - variant_from_mapping(db=db, mapping=mapped_score, dcd_mapping_version=dcd_mapping_version) - for mapped_score in mapped_scores - ] - logger.debug(f"Done constructing {len(mapped_variants)} mapped variant objects.") - - num_successful_variants = len( - [variant for variant in mapped_variants if variant.post_mapped is not None] - ) - logger.debug( - f"{num_successful_variants}/{len(mapped_variants)} variants generated a post-mapped VRS object." - ) - - if num_successful_variants == 0: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "All variants failed to map"} - elif num_successful_variants < len(mapped_variants): - score_set.mapping_state = MappingState.incomplete - else: - score_set.mapping_state = MappingState.complete - - db.bulk_save_objects(mapped_variants) - db.commit() - logger.info(f"Done populating {len(mapped_variants)} mapped variants for {score_set.urn}.") - - except Exception as e: - logging_context = { - "mapped_score_sets": urns[:idx], - "unmapped_score_sets": urns[idx:], - } - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error(f"Score set {score_set.urn} failed to map.", extra=logging_context) - logger.info(f"Rolling back all changes for scoreset {score_set.urn}") - db.rollback() - - logger.info(f"Done with score set {score_set.urn}. ({idx+1}/{len(urns)}).") - - logger.info("Done populating mapped variant data.") + logger.info(f"Populating mapped variant data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for mapping variants + job_def = STANDALONE_JOB_DEFINITIONS[map_variants_for_score_set] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Populating mapped variant data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "updater_id": as_user_id + if as_user_id is not None + else 1, # Use provided user ID or default to System user + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await map_variants_for_score_set(ctx, job_run.id) # type: ignore[call-arg] if __name__ == "__main__": diff --git a/src/mavedb/scripts/refresh_clinvar_variant_data.py b/src/mavedb/scripts/refresh_clinvar_variant_data.py index b043272c..5505aa15 100644 --- a/src/mavedb/scripts/refresh_clinvar_variant_data.py +++ b/src/mavedb/scripts/refresh_clinvar_variant_data.py @@ -1,172 +1,78 @@ -import click -from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -import requests -import csv -import time +import datetime import logging -import gzip -import random -import io -import sys - -from typing import Dict, Any, Optional, Sequence -from datetime import date +from typing import Sequence -from sqlalchemy import and_, select, distinct -from sqlalchemy.orm import Session +import asyncclick as click +from sqlalchemy import select -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.clinical_control import ClinicalControl -from mavedb.scripts.environment import with_database_session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.models.score_set import ScoreSet +from mavedb.worker.jobs.external_services.clinvar import refresh_clinvar_controls +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -# Some older variant summary files have larger field sizes than the default CSV reader can handle. -csv.field_size_limit(sys.maxsize) - - -def fetch_clinvar_variant_summary_tsv(month: Optional[str], year: str) -> bytes: - if month is None and year is None: - url = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/variant_summary.txt.gz" - else: - if int(year) <= 2023: - url = f"https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/archive/{year}/variant_summary_{year}-{month}.txt.gz" - else: - url = ( - f"https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/archive/variant_summary_{year}-{month}.txt.gz" - ) - - response = requests.get(url, stream=True) - response.raise_for_status() - return response.content - - -def parse_tsv(tsv_content: bytes) -> Dict[int, Dict[str, str]]: - with gzip.open(filename=io.BytesIO(tsv_content), mode="rt") as f: - # This readlines object will only be a list of bytes if the file is opened in "rb" mode. - reader = csv.DictReader(f.readlines(), delimiter="\t") # type: ignore - data = {int(row["#AlleleID"]): row for row in reader} - - return data - - -def query_clingen_allele_api(allele_id: str) -> Dict[str, Any]: - url = f"https://reg.clinicalgenome.org/allele/{allele_id}" - retries = 5 - for i in range(retries): - try: - response = requests.get(url) - response.raise_for_status() - break - except requests.RequestException as e: - if i < retries - 1: - wait_time = (2**i) + random.uniform(0, 1) - logger.warning(f"Request failed ({e}), retrying in {wait_time:.2f} seconds...") - time.sleep(wait_time) - else: - logger.error(f"Request failed after {retries} attempts: {e}") - raise - - logger.debug(f"Fetched ClinGen data for allele ID {allele_id}.") - return response.json() - -def refresh_clinvar_variants(db: Session, month: Optional[str], year: str, urns: Sequence[str]) -> None: - tsv_content = fetch_clinvar_variant_summary_tsv(month, year) - tsv_data = parse_tsv(tsv_content) - version = f"{month}_{year}" if month and year else f"{date.today().month}_{date.today().year}" - logger.info(f"Fetched TSV variant data for ClinVar for {version}.") - if urns: - clingen_ids = db.scalars( - select(distinct(MappedVariant.clingen_allele_id)) - .join(Variant) - .join(ScoreSet) - .where( - and_( - MappedVariant.clingen_allele_id.is_not(None), - MappedVariant.current.is_(True), - ScoreSet.urn.in_(urns), - ) - ) - ).all() +@click.command() +@click.argument("urns", nargs=-1) +@click.option("--all", help="Refresh ClinVar variant data for all score sets.", is_flag=True) +@click.option("--month", type=int, help="Month of the ClinVar data release to use (1-12).", required=True) +@click.option("--year", type=int, help="Year of the ClinVar data release to use (e.g., 2024).", required=True) +async def main(urns: Sequence[str], all: bool, month: int, year: int) -> None: + """ + Refresh ClinVar variant data for mapped variants in the given score sets. + """ + db = SessionLocal() + + if urns and all: + logger.error("Cannot provide both URNs and --all option.") + return + + if all: + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info( + f"Command invoked with --all. Routine will refresh ClinVar variant data for {len(score_set_ids)} score sets." + ) else: - clingen_ids = db.scalars( - select(distinct(MappedVariant.clingen_allele_id)).where(MappedVariant.clingen_allele_id.is_not(None)) - ).all() - total_variants_with_clingen_ids = len(clingen_ids) - - logger.info(f"Fetching ClinGen data for {total_variants_with_clingen_ids} variants.") - for index, clingen_id in enumerate(clingen_ids): - if total_variants_with_clingen_ids > 0 and index % (max(total_variants_with_clingen_ids // 100, 1)) == 0: - logger.info(f"Progress: {index / total_variants_with_clingen_ids:.0%}") - - if clingen_id is not None and "," in clingen_id: - logger.debug("Detected a multi-variant ClinGen allele ID, skipping.") - continue - - # Guaranteed based on our query filters. - clingen_data = query_clingen_allele_api(clingen_id) # type: ignore - clinvar_allele_id = clingen_data.get("externalRecords", {}).get("ClinVarAlleles", [{}])[0].get("alleleId") - - if not clinvar_allele_id or clinvar_allele_id not in tsv_data: - logger.debug( - f"No ClinVar variant data found for ClinGen allele ID {clingen_id}. ({index + 1}/{total_variants_with_clingen_ids})." - ) - continue - - variant_data = tsv_data[clinvar_allele_id] - identifier = str(clinvar_allele_id) - - clinvar_variant = db.scalars( - select(ClinicalControl).where( - ClinicalControl.db_identifier == identifier, - ClinicalControl.db_version == version, - ClinicalControl.db_name == "ClinVar", - ) - ).one_or_none() - if clinvar_variant: - clinvar_variant.gene_symbol = variant_data.get("GeneSymbol") - clinvar_variant.clinical_significance = variant_data.get("ClinicalSignificance") - clinvar_variant.clinical_review_status = variant_data.get("ReviewStatus") - else: - clinvar_variant = ClinicalControl( - db_identifier=identifier, - gene_symbol=variant_data.get("GeneSymbol"), - clinical_significance=variant_data.get("ClinicalSignificance"), - clinical_review_status=variant_data.get("ReviewStatus"), - db_version=version, - db_name="ClinVar", - ) - - db.add(clinvar_variant) - - variants_with_clingen_allele_id = db.scalars( - select(MappedVariant).where(MappedVariant.clingen_allele_id == clingen_id) - ).all() - for mapped_variant in variants_with_clingen_allele_id: - if clinvar_variant.id in [c.id for c in mapped_variant.clinical_controls]: - continue - mapped_variant.clinical_controls.append(clinvar_variant) - db.add(mapped_variant) - - db.commit() - logger.debug( - f"Added ClinVar variant data ({identifier}) for ClinGen allele ID {clingen_id}. ({index + 1}/{total_variants_with_clingen_ids})." + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Refreshing ClinVar variant data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for ClinVar controls refresh + job_def = STANDALONE_JOB_DEFINITIONS[refresh_clinvar_controls] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Refreshing ClinVar variant data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + "month": month, + "year": year, + }, ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") - -@click.command() -@with_database_session -@click.argument("urns", nargs=-1) -@click.option("--month", default=None, help="Populate mapped variants for every score set in MaveDB.") -@click.option("--year", required=True, help="Populate mapped variants for every score set in MaveDB.") -def refresh_clinvar_variants_command(db: Session, month: Optional[str], year: str, urns: Sequence[str]) -> None: - refresh_clinvar_variants(db, month, year, urns) + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await refresh_clinvar_controls(ctx, job_run.id) # type: ignore if __name__ == "__main__": - refresh_clinvar_variants_command() + main() diff --git a/src/mavedb/worker/README.md b/src/mavedb/worker/README.md new file mode 100644 index 00000000..45745205 --- /dev/null +++ b/src/mavedb/worker/README.md @@ -0,0 +1,12 @@ +# ARQ Worker Jobs Developer Documentation + +This documentation provides an overview and detailed guidance for developers working with the ARQ worker jobs, decorators, and managers in the MaveDB API codebase. It is organized into the following sections: + +- [Job System Overview](jobs_overview.md) +- [Job Decorators](job_decorators.md) +- [Job Managers](job_managers.md) +- [Pipeline Management](pipeline_management.md) +- [Job Registry and Configuration](job_registry.md) +- [Best Practices & Patterns](best_practices.md) + +Each section is a separate markdown file for clarity and maintainability. Start with `jobs_overview.md` for a high-level understanding, then refer to the other files for implementation details and usage patterns. diff --git a/src/mavedb/worker/best_practices.md b/src/mavedb/worker/best_practices.md new file mode 100644 index 00000000..65301284 --- /dev/null +++ b/src/mavedb/worker/best_practices.md @@ -0,0 +1,31 @@ +# Best Practices & Patterns + +## General Principles +- Use decorators to ensure all jobs are tracked, auditable, and robust to errors. +- Keep job functions focused and stateless; use the database and JobManager for state. +- Prefer async functions for jobs to maximize concurrency. +- Use the appropriate manager (JobManager or PipelineManager) for state transitions and coordination. +- Write unit tests for job logic and integration tests for job orchestration. + +## Error Handling +- Always handle exceptions at the job or pipeline boundary. Legacy score set and mapping jobs track status at the +item level, but this will be remedied in a future update. +- Use custom exception types for clarity and recovery strategies. +- Log all errors with sufficient context for debugging and audit. + +## Job Design +- Use `with_guaranteed_job_run_record` for standalone jobs that require audit. +- Use `with_pipeline_management` for jobs that are part of a pipeline. +- Avoid side effects outside the job context; use dependency injection for testability. + +## Testing +- Mock external services in unit tests. +- Use integration tests to verify job and pipeline orchestration. +- Test error paths and recovery logic. + +## Documentation +- Document each job's purpose, parameters, and expected side effects. +- Update the registry and README when adding new jobs. + +## References +- See the other markdown files in this directory for detailed usage and examples. diff --git a/src/mavedb/worker/job_decorators.md b/src/mavedb/worker/job_decorators.md new file mode 100644 index 00000000..c3511b07 --- /dev/null +++ b/src/mavedb/worker/job_decorators.md @@ -0,0 +1,48 @@ +# Job Decorators + +Job decorators provide lifecycle management, error handling, and audit guarantees for ARQ worker jobs. They are essential for ensuring that jobs are tracked, failures are handled robustly, and pipelines are coordinated correctly. + +## Key Decorators + +### `with_guaranteed_job_run_record(job_type)` +- Ensures a `JobRun` record is created and persisted before job execution begins. +- Should be applied before any job management decorators. +- Not supported for pipeline jobs. +- Example: + ```python + @with_guaranteed_job_run_record("cron_job") + @with_job_management + async def my_cron_job(ctx, ...): + ... + ``` + +### `with_job_management` +- Adds automatic job lifecycle management to ARQ worker functions. +- Tracks job start/completion, injects a `JobManager` for progress and state updates, and handles errors robustly. +- Supports both sync and async functions. +- Example: + ```python + @with_job_management + async def my_job(ctx, job_manager: JobManager): + job_manager.update_progress(10, message="Starting work") + ... + ``` + +### `with_pipeline_management` +- Adds pipeline lifecycle management to jobs that are part of a pipeline. +- Coordinates the pipeline after the job completes (success or failure). +- Built on top of `with_job_management`. +- Example: + ```python + @with_pipeline_management + async def my_pipeline_job(ctx, ...): + ... + ``` + +## Stacking Order +- If using both `with_guaranteed_job_run_record` and `with_job_management`, always apply `with_guaranteed_job_run_record` first. +- For pipeline jobs, use only `with_pipeline_management` (which includes job management). + +## See Also +- [Job Managers](job_managers.md) +- [Pipeline Management](pipeline_management.md) diff --git a/src/mavedb/worker/job_managers.md b/src/mavedb/worker/job_managers.md new file mode 100644 index 00000000..b099b4de --- /dev/null +++ b/src/mavedb/worker/job_managers.md @@ -0,0 +1,36 @@ +# Job Managers + +Job managers are responsible for the lifecycle, state transitions, and progress tracking of jobs and pipelines. They provide atomic operations, robust error handling, and ensure data consistency. + +## JobManager +- Manages the lifecycle of a single job (start, progress, success, failure, retry, cancel). +- Ensures atomic state transitions and safe rollback on failure. +- Does not commit database changes (only flushes); the caller is responsible for commits. +- Handles progress tracking, retry logic, and session cleanup. +- Example usage: + ```python + manager = JobManager(db, redis, job_id=123) + manager.start_job() + manager.update_progress(25, message="Starting validation") + manager.succeed_job(result={"count": 100}) + ``` + +## PipelineManager +- Coordinates pipeline execution, manages job dependencies, and updates pipeline status. +- Handles pausing, unpausing, and cancellation of pipelines. +- Uses the same exception hierarchy as JobManager for consistency. +- Example usage: + ```python + pipeline_manager = PipelineManager(db, redis, pipeline_id=456) + await pipeline_manager.coordinate_pipeline() + new_status = pipeline_manager.transition_pipeline_status() + cancelled_count = pipeline_manager.cancel_remaining_jobs(reason="Dependency failed") + ``` + +## Exception Handling +- Both managers use custom exceptions for database errors, state errors, and coordination errors. +- Always handle exceptions at the job or pipeline boundary to ensure robust recovery and logging. + +## See Also +- [Job Decorators](job_decorators.md) +- [Pipeline Management](pipeline_management.md) diff --git a/src/mavedb/worker/job_registry.md b/src/mavedb/worker/job_registry.md new file mode 100644 index 00000000..c470c1ed --- /dev/null +++ b/src/mavedb/worker/job_registry.md @@ -0,0 +1,39 @@ +# Job Registry and Configuration + +All ARQ worker jobs must be registered for execution and scheduling. The registry provides a centralized list of available jobs and cron jobs for ARQ configuration. + +## Job Registry +- Located in `jobs/registry.py`. +- Lists all job functions in `BACKGROUND_FUNCTIONS` for ARQ worker discovery. +- Defines scheduled (cron) jobs in `BACKGROUND_CRONJOBS` using ARQ's `cron` utility. + +## Example +```python +from mavedb.worker.jobs.data_management import refresh_materialized_views +from mavedb.worker.jobs.external_services import submit_score_set_mappings_to_car + +BACKGROUND_FUNCTIONS = [ + refresh_materialized_views, + submit_score_set_mappings_to_car, + ... +] + +BACKGROUND_CRONJOBS = [ + cron( + refresh_materialized_views, + name="refresh_all_materialized_views", + hour=20, + minute=0, + keep_result=timedelta(minutes=2).total_seconds(), + ), +] +``` + +## Adding a New Job +1. Implement the job function in the appropriate submodule. +2. Add the function to `BACKGROUND_FUNCTIONS` in `registry.py`. +3. (Optional) Add a cron job to `BACKGROUND_CRONJOBS` if scheduling is needed. + +## See Also +- [Job System Overview](jobs_overview.md) +- [Best Practices](best_practices.md) diff --git a/src/mavedb/worker/jobs.py b/src/mavedb/worker/jobs.py deleted file mode 100644 index 3a690d97..00000000 --- a/src/mavedb/worker/jobs.py +++ /dev/null @@ -1,1766 +0,0 @@ -import asyncio -import functools -import logging -from contextlib import asynccontextmanager -from datetime import date, timedelta -from typing import Any, Optional, Sequence - -import pandas as pd -from arq import ArqRedis -from arq.jobs import Job, JobStatus -from cdot.hgvs.dataproviders import RESTDataProvider -from sqlalchemy import cast, delete, null, select -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Session - -from mavedb.data_providers.services import vrs_mapper -from mavedb.db.view import refresh_all_mat_views -from mavedb.lib.clingen.constants import ( - CAR_SUBMISSION_ENDPOINT, - CLIN_GEN_SUBMISSION_ENABLED, - DEFAULT_LDH_SUBMISSION_BATCH_SIZE, - LDH_SUBMISSION_ENDPOINT, - LINKED_DATA_RETRY_THRESHOLD, -) -from mavedb.lib.clingen.content_constructors import construct_ldh_submission -from mavedb.lib.clingen.services import ( - ClinGenAlleleRegistryService, - ClinGenLdhService, - clingen_allele_id_from_ldh_variation, - get_allele_registry_associations, - get_clingen_variation, -) -from mavedb.lib.exceptions import ( - LinkingEnqueueError, - MappingEnqueueError, - NonexistentMappingReferenceError, - NonexistentMappingResultsError, - SubmissionEnqueueError, - UniProtIDMappingEnqueueError, - UniProtPollingEnqueueError, -) -from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants -from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.mapping import ANNOTATION_LAYERS, extract_ids_from_post_mapped_metadata -from mavedb.lib.score_sets import ( - columns_for_dataset, - create_variants, - create_variants_data, -) -from mavedb.lib.slack import log_and_send_slack_message, send_slack_error, send_slack_message -from mavedb.lib.uniprot.constants import UNIPROT_ID_MAPPING_ENABLED -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession -from mavedb.lib.validation.dataframe.dataframe import ( - validate_and_standardize_dataframe_pair, -) -from mavedb.lib.validation.exceptions import ValidationError -from mavedb.lib.variants import get_hgvs_from_post_mapped -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.enums.processing_state import ProcessingState -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.published_variant import PublishedVariantsMV -from mavedb.models.score_set import ScoreSet -from mavedb.models.user import User -from mavedb.models.variant import Variant -from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata - -logger = logging.getLogger(__name__) - -MAPPING_QUEUE_NAME = "vrs_mapping_queue" -MAPPING_CURRENT_ID_NAME = "vrs_mapping_current_job_id" -BACKOFF_LIMIT = 5 -MAPPING_BACKOFF_IN_SECONDS = 15 -LINKING_BACKOFF_IN_SECONDS = 15 * 60 - - -#################################################################################################### -# Job utilities -#################################################################################################### - - -def setup_job_state( - ctx, invoker: Optional[int], resource: Optional[str], correlation_id: Optional[str] -) -> dict[str, Any]: - ctx["state"][ctx["job_id"]] = { - "application": "mavedb-worker", - "user": invoker, - "resource": resource, - "correlation_id": correlation_id, - } - return ctx["state"][ctx["job_id"]] - - -async def enqueue_job_with_backoff( - redis: ArqRedis, job_name: str, attempt: int, backoff: int, *args -) -> tuple[Optional[str], bool, Any]: - new_job_id = None - limit_reached = attempt > BACKOFF_LIMIT - if not limit_reached: - limit_reached = True - backoff = backoff * (2**attempt) - attempt = attempt + 1 - - # NOTE: for jobs supporting backoff, `attempt` should be the final argument. - new_job = await redis.enqueue_job( - job_name, - *args, - attempt, - _defer_by=timedelta(seconds=backoff), - ) - - if new_job: - new_job_id = new_job.job_id - - return (new_job_id, not limit_reached, backoff) - - -#################################################################################################### -# Creating variants -#################################################################################################### - - -async def create_variants_for_score_set( - ctx, - correlation_id: str, - score_set_id: int, - updater_id: int, - scores: pd.DataFrame, - counts: pd.DataFrame, - score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, - count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, -): - """ - Create variants for a score set. Intended to be run within a worker. - On any raised exception, ensure ProcessingState of score set is set to `failed` prior - to exiting. - """ - logging_context = {} - try: - db: Session = ctx["db"] - hdp: RESTDataProvider = ctx["hdp"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) - logger.info(msg="Began processing of score set variants.", extra=logging_context) - - updated_by = db.scalars(select(User).where(User.id == updater_id)).one() - - score_set.modified_by = updated_by - score_set.processing_state = ProcessingState.processing - score_set.mapping_state = MappingState.pending_variant_processing - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - - db.add(score_set) - db.commit() - db.refresh(score_set) - - if not score_set.target_genes: - logger.warning( - msg="No targets are associated with this score set; could not create variants.", - extra=logging_context, - ) - raise ValueError("Can't create variants when score set has no targets.") - - validated_scores, validated_counts, validated_score_columns_metadata, validated_count_columns_metadata = ( - validate_and_standardize_dataframe_pair( - scores_df=scores, - counts_df=counts, - score_columns_metadata=score_columns_metadata, - count_columns_metadata=count_columns_metadata, - targets=score_set.target_genes, - hdp=hdp, - ) - ) - - score_set.dataset_columns = { - "score_columns": columns_for_dataset(validated_scores), - "count_columns": columns_for_dataset(validated_counts), - "score_columns_metadata": validated_score_columns_metadata - if validated_score_columns_metadata is not None - else {}, - "count_columns_metadata": validated_count_columns_metadata - if validated_count_columns_metadata is not None - else {}, - } - - # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. - if score_set.variants: - existing_variants = db.scalars(select(Variant.id).where(Variant.score_set_id == score_set.id)).all() - db.execute(delete(MappedVariant).where(MappedVariant.variant_id.in_(existing_variants))) - db.execute(delete(Variant).where(Variant.id.in_(existing_variants))) - logging_context["deleted_variants"] = score_set.num_variants - score_set.num_variants = 0 - - logger.info(msg="Deleted existing variants from score set.", extra=logging_context) - - db.flush() - db.refresh(score_set) - - variants_data = create_variants_data(validated_scores, validated_counts, None) - create_variants(db, score_set, variants_data) - - # Validation errors arise from problematic user data. These should be inserted into the database so failures can - # be persisted to them. - except ValidationError as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.processing_errors = {"exception": str(e), "detail": e.triggering_exceptions} - score_set.mapping_state = MappingState.not_attempted - - if score_set.num_variants: - score_set.processing_errors["exception"] = ( - f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" - ) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.warning(msg="Encountered a validation error while processing variants.", extra=logging_context) - - return {"success": False} - - # NOTE: Since these are likely to be internal errors, it makes less sense to add them to the DB and surface them to the end user. - # Catch all non-system exiting exceptions. - except Exception as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.processing_errors = {"exception": str(e), "detail": []} - score_set.mapping_state = MappingState.not_attempted - - if score_set.num_variants: - score_set.processing_errors["exception"] = ( - f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" - ) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.warning(msg="Encountered an internal exception while processing variants.", extra=logging_context) - - send_slack_error(err=e) - return {"success": False} - - # Catch all other exceptions. The exceptions caught here were intented to be system exiting. - except BaseException as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.mapping_state = MappingState.not_attempted - db.commit() - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.error( - msg="Encountered an unhandled exception while creating variants for score set.", extra=logging_context - ) - - # Don't raise BaseExceptions so we may emit canonical logs (TODO: Perhaps they are so problematic we want to raise them anyway). - return {"success": False} - - else: - score_set.processing_state = ProcessingState.success - score_set.processing_errors = null() - - logging_context["created_variants"] = score_set.num_variants - logging_context["processing_state"] = score_set.processing_state.name - logger.info(msg="Finished creating variants in score set.", extra=logging_context) - - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - await redis.enqueue_job("variant_mapper_manager", correlation_id, updater_id) - score_set.mapping_state = MappingState.queued - finally: - db.add(score_set) - db.commit() - db.refresh(score_set) - logger.info(msg="Committed new variants to score set.", extra=logging_context) - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return {"success": True} - - -#################################################################################################### -# Mapping variants -#################################################################################################### - - -@asynccontextmanager -async def mapping_in_execution(redis: ArqRedis, job_id: str): - await redis.set(MAPPING_CURRENT_ID_NAME, job_id) - try: - yield - finally: - await redis.set(MAPPING_CURRENT_ID_NAME, "") - - -async def map_variants_for_score_set( - ctx: dict, correlation_id: str, score_set_id: int, updater_id: int, attempt: int = 1 -) -> dict: - async with mapping_in_execution(redis=ctx["redis"], job_id=ctx["job_id"]): - logging_context = {} - score_set = None - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) - logging_context["attempt"] = attempt - logger.info(msg="Started variant mapping", extra=logging_context) - - score_set.mapping_state = MappingState.processing - score_set.mapping_errors = null() - db.add(score_set) - db.commit() - - mapping_urn = score_set.urn - assert mapping_urn, "A valid URN is needed to map this score set." - - logging_context["current_mapping_resource"] = mapping_urn - logging_context["mapping_state"] = score_set.mapping_state - logger.debug(msg="Fetched score set metadata for mapping job.", extra=logging_context) - - # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. - vrs = vrs_mapper() - blocking = functools.partial(vrs.map_score_set, mapping_urn) - loop = asyncio.get_running_loop() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Variant mapper encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - db.rollback() - if score_set: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - - return {"success": False, "retried": False, "enqueued_jobs": []} - - mapping_results = None - try: - mapping_results = await loop.run_in_executor(ctx["pool"], blocking) - logger.debug(msg="Done mapping variants.", extra=logging_context) - - except Exception as e: - db.rollback() - score_set.mapping_errors = { - "error_message": f"Encountered an internal server error during mapping. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." - } - db.add(score_set) - db.commit() - - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.warning( - msg="Variant mapper encountered an unexpected error while mapping variants. This job will be retried.", - extra=logging_context, - ) - - new_job_id = None - max_retries_exceeded = None - try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id - ) - # If we fail to enqueue a mapping manager for this score set, evict it from the queue. - if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as backoff_e: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - send_slack_error(backoff_e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} - logger.critical( - msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", - extra=logging_context, - ) - else: - if new_job_id and not max_retries_exceeded: - score_set.mapping_state = MappingState.queued - db.add(score_set) - db.commit() - logger.info( - msg="After encountering an error while mapping variants, another mapping job was queued.", - extra=logging_context, - ) - elif new_job_id is None and not max_retries_exceeded: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - logger.error( - msg="After encountering an error while mapping variants, another mapping job was unable to be queued. This score set will not be mapped.", - extra=logging_context, - ) - else: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - logger.error( - msg="After encountering an error while mapping variants, the maximum retries for this job were exceeded. This score set will not be mapped.", - extra=logging_context, - ) - finally: - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_jobs": [job for job in [new_job_id] if job], - } - - try: - if mapping_results: - mapped_scores = mapping_results.get("mapped_scores") - if not mapped_scores: - # if there are no mapped scores, the score set failed to map. - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} - else: - reference_metadata = mapping_results.get("reference_sequences") - if not reference_metadata: - raise NonexistentMappingReferenceError() - - for target_gene_identifier in reference_metadata: - target_gene = next( - ( - target_gene - for target_gene in score_set.target_genes - if target_gene.name == target_gene_identifier - ), - None, - ) - if not target_gene: - raise ValueError( - f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." - ) - # allow for multiple annotation layers - pre_mapped_metadata: dict[str, Any] = {} - post_mapped_metadata: dict[str, Any] = {} - excluded_pre_mapped_keys = {"sequence"} - - gene_info = reference_metadata[target_gene_identifier].get("gene_info") - if gene_info: - target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") - post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") - - for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: - layer_premapped = reference_metadata[target_gene_identifier]["layers"][ - annotation_layer - ].get("computed_reference_sequence") - if layer_premapped: - pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { - k: layer_premapped[k] - for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys - } - layer_postmapped = reference_metadata[target_gene_identifier]["layers"][ - annotation_layer - ].get("mapped_reference_sequence") - if layer_postmapped: - post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped - target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) - target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) - - total_variants = 0 - successful_mapped_variants = 0 - for mapped_score in mapped_scores: - total_variants += 1 - variant_urn = mapped_score.get("mavedb_id") - variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() - - # there should only be one current mapped variant per variant id, so update old mapped variant to current = false - existing_mapped_variant = ( - db.query(MappedVariant) - .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) - .one_or_none() - ) - - if existing_mapped_variant: - existing_mapped_variant.current = False - db.add(existing_mapped_variant) - - if mapped_score.get("pre_mapped") and mapped_score.get("post_mapped"): - successful_mapped_variants += 1 - - mapped_variant = MappedVariant( - pre_mapped=mapped_score.get("pre_mapped", null()), - post_mapped=mapped_score.get("post_mapped", null()), - variant_id=variant.id, - modification_date=date.today(), - mapped_date=mapping_results["mapped_date_utc"], - vrs_version=mapped_score.get("vrs_version", null()), - mapping_api_version=mapping_results["dcd_mapping_version"], - error_message=mapped_score.get("error_message", null()), - current=True, - ) - db.add(mapped_variant) - - if successful_mapped_variants == 0: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "All variants failed to map"} - elif successful_mapped_variants < total_variants: - score_set.mapping_state = MappingState.incomplete - else: - score_set.mapping_state = MappingState.complete - - logging_context["mapped_variants_inserted_db"] = len(mapped_scores) - logging_context["variants_successfully_mapped"] = successful_mapped_variants - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["mapping_errors"] = score_set.mapping_errors - logger.info(msg="Inserted mapped variants into db.", extra=logging_context) - - else: - raise NonexistentMappingResultsError() - - db.add(score_set) - db.commit() - - except Exception as e: - db.rollback() - score_set.mapping_errors = { - "error_message": f"Encountered an unexpected error while parsing mapped variants. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." - } - db.add(score_set) - db.commit() - - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.warning( - msg="An unexpected error occurred during variant mapping. This job will be attempted again.", - extra=logging_context, - ) - - new_job_id = None - max_retries_exceeded = None - try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id - ) - # If we fail to enqueue a mapping manager for this score set, evict it from the queue. - if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as backoff_e: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - send_slack_error(backoff_e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} - logger.critical( - msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", - extra=logging_context, - ) - else: - if new_job_id and not max_retries_exceeded: - score_set.mapping_state = MappingState.queued - logger.info( - msg="After encountering an error while parsing mapped variants, another mapping job was queued.", - extra=logging_context, - ) - elif new_job_id is None and not max_retries_exceeded: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - logger.error( - msg="After encountering an error while parsing mapped variants, another mapping job was unable to be queued. This score set will not be mapped.", - extra=logging_context, - ) - else: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - logger.error( - msg="After encountering an error while parsing mapped variants, the maximum retries for this job were exceeded. This score set will not be mapped.", - extra=logging_context, - ) - finally: - db.add(score_set) - db.commit() - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_jobs": [job for job in [new_job_id] if job], - } - - new_uniprot_job_id = None - try: - if UNIPROT_ID_MAPPING_ENABLED: - new_job = await redis.enqueue_job( - "submit_uniprot_mapping_jobs_for_score_set", - score_set.id, - correlation_id, - ) - - if new_job: - new_uniprot_job_id = new_job.job_id - - logging_context["submit_uniprot_mapping_job_id"] = new_uniprot_job_id - logger.info(msg="Queued a new UniProt mapping job.", extra=logging_context) - - else: - raise UniProtIDMappingEnqueueError() - else: - logger.warning( - msg="UniProt ID mapping is disabled, skipped submission of UniProt mapping jobs.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not enqueue UniProt mapping job for score set {score_set.urn}. UniProt mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant UniProt submission encountered an unexpected error while attempting to enqueue a mapping job. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_uniprot_job_id] if job]} - - new_clingen_job_id = None - try: - if CLIN_GEN_SUBMISSION_ENABLED: - new_job = await redis.enqueue_job( - "submit_score_set_mappings_to_car", - correlation_id, - score_set.id, - ) - - if new_job: - new_clingen_job_id = new_job.job_id - - logging_context["submit_clingen_variants_job_id"] = new_clingen_job_id - logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) - - else: - raise SubmissionEnqueueError() - else: - logger.warning( - msg="ClinGen submission is disabled, skipped submission of mapped variants to CAR and LDH.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not submit mappings to CAR and/or LDH mappings for score set {score_set.urn}. Mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", - extra=logging_context, - ) - - return { - "success": False, - "retried": False, - "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], - } - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return { - "success": True, - "retried": False, - "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], - } - - -async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int, attempt: int = 1) -> dict: - logging_context = {} - mapping_job_id = None - mapping_job_status = None - queued_score_set = None - try: - redis: ArqRedis = ctx["redis"] - db: Session = ctx["db"] - - logging_context = setup_job_state(ctx, updater_id, None, correlation_id) - logging_context["attempt"] = attempt - logger.debug(msg="Variant mapping manager began execution", extra=logging_context) - - queue_length = await redis.llen(MAPPING_QUEUE_NAME) # type: ignore - queued_id = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore - logging_context["variant_mapping_queue_length"] = queue_length - - # Setup the job id cache if it does not already exist. - if not await redis.exists(MAPPING_CURRENT_ID_NAME): - await redis.set(MAPPING_CURRENT_ID_NAME, "") - - if not queued_id: - logger.debug(msg="No mapping jobs exist in the queue.", extra=logging_context) - return {"success": True, "enqueued_job": None} - else: - queued_id = queued_id.decode("utf-8") - queued_score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_id)).one() - - logging_context["upcoming_mapping_resource"] = queued_score_set.urn - logger.debug(msg="Found mapping job(s) still in queue.", extra=logging_context) - - mapping_job_id = await redis.get(MAPPING_CURRENT_ID_NAME) - if mapping_job_id: - mapping_job_id = mapping_job_id.decode("utf-8") - mapping_job_status = (await Job(job_id=mapping_job_id, redis=redis).status()).value - - logging_context["existing_mapping_job_status"] = mapping_job_status - logging_context["existing_mapping_job_id"] = mapping_job_id - - except Exception as e: - send_slack_error(e) - - # Attempt to remove this item from the mapping queue. - try: - await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore - logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) - except Exception: - pass - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error(msg="Variant mapper manager encountered an unexpected error during setup.", extra=logging_context) - - return {"success": False, "enqueued_job": None} - - new_job = None - new_job_id = None - try: - if not mapping_job_id or mapping_job_status in (JobStatus.not_found, JobStatus.complete): - logger.debug(msg="No mapping jobs are running, queuing a new one.", extra=logging_context) - - new_job = await redis.enqueue_job( - "map_variants_for_score_set", correlation_id, queued_score_set.id, updater_id, attempt - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["new_mapping_job_id"] = new_job_id - logger.info(msg="Queued a new mapping job.", extra=logging_context) - - return {"success": True, "enqueued_job": new_job_id} - - logger.info( - msg="A mapping job is already running, or a new job was unable to be enqueued. Deferring mapping by 5 minutes.", - extra=logging_context, - ) - - new_job = await redis.enqueue_job( - "variant_mapper_manager", - correlation_id, - updater_id, - attempt, - _defer_by=timedelta(minutes=5), - ) - - if new_job: - # Ensure this score set remains in the front of the queue. - queued_id = await redis.rpush(MAPPING_QUEUE_NAME, queued_score_set.id) # type: ignore - new_job_id = new_job.job_id - - logging_context["new_mapping_manager_job_id"] = new_job_id - logger.info(msg="Deferred a new mapping manager job.", extra=logging_context) - - # Our persistent Redis queue and ARQ's execution rules ensure that even if the worker is stopped and not restarted - # before the deferred time, these deferred jobs will still run once able. - return {"success": True, "enqueued_job": new_job_id} - - raise MappingEnqueueError() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Variant mapper manager encountered an unexpected error while enqueing a mapping job. This job will not be retried.", - extra=logging_context, - ) - - db.rollback() - - # We shouldn't rely on the passed score set id matching the score set we are operating upon. - if not queued_score_set: - return {"success": False, "enqueued_job": new_job_id} - - # Attempt to remove this item from the mapping queue. - try: - await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore - logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) - except Exception: - pass - - score_set_exc = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_score_set.id)).one_or_none() - if score_set_exc: - score_set_exc.mapping_state = MappingState.failed - score_set_exc.mapping_errors = "Unable to queue a new mapping job or defer score set mapping." - db.add(score_set_exc) - db.commit() - - return {"success": False, "enqueued_job": new_job_id} - - -#################################################################################################### -# Materialized Views -#################################################################################################### - - -# TODO#405: Refresh materialized views within an executor. -async def refresh_materialized_views(ctx: dict): - logging_context = setup_job_state(ctx, None, None, None) - logger.debug(msg="Began refresh materialized views.", extra=logging_context) - refresh_all_mat_views(ctx["db"]) - ctx["db"].commit() - logger.debug(msg="Done refreshing materialized views.", extra=logging_context) - return {"success": True} - - -async def refresh_published_variants_view(ctx: dict, correlation_id: str): - logging_context = setup_job_state(ctx, None, None, correlation_id) - logger.debug(msg="Began refresh of published variants materialized view.", extra=logging_context) - PublishedVariantsMV.refresh(ctx["db"]) - ctx["db"].commit() - logger.debug(msg="Done refreshing published variants materialized view.", extra=logging_context) - return {"success": True} - - -#################################################################################################### -# ClinGen resource creation / linkage -#################################################################################################### - - -async def submit_score_set_mappings_to_car(ctx: dict, correlation_id: str, score_set_id: int): - logging_context = {} - score_set = None - text = "Could not submit mappings to ClinGen Allele Registry for score set %s. Mappings for this score set should be submitted manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started CAR mapped resource submission", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to submit CAR objects for this score set." - - logging_context["current_car_submission_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for CAR mapped resource submission.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="CAR mapped resource submission encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_post_mapped_objects = db.execute( - select(MappedVariant.id, MappedVariant.post_mapped) - .join(Variant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() - - if not variant_post_mapped_objects: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", - extra=logging_context, - ) - return {"success": True, "retried": False, "enqueued_job": None} - - variant_post_mapped_hgvs: dict[str, list[int]] = {} - for mapped_variant_id, post_mapped in variant_post_mapped_objects: - hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) - - if not hgvs_for_post_mapped: - logger.warning( - msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant_id}. Skipping submission of this variant.", - extra=logging_context, - ) - continue - - if hgvs_for_post_mapped in variant_post_mapped_hgvs: - variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) - else: - variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to construct post mapped HGVS strings. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - if not CAR_SUBMISSION_ENDPOINT: - logger.warning( - msg="ClinGen Allele Registry submission is disabled (no submission endpoint), skipping submission of mapped variants to CAR.", - extra=logging_context, - ) - return {"success": False, "retried": False, "enqueued_job": None} - - car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) - registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) - for hgvs_string, caid in linked_alleles.items(): - mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] - mapped_variants = db.scalars(select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids))).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = caid - db.add(mapped_variant) - - db.commit() - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - new_job_id = None - try: - new_job = await redis.enqueue_job( - "submit_score_set_mappings_to_ldh", - correlation_id, - score_set.id, - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["submit_clingen_ldh_variants_job_id"] = new_job_id - logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) - - else: - raise SubmissionEnqueueError() - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not submit mappings to LDH for score set {score_set.urn}. Mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": new_job_id} - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return {"success": True, "retried": False, "enqueued_job": new_job_id} - - -async def submit_score_set_mappings_to_ldh(ctx: dict, correlation_id: str, score_set_id: int): - logging_context = {} - score_set = None - text = ( - "Could not submit mappings to LDH for score set %s. Mappings for this score set should be submitted manually." - ) - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started LDH mapped resource submission", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to submit LDH objects for this score set." - - logging_context["current_ldh_submission_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for ldh mapped resource submission.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) - ldh_service.authenticate() - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() - - if not variant_objects: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", - extra=logging_context, - ) - return {"success": True, "retried": False, "enqueued_job": None} - - variant_content = [] - for variant, mapped_variant in variant_objects: - variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) - - if not variation: - logger.warning( - msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant.id}. Skipping submission of this variant.", - extra=logging_context, - ) - continue - - variant_content.append((variation, variant, mapped_variant)) - - submission_content = construct_ldh_submission(variant_content) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to construct submission objects. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - blocking = functools.partial( - ldh_service.dispatch_submissions, submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE - ) - loop = asyncio.get_running_loop() - submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while dispatching submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - assert not submission_failures, f"{len(submission_failures)} submissions failed to be dispatched to the LDH." - logger.info(msg="Dispatched all variant mapping submissions to the LDH.", extra=logging_context) - except AssertionError as e: - send_slack_error(e) - send_slack_message( - text=f"{len(submission_failures)} submissions failed to be dispatched to the LDH for score set {score_set.urn}." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission failed to submit all mapping resources. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - new_job_id = None - try: - new_job = await redis.enqueue_job( - "link_clingen_variants", - correlation_id, - score_set.id, - 1, - _defer_by=timedelta(seconds=LINKING_BACKOFF_IN_SECONDS), - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["link_clingen_variants_job_id"] = new_job_id - logger.info(msg="Queued a new ClinGen linking job.", extra=logging_context) - - else: - raise LinkingEnqueueError() - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to enqueue a linking job. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": new_job_id} - - return {"success": True, "retried": False, "enqueued_job": new_job_id} - - -def do_clingen_fetch(variant_urns): - return [(variant_urn, get_clingen_variation(variant_urn)) for variant_urn in variant_urns] - - -async def link_clingen_variants(ctx: dict, correlation_id: str, score_set_id: int, attempt: int) -> dict: - logging_context = {} - score_set = None - text = "Could not link mappings to LDH for score set %s. Mappings for this score set should be linked manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logging_context["linkage_retry_threshold"] = LINKED_DATA_RETRY_THRESHOLD - logging_context["attempt"] = attempt - logging_context["max_attempts"] = BACKOFF_LIMIT - logger.info(msg="Started LDH mapped resource linkage", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to link LDH objects for this score set." - - logging_context["current_ldh_linking_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for ldh mapped resource linkage.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_urns = db.scalars( - select(Variant.urn) - .join(MappedVariant) - .join(ScoreSet) - .where( - ScoreSet.urn == score_set.urn, MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None) - ) - ).all() - num_variant_urns = len(variant_urns) - - logging_context["variants_to_link_ldh"] = num_variant_urns - - if not variant_urns: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH linkage (nothing to do). A gnomAD linkage job will not be enqueued, as no variants will have a CAID.", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - logger.info( - msg="Found current mapped variants with post mapped metadata for this score set. Attempting to link them to LDH submissions.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - logger.info(msg="Attempting to link mapped variants to LDH submissions.", extra=logging_context) - - # TODO#372: Non-nullable variant urns. - blocking = functools.partial( - do_clingen_fetch, - variant_urns, # type: ignore - ) - loop = asyncio.get_running_loop() - linked_data = await loop.run_in_executor(ctx["pool"], blocking) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - linked_allele_ids = [ - (variant_urn, clingen_allele_id_from_ldh_variation(clingen_variation)) - for variant_urn, clingen_variation in linked_data - ] - - linkage_failures = [] - for variant_urn, ldh_variation in linked_allele_ids: - # XXX: Should we unlink variation if it is not found? Does this constitute a failure? - if not ldh_variation: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No LDH variation found.", - extra=logging_context, - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant = db.scalars( - select(MappedVariant).join(Variant).where(Variant.urn == variant_urn, MappedVariant.current.is_(True)) - ).one_or_none() - - if not mapped_variant: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No mapped variant found.", - extra=logging_context, - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant.clingen_allele_id = ldh_variation - db.add(mapped_variant) - - db.commit() - - except Exception as e: - db.rollback() - - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - num_linkage_failures = len(linkage_failures) - ratio_failed_linking = round(num_linkage_failures / num_variant_urns, 3) - logging_context["linkage_failure_rate"] = ratio_failed_linking - logging_context["linkage_failures"] = num_linkage_failures - logging_context["linkage_successes"] = num_variant_urns - num_linkage_failures - - assert ( - len(linked_allele_ids) == num_variant_urns - ), f"{num_variant_urns - len(linked_allele_ids)} appear to not have been attempted to be linked." - - job_succeeded = False - if not linkage_failures: - logger.info( - msg="Successfully linked all mapped variants to LDH submissions.", - extra=logging_context, - ) - - job_succeeded = True - - elif ratio_failed_linking < LINKED_DATA_RETRY_THRESHOLD: - logger.warning( - msg="Linkage failures exist, but did not exceed the retry threshold.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} mapped variants to LDH submissions for score set {score_set.urn}." - f"The retry threshold was not exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." - ) - - job_succeeded = True - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to finalize linkage. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - if job_succeeded: - gnomad_linking_job_id = None - try: - new_job = await redis.enqueue_job( - "link_gnomad_variants", - correlation_id, - score_set.id, - ) - - if new_job: - gnomad_linking_job_id = new_job.job_id - - logging_context["link_gnomad_variants_job_id"] = gnomad_linking_job_id - logger.info(msg="Queued a new gnomAD linking job.", extra=logging_context) - - else: - raise LinkingEnqueueError() - - except Exception as e: - job_succeeded = False - - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to enqueue a gnomAD linking job. GnomAD variants should be linked manually for this score set. This job will not be retried.", - extra=logging_context, - ) - finally: - return {"success": job_succeeded, "retried": False, "enqueued_job": gnomad_linking_job_id} - - # If we reach this point, we should consider the job failed (there were failures which exceeded our retry threshold). - new_job_id = None - max_retries_exceeded = None - try: - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - ctx["redis"], "variant_mapper_manager", attempt, LINKING_BACKOFF_IN_SECONDS, correlation_id - ) - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.critical( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to retry a failed linkage job. This job will not be retried.", - extra=logging_context, - ) - else: - if new_job_id and not max_retries_exceeded: - logger.info( - msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was queued.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking * 100}% of total mapped variants for {score_set.urn})." - f"This job was successfully retried. This was attempt {attempt}. Retry will occur in {backoff_time} seconds. URNs failed to link: {', '.join(linkage_failures)}." - ) - elif new_job_id is None and not max_retries_exceeded: - logger.error( - msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was unable to be queued.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." - f"This job could not be retried due to an unexpected issue while attempting to enqueue another linkage job. This was attempt {attempt}. URNs failed to link: {', '.join(linkage_failures)}." - ) - else: - logger.error( - msg="After a failure condition while linking mapped variants to LDH submissions, the maximum retries for this job were exceeded. The reamining linkage failures will not be retried.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." - f"The retry threshold was exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." - ) - - finally: - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_job": new_job_id, - } - - -######################################################################################################## -# Mapping between Mapped Metadata and UniProt IDs -######################################################################################################## - - -async def submit_uniprot_mapping_jobs_for_score_set(ctx, score_set_id: int, correlation_id: Optional[str] = None): - logging_context = {} - score_set = None - spawned_mapping_jobs: dict[int, Optional[str]] = {} - text = "Could not submit mapping jobs to UniProt for this score set %s. Mapping jobs for this score set should be submitted manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started UniProt mapping job", extra=logging_context) - - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set_id}. Skipped mapping targets to UniProt." - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) - - return {"success": True, "retried": False, "enqueued_jobs": []} - - except Exception as e: - send_slack_error(e) - if score_set: - msg = text % score_set.urn - else: - msg = text % score_set_id - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - try: - uniprot_api = UniProtIDMappingAPI() - logging_context["total_target_genes_to_map_to_uniprot"] = len(score_set.target_genes) - for target_gene in score_set.target_genes: - spawned_mapping_jobs[target_gene.id] = None # type: ignore - - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - ac_to_map = acs[0] - from_db = infer_db_name_from_sequence_accession(ac_to_map) - - try: - spawned_mapping_jobs[target_gene.id] = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore - except Exception as e: - log_and_send_slack_message( - msg=f"Failed to submit UniProt mapping job for target gene {target_gene.id}: {e}. This target will be skipped.", - ctx=logging_context, - level=logging.WARNING, - ) - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg=f"UniProt mapping job encountered an unexpected error while attempting to submit mapping jobs for score set {score_set.urn}. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, - ) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - new_job_id = None - try: - successfully_spawned_mapping_jobs = sum(1 for job in spawned_mapping_jobs.values() if job is not None) - logging_context["successfully_spawned_mapping_jobs"] = successfully_spawned_mapping_jobs - - if not successfully_spawned_mapping_jobs: - msg = f"No UniProt mapping jobs were successfully spawned for score set {score_set.urn}. Skipped enqueuing polling job." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - return {"success": True, "retried": False, "enqueued_jobs": []} - - new_job = await redis.enqueue_job( - "poll_uniprot_mapping_jobs_for_score_set", - spawned_mapping_jobs, - score_set_id, - correlation_id, - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["poll_uniprot_mapping_job_id"] = new_job_id - logger.info(msg="Enqueued polling jobs for UniProt mapping jobs.", extra=logging_context) - - else: - raise UniProtPollingEnqueueError() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg="UniProt mapping job encountered an unexpected error while attempting to enqueue polling jobs for mapping jobs. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, - ) - - return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} - - return {"success": True, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} - - -async def poll_uniprot_mapping_jobs_for_score_set( - ctx, mapping_jobs: dict[int, Optional[str]], score_set_id: int, correlation_id: Optional[str] = None -): - logging_context = {} - score_set = None - text = "Could not poll mapping jobs from UniProt for this Target %s. Mapping jobs for this score set should be submitted manually." - try: - db: Session = ctx["db"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started UniProt polling job", extra=logging_context) - - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set_id}. Skipped polling targets for UniProt mapping results." - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) - - return {"success": True, "retried": False, "enqueued_jobs": []} - - except Exception as e: - send_slack_error(e) - if score_set: - msg = text % score_set.urn - else: - msg = text % score_set_id - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - try: - uniprot_api = UniProtIDMappingAPI() - for target_gene in score_set.target_genes: - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - mapped_ac = acs[0] - job_id = mapping_jobs.get(target_gene.id) # type: ignore - - if not job_id: - msg = f"No job ID found for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - # This issue has already been sent to Slack in the job submission function, so we just log it here. - logger.debug(msg=msg, extra=logging_context) - continue - - if not uniprot_api.check_id_mapping_results_ready(job_id): - msg = f"Job {job_id} not ready for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target" - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - results = uniprot_api.get_id_mapping_results(job_id) - mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) - - if not mapped_ids: - msg = f"No UniProt ID found for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(mapped_ids) != 1: - msg = f"Found ambiguous Uniprot ID mapping results for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] - target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id - db.add(target_gene) - logger.info( - msg=f"Updated target gene {target_gene.id} with UniProt ID {mapped_uniprot_id}", extra=logging_context - ) - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg="UniProt mapping job encountered an unexpected error while attempting to poll mapping jobs. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, - ) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - db.commit() - return {"success": True, "retried": False, "enqueued_jobs": []} - - -#################################################################################################### -# gnomAD Variant Linkage -#################################################################################################### - - -async def link_gnomad_variants(ctx: dict, correlation_id: str, score_set_id: int) -> dict: - logging_context = {} - score_set = None - text = "Could not link mappings to gnomAD variants for score set %s. Mappings for this score set should be linked manually." - try: - db: Session = ctx["db"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started gnomAD variant linkage", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to link gnomAD objects for this score set." - - logging_context["current_gnomad_linking_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for gnomAD mapped resource linkage.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - # We filter out mapped variants that do not have a CAID, so this query is typed # as a Sequence[str]. Ignore MyPy's type checking here. - variant_caids: Sequence[str] = db.scalars( - select(MappedVariant.clingen_allele_id) - .join(Variant) - .join(ScoreSet) - .where( - ScoreSet.urn == score_set.urn, - MappedVariant.current.is_(True), - MappedVariant.clingen_allele_id.is_not(None), - ) - ).all() # type: ignore - num_variant_caids = len(variant_caids) - - logging_context["num_variants_to_link_gnomad"] = num_variant_caids - - if not variant_caids: - logger.warning( - msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - logger.info( - msg="Found current mapped variants with CAIDs for this score set. Attempting to link them to gnomAD variants.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - gnomad_variant_data = gnomad_variant_data_for_caids(variant_caids) - num_gnomad_variants_with_caid_match = len(gnomad_variant_data) - logging_context["num_gnomad_variants_with_caid_match"] = num_gnomad_variants_with_caid_match - - if not gnomad_variant_data: - logger.warning( - msg="No gnomAD variants with CAID matches were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to fetch gnomAD variant data from S3 via Athena. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=logging_context) - num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(db, gnomad_variant_data) - db.commit() - logging_context["num_mapped_variants_linked_to_gnomad_variants"] = num_linked_gnomad_variants - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=logging_context) - return {"success": True, "retried": False, "enqueued_job": None} diff --git a/src/mavedb/worker/jobs/__init__.py b/src/mavedb/worker/jobs/__init__.py new file mode 100644 index 00000000..e421bbad --- /dev/null +++ b/src/mavedb/worker/jobs/__init__.py @@ -0,0 +1,54 @@ +"""MaveDB Worker Job Functions. + +This package contains all worker job functions organized by domain: +- variant_processing: Variant creation and VRS mapping jobs +- external_services: Third-party service integration jobs (ClinGen, UniProt, gnomAD) +- data_management: Database and materialized view management jobs +- utils: Shared utilities for job state, retry logic, and constants + +All job functions are exported at the package level for easy import +by the worker settings and other modules. Additionally, a job registry +is provided for ARQ worker configuration. +""" + +from mavedb.worker.jobs.data_management.views import ( + refresh_materialized_views, + refresh_published_variants_view, +) +from mavedb.worker.jobs.external_services.clingen import ( + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, +) +from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants +from mavedb.worker.jobs.external_services.uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.jobs.registry import ( + BACKGROUND_CRONJOBS, + BACKGROUND_FUNCTIONS, + STANDALONE_JOB_DEFINITIONS, +) +from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set +from mavedb.worker.jobs.variant_processing.mapping import ( + map_variants_for_score_set, +) + +__all__ = [ + # Variant processing jobs + "create_variants_for_score_set", + "map_variants_for_score_set", + # External service integration jobs + "submit_score_set_mappings_to_car", + "submit_score_set_mappings_to_ldh", + "poll_uniprot_mapping_jobs_for_score_set", + "submit_uniprot_mapping_jobs_for_score_set", + "link_gnomad_variants", + # Data management jobs + "refresh_materialized_views", + "refresh_published_variants_view", + # Job registry and utilities + "BACKGROUND_FUNCTIONS", + "BACKGROUND_CRONJOBS", + "STANDALONE_JOB_DEFINITIONS", +] diff --git a/src/mavedb/worker/jobs/data_management/__init__.py b/src/mavedb/worker/jobs/data_management/__init__.py new file mode 100644 index 00000000..63502581 --- /dev/null +++ b/src/mavedb/worker/jobs/data_management/__init__.py @@ -0,0 +1,16 @@ +"""Data management job functions. + +This module exports jobs for database and view management: +- Materialized view refresh for optimized query performance +- Database maintenance and cleanup operations +""" + +from .views import ( + refresh_materialized_views, + refresh_published_variants_view, +) + +__all__ = [ + "refresh_materialized_views", + "refresh_published_variants_view", +] diff --git a/src/mavedb/worker/jobs/data_management/py.typed b/src/mavedb/worker/jobs/data_management/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/data_management/views.py b/src/mavedb/worker/jobs/data_management/views.py new file mode 100644 index 00000000..abf787c2 --- /dev/null +++ b/src/mavedb/worker/jobs/data_management/views.py @@ -0,0 +1,114 @@ +"""Database materialized view refresh jobs. + +This module contains jobs for refreshing materialized views used throughout +the MaveDB application. Materialized views provide optimized, pre-computed +data for complex queries and are refreshed periodically to maintain +data consistency and performance. +""" + +import logging + +from mavedb.db.view import refresh_all_mat_views +from mavedb.models.published_variant import PublishedVariantsMV +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.job_guarantee import with_guaranteed_job_run_record +from mavedb.worker.lib.decorators.job_management import with_job_management +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +# TODO#405: Refresh materialized views within an executor. +@with_guaranteed_job_run_record("cron_job") +@with_job_management +async def refresh_materialized_views(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Refresh all materialized views in the database. + + This job refreshes all materialized views to ensure that they are up-to-date + with the latest data. It is typically run as a scheduled cron job and meant + to be invoked indirectly via a job queue system. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job run. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Refreshes all materialized views in the database. + + Returns: + dict: Result indicating success and any exception details + """ + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "refresh_materialized_views", + "resource": "all_materialized_views", + "correlation_id": None, + } + ) + job_manager.update_progress(0, 100, "Starting refresh of all materialized views.") + logger.debug(msg="Began refresh of all materialized views.", extra=job_manager.logging_context()) + + # Do refresh + refresh_all_mat_views(job_manager.db) + job_manager.db.flush() + + # Finalize job state + job_manager.update_progress(100, 100, "Completed refresh of all materialized views.") + logger.debug(msg="Done refreshing materialized views.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception": None} + + +@with_pipeline_management +async def refresh_published_variants_view(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Refresh the published variants materialized view. + + This job refreshes the PublishedVariantsMV materialized view to ensure that it + is up-to-date with the latest data. It is meant to be invoked as part of a job queue system. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job run. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Refreshes the PublishedVariantsMV materialized view in the database. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "refresh_published_variants_view", + "resource": "published_variants_materialized_view", + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting refresh of published variants materialized view.") + logger.info(msg="Started refresh of published variants materialized view", extra=job_manager.logging_context()) + + # Do refresh + PublishedVariantsMV.refresh(job_manager.db) + job_manager.db.flush() + + # Finalize job state + job_manager.update_progress(100, 100, "Completed refresh of published variants materialized view.") + logger.debug(msg="Done refreshing published variants materialized view.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/__init__.py b/src/mavedb/worker/jobs/external_services/__init__.py new file mode 100644 index 00000000..eb88b7e9 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/__init__.py @@ -0,0 +1,28 @@ +"""External service integration job functions. + +This module exports jobs for integrating with third-party services: +- ClinGen (Clinical Genome Resource) for allele registration and data submission +- UniProt for protein sequence annotation and ID mapping +- gnomAD for population frequency and genomic context data +""" + +# External services job functions +from .clingen import ( + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, +) +from .clinvar import refresh_clinvar_controls +from .gnomad import link_gnomad_variants +from .uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) + +__all__ = [ + "submit_score_set_mappings_to_car", + "submit_score_set_mappings_to_ldh", + "refresh_clinvar_controls", + "link_gnomad_variants", + "poll_uniprot_mapping_jobs_for_score_set", + "submit_uniprot_mapping_jobs_for_score_set", +] diff --git a/src/mavedb/worker/jobs/external_services/clingen.py b/src/mavedb/worker/jobs/external_services/clingen.py new file mode 100644 index 00000000..e67e4337 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/clingen.py @@ -0,0 +1,414 @@ +"""ClinGen integration jobs for variant submission and linking. + +This module contains jobs for submitting mapped variants to ClinGen services: +- ClinGen Allele Registry (CAR) for allele registration +- ClinGen Linked Data Hub (LDH) for data submission +- Variant linking and association management + +These jobs enable integration with the ClinGen ecosystem for clinical +variant interpretation and data sharing. +""" + +import asyncio +import functools +import logging + +from sqlalchemy import select + +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.lib.clingen.constants import ( + CAR_SUBMISSION_ENDPOINT, + CLIN_GEN_SUBMISSION_ENABLED, + DEFAULT_LDH_SUBMISSION_BATCH_SIZE, + LDH_SUBMISSION_ENDPOINT, +) +from mavedb.lib.clingen.content_constructors import construct_ldh_submission +from mavedb.lib.clingen.services import ( + ClinGenAlleleRegistryService, + ClinGenLdhService, + get_allele_registry_associations, +) +from mavedb.lib.exceptions import LDHSubmissionFailureError +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Submit mapped variants for a score set to the ClinGen Allele Registry (CAR). + + This job registers mapped variants with CAR, assigns ClinGen Allele IDs (CAIDs), + and updates the database with the results. Progress is tracked throughout the submission. + + Required job_params in the JobRun: + - score_set_id (int): ID of the ScoreSet to process + - correlation_id (str): Correlation ID for tracking + + Args: + ctx (dict): Worker context containing DB and Redis connections + job_manager (JobManager): Manager for job lifecycle and DB operations + + Side Effects: + - Updates MappedVariant records with ClinGen Allele IDs + - Submits data to ClinGen Allele Registry + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "submit_score_set_mappings_to_car", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting CAR mapped resource submission.") + logger.info(msg="Started CAR mapped resource submission", extra=job_manager.logging_context()) + + # Ensure we've enabled ClinGen submission + if not CLIN_GEN_SUBMISSION_ENABLED: + job_manager.update_progress(100, 100, "ClinGen submission is disabled. Skipping CAR submission.") + logger.warning( + msg="ClinGen submission is disabled via configuration, skipping submission of mapped variants to CAR.", + extra=job_manager.logging_context(), + ) + return {"status": "skipped", "data": {}, "exception": None} + + # Check for CAR submission endpoint + if not CAR_SUBMISSION_ENDPOINT: + job_manager.update_progress(100, 100, "CAR submission endpoint not configured. Can't complete submission.") + logger.warning( + msg="ClinGen Allele Registry submission is disabled (no submission endpoint), unable to complete submission of mapped variants to CAR.", + extra=job_manager.logging_context(), + ) + return { + "status": "failed", + "data": {}, + "exception": ValueError("ClinGen Allele Registry submission endpoint is not configured."), + } + + # Fetch mapped variants with post-mapped data for the score set + variant_post_mapped_objects = job_manager.db.execute( + select(MappedVariant.id, MappedVariant.post_mapped) + .join(Variant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn) + .where(MappedVariant.post_mapped.is_not(None)) + .where(MappedVariant.current.is_(True)) + ).all() + + # Track total variants to submit + job_manager.save_to_context({"total_variants_to_submit_car": len(variant_post_mapped_objects)}) + if not variant_post_mapped_objects: + job_manager.update_progress(100, 100, "No mapped variants to submit to CAR. Skipped submission.") + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception": None} + + job_manager.update_progress( + 10, 100, f"Preparing {len(variant_post_mapped_objects)} mapped variants for CAR submission." + ) + + # Build HGVS strings for submission. Don't do duplicate submissions-- store mapped variant IDs by HGVS. + variant_post_mapped_hgvs: dict[str, list[int]] = {} + for mapped_variant_id, post_mapped in variant_post_mapped_objects: + hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) + + if not hgvs_for_post_mapped: + logger.warning( + msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant_id}. Skipping submission of this variant.", + extra=job_manager.logging_context(), + ) + continue + + if hgvs_for_post_mapped in variant_post_mapped_hgvs: + variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) + else: + variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] + + job_manager.save_to_context({"unique_variants_to_submit_car": len(variant_post_mapped_hgvs)}) + job_manager.update_progress(15, 100, "Submitting mapped variants to CAR.") + + # Do submission + car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) + registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) + job_manager.update_progress(60, 100, "Processing registered alleles from CAR.") + + # Process registered alleles and update mapped variants + linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) + total = len(linked_alleles) + processed = 0 + # Setup annotation manager + annotation_manager = AnnotationStatusManager(job_manager.db) + registered_mapped_variant_ids = [] + for hgvs_string, caid in linked_alleles.items(): + mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] + registered_mapped_variant_ids.extend(mapped_variant_ids) + mapped_variants = job_manager.db.scalars( + select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids)) + ).all() + + for mapped_variant in mapped_variants: + mapped_variant.clingen_allele_id = caid + job_manager.db.add(mapped_variant) + + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version=None, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "success_data": {"clingen_allele_id": caid}, + }, + current=True, + ) + + processed += 1 + + # Calculate progress: 50% + (processed/total_mapped)*50, rounded to nearest 5% + if total % 20 == 0 or processed == total: + progress = 50 + round((processed / total) * 45 / 5) * 5 + job_manager.update_progress(progress, 100, f"Processed {processed} of {total} registered alleles.") + + # For mapped variants which did not get a CAID, log failure annotation + failed_submissions = set(obj[0] for obj in variant_post_mapped_objects) - set(registered_mapped_variant_ids) + for mapped_variant_id in failed_submissions: + mapped_variant = job_manager.db.scalars( + select(MappedVariant).where(MappedVariant.id == mapped_variant_id) + ).one() + + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version=None, + status=AnnotationStatus.FAILED, + annotation_data={ + "error_message": "Failed to register variant with ClinGen Allele Registry.", + }, + current=True, + ) + + # Finalize progress + job_manager.update_progress(100, 100, "Completed CAR mapped resource submission.") + job_manager.db.flush() + logger.info(msg="Completed CAR mapped resource submission", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception": None} + + +@with_pipeline_management +async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Submit mapped variants for a score set to the ClinGen Linked Data Hub (LDH). + + This job submits mapped variant data to LDH for a given score set, handling authentication, + submission batching, and error reporting. Progress and errors are logged and reported to Slack. + + Required job_params in the JobRun: + - score_set_id (int): ID of the ScoreSet to process + - correlation_id (str): Correlation ID for tracking + + Args: + ctx (dict): Worker context containing DB and Redis connections + job_manager (JobManager): Manager for job lifecycle and DB operations + + Side Effects: + - Submits data to ClinGen Linked Data Hub + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "submit_score_set_mappings_to_ldh", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting LDH mapped resource submission.") + logger.info(msg="Started LDH mapped resource submission", extra=job_manager.logging_context()) + + # Connect to LDH service + ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) + ldh_service.authenticate() + + # Fetch mapped variants with post-mapped data for the score set + variant_objects = job_manager.db.execute( + select(Variant, MappedVariant) + .join(MappedVariant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn) + .where(MappedVariant.post_mapped.is_not(None)) + .where(MappedVariant.current.is_(True)) + ).all() + + # Track total variants to submit + job_manager.save_to_context({"total_variants_to_submit_ldh": len(variant_objects)}) + if not variant_objects: + job_manager.update_progress(100, 100, "No mapped variants to submit to LDH. Skipping submission.") + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception": None} + job_manager.update_progress(10, 100, f"Submitting {len(variant_objects)} mapped variants to LDH.") + + # Build submission content + variant_content = [] + variant_for_urn = {} + for variant, mapped_variant in variant_objects: + variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) + + if not variation: + logger.warning( + msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant.id}. Skipping submission of this variant.", + extra=job_manager.logging_context(), + ) + continue + + variant_content.append((variation, variant, mapped_variant)) + variant_for_urn[variant.urn] = variant + + if not variant_content: + job_manager.update_progress(100, 100, "No valid mapped variants to submit to LDH. Skipping submission.") + logger.warning( + msg="No valid mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception": None} + + job_manager.save_to_context({"unique_variants_to_submit_ldh": len(variant_content)}) + job_manager.update_progress(30, 100, f"Dispatching submissions for {len(variant_content)} unique variants to LDH.") + submission_content = construct_ldh_submission(variant_content) + + blocking = functools.partial( + ldh_service.dispatch_submissions, submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE + ) + loop = asyncio.get_running_loop() + submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) + job_manager.update_progress(90, 100, "Finalizing LDH mapped resource submission.") + job_manager.save_to_context( + { + "ldh_submission_successes": len(submission_successes), + "ldh_submission_failures": len(submission_failures), + } + ) + + # TODO prior to finalizing: Verify typing of ClinGen submission responses. See https://reg.clinicalgenome.org/doc/AlleleRegistry_1.01.xx_api_v1.pdf + annotation_manager = AnnotationStatusManager(job_manager.db) + submitted_variant_urns = set() + for success in submission_successes: + logger.debug( + msg=f"Successfully submitted mapped variant to LDH: {success}", + extra=job_manager.logging_context(), + ) + + submitted_urn = success["data"]["entId"] + submitted_variant = variant_for_urn[submitted_urn] + + annotation_manager.add_annotation( + variant_id=submitted_variant.id, + annotation_type=AnnotationType.LDH_SUBMISSION, + version=None, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "success_data": {"ldh_iri": success["data"]["ldhIri"], "ldh_id": success["data"]["ldhId"]}, + }, + current=True, + ) + submitted_variant_urns.add(submitted_urn) + + # It isn't trivial to map individual failures back to their corresponding variants, + # especially when submission occurred in batch. Save all failures generically here. + # Note that failures may not be present in the submission failures list, but they are + # guaranteed to be absent from the successes list. + for failure_urn in set(variant_for_urn.keys()) - submitted_variant_urns: + logger.error( + msg=f"Failed to submit mapped variant to LDH: {failure_urn}", + extra=job_manager.logging_context(), + ) + + failed_variant = variant_for_urn[failure_urn] + + annotation_manager.add_annotation( + variant_id=failed_variant.id, + annotation_type=AnnotationType.LDH_SUBMISSION, + version=None, + status=AnnotationStatus.FAILED, + annotation_data={ + "error_message": "Failed to submit variant to ClinGen Linked Data Hub.", + }, + current=True, + ) + + if submission_failures: + logger.warning( + msg=f"LDH mapped resource submission encountered {len(submission_failures)} failures.", + extra=job_manager.logging_context(), + ) + + if not submission_successes: + job_manager.update_progress(100, 100, "All mapped variant submissions to LDH failed.") + error_message = f"All LDH submissions failed for score set {score_set.urn}." + logger.error( + msg=error_message, + extra=job_manager.logging_context(), + ) + + # Return a failure state here rather than raising to indicate to the manager + # we should still commit any successful annotations. + return { + "status": "failed", + "data": {}, + "exception": LDHSubmissionFailureError(error_message), + } + + logger.info( + msg="Completed LDH mapped resource submission", + extra=job_manager.logging_context(), + ) + + # Finalize progress + job_manager.update_progress( + 100, + 100, + f"Finalized LDH mapped resource submission ({len(submission_successes)} successes, {len(submission_failures)} failures).", + ) + job_manager.db.flush() + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/clinvar.py b/src/mavedb/worker/jobs/external_services/clinvar.py new file mode 100644 index 00000000..e66de3e5 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/clinvar.py @@ -0,0 +1,271 @@ +"""ClinVar integration jobs for variant annotation + +This module contains job definitions and utility functions for integrating ClinVar +variant data into MaveDB. It includes functions to fetch and parse ClinVar variant +summary data, and update MaveDB records with the latest ClinVar annotations. +""" + +import asyncio +import functools +import logging + +import requests +from sqlalchemy import select + +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.lib.clingen.allele_registry import get_associated_clinvar_allele_id +from mavedb.lib.clinvar.utils import ( + fetch_clinvar_variant_summary_tsv, + parse_clinvar_variant_summary, + validate_clinvar_variant_summary_date, +) +from mavedb.models.clinical_control import ClinicalControl +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +# TODO#649: This function is currently called multiple times to fill in controls for each month/year. +# We should consider caching both fetched TSV data and/or ClinGen API results. This would +# significantly speed up large jobs annotating many variants. + + +@with_pipeline_management +async def refresh_clinvar_controls(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Job to refresh ClinVar clinical control data in MaveDB. + + This job fetches the latest ClinVar variant summary data and updates + the clinical control records in MaveDB accordingly. + + Args: + ctx (dict): The job context containing necessary information. + job_id (int): The ID of the job being executed. + job_manager (JobManager): The job manager instance for managing job state. + + Returns: + JobResultData: The result of the job execution. + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id", "year", "month"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + year = int(job.job_params["year"]) # type: ignore + month = int(job.job_params["month"]) # type: ignore + + validate_clinvar_variant_summary_date(month, year) + # Version must be in MM_YYYY format + clinvar_version = f"{month:02d}_{year}" + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "refresh_clinvar_controls", + "resource": score_set.urn, + "correlation_id": correlation_id, + "clinvar_year": year, + "clinvar_month": month, + } + ) + job_manager.update_progress(0, 100, f"Starting ClinVar clinical control refresh for version {clinvar_version}.") + logger.info(msg="Started ClinVar clinical control refresh", extra=job_manager.logging_context()) + + job_manager.update_progress(1, 100, "Fetching ClinVar variant summary TSV data.") + logger.debug("Fetching ClinVar variant summary TSV data.", extra=job_manager.logging_context()) + + # Fetch and parse ClinVar variant summary TSV data + blocking = functools.partial(fetch_clinvar_variant_summary_tsv, month, year) + loop = asyncio.get_running_loop() + tsv_content = await loop.run_in_executor(ctx["pool"], blocking) + tsv_data = parse_clinvar_variant_summary(tsv_content) + + job_manager.update_progress(10, 100, "Fetched and parsed ClinVar variant summary TSV data.") + logger.debug("Fetched and parsed ClinVar variant summary TSV data.", extra=job_manager.logging_context()) + + variants_to_refresh = job_manager.db.scalars( + select(MappedVariant) + .join(Variant) + .where( + Variant.score_set_id == score_set.id, + MappedVariant.current.is_(True), + ) + ).all() + total_variants_to_refresh = len(variants_to_refresh) + job_manager.save_to_context({"total_variants_to_refresh": total_variants_to_refresh}) + + logger.info( + f"Refreshing ClinVar data for {total_variants_to_refresh} variants.", extra=job_manager.logging_context() + ) + annotation_manager = AnnotationStatusManager(job_manager.db) + for index, mapped_variant in enumerate(variants_to_refresh): + job_manager.save_to_context({"mapped_variant_id": mapped_variant.id, "progress_index": index}) + if total_variants_to_refresh > 0 and index % (max(total_variants_to_refresh // 100, 1)) == 0: + job_manager.update_progress( + 10 + int((index / total_variants_to_refresh) * 90), + 100, + f"Refreshing ClinVar data for {total_variants_to_refresh} variants ({index} completed).", + ) + + clingen_id = mapped_variant.clingen_allele_id + job_manager.save_to_context({"clingen_allele_id": clingen_id}) + + if clingen_id is None: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "Mapped variant does not have an associated ClinGen allele ID.", + "failure_category": "missing_clingen_allele_id", + }, + ) + logger.debug( + "Mapped variant does not have an associated ClinGen allele ID.", extra=job_manager.logging_context() + ) + continue + + if clingen_id is not None and "," in clingen_id: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "Multi-variant ClinGen allele IDs cannot be associated with ClinVar data.", + "failure_category": "multi_variant_clingen_allele_id", + }, + ) + logger.debug("Detected a multi-variant ClinGen allele ID, skipping.", extra=job_manager.logging_context()) + continue + + # Fetch associated ClinVar Allele ID from ClinGen API + try: + # Guaranteed based on our query filters. + clinvar_allele_id = get_associated_clinvar_allele_id(clingen_id) # type: ignore + except requests.exceptions.RequestException as exc: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.FAILED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": f"Failed to retrieve ClinVar allele ID from ClinGen API: {str(exc)}", + "failure_category": "clingen_api_error", + }, + ) + logger.error( + f"Failed to retrieve ClinVar allele ID from ClinGen API for ClinGen allele ID {clingen_id}.", + extra=job_manager.logging_context(), + exc_info=exc, + ) + continue + + job_manager.save_to_context({"clinvar_allele_id": clinvar_allele_id}) + + if clinvar_allele_id is None: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "No ClinVar allele ID found for ClinGen allele ID.", + "failure_category": "no_associated_clinvar_allele_id", + }, + current=True, + ) + logger.debug("No ClinVar allele ID found for ClinGen allele ID.", extra=job_manager.logging_context()) + continue + + if clinvar_allele_id not in tsv_data: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "No ClinVar data found for ClinVar allele ID.", + "failure_category": "no_clinvar_variant_data", + }, + ) + logger.debug("No ClinVar variant data found for ClinGen allele ID.", extra=job_manager.logging_context()) + continue + + variant_data = tsv_data[clinvar_allele_id] + identifier = str(clinvar_allele_id) + + clinvar_variant = job_manager.db.scalars( + select(ClinicalControl).where( + ClinicalControl.db_identifier == identifier, + ClinicalControl.db_version == clinvar_version, + ClinicalControl.db_name == "ClinVar", + ) + ).one_or_none() + if clinvar_variant is None: + job_manager.save_to_context({"creating_new_clinvar_variant": True}) + clinvar_variant = ClinicalControl( + db_identifier=identifier, + gene_symbol=variant_data.get("GeneSymbol"), + clinical_significance=variant_data.get("ClinicalSignificance"), + clinical_review_status=variant_data.get("ReviewStatus"), + db_version=clinvar_version, + db_name="ClinVar", + ) + else: + job_manager.save_to_context({"creating_new_clinvar_variant": False}) + clinvar_variant.gene_symbol = variant_data.get("GeneSymbol") + clinvar_variant.clinical_significance = variant_data.get("ClinicalSignificance") + clinvar_variant.clinical_review_status = variant_data.get("ReviewStatus") + + # Add and flush the updated/new clinical control + job_manager.db.add(clinvar_variant) + job_manager.db.flush() + + # Link the clinical control to the mapped variant if not already linked + if clinvar_variant not in mapped_variant.clinical_controls: + mapped_variant.clinical_controls.append(clinvar_variant) + job_manager.db.add(mapped_variant) + logger.debug("Linked ClinicalControl to MappedVariant.", extra=job_manager.logging_context()) + + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "job_run_id": job_manager.job_id, + "success_data": { + "clinvar_allele_id": clinvar_allele_id, + }, + }, + current=True, + ) + + logger.debug("Updated ClinVar data for ClinGen allele ID.", extra=job_manager.logging_context()) + + logger.info( + msg=f"Fetched ClinVar variant summary data version {clinvar_version}", extra=job_manager.logging_context() + ) + job_manager.update_progress(100, 100, "Completed ClinVar clinical control refresh.") + + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/gnomad.py b/src/mavedb/worker/jobs/external_services/gnomad.py new file mode 100644 index 00000000..b1e33785 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/gnomad.py @@ -0,0 +1,155 @@ +"""gnomAD variant linking jobs for population frequency annotation. + +This module handles linking of mapped variants to gnomAD (Genome Aggregation Database) +variants to provide population frequency and other genomic context information. +This enrichment helps researchers understand the clinical significance and +rarity of variants in their datasets. +""" + +import logging +from typing import Sequence + +from sqlalchemy import select + +from mavedb.db import athena +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.lib.gnomad import ( + GNOMAD_DATA_VERSION, + gnomad_variant_data_for_caids, + link_gnomad_variants_to_mapped_variants, +) +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def link_gnomad_variants(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Link mapped variants to gnomAD variants based on ClinGen Allele IDs (CAIDs). + This job fetches mapped variants associated with a given score set that have CAIDs, + retrieves corresponding gnomAD variant data, and establishes links between them + in the database. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet containing mapped variants to process. + - correlation_id (str): Correlation ID for tracing requests across services. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job being executed. + job_manager (JobManager): The job manager instance for database and logging operations. + + Side Effects: + - Updates MappedVariant records to link to gnomAD variants. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "link_gnomad_variants", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting gnomAD mapped resource linkage.") + logger.info(msg="Started gnomAD mapped resource linkage", extra=job_manager.logging_context()) + + # We filter out mapped variants that do not have a CAID, so this query is typed # as a Sequence[str]. Ignore MyPy's type checking here. + variant_caids: Sequence[str] = job_manager.db.scalars( + select(MappedVariant.clingen_allele_id) + .join(Variant) + .join(ScoreSet) + .where( + ScoreSet.urn == score_set.urn, + MappedVariant.current.is_(True), + MappedVariant.clingen_allele_id.is_not(None), + ) + ).all() # type: ignore + + num_variant_caids = len(variant_caids) + job_manager.save_to_context({"num_variants_to_link_gnomad": num_variant_caids}) + + if not variant_caids: + job_manager.update_progress(100, 100, "No variants with CAIDs found to link to gnomAD variants. Nothing to do.") + logger.warning( + msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception": None} + + job_manager.update_progress(10, 100, f"Found {num_variant_caids} variants with CAIDs to link to gnomAD variants.") + logger.info( + msg="Found current mapped variants with CAIDs for this score set. Attempting to link them to gnomAD variants.", + extra=job_manager.logging_context(), + ) + + # Fetch gnomAD variant data for the CAIDs + with athena.engine.connect() as athena_session: + logger.debug("Fetching gnomAD variants from Athena.") + gnomad_variant_data = gnomad_variant_data_for_caids(athena_session, variant_caids) + + num_gnomad_variants_with_caid_match = len(gnomad_variant_data) + + # NOTE: Proceed intentionally with linking even if no matches were found, to record skipped annotations. + + job_manager.save_to_context({"num_gnomad_variants_with_caid_match": num_gnomad_variants_with_caid_match}) + job_manager.update_progress(75, 100, f"Found {num_gnomad_variants_with_caid_match} gnomAD variants matching CAIDs.") + + # Link mapped variants to gnomAD variants + logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=job_manager.logging_context()) + num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(job_manager.db, gnomad_variant_data) + job_manager.db.flush() + + # For variants which are not linked, create annotation status records indicating skipped linkage + mapped_variants_with_caids = job_manager.db.scalars( + select(MappedVariant) + .join(Variant) + .join(ScoreSet) + .where( + ScoreSet.urn == score_set.urn, + MappedVariant.current.is_(True), + MappedVariant.clingen_allele_id.is_not(None), + ) + ).all() + annotation_manager = AnnotationStatusManager(job_manager.db) + for mapped_variant in mapped_variants_with_caids: + if not mapped_variant.gnomad_variants: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.GNOMAD_ALLELE_FREQUENCY, + version=GNOMAD_DATA_VERSION, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "error_message": "No gnomAD variant could be linked for this mapped variant.", + "failure_category": "not_found", + }, + current=True, + ) + + # Save final context and progress + job_manager.save_to_context({"num_mapped_variants_linked_to_gnomad_variants": num_linked_gnomad_variants}) + job_manager.update_progress(100, 100, f"Linked {num_linked_gnomad_variants} mapped variants to gnomAD variants.") + logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/py.typed b/src/mavedb/worker/jobs/external_services/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py new file mode 100644 index 00000000..637ff162 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -0,0 +1,322 @@ +"""UniProt ID mapping jobs for protein sequence annotation. + +This module handles the submission and polling of UniProt ID mapping jobs +to enrich target gene metadata with UniProt identifiers. This enables +linking of genomic variants to protein-level functional information. + +The mapping process is asynchronous, requiring both submission and polling +jobs to handle the UniProt API's batch processing workflow. +""" + +import logging +from typing import Optional, TypedDict + +from sqlalchemy import select +from sqlalchemy.orm.attributes import flag_modified + +from mavedb.lib.exceptions import ( + NonExistentTargetGeneError, + UniprotAmbiguousMappingResultError, + UniprotMappingResultNotFoundError, + UniProtPollingEnqueueError, +) +from mavedb.lib.mapping import extract_ids_from_post_mapped_metadata +from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI +from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession +from mavedb.models.job_dependency import JobDependency +from mavedb.models.score_set import ScoreSet +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +class MappingJob(TypedDict): + job_id: Optional[str] + accession: str + + +@with_pipeline_management +async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Submit UniProt ID mapping jobs for all target genes in a given ScoreSet. + + NOTE: This function assumes that a dependent polling job has already been created + for the same ScoreSet. It is the responsibility of this function to ensure that + the polling job exists and to set the `mapping_jobs` parameter on the polling job. + + Without running the polling job, the results of the submitted UniProt mapping jobs + will never be retrieved or processed, so running this function alone is insufficient + to complete the UniProt mapping workflow. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet containing target genes to map. + - correlation_id (str): Correlation ID for tracing requests across services. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job being executed. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Submits UniProt ID mapping jobs for each target gene in the ScoreSet. + - Fetches the dependent job for this function, which is the polling job for UniProt results. + Sets the parameter `mapping_jobs` on the polling job with a dictionary of target gene IDs to UniProt job IDs. + TODO#646: Split mapping jobs into one per target gene so that polling can be more granular. + + Raises: + - UniProtPollingEnqueueError: If the dependent polling job cannot be found. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "submit_uniprot_mapping_jobs_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting UniProt mapping job submission.") + logger.info(msg="Started UniProt mapping job submission", extra=job_manager.logging_context()) + + # Preset submitted jobs metadata so it persists even if no jobs are submitted. + job.metadata_["submitted_jobs"] = {} + job_manager.db.flush() + + if not score_set.target_genes: + job_manager.update_progress(100, 100, "No target genes found. Skipped UniProt mapping job submission.") + logger.error( + msg=f"No target genes found for score set {score_set.urn}. Skipped UniProt mapping job submission.", + extra=job_manager.logging_context(), + ) + + return {"status": "ok", "data": {}, "exception": None} + + uniprot_api = UniProtIDMappingAPI() + job_manager.save_to_context({"total_target_genes_to_map_to_uniprot": len(score_set.target_genes)}) + + mapping_jobs: dict[str, MappingJob] = {} + for idx, target_gene in enumerate(score_set.target_genes): + acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore + if not acs: + logger.warning( + msg=f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped mapping this target.", + extra=job_manager.logging_context(), + ) + continue + + if len(acs) != 1: + logger.warning( + msg=f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped mapping this target.", + extra=job_manager.logging_context(), + ) + continue + + ac_to_map = acs[0] + from_db = infer_db_name_from_sequence_accession(ac_to_map) + spawned_job = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore + + # Explicitly cast ints to strs in mapping job keys. These are converted to strings internally + # by SQLAlchemy when storing job_params as JSON, so be explicit here to avoid confusion. + mapping_jobs[str(target_gene.id)] = {"job_id": spawned_job, "accession": ac_to_map} + + job_manager.save_to_context( + { + "submitted_uniprot_mapping_jobs": { + **job_manager.logging_context().get("submitted_uniprot_mapping_jobs", {}), + str(target_gene.id): mapping_jobs[str(target_gene.id)], + } + } + ) + job_manager.update_progress( + int((idx + 1 / len(score_set.target_genes)) * 95), + 100, + f"Submitted UniProt mapping job for target gene {target_gene.name}.", + ) + logger.info( + msg=f"Submitted UniProt ID mapping job for target gene {target_gene.id}.", + extra=job_manager.logging_context(), + ) + + # Save submitted jobs to job metadata for auditing purposes + job.metadata_["submitted_jobs"] = mapping_jobs + flag_modified(job, "metadata_") + job_manager.db.flush() + + # If no mapping jobs were submitted, log and exit early. + if not mapping_jobs or not any((job_info["job_id"] for job_info in mapping_jobs.values())): + job_manager.update_progress(100, 100, "No UniProt mapping jobs were submitted.") + logger.warning(msg="No UniProt mapping jobs were submitted.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception": None} + + # It's an essential responsibility of the submit job (when submissions exist) to ensure that the polling job exists. + dependent_polling_job = job_manager.db.scalars( + select(JobDependency).where(JobDependency.depends_on_job_id == job.id) + ).all() + if not dependent_polling_job or len(dependent_polling_job) != 1: + job_manager.update_progress(100, 100, "Failed to submit UniProt mapping jobs.") + logger.error( + msg=f"Could not find unique dependent polling job for UniProt mapping job {job.id}.", + extra=job_manager.logging_context(), + ) + + # Return a failure state here rather than raising to indicate to the manager + # we should still commit any successful annotations. + return { + "status": "failed", + "data": {}, + "exception": UniProtPollingEnqueueError( + f"Could not find unique dependent polling job for UniProt mapping job {job.id}." + ), + } + + # Set mapping jobs on dependent polling job. Only one polling job per score set should be created. + polling_job = dependent_polling_job[0].job_run + polling_job.job_params = { + **(polling_job.job_params or {}), + "mapping_jobs": mapping_jobs, + } + + job_manager.update_progress(100, 100, "Completed submission of UniProt mapping jobs.") + logger.info(msg="Completed UniProt mapping job submission", extra=job_manager.logging_context()) + job_manager.db.flush() + return {"status": "ok", "data": {}, "exception": None} + + +@with_pipeline_management +async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Submit UniProt ID mapping jobs for all target genes in a given ScoreSet. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet containing target genes to map. + - correlation_id (str): Correlation ID for tracing requests across services. + - mapping_jobs (dict): Dictionary of target gene IDs to UniProt job IDs. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job being processed. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Polls UniProt ID mapping jobs for each target gene in the ScoreSet. + - Updates target genes with mapped UniProt IDs in the database. + + TODO#646: Split mapping jobs into one per target gene so that polling can be more granular. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id", "mapping_jobs"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + mapping_jobs: dict[str, MappingJob] = job.job_params.get("mapping_jobs", {}) # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "poll_uniprot_mapping_jobs_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting UniProt mapping job polling.") + logger.info(msg="Started UniProt mapping job polling", extra=job_manager.logging_context()) + + if not mapping_jobs or not any(mapping_jobs.values()): + job_manager.update_progress(100, 100, "No mapping jobs found to poll.") + logger.warning( + msg=f"No mapping jobs found in job parameters for polling UniProt mapping jobs for score set {score_set.urn}.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception": None} + + # Poll each mapping job and update target genes with UniProt IDs + uniprot_api = UniProtIDMappingAPI() + for target_gene_id, mapping_job in mapping_jobs.items(): + mapping_job_id = mapping_job["job_id"] + + if not mapping_job_id: + logger.warning( + msg=f"No UniProt mapping job ID found for target gene ID {target_gene_id}. Skipped polling this job.", + extra=job_manager.logging_context(), + ) + continue + + # Check if the mapping job is ready + if not uniprot_api.check_id_mapping_results_ready(mapping_job_id): + logger.warning( + msg=f"Job {mapping_job_id} not ready. Skipped polling this job.", + extra=job_manager.logging_context(), + ) + # TODO#XXX: When results are not ready, we want to signal to the manager a desire to retry + # this polling job later. For now, we just skip and log. + continue + + # Extract mapped UniProt IDs from results + results = uniprot_api.get_id_mapping_results(mapping_job_id) + mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) + mapped_ac = mapping_job["accession"] + + # Handle cases where no or ambiguous results are found + if not mapped_ids: + msg = f"No UniProt ID found for accession {mapped_ac}. Cannot add UniProt ID." + job_manager.update_progress(100, 100, msg) + logger.error(msg=msg, extra=job_manager.logging_context()) + raise UniprotMappingResultNotFoundError() + + if len(mapped_ids) != 1: + msg = f"Ambiguous UniProt ID mapping results for accession {mapped_ac}. Cannot add UniProt ID." + job_manager.update_progress(100, 100, msg) + logger.error(msg=msg, extra=job_manager.logging_context()) + raise UniprotAmbiguousMappingResultError() + + mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] + + # Update target gene with mapped UniProt ID + target_gene = next( + (tg for tg in score_set.target_genes if str(tg.id) == str(target_gene_id)), + None, + ) + if not target_gene: + msg = f"Target gene ID {target_gene_id} not found in score set {score_set.urn}. Cannot add UniProt ID." + job_manager.update_progress(100, 100, msg) + logger.error(msg=msg, extra=job_manager.logging_context()) + raise NonExistentTargetGeneError() + + target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id + job_manager.db.add(target_gene) + logger.info( + msg=f"Updated target gene {target_gene.id} with UniProt ID {mapped_uniprot_id}", + extra=job_manager.logging_context(), + ) + job_manager.update_progress( + int((list(score_set.target_genes).index(target_gene) + 1) / len(score_set.target_genes) * 95), + 100, + f"Polled UniProt mapping job for target gene {target_gene.name}.", + ) + + job_manager.update_progress(100, 100, "Completed polling of UniProt mapping jobs.") + job_manager.db.flush() + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/pipeline_management/__init__.py b/src/mavedb/worker/jobs/pipeline_management/__init__.py new file mode 100644 index 00000000..95470f75 --- /dev/null +++ b/src/mavedb/worker/jobs/pipeline_management/__init__.py @@ -0,0 +1,12 @@ +""" +Pipeline management job entrypoints. + +This module exposes job functions for pipeline management, such as starting a pipeline. +Import job functions here and add them to __all__ for job discovery and import convenience. +""" + +from .start_pipeline import start_pipeline + +__all__ = [ + "start_pipeline", +] diff --git a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py new file mode 100644 index 00000000..7dbed7d4 --- /dev/null +++ b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py @@ -0,0 +1,65 @@ +import logging + +from mavedb.lib.exceptions import PipelineNotFoundError +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def start_pipeline(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Start the pipeline associated with the given job. + + This job initializes and starts the pipeline execution process. + It sets up the necessary pipeline management context and triggers + the pipeline coordination. + + NOTE: This function requires a dedicated 'start_pipeline' job run record + in the database. This job run must be created prior to invoking this function + and should be associated with the pipeline to be started. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job run. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Initializes and starts the pipeline execution. + + Returns: + dict: Result indicating success and any exception details + """ + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "start_pipeline", + "resource": f"pipeline_for_job_{job_id}", + "correlation_id": None, + } + ) + job_manager.update_progress(0, 100, "Coordinating pipeline for the first time.") + logger.debug(msg="Coordinating pipeline for the first time.", extra=job_manager.logging_context()) + + if not job_manager.pipeline_id: + return { + "status": "exception", + "data": {}, + "exception": PipelineNotFoundError("No pipeline associated with this job."), + } + + # Initialize PipelineManager and coordinate pipeline. The pipeline manager decorator + # will have started the pipeline for us already, but doesn't coordinate on start automatically. + redis = job_manager.redis or ctx["redis"] + pipeline_manager = PipelineManager(job_manager.db, redis, job_manager.pipeline_id) + await pipeline_manager.coordinate_pipeline() + + # Finalize job state + job_manager.db.flush() + job_manager.update_progress(100, 100, "Initial pipeline coordination complete.") + logger.debug(msg="Done starting pipeline.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/py.typed b/src/mavedb/worker/jobs/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py new file mode 100644 index 00000000..d2aab06b --- /dev/null +++ b/src/mavedb/worker/jobs/registry.py @@ -0,0 +1,154 @@ +"""Job registry for worker configuration. + +This module provides a centralized registry of all available worker jobs +as simple lists for ARQ worker configuration. +""" + +from datetime import timedelta +from typing import Callable, List + +from arq.cron import CronJob, cron + +from mavedb.lib.types.workflow import JobDefinition +from mavedb.models.enums.job_pipeline import JobType +from mavedb.worker.jobs.data_management import ( + refresh_materialized_views, + refresh_published_variants_view, +) +from mavedb.worker.jobs.external_services import ( + link_gnomad_variants, + poll_uniprot_mapping_jobs_for_score_set, + refresh_clinvar_controls, + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.jobs.pipeline_management import start_pipeline +from mavedb.worker.jobs.variant_processing import ( + create_variants_for_score_set, + map_variants_for_score_set, +) + +# All job functions for ARQ worker +BACKGROUND_FUNCTIONS: List[Callable] = [ + # Variant processing jobs + create_variants_for_score_set, + map_variants_for_score_set, + # External service jobs + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, + refresh_clinvar_controls, + submit_uniprot_mapping_jobs_for_score_set, + poll_uniprot_mapping_jobs_for_score_set, + link_gnomad_variants, + # Data management jobs + refresh_materialized_views, + refresh_published_variants_view, + # Pipeline management jobs + start_pipeline, +] + +# Cron job definitions for ARQ worker +BACKGROUND_CRONJOBS: List[CronJob] = [ + cron( + refresh_materialized_views, + name="refresh_all_materialized_views", + hour=20, + minute=0, + keep_result=timedelta(minutes=2).total_seconds(), + ), +] + + +STANDALONE_JOB_DEFINITIONS: dict[Callable, JobDefinition] = { + create_variants_for_score_set: { + "dependencies": [], + "params": { + "score_set_id": None, + "updater_id": None, + "correlation_id": None, + "scores_file_key": None, + "counts_file_key": None, + "score_columns_metadata": None, + "count_columns_metadata": None, + }, + "function": "create_variants_for_score_set", + "key": "create_variants_for_score_set", + "type": JobType.VARIANT_CREATION, + }, + map_variants_for_score_set: { + "dependencies": [], + "params": {"score_set_id": None, "updater_id": None, "correlation_id": None}, + "function": "map_variants_for_score_set", + "key": "map_variants_for_score_set", + "type": JobType.VARIANT_MAPPING, + }, + submit_score_set_mappings_to_car: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "submit_score_set_mappings_to_car", + "key": "submit_score_set_mappings_to_car", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + submit_score_set_mappings_to_ldh: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "submit_score_set_mappings_to_ldh", + "key": "submit_score_set_mappings_to_ldh", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + refresh_clinvar_controls: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None, "year": None, "month": None}, + "function": "refresh_clinvar_controls", + "key": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + submit_uniprot_mapping_jobs_for_score_set: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "submit_uniprot_mapping_jobs_for_score_set", + "key": "submit_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + poll_uniprot_mapping_jobs_for_score_set: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "poll_uniprot_mapping_jobs_for_score_set", + "key": "poll_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + link_gnomad_variants: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "link_gnomad_variants", + "key": "link_gnomad_variants", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + refresh_materialized_views: { + "dependencies": [], + "params": {"correlation_id": None}, + "function": "refresh_materialized_views", + "key": "refresh_materialized_views", + "type": JobType.DATA_MANAGEMENT, + }, + refresh_published_variants_view: { + "dependencies": [], + "params": {"correlation_id": None}, + "function": "refresh_published_variants_view", + "key": "refresh_published_variants_view", + "type": JobType.DATA_MANAGEMENT, + }, +} +""" +Standalone job definitions for direct job submission outside of pipelines. +All job definitions in this dict must correspond to a job function in BACKGROUND_FUNCTIONS +and must not have any dependencies on other jobs. +""" + + +__all__ = [ + "BACKGROUND_FUNCTIONS", + "BACKGROUND_CRONJOBS", + "STANDALONE_JOB_DEFINITIONS", +] diff --git a/src/mavedb/worker/jobs/utils/__init__.py b/src/mavedb/worker/jobs/utils/__init__.py new file mode 100644 index 00000000..4bdb3409 --- /dev/null +++ b/src/mavedb/worker/jobs/utils/__init__.py @@ -0,0 +1,28 @@ +"""Worker job utility functions and constants. + +This module provides shared utilities used across worker jobs: +- Job state management and context setup +- Retry logic with exponential backoff +- Configuration constants for queues and timeouts + +These utilities help ensure consistent behavior and error handling +across all worker job implementations. +""" + +from .constants import ( + ENQUEUE_BACKOFF_ATTEMPT_LIMIT, + LINKING_BACKOFF_IN_SECONDS, + MAPPING_BACKOFF_IN_SECONDS, + MAPPING_CURRENT_ID_NAME, + MAPPING_QUEUE_NAME, +) +from .setup import validate_job_params + +__all__ = [ + "validate_job_params", + "MAPPING_QUEUE_NAME", + "MAPPING_CURRENT_ID_NAME", + "MAPPING_BACKOFF_IN_SECONDS", + "LINKING_BACKOFF_IN_SECONDS", + "ENQUEUE_BACKOFF_ATTEMPT_LIMIT", +] diff --git a/src/mavedb/worker/jobs/utils/constants.py b/src/mavedb/worker/jobs/utils/constants.py new file mode 100644 index 00000000..cca5a02c --- /dev/null +++ b/src/mavedb/worker/jobs/utils/constants.py @@ -0,0 +1,17 @@ +"""Constants used across worker jobs. + +This module centralizes configuration constants used by various worker jobs +including queue names, timeouts, and retry limits. This provides a single +source of truth for job configuration values. +""" + +### Mapping job constants +MAPPING_QUEUE_NAME = "vrs_mapping_queue" +MAPPING_CURRENT_ID_NAME = "vrs_mapping_current_job_id" +MAPPING_BACKOFF_IN_SECONDS = 15 + +### Linking job constants +LINKING_BACKOFF_IN_SECONDS = 15 * 60 + +### Backoff constants +ENQUEUE_BACKOFF_ATTEMPT_LIMIT = 5 diff --git a/src/mavedb/worker/jobs/utils/py.typed b/src/mavedb/worker/jobs/utils/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/utils/setup.py b/src/mavedb/worker/jobs/utils/setup.py new file mode 100644 index 00000000..b569bb0e --- /dev/null +++ b/src/mavedb/worker/jobs/utils/setup.py @@ -0,0 +1,24 @@ +"""Job state management utilities. + +This module provides utilities for managing job state and context across +the worker job lifecycle. It handles setup of logging context, correlation +IDs, and other state information needed for job traceability and monitoring. +""" + +import logging + +from mavedb.models.job_run import JobRun + +logger = logging.getLogger(__name__) + + +def validate_job_params(required_params: list[str], job: JobRun) -> None: + """ + Validate that the given job has all required parameters present in its job_params. + """ + if not job.job_params: + raise ValueError("Job has no job_params defined.") + + for param in required_params: + if param not in job.job_params: + raise ValueError(f"Missing required job param: {param}") diff --git a/src/mavedb/worker/jobs/variant_processing/__init__.py b/src/mavedb/worker/jobs/variant_processing/__init__.py new file mode 100644 index 00000000..a6df0975 --- /dev/null +++ b/src/mavedb/worker/jobs/variant_processing/__init__.py @@ -0,0 +1,17 @@ +"""Variant processing job functions. + +This module exports jobs responsible for variant creation and mapping: +- Variant creation from uploaded score/count data +- VRS mapping to standardized genomic coordinates +- Queue management for mapping workflows +""" + +from .creation import create_variants_for_score_set +from .mapping import ( + map_variants_for_score_set, +) + +__all__ = [ + "create_variants_for_score_set", + "map_variants_for_score_set", +] diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py new file mode 100644 index 00000000..cee4ff5f --- /dev/null +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -0,0 +1,252 @@ +"""Variant creation jobs for score sets. + +This module contains jobs responsible for creating and validating variants +from uploaded score and count data. It handles the full variant creation +pipeline including data validation, standardization, and database persistence. +""" + +import io +import logging + +import pandas as pd +from sqlalchemy import delete, null, select + +from mavedb.data_providers.services import CSV_UPLOAD_S3_BUCKET_NAME, RESTDataProvider, s3_client +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.score_sets import columns_for_dataset, create_variants, create_variants_data +from mavedb.lib.validation.dataframe.dataframe import validate_and_standardize_dataframe_pair +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Create variants for a given ScoreSet based on uploaded score and count data. + + Args: + ctx: The job context dictionary. + job_id: The ID of the job being executed. + job_manager: Manager for job lifecycle and DB operations. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet to create variants for. + - correlation_id (str): Correlation ID for tracing requests across services. + - updater_id (int): The ID of the user performing the update. + - scores_file_key (str): S3 key for the uploaded scores CSV file. + - counts_file_key (str): S3 key for the uploaded counts CSV file. + - score_columns_metadata (dict): Metadata for score columns. + - count_columns_metadata (dict): Metadata for count columns. + + Side Effects: + - Creates Variant and MappedVariant records in the database. + + Returns: + dict: Result indicating success and any exception details + """ + # Handle everything prior to score set fetch in an outer layer. Any issues prior to + # fetching the score set should fail the job outright and we will be unable to set + # a processing state on the score set itself. + logger.info(msg="Starting create_variants_for_score_set job", extra=job_manager.logging_context()) + hdp: RESTDataProvider = ctx["hdp"] + + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = [ + "score_set_id", + "correlation_id", + "updater_id", + "scores_file_key", + "counts_file_key", + "score_columns_metadata", + "count_columns_metadata", + ] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + + # Main processing block. Handled in a try/except to ensure we can set score set state appropriately, + # which is handled independently of the job state. + # TODO:647 In a future iteration, we should rely on the job manager itself for maintaining processing + # state for better cohesion. This try/except is redundant in it's duties with the job manager. + try: + correlation_id = job.job_params["correlation_id"] # type: ignore + updater_id = job.job_params["updater_id"] # type: ignore + score_file_key = job.job_params["scores_file_key"] # type: ignore + count_file_key = job.job_params["counts_file_key"] # type: ignore + score_columns_metadata = job.job_params["score_columns_metadata"] # type: ignore + count_columns_metadata = job.job_params["count_columns_metadata"] # type: ignore + + job_manager.save_to_context( + { + "score_set_id": score_set.id, + "updater_id": updater_id, + "correlation_id": correlation_id, + "score_file_key": score_file_key, + "count_file_key": count_file_key, + "bucket_name": CSV_UPLOAD_S3_BUCKET_NAME, + } + ) + logger.debug(msg="Fetching file resources from S3 for variant creation", extra=job_manager.logging_context()) + + s3 = s3_client() + scores = io.BytesIO() + s3.download_fileobj(Bucket=CSV_UPLOAD_S3_BUCKET_NAME, Key=score_file_key, Fileobj=scores) + scores.seek(0) + scores_df = pd.read_csv(scores) + + # Counts file is optional + counts_df = None + if count_file_key: + counts = io.BytesIO() + s3.download_fileobj(Bucket=CSV_UPLOAD_S3_BUCKET_NAME, Key=count_file_key, Fileobj=counts) + counts.seek(0) + counts_df = pd.read_csv(counts) + + logger.debug(msg="Successfully fetched file resources from S3", extra=job_manager.logging_context()) + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "create_variants_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting variant creation job.") + logger.info(msg="Started variant creation job", extra=job_manager.logging_context()) + + updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() + + score_set.modified_by = updated_by + score_set.processing_state = ProcessingState.processing + score_set.mapping_state = MappingState.pending_variant_processing + + job_manager.save_to_context( + {"processing_state": score_set.processing_state.name, "mapping_state": score_set.mapping_state.name} + ) + + # Flush initial score set state + job_manager.db.add(score_set) + job_manager.db.flush() + job_manager.db.refresh(score_set) + + job_manager.update_progress(10, 100, "Validated score set metadata and beginning data validation.") + + if not score_set.target_genes: + job_manager.update_progress(100, 100, "Score set has no targets; cannot create variants.") + logger.warning( + msg="No targets are associated with this score set; could not create variants.", + extra=job_manager.logging_context(), + ) + raise ValueError("Can't create variants when score set has no targets.") + + validated_scores, validated_counts, validated_score_columns_metadata, validated_count_columns_metadata = ( + validate_and_standardize_dataframe_pair( + scores_df=scores_df, + counts_df=counts_df, + score_columns_metadata=score_columns_metadata, + count_columns_metadata=count_columns_metadata, + targets=score_set.target_genes, + hdp=hdp, + ) + ) + + job_manager.update_progress(80, 100, "Data validation complete; creating variants in database.") + + score_set.dataset_columns = { + "score_columns": columns_for_dataset(validated_scores), + "count_columns": columns_for_dataset(validated_counts), + "score_columns_metadata": validated_score_columns_metadata + if validated_score_columns_metadata is not None + else {}, + "count_columns_metadata": validated_count_columns_metadata + if validated_count_columns_metadata is not None + else {}, + } + + # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. + if score_set.variants: + existing_variants = job_manager.db.scalars( + select(Variant.id).where(Variant.score_set_id == score_set.id) + ).all() + job_manager.db.execute(delete(MappedVariant).where(MappedVariant.variant_id.in_(existing_variants))) + job_manager.db.execute(delete(Variant).where(Variant.id.in_(existing_variants))) + + job_manager.save_to_context({"deleted_variants": len(existing_variants)}) + score_set.num_variants = 0 + + logger.info(msg="Deleted existing variants from score set.", extra=job_manager.logging_context()) + + job_manager.db.flush() + job_manager.db.refresh(score_set) + + variants_data = create_variants_data(validated_scores, validated_counts, None) + create_variants(job_manager.db, score_set, variants_data) + + except Exception as e: + job_manager.db.rollback() + score_set.processing_state = ProcessingState.failed + score_set.mapping_state = MappingState.not_attempted + + # Capture exception details in score set processing errors for all exceptions. + score_set.processing_errors = {"exception": str(e), "detail": []} + # ValidationErrors arise from problematic input data; capture their details specifically. + if isinstance(e, ValidationError): + score_set.processing_errors["detail"] = e.triggering_exceptions + + if score_set.num_variants: + score_set.processing_errors["exception"] = ( + f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" + ) + + job_manager.save_to_context( + { + "processing_state": score_set.processing_state.name, + "mapping_state": score_set.mapping_state.name, + **format_raised_exception_info_as_dict(e), + "created_variants": 0, + } + ) + job_manager.update_progress(100, 100, "Variant creation job failed due to an internal error.") + logger.error( + msg="Encountered an internal exception while processing variants.", extra=job_manager.logging_context() + ) + + return {"status": "failed" if isinstance(e, ValidationError) else "exception", "data": {}, "exception": e} + + else: + score_set.processing_state = ProcessingState.success + score_set.mapping_state = MappingState.queued + score_set.processing_errors = null() + + job_manager.save_to_context( + { + "processing_state": score_set.processing_state.name, + "mapping_state": score_set.mapping_state.name, + "created_variants": score_set.num_variants, + } + ) + + finally: + job_manager.db.add(score_set) + job_manager.db.flush() + job_manager.db.refresh(score_set) + + job_manager.update_progress(100, 100, "Completed variant creation job.") + logger.info(msg="Added new variants to score set.", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/variant_processing/mapping.py b/src/mavedb/worker/jobs/variant_processing/mapping.py new file mode 100644 index 00000000..eee55a32 --- /dev/null +++ b/src/mavedb/worker/jobs/variant_processing/mapping.py @@ -0,0 +1,318 @@ +"""Variant mapping jobs using VRS (Variant Representation Specification). + +This module handles the mapping of variants to standardized genomic coordinates +using the VRS mapping service. It includes queue management, retry logic, +and coordination with downstream services like ClinGen and UniProt. +""" + +import asyncio +import functools +import logging +from datetime import date +from typing import Any + +from sqlalchemy import cast, null, select +from sqlalchemy.dialects.postgresql import JSONB + +from mavedb.data_providers.services import vrs_mapper +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.lib.exceptions import ( + NoMappedVariantsError, + NonexistentMappingReferenceError, + NonexistentMappingResultsError, + NonexistentMappingScoresError, +) +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.mapping import ANNOTATION_LAYERS, EXCLUDED_PREMAPPED_ANNOTATION_KEYS +from mavedb.lib.slack import send_slack_error +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Map variants for a given score set using VRS.""" + # Handle everything prior to score set fetch in an outer layer. Any issues prior to + # fetching the score set should fail the job outright and we will be unable to set + # a processing state on the score set itself. + + job = job_manager.get_job() + + _job_required_params = [ + "score_set_id", + "correlation_id", + "updater_id", + ] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + + # Handle everything within try/except to persist appropriate mapping state + try: + correlation_id = job.job_params["correlation_id"] # type: ignore + updater_id = job.job_params["updater_id"] # type: ignore + updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "map_variants_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting variant mapping job.") + logger.info(msg="Started variant mapping job", extra=job_manager.logging_context()) + + # TODO#372: non-nullable URNs + if not score_set.urn: # pragma: no cover + raise ValueError("Score set URN is required for variant mapping.") + + # Setup score set state for mapping + score_set.mapping_state = MappingState.processing + score_set.mapping_errors = null() + score_set.modified_by = updated_by + score_set.modification_date = date.today() + + job_manager.db.add(score_set) + job_manager.db.flush() + + job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) + job_manager.update_progress(10, 100, "Score set prepared for variant mapping.") + logger.debug(msg="Score set prepared for variant mapping.", extra=job_manager.logging_context()) + + # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. + vrs = vrs_mapper() + blocking = functools.partial(vrs.map_score_set, score_set.urn) + loop = asyncio.get_running_loop() + + mapping_results = None + + logger.debug(msg="Mapping variants using VRS mapping service.", extra=job_manager.logging_context()) + job_manager.update_progress(30, 100, "Mapping variants using VRS mapping service.") + mapping_results = await loop.run_in_executor(ctx["pool"], blocking) + + logger.debug(msg="Done mapping variants.", extra=job_manager.logging_context()) + job_manager.update_progress(80, 100, "Processing mapped variants.") + + ## Check our assumptions about mapping results and handle errors appropriately. + + # Ensure we have mapping results + if not mapping_results: + job_manager.db.rollback() + score_set.mapping_errors = {"error_message": "Mapping results were not returned from VRS mapping service."} + job_manager.update_progress(100, 100, "Variant mapping failed due to missing results.") + logger.error( + msg="Mapping results were not returned from VRS mapping service.", extra=job_manager.logging_context() + ) + raise NonexistentMappingResultsError("Mapping results were not returned from VRS mapping service.") + + # Ensure we have mapped scores + mapped_scores = mapping_results.get("mapped_scores") + if not mapped_scores: + job_manager.db.rollback() + score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} + job_manager.update_progress(100, 100, "Variant mapping failed; no variants were mapped.") + logger.error(msg="No variants were mapped for this score set.", extra=job_manager.logging_context()) + raise NonexistentMappingScoresError("No variants were mapped for this score set.") + + # Ensure we have reference metadata + reference_metadata = mapping_results.get("reference_sequences") + if not reference_metadata: + job_manager.db.rollback() + score_set.mapping_errors = {"error_message": "Reference metadata missing from mapping results."} + job_manager.update_progress(100, 100, "Variant mapping failed due to missing reference metadata.") + logger.error(msg="Reference metadata missing from mapping results.", extra=job_manager.logging_context()) + raise NonexistentMappingReferenceError("Reference metadata missing from mapping results.") + + # Process and store mapped variants + for target_gene_identifier in reference_metadata: + target_gene = next( + (target_gene for target_gene in score_set.target_genes if target_gene.name == target_gene_identifier), + None, + ) + + if not target_gene: + raise ValueError( + f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." + ) + + job_manager.save_to_context({"processing_target_gene": target_gene.id}) + logger.debug(f"Processing target gene {target_gene.name}.", extra=job_manager.logging_context()) + + # allow for multiple annotation layers + pre_mapped_metadata: dict[str, Any] = {} + post_mapped_metadata: dict[str, Any] = {} + + # add gene-level info + gene_info = reference_metadata[target_gene_identifier].get("gene_info") + if gene_info: + target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") + post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") + + job_manager.save_to_context({"mapped_hgnc_name": target_gene.mapped_hgnc_name}) + logger.debug("Added mapped HGNC name to target gene.", extra=job_manager.logging_context()) + + # add annotation layer info + for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: + layer_premapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( + "computed_reference_sequence" + ) + if layer_premapped: + pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { + k: layer_premapped[k] + for k in set(list(layer_premapped.keys())) - EXCLUDED_PREMAPPED_ANNOTATION_KEYS + } + job_manager.save_to_context({"pre_mapped_layer_exists": True}) + + layer_postmapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( + "mapped_reference_sequence" + ) + if layer_postmapped: + post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped + job_manager.save_to_context({"post_mapped_layer_exists": True}) + + logger.debug( + f"Added annotation layer mapping metadata for {annotation_layer}.", + extra=job_manager.logging_context(), + ) + + target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) + target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) + job_manager.db.add(target_gene) + logger.debug("Added mapping metadata to target gene.", extra=job_manager.logging_context()) + + total_variants = len(mapped_scores) + job_manager.save_to_context({"total_variants_to_process": total_variants}) + job_manager.update_progress(90, 100, "Saving mapped variants.") + + successful_mapped_variants = 0 + annotation_manager = AnnotationStatusManager(job_manager.db) + for mapped_score in mapped_scores: + variant_urn = mapped_score.get("mavedb_id") + variant = job_manager.db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() + + job_manager.save_to_context({"processing_variant": variant.id}) + logger.debug(f"Processing variant {variant.id}.", extra=job_manager.logging_context()) + + # there should only be one current mapped variant per variant id, so update old mapped variant to current = false + existing_mapped_variant = ( + job_manager.db.query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() + ) + + if existing_mapped_variant: + job_manager.save_to_context({"existing_mapped_variant": existing_mapped_variant.id}) + existing_mapped_variant.current = False + job_manager.db.add(existing_mapped_variant) + logger.debug(msg="Set existing mapped variant to current = false.", extra=job_manager.logging_context()) + + annotation_was_successful = mapped_score.get("pre_mapped") and mapped_score.get("post_mapped") + if annotation_was_successful: + successful_mapped_variants += 1 + job_manager.save_to_context({"successful_mapped_variants": successful_mapped_variants}) + + mapped_variant = MappedVariant( + pre_mapped=mapped_score.get("pre_mapped", null()), + post_mapped=mapped_score.get("post_mapped", null()), + variant_id=variant.id, + modification_date=date.today(), + mapped_date=mapping_results["mapped_date_utc"], + vrs_version=mapped_score.get("vrs_version", null()), + mapping_api_version=mapping_results["dcd_mapping_version"], + error_message=mapped_score.get("error_message", null()), + current=True, + ) + + annotation_manager.add_annotation( + variant_id=variant.id, # type: ignore + annotation_type=AnnotationType.VRS_MAPPING, + version=mapped_score.get("vrs_version", null()), + status=AnnotationStatus.SUCCESS if annotation_was_successful else AnnotationStatus.FAILED, + annotation_data={ + "error_message": mapped_score.get("error_message", null()), + "job_run_id": job.id, + "success_data": { + "mapped_assay_level_hgvs": get_hgvs_from_post_mapped(mapped_score.get("post_mapped", {})), + }, + }, + current=True, + ) + + job_manager.db.add(mapped_variant) + logger.debug(msg="Added new mapped variant to session.", extra=job_manager.logging_context()) + + if successful_mapped_variants == 0: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "All variants failed to map."} + elif successful_mapped_variants < total_variants: + score_set.mapping_state = MappingState.incomplete + else: + score_set.mapping_state = MappingState.complete + + job_manager.save_to_context( + { + "successful_mapped_variants": successful_mapped_variants, + "mapping_state": score_set.mapping_state.name, + "mapping_errors": score_set.mapping_errors, + "inserted_mapped_variants": len(mapped_scores), + } + ) + except (NonexistentMappingResultsError, NonexistentMappingScoresError, NonexistentMappingReferenceError) as e: + send_slack_error(e) + logging_context = {**job_manager.logging_context(), **format_raised_exception_info_as_dict(e)} + logger.error(msg="Known error during variant mapping.", extra=logging_context) + + score_set.mapping_state = MappingState.failed + # These exceptions have already set mapping_errors appropriately + + return {"status": "exception", "data": {}, "exception": e} + + except Exception as e: + send_slack_error(e) + logging_context = {**job_manager.logging_context(), **format_raised_exception_info_as_dict(e)} + logger.error(msg="Encountered an unexpected error while parsing mapped variants.", extra=logging_context) + + job_manager.db.rollback() + + score_set.mapping_state = MappingState.failed + if not score_set.mapping_errors: + score_set.mapping_errors = { + "error_message": f"Encountered an unexpected error while parsing mapped variants. This job will be retried up to {job.max_retries} times (this was attempt {job.retry_count})." + } + job_manager.update_progress(100, 100, "Variant mapping failed due to an unexpected error.") + + return {"status": "exception", "data": {}, "exception": e} + + finally: + job_manager.db.add(score_set) + job_manager.db.flush() + + logger.info(msg="Inserted mapped variants into db.", extra=job_manager.logging_context()) + job_manager.update_progress(100, 100, "Finished processing mapped variants.") + + if successful_mapped_variants == 0: + logger.error(msg="No variants were successfully mapped.", extra=job_manager.logging_context()) + return { + "status": "failed", + "data": {}, + "exception": NoMappedVariantsError("No variants were successfully mapped."), + } + + logger.info(msg="Variant mapping job completed successfully.", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/variant_processing/py.typed b/src/mavedb/worker/jobs/variant_processing/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs_overview.md b/src/mavedb/worker/jobs_overview.md new file mode 100644 index 00000000..ec14b421 --- /dev/null +++ b/src/mavedb/worker/jobs_overview.md @@ -0,0 +1,32 @@ +# Job System Overview + +The ARQ worker job system in MaveDB provides a robust, scalable, and auditable framework for background processing, data management, and integration with external services. It is designed to support both simple jobs and complex pipelines with dependency management, error handling, and progress tracking. + +## Key Concepts + +- **Job**: A discrete unit of work, typically implemented as an async function, executed by the ARQ worker. +- **Pipeline**: A sequence of jobs with defined dependencies, managed as a single workflow. +- **JobRun**: A database record tracking the execution state, progress, and results of a job. +- **JobManager**: A class responsible for managing the lifecycle and state transitions of a single job. +- **PipelineManager**: A class responsible for coordinating pipelines, managing dependencies, and updating pipeline status. +- **Decorators**: Utilities that add lifecycle management, error handling, and audit guarantees to job functions. + +## Directory Structure + +- `jobs/` — Entrypoints and registry for all ARQ worker jobs. +- `jobs/data_management/`, `jobs/external_services/`, `jobs/variant_processing/`, etc. — Job implementations grouped by domain. +- `lib/decorators/` — Decorators for job and pipeline management. +- `lib/managers/` — JobManager, PipelineManager, and related utilities. + +## Job Lifecycle + +1. **Job Registration**: All available jobs are registered in `jobs/registry.py` for ARQ configuration. +2. **Job Execution**: Jobs are executed by the ARQ worker, with decorators ensuring audit, error handling, and state management. +3. **State Tracking**: Each job run is tracked in the database via a `JobRun` record. +4. **Pipeline Coordination**: For jobs that are part of a pipeline, the `PipelineManager` coordinates dependencies and status. + +## When to Add a Job +- When you need background processing, integration with external APIs, or scheduled/cron tasks. +- When you want robust error handling, progress tracking, and auditability for long-running or critical operations. + +See the following sections for details on decorators, managers, and best practices. diff --git a/src/mavedb/worker/lib/__init__.py b/src/mavedb/worker/lib/__init__.py new file mode 100644 index 00000000..8ab17989 --- /dev/null +++ b/src/mavedb/worker/lib/__init__.py @@ -0,0 +1,7 @@ +""" +Worker library modules for job management and pipeline coordination. +""" + +from .managers import JobManager, PipelineManager + +__all__ = ["JobManager", "PipelineManager"] diff --git a/src/mavedb/worker/lib/decorators/__init__.py b/src/mavedb/worker/lib/decorators/__init__.py new file mode 100644 index 00000000..4bef68d5 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/__init__.py @@ -0,0 +1,28 @@ +""" +Decorator utilities for job and pipeline management. + +This module exposes decorators for managing job and pipeline lifecycle hooks, error handling, +and logging in worker functions. Use these decorators to ensure consistent state management +and observability for background jobs and pipelines. + +Available decorators: +- with_job_management: Handles job context and state transitions +- with_pipeline_management: Handles pipeline context and coordination in addition to job management + +Example usage:: + from mavedb.worker.lib.decorators import managed_workflow + + @with_pipeline_management + async def my_worker_function_in_a_pipeline(...): + ... + + @with_job_management + async def my_standalone_job_function(...): + ... +""" + +from .job_guarantee import with_guaranteed_job_run_record +from .job_management import with_job_management +from .pipeline_management import with_pipeline_management + +__all__ = ["with_job_management", "with_pipeline_management", "with_guaranteed_job_run_record"] diff --git a/src/mavedb/worker/lib/decorators/job_guarantee.py b/src/mavedb/worker/lib/decorators/job_guarantee.py new file mode 100644 index 00000000..d93c08d6 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/job_guarantee.py @@ -0,0 +1,98 @@ +""" +Job Guarantee Decorator - Ensures a JobRun record is persisted before job execution. + +This decorator guarantees that a corresponding JobRun record is created and tracked for the decorated +function in the database before execution begins. It is designed to be stacked before managed job +decorators (such as with_job_management) to provide a consistent audit trail and robust error handling +for all job entrypoints, including cron-triggered jobs. + +NOTE +- This decorator must be applied before any job management decorators. +- This decorator is not supported as part of pipeline management; stacking it + with pipeline management decorators is not allowed and it should only be used with + standalone jobs. + +Features: +- Persists JobRun with job_type, function name, and parameters +- Integrates cleanly with managed job and pipeline decorators + +Example: + @with_guaranteed_job_run_record("cron_job") + @with_job_management + async def my_cron_job(ctx, ...): + ... +""" + +import functools +from typing import Any, Awaitable, Callable, TypeVar + +from sqlalchemy.orm import Session + +from mavedb import __version__ +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_session_ctx, is_test_mode +from mavedb.worker.lib.managers.types import JobResultData + +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +def with_guaranteed_job_run_record(job_type: str) -> Callable[[F], F]: + """ + Async decorator to ensure a JobRun record is created and persisted before executing the job function. + Should be applied before the managed job decorator. + + Args: + job_type (str): The type/category of the job (e.g., "cron_job", "data_processing"). + + Returns: + Decorated async function with job run persistence guarantee. + + Example: + ``` + @with_guaranteed_job_run_record("cron_job") + @with_job_management + async def my_cron_job(ctx, ...): + ... + ``` + """ + + def decorator(func: F) -> F: + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + with ensure_session_ctx(ctx=ensure_ctx(args)): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) + + # The job id must be passed as the second argument to the wrapped function. + job = _create_job_run(job_type, func, args, kwargs) + args = list(args) + args.insert(1, job.id) + args = tuple(args) + + return await func(*args, **kwargs) + + return async_wrapper # type: ignore + + return decorator + + +def _create_job_run(job_type: str, func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> JobRun: + """ + Creates and persists a JobRun record for a function before job execution. + """ + # Extract context (implicit first argument by ARQ convention) + ctx = ensure_ctx(args) + db: Session = ctx["db"] + + job_run = JobRun( + job_type=job_type, + job_function=func.__name__, + status=JobStatus.PENDING, + mavedb_version=__version__, + ) # type: ignore[call-arg] + db.add(job_run) + db.commit() + + return job_run diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py new file mode 100644 index 00000000..5b8a8ca0 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -0,0 +1,185 @@ +""" +Managed Job Decorator - Unified decorator for complete job lifecycle management. + +Provides automatic job lifecycle tracking with support for both sync and async functions. +Includes JobManager injection for advanced operations and robust error handling. +""" + +import functools +import inspect +import logging +from typing import Any, Awaitable, Callable, TypeVar, cast + +from arq import ArqRedis +from sqlalchemy.orm import Session + +from mavedb.lib.slack import send_slack_error +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_job_id, ensure_session_ctx, is_test_mode +from mavedb.worker.lib.managers import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def with_job_management(func: F) -> F: + """ + Decorator that adds automatic job lifecycle management to ARQ worker functions. + + Features: + - Job start/completion tracking with error handling + - JobManager injection for advanced operations + - Robust error handling with guaranteed state persistence + + The decorator injects a 'job_manager' parameter into the function that provides + access to progress updates and the underlying JobManager. + + Example: + ``` + @with_job_management + async def my_job_function(ctx, param1, param2, job_manager: JobManager): + job_manager.update_progress(10, message="Starting work") + + # Access JobManager for advanced operations + job_info = job_manager.get_job_info() + + # Do work... + job_manager.update_progress(50, message="Halfway done") + + # More work... + job_manager.update_progress(100, message="Complete") + + return {"result": "success"} + ``` + + Args: + func: The async function to decorate + + Returns: + Decorated async function with lifecycle management + """ + if not inspect.iscoroutinefunction(func): # pragma: no cover + raise ValueError("with_job_management decorator can only be applied to async functions") + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + with ensure_session_ctx(ctx=ensure_ctx(args)): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) + + return await _execute_managed_job(func, args, kwargs) + + return cast(F, async_wrapper) + + +async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> Any: + """ + Execute a managed ARQ job with full lifecycle tracking. + + This function handles the complete job lifecycle including: + - JobManager initialization from context + - Job start tracking + - ProgressTracker injection + - Async function execution + - Job completion tracking + - Error handling and cleanup + + Args: + func: Async function to execute + args: Function arguments + kwargs: Function keyword arguments + + Returns: + Function result + + Raises: + Exception: Re-raises any exception after proper job failure tracking + """ + try: + ctx = ensure_ctx(args) + db_session: Session = ctx["db"] + job_id = ensure_job_id(args) + + if "redis" not in ctx: + raise ValueError("Redis connection not found in job context") + redis_pool: ArqRedis = ctx["redis"] + except Exception as e: + logger.critical(f"Failed to initialize job management context: {e}") + send_slack_error(e) + raise + + try: + # Initialize JobManager + job_manager = JobManager(db_session, redis_pool, job_id) + + # Inject the job manager into kwargs for access within the function + kwargs["job_manager"] = job_manager + + # Mark job as started and persist state + job_manager.start_job() + db_session.commit() + + # Execute the async function + result = await func(*args, **kwargs) + + # Move job to final state based on result + if result.get("status") == "failed" or result.get("exception"): + # Exception info should always be present for failed jobs + job_manager.fail_job(result=result, error=result["exception"]) # type: ignore[arg-type] + send_slack_error(result["exception"]) + + elif result.get("status") == "skipped": + job_manager.skip_job(result=result) + else: + job_manager.succeed_job(result=result) + db_session.commit() + + # If the job is not marked as succeeded, check if we should retry + if job_manager.get_job_status() != JobStatus.SUCCEEDED and job_manager.should_retry(): + job_manager.prepare_retry(reason="Job did not complete successfully") + db_session.commit() + + return result + + except Exception as e: + # Prioritize salvaging lifecycle state + try: + db_session.rollback() + + # Build failure result data + result = {"status": "exception", "data": {}, "exception": e} + + # Mark job as failed + job_manager.fail_job(result=result, error=e) + db_session.commit() + + # TODO: Decide on retry logic based on exception type and result. + if job_manager.should_retry(): + # Prepare job for retry and persist state + job_manager.prepare_retry(reason=str(e)) + db_session.commit() + + # short circuit raising the exception. We indicate to the caller + # we did encounter a terminal failure and coordination should proceed. + return result + + except Exception as inner_e: + logger.critical(f"Failed to mark job {job_id} as failed: {inner_e}") + + # Notify separately about inner failure, which affects job persistence + send_slack_error(inner_e) + + # Re-raise the outer exception immediately to prevent duplicate notifications + finally: + logger.error(f"Job {job_id} failed: {e}") + + # Notify about the original exception + send_slack_error(e) + + # Swallow the exception after alerting so ARQ can finish the job cleanly and log results. + # We don't mind that we lose ARQs built in job marking, since we perform our own job + # lifecycle management via with_job_management. + return result diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py new file mode 100644 index 00000000..a181c72e --- /dev/null +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -0,0 +1,190 @@ +""" +Managed Job Decorator - Unified decorator for complete job lifecycle management. + +Provides automatic job lifecycle tracking with support for both sync and async functions. +Includes JobManager injection for advanced operations and robust error handling. +""" + +import functools +import inspect +import logging +from typing import Any, Awaitable, Callable, TypeVar, cast + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.lib.slack import send_slack_error +from mavedb.models.enums.job_pipeline import PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators import with_job_management +from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_job_id, ensure_session_ctx, is_test_mode +from mavedb.worker.lib.managers import PipelineManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def with_pipeline_management(func: F) -> F: + """ + Decorator that adds automatic pipeline lifecycle management to ARQ worker functions. Practically, + this means calling `PipelineManager.coordinate_pipeline()` after the decorated function completes. + + This decorator performs no pipeline coordination prior to function execution; it only + coordinates the pipeline after the function has run (whether successfully or with failure). + As a result, this decorator is best suited for jobs that represent discrete steps within a pipeline. + Pipelines are expected to be pre-defined and associated with jobs prior to execution and should be transitioned + to a running state by other means (e.g. a dedicated pipeline starter job). Attempting to start pipelines + within this decorator is not supported, and doing so may lead to unexpected behavior. + + Because pipeline management depends on job management, this decorator is built on top of the + `with_job_management` decorator. + + This decorator may be added to jobs which may or may not belong to a pipeline. If the job does not + belong to a pipeline, the decorator will simply skip pipeline coordination steps. Although pipeline + membership is optional, the decorator still will always enforce job lifecycle management via + `with_job_management`. + + Features: + - Pipeline lifecycle tracking + - Job lifecycle tracking via with_job_management + - Robust error handling, logging, and alerting on failures + + Example: + @with_pipeline_management + async def my_job_function(ctx, param1, param2): + ... job logic ... + + On decorator exit, pipeline coordination is attempted. + + Args: + func: The async function to decorate + + Returns: + Decorated async function with lifecycle management + """ + if not inspect.iscoroutinefunction(func): # pragma: no cover + raise ValueError("with_pipeline_management decorator can only be applied to async functions") + + # Wrap the function with job management. It isn't as simple as stacking decorators + # as we can only call job management after setting up pipeline management. + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + with ensure_session_ctx(ctx=ensure_ctx(args)): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) + + return await _execute_managed_pipeline(func, args, kwargs) + + return cast(F, async_wrapper) + + +async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> Any: + """ + Execute the managed pipeline function with lifecycle management. + + Args: + func: The async function to execute. + args: Positional arguments for the function. + kwargs: Keyword arguments for the function. + + Returns: + Any: The result of the function execution. + + Raises: + Exception: Propagates any exception raised during function execution. + """ + try: + ctx = ensure_ctx(args) + job_id = ensure_job_id(args) + db_session: Session = ctx["db"] + + if "redis" not in ctx: + raise ValueError("Redis connection not found in pipeline context") + redis_pool: ArqRedis = ctx["redis"] + except Exception as e: + logger.critical(f"Failed to initialize pipeline management context: {e}") + send_slack_error(e) + raise + + pipeline_manager = None + pipeline_id = None + try: + # Attempt to load the pipeline ID from the job. + # - If pipeline_id is not None, initialize PipelineManager + # - If None, skip pipeline coordination. We do not enforce every job to belong to a pipeline. + # - If error occurs, handle below + pipeline_id = db_session.execute(select(JobRun.pipeline_id).where(JobRun.id == job_id)).scalar_one() + if pipeline_id: + pipeline_manager = PipelineManager(db=db_session, redis=redis_pool, pipeline_id=pipeline_id) + + logger.info(f"Pipeline ID for job {job_id} is {pipeline_id}. Coordinating pipeline.") + + # If the pipeline is still in the created state, start it now. From this context, + # we do not wish to coordinate the pipeline. Doing so would result in the current + # job being re-queued before it has been marked as running, leading to potential state + # inconsistencies. + if pipeline_manager and pipeline_manager.get_pipeline_status() == PipelineStatus.CREATED: + await pipeline_manager.start_pipeline(coordinate=False) + db_session.commit() + + logger.info(f"Pipeline {pipeline_id} associated with job {job_id} started successfully") + + # Wrap the function with job management, then execute. This ensures both: + # - Job lifecycle management is nested within pipeline management + # - Exceptions from the job management layer are caught here for pipeline coordination + job_managed_func = with_job_management(func) + result = await job_managed_func(*args, **kwargs) + + # Attempt to coordinate pipeline next steps after successful job execution + if pipeline_manager: + await pipeline_manager.coordinate_pipeline() + + # Commit any changes made during pipeline coordination + db_session.commit() + + logger.info(f"Pipeline {pipeline_id} associated with job {job_id} coordinated successfully") + else: + logger.info(f"No pipeline associated with job {job_id}; skipping coordination") + + return result + + except Exception as e: + try: + # Rollback any uncommitted changes + db_session.rollback() + + # Attempt one final coordination to clean up any stubborn pipeline state + if pipeline_manager: + await pipeline_manager.coordinate_pipeline() + + # Commit any changes made during final coordination + db_session.commit() + + except Exception as inner_e: + logger.critical( + f"Unable to perform cleanup coordination on pipeline {pipeline_id} associated with job {job_id} after error: {inner_e}" + ) + + # Notify about the internal error, as it indicates a serious problem with pipeline state persistence + send_slack_error(inner_e) + + # No further work here. We can rely on the notification hooks below to alert on the original failure + # and should allow result generation to proceed as normal so the job can be logged. + finally: + logger.error(f"Pipeline {pipeline_id} associated with job {job_id} failed to coordinate: {e}") + + # Build job result data for failure + result = {"status": "failed", "data": {}, "exception": e} + + # Notify about the original failure + send_slack_error(e) + + # Swallow the exception after alerting so ARQ can finish the job cleanly and log results. + # We don't mind that we lose ARQs built in job marking, since we perform our own job + # lifecycle management via with_job_management. + return result diff --git a/src/mavedb/worker/lib/decorators/py.typed b/src/mavedb/worker/lib/decorators/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/lib/decorators/utils.py b/src/mavedb/worker/lib/decorators/utils.py new file mode 100644 index 00000000..4315b6e0 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/utils.py @@ -0,0 +1,53 @@ +import os +from contextlib import contextmanager + +from mavedb.db.session import db_session + + +def is_test_mode() -> bool: + """Check if the application is running in test mode based on the MAVEDB_TEST_MODE environment variable. + + Returns: + bool: True if in test mode, False otherwise. + """ + # Although not ideal, we use an environment variable to detect whether + # the application is in test mode. In the context of decorators, test + # mode makes them no-ops to facilitate unit testing without side effects. + # + # This is necessary because decorators are applied at import time, making + # it difficult to mock their behavior in tests when they must be imported + # up front and provided to the ARQ worker. + # + # This pattern allows us to control decorator behavior in tests without + # altering production code paths. + return os.getenv("MAVEDB_TEST_MODE") == "1" + + +@contextmanager +def ensure_session_ctx(ctx): + if "db" in ctx and ctx["db"] is not None: + # No-op context manager + yield ctx["db"] + else: + with db_session() as session: + ctx["db"] = session + yield session + ctx["db"] = None # Optionally clean up + + +def ensure_ctx(args) -> dict: + # Extract context (first argument by ARQ convention) + if not args or len(args) < 1 or not isinstance(args[0], dict): + raise ValueError("Managed functions must receive context as first argument") + + ctx = args[0] + return ctx + + +def ensure_job_id(args) -> int: + # Extract job_id (second argument by MaveDB convention) + if not args or len(args) < 2 or not isinstance(args[1], int): + raise ValueError("Job ID not found in function arguments") + + job_id = args[1] + return job_id diff --git a/src/mavedb/worker/lib/managers/__init__.py b/src/mavedb/worker/lib/managers/__init__.py new file mode 100644 index 00000000..b75eb40f --- /dev/null +++ b/src/mavedb/worker/lib/managers/__init__.py @@ -0,0 +1,67 @@ +"""Manager classes and shared utilities for job and pipeline coordination. + +This package provides managers for job lifecycle and pipeline coordination, +along with shared constants, exceptions, and types used across the worker system. + +Main Classes: + JobManager: Individual job lifecycle management + PipelineManager: Pipeline coordination and dependency management + +Shared Utilities: + Constants: Job statuses, timeouts, retry limits + Exceptions: Standardized error hierarchy + Types: TypedDict definitions and common type hints + +Example Usage: + >>> from mavedb.worker.lib.managers import JobManager, PipelineManager + >>> from mavedb.worker.lib.managers import JobStateError, TERMINAL_JOB_STATUSES + >>> + >>> job_manager = JobManager(db, redis, job_id) + >>> pipeline_manager = PipelineManager(db, redis) + >>> + >>> # Individual job operations + >>> job_manager.start_job() + >>> job_manager.succeed_job({"output": "success"}) + >>> + >>> # Pipeline coordination + >>> await pipeline_manager.coordinate_after_completion(True) +""" + +# Main manager classes +# Commonly used constants +# Main manager classes +from .base_manager import BaseManager +from .constants import ( + ACTIVE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, +) + +# Exception hierarchy +from .exceptions import ( + DatabaseConnectionError, + JobStateError, + JobTransitionError, +) +from .job_manager import JobManager +from .pipeline_manager import PipelineManager + +# Type definitions +from .types import JobResultData, RetryHistoryEntry + +__all__ = [ + # Main classes + "BaseManager", + "JobManager", + "PipelineManager", + # Constants + "ACTIVE_JOB_STATUSES", + "TERMINAL_JOB_STATUSES", + # Exceptions + "DatabaseConnectionError", + "JobStateError", + "JobTransitionError", + "PipelineCoordinationError", + # Types + "JobResultData", + "RetryHistoryEntry", +] diff --git a/src/mavedb/worker/lib/managers/base_manager.py b/src/mavedb/worker/lib/managers/base_manager.py new file mode 100644 index 00000000..de0fe67f --- /dev/null +++ b/src/mavedb/worker/lib/managers/base_manager.py @@ -0,0 +1,42 @@ +"""Base manager class providing common database transaction handling. + +This module provides the BaseManager class that encapsulates common database +session management patterns used across all manager classes. +""" + +import logging +from abc import ABC +from typing import Optional + +from arq import ArqRedis +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class BaseManager(ABC): + """Base class for all manager classes providing common interface. + + Provides standardized pattern for initializing a manager with database + and Redis connections. + + Features: + - Common initialization pattern + + Attributes: + db: SQLAlchemy database session for queries and transactions + redis: ARQ Redis client for job queue operations + """ + + def __init__(self, db: Session, redis: Optional[ArqRedis]): + """Initialize base manager with database and Redis connections. + + Args: + db: SQLAlchemy database session for job and pipeline queries + redis(Optional[ArqRedis]): ARQ Redis client for job queue operations + + Raises: + DatabaseConnectionError: Cannot connect to database + """ + self.db = db + self.redis = redis diff --git a/src/mavedb/worker/lib/managers/constants.py b/src/mavedb/worker/lib/managers/constants.py new file mode 100644 index 00000000..4eabd684 --- /dev/null +++ b/src/mavedb/worker/lib/managers/constants.py @@ -0,0 +1,56 @@ +"""Constants for job management and pipeline coordination. + +This module defines commonly used job status groupings that are used throughout +the job management system for state validation, dependency checking, and +pipeline coordination. +""" + +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus, PipelineStatus + +# Job status constants for common groupings +STARTABLE_JOB_STATUSES = [JobStatus.QUEUED, JobStatus.PENDING] +"""Job statuses that can be transitioned to RUNNING state.""" + +COMPLETED_JOB_STATUSES = [JobStatus.SUCCEEDED, JobStatus.FAILED] +"""Job statuses indicating finished execution (completed states).""" + +TERMINAL_JOB_STATUSES = [JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED, JobStatus.SKIPPED] +"""Job statuses indicating finished execution (terminal states).""" + +CANCELLED_JOB_STATUSES = [JobStatus.CANCELLED, JobStatus.SKIPPED, JobStatus.FAILED] +"""Job statuses that should stop execution (termination conditions).""" + +RETRYABLE_JOB_STATUSES = [JobStatus.FAILED, JobStatus.CANCELLED, JobStatus.SKIPPED] +"""Job statuses that can be retried.""" + +ACTIVE_JOB_STATUSES = [JobStatus.PENDING, JobStatus.QUEUED, JobStatus.RUNNING] +"""Job statuses that can be cancelled/skipped when pipeline fails.""" + +RETRYABLE_FAILURE_CATEGORIES = ( + FailureCategory.NETWORK_ERROR, + FailureCategory.TIMEOUT, + FailureCategory.SERVICE_UNAVAILABLE, + # TODO: Add more retryable exception types as needed +) +"""Failure categories that are considered retryable errors.""" + +# Pipeline coordination constants +STARTABLE_PIPELINE_STATUSES = [PipelineStatus.PAUSED, PipelineStatus.CREATED] +"""Pipeline statuses that can be transitioned to RUNNING state.""" + +TERMINAL_PIPELINE_STATUSES = [ + PipelineStatus.SUCCEEDED, + PipelineStatus.FAILED, + PipelineStatus.PARTIAL, + PipelineStatus.CANCELLED, +] +"""Pipeline statuses indicating finished execution (terminal states).""" + +CANCELLED_PIPELINE_STATUSES = [PipelineStatus.CANCELLED, PipelineStatus.FAILED] +"""Pipeline statuses indicating the pipeline has been cancelled or failed.""" + +CANCELLABLE_PIPELINE_STATUSES = [PipelineStatus.CREATED, PipelineStatus.RUNNING, PipelineStatus.PAUSED] +"""Pipeline statuses that can be cancelled/skipped.""" + +RUNNING_PIPELINE_STATUSES = [PipelineStatus.RUNNING] +"""Pipeline statuses indicating active execution.""" diff --git a/src/mavedb/worker/lib/managers/exceptions.py b/src/mavedb/worker/lib/managers/exceptions.py new file mode 100644 index 00000000..48fa4b83 --- /dev/null +++ b/src/mavedb/worker/lib/managers/exceptions.py @@ -0,0 +1,63 @@ +""" +Manager Exceptions for explicit error handling. +""" + + +class ManagerError(Exception): + """Base exception for Manager operations.""" + + pass + + +## Pipeline Manager Exceptions + + +class PipelineManagerError(ManagerError): + """Pipeline Manager specific errors.""" + + pass + + +class PipelineCoordinationError(PipelineManagerError): + """Pipeline coordination failed - may be recoverable.""" + + pass + + +class PipelineTransitionError(PipelineManagerError): + """Pipeline is in wrong state for requested operation.""" + + pass + + +class PipelineStateError(PipelineManagerError): + """Critical pipeline state operations failed - database issues preventing state persistence.""" + + pass + + +## Job Manager Exceptions + + +class JobManagerError(ManagerError): + """Job Manager specific errors.""" + + pass + + +class JobStateError(JobManagerError): + """Critical job state operations failed - database issues preventing state persistence.""" + + pass + + +class JobTransitionError(JobManagerError): + """Job is in wrong state for requested operation.""" + + pass + + +class DatabaseConnectionError(JobStateError): + """Database connection issues preventing any operations.""" + + pass diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py new file mode 100644 index 00000000..e762ada0 --- /dev/null +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -0,0 +1,910 @@ +"""Job lifecycle management for individual job state transitions. + +This module provides the JobManager class for managing individual job state transitions +with atomic operations and explicit error handling to ensure data consistency. +Pipeline coordination is handled separately by the PipelineManager. + +Example usage: + >>> from mavedb.worker.lib.job_manager import JobManager + >>> + >>> # Initialize with database and Redis connections + >>> job_manager = JobManager(db_session, redis_client, job_id=123) + >>> + >>> # Start job execution + >>> job_manager.start_job() + >>> + >>> # Update progress during execution + >>> job_manager.update_progress(50, 100, "Processing variants...") + >>> + >>> # Complete job (pipeline coordination handled separately) + >>> job_manager.complete_job( + ... status=JobStatus.SUCCEEDED, + ... result={"variants_processed": 1000} + ... ) + +Error Handling: + The JobManager uses specific exception types to distinguish between different + failure modes, allowing callers to implement appropriate recovery strategies: + + - DatabaseConnectionError: Database connectivity issues + - JobStateError: Critical state persistence failures + - JobTransitionError: Invalid state transitions +""" + +import logging +import traceback +from datetime import datetime +from typing import Any, Optional + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from sqlalchemy.orm.attributes import flag_modified + +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.managers.base_manager import BaseManager +from mavedb.worker.lib.managers.constants import ( + CANCELLED_JOB_STATUSES, + RETRYABLE_FAILURE_CATEGORIES, + RETRYABLE_JOB_STATUSES, + STARTABLE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + JobStateError, + JobTransitionError, +) +from mavedb.worker.lib.managers.types import JobResultData, RetryHistoryEntry + +logger = logging.getLogger(__name__) + + +class JobManager(BaseManager): + """Manages individual job lifecycle with atomic state transitions. + + The JobManager provides a high-level interface for managing individual job execution + while ensuring database consistency. It handles job state transitions, progress updates, + and retry logic. Pipeline coordination is handled separately by the PipelineManager. + + Key Features: + - Atomic state transitions with rollback on failure + - Explicit exception handling for different failure modes + - Progress tracking and retry mechanisms + - Automatic session cleanup on object manipulation failures + - Focus on individual job lifecycle only + + Note: + To avoid persisting inconsistent job state to the database, any failures + during job manipulation (e.g., fetching job, updating fields) will result + in a safe rollback of the current transaction. This ensures that partial + updates do not corrupt job state. This manager DOES NOT COMMIT database + changes, only flushes them. Commit responsibility lies with the caller. + + Usage Patterns: + + Basic job execution: + >>> manager = JobManager(db, redis, job_id=123) + >>> manager.start_job() + >>> manager.update_progress(25, message="Starting validation") + >>> manager.succeed_job(result={"count": 100}) + + Progress tracking convenience: + >>> manager.set_progress_total(1000, "Processing 1000 records") + >>> for record in records: + ... process_record(record) + ... manager.increment_progress() # Increment by 1 + ... if manager.is_cancelled(): + ... break + + Job failure handling: + >>> try: + ... process_data() + ... except ValidationError as e: + ... manager.fail_job(error=e, result={"partial_results": partial_data}) + + Direct completion control: + >>> manager.complete_job(status=JobStatus.SUCCEEDED, result=data) + + Error handling: + >>> try: + ... manager.complete_job(status=JobStatus.SUCCEEDED, result=data) + ... except JobStateError as e: + ... logger.critical(f"Critical state failure: {e}") + ... # Job completion failed - state not saved + + Job retry: + >>> try: + ... manager.retry_job(reason="Transient network error") + ... except JobTransitionError as e: + ... logger.error(f"Cannot retry job in current state: {e}") + + Exception Hierarchy: + - DatabaseConnectionError: Cannot connect to database + - JobStateError: Critical state persistence failures + - JobTransitionError: Invalid state transitions (e.g., start already running job) + + Thread Safety: + JobManager is not thread-safe. Each instance should be used by a single + worker thread and should not be shared across concurrent operations. + """ + + context: dict[str, Any] = {} + + def __init__(self, db: Session, redis: Optional[ArqRedis], job_id: int): + """Initialize JobManager for a specific job. + + Args: + db: Active SQLAlchemy session for database operations. Session should + be configured for the appropriate database and have proper + transaction isolation. + redis: ARQ Redis client for job queue operations. Must be connected + and ready for enqueue operations. Optional; can be None if Redis is not used. + job_id: Unique identifier of the job to manage. Must correspond to + an existing JobRun record in the database. + + Raises: + DatabaseConnectionError: If the job cannot be fetched from database, + indicating connectivity issues or invalid job_id. + + Example: + >>> db_session = get_database_session() + >>> redis_client = get_arq_redis_client() + >>> manager = JobManager(db_session, redis_client, 12345) + >>> # Manager is now ready to handle job 12345 + """ + super().__init__(db, redis) + + self.job_id = job_id + job = self.get_job() + self.pipeline_id = job.pipeline_id if job else None + + self.save_to_context( + {"job_id": str(self.job_id), "pipeline_id": str(self.pipeline_id) if self.pipeline_id else None} + ) + + def save_to_context(self, ctx: dict) -> dict[str, Any]: + for k, v in ctx.items(): + self.context[k] = v + + return self.context + + def logging_context(self) -> dict[str, Any]: + return self.context + + def start_job(self) -> None: + """Mark job as started and initialize execution tracking. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Transitions job from QUEUED or PENDING to RUNNING state, setting start + timestamp and a default progress message. This method should be called + once at the beginning of job execution. + + State Changes: + - Sets status to JobStatus.RUNNING + - Records started_at timestamp + - Initializes progress to 0/100 + - Sets progress_message to "Job began execution" + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save job start state to database + JobTransitionError: Job not in valid state to start (must be QUEUED or PENDING) + + Example: + >>> manager = JobManager(db, redis, 123) + >>> manager.start_job() # Job 123 now marked as RUNNING + >>> # Proceed with job execution logic... + """ + job_run = self.get_job() + if job_run.status not in STARTABLE_JOB_STATUSES: + self.save_to_context({"job_status": str(job_run.status)}) + logger.error( + "Invalid job start attempt: status not in STARTABLE_JOB_STATUSES", extra=self.logging_context() + ) + raise JobTransitionError(f"Cannot start job {self.job_id} from status {job_run.status}") + + try: + job_run.status = JobStatus.RUNNING + job_run.started_at = datetime.now() + job_run.progress_message = "Job began execution" + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job start state", extra=self.logging_context()) + raise JobStateError(f"Failed to update job start state: {e}") + + self.save_to_context({"job_status": str(job_run.status)}) + logger.info("Job marked as started", extra=self.logging_context()) + + def complete_job(self, status: JobStatus, result: JobResultData, error: Optional[Exception] = None) -> None: + """Mark job as completed with the specified final status. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Transitions job to the passed terminal status (SUCCEEDED, FAILED, CANCELLED, SKIPPED), + recording the finished_at timestamp, result data, and error details if applicable. + + Args: + status: Final job status - must be a terminal status + (SUCCEEDED, FAILED, CANCELLED, SKIPPED) + result: JobResultData to store in metadata. Should be JSON-serializable + dictionary containing any outputs, metrics, or artifacts produced. + error: Exception that caused job failure, if applicable. Error details + will be logged and stored for debugging. + + State Changes: + - Sets status to the specified terminal status + - Sets finished_at timestamp + - Stores result in job metadata + - Records error details if provided and status is FAILED + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + JobTransitionError: Invalid terminal status provided + + Examples: + Successful completion: + >>> result_data = {"records_processed": 1500, "errors": 0} + >>> manager.complete_job( + ... status=JobStatus.SUCCEEDED, + ... result=result_data + ... ) + + Failed completion with error: + >>> try: + ... process_data() + ... except ValidationError as e: + ... manager.complete_job( + ... status=JobStatus.FAILED, + ... result={"partial_results": data}, + ... error=e + ... ) + + Note: + Job completion state is saved independently of any pipeline + coordination. Use PipelineManager for coordinating dependent jobs. + """ + # Validate terminal status + if status not in TERMINAL_JOB_STATUSES: + self.save_to_context({"job_status": str(status)}) + logger.error("Invalid job completion status: not in TERMINAL_JOB_STATUSES", extra=self.logging_context()) + raise JobTransitionError( + f"Cannot commplete job to status: {status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" + ) + + job_run = self.get_job() + try: + job_run.status = status + job_run.metadata_["result"] = { + "status": result["status"], + "data": result["data"], + "exception_details": format_raised_exception_info_as_dict(result["exception"]) # type: ignore + if result.get("exception") + else None, + } + job_run.finished_at = datetime.now() + + if status == JobStatus.FAILED: + job_run.failure_category = FailureCategory.UNKNOWN + + if error: + job_run.error_message = str(error) + job_run.error_traceback = traceback.format_exc() + # TODO: Classify failure category based on error type + job_run.failure_category = FailureCategory.UNKNOWN + + self.save_to_context({"failure_category": str(job_run.failure_category)}) + + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while updating job completion state", extra=self.logging_context() + ) + raise JobStateError(f"Failed to update job completion state: {e}") + + self.save_to_context({"job_status": str(job_run.status)}) + logger.info("Job marked as completed", extra=self.logging_context()) + + def fail_job(self, error: Exception, result: JobResultData) -> None: + """Mark job as failed and record error details. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job execution as failed. This is equivalent + to calling complete_job(status=JobStatus.FAILED, error=error, result=result) but + provides clearer intent and a more focused API for failure scenarios. + + Args: + error: Exception that caused job failure. Error details will be logged + and stored for debugging. Used to populate error message and traceback. + result: Partial results to store in metadata. Should be + JSON-serializable dictionary containing any partial outputs, + metrics, or debugging information produced before failure. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Basic failure with exception: + >>> try: + ... validate_data(input_data) + ... except ValidationError as e: + ... manager.fail_job(error=e, result={}) + + Failure with partial results: + >>> try: + ... results = process_batch(records) + ... except ProcessingError as e: + ... partial_results = {"processed": len(results), "failed_at": e.record_id} + ... manager.fail_job(error=e, result=partial_results) + + Note: + This method is equivalent to complete_job(status=JobStatus.FAILED, error=error, result=result). + Use this method when job failure is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.FAILED, result=result, error=error) + + def succeed_job(self, result: JobResultData) -> None: + """Mark job as succeeded and record results. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job execution as successful. This is equivalent + to calling complete_job(status=JobStatus.SUCCEEDED, result=result) but provides clearer + intent and a more focused API for success scenarios. + + Args: + result: Job result data to store in metadata. Should be JSON-serializable + dictionary containing any outputs, metrics, or artifacts produced. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Successful completion: + >>> result_data = {"records_processed": 1500, "errors": 0, "duration": 45.2} + >>> manager.succeed_job(result=result_data) + + Success with metrics: + >>> metrics = { + ... "input_count": 10000, + ... "output_count": 9847, + ... "skipped": 153, + ... "processing_time": 120.5, + ... "memory_peak": "2.1GB" + ... } + >>> manager.succeed_job(result=metrics) + + Note: + This method is equivalent to complete_job(status=JobStatus.SUCCEEDED, result=result). + Use this method when job success is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.SUCCEEDED, result=result) + + def cancel_job(self, result: JobResultData) -> None: + """Mark job as cancelled. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job execution as cancelled. This is equivalent + to calling complete_job(status=JobStatus.CANCELLED, result=result) but provides + clearer intent and a more focused API for cancellation scenarios. + + Args: + reason: Human-readable reason for cancellation (e.g., "user_requested", + "pipeline_cancelled", "timeout"). Used for debugging and audit trails. + result: Partial results to store in metadata. Should be JSON-serializable + dictionary containing any partial outputs or cancellation details. + If None, defaults to cancellation metadata. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Basic cancellation: + >>> manager.cancel_job({"reason": "user_requested"}) + + Note: + This method is equivalent to complete_job(status=JobStatus.CANCELLED, result=result). + Use this method when job cancellation is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.CANCELLED, result=result) + + def skip_job(self, result: JobResultData) -> None: + """Mark job as skipped. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job as skipped (not executed). This is equivalent + to calling complete_job(status=JobStatus.SKIPPED, result=result) but provides + clearer intent and a more focused API for skip scenarios. + + Args: + result: Skip details to store in metadata. Should be JSON-serializable + dictionary containing skip reason and context. + If None, defaults to skip metadata. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Basic skip: + >>> manager.skip_job({"reason": "No work to perform"}) + + Note: + This method is equivalent to complete_job(status=JobStatus.SKIPPED, result=result). + Use this method when job skipping is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.SKIPPED, result=result) + + def prepare_retry(self, reason: str = "retry_requested") -> None: + """Prepare a failed job for retry by resetting state to PENDING. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Resets a failed job back to PENDING status so it can be re-enqueued + by the pipeline coordination system. This is similar to job completion + but transitions to PENDING instead of a terminal state. + + Args: + reason: Human-readable reason for the retry (e.g., "transient_network_error", + "memory_limit_exceeded"). Used for debugging and audit trails. + + State Changes: + - Increments retry_count + - Resets status from FAILED, SKIPPED, CANCELLED to PENDING + - Clears error_message, error_traceback, failure_category + - Clears finished_at timestamp + - Adds retry attempt to metadata history + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobTransitionError: Job not in FAILED state (cannot retry) + JobStateError: Cannot save retry state changes + + Examples: + Basic retry preparation: + >>> try: + ... manager.prepare_retry("network_timeout") + ... except JobTransitionError: + ... logger.error("Cannot retry job - not in failed state") + + Conditional retry with limits: + >>> job = manager.get_job() + >>> if job and job.retry_count < 3: + ... manager.prepare_retry(f"attempt_{job.retry_count + 1}") + ... # PipelineManager will handle enqueueing + ... else: + ... logger.error("Max retries exceeded") + + Retry History: + Each retry attempt is recorded in job metadata with: + - retry_attempt: Sequential attempt number + - timestamp: When retry was initiated + - result: Previous execution results (for debugging) + - reason: Provided retry reason + + Note: + After calling this method, use PipelineManager.enqueue_ready_jobs() + to actually enqueue the job for execution. + """ + job_run = self.get_job() + if job_run.status not in RETRYABLE_JOB_STATUSES: + self.save_to_context({"job_status": str(job_run.status)}) + logger.error("Invalid job retry status: status not in RETRYABLE_JOB_STATUSES", extra=self.logging_context()) + raise JobTransitionError(f"Cannot retry job {self.job_id} due to invalid state ({job_run.status})") + + try: + job_run.status = JobStatus.PENDING + current_result: JobResultData = job_run.metadata_.get("result", {}) + job_run.retry_count = (job_run.retry_count or 0) + 1 + job_run.progress_message = "Job retry prepared" + job_run.error_message = None + job_run.error_traceback = None + job_run.failure_category = None + job_run.finished_at = None + job_run.started_at = None + + # Add retry history - metadata manipulation (risky) + retry_history: list[RetryHistoryEntry] = job_run.metadata_.setdefault("retry_history", []) + retry_history.append( + { + "attempt": job_run.retry_count, + "timestamp": datetime.now().isoformat(), + "result": current_result, + "reason": reason, + } + ) + job_run.metadata_.pop("result", None) # Clear previous result + flag_modified(job_run, "metadata_") + + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job retry state", extra=self.logging_context()) + raise JobStateError(f"Failed to update job retry state: {e}") + + self.save_to_context({"job_status": str(job_run.status), "retry_attempt": job_run.retry_count}) + logger.info("Job successfully prepared for retry", extra=self.logging_context()) + + def prepare_queue(self) -> None: + """Prepare job for enqueueing by setting QUEUED status. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Transitions job from PENDING to QUEUED status before ARQ enqueueing. + This ensures proper state tracking and validates the transition. + + Raises: + JobTransitionError: Job not in PENDING state + JobStateError: Cannot save state change + """ + job_run = self.get_job() + if job_run.status != JobStatus.PENDING: + self.save_to_context({"job_status": str(job_run.status)}) + logger.error("Invalid job queue attempt: status not PENDING", extra=self.logging_context()) + raise JobTransitionError(f"Cannot queue job {self.job_id} from status {job_run.status}") + + try: + job_run.status = JobStatus.QUEUED + job_run.progress_message = "Job queued for execution" + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job queue state", extra=self.logging_context()) + raise JobStateError(f"Failed to update job queue state: {e}") + + self.save_to_context({"job_status": str(job_run.status)}) + logger.debug("Job successfully prepared for queueing", extra=self.logging_context()) + + def reset_job(self) -> None: + """Reset job to initial state for re-execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Resets all job state fields to their initial values, allowing the job + to be re-executed from scratch. This is useful for testing or manual + re-runs of jobs without retaining any prior execution history. + + State Changes: + - Sets status to PENDING + - Clears started_at and finished_at timestamps + - Resets progress to 0/100 with default message + - Clears error details and failure category + - Resets retry_count to 0 + - Clears metadata + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save reset state changes + Examples: + Basic job reset: + >>> manager.reset_job() + >>> # Job is now reset to initial state for re-execution + """ + job_run = self.get_job() + try: + job_run.status = JobStatus.PENDING + job_run.started_at = None + job_run.finished_at = None + job_run.progress_current = None + job_run.progress_total = None + job_run.progress_message = None + job_run.error_message = None + job_run.error_traceback = None + job_run.failure_category = None + job_run.retry_count = 0 + job_run.metadata_ = {} + + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while resetting job state", extra=self.logging_context()) + raise JobStateError(f"Failed to reset job state: {e}") + + self.save_to_context({"job_status": str(job_run.status), "retry_attempt": job_run.retry_count}) + logger.info("Job successfully reset to initial state", extra=self.logging_context()) + + def update_progress(self, current: int, total: int = 100, message: Optional[str] = None) -> None: + """Update job progress information during execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Provides real-time progress updates for long-running jobs. Progress updates + are best-effort operations that won't interrupt job execution if they fail. + This allows jobs to continue even if progress tracking has issues. + + Args: + current: Current progress value (e.g., records processed so far) + total: Total expected progress value (default: 100 for percentage) + message: Optional human-readable progress description + + Examples: + Percentage-based progress: + >>> manager.update_progress(25, 100, "Validating input data") + >>> manager.update_progress(50, 100, "Processing records") + >>> manager.update_progress(100, 100, "Finalizing results") + + Count-based progress: + >>> total_records = 50000 + >>> for i, record in enumerate(records): + ... process_record(record) + ... if i % 1000 == 0: # Update every 1000 records + ... manager.update_progress( + ... current=i, + ... total=total_records, + ... message=f"Processed {i}/{total_records} records" + ... ) + + Handling progress failures: + >>> try: + ... manager.update_progress(75, message="Almost done") + ... except DatabaseConnectionError: + ... logger.debug("Progress update failed, continuing job") + ... # Job continues normally + + Note: + Progress updates are non-blocking and failure-tolerant. If a progress + update fails, the job may choose to continue execution normally. Failed + progress updates are logged at debug level. + """ + job_run = self.get_job() + try: + job_run.progress_current = current + job_run.progress_total = total + if message: + job_run.progress_message = message + + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job progress", extra=self.logging_context()) + raise JobStateError(f"Failed to update job progress state: {e}") + + self.save_to_context( + {"job_progress_current": current, "job_progress_total": total, "job_progress_message": message} + ) + logger.debug("Updated progress successfully for job", extra=self.logging_context()) + + def update_status_message(self, message: str) -> None: + """Update job status message without changing progress. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for updating the progress message while keeping + current progress values unchanged. Useful for status updates during + long-running operations. + + Args: + message: Human-readable status message describing current activity + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save status message update + + Example: + >>> manager.update_status_message("Connecting to external API...") + >>> # Do API work + >>> manager.update_status_message("Processing API response...") + """ + job_run = self.get_job() + try: + job_run.progress_message = message + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while updating job status message", extra=self.logging_context() + ) + raise JobStateError(f"Failed to update job status message state: {e}") + + self.save_to_context({"job_progress_message": message}) + logger.debug("Updated status message successfully for job", extra=self.logging_context()) + + def increment_progress(self, amount: int = 1, message: Optional[str] = None) -> None: + """Increment job progress by a specified amount. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for incrementing progress without needing to track + the current progress value. Useful for batch processing where you want + to increment by 1 for each item processed. + + Args: + amount: Amount to increment progress by (default: 1) + message: Optional message to update along with progress + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save progress update + + Examples: + >>> # Process items one by one + >>> for item in items: + ... process_item(item) + ... manager.increment_progress() # Increment by 1 + + >>> # Process in batches + >>> for batch in batches: + ... process_batch(batch) + ... manager.increment_progress(len(batch), f"Processed batch {i}") + """ + job_run = self.get_job() + try: + current = job_run.progress_current or 0 + job_run.progress_current = current + amount + if message: + job_run.progress_message = message + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while incrementing job progress", extra=self.logging_context() + ) + raise JobStateError(f"Failed to increment job progress state: {e}") + + self.save_to_context( + { + "job_progress_current": current, + "job_progress_total": job_run.progress_total, + "job_progress_message": message or "", + } + ) + logger.debug("Incremented progress successfully for job", extra=self.logging_context()) + + def set_progress_total(self, total: int, message: Optional[str] = None) -> None: + """Update the total progress value, useful when total becomes known during execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for updating progress total when it's discovered during + job execution (e.g., after counting records to process). + + Args: + total: New total progress value + message: Optional message to update along with total + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save progress total update + + Example: + >>> # Initially unknown total + >>> manager.start_job() + >>> records = load_all_records() # Discovers actual count + >>> manager.set_progress_total(len(records), f"Processing {len(records)} records") + """ + job_run = self.get_job() + try: + job_run.progress_total = total + if message: + job_run.progress_message = message + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while updating job progress total", extra=self.logging_context() + ) + raise JobStateError(f"Failed to update job progress total state: {e}") + + self.save_to_context({"job_progress_total": total, "job_progress_message": message}) + logger.debug("Updated progress total successfully for job", extra=self.logging_context()) + + def is_cancelled(self) -> bool: + """Check if job has been cancelled or should stop execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for checking if the job should stop execution due to + cancellation, pipeline failure, or other termination conditions. Jobs + can use this for graceful shutdown. + + Returns: + bool: True if job should stop execution, False if it can continue + + Raises: + DatabaseConnectionError: Cannot fetch job status from database + + Example: + >>> for item in large_dataset: + ... if manager.is_cancelled(): + ... logger.info("Job cancelled, stopping gracefully") + ... break + ... process_item(item) + """ + return self.get_job_status() in CANCELLED_JOB_STATUSES + + def should_retry(self) -> bool: + """Check if job should be retried based on error type and retry count. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method that implements common retry logic. Checks current + retry count against maximum and evaluates if the error type is retryable. + + Returns: + bool: True if job should be retried, False otherwise + + Raises: + DatabaseConnectionError: Cannot fetch job info from database + + Examples: + >>> try: + ... result = do_work() + ... except NetworkError as e: + ... manager.fail_job(e, result) + ... if manager.should_retry(): + ... manager.retry_job() + ... else: + ... manager.fail_job(e, result) + """ + job_run = self.get_job() + try: + self.save_to_context( + { + "job_retry_count": job_run.retry_count, + "job_max_retries": job_run.max_retries, + "job_failure_category": str(job_run.failure_category) if job_run.failure_category else None, + "job_status": str(job_run.status), + } + ) + + # Check if job is in FAILED state + if job_run.status != JobStatus.FAILED: + logger.debug("Job cannot be retried: not in FAILED state", extra=self.logging_context()) + return False + + # Check retry count + current_retries = job_run.retry_count or 0 + if current_retries >= job_run.max_retries: + logger.debug("Job cannot be retried: max retries reached", extra=self.logging_context()) + return False + + # Check if failure category is retryable + if job_run.failure_category not in RETRYABLE_FAILURE_CATEGORIES: + logger.debug("Job cannot be retried: failure category not retryable", extra=self.logging_context()) + return False + + logger.debug("Job is retryable", extra=self.logging_context()) + return True + + except (AttributeError, TypeError, KeyError, ValueError) as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Unexpected error checking retry eligibility", extra=self.logging_context()) + raise JobStateError(f"Failed to check retry eligibility state: {e}") + + def get_job_status(self) -> JobStatus: # pragma: no cover + """Get current job status for monitoring and debugging. + + Provides non-blocking access to job status without affecting job + execution. Used by decorators and monitoring systems to check job state. + + Returns: + JobStatus: Current job status (QUEUED, RUNNING, SUCCEEDED, + FAILED, etc.). + + Raises: + DatabaseConnectionError: Cannot connect to database, SQL query failed, + or job not found (indicates data inconsistency) + + Examples: + >>> status = manager.get_job_status() + >>> if status == JobStatus.RUNNING: + ... logger.info("Job is currently executing") + """ + return self.get_job().status + + def get_job(self) -> JobRun: + """Get complete job information for monitoring and debugging. + + Retrieves full JobRun instance with all fields populated. Used by + decorators and monitoring systems that need access to job metadata, + progress, error details, or other comprehensive job information. + + Returns: + JobRun: Complete job instance with all fields. + + Raises: + DatabaseConnectionError: Cannot connect to database, SQL query failed, + or job not found (indicates data inconsistency) + + Example: + >>> job = manager.get_job() + >>> if job: + ... logger.info(f"Job {job.urn} progress: {job.progress_current}/{job.progress_total}") + ... if job.error_message: + ... logger.error(f"Job error: {job.error_message}") + """ + try: + return self.db.execute(select(JobRun).where(JobRun.id == self.job_id)).scalar_one() + except SQLAlchemyError as e: + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Unexpected error fetching job info", extra=self.logging_context()) + raise DatabaseConnectionError(f"Failed to fetch job {self.job_id}: {e}") diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py new file mode 100644 index 00000000..b0ecfcf1 --- /dev/null +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -0,0 +1,1145 @@ +"""Pipeline coordination management for job dependencies and status. + +This module provides the PipelineManager class for coordinating pipeline execution, +managing job dependencies, and updating pipeline status. The PipelineManager is +separated from individual job lifecycle management to provide clean separation of concerns. + +Example usage: + >>> from mavedb.worker.lib.pipeline_manager import PipelineManager + >>> + >>> # Initialize with database and Redis connections + >>> pipeline_manager = PipelineManager(db_session, redis_client, pipeline_id=456) + >>> + >>> # Coordinate after a job completes + >>> await pipeline_manager.coordinate_pipeline() + >>> + >>> # Update pipeline status + >>> new_status = pipeline_manager.transition_pipeline_status() + >>> + >>> # Cancel remaining jobs when pipeline fails + >>> cancelled_count = pipeline_manager.cancel_remaining_jobs( + ... reason="Dependency failed" + ... ) + >>> + >>> # Pause/unpause pipeline + >>> was_paused = pipeline_manager.pause_pipeline("Maintenance") + >>> was_unpaused = await pipeline_manager.unpause_pipeline("Complete") + +Error Handling: + The PipelineManager uses the same exception hierarchy as JobManager for consistency: + + - DatabaseConnectionError: Database connectivity issues + - JobStateError: Critical state persistence failures + - PipelineCoordinationError: Pipeline coordination failures +""" + +import logging +from datetime import datetime, timedelta +from typing import Sequence + +from arq import ArqRedis +from sqlalchemy import and_, func, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from mavedb.lib.slack import send_slack_message +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.managers import BaseManager, JobManager +from mavedb.worker.lib.managers.constants import ( + ACTIVE_JOB_STATUSES, + CANCELLED_JOB_STATUSES, + CANCELLED_PIPELINE_STATUSES, + RUNNING_PIPELINE_STATUSES, + TERMINAL_PIPELINE_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + PipelineCoordinationError, + PipelineStateError, + PipelineTransitionError, +) +from mavedb.worker.lib.managers.utils import ( + construct_bulk_cancellation_result, + job_dependency_is_met, + job_should_be_skipped_due_to_unfulfillable_dependency, +) + +logger = logging.getLogger(__name__) + + +class PipelineManager(BaseManager): + """Manages pipeline coordination and job dependencies with atomic operations. + + The PipelineManager provides a focused interface for coordinating pipeline execution + without coupling to individual job lifecycle management. It handles dependency + checking, status updates, and pipeline-wide operations like cancellation. + + Key Features: + - Atomic pipeline status transitions with rollback on failure + - Dependency-based job enqueueing with race condition prevention + - Pipeline-wide cancellation with proper error handling + - Separation from individual job lifecycle management + - Consistent exception handling and logging + + Usage Patterns: + + Pipeline coordination after job completion: + >>> manager = PipelineManager(db, redis, pipeline_id=123) + >>> await manager.coordinate_pipeline() + + Manual pipeline operations: + >>> # Update pipeline status based on current job states + >>> new_status = manager.transition_pipeline_status() + >>> + >>> # Cancel remaining jobs + >>> cancelled_count = manager.cancel_remaining_jobs( + ... reason="Manual cancellation" + ... ) + >>> + >>> # Pause pipeline execution + >>> was_paused = manager.pause_pipeline( + ... reason="System maintenance" + ... ) + >>> + >>> # Resume pipeline execution + >>> was_unpaused = await manager.unpause_pipeline( + ... reason="Maintenance complete" + ... ) + + Dependency management: + >>> # Check if a job can be enqueued + >>> can_run = manager.can_enqueue_job(job) + >>> + >>> # Enqueue all ready jobs (independent and dependent) + >>> await manager.enqueue_ready_jobs() + + Pipeline monitoring: + >>> # Get detailed progress statistics + >>> progress = manager.get_pipeline_progress() + >>> print(f"Pipeline {progress['completion_percentage']:.1f}% complete") + >>> + >>> # Get job counts by status + >>> counts = manager.get_job_counts_by_status() + >>> print(f"Failed jobs: {counts.get(JobStatus.FAILED, 0)}") + + Job retry and pipeline restart: + >>> # Retry all failed jobs + >>> retried_count = await manager.retry_failed_jobs() + >>> + >>> # Restart entire pipeline + >>> restarted = await manager.restart_pipeline("Fixed issue") + + Thread Safety: + PipelineManager is not thread-safe. Each instance should be used by a single + worker thread and should not be shared across concurrent operations. + """ + + def __init__(self, db: Session, redis: ArqRedis, pipeline_id: int): + """Initialize pipeline manager with database and Redis connections. + + Args: + db: SQLAlchemy database session for job and pipeline queries + redis: ARQ Redis client for job queue operations. Note that although the Redis + client is optional for base managers, PipelineManager requires it for + job coordination. + pipeline_id: ID of the pipeline this manager instance will coordinate + + Raises: + DatabaseConnectionError: Cannot connect to database + + Example: + >>> db_session = get_database_session() + >>> redis_client = get_arq_redis_client() + >>> manager = PipelineManager(db_session, redis_client, pipeline_id=456) + """ + super().__init__(db, redis) + self.pipeline_id = pipeline_id + self.get_pipeline() # Validate pipeline exists on init + + async def start_pipeline(self, coordinate: bool = True) -> None: + """Start the pipeline + + Entry point to start pipeline execution. Sets pipeline status to RUNNING + and enqueues independent jobs using coordinate pipeline if coordinate is True. + + Raises: + DatabaseConnectionError: Cannot query or update pipeline + PipelineStateError: Cannot update pipeline state + PipelineCoordinationError: Failed to enqueue ready jobs + + Example: + >>> # Start a new pipeline + >>> await pipeline_manager.start_pipeline() + """ + status = self.get_pipeline_status() + + if status != PipelineStatus.CREATED: + logger.error( + f"Pipeline {self.pipeline_id} is in a non-created state (current status: {status}) and may not be started" + ) + raise PipelineTransitionError(f"Pipeline {self.pipeline_id} is in state {status} and may not be started") + + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} started successfully") + + # Allow controllable coordination logic. By default, we want to coordinate + # immediately after starting to enqueue independent jobs. However, if a job + # has already been enqueued and is beginning execution and starts the pipeline, + # as a result of its job management decorator, we want to skip coordination here + # so we do not double-enqueue jobs. + if coordinate: + await self.coordinate_pipeline() + + async def coordinate_pipeline(self) -> None: + """Coordinate pipeline after a job completes. + + This is the main coordination entry point called after jobs complete. + It updates pipeline status and enqueues ready jobs or cancels remaining jobs + based on the completion result. The method operates on the entire pipeline + state rather than tracking individual job completions. + + Raises: + DatabaseConnectionError: Cannot query job or pipeline info + PipelineStateError: Cannot update pipeline state + PipelineCoordinationError: Failed to enqueue jobs or cancel remaining jobs + JobStateError: Critical job state persistence failure + JobTransitionError: Job cannot be transitioned from current state to new state + + + Example: + >>> # Called after successful job completion + >>> await pipeline_manager.coordinate_pipeline() + """ + new_status = self.transition_pipeline_status() + self.db.flush() + + if new_status in CANCELLED_PIPELINE_STATUSES: + self.cancel_remaining_jobs(reason="Pipeline failed or cancelled") + + # Only enqueue new jobs if pipeline is running + if new_status in RUNNING_PIPELINE_STATUSES: + await self.enqueue_ready_jobs() + + # After enqueuing jobs, re-evaluate pipeline status in case it changed. + # We only expect the status to change if jobs with unsatisfiable dependencies were skipped. + self.transition_pipeline_status() + self.db.flush() + + def transition_pipeline_status(self) -> PipelineStatus: + """Update pipeline status based on current job states. + + Analyzes the status distribution of all jobs in the pipeline to determine + the appropriate pipeline status. Updates pipeline status and finished_at + timestamp when the status changes to a terminal state. + + Returns: + PipelineStatus: The current pipeline status after update. If unchanged, the + previous status is returned. + + Raises: + DatabaseConnectionError: Cannot query job statuses or pipeline info + JobStateError: Cannot update pipeline status or corrupted job data + + Status Logic: + - FAILED: Any job has FAILED status + - RUNNING: Any job is RUNNING or QUEUED + - SUCCEEDED: All jobs are SUCCEEDED + - PARTIAL: Mix of SUCCEEDED/SKIPPED/CANCELLED with no FAILED/RUNNING + - CANCELLED: All remaining jobs are CANCELLED + - No Change: If pipeline is PAUSED, CANCELLED, or has no jobs: status remains unchanged + + Example: + >>> new_status = pipeline_manager.transition_pipeline_status() + >>> print(f"Pipeline status is now {new_status}") + """ + pipeline = self.get_pipeline() + status_counts = self.get_job_counts_by_status() + + old_status = pipeline.status + try: + total_jobs = sum(status_counts.values()) + if old_status in TERMINAL_PIPELINE_STATUSES: + logger.debug(f"Pipeline {self.pipeline_id} is in terminal status {old_status}; skipping update") + return old_status # No change from terminal state + + if old_status == PipelineStatus.PAUSED: + logger.debug(f"Pipeline {self.pipeline_id} is paused; skipping status update") + return old_status # No change from paused state + + # The pipeline must not be in a terminal state (from above), but has no jobs. Consider it complete. + if total_jobs == 0: + logger.debug(f"No jobs found in pipeline {self.pipeline_id} - considering pipeline complete") + + self.set_pipeline_status(PipelineStatus.SUCCEEDED) + return PipelineStatus.SUCCEEDED + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Invalid job status data for pipeline {self.pipeline_id}: {e}") + raise PipelineStateError(f"Corrupted job status data for pipeline {self.pipeline_id}: {e}") + + # The pipeline is not in a terminal state and has jobs - determine new status + try: + if status_counts.get(JobStatus.FAILED, 0) > 0: + new_status = PipelineStatus.FAILED + elif status_counts.get(JobStatus.RUNNING, 0) > 0 or status_counts.get(JobStatus.QUEUED, 0) > 0: + new_status = PipelineStatus.RUNNING + + # Pending jobs still exist, don't change the status. + # These might be picked up soon, or they may be proactively + # skipped later if dependencies cannot be met. + # + # Although there is a tension between having only pending + # and succeeded jobs (which would suggest partial/succeeded), + # we leave the status as-is until jobs are actually processed. + # + # *A pipeline with a terminal status must not have pending jobs* + elif status_counts.get(JobStatus.PENDING, 0) > 0: + new_status = old_status + + elif status_counts.get(JobStatus.SUCCEEDED, 0) > 0: + succeeded_jobs = status_counts.get(JobStatus.SUCCEEDED, 0) + skipped_jobs = status_counts.get(JobStatus.SKIPPED, 0) + cancelled_jobs = status_counts.get(JobStatus.CANCELLED, 0) + + if succeeded_jobs == total_jobs: + new_status = PipelineStatus.SUCCEEDED + logger.debug(f"All jobs succeeded in pipeline {self.pipeline_id}") + elif (succeeded_jobs + skipped_jobs + cancelled_jobs) == total_jobs: + new_status = PipelineStatus.PARTIAL + logger.debug(f"Pipeline {self.pipeline_id} completed partially: {status_counts}") + else: + new_status = PipelineStatus.PARTIAL + logger.warning(f"Inconsistent job counts detected for pipeline {self.pipeline_id}: {status_counts}") + send_slack_message( + f"Inconsistent job counts detected for pipeline {self.pipeline_id}: {status_counts}" + ) + + else: + new_status = PipelineStatus.CANCELLED + + if pipeline.status != new_status: + self.set_pipeline_status(new_status) + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Object manipulation failed updating pipeline status for {self.pipeline_id}: {e}") + raise PipelineStateError(f"Failed to update pipeline status for {self.pipeline_id}: {e}") + + if new_status != old_status: + logger.info(f"Pipeline {self.pipeline_id} status successfully updated to {new_status} from {old_status}") + else: + logger.debug(f"No status change for pipeline {self.pipeline_id} (remains {old_status})") + + return new_status + + async def enqueue_ready_jobs(self) -> None: + """Find and enqueue all jobs that are ready to run. + + Identifies pending jobs in the pipeline (including retries) whose dependencies + are satisfied, updates their status to QUEUED, and enqueues them in ARQ. + This handles both independent jobs and jobs with dependencies, as well as + jobs that have been prepared for retry. + + Does not enqueue jobs if the pipeline is paused. + + Raises: + DatabaseConnectionError: Cannot query pending jobs or job dependencies + JobStateError: Cannot update job state to QUEUED (critical failure) + PipelineCoordinationError: One or more jobs failed to enqueue in ARQ + + Process: + 1. Ensure pipeline is running (skip enqueues if not) + 2. Query all PENDING jobs in pipeline (includes retries) + 3. Check dependency requirements for each job + 4. For jobs ready to run: flush status change and enqueue in ARQ + + Note: + - This method handles both independent and dependent jobs uniformly - + any job in PENDING status that meets its dependency requirements + (including jobs with no dependencies) will be enqueued, unless the + pipeline is paused. + + Examples: + Basic usage: + >>> # Enqueue all ready jobs in the pipeline + >>> await pipeline_manager.enqueue_ready_jobs() + + Handling coordination errors: + >>> try: + ... await pipeline_manager.enqueue_ready_jobs() + ... except PipelineCoordinationError as e: + ... logger.error(f"Failed to enqueue some jobs: {e}") + ... # Optionally cancel pipeline or take other recovery actions + """ + current_status = self.get_pipeline_status() + if current_status not in RUNNING_PIPELINE_STATUSES: + logger.error(f"Pipeline {self.pipeline_id} is not running - skipping job enqueue") + raise PipelineStateError( + f"Pipeline {self.pipeline_id} is in status {current_status} and cannot enqueue jobs" + ) + + jobs_to_queue: list[JobRun] = [] + for job in self.get_pending_jobs(): + job_manager = JobManager(self.db, self.redis, job.id) + + # Attempt to enqueue the job if dependencies are met + if self.can_enqueue_job(job): + job_manager.prepare_queue() + jobs_to_queue.append(job) + continue + + should_skip, reason = self.should_skip_job_due_to_dependencies(job) + if should_skip: + job_manager.update_status_message(f"Job skipped: {reason}") + job_manager.skip_job( + { + "status": "skipped", + "exception": None, + "data": {"result": reason, "timestamp": datetime.now().isoformat()}, + } + ) + logger.info(f"Skipped job {job.urn} due to unreachable dependencies: {reason}") + continue + + # Ensure enqueued jobs can view the status change and pipelines + # can view skipped jobs by flushing transactions. + self.db.flush() + + if not jobs_to_queue: + logger.debug(f"No ready jobs to enqueue in pipeline {self.pipeline_id}") + return + + successfully_enqueued = [] + for job in jobs_to_queue: + await self._enqueue_in_arq(job, is_retry=False) + successfully_enqueued.append(job.urn) + logger.info(f"Successfully enqueued job {job.urn}") + + logger.info(f"Successfully enqueued {len(successfully_enqueued)} jobs: {successfully_enqueued}.") + + def cancel_remaining_jobs(self, reason: str = "Pipeline cancelled") -> None: + """Cancel all remaining jobs in the pipeline when the pipeline fails. + + Finds all active pipeline jobs and marks them as SKIPPED or CANCELLED + to prevent further execution when the pipeline has failed. Records the + cancellation reason and timestamp for audit purposes. + + Args: + reason: Human-readable reason for cancellation + + Raises: + DatabaseConnectionError: Cannot query jobs to cancel + PipelineCoordinationError: Failed to cancel one or more jobs + """ + remaining_jobs = self.get_active_jobs() + if not remaining_jobs: + logger.debug(f"No jobs to cancel in pipeline {self.pipeline_id}") + else: + bulk_cancellation_result = construct_bulk_cancellation_result(reason) + + for job in remaining_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + + # Skip PENDING jobs, cancel RUNNING/QUEUED jobs + if job_manager.get_job_status() == JobStatus.PENDING: + job_manager.skip_job(result=bulk_cancellation_result) + logger.debug(f"Skipped job {job.urn}: {reason}") + else: + job_manager.cancel_job(result=bulk_cancellation_result) + logger.debug(f"Cancelled job {job.urn}: {reason}") + + logger.info(f"Cancelled all remaining jobs in pipeline {self.pipeline_id}") + + async def cancel_pipeline(self, reason: str = "Pipeline cancelled") -> None: + """Cancel the entire pipeline and all remaining jobs. + + Sets the pipeline status to CANCELLED and cancels all PENDING and QUEUED + jobs in the pipeline. Records the cancellation reason for audit purposes. + + Args: + reason: Human-readable reason for pipeline cancellation + + Raises: + DatabaseConnectionError: Cannot query or update pipeline/jobs + PipelineCoordinationError: Failed to cancel pipeline or jobs + + Example: + >>> # Cancel a running pipeline due to external event + >>> await pipeline_manager.cancel_pipeline( + ... reason="User requested cancellation" + ... ) + """ + current_status = self.get_pipeline_status() + + if current_status in TERMINAL_PIPELINE_STATUSES: + logger.error(f"Pipeline {self.pipeline_id} is already in terminal status {current_status}") + raise PipelineTransitionError( + f"Pipeline {self.pipeline_id} is in terminal state {current_status} and may not be cancelled" + ) + + self.set_pipeline_status(PipelineStatus.CANCELLED) + self.db.flush() + logger.info(f"Pipeline {self.pipeline_id} cancelled: {reason}") + + await self.coordinate_pipeline() + + async def pause_pipeline(self, reason: str = "Pipeline paused") -> None: + """Pause the pipeline to stop further job execution. + + Sets the pipeline status to PAUSED, preventing new jobs from being enqueued + while allowing currently running jobs to complete. This provides a way to + temporarily halt pipeline execution without cancelling remaining jobs. + + Args: + reason: Human-readable reason for pausing the pipeline + + Raises: + DatabaseConnectionError: Cannot query or update pipeline + JobStateError: Cannot update pipeline state + PipelineTransitionError: Pipeline cannot be paused due to current state + + Example: + >>> # Pause pipeline for maintenance + >>> was_paused = manager.pause_pipeline( + ... reason="System maintenance" + ... ) + """ + current_status = self.get_pipeline_status() + + if current_status in TERMINAL_PIPELINE_STATUSES: + logger.error(f"Pipeline {self.pipeline_id} cannot be paused (current status: {current_status})") + raise PipelineTransitionError( + f"Pipeline {self.pipeline_id} is in terminal state {current_status} and may not be paused" + ) + + if current_status == PipelineStatus.PAUSED: + logger.error(f"Pipeline {self.pipeline_id} is already paused") + raise PipelineTransitionError(f"Pipeline {self.pipeline_id} is already paused") + + self.set_pipeline_status(PipelineStatus.PAUSED) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} paused (was {current_status}): {reason}") + await self.coordinate_pipeline() + + async def unpause_pipeline(self, reason: str = "Pipeline unpaused") -> None: + """Unpause the pipeline and resume job execution. + + Sets the pipeline status from PAUSED back to RUNNING and enqueues any + jobs that are ready to run. This resumes normal pipeline execution + after a pause. + + Args: + reason: Human-readable reason for unpausing the pipeline + + Raises: + DatabaseConnectionError: Cannot query or update pipeline + PipelineStateError: Cannot update pipeline state + PipelineCoordinationError: Failed to enqueue ready jobs after unpause + + Example: + >>> # Resume pipeline after maintenance + >>> was_unpaused = await manager.unpause_pipeline( + ... reason="Maintenance complete" + ... ) + """ + current_status = self.get_pipeline_status() + + if current_status != PipelineStatus.PAUSED: + logger.error( + f"Pipeline {self.pipeline_id} is not paused (current status: {current_status}) and may not be unpaused" + ) + raise PipelineTransitionError( + f"Pipeline {self.pipeline_id} is not paused (current status: {current_status}) and may not be unpaused" + ) + + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} unpaused (was {current_status}): {reason}") + await self.coordinate_pipeline() + + async def restart_pipeline(self) -> None: + """Restart the entire pipeline from the beginning. + + Resets ALL jobs in the pipeline to PENDING status, resets pipeline state to RUNNING, and re-enqueues + independent jobs. This is useful for recovering from pipeline-wide issues. + + Raises: + PipelineCoordinationError: If restart operations fail + DatabaseConnectionError: If database operations fail + + Example: + >>> success = await manager.restart_pipeline("Fixed configuration issue") + >>> print(f"Pipeline restart: {'successful' if success else 'failed'}") + """ + all_jobs = self.get_all_jobs() + if not all_jobs: + logger.debug(f"No jobs found for pipeline {self.pipeline_id} restart") + return + + # Reset all jobs to PENDING status + for job in all_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + job_manager.reset_job() + + # Reset pipeline status to created + self.set_pipeline_status(PipelineStatus.CREATED) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} reset for restart successfully") + await self.start_pipeline() + + def can_enqueue_job(self, job: JobRun) -> bool: + """Check if a job can be enqueued based on dependency requirements. + + Validates that all job dependencies are satisfied according to their + dependency types before allowing enqueue. Prevents premature execution + of jobs that depend on incomplete predecessors. + + Args: + job: JobRun instance to check dependencies for + + Returns: + bool: True if all dependencies are satisfied and job can be enqueued, + False if dependencies are still pending + + Raises: + DatabaseConnectionError: Cannot query job dependencies + JobStateError: Corrupted dependency data detected + + Dependency Types: + - SUCCESS_REQUIRED: Dependent job must have SUCCEEDED status + - COMPLETION_REQUIRED: Dependent job must be SUCCEEDED or FAILED + """ + for dependency, dependent_job in self.get_dependencies_for_job(job): + try: + if not job_dependency_is_met( + dependency_type=dependency.dependency_type, + dependent_job_status=dependent_job.status, + ): + logger.debug(f"Job {job.urn} cannot be enqueued; dependency on job {dependent_job.urn} not met") + return False + + except (AttributeError, KeyError, TypeError, ValueError) as e: + logger.debug(f"Invalid dependency data detected for job {job.id}: {e}") + raise PipelineStateError(f"Corrupted dependency data during enqueue check for job {job.id}: {e}") + + logger.debug(f"All dependencies satisfied for job {job.urn}; ready to enqueue") + return True + + def should_skip_job_due_to_dependencies(self, job: JobRun) -> tuple[bool, str]: + """Check if a job's dependencies are unsatisfiable and the job should be skipped. + + Validates whether a job's dependencies can still be met based on the + current status of dependent jobs. This helps identify jobs that should + be skipped because their dependencies are in terminal non-success states. + + Args: + job: JobRun instance to check dependencies for + + Returns: + tuple[bool, str]: (True, reason) if dependencies cannot be met and job + should be skipped, (False, "") if dependencies may + still be satisfied + + Raises: + DatabaseConnectionError: Cannot query job dependencies + PipelineStateError: Critical state persistence failure + + Notes: + - A job is considered unreachable if any of its dependencies that + require SUCCESS have FAILED, SKIPPED, or CANCELLED status. + - A job is considered unreachable if any of its dependencies that + require COMPLETION have SKIPPED or CANCELLED status. + + Examples: + Basic usage: + >>> should_skip, reason = manager.should_skip_job_due_to_dependencies(job) + >>> if should_skip: + ... print(f"Job should be skipped: {reason}") + >>> else: + ... print("Job dependencies may still be satisfied") + """ + for dependency, dep_job in self.get_dependencies_for_job(job): + try: + should_skip, reason = job_should_be_skipped_due_to_unfulfillable_dependency( + dependency_type=dependency.dependency_type, + dependent_job_status=dep_job.status, + ) + + if should_skip: + logger.debug(f"Job {job.urn} should be skipped due to dependency on job {dep_job.urn}: {reason}") + # guaranteed to be str if should_skip is True + return True, reason # type: ignore + + except (AttributeError, KeyError, TypeError, ValueError) as e: + logger.debug(f"Invalid dependency data detected for job {job.id}: {e}") + raise PipelineStateError(f"Corrupted dependency data during skip check for job {job.id}: {e}") + + logger.debug(f"Job {job.urn} dependencies may still be satisfied; not skipping") + return False, "" + + async def retry_failed_jobs(self) -> None: + """Retry all failed jobs in the pipeline. + + Resets failed jobs to PENDING status and re-enqueues them for execution. + Only affects jobs with FAILED status; other jobs remain unchanged. + + Raises: + PipelineCoordinationError: If job retry fails + DatabaseConnectionError: If database operations fail + + Example: + >>> await manager.retry_failed_jobs() + >>> print("Successfully retried failed jobs") + """ + failed_jobs = self.get_failed_jobs() + if not failed_jobs: + logger.debug(f"No failed jobs found for pipeline {self.pipeline_id}") + return + + for job in failed_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + job_manager.prepare_retry() + + # Ensure the pipeline status is set to running so jobs are picked up + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + await self.coordinate_pipeline() + + async def retry_unsuccessful_jobs(self) -> None: + """Retry all unsuccessful jobs in the pipeline. + + Resets unsuccessful jobs (CANCELLED, SKIPPED, FAILED) to PENDING status + and re-enqueues them for execution. This is useful for recovering from + partial failures or interruptions. + + Raises: + PipelineCoordinationError: If job retry fails + DatabaseConnectionError: If database operations fail + + Example: + >>> await manager.retry_unsuccessful_jobs() + >>> print("Successfully retried unsuccessful jobs") + """ + unsuccessful_jobs = self.get_unsuccessful_jobs() + if not unsuccessful_jobs: + logger.debug(f"No unsuccessful jobs found for pipeline {self.pipeline_id}") + return + + for job in unsuccessful_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + job_manager.prepare_retry() + + # Ensure the pipeline status is set to running so jobs are picked up + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + await self.coordinate_pipeline() + + async def retry_pipeline(self) -> None: + """Retry all unsuccessful jobs in the pipeline. + + Convenience method to retry all jobs that did not complete successfully, + including CANCELLED, SKIPPED, and FAILED jobs. Resets their status to PENDING + and re-enqueues them for execution. + + This is equivalent to calling `retry_unsuccessful_jobs` but provides a clearer + semantic for pipeline-level retries. + """ + await self.retry_unsuccessful_jobs() + + def get_jobs_by_status(self, status: list[JobStatus]) -> Sequence[JobRun]: + """Get all jobs in the pipeline with a specific status. + + Args: + status: JobStatus to filter jobs by + + Returns: + Sequence[JobRun]: List of jobs with the specified status ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> running_jobs = manager.get_jobs_by_status([JobStatus.RUNNING]) + >>> print(f"Found {len(running_jobs)} running jobs") + """ + try: + return ( + self.db.execute( + select(JobRun) + .where(and_(JobRun.pipeline_id == self.pipeline_id, JobRun.status.in_(status))) + .order_by(JobRun.created_at) + ) + .scalars() + .all() + ) + except SQLAlchemyError as e: + logger.debug( + f"Database query failed getting jobs with status {status} for pipeline {self.pipeline_id}: {e}" + ) + raise DatabaseConnectionError(f"Failed to get jobs with status {status}: {e}") + + def get_pending_jobs(self) -> Sequence[JobRun]: + """Get all PENDING jobs in the pipeline. + + Convenience method for fetching all pending jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.PENDING]) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of pending jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> pending_jobs = manager.get_pending_jobs() + >>> print(f"Found {len(pending_jobs)} pending jobs") + """ + return self.get_jobs_by_status([JobStatus.PENDING]) + + def get_running_jobs(self) -> Sequence[JobRun]: + """Get all RUNNING jobs in the pipeline. + + Convenience method for fetching all running jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.RUNNING]) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of running jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> running_jobs = manager.get_running_jobs() + >>> print(f"Found {len(running_jobs)} running jobs") + """ + return self.get_jobs_by_status([JobStatus.RUNNING]) + + def get_active_jobs(self) -> Sequence[JobRun]: + """Get all active jobs in the pipeline. + + Convenience method for fetching all active jobs. This is equivalent + to calling get_jobs_by_status(ACTIVE_JOB_STATUSES) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of remaining jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> active_jobs = manager.get_active_jobs() + >>> print(f"Found {len(active_jobs)} active jobs") + """ + return self.get_jobs_by_status(ACTIVE_JOB_STATUSES) + + def get_failed_jobs(self) -> Sequence[JobRun]: + """Get all failed jobs in the pipeline. + + Convenience method for fetching all failed jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.FAILED]) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of failed jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> failed_jobs = manager.get_failed_jobs() + >>> print(f"Found {len(failed_jobs)} failed jobs for potential retry") + """ + return self.get_jobs_by_status([JobStatus.FAILED]) + + def get_unsuccessful_jobs(self) -> Sequence[JobRun]: + """Get all unsuccessful jobs in the pipeline. + + Convenience method for fetching all unsuccessful (but terminated) jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.FAILED, JobStatus.CANCELLED, JobStatus.SKIPPED]) + but provides clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of unsuccessful jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> unsuccessful_jobs = manager.get_unsuccessful_jobs() + >>> print(f"Found {len(unsuccessful_jobs)} unsuccessful jobs") + """ + return self.get_jobs_by_status(CANCELLED_JOB_STATUSES) + + def get_all_jobs(self) -> Sequence[JobRun]: + """Get all jobs in the pipeline regardless of status. + + Returns: + Sequence[JobRun]: List of all jobs in pipeline ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Examples: + >>> all_jobs = manager.get_all_jobs() + >>> print(f"Total jobs in pipeline: {len(all_jobs)}") + """ + try: + return ( + self.db.execute( + select(JobRun).where(JobRun.pipeline_id == self.pipeline_id).order_by(JobRun.created_at) + ) + .scalars() + .all() + ) + except SQLAlchemyError as e: + logger.debug(f"Database query failed getting all jobs for pipeline {self.pipeline_id}: {e}") + raise DatabaseConnectionError(f"Failed to get all jobs: {e}") + + def get_dependencies_for_job(self, job: JobRun) -> Sequence[tuple[JobDependency, JobRun]]: + """Get all dependencies for a specific job. + + Args: + job: JobRun instance to fetch dependencies for + + Returns: + Sequence[Row[tuple[JobDependency, JobRun]]]: List of dependencies with associated JobRun instances + + Raises: + DatabaseConnectionError: Cannot query job dependencies + + Examples: + >>> dependencies = manager.get_dependencies_for_job(job) + >>> for dependency, dep_job in dependencies: + ... print(f"Job {job.urn} depends on job {dep_job.urn} with dependency type {dependency.dependency_type}") + """ + try: + # Although the returned type wraps tuples in a row, the contents are still accessible as tuples. + # This allows unpacking as shown in the example, and we can ignore the type checker warning so + # callers can have access to the simpler interface. + return self.db.execute( + select(JobDependency, JobRun) + .join(JobRun, JobDependency.depends_on_job_id == JobRun.id) + .where(JobDependency.id == job.id) + ).all() # type: ignore + except SQLAlchemyError as e: + logger.debug(f"SQL query failed for dependencies of job {job.id}: {e}") + raise DatabaseConnectionError(f"Failed to get job dependencies for job {job.id}: {e}") + + def get_pipeline(self) -> Pipeline: + """Get the Pipeline instance for this manager. + + Returns: + Pipeline: The Pipeline instance associated with this manager + + Raises: + DatabaseConnectionError: Cannot query pipeline information + + Examples: + >>> pipeline = manager.get_pipeline() + >>> print(f"Pipeline ID: {pipeline.id}, Status: {pipeline.status}") + """ + + try: + return self.db.execute(select(Pipeline).where(Pipeline.id == self.pipeline_id)).scalar_one() + except SQLAlchemyError as e: + logger.debug(f"Database query failed getting pipeline {self.pipeline_id}: {e}") + raise DatabaseConnectionError(f"Failed to get pipeline {self.pipeline_id}: {e}") + + def get_job_counts_by_status(self) -> dict[JobStatus, int]: + """Get count of jobs by status for monitoring. + + Returns a simple dictionary mapping job statuses to their counts, + useful for dashboard displays and monitoring systems. + + Returns: + dict[JobStatus, int]: Dictionary mapping JobStatus to count + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> counts = manager.get_job_counts_by_status() + >>> print(f"Failed jobs: {counts.get(JobStatus.FAILED, 0)}") + """ + try: + job_counts = self.db.execute( + select(JobRun.status, func.count(JobRun.id)) + .where(JobRun.pipeline_id == self.pipeline_id) + .group_by(JobRun.status) + ).all() + except SQLAlchemyError as e: + logger.debug(f"Database query failed getting job counts for pipeline {self.pipeline_id}: {e}") + raise DatabaseConnectionError(f"Failed to get job counts for pipeline {self.pipeline_id}: {e}") + + return {status: count for status, count in job_counts} + + def get_pipeline_progress(self) -> dict: + """Get detailed pipeline progress statistics. + + Provides comprehensive pipeline progress information including job counts, + completion percentage, duration, and estimated completion time. + + Returns: + dict: Pipeline progress statistics with the following keys: + - total_jobs: Total number of jobs in pipeline + - completed_jobs: Number of jobs in terminal states + - successful_jobs: Number of successfully completed jobs + - failed_jobs: Number of failed jobs + - running_jobs: Number of currently running jobs + - pending_jobs: Number of jobs waiting to run + - completion_percentage: Percentage of jobs completed (0-100) + - duration: Time pipeline has been running (in seconds) + - status_counts: Dictionary of job counts by status + + Raises: + DatabaseConnectionError: Cannot query pipeline or job information + + Example: + >>> progress = manager.get_pipeline_progress() + >>> print(f"Pipeline {progress['completion_percentage']:.1f}% complete") + """ + status_counts = self.get_job_counts_by_status() + pipeline = self.get_pipeline() + + try: + total_jobs = sum(status_counts.values()) + + if total_jobs == 0: + return { + "total_jobs": 0, + "completed_jobs": 0, + "successful_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "pending_jobs": 0, + "completion_percentage": 100.0, + "duration": 0, + "status_counts": {}, + } + + # Calculate progress metrics + successful_jobs = status_counts.get(JobStatus.SUCCEEDED, 0) + failed_jobs = status_counts.get(JobStatus.FAILED, 0) + running_jobs = status_counts.get(JobStatus.RUNNING, 0) + status_counts.get(JobStatus.QUEUED, 0) + pending_jobs = status_counts.get(JobStatus.PENDING, 0) + skipped_jobs = status_counts.get(JobStatus.SKIPPED, 0) + cancelled_jobs = status_counts.get(JobStatus.CANCELLED, 0) + + completed_jobs = successful_jobs + failed_jobs + skipped_jobs + cancelled_jobs + completion_percentage = (completed_jobs / total_jobs) * 100 if total_jobs > 0 else 0 + + # Calculate duration + duration = 0 + if pipeline.created_at: + end_time = pipeline.finished_at or datetime.now() + duration = int((end_time - pipeline.created_at).total_seconds()) + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Invalid data detected calculating progress for pipeline {self.pipeline_id}: {e}") + raise PipelineStateError(f"Corrupted data during progress calculation for pipeline {self.pipeline_id}: {e}") + + return { + "total_jobs": total_jobs, + "completed_jobs": completed_jobs, + "successful_jobs": successful_jobs, + "failed_jobs": failed_jobs, + "running_jobs": running_jobs, + "pending_jobs": pending_jobs, + "completion_percentage": completion_percentage, + "duration": duration, + "status_counts": status_counts, + } + + def get_pipeline_status(self) -> PipelineStatus: + """Get the current status of the pipeline. + + Returns: + PipelineStatus: Current status of the pipeline + + Raises: + DatabaseConnectionError: Cannot query pipeline information + + Example: + >>> status = manager.get_pipeline_status() + >>> print(f"Pipeline status: {status}") + """ + return self.get_pipeline().status + + def set_pipeline_status(self, new_status: PipelineStatus) -> None: + """Set the status of the pipeline. + + Args: + new_status: PipelineStatus enum value to set the pipeline to + + Raises: + DatabaseConnectionError: Cannot query or update pipeline information + PipelineStateError: Cannot update pipeline status + + Example: + >>> manager.set_pipeline_status(PipelineStatus.PAUSED) + >>> print("Pipeline paused") + + Note: + This method does not perform any validation on the status transition, + nor does it attempt to coordinate the pipeline after the status change + or flush the change to the database. + """ + pipeline = self.get_pipeline() + try: + pipeline.status = new_status + + # Ensure finished_at is set/cleared appropriately + if new_status in TERMINAL_PIPELINE_STATUSES: + pipeline.finished_at = datetime.now() + else: + pipeline.finished_at = None + + # Ensure started_at is set/cleared appropriately + if new_status == PipelineStatus.CREATED: + pipeline.started_at = None + elif new_status == PipelineStatus.RUNNING and pipeline.started_at is None: + pipeline.started_at = datetime.now() + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Object manipulation failed setting status for pipeline {self.pipeline_id}: {e}") + raise PipelineStateError(f"Failed to set pipeline status for {self.pipeline_id}: {e}") + + logger.info(f"Pipeline {self.pipeline_id} status set to {new_status}") + + async def _enqueue_in_arq(self, job: JobRun, is_retry: bool) -> None: + """Enqueue a job in ARQ with proper error handling and retry delay. + + Args: + job: JobRun instance to enqueue + is_retry: Whether this is a retry attempt + + Raises: + PipelineCoordinationError: If ARQ enqueuing fails + """ + if not self.redis: + logger.error(f"Redis client is not configured for PipelineManager; cannot enqueue job {job.urn}") + raise PipelineCoordinationError("Redis client is not configured for job enqueueing; cannot proceed.") + + try: + defer_by = timedelta(seconds=job.retry_delay_seconds if is_retry and job.retry_delay_seconds else 0) + arq_success = await self.redis.enqueue_job(job.job_function, job.id, _defer_by=defer_by, _job_id=job.urn) + except Exception as e: + logger.debug(f"ARQ enqueue operation failed for job {job.urn}: {e}") + raise PipelineCoordinationError(f"Failed to enqueue job in ARQ: {e}") + + if arq_success: + logger.info(f"{'Retried' if is_retry else 'Enqueued'} job {job.urn} in ARQ") + else: + logger.info(f"Job {job.urn} has already been enqueued in ARQ") diff --git a/src/mavedb/worker/lib/managers/py.typed b/src/mavedb/worker/lib/managers/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/lib/managers/types.py b/src/mavedb/worker/lib/managers/types.py new file mode 100644 index 00000000..475b28a2 --- /dev/null +++ b/src/mavedb/worker/lib/managers/types.py @@ -0,0 +1,26 @@ +from typing import Literal, Optional, TypedDict + + +class JobResultData(TypedDict): + status: Literal["ok", "failed", "skipped", "exception", "cancelled"] + data: dict + exception: Optional[Exception] + + +class RetryHistoryEntry(TypedDict): + attempt: int + timestamp: str + result: JobResultData + reason: str + + +class PipelineProgress(TypedDict): + total_jobs: int + completed_jobs: int + successful_jobs: int + failed_jobs: int + running_jobs: int + pending_jobs: int + completion_percentage: float + duration: int # seconds + status_counts: dict diff --git a/src/mavedb/worker/lib/managers/utils.py b/src/mavedb/worker/lib/managers/utils.py new file mode 100644 index 00000000..975fc7d6 --- /dev/null +++ b/src/mavedb/worker/lib/managers/utils.py @@ -0,0 +1,107 @@ +"""Utility functions for job and pipeline management. + +This module provides helper functions for common operations in job and pipeline +management, such as creating standardized result structures, data formatting, and +dependency checking. +""" + +import logging +from datetime import datetime +from typing import Literal, Optional, Union + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus +from mavedb.worker.lib.managers.constants import COMPLETED_JOB_STATUSES +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +def construct_bulk_cancellation_result(reason: str) -> JobResultData: + """Construct a standardized JobResultData structure for bulk job cancellations. + + Args: + reason: Human-readable reason for the cancellation + + Returns: + JobResultData: Standardized result data with cancellation metadata + """ + return { + "status": "cancelled", + "data": { + "reason": reason, + "timestamp": datetime.now().isoformat(), + }, + "exception": None, + } + + +def job_dependency_is_met(dependency_type: Optional[DependencyType], dependent_job_status: JobStatus) -> bool: + """Check if a job dependency is met based on the dependency type and the status of the dependent job. + + Args: + dependency_type: Type of dependency ('hard' or 'soft') + dependent_job_status: Status of the dependent job + + Returns: + bool: True if the dependency is met, False otherwise + + Notes: + - For 'hard' dependencies, the dependent job must have succeeded. + - For 'soft' dependencies, the dependent job must be in a terminal state. + - If no dependency type is specified, the dependency is considered met. + """ + if not dependency_type: + logger.debug("No dependency type specified; assuming dependency is met.") + return True + + if dependency_type == DependencyType.SUCCESS_REQUIRED: + if dependent_job_status != JobStatus.SUCCEEDED: + logger.debug(f"Dependency not met: dependent job did not succeed ({dependent_job_status}).") + return False + + if dependency_type == DependencyType.COMPLETION_REQUIRED: + if dependent_job_status not in COMPLETED_JOB_STATUSES: + logger.debug( + f"Dependency not met: dependent job has not reached a completed status ({dependent_job_status})." + ) + return False + + return True + + +def job_should_be_skipped_due_to_unfulfillable_dependency( + dependency_type: Optional[DependencyType], dependent_job_status: JobStatus +) -> Union[tuple[Literal[False], None], tuple[Literal[True], str]]: + """Determine if a job should be skipped due to an unfulfillable dependency. + + Args: + dependency_type: Type of dependency ('hard' or 'soft') + dependent_job_status: Status of the dependent job + + Returns: + Union[tuple[Literal[False], None], tuple[Literal[True], str]]: Tuple indicating + if the job should be skipped and the reason + + Notes: + - A job should be skipped if it has a 'hard' dependency and the dependent job did not succeed. + """ + + # If dependency must have SUCCEEDED but is in a terminal non-success state, skip. + if dependency_type == DependencyType.SUCCESS_REQUIRED: + if dependent_job_status in (JobStatus.FAILED, JobStatus.SKIPPED, JobStatus.CANCELLED): + logger.debug( + f"Job should be skipped due to unfulfillable 'success_required' dependency " + f"({dependent_job_status})." + ) + return True, f"Dependency did not succeed ({dependent_job_status})" + + # If dependency requires 'completion' and you want CANCELLED to NOT qualify, skip here too. + if dependency_type == DependencyType.COMPLETION_REQUIRED: + if dependent_job_status in (JobStatus.CANCELLED, JobStatus.SKIPPED): + logger.debug( + f"Job should be skipped due to unfulfillable 'completion_required' dependency " + f"({dependent_job_status})." + ) + return True, f"Dependency was not completed successfully ({dependent_job_status})" + + return False, None diff --git a/src/mavedb/worker/lib/py.typed b/src/mavedb/worker/lib/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/pipeline_management.md b/src/mavedb/worker/pipeline_management.md new file mode 100644 index 00000000..02ee5694 --- /dev/null +++ b/src/mavedb/worker/pipeline_management.md @@ -0,0 +1,29 @@ +# Pipeline Management + +Pipeline management in the ARQ worker system allows for the orchestration of complex workflows composed of multiple dependent jobs. Pipelines are coordinated using the `PipelineManager` and the `with_pipeline_management` decorator. + +## Key Concepts +- **Pipeline**: A collection of jobs with defined dependencies and a shared execution context. +- **PipelineManager**: Handles pipeline status, job dependencies, pausing/unpausing, and cancellation. +- **with_pipeline_management**: Decorator that ensures pipeline coordination after job completion. + +## Usage Patterns +- Use pipelines for workflows that require multiple jobs to run in sequence or with dependencies. +- Each job in a pipeline should be decorated with `with_pipeline_management`. +- Pipelines are defined and started outside the decorator; the decorator only coordinates after job completion. + +## Example +```python +@with_pipeline_management +async def validate_and_map_variants(ctx, ...): + ... +``` + +## Features +- Automatic pipeline status updates +- Dependency management and job coordination +- Robust error handling and logging + +## See Also +- [Job Managers](job_managers.md) +- [Job Decorators](job_decorators.md) diff --git a/src/mavedb/worker/py.typed b/src/mavedb/worker/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/settings.py b/src/mavedb/worker/settings.py deleted file mode 100644 index 0a9359d5..00000000 --- a/src/mavedb/worker/settings.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -from concurrent import futures -from datetime import timedelta -from typing import Callable - -from arq.connections import RedisSettings -from arq.cron import CronJob, cron - -from mavedb.data_providers.services import cdot_rest -from mavedb.db.session import SessionLocal -from mavedb.lib.logging.canonical import log_job -from mavedb.worker.jobs import ( - create_variants_for_score_set, - map_variants_for_score_set, - variant_mapper_manager, - refresh_materialized_views, - refresh_published_variants_view, - submit_score_set_mappings_to_ldh, - link_clingen_variants, - poll_uniprot_mapping_jobs_for_score_set, - submit_uniprot_mapping_jobs_for_score_set, - link_gnomad_variants, - submit_score_set_mappings_to_car, -) - -# ARQ requires at least one task on startup. -BACKGROUND_FUNCTIONS: list[Callable] = [ - create_variants_for_score_set, - variant_mapper_manager, - map_variants_for_score_set, - refresh_published_variants_view, - submit_score_set_mappings_to_ldh, - link_clingen_variants, - poll_uniprot_mapping_jobs_for_score_set, - submit_uniprot_mapping_jobs_for_score_set, - link_gnomad_variants, - submit_score_set_mappings_to_car, -] -# In UTC time. Depending on daylight savings time, this will bounce around by an hour but should always be very early in the morning -# for all of the USA. -BACKGROUND_CRONJOBS: list[CronJob] = [ - cron( - refresh_materialized_views, - name="refresh_all_materialized_views", - hour=20, - minute=0, - keep_result=timedelta(minutes=2).total_seconds(), - ) -] - -REDIS_IP = os.getenv("REDIS_IP") or "localhost" -REDIS_PORT = int(os.getenv("REDIS_PORT") or 6379) -REDIS_SSL = (os.getenv("REDIS_SSL") or "false").lower() == "true" - - -RedisWorkerSettings = RedisSettings(host=REDIS_IP, port=REDIS_PORT, ssl=REDIS_SSL) - - -async def startup(ctx): - ctx["pool"] = futures.ProcessPoolExecutor() - - -async def shutdown(ctx): - pass - - -async def on_job_start(ctx): - db = SessionLocal() - db.current_user_id = None - ctx["db"] = db - ctx["hdp"] = cdot_rest() - ctx["state"] = {} - - -async def on_job_end(ctx): - db = ctx["db"] - db.close() - - -class ArqWorkerSettings: - """ - Settings for the ARQ worker. - """ - - on_startup = startup - on_shutdown = shutdown - on_job_start = on_job_start - on_job_end = on_job_end - after_job_end = log_job - redis_settings = RedisWorkerSettings - functions: list = BACKGROUND_FUNCTIONS - cron_jobs: list = BACKGROUND_CRONJOBS - - job_timeout = 5 * 60 * 60 # Keep jobs alive for a long while... diff --git a/src/mavedb/worker/settings/__init__.py b/src/mavedb/worker/settings/__init__.py new file mode 100644 index 00000000..af2e6a27 --- /dev/null +++ b/src/mavedb/worker/settings/__init__.py @@ -0,0 +1,19 @@ +"""Worker settings configuration. + +This module provides ARQ worker settings organized by concern: +- constants: Environment variable configuration +- redis: Redis connection settings +- lifecycle: Worker startup/shutdown hooks +- worker: Main ARQ worker configuration class + +The settings are designed to be modular and easily testable, +with clear separation between infrastructure and application concerns. +""" + +from .redis import RedisWorkerSettings +from .worker import ArqWorkerSettings + +__all__ = [ + "ArqWorkerSettings", + "RedisWorkerSettings", +] diff --git a/src/mavedb/worker/settings/constants.py b/src/mavedb/worker/settings/constants.py new file mode 100644 index 00000000..b5e8f23d --- /dev/null +++ b/src/mavedb/worker/settings/constants.py @@ -0,0 +1,12 @@ +"""Environment configuration constants for worker settings. + +This module centralizes all environment variable handling for the worker, +providing sensible defaults and type conversion for configuration values. +All worker-related environment variables should be defined here. +""" + +import os + +REDIS_IP = os.getenv("REDIS_IP") or "localhost" +REDIS_PORT = int(os.getenv("REDIS_PORT") or 6379) +REDIS_SSL = (os.getenv("REDIS_SSL") or "false").lower() == "true" diff --git a/src/mavedb/worker/settings/lifecycle.py b/src/mavedb/worker/settings/lifecycle.py new file mode 100644 index 00000000..7e5f933f --- /dev/null +++ b/src/mavedb/worker/settings/lifecycle.py @@ -0,0 +1,44 @@ +"""Worker lifecycle management hooks. + +This module defines the startup, shutdown, and job lifecycle hooks +for the ARQ worker. These hooks manage: +- Process pool for CPU-intensive tasks +- HGVS data provider setup +- Job state initialization and cleanup +""" + +from concurrent import futures + +from mavedb.data_providers.services import cdot_rest + + +def standalone_ctx(): + """Create a standalone worker context dictionary.""" + ctx = {} + ctx["pool"] = futures.ProcessPoolExecutor() + ctx["redis"] = None # Redis connection can be set up here if needed. + ctx["hdp"] = cdot_rest() + ctx["state"] = {} + + # Additional context setup can be added here as needed. + # This function should not drift from the lifecycle hooks + # below and is useful for invoking worker jobs outside of ARQ. + + return ctx + + +async def startup(ctx): + ctx["pool"] = futures.ProcessPoolExecutor() + + +async def shutdown(ctx): + pass + + +async def on_job_start(ctx): + ctx["hdp"] = cdot_rest() + ctx["state"] = {} + + +async def on_job_end(ctx): + pass diff --git a/src/mavedb/worker/settings/redis.py b/src/mavedb/worker/settings/redis.py new file mode 100644 index 00000000..2773f77f --- /dev/null +++ b/src/mavedb/worker/settings/redis.py @@ -0,0 +1,12 @@ +"""Redis connection settings for ARQ worker. + +This module provides Redis connection configuration using environment +variables with appropriate defaults. The settings are compatible with +ARQ's RedisSettings class and handle SSL connections. +""" + +from arq.connections import RedisSettings + +from mavedb.worker.settings.constants import REDIS_IP, REDIS_PORT, REDIS_SSL + +RedisWorkerSettings = RedisSettings(host=REDIS_IP, port=REDIS_PORT, ssl=REDIS_SSL) diff --git a/src/mavedb/worker/settings/worker.py b/src/mavedb/worker/settings/worker.py new file mode 100644 index 00000000..03bad1f3 --- /dev/null +++ b/src/mavedb/worker/settings/worker.py @@ -0,0 +1,33 @@ +"""Main ARQ worker configuration class. + +This module defines the primary ArqWorkerSettings class that brings together +all worker configuration including: +- Job functions and cron jobs from the jobs registry +- Redis connection settings +- Lifecycle hooks for startup/shutdown and job execution +- Timeout and logging configuration + +This is the main configuration class used to start the ARQ worker. +""" + +from mavedb.lib.logging.canonical import log_job +from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS +from mavedb.worker.settings.lifecycle import on_job_end, on_job_start, shutdown, startup +from mavedb.worker.settings.redis import RedisWorkerSettings + + +class ArqWorkerSettings: + """ + Settings for the ARQ worker. + """ + + on_startup = startup + on_shutdown = shutdown + on_job_start = on_job_start + on_job_end = on_job_end + after_job_end = log_job + redis_settings = RedisWorkerSettings + functions: list = BACKGROUND_FUNCTIONS + cron_jobs: list = BACKGROUND_CRONJOBS + + job_timeout = 5 * 60 * 60 # Keep jobs alive for a long while... diff --git a/tests/conftest.py b/tests/conftest.py index b11f728c..acebc569 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,15 @@ import logging # noqa: F401 +import os import sys +from contextlib import contextmanager from datetime import datetime from unittest import mock import email_validator import pytest import pytest_postgresql -from sqlalchemy import create_engine +import pytest_socket +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool @@ -57,6 +60,21 @@ email_validator.TEST_ENVIRONMENT = True +def pytest_runtest_setup(item): + # Only block sockets for tests not marked with 'network' + if "network" not in item.keywords: + try: + pytest_socket.socket_allow_hosts(["localhost", "127.0.0.1", "::1"], allow_unix_socket=True) + except ImportError: + pass + + else: + try: + pytest_socket.enable_socket() + except ImportError: + pass + + @pytest.fixture() def session(postgresql): # Un-comment this line to log all database queries: @@ -72,6 +90,15 @@ def session(postgresql): Base.metadata.create_all(bind=engine) + # Create a unique index for the published_variants_materialized_view to + # enforce uniqueness on (variant_id, mapped_variant_id, score_set_id). This + # allows us to test mat view refreshes that require this constraint. + session.execute( + text("""CREATE UNIQUE INDEX IF NOT EXISTS published_variants_mv_unique_idx + ON published_variants_materialized_view (variant_id, mapped_variant_id, score_set_id)"""), + ) + session.commit() + try: yield session finally: @@ -79,6 +106,36 @@ def session(postgresql): Base.metadata.drop_all(bind=engine) +@pytest.fixture +def db_session_fixture(session): + @contextmanager + def _db_session_cm(): + yield session + + return _db_session_cm + + +# ALL locations which use the db_session fixture need to be patched to use +# the test version. +@pytest.fixture +def patch_db_session_ctxmgr(db_session_fixture): + """Patches all known locations of the db_session fixture to use the test version. + + To use this fixture, add it to the pytestmark list of a test module: + pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + If you see an error about a test being unable to connect to the database, you + likely need to add another patch here for the module that is trying to use + db_session or include the above mark in your test module. + """ + with ( + mock.patch("mavedb.db.session.db_session", db_session_fixture), + mock.patch("mavedb.worker.lib.decorators.utils.db_session", db_session_fixture), + # Add other modules that use db_session here as needed + ): + yield + + @pytest.fixture def setup_lib_db(session): """ @@ -336,3 +393,13 @@ def test_needing_publication_identifier_mock(mock_publication_fetch, ...): mocked_publications.append(publication_to_mock) # Return a single dict (original behavior) if only one was provided; otherwise the list. return mocked_publications[0] if len(mocked_publications) == 1 else mocked_publications + + +# Automatically set MAVEDB_TEST_MODE=1 for unit tests, unset for integration tests. +@pytest.fixture(autouse=True) +def set_mavedb_test_mode_flag(request): + # If 'unit' marker is present, set the flag; otherwise, unset it. + if request.node.get_closest_marker("unit"): + os.environ["MAVEDB_TEST_MODE"] = "1" + else: + os.environ.pop("MAVEDB_TEST_MODE", None) diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index 8597c4f9..579fbd5c 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -1,9 +1,10 @@ import os +import shutil +import tempfile from concurrent import futures from inspect import getsourcefile from posixpath import abspath -import shutil -import tempfile +from unittest.mock import patch import cdot.hgvs.dataproviders import pytest @@ -12,16 +13,18 @@ from biocommons.seqrepo import SeqRepo from fastapi.testclient import TestClient from httpx import AsyncClient -from unittest.mock import patch +from sqlalchemy import Column, Float, Integer, MetaData, String, Table +from mavedb.db.session import create_engine, sessionmaker +from mavedb.deps import get_db, get_seqrepo, get_worker, hgvs_data_provider from mavedb.lib.authentication import UserData, get_current_user from mavedb.lib.authorization import require_current_user +from mavedb.lib.gnomad import gnomad_table_name from mavedb.models.user import User from mavedb.server_main import app -from mavedb.deps import get_db, get_worker, hgvs_data_provider, get_seqrepo -from mavedb.worker.settings import BACKGROUND_FUNCTIONS, BACKGROUND_CRONJOBS - -from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER +from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS +from mavedb.worker.lib.managers.types import JobResultData +from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER, VALID_CAID #################################################################################################### # REDIS @@ -78,6 +81,10 @@ def some_test(client, arq_redis): await redis_.aclose(close_connection_pool=True) +async def dummy_arq_function(ctx, *args, **kwargs) -> JobResultData: + return {"status": "ok", "data": {}, "exception_details": None} + + @pytest_asyncio.fixture() async def arq_worker(data_provider, session, arq_redis): """ @@ -87,7 +94,7 @@ async def arq_worker(data_provider, session, arq_redis): ``` async def worker_test(arq_redis, arq_worker): - await arq_redis.enqueue_job('some_job') + await arq_redis.enqueue_job('dummy_arq_function') await arq_worker.async_run() await arq_worker.run_check() ``` @@ -103,7 +110,7 @@ async def on_job(ctx): ctx["pool"] = futures.ProcessPoolExecutor() worker_ = Worker( - functions=BACKGROUND_FUNCTIONS, + functions=BACKGROUND_FUNCTIONS + [dummy_arq_function], cron_jobs=BACKGROUND_CRONJOBS, redis_pool=arq_redis, burst=True, @@ -120,9 +127,8 @@ async def on_job(ctx): @pytest.fixture -def standalone_worker_context(session, data_provider, arq_redis): +def standalone_worker_context(data_provider, arq_redis): yield { - "db": session, "hdp": data_provider, "state": {}, "job_id": "test_job", @@ -401,3 +407,58 @@ def client(app_): async def async_client(app_): async with AsyncClient(app=app_, base_url="http://testserver") as ac: yield ac + + +##################################################################################################### +# Athena +##################################################################################################### + + +@pytest.fixture +def athena_engine(): + """Create and yield a SQLAlchemy engine connected to a mock Athena database.""" + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + + # TODO: Define your table schema here + my_table = Table( + gnomad_table_name(), + metadata, + Column("id", Integer, primary_key=True), + Column("locus.contig", String), + Column("locus.position", Integer), + Column("alleles", String), + Column("caid", String), + Column("joint.freq.all.ac", Integer), + Column("joint.freq.all.an", Integer), + Column("joint.fafmax.faf95_max_gen_anc", String), + Column("joint.fafmax.faf95_max", Float), + ) + metadata.create_all(engine) + + session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() + + # Insert test data + session.execute( + my_table.insert(), + [ + { + "id": 1, + "locus.contig": "chr1", + "locus.position": 12345, + "alleles": "[G, A]", + "caid": VALID_CAID, + "joint.freq.all.ac": 23, + "joint.freq.all.an": 32432423, + "joint.fafmax.faf95_max_gen_anc": "anc1", + "joint.fafmax.faf95_max": 0.000006763700000000002, + } + ], + ) + session.commit() + session.close() + + try: + yield engine + finally: + engine.dispose() diff --git a/tests/helpers/constants.py b/tests/helpers/constants.py index 32918235..208a61e2 100644 --- a/tests/helpers/constants.py +++ b/tests/helpers/constants.py @@ -43,6 +43,7 @@ VALID_PRO_ACCESSION = "NP_001637.4" VALID_GENE = "BRCA1" VALID_UNIPROT_ACCESSION = "P05067" +VALID_CAID = "CA9765210" VALID_ENSEMBL_IDENTIFIER = "ENST00000530893.6" @@ -1209,52 +1210,35 @@ }, } -TEST_CODING_LAYER = { +TEST_PROTEIN_LAYER = { + "computed_reference_sequence": { + "sequence_type": "protein", + "sequence_id": "ga4gh:SQ.ref_protein_test", + "sequence": "MKTIIALSYIFCLVFADYKDDDDK", + }, "mapped_reference_sequence": { - "sequence_accessions": [VALID_NT_ACCESSION], + "sequence_type": "protein", + "sequence_id": "ga4gh:SQ.map_protein_test", + "sequence_accessions": [VALID_PRO_ACCESSION], }, } -TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD = { - "metadata": {}, - "reference_sequences": { - "TEST1": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - } +TEST_CODING_LAYER = { + "computed_reference_sequence": { + "sequence_type": "coding", + "sequence_id": "ga4gh:SQ.ref_coding_test", + "sequence": "ATGAAGACGATTATTGCTCTTATCTTTCCTCTTTTGCTGATATACGACGACGACAAA", }, - "mapped_scores": [], - "vrs_version": "2.0", - "dcd_mapping_version": "pytest.0.0", - "mapped_date_utc": datetime.isoformat(datetime.now()), -} - -TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD = { - "metadata": {}, - "reference_sequences": { - "TEST2": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - } + "mapped_reference_sequence": { + "sequence_type": "coding", + "sequence_id": "ga4gh:SQ.map_coding_test", + "sequence_accessions": [VALID_NT_ACCESSION], }, - "mapped_scores": [], - "vrs_version": "2.0", - "dcd_mapping_version": "pytest.0.0", - "mapped_date_utc": datetime.isoformat(datetime.now()), } -TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD = { +TEST_MAPPING_SCAFFOLD = { "metadata": {}, - "reference_sequences": { - "TEST3": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - }, - "TEST4": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - }, - }, + "reference_sequences": {}, "mapped_scores": [], "vrs_version": "2.0", "dcd_mapping_version": "pytest.0.0", diff --git a/tests/helpers/transaction_spy.py b/tests/helpers/transaction_spy.py new file mode 100644 index 00000000..4381aa75 --- /dev/null +++ b/tests/helpers/transaction_spy.py @@ -0,0 +1,222 @@ +from contextlib import contextmanager +from typing import Generator, TypedDict, Union +from unittest.mock import AsyncMock, MagicMock, patch + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from tests.helpers.util.common import create_failing_side_effect + + +class TransactionSpy: + """Factory for creating database transaction spy context managers.""" + + class Spies(TypedDict): + flush: Union[MagicMock, AsyncMock] + rollback: Union[MagicMock, AsyncMock] + commit: Union[MagicMock, AsyncMock] + + class SpiesWithException(Spies): + exception: Exception + + @staticmethod + @contextmanager + def spy( + session: Session, + expect_rollback: bool = False, + expect_flush: bool = False, + expect_commit: bool = False, + ) -> Generator[Spies, None, None]: + """ + Create spies for database transaction methods. + + Args: + session: Database session to spy on + expect_rollback: Whether to assert db.rollback to be called + expect_flush: Whether to assert db.flush to be called + expect_commit: Whether to assert db.commit to be called + + Yields: + dict: Dictionary containing all the spies for granular assertion + + Note: + Use caution when combining expectations. For example, if expect_commit + is True, you may wish to set expect_flush to True as well, since commit + typically implies a flush operation within SQLAlchemy internals. + + Example: + ``` + with TransactionSpy.spy(session, expect_rollback=True) as spies: + # perform operation + ... + + # Make manual granular assertions on spies if desired + spies['rollback'].assert_called_once() + + # if assert_XXX=True is set, automatic assertions will be made at context exit. + # In this example, expect_rollback=True will ensure rollback was called at some point. + ``` + """ + with ( + patch.object(session, "rollback", wraps=session.rollback) as rollback_spy, + patch.object(session, "flush", wraps=session.flush) as flush_spy, + patch.object(session, "commit", wraps=session.commit) as commit_spy, + ): + spies: TransactionSpy.Spies = { + "flush": flush_spy, + "rollback": rollback_spy, + "commit": commit_spy, + } + + yield spies + + # Automatic assertions based on session expectations. + if expect_flush: + flush_spy.assert_called() + else: + flush_spy.assert_not_called() + if expect_rollback: + rollback_spy.assert_called() + else: + rollback_spy.assert_not_called() + if expect_commit: + commit_spy.assert_called() + else: + commit_spy.assert_not_called() + + @staticmethod + @contextmanager + def mock_database_execution_failure( + session: Session, + exception=None, + fail_on_call=1, + expect_rollback: bool = False, + expect_flush: bool = False, + expect_commit: bool = False, + ) -> Generator[SpiesWithException, None, None]: + """ + Create a context that mocks database execution failures with transaction spies. This context + will automatically assert calls to rollback, flush, and commit based on the provided expectations + which all default to False. + + Args: + session: Database session to mock + exception: Exception to raise (defaults to SQLAlchemyError) + fail_on_call: Which call should fail (defaults to first call) + expect_rollback: Whether to assert rollback called (defaults to False) + expect_flush: Whether to assert flush called (defaults to False) + expect_commit: Whether to assert commit called (defaults to False) + Yields: + dict: Dictionary containing spies and the exception that will be raised + """ + exception = exception or SQLAlchemyError("DB Error") + + with ( + patch.object( + session, + "execute", + side_effect=create_failing_side_effect(exception, session.execute, fail_on_call), + ), + TransactionSpy.spy( + session, + expect_rollback=expect_rollback, + expect_flush=expect_flush, + expect_commit=expect_commit, + ) as transaction_spies, + ): + spies: TransactionSpy.SpiesWithException = { + **transaction_spies, + "exception": exception, + } + + yield spies + + @staticmethod + @contextmanager + def mock_database_flush_failure( + session: Session, + exception=None, + fail_on_call=1, + expect_rollback: bool = True, + expect_flush: bool = True, + expect_commit: bool = False, + ) -> Generator[SpiesWithException, None, None]: + """ + Create a context that mocks flush failures specifically. This context will automatically + assert that rollback and flush are called, and that commit is not called. These automatic + assertions can be overridden via the expect_XXX parameters. + + Args: + session: Database session to mock + exception: Exception to raise on flush (defaults to SQLAlchemyError) + fail_on_call: Which flush call should fail (defaults to first call) + expect_rollback: Whether to assert rollback called (defaults to True) + expect_flush: Whether to assert flush called (defaults to True) + expect_commit: Whether to assert commit called (defaults to False) + Yields: + dict: Dictionary containing spies and the exception + """ + exception = exception or SQLAlchemyError("Flush Error") + + with ( + patch.object( + session, "flush", side_effect=create_failing_side_effect(exception, session.flush, fail_on_call) + ), + TransactionSpy.spy( + session, + expect_rollback=expect_rollback, + expect_flush=expect_flush, + expect_commit=expect_commit, + ) as transaction_spies, + ): + spies: TransactionSpy.SpiesWithException = { + **transaction_spies, + "exception": exception, + } + + yield spies + + @staticmethod + @contextmanager + def mock_database_rollback_failure( + session: Session, + exception=None, + fail_on_call=1, + expect_rollback: bool = True, + expect_flush: bool = False, + expect_commit: bool = False, + ) -> Generator[SpiesWithException, None, None]: + """ + Create a context that mocks rollback failures specifically. This context will automatically + assert that rollback is called, flush is not called, and commit is not called. These automatic + assertions can be overridden via the expect_XXX parameters. + + Args: + session: Database session to mock + exception: Exception to raise on rollback (defaults to SQLAlchemyError) + fail_on_call: Which rollback call should fail (defaults to first call) + expect_rollback: Whether to assert rollback called (defaults to True) + expect_flush: Whether to assert flush called (defaults to False) + expect_commit: Whether to assert commit called (defaults to False) + Yields: + dict: Dictionary containing spies and the exception + """ + exception = exception or SQLAlchemyError("Rollback Error") + + with ( + patch.object( + session, "rollback", side_effect=create_failing_side_effect(exception, session.rollback, fail_on_call) + ), + TransactionSpy.spy( + session, + expect_rollback=expect_rollback, + expect_flush=expect_flush, + expect_commit=expect_commit, + ) as transaction_spies, + ): + spies: TransactionSpy.SpiesWithException = { + **transaction_spies, + "exception": exception, + } + + yield spies diff --git a/tests/helpers/util/common.py b/tests/helpers/util/common.py index 407cf101..0acf2c1e 100644 --- a/tests/helpers/util/common.py +++ b/tests/helpers/util/common.py @@ -56,3 +56,34 @@ def deepcamelize(data: Any) -> Any: return [deepcamelize(item) for item in data] else: return data + + +def create_failing_side_effect(exception, original_method, fail_on_call=1): + """ + Create a side effect function that fails on a specific call number, then delegates to original method. + + Args: + exception: The exception to raise on the failing call + original_method: The original method to delegate to after the failure + fail_on_call: Which call number should fail (1-indexed, defaults to first call) + + Returns: + A callable that can be used as a side_effect in mock.patch + + Example: + with patch.object(session, "execute", side_effect=create_failing_side_effect( + SQLAlchemyError("DB Error"), session.execute + )): + # First call will raise SQLAlchemyError, subsequent calls work normally + pass + """ + call_count = 0 + + def side_effect_function(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == fail_on_call: + raise exception + return original_method(*args, **kwargs) + + return side_effect_function diff --git a/tests/helpers/util/setup/worker.py b/tests/helpers/util/setup/worker.py new file mode 100644 index 00000000..2723b90f --- /dev/null +++ b/tests/helpers/util/setup/worker.py @@ -0,0 +1,145 @@ +from asyncio.unix_events import _UnixSelectorEventLoop +from copy import deepcopy +from unittest.mock import patch + +from sqlalchemy import select + +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.variant import Variant +from mavedb.worker.jobs import ( + create_variants_for_score_set, + map_variants_for_score_set, +) +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import ( + TEST_CODING_LAYER, + TEST_GENE_INFO, + TEST_GENOMIC_LAYER, + TEST_MAPPING_SCAFFOLD, + TEST_PROTEIN_LAYER, + TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, + TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, +) + + +async def create_variants_in_score_set( + session, mock_s3_client, score_df, count_df, mock_worker_ctx, variant_creation_run +): + """Add variants to a given score set in the database.""" + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[score_df, count_df], + ), + ): + # Guard against both possible function signatures, with some uses of this function coming from + # integration tests that need not pass a JobManager. + try: + result = await create_variants_for_score_set( + mock_worker_ctx, + variant_creation_run.id, + ) + except TypeError: + result = await create_variants_for_score_set( + mock_worker_ctx, + variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], variant_creation_run.id), + ) + + assert result["status"] == "ok" + session.commit() + + +async def create_mappings_in_score_set( + session, mock_s3_client, mock_worker_ctx, score_df, count_df, variant_creation_run, variant_mapping_run +): + await create_variants_in_score_set( + session, mock_s3_client, score_df, count_df, mock_worker_ctx, variant_creation_run + ) + + score_set = session.execute( + select(ScoreSetDbModel).where(ScoreSetDbModel.id == variant_creation_run.job_params["score_set_id"]) + ).scalar_one() + + async def dummy_mapping_job(): + return await construct_mock_mapping_output(session, score_set, with_layers={"g", "c", "p"}) + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Guard against both possible function signatures, with some uses of this function coming from + # integration tests that need not pass a JobManager. + try: + result = await map_variants_for_score_set(mock_worker_ctx, variant_mapping_run.id) + except TypeError: + result = await map_variants_for_score_set( + mock_worker_ctx, + variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], variant_mapping_run.id), + ) + + assert result["status"] == "ok" + session.commit() + + +async def construct_mock_mapping_output( + session, + score_set, + with_layers, + with_gene_info=True, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, +): + """Construct mapping output for a given score set in the database.""" + mapping_output = deepcopy(TEST_MAPPING_SCAFFOLD) + + if with_reference_metadata: + for target in score_set.target_genes: + mapping_output["reference_sequences"][target.name] = { + "gene_info": TEST_GENE_INFO if with_gene_info else {}, + } + + for target in score_set.target_genes: + mapping_output["reference_sequences"][target.name]["layers"] = {} + if "g" in with_layers: + mapping_output["reference_sequences"][target.name]["layers"]["g"] = TEST_GENOMIC_LAYER + if "c" in with_layers: + mapping_output["reference_sequences"][target.name]["layers"]["c"] = TEST_CODING_LAYER + if "p" in with_layers: + mapping_output["reference_sequences"][target.name]["layers"]["p"] = TEST_PROTEIN_LAYER + + if with_mapped_scores: + variants = session.scalars( + select(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + + for idx, variant in enumerate(variants): + mapped_score = { + "pre_mapped": deepcopy(TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X) if with_pre_mapped else {}, + "post_mapped": deepcopy(TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X) if with_post_mapped else {}, + "mavedb_id": variant.urn, + } + + # Don't alter HGVS strings in post mapped output. This makes it considerably + # easier to assert correctness in tests. + if with_post_mapped: + mapped_score["post_mapped"]["expressions"][0]["value"] = variant.hgvs_nt or variant.hgvs_pro + + # Skip every other variant if not with_all_variants + if not with_all_variants and idx % 2 == 0: + mapped_score["post_mapped"] = {} + + mapping_output["mapped_scores"].append(mapped_score) + + if not mapping_output["mapped_scores"]: + mapping_output["error_message"] = "test error: no mapped scores" + + return mapping_output diff --git a/tests/helpers/util/variant.py b/tests/helpers/util/variant.py index 5fcc05db..eede1e61 100644 --- a/tests/helpers/util/variant.py +++ b/tests/helpers/util/variant.py @@ -36,7 +36,11 @@ def mock_worker_variant_insertion( with ( open(scores_csv_path, "rb") as score_file, patch.object(ArqRedis, "enqueue_job", return_value=None) as worker_queue, + patch("mavedb.routers.score_sets.s3_client") as mock_s3_client, ): + mock_s3 = mock_s3_client.return_value + mock_s3.upload_fileobj.return_value = None # or whatever you want + files = {"scores_file": (scores_csv_path.name, score_file, "rb")} if counts_csv_path is not None: @@ -69,6 +73,7 @@ def mock_worker_variant_insertion( # Assert we have mocked a job being added to the queue, and that the request succeeded. The # response value here isn't important- we will add variants to the score set manually. + mock_s3.upload_fileobj.assert_called() worker_queue.assert_called_once() assert response.status_code == 200 diff --git a/tests/lib/clingen/network/test_allele_registry.py b/tests/lib/clingen/network/test_allele_registry.py new file mode 100644 index 00000000..f2ab2bff --- /dev/null +++ b/tests/lib/clingen/network/test_allele_registry.py @@ -0,0 +1,72 @@ +import pytest + +from mavedb.lib.clingen.allele_registry import ( + get_associated_clinvar_allele_id, + get_canonical_pa_ids, + get_matching_registered_ca_ids, +) + + +@pytest.mark.network +class TestGetCanonicalPaIdsNetwork: + def test_get_canonical_pa_ids_known_caid(self): + # Using a known ClinGen Allele ID with MANE transcripts + clingen_allele_id = "CA321211" # Example ClinGen Allele ID + result = get_canonical_pa_ids(clingen_allele_id) + assert isinstance(result, list) + assert result == ["PA2573050890", "PA321212"] # Expected MANE PA ID + + def test_get_canonical_pa_ids_known_no_mane(self): + # Using a ClinGen Allele ID for protein change, as this will not have mane transcripts + clingen_allele_id = "PA102264" # Example ClinGen Allele ID with no MANE + result = get_canonical_pa_ids(clingen_allele_id) + assert result == [] + + def test_get_canonical_pa_ids_invalid_id(self): + # Using an invalid ClinGen Allele ID + clingen_allele_id = "INVALID_ID" + result = get_canonical_pa_ids(clingen_allele_id) + assert result == [] + + +@pytest.mark.network +class TestGetMatchingRegisteredCaIdsNetwork: + def test_get_matching_registered_ca_ids_known_paid(self): + # Using a known ClinGen PA ID with registered CA IDs + clingen_pa_id = "PA2573050890" # Example ClinGen PA ID + result = get_matching_registered_ca_ids(clingen_pa_id) + assert isinstance(result, list) + assert "CA321211" in result # Expected registered CA ID + + def test_get_matching_registered_ca_ids_known_no_caids(self): + # Using a ClinGen PA ID with no registered CA IDs + clingen_pa_id = "PA3051398879" # Example ClinGen PA ID with no registered CA IDs + result = get_matching_registered_ca_ids(clingen_pa_id) + assert result == [] + + def test_get_matching_registered_ca_ids_invalid_id(self): + # Using an invalid ClinGen PA ID + clingen_pa_id = "INVALID_ID" + result = get_matching_registered_ca_ids(clingen_pa_id) + assert result == [] + + +@pytest.mark.network +class TestGetAssociatedClinvarAlleleIdNetwork: + def test_get_associated_clinvar_allele_id_known_caid(self): + # Using a known ClinGen Allele ID with associated ClinVar Allele ID + clingen_allele_id = "CA321211" # Example ClinGen Allele ID + result = get_associated_clinvar_allele_id(clingen_allele_id) + assert result == "211565" # Expected ClinVar Allele ID + + def test_get_associated_clinvar_allele_id_no_association(self): + # Using a ClinGen Allele ID with no associated ClinVar Allele ID + clingen_allele_id = "CA9532274" # Example ClinGen Allele ID with no association + result = get_associated_clinvar_allele_id(clingen_allele_id) + assert result is None + + def test_get_associated_clinvar_allele_id_invalid_id(self): + # Using an invalid ClinGen Allele ID + clingen_allele_id = "INVALID_ID" + result = get_associated_clinvar_allele_id(clingen_allele_id) + assert result is None diff --git a/tests/lib/clingen/test_allele_registry.py b/tests/lib/clingen/test_allele_registry.py new file mode 100644 index 00000000..d54b6d4a --- /dev/null +++ b/tests/lib/clingen/test_allele_registry.py @@ -0,0 +1,189 @@ +from unittest import mock + +import pytest + +from mavedb.lib.clingen.allele_registry import ( + get_associated_clinvar_allele_id, + get_canonical_pa_ids, + get_matching_registered_ca_ids, +) + + +@pytest.mark.unit +@mock.patch("mavedb.lib.clingen.allele_registry.requests.get") +class TestGetCanonicalPaIds: + def test_get_canonical_pa_ids_success(self, mock_request): + # Mock response object + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "transcriptAlleles": [ + {"MANE": True, "@id": "https://reg.genome.network/allele/PA12345"}, + {"MANE": False, "@id": "https://reg.genome.network/allele/PA54321"}, + {"MANE": True, "@id": "https://reg.genome.network/allele/PA67890"}, + {"@id": "https://reg.genome.network/allele/PA00000"}, # No MANE + ] + } + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00001") + assert result == ["PA12345", "PA67890"] + + def test_get_canonical_pa_ids_no_transcript_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00002") + assert result == [] + + def test_get_canonical_pa_ids_empty_transcript_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"transcriptAlleles": []} + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00003") + assert result == [] + + def test_get_canonical_pa_ids_missing_mane_or_id(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "transcriptAlleles": [ + {"MANE": True}, # Missing @id + {"@id": "https://reg.genome.network/allele/PA99999"}, # Missing MANE + {}, # Missing both + ] + } + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00004") + assert result == [] + + def test_get_canonical_pa_ids_api_error(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA404") + assert result == [] + + +@pytest.mark.unit +@mock.patch("mavedb.lib.clingen.allele_registry.requests.get") +class TestGetMatchingRegisteredCaIds: + def test_get_matching_registered_ca_ids_success(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "aminoAcidAlleles": [ + { + "matchingRegisteredTranscripts": [ + {"@id": "https://reg.genome.network/allele/CA11111"}, + {"@id": "https://reg.genome.network/allele/CA22222"}, + ] + }, + { + "matchingRegisteredTranscripts": [ + {"@id": "https://reg.genome.network/allele/CA33333"}, + ] + }, + { + # No matchingRegisteredTranscripts + }, + ] + } + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA12345") + assert result == ["CA11111", "CA22222", "CA33333"] + + def test_get_matching_registered_ca_ids_no_amino_acid_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA00000") + assert result == [] + + def test_get_matching_registered_ca_ids_empty_amino_acid_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"aminoAcidAlleles": []} + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA00001") + assert result == [] + + def test_get_matching_registered_ca_ids_missing_matching_registered_transcripts(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "aminoAcidAlleles": [ + {}, # No matchingRegisteredTranscripts + {"matchingRegisteredTranscripts": []}, # Empty list + ] + } + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA00002") + assert result == [] + + def test_get_matching_registered_ca_ids_api_error(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 500 + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PAERROR") + assert result == [] + + +@pytest.mark.unit +@mock.patch("mavedb.lib.clingen.allele_registry.requests.get") +class TestGetAssociatedClinvarAlleleId: + def test_get_associated_clinvar_allele_id_success(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"externalRecords": {"ClinVarAlleles": [{"alleleId": "123456"}]}} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00001") + assert result == "123456" + + def test_get_associated_clinvar_allele_id_no_external_records(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00002") + assert result is None + + def test_get_associated_clinvar_allele_id_no_clinvar_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"externalRecords": {}} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00003") + assert result is None + + def test_get_associated_clinvar_allele_id_missing_allele_id(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"externalRecords": {"ClinVarAlleles": [{}]}} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00004") + assert result is None + + def test_get_associated_clinvar_allele_id_api_error(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA404") + assert result is None diff --git a/tests/lib/clingen/test_services.py b/tests/lib/clingen/test_services.py index 34828649..7141eea3 100644 --- a/tests/lib/clingen/test_services.py +++ b/tests/lib/clingen/test_services.py @@ -1,27 +1,23 @@ # ruff: noqa: E402 import os +from datetime import datetime +from unittest.mock import MagicMock, patch + import pytest import requests -from datetime import datetime -from unittest.mock import patch, MagicMock -from urllib import parse arq = pytest.importorskip("arq") cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") -from mavedb.lib.clingen.constants import LDH_MAVE_ACCESS_ENDPOINT, GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD -from mavedb.lib.utils import batched +from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD from mavedb.lib.clingen.services import ( ClinGenAlleleRegistryService, ClinGenLdhService, - get_clingen_variation, - clingen_allele_id_from_ldh_variation, get_allele_registry_associations, ) - -from tests.helpers.constants import VALID_CLINGEN_CA_ID +from mavedb.lib.utils import batched TEST_CLINGEN_URL = "https://pytest.clingen.com" TEST_CAR_URL = "https://pytest.car.clingen.com" @@ -219,66 +215,6 @@ def test_dispatch_submissions_no_batching(self, mock_batched, mock_authenticate, ) -@patch("mavedb.lib.clingen.services.requests.get") -def test_get_clingen_variation_success(mock_get): - mocked_response_json = {"data": {"ldFor": {"Variant": [{"id": "variant_1", "name": "Test Variant"}]}}} - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = mocked_response_json - mock_get.return_value = mock_response - - urn = "urn:example:variant" - result = get_clingen_variation(urn) - - assert result == mocked_response_json - mock_get.assert_called_once_with( - f"{LDH_MAVE_ACCESS_ENDPOINT}/{parse.quote_plus(urn)}", - headers={"Accept": "application/json"}, - ) - - -@patch("mavedb.lib.clingen.services.requests.get") -def test_get_clingen_variation_failure(mock_get): - mock_response = MagicMock() - mock_response.status_code = 404 - mock_response.text = "Not Found" - mock_get.return_value = mock_response - - urn = "urn:example:nonexistent_variant" - result = get_clingen_variation(urn) - - assert result is None - mock_get.assert_called_once_with( - f"{LDH_MAVE_ACCESS_ENDPOINT}/{parse.quote_plus(urn)}", - headers={"Accept": "application/json"}, - ) - - -def test_clingen_allele_id_from_ldh_variation_success(): - variation = {"data": {"ldFor": {"Variant": [{"entId": VALID_CLINGEN_CA_ID}]}}} - result = clingen_allele_id_from_ldh_variation(variation) - assert result == VALID_CLINGEN_CA_ID - - -def test_clingen_allele_id_from_ldh_variation_missing_key(): - variation = {"data": {"ldFor": {"Variant": []}}} - - result = clingen_allele_id_from_ldh_variation(variation) - assert result is None - - -def test_clingen_allele_id_from_ldh_variation_no_variation(): - result = clingen_allele_id_from_ldh_variation(None) - assert result is None - - -def test_clingen_allele_id_from_ldh_variation_key_error(): - variation = {"data": {}} - - result = clingen_allele_id_from_ldh_variation(variation) - assert result is None - - class TestClinGenAlleleRegistryService: def test_init(self, car_service): assert car_service.url == TEST_CAR_URL diff --git a/tests/lib/clinvar/network/test_utils.py b/tests/lib/clinvar/network/test_utils.py new file mode 100644 index 00000000..6bbf3650 --- /dev/null +++ b/tests/lib/clinvar/network/test_utils.py @@ -0,0 +1,23 @@ +from datetime import datetime + +import pytest + +from mavedb.lib.clinvar.utils import fetch_clinvar_variant_summary_tsv + + +@pytest.mark.network +@pytest.mark.slow +class TestFetchClinvarVariantSummaryTSVIntegration: + def test_fetch_recent_variant_summary(self): + now = datetime.now() + # Attempt to fetch the most recent available month (previous month) + month = now.month - 1 if now.month > 1 else 12 + year = now.year if now.month > 1 else now.year - 1 + + content = fetch_clinvar_variant_summary_tsv(month, year) + assert content.startswith(b"\x1f\x8b") # Gzip magic number + + def test_fetch_older_variant_summary(self): + # Fetch an older known date + content = fetch_clinvar_variant_summary_tsv(2, 2015) + assert content.startswith(b"\x1f\x8b") # Gzip magic number diff --git a/tests/lib/clinvar/test_utils.py b/tests/lib/clinvar/test_utils.py new file mode 100644 index 00000000..7dd19089 --- /dev/null +++ b/tests/lib/clinvar/test_utils.py @@ -0,0 +1,148 @@ +import csv +import gzip +import io +from datetime import datetime + +import pytest +import requests + +from mavedb.lib.clinvar.utils import ( + fetch_clinvar_variant_summary_tsv, + parse_clinvar_variant_summary, + validate_clinvar_variant_summary_date, +) + + +@pytest.mark.unit +class TestValidateClinvarVariantSummaryDate: + def test_valid_past_date(self): + # Should not raise for a valid past date + validate_clinvar_variant_summary_date(2, 2015) + + def test_valid_current_month_and_year(self): + now = datetime.now() + # Should not raise for current month and year + validate_clinvar_variant_summary_date(now.month, now.year) + + def test_invalid_month_low(self): + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + validate_clinvar_variant_summary_date(0, 2020) + + def test_invalid_month_high(self): + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + validate_clinvar_variant_summary_date(13, 2020) + + def test_year_before_2015(self): + with pytest.raises(ValueError, match="ClinVar archived data is only available from February 2015 onwards."): + validate_clinvar_variant_summary_date(6, 2014) + + def test_year_2015_before_february(self): + with pytest.raises(ValueError, match="ClinVar archived data is only available from February 2015 onwards."): + validate_clinvar_variant_summary_date(1, 2015) + + def test_year_in_future(self): + future_year = datetime.now().year + 1 + with pytest.raises(ValueError, match="Cannot fetch ClinVar data for future years."): + validate_clinvar_variant_summary_date(6, future_year) + + def test_month_in_future_for_current_year(self): + now = datetime.now() + if now.month == 12: + pytest.skip("December, no future month in current year") + return # December, no future month in current year + + future_month = now.month + 1 if now.month < 12 else 12 + with pytest.raises(ValueError, match="Cannot fetch ClinVar data for future months."): + validate_clinvar_variant_summary_date(future_month, now.year) + + +@pytest.mark.unit +class TestFetchClinvarVariantSummaryTSV: + class MockResponse: + def __init__(self, content, status_code=200, raise_exc=None): + self.content = content + self.status_code = status_code + self._raise_exc = raise_exc + + def raise_for_status(self): + if self._raise_exc: + raise self._raise_exc + + def test_fetch_clinvar_variant_summary_tsv_top_level_success(self, monkeypatch): + # Simulate successful fetch from top-level URL + mock_content = b"mock gzipped content" + + def mock_get(url, stream=True): + return self.MockResponse(mock_content) + + monkeypatch.setattr("requests.get", mock_get) + result = fetch_clinvar_variant_summary_tsv(1, 2016) + assert result == mock_content + + def test_fetch_clinvar_variant_summary_tsv_archive_success(self, monkeypatch): + # Simulate top-level fails, archive succeeds + mock_content = b"archive gzipped content" + + def mock_get(url, stream=True): + if "variant_summary_2015-01.txt.gz" in url and "/2015/" not in url: + raise requests.RequestException("Top-level not found") + return self.MockResponse(mock_content) + + monkeypatch.setattr("requests.get", mock_get) + result = fetch_clinvar_variant_summary_tsv(1, 2016) + assert result == mock_content + + def test_fetch_clinvar_variant_summary_tsv_both_fail(self, monkeypatch): + # Simulate both URLs failing + def mock_get(url, stream=True): + raise requests.RequestException("Not found") + + monkeypatch.setattr("requests.get", mock_get) + with pytest.raises(requests.RequestException, match="Not found"): + fetch_clinvar_variant_summary_tsv(1, 2016) + + def test_fetch_clinvar_variant_summary_tsv_invalid_date(self, monkeypatch): + # Should raise ValueError before any network call + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + fetch_clinvar_variant_summary_tsv(0, 2020) + + +class TestParseClinvarVariantSummary: + def make_gzipped_tsv(self, text: str) -> bytes: + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(text.encode("utf-8")) + return buf.getvalue() + + def test_parse_clinvar_variant_summary_basic(self): + tsv = "#AlleleID\tGeneSymbol\tClinicalSignificance\n" "123\tBRCA1\tPathogenic\n" "456\tTP53\tBenign\n" + gzipped = self.make_gzipped_tsv(tsv) + result = parse_clinvar_variant_summary(gzipped) + assert "123" in result + assert "456" in result + assert result["123"]["GeneSymbol"] == "BRCA1" + assert result["456"]["ClinicalSignificance"] == "Benign" + + def test_parse_clinvar_variant_summary_missing_alleleid_column(self): + tsv = "GeneSymbol\tClinicalSignificance\n" "BRCA1\tPathogenic\n" + gzipped = self.make_gzipped_tsv(tsv) + with pytest.raises(KeyError): + parse_clinvar_variant_summary(gzipped) + + def test_parse_clinvar_variant_summary_empty_content(self): + gzipped = self.make_gzipped_tsv("") + parse_clinvar_variant_summary(gzipped) + + def test_parse_clinvar_variant_summary_large_field(self): + large_field = "A" * (csv.field_size_limit() + 100) + tsv = f"#AlleleID\tGeneSymbol\n999\t{large_field}\n" + gzipped = self.make_gzipped_tsv(tsv) + result = parse_clinvar_variant_summary(gzipped) + assert result["999"]["GeneSymbol"] == large_field + + def test_parse_clinvar_variant_summary_does_not_alter_field_size_limit(self): + default_limit = csv.field_size_limit() + tsv = "#AlleleID\tGeneSymbol\n1\tBRCA1\n" + gzipped = self.make_gzipped_tsv(tsv) + parse_clinvar_variant_summary(gzipped) + assert csv.field_size_limit() == default_limit diff --git a/tests/lib/test_annotation_status_manager.py b/tests/lib/test_annotation_status_manager.py new file mode 100644 index 00000000..df78ce69 --- /dev/null +++ b/tests/lib/test_annotation_status_manager.py @@ -0,0 +1,499 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("psycopg2") + +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.variant import Variant + + +@pytest.fixture +def annotation_status_manager(session): + """Fixture to provide an AnnotationStatusManager instance.""" + return AnnotationStatusManager(session) + + +@pytest.fixture +def existing_annotation_status(session, annotation_status_manager, setup_lib_db_with_variant): + """Fixture to create an existing annotation status in the database.""" + + # Add initial annotation + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + return annotation + + +@pytest.fixture +def existing_unversioned_annotation_status(session, annotation_status_manager, setup_lib_db_with_variant): + """Fixture to create an existing annotation status in the database.""" + + # Add initial annotation + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=None, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + return annotation + + +@pytest.mark.unit +class TestAnnotationStatusManagerCreateAnnotationUnit: + """Unit tests for AnnotationStatusManager.add_annotation method.""" + + @pytest.mark.parametrize( + "annotation_type", + AnnotationType._member_map_.values(), + ) + @pytest.mark.parametrize( + "status", + AnnotationStatus._member_map_.values(), + ) + def test_add_annotation_creates_entry_with_annotation_type_version_status( + self, session, annotation_status_manager, annotation_type, status, setup_lib_db_with_variant + ): + """Test that adding an annotation creates a new entry with correct type and version.""" + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=annotation_type, + version="v1.0", + annotation_data={}, + current=True, + status=status, + ) + session.commit() + + assert annotation.annotation_type == annotation_type + assert annotation.status == status + assert annotation.version == "v1.0" + + def test_add_annotation_persists_annotation_data( + self, session, annotation_status_manager, setup_lib_db_with_variant + ): + """Test that adding an annotation persists the provided annotation data.""" + annotation_data = { + "success_data": {"some_key": "some_value"}, + "error_message": None, + "failure_category": None, + } + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + status=AnnotationStatus.SUCCESS, + version="v1.0", + annotation_data=annotation_data, + current=True, + ) + session.commit() + + for key, value in annotation_data.items(): + assert getattr(annotation, key) == value + + def test_add_annotation_creates_entry_and_marks_previous_not_current( + self, session, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation creates a new entry and marks previous ones as not current.""" + manager = AnnotationStatusManager(session) + + # Add second annotation for same (variant, type, version) + annotation = manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_annotation_status) + assert existing_annotation_status.current is False + + def test_add_annotation_with_different_version_keeps_previous_current( + self, session, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation with a different version keeps previous current.""" + manager = AnnotationStatusManager(session) + + # Add second annotation for same (variant, type) but different version + annotation = manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v2", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_annotation_status) + assert existing_annotation_status.current is True + + def test_add_annotation_with_different_type_keeps_previous_current( + self, session, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation with a different type keeps previous current.""" + manager = AnnotationStatusManager(session) + + # Add second annotation for same variant but different type + annotation = manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_annotation_status) + assert existing_annotation_status.current is True + + def test_add_annotation_without_version(self, session, annotation_status_manager, setup_lib_db_with_variant): + """Test that adding an annotation without specifying version works correctly.""" + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VEP_FUNCTIONAL_CONSEQUENCE, + version=None, + annotation_data={}, + status=AnnotationStatus.SKIPPED, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.version is None + assert annotation.current is True + + def test_add_annotation_multiple_without_version_marks_previous_not_current( + self, session, annotation_status_manager, existing_unversioned_annotation_status, setup_lib_db_with_variant + ): + """Test that adding multiple annotations without version marks previous ones as not current.""" + + # Add second annotation without version + second_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=None, + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + assert second_annotation.id is not None + assert second_annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_unversioned_annotation_status) + assert existing_unversioned_annotation_status.current is False + + def test_add_annotation_different_type_without_version_keeps_previous_current( + self, session, annotation_status_manager, existing_unversioned_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation of different type without version keeps previous current.""" + + # Add second annotation of different type without version + second_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version=None, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert second_annotation.id is not None + assert second_annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_unversioned_annotation_status) + assert existing_unversioned_annotation_status.current is True + + def test_add_annotation_multiple_variants_independent_current_flags( + self, session, annotation_status_manager, setup_lib_db_with_score_set + ): + """Test that adding annotations for different variants maintains independent current flags.""" + + variant1 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={}) + variant2 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.2A>T", hgvs_pro="NP_000000.1:p.Met2Val", data={}) + session.add_all([variant1, variant2]) + session.commit() + session.refresh(variant1) + session.refresh(variant2) + + # Add annotation for variant 1 + annotation1 = annotation_status_manager.add_annotation( + variant_id=variant1.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Add annotation for variant 2 + annotation2 = annotation_status_manager.add_annotation( + variant_id=variant2.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation1.id is not None + assert annotation1.current is True + + assert annotation2.id is not None + assert annotation2.current is True + + +class TestAnnotationStatusManagerGetCurrentAnnotationUnit: + """Unit tests for AnnotationStatusManager.get_current_annotation method.""" + + def test_get_current_annotation_returns_none_when_no_entry( + self, annotation_status_manager, setup_lib_db_with_variant + ): + """Test that getting current annotation returns None when no entry exists.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + assert annotation is None + + def test_get_current_annotation_returns_correct_entry( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation returns the correct entry.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + assert annotation.id == existing_annotation_status.id + assert annotation.current is True + + def test_get_current_annotation_returns_none_for_non_current( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation returns None when the entry is not current.""" + # Mark existing annotation as not current + existing_annotation_status.current = False + session.commit() + + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + assert annotation is None + + def test_get_current_annotation_with_different_version_returns_none( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation with different version returns None.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v2", + ) + assert annotation is None + + def test_get_current_annotation_with_different_type_returns_none( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation with different type returns None.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version="v1", + ) + assert annotation is None + + def test_get_current_annotation_without_version_returns_correct_entry( + self, session, annotation_status_manager, existing_unversioned_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation without version returns the correct entry.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=None, + ) + assert annotation.id == existing_unversioned_annotation_status.id + assert annotation.current is True + + +class TestAnnotationStatusManagerIntegration: + """Integration tests for AnnotationStatusManager methods.""" + + def test_add_and_get_current_annotation_work_together( + self, session, annotation_status_manager, setup_lib_db_with_variant + ): + """Test that adding and getting current annotation work together correctly.""" + # Add annotation + added_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Get current annotation + retrieved_annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + + assert retrieved_annotation is not None + assert retrieved_annotation.id == added_annotation.id + assert retrieved_annotation.current is True + assert retrieved_annotation.status == AnnotationStatus.SUCCESS + + @pytest.mark.parametrize( + "version", + ["v1.0", "v2.0", None], + ) + def test_add_multiple_and_get_current_returns_latest( + self, session, annotation_status_manager, version, setup_lib_db_with_variant + ): + """Test that adding multiple annotations and getting current returns the latest one.""" + # Add first annotation + annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + # Add second annotation + second_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Get current annotation + retrieved_annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + ) + + assert retrieved_annotation is not None + assert retrieved_annotation.id == second_annotation.id + assert retrieved_annotation.current is True + assert retrieved_annotation.version == version + assert retrieved_annotation.status == AnnotationStatus.SUCCESS + + @pytest.mark.parametrize( + "version", + ["v1.0", "v2.0", None], + ) + def test_add_annotations_for_different_variants_and_get_current_independent( + self, session, annotation_status_manager, version, setup_lib_db_with_score_set + ): + """Test that adding annotations for different variants and getting current works independently.""" + + variant1 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={}) + variant2 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.2A>T", hgvs_pro="NP_000000.1:p.Met2Val", data={}) + session.add_all([variant1, variant2]) + session.commit() + session.refresh(variant1) + session.refresh(variant2) + + # Add annotation for variant 1 + annotation1 = annotation_status_manager.add_annotation( + variant_id=variant1.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Add annotation for variant 2 + annotation2 = annotation_status_manager.add_annotation( + variant_id=variant2.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + # Get current annotation for variant 1 + retrieved_annotation1 = annotation_status_manager.get_current_annotation( + variant_id=variant1.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + ) + + assert retrieved_annotation1 is not None + assert retrieved_annotation1.id == annotation1.id + assert retrieved_annotation1.current is True + assert retrieved_annotation1.status == AnnotationStatus.SUCCESS + assert retrieved_annotation1.version == version + + # Get current annotation for variant 2 + retrieved_annotation2 = annotation_status_manager.get_current_annotation( + variant_id=variant2.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + ) + + assert retrieved_annotation2 is not None + assert retrieved_annotation2.id == annotation2.id + assert retrieved_annotation2.current is True + assert retrieved_annotation2.status == AnnotationStatus.FAILED + assert retrieved_annotation2.version == version diff --git a/tests/lib/test_gnomad.py b/tests/lib/test_gnomad.py index 043c6c56..14dde952 100644 --- a/tests/lib/test_gnomad.py +++ b/tests/lib/test_gnomad.py @@ -1,25 +1,26 @@ # ruff: noqa: E402 -import pytest -import importlib from unittest.mock import patch +import pytest + +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + pyathena = pytest.importorskip("pyathena") fastapi = pytest.importorskip("fastapi") from mavedb.lib.gnomad import ( - gnomad_identifier, allele_list_from_list_like_string, + gnomad_identifier, + gnomad_table_name, link_gnomad_variants_to_mapped_variants, ) -from mavedb.models.mapped_variant import MappedVariant from mavedb.models.gnomad_variant import GnomADVariant - +from mavedb.models.mapped_variant import MappedVariant from tests.helpers.constants import ( - TEST_GNOMAD_ALLELE_NUMBER, + TEST_GNOMAD_DATA_VERSION, TEST_GNOMAD_VARIANT, TEST_MINIMAL_MAPPED_VARIANT, - TEST_GNOMAD_DATA_VERSION, ) ### Tests for gnomad_identifier function ### @@ -63,22 +64,17 @@ def test_gnomad_identifier_raises_with_no_alleles(): ### Tests for gnomad_table_name function ### -def test_gnomad_table_name_returns_expected(monkeypatch): - monkeypatch.setenv("GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION) - # Reload the module to update GNOMAD_DATA_VERSION global - import mavedb.lib.gnomad as gnomad_mod - - importlib.reload(gnomad_mod) - assert gnomad_mod.gnomad_table_name() == TEST_GNOMAD_DATA_VERSION.replace(".", "_") - +def test_gnomad_table_name_returns_expected(): + with patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION): + assert gnomad_table_name() == TEST_GNOMAD_DATA_VERSION.replace(".", "_") -def test_gnomad_table_name_raises_if_env_not_set(monkeypatch): - monkeypatch.delenv("GNOMAD_DATA_VERSION", raising=False) - import mavedb.lib.gnomad as gnomad_mod - importlib.reload(gnomad_mod) - with pytest.raises(ValueError, match="GNOMAD_DATA_VERSION environment variable is not set."): - gnomad_mod.gnomad_table_name() +def test_gnomad_table_name_raises_if_env_not_set(): + with ( + pytest.raises(ValueError, match="GNOMAD_DATA_VERSION environment variable is not set."), + patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", None), + ): + gnomad_table_name() ### Tests for allele_list_from_list_like_string function ### @@ -125,6 +121,16 @@ def test_allele_list_from_list_like_string_invalid_format_not_list(): ### Tests for link_gnomad_variants_to_mapped_variants function ### +def _verify_annotation_status(session, mapped_variants, expected_version): + annotations = session.query(VariantAnnotationStatus).all() + assert len(annotations) == len(mapped_variants) + + for mapped_variant, annotation in zip(mapped_variants, annotations): + assert annotation.variant_id == mapped_variant.variant_id + assert annotation.annotation_type == "gnomad_allele_frequency" + assert annotation.version == expected_version + + def test_links_new_gnomad_variant_to_mapped_variant( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant ): @@ -148,6 +154,8 @@ def test_links_new_gnomad_variant_to_mapped_variant( for attr in edited_saved_gnomad_variant: assert getattr(mapped_variant.gnomad_variants[0], attr) == edited_saved_gnomad_variant[attr] + _verify_annotation_status(session, [mapped_variant], TEST_GNOMAD_DATA_VERSION) + def test_can_link_gnomad_variants_with_none_type_faf_fields( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant @@ -175,6 +183,8 @@ def test_can_link_gnomad_variants_with_none_type_faf_fields( for attr in gnomad_variant_comparator: assert getattr(mapped_variant.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + _verify_annotation_status(session, [mapped_variant], TEST_GNOMAD_DATA_VERSION) + def test_links_existing_gnomad_variant(session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant): gnomad_variant = GnomADVariant(**TEST_GNOMAD_VARIANT) @@ -199,8 +209,10 @@ def test_links_existing_gnomad_variant(session, mocked_gnomad_variant_row, setup for attr in edited_saved_gnomad_variant: assert getattr(mapped_variant.gnomad_variants[0], attr) == edited_saved_gnomad_variant[attr] + _verify_annotation_status(session, [mapped_variant], TEST_GNOMAD_DATA_VERSION) -def test_removes_existing_gnomad_variant_with_same_version( + +def test_adding_existing_gnomad_variant_with_same_version_does_not_result_in_duplication( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant ): mapped_variant = setup_lib_db_with_mapped_variant @@ -212,7 +224,6 @@ def test_removes_existing_gnomad_variant_with_same_version( result = link_gnomad_variants_to_mapped_variants(session, [mocked_gnomad_variant_row]) assert result == 1 - setattr(mocked_gnomad_variant_row, "joint.freq.all.ac", "1234") with patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION): result = link_gnomad_variants_to_mapped_variants(session, [mocked_gnomad_variant_row]) assert result == 1 @@ -221,8 +232,6 @@ def test_removes_existing_gnomad_variant_with_same_version( session.refresh(mapped_variant) edited_saved_gnomad_variant = TEST_GNOMAD_VARIANT.copy() - edited_saved_gnomad_variant["allele_count"] = 1234 - edited_saved_gnomad_variant["allele_frequency"] = float(1234 / int(TEST_GNOMAD_ALLELE_NUMBER)) edited_saved_gnomad_variant.pop("creation_date") edited_saved_gnomad_variant.pop("modification_date") @@ -230,6 +239,8 @@ def test_removes_existing_gnomad_variant_with_same_version( for attr in edited_saved_gnomad_variant: assert getattr(mapped_variant.gnomad_variants[0], attr) == edited_saved_gnomad_variant[attr] + _verify_annotation_status(session, [mapped_variant, mapped_variant], TEST_GNOMAD_DATA_VERSION) + def test_links_multiple_rows_and_variants(session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant): mapped_variant1 = setup_lib_db_with_mapped_variant @@ -256,11 +267,15 @@ def test_links_multiple_rows_and_variants(session, mocked_gnomad_variant_row, se for attr in gnomad_variant_comparator: assert getattr(mv.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + _verify_annotation_status(session, [mapped_variant1, mapped_variant2], TEST_GNOMAD_DATA_VERSION) + def test_returns_zero_when_no_mapped_variants(session, mocked_gnomad_variant_row): result = link_gnomad_variants_to_mapped_variants(session, [mocked_gnomad_variant_row]) assert result == 0 + _verify_annotation_status(session, [], TEST_GNOMAD_DATA_VERSION) + def test_only_current_flag_filters_variants(session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant): mapped_variant1 = setup_lib_db_with_mapped_variant @@ -287,6 +302,8 @@ def test_only_current_flag_filters_variants(session, mocked_gnomad_variant_row, for attr in gnomad_variant_comparator: assert getattr(mapped_variant2.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + _verify_annotation_status(session, [mapped_variant2], TEST_GNOMAD_DATA_VERSION) + def test_only_current_flag_is_false_operates_on_all_variants( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant @@ -313,3 +330,5 @@ def test_only_current_flag_is_false_operates_on_all_variants( assert len(mv.gnomad_variants) == 1 for attr in gnomad_variant_comparator: assert getattr(mv.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + + _verify_annotation_status(session, [mapped_variant1, mapped_variant2], TEST_GNOMAD_DATA_VERSION) diff --git a/tests/lib/workflow/conftest.py b/tests/lib/workflow/conftest.py new file mode 100644 index 00000000..0f9d9e50 --- /dev/null +++ b/tests/lib/workflow/conftest.py @@ -0,0 +1,111 @@ +from unittest.mock import patch + +import pytest + +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_run import JobRun +from mavedb.models.user import User +from tests.helpers.constants import TEST_USER + +try: + from .conftest_optional import * # noqa: F403, F401 +except ImportError: + pass + + +@pytest.fixture +def sample_job_definition(): + """Provides a sample job definition for testing.""" + return { + "key": "sample_job", + "type": "data_processing", + "function": "process_data", + "params": {"param1": "value1", "param2": "value2", "required_param": None}, + "dependencies": [], + } + + +@pytest.fixture +def sample_independent_pipeline_definition(sample_job_definition): + """Provides a sample pipeline definition for testing.""" + return { + "name": "sample_pipeline", + "description": "A sample pipeline for testing purposes.", + "job_definitions": [sample_job_definition], + } + + +@pytest.fixture +def sample_dependent_pipeline_definition(): + """Provides a sample pipeline definition with job dependencies for testing.""" + job_def_1 = { + "key": "job_1", + "type": "data_processing", + "function": "process_data_1", + "params": {"paramA": None}, + "dependencies": [], + } + job_def_2 = { + "key": "job_2", + "type": "data_processing", + "function": "process_data_2", + "params": {"paramB": None}, + "dependencies": [("job_1", DependencyType.SUCCESS_REQUIRED)], + } + return { + "name": "dependent_pipeline", + "description": "A sample pipeline with job dependencies for testing.", + "job_definitions": [job_def_1, job_def_2], + } + + +@pytest.fixture +def with_test_pipeline_definition_ctx(sample_dependent_pipeline_definition, sample_independent_pipeline_definition): + """Fixture to temporarily add a test pipeline definition.""" + test_pipeline_definitions = { + sample_dependent_pipeline_definition["name"]: sample_dependent_pipeline_definition, + sample_independent_pipeline_definition["name"]: sample_independent_pipeline_definition, + } + + with patch("mavedb.lib.workflow.pipeline_factory.PIPELINE_DEFINITIONS", test_pipeline_definitions): + yield + + +@pytest.fixture +def test_user(session): + """Fixture to create and provide a test user in the database.""" + db = session + user = User(**TEST_USER) + db.add(user) + db.commit() + yield user + + +@pytest.fixture +def test_workflow_parent_job_run(session, test_user): + """Fixture to create and provide a test parent job run for workflow tests.""" + parent_job_run = JobRun( + job_type="test_type", + job_function="test_function", + job_params={}, + correlation_id="test_correlation_id", + ) + session.add(parent_job_run) + session.commit() + + yield parent_job_run + + +@pytest.fixture +def test_workflow_child_job_run(session, test_user, test_workflow_parent_job_run): + """Fixture to create and provide a test child job run for workflow tests.""" + child_job_run = JobRun( + job_type="test_type", + job_function="test_function", + job_params={}, + correlation_id="test_correlation_id", + ) + session.add(child_job_run) + session.commit() + + yield child_job_run diff --git a/tests/lib/workflow/conftest_optional.py b/tests/lib/workflow/conftest_optional.py new file mode 100644 index 00000000..f165cc74 --- /dev/null +++ b/tests/lib/workflow/conftest_optional.py @@ -0,0 +1,16 @@ +import pytest + +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.lib.workflow.pipeline_factory import PipelineFactory + + +@pytest.fixture +def job_factory(session): + """Fixture to provide a mocked JobFactory instance.""" + yield JobFactory(session) + + +@pytest.fixture +def pipeline_factory(session): + """Fixture to provide a mocked PipelineFactory instance.""" + yield PipelineFactory(session) diff --git a/tests/lib/workflow/test_job_factory.py b/tests/lib/workflow/test_job_factory.py new file mode 100644 index 00000000..bf2e13ba --- /dev/null +++ b/tests/lib/workflow/test_job_factory.py @@ -0,0 +1,316 @@ +# ruff: noqa: E402 +import pytest + +from mavedb.models.job_dependency import JobDependency + +pytest.importorskip("fastapi") + +from unittest.mock import patch + +from mavedb.models.pipeline import Pipeline + + +@pytest.mark.unit +class TestJobFactoryCreateJobRunUnit: + """Unit tests for the JobFactory create_job_run method.""" + + def test_create_job_run_persists_preset_params_from_definition(self, job_factory, sample_job_definition): + existing_params = {"param1": "new_value1", "param2": "new_value2", "required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=existing_params, + pipeline_id=1, + ) + + assert job_run.job_params["param1"] == "value1" + assert job_run.job_params["param2"] == "value2" + + def test_create_job_run_raises_error_for_missing_params(self, job_factory, sample_job_definition): + incomplete_params = {"param1": "new_value1"} # Missing param2 + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=incomplete_params, + pipeline_id=1, + ) + + assert "Missing required param: required_param" in str(exc_info.value) + + def test_create_job_run_fills_in_required_params(self, job_factory, sample_job_definition): + pipeline_params = {"required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=pipeline_params, + pipeline_id=1, + ) + + assert job_run.job_params["param1"] == "value1" + assert job_run.job_params["param2"] == "value2" + assert job_run.job_params["required_param"] == "required_value" + + def test_create_job_run_persists_correlation_id(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run.correlation_id == "test-correlation-id" + + def test_create_job_run_persists_mavedb_version(self, job_factory, sample_job_definition): + with patch("mavedb.lib.workflow.job_factory.mavedb_version", "1.2.3"): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run.mavedb_version == "1.2.3" + + def test_create_job_run_persists_job_type_and_function(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run.job_type == sample_job_definition["type"] + assert job_run.job_function == sample_job_definition["function"] + + def test_create_job_run_ignores_extra_pipeline_params(self, job_factory, sample_job_definition): + pipeline_params = { + "param1": "new_value1", + "param2": "new_value2", + "required_param": "required_value", + "extra_param": "should_be_ignored", + } + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=pipeline_params, + pipeline_id=1, + ) + + assert "extra_param" not in job_run.job_params + + def test_create_job_run_with_no_pipeline_id(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + ) + + assert job_run.pipeline_id is None + + def test_create_job_run_associates_with_pipeline(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=42, + ) + + assert job_run.pipeline_id == 42 + + def test_create_job_run_adds_to_session(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run in job_factory.session.new + + +@pytest.mark.integration +class TestJobFactoryCreateJobRunIntegration: + """Integration tests for the JobFactory create_job_run method within pipeline execution.""" + + def test_create_job_run_independent(self, job_factory, sample_job_definition): + pipeline_params = {"required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="integration-correlation-id", + pipeline_params=pipeline_params, + ) + job_factory.session.commit() + + retrieved_job_run = job_factory.session.get(type(job_run), job_run.id) + + assert retrieved_job_run is not None + assert retrieved_job_run.job_type == sample_job_definition["type"] + assert retrieved_job_run.job_function == sample_job_definition["function"] + assert retrieved_job_run.job_params["param1"] == "value1" + assert retrieved_job_run.job_params["param2"] == "value2" + assert retrieved_job_run.job_params["required_param"] == "required_value" + assert retrieved_job_run.correlation_id == "integration-correlation-id" + assert retrieved_job_run.pipeline_id is None + + def test_create_job_run_with_pipeline(self, job_factory, sample_job_definition): + pipeline = Pipeline( + name="Test Pipeline", + description="A pipeline for testing JobFactory integration.", + ) + job_factory.session.add(pipeline) + job_factory.session.flush() + + pipeline_params = {"required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="integration-correlation-id", + pipeline_params=pipeline_params, + pipeline_id=pipeline.id, + ) + job_factory.session.commit() + + retrieved_job_run = job_factory.session.get(type(job_run), job_run.id) + + assert retrieved_job_run is not None + assert retrieved_job_run.job_type == sample_job_definition["type"] + assert retrieved_job_run.job_function == sample_job_definition["function"] + assert retrieved_job_run.job_params["param1"] == "value1" + assert retrieved_job_run.job_params["param2"] == "value2" + assert retrieved_job_run.job_params["required_param"] == "required_value" + assert retrieved_job_run.correlation_id == "integration-correlation-id" + assert retrieved_job_run.pipeline_id == pipeline.id + + def test_create_job_run_missing_params_raises_error(self, job_factory, sample_job_definition): + incomplete_params = {"param1": "new_value1"} # Missing required_param + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="integration-correlation-id", + pipeline_params=incomplete_params, + pipeline_id=100, + ) + + assert "Missing required param: required_param" in str(exc_info.value) + + +@pytest.mark.unit +class TestJobFactoryCreateJobDependencyUnit: + """Unit tests for the JobFactory create_job_dependency method.""" + + def test_create_job_dependency_persists_fields( + self, job_factory, test_workflow_parent_job_run, test_workflow_child_job_run + ): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = test_workflow_child_job_run.id + dependency_type = "success_required" + + job_dependency = job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + dependency_type=dependency_type, + ) + + assert job_dependency.id == child_job_run_id + assert job_dependency.depends_on_job_id == parent_job_run_id + assert job_dependency.dependency_type == dependency_type + + def test_create_job_dependency_defaults_dependency_type( + self, job_factory, test_workflow_parent_job_run, test_workflow_child_job_run + ): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = test_workflow_child_job_run.id + + job_dependency = job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert job_dependency.id == child_job_run_id + assert job_dependency.depends_on_job_id == parent_job_run_id + assert job_dependency.dependency_type == "success_required" + + def test_create_job_dependency_raises_error_for_nonexistent_parent(self, job_factory, test_workflow_child_job_run): + parent_job_run_id = 9999 # Assuming this ID does not exist + child_job_run_id = test_workflow_child_job_run.id + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Parent job run ID {parent_job_run_id} does not exist." in str(exc_info.value) + + def test_create_job_dependency_raises_error_for_nonexistent_child(self, job_factory, test_workflow_parent_job_run): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = 9999 # Assuming this ID does not exist + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Child job run ID {child_job_run_id} does not exist." in str(exc_info.value) + + +@pytest.mark.integration +class TestJobFactoryCreateJobDependencyIntegration: + """Integration tests for the JobFactory create_job_dependency method within job execution.""" + + def test_create_job_dependency(self, job_factory, test_workflow_parent_job_run, test_workflow_child_job_run): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = test_workflow_child_job_run.id + dependency_type = "success_required" + + job_dependency = job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + dependency_type=dependency_type, + ) + job_factory.session.commit() + + retrieved_dependency = ( + job_factory.session.query(type(job_dependency)) + .filter( + type(job_dependency).id == child_job_run_id, + type(job_dependency).depends_on_job_id == parent_job_run_id, + ) + .first() + ) + + assert retrieved_dependency is not None + assert retrieved_dependency.id == child_job_run_id + assert retrieved_dependency.depends_on_job_id == parent_job_run_id + assert retrieved_dependency.dependency_type == dependency_type + + def test_create_job_dependency_missing_parent_raises_error(self, session, job_factory, test_workflow_child_job_run): + parent_job_run_id = 9999 # Assuming this ID does not exist + child_job_run_id = test_workflow_child_job_run.id + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Parent job run ID {parent_job_run_id} does not exist." in str(exc_info.value) + job_dependencies = session.query(JobDependency).all() + assert not job_dependencies + + def test_create_job_dependency_missing_child_raises_error(self, session, job_factory, test_workflow_parent_job_run): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = 9999 # Assuming this ID does not exist + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Child job run ID {child_job_run_id} does not exist." in str(exc_info.value) + job_dependencies = session.query(JobDependency).all() + assert not job_dependencies diff --git a/tests/lib/workflow/test_pipeline_factory.py b/tests/lib/workflow/test_pipeline_factory.py new file mode 100644 index 00000000..b944e469 --- /dev/null +++ b/tests/lib/workflow/test_pipeline_factory.py @@ -0,0 +1,242 @@ +# ruff: noqa: E402 +import pytest + +pytest.importorskip("fastapi") + +from sqlalchemy import select + +from mavedb.lib.workflow.pipeline_factory import PipelineFactory +from mavedb.models.job_run import JobRun + + +@pytest.mark.unit +class TestPipelineFactoryUnit: + """Unit tests for the PipelineFactory class.""" + + def test_create_pipeline_raises_if_pipeline_not_found(self, session, test_user): + """Test that creating a pipeline with an unknown name raises a KeyError.""" + pipeline_factory = PipelineFactory(session=session) + + with pytest.raises(KeyError) as exc_info: + pipeline_factory.create_pipeline( + pipeline_name="unknown_pipeline", + creating_user=test_user, + pipeline_params={}, + ) + + assert "unknown_pipeline" in str(exc_info.value) + + def test_create_pipeline_prioritizes_correlation_id_from_params( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that the correlation_id from pipeline_params is used when creating a pipeline.""" + pipeline_name = sample_independent_pipeline_definition["name"] + test_correlation_id = "test-correlation-id-123" + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"correlation_id": test_correlation_id, "required_param": "some_value"}, + ) + + assert job_run.correlation_id == test_correlation_id + + def test_create_pipeline_creates_start_pipeline_job( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline results in a JobRun of type 'start_pipeline'.""" + pipeline_name = sample_independent_pipeline_definition["name"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + stmt = select(JobRun).where(JobRun.pipeline_id == pipeline.id) + job_runs = session.execute(stmt).scalars().all() + + start_pipeline_jobs = [jr for jr in job_runs if jr.job_function == "start_pipeline"] + assert len(start_pipeline_jobs) == 1 + assert start_pipeline_jobs[0].id == job_run.id + + def test_create_pipeline_creates_job_runs( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline results in the correct number of JobRun instances.""" + pipeline_name = sample_independent_pipeline_definition["name"] + expected_job_count = len(sample_independent_pipeline_definition["job_definitions"]) + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + stmt = select(JobRun).where(JobRun.pipeline_id == pipeline.id) + job_runs = session.execute(stmt).scalars().all() + + # One additional job run for the start_pipeline job + assert len(job_runs) == expected_job_count + 1 + + def test_create_pipeline_creates_job_dependencies( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_dependent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline with job dependencies results in correct JobDependency records.""" + pipeline_name = sample_dependent_pipeline_definition["name"] + jobs = sample_dependent_pipeline_definition["job_definitions"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"paramA": "valueA", "paramB": "valueB", "required_param": "some_value"}, + ) + + stmt = select(JobRun).where(JobRun.pipeline_id == pipeline.id) + job_runs = session.execute(stmt).scalars().all() + job_run_dict = {jr.job_function: jr for jr in job_runs} + + # Verify dependencies + for job_def in jobs: + job_deps = job_def["dependencies"] + job_run = job_run_dict[job_def["function"]] + + # For each dependency, check that a JobDependency record exists + # and verify its properties + for dep_key, dependency_type in job_deps: + dep_job_run = job_run_dict[[jd for jd in jobs if jd["key"] == dep_key][0]["function"]] + + assert len(job_run.job_dependencies) == 1 + for jd in job_run.job_dependencies: + assert jd.depends_on_job_id == dep_job_run.id + assert jd.dependency_type == dependency_type + + def test_create_pipeline_creates_pipeline( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline results in a Pipeline record in the database.""" + pipeline_name = sample_independent_pipeline_definition["name"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + stmt = select(pipeline.__class__).where(pipeline.__class__.id == pipeline.id) + retrieved_pipeline = session.execute(stmt).scalars().first() + + assert retrieved_pipeline is not None + assert retrieved_pipeline.id == pipeline.id + + +@pytest.mark.integration +class TestPipelineFactoryIntegration: + """Integration tests for the PipelineFactory class.""" + + def test_create_pipeline_independent( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Integration test for creating an independent pipeline.""" + pipeline_name = sample_independent_pipeline_definition["name"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + assert pipeline.name == pipeline_name + assert job_run.job_function == "start_pipeline" + + for job_def in sample_independent_pipeline_definition["job_definitions"]: + stmt = select(JobRun).where( + JobRun.pipeline_id == pipeline.id, + JobRun.job_function == job_def["function"], + ) + job_run = session.execute(stmt).scalars().first() + assert job_run is not None + assert job_run.job_params["param1"] == "value1" + assert job_run.job_params["param2"] == "value2" + assert job_run.pipeline_id == pipeline.id + assert job_run.job_dependencies == [] + + def test_create_pipeline_dependent( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_dependent_pipeline_definition, + test_user, + ): + """Integration test for creating a dependent pipeline.""" + pipeline_name = sample_dependent_pipeline_definition["name"] + + passed_params = {"paramA": "valueA", "paramB": "valueB", "required_param": "some_value"} + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params=passed_params, + ) + + assert pipeline.name == pipeline_name + assert job_run.job_function == "start_pipeline" + + job_runs = {} + for job_def in sample_dependent_pipeline_definition["job_definitions"]: + stmt = select(JobRun).where( + JobRun.pipeline_id == pipeline.id, + JobRun.job_function == job_def["function"], + ) + jr = session.execute(stmt).scalars().first() + assert jr is not None + assert jr.pipeline_id == pipeline.id + for param_key, param_value in job_def["params"].items(): + if param_value is not None: + assert jr.job_params[param_key] == param_value + else: + assert jr.job_params[param_key] == passed_params[param_key] + + job_runs[job_def["key"]] = jr + + # Verify dependencies + for job_def in sample_dependent_pipeline_definition["job_definitions"]: + job_deps = job_def["dependencies"] + job_run = job_runs[job_def["key"]] + for dep_key, dependency_type in job_deps: + dep_job_run = job_runs[dep_key] + + assert len(job_run.job_dependencies) == 1 + for jd in job_run.job_dependencies: + assert jd.depends_on_job_id == dep_job_run.id + assert jd.dependency_type == dependency_type diff --git a/tests/routers/conftest.py b/tests/routers/conftest.py index d54b18d8..ba34c548 100644 --- a/tests/routers/conftest.py +++ b/tests/routers/conftest.py @@ -4,32 +4,36 @@ import pytest from mavedb.models.clinical_control import ClinicalControl -from mavedb.models.controlled_keyword import ControlledKeyword from mavedb.models.contributor import Contributor +from mavedb.models.controlled_keyword import ControlledKeyword from mavedb.models.enums.user_role import UserRole -from mavedb.models.publication_identifier import PublicationIdentifier from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.license import License +from mavedb.models.publication_identifier import PublicationIdentifier from mavedb.models.role import Role from mavedb.models.taxonomy import Taxonomy from mavedb.models.user import User - from tests.helpers.constants import ( ADMIN_USER, - TEST_CLINVAR_CONTROL, - TEST_GENERIC_CLINICAL_CONTROL, - EXTRA_USER, EXTRA_CONTRIBUTOR, + EXTRA_LICENSE, + EXTRA_USER, + TEST_CLINVAR_CONTROL, TEST_DB_KEYWORDS, - TEST_LICENSE, + TEST_GENERIC_CLINICAL_CONTROL, + TEST_GNOMAD_VARIANT, TEST_INACTIVE_LICENSE, - EXTRA_LICENSE, + TEST_LICENSE, + TEST_PUBMED_PUBLICATION, TEST_SAVED_TAXONOMY, TEST_USER, - TEST_PUBMED_PUBLICATION, - TEST_GNOMAD_VARIANT, ) +try: + from .conftest_optional import * # noqa: F403, F401 +except ImportError: + pass + @pytest.fixture def setup_router_db(session): diff --git a/tests/routers/conftest_optional.py b/tests/routers/conftest_optional.py new file mode 100644 index 00000000..efbd119b --- /dev/null +++ b/tests/routers/conftest_optional.py @@ -0,0 +1,14 @@ +from unittest import mock + +import pytest +from mypy_boto3_s3 import S3Client + + +@pytest.fixture +def mock_s3_client(): + """Mock S3 client for tests that interact with S3.""" + + with mock.patch("mavedb.routers.score_sets.s3_client") as mock_s3_client_func: + mock_s3 = mock.MagicMock(spec=S3Client) + mock_s3_client_func.return_value = mock_s3 + yield mock_s3 diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 86234392..5cb29ab6 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -448,7 +448,7 @@ def test_can_patch_score_set_data_before_publication( indirect=["mock_publication_fetch"], ) def test_can_patch_score_set_data_with_files_before_publication( - client, setup_router_db, form_field, filename, mime_type, data_files, mock_publication_fetch + client, setup_router_db, form_field, filename, mime_type, data_files, mock_publication_fetch, mock_s3_client ): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -460,7 +460,10 @@ def test_can_patch_score_set_data_with_files_before_publication( if form_field == "counts_file" or form_field == "scores_file": data_file_path = data_files / filename files = {form_field: (filename, open(data_file_path, "rb"), mime_type)} - with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + with ( + patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), + ): response = client.patch(f"/api/v1/score-sets-with-variants/{score_set['urn']}", files=files) worker_queue.assert_called_once() assert response.status_code == 200 @@ -871,13 +874,14 @@ def test_creating_user_can_view_all_score_calibrations_in_score_set(client, setu ######################################################################################################################## -def test_add_score_set_variants_scores_only_endpoint(client, setup_router_db, data_files): +def test_add_score_set_variants_scores_only_endpoint(client, setup_router_db, data_files, mock_s3_client): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores.csv" with ( open(scores_csv_path, "rb") as scores_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -895,7 +899,9 @@ def test_add_score_set_variants_scores_only_endpoint(client, setup_router_db, da assert score_set == response_data -def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setup_router_db, data_files): +def test_add_score_set_variants_scores_and_counts_endpoint( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores.csv" @@ -904,6 +910,7 @@ def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setu open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -925,7 +932,7 @@ def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setu def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( - session, client, setup_router_db, data_files + session, client, setup_router_db, data_files, mock_s3_client ): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -939,6 +946,7 @@ def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( open(score_columns_metadata_path, "rb") as score_columns_metadata_file, open(count_columns_metadata_path, "rb") as count_columns_metadata_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): score_columns_metadata = json.load(score_columns_metadata_file) count_columns_metadata = json.load(count_columns_metadata_file) @@ -965,13 +973,14 @@ def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( assert score_set == response_data -def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_router_db, data_files): +def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_router_db, data_files, mock_s3_client): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores_utf8_encoded.csv" with ( open(scores_csv_path, "rb") as scores_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -989,7 +998,9 @@ def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_ assert score_set == response_data -def test_add_score_set_variants_scores_and_counts_endpoint_utf8_encoded(session, client, setup_router_db, data_files): +def test_add_score_set_variants_scores_and_counts_endpoint_utf8_encoded( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores_utf8_encoded.csv" @@ -998,6 +1009,7 @@ def test_add_score_set_variants_scores_and_counts_endpoint_utf8_encoded(session, open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1073,7 +1085,9 @@ def test_anonymous_cannot_add_scores_to_other_user_score_set( assert "Could not validate credentials" in response_data["detail"] -def test_contributor_can_add_scores_to_other_user_score_set(session, client, setup_router_db, data_files): +def test_contributor_can_add_scores_to_other_user_score_set( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) change_ownership(session, score_set["urn"], ScoreSetDbModel) @@ -1090,6 +1104,7 @@ def test_contributor_can_add_scores_to_other_user_score_set(session, client, set with ( open(scores_csv_path, "rb") as scores_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1127,7 +1142,9 @@ def test_contributor_can_add_scores_to_other_user_score_set(session, client, set assert score_set == response_data -def test_contributor_can_add_scores_and_counts_to_other_user_score_set(session, client, setup_router_db, data_files): +def test_contributor_can_add_scores_and_counts_to_other_user_score_set( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) change_ownership(session, score_set["urn"], ScoreSetDbModel) @@ -1146,6 +1163,7 @@ def test_contributor_can_add_scores_and_counts_to_other_user_score_set(session, open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1187,7 +1205,7 @@ def test_contributor_can_add_scores_and_counts_to_other_user_score_set(session, def test_admin_can_add_scores_to_other_user_score_set( - session, client, setup_router_db, data_files, admin_app_overrides + session, client, setup_router_db, data_files, mock_s3_client, admin_app_overrides ): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -1197,6 +1215,7 @@ def test_admin_can_add_scores_to_other_user_score_set( open(scores_csv_path, "rb") as scores_file, DependencyOverrider(admin_app_overrides), patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1214,7 +1233,9 @@ def test_admin_can_add_scores_to_other_user_score_set( assert score_set == response_data -def test_admin_can_add_scores_and_counts_to_other_user_score_set(session, client, setup_router_db, data_files): +def test_admin_can_add_scores_and_counts_to_other_user_score_set( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores.csv" @@ -1223,6 +1244,7 @@ def test_admin_can_add_scores_and_counts_to_other_user_score_set(session, client open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", diff --git a/tests/worker/conftest.py b/tests/worker/conftest.py index 49dad88f..4f1f32e3 100644 --- a/tests/worker/conftest.py +++ b/tests/worker/conftest.py @@ -1,34 +1,287 @@ +""" +Test configuration and fixtures for worker lib tests. +""" + +from datetime import datetime from pathlib import Path from shutil import copytree from unittest.mock import Mock +import pandas as pd import pytest +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus +from mavedb.models.experiment import Experiment +from mavedb.models.experiment_set import ExperimentSet +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun from mavedb.models.license import License -from mavedb.models.taxonomy import Taxonomy +from mavedb.models.pipeline import Pipeline +from mavedb.models.score_set import ScoreSet +from mavedb.models.target_gene import TargetGene +from mavedb.models.target_sequence import TargetSequence from mavedb.models.user import User +from tests.helpers.constants import EXTRA_USER, TEST_LICENSE, TEST_USER + +# Attempt to import optional top level fixtures. If the modules they depend on are not installed, +# we won't have access to our full fixture suite and only a limited subset of tests can be run. +try: + from .conftest_optional import * # noqa: F401, F403 + +except ModuleNotFoundError: + pass + + +@pytest.fixture +def sample_job_run(sample_pipeline): + """Create a sample JobRun instance for testing.""" + return JobRun( + id=1, + urn="test:job:1", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=sample_pipeline.id, + progress_current=0, + progress_total=100, + progress_message="Ready to start", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_dependent_job_run(sample_pipeline): + """Create a sample dependent JobRun instance for testing.""" + return JobRun( + id=2, + urn="test:job:2", + job_type="dependent_job", + job_function="dependent_function", + status=JobStatus.PENDING, + pipeline_id=sample_pipeline.id, + progress_current=0, + progress_total=100, + progress_message="Waiting for dependency", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_independent_job_run(): + """Create a sample independent JobRun instance for testing.""" + return JobRun( + id=3, + urn="test:job:3", + job_type="independent_job", + job_function="independent_function", + status=JobStatus.PENDING, + pipeline_id=None, + progress_current=0, + progress_total=100, + progress_message="Ready to start", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_pipeline(): + """Create a sample Pipeline instance for testing.""" + return Pipeline( + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_empty_pipeline(): + """Create a sample Pipeline instance with no jobs for testing.""" + return Pipeline( + id=999, + urn="test:pipeline:999", + name="Empty Pipeline", + description="A pipeline with no jobs", + status=PipelineStatus.CREATED, + correlation_id="empty_correlation_456", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_job_dependency(sample_dependent_job_run, sample_job_run): + """Create a sample JobDependency instance for testing.""" + return JobDependency( + id=sample_dependent_job_run.id, # dependent job + depends_on_job_id=sample_job_run.id, # depends on job 1 + dependency_type=DependencyType.SUCCESS_REQUIRED, + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_user(): + """Create a sample User instance for testing.""" + return User(**TEST_USER) -from tests.helpers.constants import ( - EXTRA_USER, - TEST_LICENSE, - TEST_INACTIVE_LICENSE, - TEST_SAVED_TAXONOMY, - TEST_USER, - TEST_MAVEDB_ATHENA_ROW, -) + +@pytest.fixture +def sample_extra_user(): + """Create an extra sample User instance for testing.""" + return User(**EXTRA_USER) + + +@pytest.fixture +def sample_license(): + """Create a sample License instance for testing.""" + return License(**TEST_LICENSE) + + +@pytest.fixture +def sample_experiment_set(sample_user): + """Create a sample ExperimentSet instance for testing.""" + return ExperimentSet( + extra_metadata={}, + created_by=sample_user, + ) + + +@pytest.fixture +def sample_experiment(sample_experiment_set, sample_user): + """Create a sample Experiment instance for testing.""" + return Experiment( + title="Sample Experiment", + short_description="A sample experiment for testing purposes", + abstract_text="This is an abstract for the sample experiment.", + method_text="This is a method description for the sample experiment.", + extra_metadata={}, + experiment_set=sample_experiment_set, + created_by=sample_user, + ) + + +@pytest.fixture +def sample_score_set(sample_experiment, sample_user, sample_license): + """Create a sample ScoreSet instance for testing.""" + return ScoreSet( + title="Sample Score Set", + short_description="A sample score set for testing purposes", + abstract_text="This is an abstract for the sample score set.", + method_text="This is a method description for the sample score set.", + extra_metadata={}, + experiment=sample_experiment, + created_by=sample_user, + license=sample_license, + target_genes=[ + TargetGene( + name="Sample Gene", + category="protein_coding", + target_sequence=TargetSequence(label="testsequence", sequence_type="dna", sequence="ATGCAT"), + ) + ], + ) @pytest.fixture -def setup_worker_db(session): +def with_populated_domain_data( + session, + sample_user, + sample_extra_user, + sample_experiment_set, + sample_experiment, + sample_score_set, + sample_license, +): db = session - db.add(User(**TEST_USER)) - db.add(User(**EXTRA_USER)) - db.add(Taxonomy(**TEST_SAVED_TAXONOMY)) - db.add(License(**TEST_LICENSE)) - db.add(License(**TEST_INACTIVE_LICENSE)) + db.add(sample_user) + db.add(sample_extra_user) + db.add(sample_experiment_set) + db.add(sample_experiment) + db.add(sample_score_set) + db.add(sample_license) db.commit() +@pytest.fixture +def with_populated_job_data( + session, + sample_job_run, + sample_pipeline, + sample_empty_pipeline, + sample_job_dependency, + sample_dependent_job_run, + sample_independent_job_run, +): + """Set up the database with sample data for worker tests.""" + session.add(sample_pipeline) + session.add(sample_empty_pipeline) + session.add(sample_job_run) + session.add(sample_dependent_job_run) + session.add(sample_independent_job_run) + session.add(sample_job_dependency) + session.commit() + + +@pytest.fixture +def mock_pipeline(): + """Create a mock Pipeline instance. By default, + properties are identical to a default new Pipeline entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=Pipeline, + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + metadata_={}, + created_at=datetime.now(), + started_at=None, + finished_at=None, + created_by_user_id=None, + mavedb_version=None, + ) + + +@pytest.fixture +def mock_job_run(mock_pipeline): + """Create a mock JobRun instance. By default, + properties are identical to a default new JobRun entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=JobRun, + id=123, + urn="test:job:123", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=mock_pipeline.id, + priority=0, + max_retries=3, + retry_count=0, + retry_delay_seconds=None, + scheduled_at=datetime.now(), + started_at=None, + finished_at=None, + created_at=datetime.now(), + error_message=None, + error_traceback=None, + failure_category=None, + progress_current=None, + progress_total=None, + progress_message=None, + correlation_id=None, + metadata_={}, + mavedb_version=None, + ) + + @pytest.fixture def data_files(tmp_path): copytree(Path(__file__).absolute().parent / "data", tmp_path / "data") @@ -36,10 +289,10 @@ def data_files(tmp_path): @pytest.fixture -def mocked_gnomad_variant_row(): - gnomad_variant = Mock() +def sample_score_dataframe(data_files): + return pd.read_csv(data_files / "scores.csv") - for key, value in TEST_MAVEDB_ATHENA_ROW.items(): - setattr(gnomad_variant, key, value) - return gnomad_variant +@pytest.fixture +def sample_count_dataframe(data_files): + return pd.read_csv(data_files / "counts.csv") diff --git a/tests/worker/conftest_optional.py b/tests/worker/conftest_optional.py new file mode 100644 index 00000000..f6da4b7c --- /dev/null +++ b/tests/worker/conftest_optional.py @@ -0,0 +1,63 @@ +from concurrent.futures import ProcessPoolExecutor +from unittest.mock import Mock, patch + +import pytest +from arq import ArqRedis +from cdot.hgvs.dataproviders import RESTDataProvider +from sqlalchemy.orm import Session + +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager + + +@pytest.fixture +def mock_job_manager(mock_job_run): + """Create a JobManager with mocked database and Redis dependencies.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + # Don't call the real constructor since it tries to load the job from DB + manager = object.__new__(JobManager) + manager.db = mock_db + manager.redis = mock_redis + manager.job_id = mock_job_run.id + + with patch.object(manager, "get_job", return_value=mock_job_run): + yield manager + + +@pytest.fixture +def mock_pipeline_manager(mock_job_manager, mock_pipeline): + """Create a PipelineManager with mocked database, Redis dependencies, and job manager.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + # Don't call the real constructor since it tries to validate the pipeline + manager = object.__new__(PipelineManager) + manager.db = mock_db + manager.redis = mock_redis + manager.pipeline_id = 123 + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.JobManager") as mock_job_manager_class, + patch.object(manager, "get_pipeline", return_value=mock_pipeline), + ): + mock_job_manager_class.return_value = mock_job_manager + yield manager + + +@pytest.fixture +def mock_worker_ctx(): + """Create a mock worker context dictionary for testing.""" + mock_redis = Mock(spec=ArqRedis) + mock_hdp = Mock(spec=RESTDataProvider) + mock_pool = Mock(spec=ProcessPoolExecutor) + + # Don't mock the session itself to allow real DB interactions in tests + # It's generally more pain than it's worth to mock out SQLAlchemy sessions, + # although it can sometimes be useful when raising specific exceptions. + return { + "redis": mock_redis, + "hdp": mock_hdp, + "pool": mock_pool, + } diff --git a/tests/worker/data/counts.csv b/tests/worker/data/counts.csv index 0cc1e742..4821232a 100644 --- a/tests/worker/data/counts.csv +++ b/tests/worker/data/counts.csv @@ -1,4 +1,5 @@ -hgvs_nt,hgvs_pro,c_0,c_1 -c.1A>T,p.Thr1Ser,10,20 -c.2C>T,p.Thr1Met,8,8 -c.6T>A,p.Phe2Leu,90,2 +hgvs_nt,hgvs_splice,hgvs_pro,c_0,c_1 +c.1A>T,NA,p.Met1Leu,10,20 +c.2T>A,NA,p.Met1Lys,8,8 +c.3G>C,NA,p.Met1Ile,90,2 +c.4C>G,NA,p.His2Asp,12,1 diff --git a/tests/worker/data/scores.csv b/tests/worker/data/scores.csv index 11fce498..bd8e3bae 100644 --- a/tests/worker/data/scores.csv +++ b/tests/worker/data/scores.csv @@ -1,4 +1,5 @@ -hgvs_nt,hgvs_pro,score,s_0,s_1 -c.1A>T,p.Thr1Ser,0.3,val1,val1 -c.2C>T,p.Thr1Met,0.0,val2,val2 -c.6T>A,p.Phe2Leu,-1.65,val3,val3 +hgvs_nt,hgvs_splice,hgvs_pro,score,s_0,s_1 +c.1A>T,NA,p.Met1Leu,0.3,val1,val1 +c.2T>A,NA,p.Met1Lys,0,val2,val2 +c.3G>C,NA,p.Met1Ile,-1.65,val3,val3 +c.4C>G,NA,p.His2Asp,NA,val5,val4 diff --git a/tests/worker/jobs/conftest.py b/tests/worker/jobs/conftest.py new file mode 100644 index 00000000..677b4955 --- /dev/null +++ b/tests/worker/jobs/conftest.py @@ -0,0 +1,872 @@ +import pytest + +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.pipeline import Pipeline +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from tests.helpers.constants import VALID_CAID + +try: + from .conftest_optional import * # noqa: F403, F401 +except ImportError: + pass + + +## param fixtures for job runs ## + + +@pytest.fixture +def create_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): + """Provide sample parameters for create_variants_for_score_set job.""" + + return { + "scores_file_key": "sample_scores.csv", + "counts_file_key": "sample_counts.csv", + "correlation_id": "sample-correlation-id", + "updater_id": sample_user.id, + "score_set_id": sample_score_set.id, + "score_columns_metadata": {"s_0": {"description": "metadataS", "details": "detailsS"}}, + "count_columns_metadata": {"c_0": {"description": "metadataC", "details": "detailsC"}}, + } + + +@pytest.fixture +def map_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): + """Provide sample parameters for map_variants_for_score_set job.""" + + return { + "score_set_id": sample_score_set.id, + "correlation_id": "sample-mapping-correlation-id", + "updater_id": sample_user.id, + } + + +@pytest.fixture +def link_gnomad_variants_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for create_variants_for_score_set job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def submit_uniprot_mapping_jobs_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for submit_uniprot_mapping_jobs_for_score_set job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def poll_uniprot_mapping_jobs_sample_params( + submit_uniprot_mapping_jobs_sample_params, + with_dependent_polling_job_for_submission_run, +): + """Provide sample parameters for poll_uniprot_mapping_jobs_for_score_set job.""" + + return { + "correlation_id": submit_uniprot_mapping_jobs_sample_params["correlation_id"], + "score_set_id": submit_uniprot_mapping_jobs_sample_params["score_set_id"], + "mapping_jobs": {}, + } + + +@pytest.fixture +def submit_score_set_mappings_to_car_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for submit_score_set_mappings_to_car job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def refresh_clinvar_controls_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for refresh_clinvar_controls job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + "month": 1, + "year": 2026, + } + + +## Sample pipeline + + +@pytest.fixture +def sample_pipeline(): + """Create a sample Pipeline instance for testing.""" + + return Pipeline( + name="Sample Pipeline", + description="A sample pipeline for testing purposes", + ) + + +@pytest.fixture +def with_sample_pipeline(session, sample_pipeline): + """Fixture to ensure sample pipeline exists in the database.""" + session.add(sample_pipeline) + session.commit() + + +## Variant creation job fixtures + + +@pytest.fixture +def dummy_variant_creation_job_run(create_variants_sample_params): + """Create a dummy variant creation job run for testing.""" + + return JobRun( + urn="test:dummy_variant_creation_job", + job_type="dummy_variant_creation", + job_function="dummy_variant_creation_function", + max_retries=3, + retry_count=0, + job_params=create_variants_sample_params, + ) + + +@pytest.fixture +def dummy_variant_mapping_job_run(map_variants_sample_params): + """Create a dummy variant mapping job run for testing.""" + + return JobRun( + urn="test:dummy_variant_mapping_job", + job_type="dummy_variant_mapping", + job_function="dummy_variant_mapping_function", + max_retries=3, + retry_count=0, + job_params=map_variants_sample_params, + ) + + +@pytest.fixture +def with_dummy_setup_jobs( + session, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, +): + """Add dummy variant creation and mapping job runs to the session.""" + + session.add(dummy_variant_creation_job_run) + session.add(dummy_variant_mapping_job_run) + session.commit() + + +## Gnomad Linkage Job Fixtures ## + + +@pytest.fixture +def sample_link_gnomad_variants_pipeline(): + """Create a pipeline instance for link_gnomad_variants job.""" + + return Pipeline( + urn="test:link_gnomad_variants_pipeline", + name="Link gnomAD Variants Pipeline", + ) + + +@pytest.fixture +def sample_link_gnomad_variants_run(link_gnomad_variants_sample_params): + """Create a JobRun instance for link_gnomad_variants job.""" + + return JobRun( + urn="test:link_gnomad_variants", + job_type="link_gnomad_variants", + job_function="link_gnomad_variants", + max_retries=3, + retry_count=0, + job_params=link_gnomad_variants_sample_params, + ) + + +@pytest.fixture +def with_gnomad_linking_job(session, sample_link_gnomad_variants_run): + """Add a link_gnomad_variants job run to the session.""" + + session.add(sample_link_gnomad_variants_run) + session.commit() + + +@pytest.fixture +def with_gnomad_linking_pipeline(session, sample_link_gnomad_variants_pipeline): + """Add a link_gnomad_variants pipeline to the session.""" + + session.add(sample_link_gnomad_variants_pipeline) + session.commit() + + +@pytest.fixture +def sample_link_gnomad_variants_run_pipeline( + session, + with_gnomad_linking_job, + with_gnomad_linking_pipeline, + sample_link_gnomad_variants_run, + sample_link_gnomad_variants_pipeline, +): + """Provide a context with a link_gnomad_variants job run and pipeline.""" + + sample_link_gnomad_variants_run.pipeline_id = sample_link_gnomad_variants_pipeline.id + session.commit() + return sample_link_gnomad_variants_run + + +@pytest.fixture +def setup_sample_variants_with_caid( + session, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run +): + """Setup variants and mapped variants in the database for testing.""" + score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) + + # Add a variant and mapped variant to the database with a CAID + variant = Variant( + urn="urn:variant:test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={"hgvs_c": "NM_000000.1:c.1A>G", "hgvs_p": "NP_000000.1:p.Met1Val"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id=VALID_CAID, + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + return variant, mapped_variant + + +## Uniprot Job Fixtures ## + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_pipeline(): + """Create a pipeline instance for submit_uniprot_mapping_jobs_for_score_set job.""" + + return Pipeline( + urn="test:submit_uniprot_mapping_jobs_pipeline", + name="Submit UniProt Mapping Jobs Pipeline", + ) + + +@pytest.fixture +def sample_poll_uniprot_mapping_jobs_pipeline(): + """Create a pipeline instance for poll_uniprot_mapping_jobs_for_score_set job.""" + + return Pipeline( + urn="test:poll_uniprot_mapping_jobs_pipeline", + name="Poll UniProt Mapping Jobs Pipeline", + ) + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_run(submit_uniprot_mapping_jobs_sample_params): + """Create a JobRun instance for submit_uniprot_mapping_jobs_for_score_set job.""" + + return JobRun( + urn="test:submit_uniprot_mapping_jobs", + job_type="submit_uniprot_mapping_jobs", + job_function="submit_uniprot_mapping_jobs_for_score_set", + max_retries=3, + retry_count=0, + job_params=submit_uniprot_mapping_jobs_sample_params, + ) + + +@pytest.fixture +def sample_dummy_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, +): + """Create a sample dummy dependent polling job for the submission run.""" + + dependent_job = JobRun( + urn="test:dummy_poll_uniprot_mapping_jobs", + job_type="dummy_poll_uniprot_mapping_jobs", + job_function="dummy_arq_function", + max_retries=3, + retry_count=0, + job_params={ + "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], + "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], + "mapping_jobs": {}, + }, + ) + + return dependent_job + + +@pytest.fixture +def sample_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, +): + """Create a sample dependent polling job for the submission run.""" + + dependent_job = JobRun( + urn="test:dependent_poll_uniprot_mapping_jobs", + job_type="dependent_poll_uniprot_mapping_jobs", + job_function="poll_uniprot_mapping_jobs_for_score_set", + max_retries=3, + retry_count=0, + job_params={ + "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], + "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], + "mapping_jobs": {}, + }, + ) + + return dependent_job + + +@pytest.fixture +def with_dummy_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, +): + """Create a sample dummy dependent polling job for the submission run.""" + session.add(sample_dummy_polling_job_for_submission_run) + session.commit() + + dependency = JobDependency( + id=sample_dummy_polling_job_for_submission_run.id, + depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, + dependency_type=DependencyType.SUCCESS_REQUIRED, + ) + session.add(dependency) + session.commit() + + +@pytest.fixture +def with_dependent_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, + sample_polling_job_for_submission_run, +): + """Create a sample dependent polling job for the submission run.""" + session.add(sample_polling_job_for_submission_run) + session.commit() + + dependency = JobDependency( + id=sample_polling_job_for_submission_run.id, + depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, + dependency_type=DependencyType.SUCCESS_REQUIRED, + ) + session.add(dependency) + session.commit() + + +@pytest.fixture +def with_independent_polling_job_for_submission_run( + session, + sample_polling_job_for_submission_run, +): + """Create a sample dependent polling job for the submission run.""" + session.add(sample_polling_job_for_submission_run) + session.commit() + + +@pytest.fixture +def with_submit_uniprot_mapping_job(session, sample_submit_uniprot_mapping_jobs_run): + """Add a submit_uniprot_mapping_jobs job run to the session.""" + + session.add(sample_submit_uniprot_mapping_jobs_run) + session.commit() + + +@pytest.fixture +def with_poll_uniprot_mapping_job(session, sample_poll_uniprot_mapping_jobs_run): + """Add a poll_uniprot_mapping_jobs job run to the session.""" + + session.add(sample_poll_uniprot_mapping_jobs_run) + session.commit() + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_run_in_pipeline( + session, + with_submit_uniprot_mapping_job, + with_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run, + sample_submit_uniprot_mapping_jobs_pipeline, +): + """Provide a context with a submit_uniprot_mapping_jobs job run and pipeline.""" + + sample_submit_uniprot_mapping_jobs_run.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return sample_submit_uniprot_mapping_jobs_run + + +@pytest.fixture +def sample_poll_uniprot_mapping_jobs_run_in_pipeline( + session, + with_independent_polling_job_for_submission_run, + with_poll_uniprot_mapping_jobs_pipeline, + sample_polling_job_for_submission_run, + sample_poll_uniprot_mapping_jobs_pipeline, +): + """Provide a context with a poll_uniprot_mapping_jobs job run and pipeline.""" + + sample_polling_job_for_submission_run.pipeline_id = sample_poll_uniprot_mapping_jobs_pipeline.id + session.commit() + return sample_polling_job_for_submission_run + + +@pytest.fixture +def sample_dummy_polling_job_for_submission_run_in_pipeline( + session, + with_dummy_polling_job_for_submission_run, + with_submit_uniprot_mapping_jobs_pipeline, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_dummy_polling_job_for_submission_run, +): + """Provide a context with a dependent polling job run in the pipeline.""" + + dependent_job = sample_dummy_polling_job_for_submission_run + dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return dependent_job + + +@pytest.fixture +def sample_polling_job_for_submission_run_in_pipeline( + session, + with_dependent_polling_job_for_submission_run, + with_submit_uniprot_mapping_jobs_pipeline, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_polling_job_for_submission_run, +): + """Provide a context with a dependent polling job run in the pipeline.""" + + dependent_job = sample_polling_job_for_submission_run + dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return dependent_job + + +@pytest.fixture +def with_submit_uniprot_mapping_jobs_pipeline( + session, + sample_submit_uniprot_mapping_jobs_pipeline, +): + """Add a submit_uniprot_mapping_jobs pipeline to the session.""" + + session.add(sample_submit_uniprot_mapping_jobs_pipeline) + session.commit() + + +@pytest.fixture +def with_poll_uniprot_mapping_jobs_pipeline( + session, + sample_poll_uniprot_mapping_jobs_pipeline, +): + """Add a poll_uniprot_mapping_jobs pipeline to the session.""" + session.add(sample_poll_uniprot_mapping_jobs_pipeline) + session.commit() + + +## Clingen Job Fixtures ## + + +@pytest.fixture +def submit_score_set_mappings_to_car_sample_pipeline(): + """Create a pipeline instance for submit_score_set_mappings_to_car job.""" + + return Pipeline( + urn="test:submit_score_set_mappings_to_car_pipeline", + name="Submit Score Set Mappings to ClinGen Allele Registry Pipeline", + ) + + +@pytest.fixture +def submit_score_set_mappings_to_ldh_sample_pipeline(): + """Create a pipeline instance for submit_score_set_mappings_to_ldh job.""" + + return Pipeline( + urn="test:submit_score_set_mappings_to_ldh_pipeline", + name="Submit Score Set Mappings to ClinGen Allele Registry Pipeline", + ) + + +@pytest.fixture +def submit_score_set_mappings_to_car_sample_job_run(submit_score_set_mappings_to_car_params): + """Create a JobRun instance for submit_score_set_mappings_to_car job.""" + + return JobRun( + urn="test:submit_score_set_mappings_to_car", + job_type="submit_score_set_mappings_to_car", + job_function="submit_score_set_mappings_to_car", + max_retries=3, + retry_count=0, + job_params=submit_score_set_mappings_to_car_params, + ) + + +@pytest.fixture +def submit_score_set_mappings_to_ldh_sample_job_run(submit_score_set_mappings_to_car_params): + """Create a JobRun instance for submit_score_set_mappings_to_car job.""" + + return JobRun( + urn="test:submit_score_set_mappings_to_car", + job_type="submit_score_set_mappings_to_car", + job_function="submit_score_set_mappings_to_car", + max_retries=3, + retry_count=0, + job_params=submit_score_set_mappings_to_car_params, + ) + + +@pytest.fixture +def submit_score_set_mappings_to_car_sample_job_run_in_pipeline( + session, + with_submit_score_set_mappings_to_car_pipeline, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_pipeline, + submit_score_set_mappings_to_car_sample_job_run, +): + """Provide a context with a submit_score_set_mappings_to_car job run and pipeline.""" + + submit_score_set_mappings_to_car_sample_job_run.pipeline_id = submit_score_set_mappings_to_car_sample_pipeline.id + session.commit() + return submit_score_set_mappings_to_car_sample_job_run + + +@pytest.fixture +def submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline( + session, + with_submit_score_set_mappings_to_ldh_pipeline, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_pipeline, + submit_score_set_mappings_to_ldh_sample_job_run, +): + """Provide a context with a submit_score_set_mappings_to_ldh job run and pipeline.""" + + submit_score_set_mappings_to_ldh_sample_job_run.pipeline_id = submit_score_set_mappings_to_ldh_sample_pipeline.id + session.commit() + return submit_score_set_mappings_to_ldh_sample_job_run + + +@pytest.fixture +def with_submit_score_set_mappings_to_car_job(session, submit_score_set_mappings_to_car_sample_job_run): + """Add a submit_score_set_mappings_to_car job run to the session.""" + + session.add(submit_score_set_mappings_to_car_sample_job_run) + session.commit() + + +@pytest.fixture +def with_submit_score_set_mappings_to_ldh_job(session, submit_score_set_mappings_to_ldh_sample_job_run): + """Add a submit_score_set_mappings_to_ldh job run to the session.""" + + session.add(submit_score_set_mappings_to_ldh_sample_job_run) + session.commit() + + +@pytest.fixture +def with_submit_score_set_mappings_to_car_pipeline( + session, + submit_score_set_mappings_to_car_sample_pipeline, +): + """Add a submit_score_set_mappings_to_car pipeline to the session.""" + + session.add(submit_score_set_mappings_to_car_sample_pipeline) + session.commit() + + +@pytest.fixture +def with_submit_score_set_mappings_to_ldh_pipeline( + session, + submit_score_set_mappings_to_ldh_sample_pipeline, +): + """Add a submit_score_set_mappings_to_ldh pipeline to the session.""" + + session.add(submit_score_set_mappings_to_ldh_sample_pipeline) + session.commit() + + +@pytest.fixture +def sample_independent_variant_creation_run(create_variants_sample_params): + """Create a JobRun instance for variant creation job.""" + + return JobRun( + urn="test:create_variants_for_score_set", + job_type="create_variants_for_score_set", + job_function="create_variants_for_score_set", + max_retries=3, + retry_count=0, + job_params=create_variants_sample_params, + ) + + +@pytest.fixture +def sample_independent_variant_mapping_run(map_variants_sample_params): + """Create a JobRun instance for variant mapping job.""" + + return JobRun( + urn="test:map_variants_for_score_set", + job_type="map_variants_for_score_set", + job_function="map_variants_for_score_set", + max_retries=3, + retry_count=0, + job_params=map_variants_sample_params, + ) + + +@pytest.fixture +def dummy_pipeline_step(): + """Create a dummy pipeline step function for testing.""" + + return JobRun( + urn="test:dummy_pipeline_step", + job_type="dummy_pipeline_step", + job_function="dummy_arq_function", + max_retries=3, + retry_count=0, + ) + + +@pytest.fixture +def sample_pipeline_variant_creation_run( + session, + with_variant_creation_pipeline, + sample_variant_creation_pipeline, + sample_independent_variant_creation_run, +): + """Create a JobRun instance for variant creation job.""" + + sample_independent_variant_creation_run.pipeline_id = sample_variant_creation_pipeline.id + session.add(sample_independent_variant_creation_run) + session.commit() + return sample_independent_variant_creation_run + + +@pytest.fixture +def sample_pipeline_variant_mapping_run( + session, + with_variant_mapping_pipeline, + sample_independent_variant_mapping_run, + sample_variant_mapping_pipeline, +): + """Create a JobRun instance for variant mapping job.""" + + sample_independent_variant_mapping_run.pipeline_id = sample_variant_mapping_pipeline.id + session.add(sample_independent_variant_mapping_run) + session.commit() + return sample_independent_variant_mapping_run + + +@pytest.fixture +def sample_variant_creation_pipeline(): + """Create a Pipeline instance.""" + + return Pipeline( + name="variant_creation_pipeline", + description="Pipeline for creating variants", + ) + + +@pytest.fixture +def sample_variant_mapping_pipeline(): + """Create a Pipeline instance.""" + + return Pipeline( + name="variant_mapping_pipeline", + description="Pipeline for mapping variants", + ) + + +@pytest.fixture +def with_independent_processing_runs( + session, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, +): + """Fixture to ensure independent variant processing runs exist in the database.""" + + session.add(sample_independent_variant_creation_run) + session.add(sample_independent_variant_mapping_run) + session.commit() + + +@pytest.fixture +def with_variant_creation_pipeline(session, sample_variant_creation_pipeline): + """Fixture to ensure variant creation pipeline and its runs exist in the database.""" + session.add(sample_variant_creation_pipeline) + session.commit() + + +@pytest.fixture +def with_variant_creation_pipeline_runs( + session, + with_variant_creation_pipeline, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + dummy_pipeline_step, +): + """Fixture to ensure pipeline variant processing runs exist in the database.""" + session.add(sample_pipeline_variant_creation_run) + dummy_pipeline_step.pipeline_id = sample_variant_creation_pipeline.id + session.add(dummy_pipeline_step) + session.commit() + + +@pytest.fixture +def with_variant_mapping_pipeline(session, sample_variant_mapping_pipeline): + """Fixture to ensure variant mapping pipeline and its runs exist in the database.""" + session.add(sample_variant_mapping_pipeline) + session.commit() + + +@pytest.fixture +def with_variant_mapping_pipeline_runs( + session, + with_variant_mapping_pipeline, + sample_variant_mapping_pipeline, + sample_pipeline_variant_mapping_run, + dummy_pipeline_step, +): + """Fixture to ensure pipeline variant processing runs exist in the database.""" + session.add(sample_pipeline_variant_mapping_run) + dummy_pipeline_step.pipeline_id = sample_variant_mapping_pipeline.id + session.add(dummy_pipeline_step) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline(): + """Create a sample Pipeline instance for testing.""" + + return Pipeline( + name="Dummy Pipeline", + description="A dummy pipeline for testing purposes", + ) + + +@pytest.fixture +def with_dummy_pipeline(session, sample_dummy_pipeline): + """Fixture to ensure dummy pipeline exists in the database.""" + session.add(sample_dummy_pipeline) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline): + """Create a sample JobRun instance for starting the dummy pipeline.""" + start_job_run = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="start_pipeline", + job_function="start_pipeline", + ) + session.add(start_job_run) + session.commit() + + return start_job_run + + +@pytest.fixture +def with_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline_start): + """Fixture to ensure a start pipeline job run for the dummy pipeline exists in the database.""" + session.add(sample_dummy_pipeline_start) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline_step(session, sample_dummy_pipeline): + """Create a sample PipelineStep instance for the dummy pipeline.""" + step = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="dummy_step", + job_function="dummy_arq_function", + ) + session.add(step) + session.commit() + return step + + +@pytest.fixture +def with_full_dummy_pipeline(session, with_dummy_pipeline_start, sample_dummy_pipeline, sample_dummy_pipeline_step): + """Fixture to ensure dummy pipeline steps exist in the database.""" + session.add(sample_dummy_pipeline_step) + session.commit() + + +@pytest.fixture +def sample_refresh_clinvar_controls_job_run(refresh_clinvar_controls_sample_params): + """Create a JobRun instance for refresh_clinvar_controls job.""" + + return JobRun( + urn="test:refresh_clinvar_controls", + job_type="refresh_clinvar_controls", + job_function="refresh_clinvar_controls", + max_retries=3, + retry_count=0, + job_params=refresh_clinvar_controls_sample_params, + ) + + +@pytest.fixture +def with_refresh_clinvar_controls_job(session, sample_refresh_clinvar_controls_job_run): + """Add a refresh_clinvar_controls job run to the session.""" + + session.add(sample_refresh_clinvar_controls_job_run) + session.commit() + + +@pytest.fixture +def sample_refresh_clinvar_controls_pipeline(): + """Create a pipeline instance for refresh_clinvar_controls job.""" + + return Pipeline( + urn="test:refresh_clinvar_controls_pipeline", + name="Refresh ClinVar Controls Pipeline", + ) + + +@pytest.fixture +def with_refresh_clinvar_controls_pipeline( + session, + sample_refresh_clinvar_controls_pipeline, +): + """Add a refresh_clinvar_controls pipeline to the session.""" + + session.add(sample_refresh_clinvar_controls_pipeline) + session.commit() + + +@pytest.fixture +def sample_refresh_clinvar_controls_job_in_pipeline( + session, + with_refresh_clinvar_controls_job, + with_refresh_clinvar_controls_pipeline, + sample_refresh_clinvar_controls_job_run, + sample_refresh_clinvar_controls_pipeline, +): + """Provide a context with a refresh_clinvar_controls job run and pipeline.""" + + sample_refresh_clinvar_controls_job_run.pipeline_id = sample_refresh_clinvar_controls_pipeline.id + session.commit() + return sample_refresh_clinvar_controls_job_run diff --git a/tests/worker/jobs/conftest_optional.py b/tests/worker/jobs/conftest_optional.py new file mode 100644 index 00000000..3ca408cb --- /dev/null +++ b/tests/worker/jobs/conftest_optional.py @@ -0,0 +1,14 @@ +from unittest import mock + +import pytest +from mypy_boto3_s3 import S3Client + + +@pytest.fixture +def mock_s3_client(): + """Mock S3 client for tests that interact with S3.""" + + with mock.patch("mavedb.worker.jobs.variant_processing.creation.s3_client") as mock_s3_client_func: + mock_s3 = mock.MagicMock(spec=S3Client) + mock_s3_client_func.return_value = mock_s3 + yield mock_s3 diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py new file mode 100644 index 00000000..26ab0426 --- /dev/null +++ b/tests/worker/jobs/data_management/test_views.py @@ -0,0 +1,300 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +from unittest.mock import call, patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.models.published_variant import PublishedVariantsMV +from mavedb.worker.jobs.data_management.views import refresh_materialized_views, refresh_published_variants_view +from tests.helpers.transaction_spy import TransactionSpy + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + +############################################################################################################################################ +# refresh_materialized_views +############################################################################################################################################ + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestRefreshMaterializedViewsUnit: + """Unit tests for the refresh_materialized_views function.""" + + async def test_refresh_materialized_views_calls_refresh_function(self, mock_worker_ctx, mock_job_manager): + """Test that refresh_materialized_views calls the refresh function.""" + with ( + patch("mavedb.worker.jobs.data_management.views.refresh_all_mat_views") as mock_refresh, + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), + ): + result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_refresh.assert_called_once_with(mock_job_manager.db) + assert result == {"status": "ok", "data": {}, "exception": None} + + async def test_refresh_materialized_views_updates_progress(self, mock_worker_ctx, mock_job_manager): + """Test that refresh_materialized_views updates progress correctly.""" + with ( + patch("mavedb.worker.jobs.data_management.views.refresh_all_mat_views"), + patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), + ): + result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) + + expected_calls = [ + call(0, 100, "Starting refresh of all materialized views."), + call(100, 100, "Completed refresh of all materialized views."), + ] + mock_update_progress.assert_has_calls(expected_calls) + assert result == {"status": "ok", "data": {}, "exception": None} + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshMaterializedViewsIntegration: + """Integration tests for the refresh_materialized_views function and decorator logic.""" + + async def test_refresh_materialized_views_integration(self, standalone_worker_context, session): + """Integration test that runs refresh_materialized_views end-to-end.""" + + # Flush will be called implicitly when the transaction is committed + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + result = await refresh_materialized_views(standalone_worker_context) + + job = session.execute( + select(JobRun).where(JobRun.job_function == "refresh_materialized_views") + ).scalar_one_or_none() + assert job is not None + assert job.status == JobStatus.SUCCEEDED + assert job.job_type == "cron_job" + + assert result == {"status": "ok", "data": {}, "exception": None} + + async def test_refresh_materialized_views_handles_exceptions(self, standalone_worker_context, session): + """Integration test that ensures exceptions during refresh are handled properly.""" + + with ( + patch( + "mavedb.worker.jobs.data_management.views.refresh_all_mat_views", + side_effect=Exception("Test exception during refresh"), + ), + TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await refresh_materialized_views(standalone_worker_context) + mock_send_slack_error.assert_called_once() + + job = session.execute( + select(JobRun).where(JobRun.job_function == "refresh_materialized_views") + ).scalar_one_or_none() + + assert job is not None + assert job.status == JobStatus.FAILED + assert job.job_type == "cron_job" + assert job.error_message == "Test exception during refresh" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshMaterializedViewsArqContext: + """Integration tests for refresh_materialized_views within an ARQ worker context.""" + + async def test_refresh_materialized_views_arq_integration( + self, arq_redis, arq_worker, standalone_worker_context, session + ): + """Integration test that runs refresh_materialized_views end-to-end using ARQ context.""" + await arq_redis.enqueue_job("refresh_materialized_views") + await arq_worker.async_run() + await arq_worker.run_check() + + job = session.execute( + select(JobRun).where(JobRun.job_function == "refresh_materialized_views") + ).scalar_one_or_none() + assert job is not None + assert job.status == JobStatus.SUCCEEDED + assert job.job_type == "cron_job" + + +############################################################################################################################################ +# refresh_published_variants_view +############################################################################################################################################ + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestRefreshPublishedVariantsViewUnit: + """Unit tests for the refresh_published_variants_view function.""" + + async def test_refresh_published_variants_view_calls_refresh_function( + self, mock_worker_ctx, mock_job_manager, mock_job_run + ): + """Test that refresh_published_variants_view calls the refresh function.""" + mock_job_run.job_params = {"correlation_id": "test-corr-id"} + + with ( + patch.object(PublishedVariantsMV, "refresh") as mock_refresh, + patch("mavedb.worker.jobs.data_management.views.validate_job_params"), + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), + ): + result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_refresh.assert_called_once_with(mock_job_manager.db) + assert result == {"status": "ok", "data": {}, "exception": None} + + async def test_refresh_published_variants_view_updates_progress( + self, mock_worker_ctx, mock_job_manager, mock_job_run + ): + """Test that refresh_published_variants_view updates progress correctly.""" + mock_job_run.job_params = {"correlation_id": "test-corr-id"} + + with ( + patch.object(PublishedVariantsMV, "refresh"), + patch("mavedb.worker.jobs.data_management.views.validate_job_params"), + patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), + ): + result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) + + expected_calls = [ + call(0, 100, "Starting refresh of published variants materialized view."), + call(100, 100, "Completed refresh of published variants materialized view."), + ] + mock_update_progress.assert_has_calls(expected_calls) + assert result == {"status": "ok", "data": {}, "exception": None} + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshPublishedVariantsViewIntegration: + """Integration tests for the refresh_published_variants_view function and decorator logic.""" + + @pytest.fixture() + def setup_refresh_job_run(self, session): + """Add a refresh_published_variants_view job run to the DB before each test.""" + job_run = JobRun( + job_type="data_management", + job_function="refresh_published_variants_view", + status=JobStatus.PENDING, + job_params={"correlation_id": "test-corr-id"}, + ) + session.add(job_run) + session.commit() + return job_run + + async def test_refresh_published_variants_view_integration_standalone( + self, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that runs refresh_published_variants_view end-to-end.""" + # Flush will be called implicitly when the transaction is committed + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.SUCCEEDED + assert result == {"status": "ok", "data": {}, "exception": None} + + async def test_refresh_published_variants_view_integration_pipeline( + self, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that runs refresh_published_variants_view end-to-end.""" + # Create a pipeline for the job run and associate it + pipeline = Pipeline( + name="Test Pipeline for Published Variants View Refresh", + ) + session.add(pipeline) + session.commit() + session.refresh(pipeline) + setup_refresh_job_run.pipeline_id = pipeline.id + session.add(setup_refresh_job_run) + session.commit() + + # Flush will be called implicitly when the transaction is committed + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.SUCCEEDED + assert result == {"status": "ok", "data": {}, "exception": None} + session.refresh(pipeline) + assert pipeline.status == PipelineStatus.SUCCEEDED + + async def test_refresh_published_variants_view_handles_exceptions( + self, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that ensures exceptions during refresh are handled properly.""" + with ( + patch.object( + PublishedVariantsMV, + "refresh", + side_effect=Exception("Test exception during published variants view refresh"), + ), + TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + mock_send_slack_error.assert_called_once() + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.FAILED + assert setup_refresh_job_run.error_message == "Test exception during published variants view refresh" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + async def test_refresh_published_variants_view_requires_params( + self, setup_refresh_job_run, standalone_worker_context, session + ): + """Integration test that ensures required job params are validated.""" + setup_refresh_job_run.job_params = {} # Clear required params + session.add(setup_refresh_job_run) + session.commit() + + with ( + TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + mock_send_slack_error.assert_called_once() + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.FAILED + assert "Job has no job_params defined" in setup_refresh_job_run.error_message + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshPublishedVariantsViewArqContext: + """Integration tests for refresh_published_variants_view within an ARQ worker context.""" + + @pytest.fixture() + def setup_refresh_job_run(self, session): + """Add a refresh_published_variants_view job run to the DB before each test.""" + job_run = JobRun( + job_type="data_management", + job_function="refresh_published_variants_view", + status=JobStatus.PENDING, + job_params={"correlation_id": "test-corr-id"}, + ) + session.add(job_run) + session.commit() + return job_run + + async def test_refresh_published_variants_view_arq_integration( + self, arq_redis, arq_worker, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that runs refresh_published_variants_view end-to-end using ARQ context.""" + await arq_redis.enqueue_job("refresh_published_variants_view", setup_refresh_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/network/test_clingen.py b/tests/worker/jobs/external_services/network/test_clingen.py new file mode 100644 index 00000000..2bd8645a --- /dev/null +++ b/tests/worker/jobs/external_services/network/test_clingen.py @@ -0,0 +1,141 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from unittest.mock import patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.mapped_variant import MappedVariant +from tests.helpers.util.setup.worker import create_mappings_in_score_set + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +# XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. +@pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.network +class TestE2EClingenSubmitScoreSetMappingsToCar: + """End-to-end tests for ClinGen CAR submission jobs.""" + + async def test_clingen_car_submission_e2e( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + mock_s3_client, + sample_score_set, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_pipeline, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + """Test the end-to-end flow of submitting score set mappings to ClinGen CAR.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network", + ), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testuser"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the submission job was completed successfully + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify that variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id is not None + + +# XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. +@pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.network +class TestE2EClingenSubmitScoreSetMappingsToLdh: + """End-to-end tests for ClinGen LDH submission jobs.""" + + async def test_clingen_ldh_submission_e2e( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + mock_s3_client, + sample_score_set, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_pipeline, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + """Test the end-to-end flow of submitting score set mappings to ClinGen LDH.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to simulate all submissions failing + with ( + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), + patch("mavedb.lib.clingen.constants.LDH_ACCESS_ENDPOINT", "https://genboree.org/ldh-stg/srvc"), + patch("mavedb.lib.clingen.constants.CLIN_GEN_TENANT", "dev-clingen"), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the submission job succeeded + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/network/test_clinvar.py b/tests/worker/jobs/external_services/network/test_clinvar.py new file mode 100644 index 00000000..54ae2fff --- /dev/null +++ b/tests/worker/jobs/external_services/network/test_clinvar.py @@ -0,0 +1,48 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from sqlalchemy import select + +from mavedb.models.clinical_control import ClinicalControl +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus, JobStatus +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.network +@pytest.mark.slow +class TestE2ERefreshClinvarControls: + async def test_refresh_clinvar_controls_e2e( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + setup_sample_variants_with_caid, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test the end-to-end flow of refreshing ClinVar clinical controls.""" + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that clinical controls were added successfully + clinical_controls = session.scalars(select(ClinicalControl)).all() + assert len(clinical_controls) == 1 + assert clinical_controls[0].db_identifier == "3045425" + + # Verify that annotation status was added + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == AnnotationStatus.SUCCESS + assert annotation_statuses[0].annotation_type == AnnotationType.CLINVAR_CONTROL + + # Verify that the job run was completed successfully + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/network/test_uniprot.py b/tests/worker/jobs/external_services/network/test_uniprot.py new file mode 100644 index 00000000..506eb20f --- /dev/null +++ b/tests/worker/jobs/external_services/network/test_uniprot.py @@ -0,0 +1,66 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from tests.helpers.constants import TEST_REFSEQ_IDENTIFIER + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.network +class TestE2EUniprotMappingJobs: + """End-to-end tests for UniProt mapping jobs.""" + + async def test_uniprot_mapping_jobs_e2e( + self, + session, + arq_redis, + arq_worker, + sample_score_set, + with_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_polling_job_for_submission_run_in_pipeline, + ): + """Test the end-to-end flow of submitting and polling UniProt mapping jobs.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [TEST_REFSEQ_IDENTIFIER]}} + session.commit() + + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + submitted_jobs = sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] + assert "1" in submitted_jobs + assert submitted_jobs["1"]["job_id"] is not None + assert submitted_jobs["1"]["accession"] == TEST_REFSEQ_IDENTIFIER + + # Verify that polling job params have been updated correctly + session.refresh(sample_polling_job_for_submission_run_in_pipeline) + assert sample_polling_job_for_submission_run_in_pipeline.job_params["mapping_jobs"] == { + "1": {"job_id": submitted_jobs["1"]["job_id"], "accession": TEST_REFSEQ_IDENTIFIER} + } + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job has run and is succeeded (pipeline ctx) + session.refresh(sample_polling_job_for_submission_run_in_pipeline) + assert sample_polling_job_for_submission_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is running + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py new file mode 100644 index 00000000..365f9483 --- /dev/null +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -0,0 +1,2293 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import call, patch + +from sqlalchemy import select + +from mavedb.lib.exceptions import LDHSubmissionFailureError +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.variant import Variant +from mavedb.models.variant_annotation_status import VariantAnnotationStatus +from mavedb.worker.jobs.external_services.clingen import ( + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, +) +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST +from tests.helpers.util.setup.worker import create_mappings_in_score_set + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToCarUnit: + """Tests for the Clingen submit_score_set_mappings_to_car function.""" + + async def test_submit_score_set_mappings_to_car_submission_disabled( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + ): + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", False), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "ClinGen submission is disabled. Skipping CAR submission.") + assert result["status"] == "skipped" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_no_mappings( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + ): + """Test submitting score set mappings to ClinGen when there are no mappings.""" + with ( + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "No mapped variants to submit to CAR. Skipped submission.") + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_submission_endpoint_not_set( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + ): + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", ""), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with( + 100, 100, "CAR submission endpoint not configured. Can't complete submission." + ) + assert result["status"] == "failed" + assert isinstance(result["exception"], ValueError) + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_no_registered_alleles( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return no registered alleles + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=[], + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + assert ann.annotation_type == "clingen_allele_id" + + async def test_submit_score_set_mappings_to_car_no_linked_alleles( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles that do not match submitted HGVS + registered_alleles_mock = [ + {"@id": "CA123456", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>C"}]}, + {"@id": "CA234567", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>G"}]}, + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + assert ann.annotation_type == "clingen_allele_id" + + async def test_submit_score_set_mappings_to_car_repeated_hgvs( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles with repeated HGVS + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": "CA_DUPLICATE", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mapped_variants[0].post_mapped)}], + } + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + # Patch get_hgvs_from_post_mapped to return the same HGVS for all variants + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + return_value=get_hgvs_from_post_mapped(mapped_variants[0].post_mapped), + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id == "CA_DUPLICATE" + + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + assert ann.annotation_type == "clingen_allele_id" + + async def test_submit_score_set_mappings_to_car_hgvs_not_found( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Get the mapped variants from score set before submission + mapped_variants = session.scalars( + select(MappedVariant) + .join(Variant) + .where(Variant.score_set_id == submit_score_set_mappings_to_car_sample_job_run.job_params["score_set_id"]) + ).all() + + # Patch ClinGenAlleleRegistryService to return registered alleles + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + # Patch get_hgvs_from_post_mapped to not find any HGVS in registered alleles + patch("mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + assert ann.annotation_type == "clingen_allele_id" + + async def test_submit_score_set_mappings_to_car_propagates_exception( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + pytest.raises(Exception) as exc_info, + ): + await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + assert str(exc_info.value) == "ClinGen service error" + + async def test_submit_score_set_mappings_to_car_success( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + sample_score_set, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Get the mapped variants from score set before submission + mapped_variants = session.scalars( + select(MappedVariant).join(Variant).where(Variant.score_set_id == sample_score_set.id) + ).all() + assert len(mapped_variants) == 4 + + # Patch ClinGenAlleleRegistryService to return registered alleles + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + assert ann.annotation_type == "clingen_allele_id" + + async def test_submit_score_set_mappings_to_car_updates_progress( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + sample_score_set, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Get the mapped variants from score set before submission + mapped_variants = session.scalars( + select(MappedVariant).join(Variant).where(Variant.score_set_id == sample_score_set.id) + ).all() + assert len(mapped_variants) == 4 + + # Patch ClinGenAlleleRegistryService to return registered alleles + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), + ) + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting CAR mapped resource submission."), + call(10, 100, "Preparing 4 mapped variants for CAR submission."), + call(15, 100, "Submitting mapped variants to CAR."), + call(60, 100, "Processing registered alleles from CAR."), + call(95, 100, "Processed 4 of 4 registered alleles."), + call(100, 100, "Completed CAR mapped resource submission."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToCarIntegration: + """Integration tests for the Clingen submit_score_set_mappings_to_car function.""" + + async def test_submit_score_set_mappings_to_car_independent_ctx( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == len(mapped_variants) + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_pipeline_ctx( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + submit_score_set_mappings_to_car_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + + assert result["status"] == "ok" + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == len(mapped_variants) + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_submission_disabled( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", False), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "skipped" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SKIPPED + + async def test_submit_score_set_mappings_to_car_no_submission_endpoint( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", ""), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "failed" + assert isinstance(result["exception"], ValueError) + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED + + async def test_submit_score_set_mappings_to_car_no_mappings( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + ): + """Test submitting score set mappings to ClinGen when there are no mappings.""" + with ( + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_no_registered_alleles( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return no registered alleles + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=[], + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_no_linked_alleles( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles that do not match submitted HGVS + registered_alleles_mock = [ + {"@id": "CA123456", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>C"}]}, + {"@id": "CA234567", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>G"}]}, + ] + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_propagates_exception_to_decorator( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + assert str(result["exception"]) == "ClinGen service error" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToCarArqContext: + """Tests for the Clingen submit_score_set_mappings_to_car function with ARQ context.""" + + async def test_submit_score_set_mappings_to_car_with_arq_context_independent( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + async def test_submit_score_set_mappings_to_car_with_arq_context_pipeline( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + submit_score_set_mappings_to_car_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handling_independent( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED + assert submit_score_set_mappings_to_car_sample_job_run.error_message == "ClinGen service error" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 0 + + async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handling_pipeline( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + submit_score_set_mappings_to_car_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.FAILED + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.error_message == "ClinGen service error" + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.FAILED + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToLdhUnit: + """Unit tests for the Clingen submit_score_set_mappings_to_car function.""" + + async def test_submit_score_set_mappings_to_ldh_no_variants( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), + ) + + mock_update_progress.assert_called_with(100, 100, "No mapped variants to submit to LDH. Skipping submission.") + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_ldh_all_submissions_failed( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_submission_failure(*args, **kwargs): + return ([], [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 4) + + # Patch ClinGenLdhService to simulate all submissions failing + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_failure(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), + ) + + assert result["status"] == "failed" + assert isinstance(result["exception"], LDHSubmissionFailureError) + mock_update_progress.assert_called_with(100, 100, "All mapped variant submissions to LDH failed.") + + async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise HGVS not found exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", return_value=None), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), + ) + + mock_update_progress.assert_called_with( + 100, 100, "No valid mapped variants to submit to LDH. Skipping submission." + ) + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_ldh_propagates_exception( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + pytest.raises(Exception) as exc_info, + ): + await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), + ) + + assert str(exc_info.value) == "LDH service error" + + async def test_submit_score_set_mappings_to_ldh_partial_submission( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_partial_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants[2:], start=1) + ], + [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 2, + ) + + # Patch ClinGenLdhService to simulate partial submission success + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_partial_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_called_with( + 100, 100, "Finalized LDH mapped resource submission (2 successes, 2 failures)." + ) + + async def test_submit_score_set_mappings_to_ldh_all_successful_submission( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_successful_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], + [], + ) + + # Patch ClinGenLdhService to simulate all submissions succeeding + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_successful_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_called_with( + 100, 100, "Finalized LDH mapped resource submission (4 successes, 0 failures)." + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToLdhIntegration: + """Integration tests for the Clingen submit_score_set_mappings_to_ldh function.""" + + async def test_submit_score_set_mappings_to_ldh_independent( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_pipeline_ctx( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + submit_score_set_mappings_to_ldh_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + + assert result["status"] == "ok" + + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_propagates_exception_to_decorator( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + assert str(result["exception"]) == "LDH service error" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + + async def test_submit_score_set_mappings_to_ldh_no_linked_alleles( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_no_linked_alleles_submission(*args, **kwargs): + return ([], []) + + # Patch ClinGenLdhService to simulate no linked alleles found + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_no_linked_alleles_submission(), + ), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify annotation statuses were created with failures + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise HGVS not found exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_all_submissions_failed( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_submission_failure(*args, **kwargs): + return ([], [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 4) + + # Patch ClinGenLdhService to simulate all submissions failing + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_failure(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "failed" + assert isinstance(result["exception"], LDHSubmissionFailureError) + + # Verify annotation statuses were created with failures + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + + # Verify the job status is updated in the database + # TODO:XXX: Change status to 'failed' once decorator supports it + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + + async def test_submit_score_set_mappings_to_ldh_partial_submission( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_partial_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": variants[0].urn, + "ldhId": f"LDH123400{1}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{1}", + }, + "status": {"code": 200, "name": "OK"}, + } + ], + [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 3, + ) + + # Patch ClinGenLdhService to simulate partial submission success + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_partial_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + success_count = 0 + failure_count = 0 + for ann in annotation_statuses: + if ann.status == "success": + success_count += 1 + elif ann.status == "failed": + failure_count += 1 + + assert success_count == 1 + assert failure_count == 3 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_all_successful_submission( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], + [], + ) + + # Patch ClinGenLdhService to simulate all submissions succeeding + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToLdhArqIntegration: + """ARQ Integration tests for the Clingen submit_score_set_mappings_to_ldh function.""" + + async def test_submit_score_set_mappings_to_ldh_independent( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_with_arq_context_in_pipeline( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + submit_score_set_mappings_to_ldh_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + variants = session.scalars(select(Variant)).all() + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handling( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + assert submit_score_set_mappings_to_ldh_sample_job_run.error_message == "LDH service error" + + async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handling_pipeline_ctx( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + submit_score_set_mappings_to_ldh_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.FAILED + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.error_message == "LDH service error" + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.FAILED diff --git a/tests/worker/jobs/external_services/test_clinvar.py b/tests/worker/jobs/external_services/test_clinvar.py new file mode 100644 index 00000000..a7eeb6f2 --- /dev/null +++ b/tests/worker/jobs/external_services/test_clinvar.py @@ -0,0 +1,1470 @@ +# ruff: noqa: E402 + +import pytest +import requests + +from mavedb.models.clinical_control import ClinicalControl +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus, JobStatus, PipelineStatus +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + +pytest.importorskip("arq") + +import gzip +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import call, patch + +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.external_services.clinvar import refresh_clinvar_controls +from mavedb.worker.lib.managers.job_manager import JobManager + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +async def mock_fetch_tsv(*args, **kwargs): + data = b"#AlleleID\tClinicalSignificance\tGeneSymbol\tReviewStatus\nVCV000000123\tbenign\tTEST\treviewed by expert panel" + return gzip.compress(data) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestRefreshClinvarControlsUnit: + """Tests for the refresh_clinvar_controls job function.""" + + async def test_refresh_clinvar_controls_invalid_month_raises( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + # edit the job run to have an invalid month + sample_refresh_clinvar_controls_job_run.job_params["month"] = 13 + session.commit() + + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + async def test_refresh_clinvar_controls_invalid_year_raises( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + # edit the job run to have an invalid year + sample_refresh_clinvar_controls_job_run.job_params["year"] = 1999 + session.commit() + + with pytest.raises(ValueError, match="ClinVar archived data is only available from February 2015 onwards."): + await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + async def test_refresh_clinvar_controls_propagates_exception_during_fetch( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + # Mock the fetch_clinvar_variant_data function to raise an exception + async def awaitable_exception(*args, **kwargs): + raise Exception("Network error") + + with ( + pytest.raises(Exception, match="Network error"), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=awaitable_exception(), + ), + ): + await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + async def test_refresh_clinvar_controls_no_mapped_variants( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test that the job completes successfully when there are no mapped variants.""" + + async def awaitable_noop(*args, **kwargs): + return {} + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=awaitable_noop(), + ), + patch("mavedb.worker.jobs.external_services.clinvar.parse_clinvar_variant_summary"), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + async def test_refresh_clinvar_controls_no_variants_have_caids( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test that the job completes successfully when no variants have CAIDs.""" + # Add a variant without a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:test-variant-no-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.2G>A", + hgvs_pro="NP_000000.1:p.Val2Ile", + data={"hgvs_c": "NM_000000.1:c.2G>A", "hgvs_p": "NP_000000.1:p.Val2Ile"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + with patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant without a CAID + variant_no_caid = ( + session.query(VariantAnnotationStatus).filter(VariantAnnotationStatus.variant_id == variant.id).one() + ) + assert variant_no_caid.status == AnnotationStatus.SKIPPED + assert variant_no_caid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert variant_no_caid.error_message == "Mapped variant does not have an associated ClinGen allele ID." + + async def test_refresh_clinvar_controls_variants_are_multivariants( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job completes successfully when all variants are multi-variant CAIDs.""" + # Update the mapped variant to have a multi-variant CAID + mapped_variant = session.query(MappedVariant).first() + mapped_variant.clingen_allele_id = "CA-MULTI-001,CA-MULTI-002" + session.commit() + + with patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the multi-variant CAID + variant_with_multicid = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_with_multicid.status == AnnotationStatus.SKIPPED + assert variant_with_multicid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert ( + variant_with_multicid.error_message + == "Multi-variant ClinGen allele IDs cannot be associated with ClinVar data." + ) + + async def test_refresh_clinvar_controls_clingen_api_failure( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles ClinGen API failures gracefully.""" + + # Mock the get_associated_clinvar_allele_id function to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=requests.exceptions.RequestException("ClinGen API error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to ClinGen API failure + mapped_variant = session.query(MappedVariant).first() + variant_with_api_failure = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_with_api_failure.status == AnnotationStatus.FAILED + assert variant_with_api_failure.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "Failed to retrieve ClinVar allele ID from ClinGen API" in variant_with_api_failure.error_message + + async def test_refresh_clinvar_controls_no_associated_clinvar_allele_id( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles no associated ClinVar Allele ID gracefully.""" + + # Mock the get_associated_clinvar_allele_id function to return None + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value=None, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no associated ClinVar Allele ID + mapped_variant = session.query(MappedVariant).first() + variant_no_clinvar_allele = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_allele.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_allele.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar allele ID found for ClinGen allele ID" in variant_no_clinvar_allele.error_message + + async def test_refresh_clinvar_controls_no_clinvar_data_found( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles no ClinVar data found for the associated ClinVar Allele ID.""" + + async def mock_fetch_tsv(*args, **kwargs): + data = b"#AlleleID\tClinicalSignificance\tGeneSymbol\tReviewStatus\nVCV000000001\tbenign\tTEST\treviewed by expert panel" + return gzip.compress(data) + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no ClinVar data found + mapped_variant = session.query(MappedVariant).first() + variant_no_clinvar_data = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_data.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_data.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar data found for ClinVar allele ID" in variant_no_clinvar_data.error_message + + async def test_refresh_clinvar_controls_successful_annotation_existing_control( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job successfully annotates a variant with ClinVar control data.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + mapped_variant = session.query(MappedVariant).first() + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + async def test_refresh_clinvar_controls_successful_annotation_new_control( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test that the job successfully annotates a variant with ClinVar control data when no prior status exists.""" + # Add a variant and mapped variant to the database with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:test-variant-with-caid-2", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.3C>T", + hgvs_pro="NP_000000.1:p.Ala3Val", + data={"hgvs_c": "NM_000000.1:c.3C>T", "hgvs_p": "NP_000000.1:p.Ala3Val"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA124", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + async def test_refresh_clinvar_controls_idempotent_run( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that running the job multiple times does not create duplicate annotation statuses.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[mock_fetch_tsv(), mock_fetch_tsv()], + ), + ): + # First run + result1 = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + session.commit() + + # Second run + result2 = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result1["status"] == "ok" + assert result2["status"] == "ok" + + # Verify only one clinical control annotation exists for the variant + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 1 + + # Verify two annotated variants exist but both reflect the same successful annotation, and only + # one is current + annotated_variants = session.query(VariantAnnotationStatus).all() + assert len(annotated_variants) == 2 + statuses = [av.status for av in annotated_variants] + assert statuses.count(AnnotationStatus.SUCCESS) == 2 + current_statuses = [av for av in annotated_variants if av.current] + assert len(current_statuses) == 1 + + async def test_refresh_clinvar_controls_partial_failure( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles partial failures gracefully.""" + + variant1, mapped_variant1 = setup_sample_variants_with_caid + + # Add an additional mapped variant to the database with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant2 = Variant( + urn="urn:variant:test-variant-with-caid-2", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.4G>C", + hgvs_pro="NP_000000.1:p.Gly4Ala", + data={"hgvs_c": "NM_000000.1:c.4G>C", "hgvs_p": "NP_000000.1:p.Gly4Ala"}, + ) + session.add(variant2) + session.commit() + mapped_variant2 = MappedVariant( + variant_id=variant2.id, + clingen_allele_id="CA125", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant2) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to raise an exception for the first call + def side_effect_get_associated_clinvar_allele_id(clingen_allele_id): + if clingen_allele_id == "CA125": + raise requests.exceptions.RequestException("ClinGen API error") + return "VCV000000123" + + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=side_effect_get_associated_clinvar_allele_id, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify annotation statuses for both variants + variant_with_api_failure = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant2.variant_id) + .one() + ) + assert variant_with_api_failure.status == AnnotationStatus.FAILED + assert variant_with_api_failure.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "Failed to retrieve ClinVar allele ID from ClinGen API" in variant_with_api_failure.error_message + + annotated_variant2 = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant1.variant_id) + .one() + ) + assert annotated_variant2.status == AnnotationStatus.SUCCESS + assert annotated_variant2.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant2.error_message is None + + async def test_refresh_clinvar_controls_updates_progress( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job updates progress correctly.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting ClinVar clinical control refresh for version 01_2026."), + call(1, 100, "Fetching ClinVar variant summary TSV data."), + call(10, 100, "Fetched and parsed ClinVar variant summary TSV data."), + call(10, 100, "Refreshing ClinVar data for 1 variants (0 completed)."), + call(100, 100, "Completed ClinVar clinical control refresh."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestRefreshClinvarControlsIntegration: + """Integration tests for the refresh_clinvar_controls job function.""" + + async def test_refresh_clinvar_controls_no_mapped_variants( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job completes successfully when there are no mapped variants.""" + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify no controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_no_variants_with_caid( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job completes successfully when no variants have CAIDs.""" + # Add a variant without a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-no-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.5T>A", + hgvs_pro="NP_000000.1:p.Leu5Gln", + data={"hgvs_c": "NM_000000.1:c.5T>A", "hgvs_p": "NP_000000.1:p.Leu5Gln"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant without a CAID + variant_no_caid = ( + session.query(VariantAnnotationStatus).filter(VariantAnnotationStatus.variant_id == variant.id).one() + ) + assert variant_no_caid.status == AnnotationStatus.SKIPPED + assert variant_no_caid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert variant_no_caid.error_message == "Mapped variant does not have an associated ClinGen allele ID." + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controlsvariants_are_multivariants( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job completes successfully when all variants are multi-variant CAIDs.""" + # Add a variant with a multi-variant CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-multicid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.6A>G", + hgvs_pro="NP_000000.1:p.Thr6Ala", + data={"hgvs_c": "NM_000000.1:c.6A>G", "hgvs_p": "NP_000000.1:p.Thr6Ala"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA-MULTI-003,CA-MULTI-004", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the multi-variant CAID + variant_with_multicid = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_with_multicid.status == AnnotationStatus.SKIPPED + assert variant_with_multicid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert ( + variant_with_multicid.error_message + == "Multi-variant ClinGen allele IDs cannot be associated with ClinVar data." + ) + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_no_associated_clinvar_allele_id( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job handles no associated ClinVar Allele ID gracefully.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.7C>A", + hgvs_pro="NP_000000.1:p.Ser7Tyr", + data={"hgvs_c": "NM_000000.1:c.7C>A", "hgvs_p": "NP_000000.1:p.Ser7Tyr"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA126", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return None + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value=None, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no associated ClinVar Allele ID + variant_no_clinvar_allele = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_allele.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_allele.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar allele ID found for ClinGen allele ID" in variant_no_clinvar_allele.error_message + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_no_clinvar_data( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job handles no ClinVar data found for the associated ClinVar Allele ID.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.8G>T", + hgvs_pro="NP_000000.1:p.Val8Phe", + data={"hgvs_c": "NM_000000.1:c.8G>T", "hgvs_p": "NP_000000.1:p.Val8Phe"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA127", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000001", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no ClinVar data found + variant_no_clinvar_data = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_data.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_data.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar data found for ClinVar allele ID" in variant_no_clinvar_data.error_message + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_successful_annotation_existing_control( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job successfully annotates a variant with ClinVar control data.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.9A>C", + hgvs_pro="NP_000000.1:p.Lys9Thr", + data={"hgvs_c": "NM_000000.1:c.9A>C", "hgvs_p": "NP_000000.1:p.Lys9Thr"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA128", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + clinical_control = ClinicalControl( + db_name="ClinVar", + db_identifier="VCV000000123", + clinical_significance="likely pathogenic", + gene_symbol="TEST", + clinical_review_status="criteria provided, single submitter", + db_version="01_2026", + ) + session.add(clinical_control) + session.commit() + + mapped_variant.clinical_controls.append(clinical_control) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + # Verify the clinical control was updated + session.refresh(clinical_control) + assert clinical_control.clinical_significance == "benign" + assert clinical_control.clinical_review_status == "reviewed by expert panel" + assert mapped_variant in clinical_control.mapped_variants + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_successful_annotation_new_control( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job successfully annotates a variant with ClinVar control data when no prior status exists.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.10C>G", + hgvs_pro="NP_000000.1:p.Pro10Arg", + data={"hgvs_c": "NM_000000.1:c.10C>G", "hgvs_p": "NP_000000.1:p.Pro10Arg"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA129", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + # Verify the clinical control was added + clinical_control = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant)).one() + ) + assert clinical_control.db_identifier == "VCV000000123" + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_successful_annotation_pipeline_context( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_pipeline, + sample_refresh_clinvar_controls_job_in_pipeline, + ): + """Integration test: job successfully annotates a variant with ClinVar control data in a pipeline context.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_in_pipeline.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.12G>A", + hgvs_pro="NP_000000.1:p.Met12Ile", + data={"hgvs_c": "NM_000000.1:c.12G>A", "hgvs_p": "NP_000000.1:p.Met12Ile"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA130", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_in_pipeline.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + # Verify the clinical control was added + clinical_control = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant)).one() + ) + assert clinical_control.db_identifier == "VCV000000123" + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_in_pipeline) + assert sample_refresh_clinvar_controls_job_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline is marked as completed + session.refresh(sample_refresh_clinvar_controls_pipeline) + assert sample_refresh_clinvar_controls_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_idempotent_run( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: running the job multiple times does not create duplicate annotation statuses.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[mock_fetch_tsv(), mock_fetch_tsv()], + ), + ): + # First run + result1 = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + session.commit() + # reset the job run status to pending for the second run + sample_refresh_clinvar_controls_job_run.status = JobStatus.PENDING + session.commit() + + # Second run + result2 = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result1["status"] == "ok" + assert result2["status"] == "ok" + + # Verify only one clinical control annotation exists for the variant + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 1 + + # Verify two annotated variants exist but both reflect the same successful annotation, and only + # one is current + annotated_variants = session.query(VariantAnnotationStatus).all() + assert len(annotated_variants) == 2 + statuses = [av.status for av in annotated_variants] + assert statuses.count(AnnotationStatus.SUCCESS) == 2 + current_statuses = [av for av in annotated_variants if av.current] + assert len(current_statuses) == 1 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_partial_failure( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job handles partial failures gracefully.""" + + variant1, mapped_variant1 = setup_sample_variants_with_caid + # Add an additional mapped variant to the database with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant2 = Variant( + urn="urn:variant:integration-test-variant-with-caid-2", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.11G>C", + hgvs_pro="NP_000000.1:p.Gly11Ala", + data={"hgvs_c": "NM_000000.1:c.11G>C", "hgvs_p": "NP_000000.1:p.Gly11Ala"}, + ) + session.add(variant2) + session.commit() + mapped_variant2 = MappedVariant( + variant_id=variant2.id, + clingen_allele_id="CA130", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant2) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to raise an exception for the first call + def side_effect_get_associated_clinvar_allele_id(clingen_allele_id): + if clingen_allele_id == "CA130": + raise requests.exceptions.RequestException("ClinGen API error") + return "VCV000000123" + + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=side_effect_get_associated_clinvar_allele_id, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify annotation statuses for both variants + variant_with_api_failure = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant2.variant_id) + .one() + ) + assert variant_with_api_failure.status == AnnotationStatus.FAILED + assert variant_with_api_failure.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "Failed to retrieve ClinVar allele ID from ClinGen API" in variant_with_api_failure.error_message + + annotated_variant2 = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant1.variant_id) + .one() + ) + assert annotated_variant2.status == AnnotationStatus.SUCCESS + assert annotated_variant2.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant2.error_message is None + + # Verify a clinical control was added for the successfully annotated variant and not the unsuccessful one + clinical_control1 = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant1)).one() + ) + assert clinical_control1.db_identifier == "VCV000000123" + + clinical_control2 = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant2)).all() + ) + assert len(clinical_control2) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_propagates_exceptions_to_decorator( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that unexpected exceptions are propagated.""" + + # Mock the get_associated_clinvar_allele_id function to raise an unexpected exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=ValueError("Unexpected error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "exception" + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as failed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.FAILED + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshClinvarControlsArqContext: + """Tests for running the refresh_clinvar_controls job function within an ARQ worker context.""" + + async def test_refresh_clinvar_controls_with_arq_context_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job completes successfully within an ARQ worker context.""" + + # Patch external service calls + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) > 0 + + # Verify annotation status was created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == AnnotationStatus.SUCCESS + assert annotation_statuses[0].annotation_type == AnnotationType.CLINVAR_CONTROL + + # Verify that the job completed successfully + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_with_arq_context_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job completes successfully within an ARQ worker context in a pipeline context.""" + + # Patch external service calls + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) > 0 + + # Verify annotation status was created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == AnnotationStatus.SUCCESS + assert annotation_statuses[0].annotation_type == AnnotationType.CLINVAR_CONTROL + + # Verify that the job completed successfully + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + # Verify the pipeline is marked as completed + pass + + async def test_refresh_clinvar_controls_with_arq_context_exception_handling_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job handles exceptions properly within an ARQ worker context.""" + # Patch external service calls to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=ValueError("Unexpected error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as failed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.FAILED + + async def test_refresh_clinvar_controls_with_arq_context_exception_handling_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job handles exceptions properly within an ARQ worker context in a pipeline context.""" + # Patch external service calls to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=ValueError("Unexpected error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as failed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.FAILED + + # Verify the pipeline is marked as failed + pass diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py new file mode 100644 index 00000000..92f515c1 --- /dev/null +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -0,0 +1,500 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from unittest.mock import MagicMock, call, patch + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.gnomad_variant import GnomADVariant +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.variant_annotation_status import VariantAnnotationStatus +from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants +from mavedb.worker.lib.managers.job_manager import JobManager + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestLinkGnomadVariantsUnit: + """Unit tests for the link_gnomad_variants job.""" + + async def test_link_gnomad_variants_no_variants_with_caids( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + ): + """Test linking gnomAD variants when no mapped variants have CAIDs.""" + with patch.object(JobManager, "update_progress") as mock_update_progress: + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_any_call( + 100, 100, "No variants with CAIDs found to link to gnomAD variants. Nothing to do." + ) + + async def test_link_gnomad_variants_no_gnomad_matches( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test linking gnomAD variants when no gnomAD variants match the CAIDs.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value={}, + ), + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_any_call(100, 100, "Linked 0 mapped variants to gnomAD variants.") + + async def test_link_gnomad_variants_call_linking_method( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test that the linking method is called when gnomAD variants match CAIDs.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[MagicMock()], + ), + patch( + "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", + return_value=1, + ) as mock_linking_method, + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_linking_method.assert_called_once() + mock_update_progress.assert_any_call(100, 100, "Linked 1 mapped variants to gnomAD variants.") + + async def test_link_gnomad_variants_updates_progress( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test that progress updates are made during the linking process.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[MagicMock()], + ), + patch( + "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", + return_value=1, + ), + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting gnomAD mapped resource linkage."), + call(10, 100, "Found 1 variants with CAIDs to link to gnomAD variants."), + call(75, 100, "Found 1 gnomAD variants matching CAIDs."), + call(100, 100, "Linked 1 mapped variants to gnomAD variants."), + ] + ) + + async def test_link_gnomad_variants_propagates_exceptions( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test that exceptions during the linking process are propagated.""" + with ( + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + with pytest.raises(Exception) as exc_info: + await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert str(exc_info.value) == "Test exception" + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestLinkGnomadVariantsIntegration: + """Integration tests for the link_gnomad_variants job.""" + + async def test_link_gnomad_variants_no_variants_with_caids( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + ): + """Test the end-to-end functionality of the link_gnomad_variants job when no variants have CAIDs.""" + + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run.id) + assert result["status"] == "ok" + + # Verify that no gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify no annotations were rendered (since there were no variants with CAIDs) + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_no_matching_caids( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test the end-to-end functionality of the link_gnomad_variants job when no matching CAIDs are found.""" + # Update the created mapped variant to have a CAID that won't match any gnomAD data + mapped_variant = session.query(MappedVariant).first() + mapped_variant.clingen_allele_id = "NON_MATCHING_CAID" + session.commit() + + # Patch the athena engine to use the mock athena_engine fixture + with patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine): + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run.id) + + assert result["status"] == "ok" + + # Verify that no gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify a skipped annotation status was rendered (since there were variants with CAIDs) + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "skipped" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_successful_linking_independent( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test the end-to-end functionality of the link_gnomad_variants job with successful linking.""" + + # Patch the athena engine to use the mock athena_engine fixture + with patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine): + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run.id) + + assert result["status"] == "ok" + + # Verify that gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_successful_linking_pipeline( + self, + session, + with_populated_domain_data, + mock_worker_ctx, + sample_link_gnomad_variants_run_pipeline, + sample_link_gnomad_variants_pipeline, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test the end-to-end functionality of the link_gnomad_variants job with successful linking in a pipeline.""" + + # Patch the athena engine to use the mock athena_engine fixture + with patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine): + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run_pipeline.id) + + assert result["status"] == "ok" + + # Verify that gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run_pipeline) + assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.SUCCEEDED + + # Verify pipeline status updates + session.refresh(sample_link_gnomad_variants_pipeline) + assert sample_link_gnomad_variants_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_link_gnomad_variants_exceptions_handled_by_decorators( + self, + session, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test that exceptions during the linking process are handled by decorators.""" + + # Patch the athena engine to use the mock athena_engine fixture + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await link_gnomad_variants( + mock_worker_ctx, + sample_link_gnomad_variants_run.id, + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.FAILED + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestLinkGnomadVariantsArqContext: + """Tests for link_gnomad_variants job using the ARQ context fixture.""" + + async def test_link_gnomad_variants_with_arq_context_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_gnomad_linking_job, + athena_engine, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that the link_gnomad_variants job works with the ARQ context fixture.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + + # Verify that the job completed successfully + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_with_arq_context_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + athena_engine, + sample_link_gnomad_variants_run_pipeline, + sample_link_gnomad_variants_pipeline, + setup_sample_variants_with_caid, + ): + """Test that the link_gnomad_variants job works with the ARQ context fixture in a pipeline.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run_pipeline.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + + # Verify that the job completed successfully + session.refresh(sample_link_gnomad_variants_run_pipeline) + assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.SUCCEEDED + + # Verify pipeline status updates + session.refresh(sample_link_gnomad_variants_pipeline) + assert sample_link_gnomad_variants_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_link_gnomad_variants_with_arq_context_exception_handling_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_gnomad_linking_job, + athena_engine, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that exceptions in the link_gnomad_variants job are handled with the ARQ context fixture.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that no gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify no annotations were rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job failed + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.FAILED + + async def test_link_gnomad_variants_with_arq_context_exception_handling_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + athena_engine, + sample_link_gnomad_variants_pipeline, + sample_link_gnomad_variants_run_pipeline, + setup_sample_variants_with_caid, + ): + """Test that exceptions in the link_gnomad_variants job are handled with the ARQ context fixture.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run_pipeline.id) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that no gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify no annotations were rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job failed + session.refresh(sample_link_gnomad_variants_run_pipeline) + assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.FAILED + + # Verify that the pipeline failed + session.refresh(sample_link_gnomad_variants_pipeline) + assert sample_link_gnomad_variants_pipeline.status == PipelineStatus.FAILED diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py new file mode 100644 index 00000000..99ab3a07 --- /dev/null +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -0,0 +1,2048 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from unittest.mock import call, patch + +from mavedb.lib.exceptions import ( + NonExistentTargetGeneError, + UniprotAmbiguousMappingResultError, + UniprotMappingResultNotFoundError, + UniProtPollingEnqueueError, +) +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.target_gene import TargetGene +from mavedb.models.target_sequence import TargetSequence +from mavedb.worker.jobs.external_services.uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import ( + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + TEST_UNIPROT_SWISS_PROT_TYPE, + VALID_NT_ACCESSION, + VALID_UNIPROT_ACCESSION, +) + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestSubmitUniprotMappingJobsForScoreSetUnit: + """Unit tests for submit_uniprot_mapping_jobs_for_score_set function.""" + + async def test_submit_uniprot_mapping_jobs_no_targets( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when no target genes are present.""" + + # Ensure the sample score set has no target genes + sample_score_set.target_genes = [] + session.commit() + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, 100, "No target genes found. Skipped UniProt mapping job submission." + ) + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_no_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when no ACs are present in post mapped metadata.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No UniProt mapping jobs were submitted.") + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_too_many_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when too many ACs are present in post mapped metadata.""" + + # Arrange the post mapped metadata to have multiple ACs + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION, "P67890"]}} + session.commit() + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No UniProt mapping jobs were submitted.") + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_no_jobs_submitted( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when no jobs are submitted.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value=None, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No UniProt mapping jobs were submitted.") + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": None, "accession": VALID_NT_ACCESSION} + } + + async def test_submit_uniprot_mapping_jobs_api_failure_raises( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test handling of UniProt API failure during job submission.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch.object(JobManager, "update_progress"), + pytest.raises(Exception, match="UniProt API failure"), + ): + await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_raises_dependent_job_not_available( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test handling when dependent polling job is not available.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "Failed to submit UniProt mapping jobs.") + assert result["status"] == "failed" + assert isinstance(result["exception"], UniProtPollingEnqueueError) + + # Verify that the job metadata contains the submitted jobs (which were submitted before the error) + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + + async def test_submit_uniprot_mapping_jobs_successful_submission( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Test successful submission of UniProt mapping jobs.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch.object(JobManager, "update_progress"), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + assert job_result["status"] == "ok" + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + async def test_submit_uniprot_mapping_jobs_partial_submission( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Test partial submission of UniProt mapping jobs.""" + + # Add another target gene to the score set to simulate multiple submissions + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + session.add(new_target_gene) + session.commit() + + # Arrange the post mapped metadata to have a single AC for both target genes + target_gene_1 = sample_score_set.target_genes[0] + target_gene_1.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + target_gene_2 = new_target_gene + target_gene_2.post_mapped_metadata = {"protein": {"sequence_accessions": ["NM_000546"]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=["job_12345", None], + ), + patch.object(JobManager, "update_progress"), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + assert job_result["status"] == "ok" + + expected_submitted_jobs = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}, + "2": {"job_id": None, "accession": "NM_000546"}, + } + + # Verify that the job metadata contains both submitted and failed jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + async def test_submit_uniprot_mapping_jobs_updates_progress( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test that progress updates are made during UniProt mapping job submission.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting UniProt mapping job submission."), + call( + 95, 100, f"Submitted UniProt mapping job for target gene {sample_score_set.target_genes[0].name}." + ), + call(100, 100, "Completed submission of UniProt mapping jobs."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestSubmitUniprotMappingJobsForScoreSetIntegration: + """Integration tests for submit_uniprot_mapping_jobs_for_score_set function.""" + + async def test_submit_uniprot_mapping_jobs_success_independent_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_called_once() + assert job_result["status"] == "ok" + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run) + sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending (non-pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + + async def test_submit_uniprot_mapping_jobs_success_pipeline_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_jobs_pipeline, + with_dummy_polling_job_for_submission_run, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_dummy_polling_job_for_submission_run_in_pipeline, + sample_score_set, + ): + """Integration test for submitting UniProt mapping jobs in a pipeline context.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + + mock_submit_id_mapping.assert_called_once() + assert job_result["status"] == "ok" + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert ( + sample_dummy_polling_job_for_submission_run_in_pipeline.job_params["mapping_jobs"] + == expected_submitted_jobs + ) + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is now queued (pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert sample_dummy_polling_job_for_submission_run_in_pipeline.status == JobStatus.QUEUED + + # Verify that the pipeline run status is running + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.RUNNING + + async def test_submit_uniprot_mapping_jobs_no_targets( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when no target genes are present.""" + + # Ensure the sample score set has no target genes + sample_score_set.target_genes = [] + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_not_called() + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_no_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when no ACs are present in post mapped metadata.""" + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_not_called() + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_too_many_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when too many ACs are present in post mapped metadata.""" + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_not_called() + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_propagates_exceptions( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test to ensure exceptions during UniProt mapping job submission are propagated to decorators.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_.get("submitted_jobs") is None + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_no_jobs_submitted( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when no jobs are submitted.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value=None, + ), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": None, "accession": VALID_NT_ACCESSION} + } + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_partial_submission( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for partial submission of UniProt mapping jobs.""" + + # Add another target gene to the score set to simulate multiple submissions + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + session.add(new_target_gene) + session.commit() + + # Add accessions to both target genes' post mapped metadata + for idx, tg in enumerate(sample_score_set.target_genes): + tg.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION + f"{idx:05d}"]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=["job_12345", None], + ), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + assert job_result["status"] == "ok" + + expected_submitted_jobs = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION + "00000"}, + "2": {"job_id": None, "accession": VALID_NT_ACCESSION + "00001"}, + } + + # Verify that the job metadata contains both submitted and failed jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and params were updated correctly + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == expected_submitted_jobs + + async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Integration test to ensure error is raised to the decorator when dependent polling job is not available.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "failed" + assert isinstance(result["exception"], UniProtPollingEnqueueError) + + # Verify that the job metadata contains the job we submitted before the error + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + + # nothing to verify for dependent polling job since it does not exist + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestSubmitUniprotMappingJobsArqContext: + """Integration tests for submit_uniprot_mapping_jobs_for_score_set function in ARQ context.""" + + async def test_submit_uniprot_mapping_jobs_with_arq_context_independent( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run) + sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending (non-pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + + async def test_submit_uniprot_mapping_jobs_with_arq_context_pipeline( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_jobs_pipeline, + with_dummy_polling_job_for_submission_run, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_dummy_polling_job_for_submission_run_in_pipeline, + sample_score_set, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert ( + sample_dummy_polling_job_for_submission_run_in_pipeline.job_params["mapping_jobs"] + == expected_submitted_jobs + ) + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is now queued (pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert sample_dummy_polling_job_for_submission_run_in_pipeline.status == JobStatus.QUEUED + + # Verify that the pipeline run status is running + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.RUNNING + + async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_independent( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test to ensure exceptions during UniProt mapping job submission are propagated to decorators.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_.get("submitted_jobs") is None + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_pipeline( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_jobs_pipeline, + with_dummy_polling_job_for_submission_run, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_dummy_polling_job_for_submission_run_in_pipeline, + sample_score_set, + ): + """Integration test to ensure exceptions during UniProt mapping job submission are propagated to decorators.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_.get("submitted_jobs") is None + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.FAILED + + # Verify that the dependent polling job is now cancelled and no param changes were made + assert sample_dummy_polling_job_for_submission_run_in_pipeline.status == JobStatus.SKIPPED + assert sample_dummy_polling_job_for_submission_run_in_pipeline.job_params.get("mapping_jobs") == {} + + # Verify that the pipeline run status is failed + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.FAILED + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestPollUniprotMappingJobsForScoreSetUnit: + """Unit tests for poll_uniprot_mapping_jobs_for_score_set function.""" + + async def test_poll_uniprot_mapping_jobs_no_mapping_jobs( + self, + session, + mock_worker_ctx, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Ensure there are no mapping jobs in the polling job params + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = {} + session.commit() + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No mapping jobs found to poll.") + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # TODO:XXX -- We will eventually want to make sure the job indicates to the manager + # its desire to be retried. For now, we just verify that no changes are made + # when results are not ready. + async def test_poll_uniprot_mapping_jobs_results_not_ready( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=False, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_called_with(100, 100, "Completed polling of UniProt mapping jobs.") + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + async def test_poll_uniprot_mapping_jobs_no_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={"results": []}, # minimal response with no results + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(UniprotMappingResultNotFoundError), + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, 100, f"No UniProt ID found for accession {VALID_NT_ACCESSION}. Cannot add UniProt ID." + ) + + async def test_poll_uniprot_mapping_jobs_ambiguous_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={ + "results": [ + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": f"{VALID_UNIPROT_ACCESSION}", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": "P67890", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + ] + }, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(UniprotAmbiguousMappingResultError), + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, + 100, + f"Ambiguous UniProt ID mapping results for accession {VALID_NT_ACCESSION}. Cannot add UniProt ID.", + ) + + async def test_poll_uniprot_mapping_jobs_nonexistent_target( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job with a non-existent target gene ID + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "999": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(NonExistentTargetGeneError), + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, + 100, + f"Target gene ID 999 not found in score set {sample_score_set.urn}. Cannot add UniProt ID.", + ) + + async def test_poll_uniprot_mapping_jobs_successful_update( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_called_with(100, 100, "Completed polling of UniProt mapping jobs.") + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + async def test_poll_uniprot_mapping_jobs_partial_success( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have two mapping jobs + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}, + "2": {"job_id": "job_67890", "accession": "NONEXISTENT_AC"}, + } + session.commit() + + # Add another target gene to the score set to correspond to the second mapping job + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + session.add(new_target_gene) + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=[True, False], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + side_effect=[ + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, # Successful result for the first mapping job + {"results": []}, # No results for the second mapping job + ], + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_called_with(100, 100, "Completed polling of UniProt mapping jobs.") + + # Verify the target gene uniprot id has been updated for the successful mapping and + # remains None for the failed mapping + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + assert sample_score_set.target_genes[1].uniprot_id_from_mapped_metadata is None + + async def test_poll_uniprot_mapping_jobs_updates_progress( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have one mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_11111", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=[True, True, True], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + side_effect=[TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE], + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made incrementally + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting UniProt mapping job polling."), + call(95, 100, "Polled UniProt mapping job for target gene Sample Gene."), + call(100, 100, "Completed polling of UniProt mapping jobs."), + ] + ) + + # Verify the target gene uniprot ids have been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + async def test_poll_uniprot_mapping_jobs_propagates_exceptions( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + pytest.raises(Exception) as exc_info, + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=session, + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert str(exc_info.value) == "UniProt API failure" + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestPollUniprotMappingJobsForScoreSetIntegration: + """Integration tests for poll_uniprot_mapping_jobs_for_score_set function.""" + + async def test_poll_uniprot_mapping_jobs_success_independent_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_success_pipeline_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_poll_uniprot_mapping_jobs_pipeline, + sample_score_set, + sample_poll_uniprot_mapping_jobs_run_in_pipeline, + sample_poll_uniprot_mapping_jobs_pipeline, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_poll_uniprot_mapping_jobs_run_in_pipeline.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_poll_uniprot_mapping_jobs_run_in_pipeline.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) + assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded (this is the only job in the test pipeline) + session.refresh(sample_poll_uniprot_mapping_jobs_pipeline) + assert sample_poll_uniprot_mapping_jobs_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_no_mapping_jobs( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Ensure there are no mapping jobs in the polling job params + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = {} + session.commit() + + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job succeeded + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_partial_mapping_jobs( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have two mapping jobs + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}, + "2": {"job_id": None, "accession": "NONEXISTENT_AC"}, + } + session.commit() + + # Add another target gene to the score set to correspond to the second mapping job + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + session.add(new_target_gene) + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=[True], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + side_effect=[TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE], + ), + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id has been updated for the successful mapping and + # remains None for the mapping with no job id + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + assert sample_score_set.target_genes[1].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job succeeded + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_results_not_ready( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=False, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job succeeded + # TODO#XXX -- For now, we mark the job as succeeded even if no updates were made. + # In the future, we may want to have the job indicate it should be retried. + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_no_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={"results": []}, # minimal response with no results + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], UniprotMappingResultNotFoundError) + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + async def test_poll_uniprot_mapping_jobs_ambiguous_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={ + "results": [ + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": f"{VALID_UNIPROT_ACCESSION}", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": "P67890", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + ] + }, + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], UniprotAmbiguousMappingResultError) + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + async def test_poll_uniprot_mapping_jobs_nonexistent_target( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job with a non-existent target gene ID + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "999": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], NonExistentTargetGeneError) + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + async def test_poll_uniprot_mapping_jobs_propagates_exceptions_to_decorator( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestPollUniprotMappingJobsForScoreSetArqContext: + """Integration tests for poll_uniprot_mapping_jobs_for_score_set function with ARQ context.""" + + async def test_poll_uniprot_mapping_jobs_with_arq_context_independent( + self, + session, + arq_worker, + arq_redis, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", sample_polling_job_for_submission_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_with_arq_context_pipeline( + self, + session, + arq_worker, + arq_redis, + with_populated_domain_data, + with_poll_uniprot_mapping_jobs_pipeline, + sample_score_set, + sample_poll_uniprot_mapping_jobs_run_in_pipeline, + sample_poll_uniprot_mapping_jobs_pipeline, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_poll_uniprot_mapping_jobs_run_in_pipeline.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", + sample_poll_uniprot_mapping_jobs_run_in_pipeline.id, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) + assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded (this is the only job in the test pipeline) + session.refresh(sample_poll_uniprot_mapping_jobs_pipeline) + assert sample_poll_uniprot_mapping_jobs_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_independent( + self, + session, + arq_worker, + arq_redis, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", sample_polling_job_for_submission_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_pipeline( + self, + session, + arq_worker, + arq_redis, + mock_worker_ctx, + with_populated_domain_data, + with_poll_uniprot_mapping_jobs_pipeline, + sample_score_set, + sample_poll_uniprot_mapping_jobs_run_in_pipeline, + sample_poll_uniprot_mapping_jobs_pipeline, + ): + # Arrange the polling job params to have a single mapping job + sample_poll_uniprot_mapping_jobs_run_in_pipeline.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", + sample_poll_uniprot_mapping_jobs_run_in_pipeline.id, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that the polling job failed + session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) + assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.FAILED + + # Verify that the pipeline run status is failed + session.refresh(sample_poll_uniprot_mapping_jobs_pipeline) + assert sample_poll_uniprot_mapping_jobs_pipeline.status == PipelineStatus.FAILED + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None diff --git a/tests/worker/jobs/pipeline_management/test_start_pipeline.py b/tests/worker/jobs/pipeline_management/test_start_pipeline.py new file mode 100644 index 00000000..08179374 --- /dev/null +++ b/tests/worker/jobs/pipeline_management/test_start_pipeline.py @@ -0,0 +1,316 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from unittest.mock import call, patch + +from sqlalchemy import select + +from mavedb.lib.exceptions import PipelineNotFoundError +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.jobs.pipeline_management.start_pipeline import start_pipeline +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestStartPipelineUnit: + """Unit tests for starting pipelines.""" + + @pytest.fixture(autouse=True) + def setup_start_pipeline_job_run(self, session, with_dummy_pipeline, sample_dummy_pipeline): + """Fixture to ensure a start pipeline job run exists in the database.""" + job_run = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="start_pipeline", + job_function="start_pipeline", + ) + session.add(job_run) + session.commit() + + return job_run + + async def test_start_pipeline_raises_exception_when_no_pipeline_associated_with_job( + self, + session, + mock_worker_ctx, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline raises an exception when no pipeline is associated with the job.""" + + # Remove pipeline association from job run + setup_start_pipeline_job_run.pipeline_id = None + session.commit() + + result = await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + assert result["status"] == "exception" + assert isinstance(result["exception"], PipelineNotFoundError) + + async def test_start_pipeline_starts_pipeline_successfully( + self, + session, + mock_worker_ctx, + mock_pipeline_manager, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline completes successfully.""" + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.PipelineManager") as mock_pipeline_manager_class, + patch.object(PipelineManager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + result = await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + assert result["status"] == "ok" + mock_coordinate_pipeline.assert_called_once() + + async def test_start_pipeline_updates_progress( + self, + session, + mock_worker_ctx, + mock_pipeline_manager, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline updates job progress.""" + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.PipelineManager") as mock_pipeline_manager_class, + patch.object(PipelineManager, "coordinate_pipeline", return_value=None), + patch.object( + JobManager, + "update_progress", + return_value=None, + ) as mock_update_progress, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + result = await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + assert result["status"] == "ok" + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Coordinating pipeline for the first time."), + call(100, 100, "Initial pipeline coordination complete."), + ] + ) + + async def test_start_pipeline_raises_exception( + self, + session, + mock_worker_ctx, + mock_pipeline_manager, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline raises an exception.""" + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.PipelineManager") as mock_pipeline_manager_class, + patch.object( + PipelineManager, + "coordinate_pipeline", + side_effect=Exception("Simulated pipeline start failure"), + ), + pytest.raises(Exception, match="Simulated pipeline start failure"), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestStartPipelineIntegration: + """Integration tests for starting pipelines.""" + + async def test_start_pipeline_on_job_without_pipeline_fails( + self, + session, + mock_worker_ctx, + with_full_dummy_pipeline, + sample_dummy_pipeline_start, + ): + """Test that starting a pipeline on a job without an associated pipeline fails.""" + + sample_dummy_pipeline_start.pipeline_id = None + session.commit() + + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "exception" + mock_send_slack_error.assert_called_once() + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.FAILED + + async def test_start_pipeline_on_valid_job_succeeds_and_coordinates_pipeline( + self, session, mock_worker_ctx, with_full_dummy_pipeline, sample_dummy_pipeline_start, sample_dummy_pipeline + ): + """Test that starting a pipeline on a valid job succeeds and coordinates the pipeline.""" + + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "ok" + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.RUNNING + + async def test_start_pipeline_handles_exceptions_gracefully( + self, + session, + mock_worker_ctx, + with_full_dummy_pipeline, + sample_dummy_pipeline, + sample_dummy_pipeline_start, + ): + """Test that starting a pipeline handles exceptions gracefully.""" + # Mock a coordination failure during pipeline start. Realistically if this failed in pipeline start + # it would likely also fail during the final coordination attempt in the exception handler, but for testing purposes + # we only mock the initial failure here. In a real-world scenario, we'd likely have to rely on our alerting here and + # intervene manually or via a separate recovery job to fix the pipeline state. + real_coordinate_pipeline = PipelineManager.coordinate_pipeline + call_count = {"n": 0} + + async def custom_side_effect(*args, **kwargs): + if call_count["n"] == 0: + call_count["n"] += 1 + raise Exception("Simulated pipeline start failure") + return await real_coordinate_pipeline( + PipelineManager(session, session, sample_dummy_pipeline.id), *args, **kwargs + ) # Allow the final coordination attempt to proceed 'normally' + + with ( + patch( + "mavedb.worker.lib.managers.pipeline_manager.PipelineManager.coordinate_pipeline", + side_effect=custom_side_effect, + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "exception" + mock_send_slack_error.assert_called_once() + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.FAILED + + # Verify that the pipeline state is updated to CANCELLED + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.FAILED + + async def test_start_pipeline_no_jobs_in_pipeline( + self, + session, + mock_worker_ctx, + with_dummy_pipeline, + sample_dummy_pipeline_start, + sample_dummy_pipeline, + ): + """Test starting a pipeline that has no jobs defined.""" + + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "ok" + + # Verify that a JobRun was created for the start_pipeline job and it succeeded + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.SUCCEEDED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestStartPipelineArqContext: + """Test starting pipelines using an ARQ worker context.""" + + async def test_start_pipeline_with_arq_context( + self, + session, + arq_redis, + arq_worker, + with_full_dummy_pipeline, + sample_dummy_pipeline_start, + sample_dummy_pipeline, + ): + """Test starting a pipeline using an ARQ worker context.""" + + await arq_redis.enqueue_job("start_pipeline", sample_dummy_pipeline_start.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.RUNNING + + # Verify that other pipeline steps have been queued + pipeline_steps = ( + session.execute( + select(JobRun).where( + JobRun.pipeline_id == sample_dummy_pipeline.id, JobRun.id != sample_dummy_pipeline_start.id + ) + ) + .scalars() + .all() + ) + assert len(pipeline_steps) == 1 + assert pipeline_steps[0].job_type == "dummy_step" + assert pipeline_steps[0].status == JobStatus.QUEUED + + async def test_start_pipeline_with_arq_context_no_jobs_in_pipeline( + self, + session, + arq_redis, + arq_worker, + with_dummy_pipeline, + sample_dummy_pipeline_start, + sample_dummy_pipeline, + ): + """Test starting a pipeline with no jobs using an ARQ worker context.""" + + await arq_redis.enqueue_job("start_pipeline", sample_dummy_pipeline_start.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that a JobRun was created for the start_pipeline job and it succeeded + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.SUCCEEDED diff --git a/tests/worker/jobs/utils/test_setup.py b/tests/worker/jobs/utils/test_setup.py new file mode 100644 index 00000000..70c40759 --- /dev/null +++ b/tests/worker/jobs/utils/test_setup.py @@ -0,0 +1,34 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from unittest.mock import Mock + +from mavedb.models.job_run import JobRun +from mavedb.worker.jobs.utils.setup import validate_job_params + + +@pytest.mark.unit +def test_validate_job_params_success(): + job = Mock(spec=JobRun, job_params={"foo": 1, "bar": 2}) + + # Should not raise + validate_job_params(["foo", "bar"], job) + + +@pytest.mark.unit +def test_validate_job_params_missing_param(): + job = Mock(spec=JobRun, job_params={"foo": 1}) + + with pytest.raises(ValueError, match="Missing required job param: bar"): + validate_job_params(["foo", "bar"], job) + + +@pytest.mark.unit +def test_validate_job_params_no_params(): + job = Mock(spec=JobRun, job_params=None) + + with pytest.raises(ValueError, match="Job has no job_params defined."): + validate_job_params(["foo"], job) diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py new file mode 100644 index 00000000..b2b15fca --- /dev/null +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -0,0 +1,1400 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +import math +from unittest.mock import ANY, MagicMock, call, patch + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.models.variant import Variant +from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set +from mavedb.worker.lib.managers.job_manager import JobManager + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.unit +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_db_session_ctxmgr") +class TestCreateVariantsForScoreSetUnit: + """Unit tests for create_variants_for_score_set job.""" + + async def test_create_variants_for_score_set_raises_key_error_on_missing_hdp_from_ctx( + self, + mock_worker_ctx, + mock_job_manager, + ): + ctx = mock_worker_ctx.copy() + del ctx["hdp"] + + with pytest.raises(KeyError) as exc_info: + await create_variants_for_score_set(ctx, 999, mock_job_manager) + + assert str(exc_info.value) == "'hdp'" + + async def test_create_variants_for_score_set_calls_s3_client_with_correct_parameters( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None) as mock_download_fileobj, + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + # Use ANY for dynamically created Fileobj parameters. + mock_download_fileobj.assert_has_calls( + [ + call(Bucket="score-set-csv-uploads-dev", Key="sample_scores.csv", Fileobj=ANY), + call(Bucket="score-set-csv-uploads-dev", Key="sample_counts.csv", Fileobj=ANY), + ] + ) + + async def test_create_variants_for_score_set_s3_file_not_found( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object( + mock_s3_client, + "download_fileobj", + side_effect=Exception("The specified key does not exist."), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + + async def test_create_variants_for_score_set_counts_file_can_be_optional( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Remove counts_file_key to test optional behavior + create_variants_sample_params_without_counts = create_variants_sample_params.copy() + create_variants_sample_params_without_counts["counts_file_key"] = None + create_variants_sample_params_without_counts["count_columns_metadata"] = None + sample_independent_variant_creation_run.job_params = create_variants_sample_params_without_counts + session.add(sample_independent_variant_creation_run) + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample score dataframe only + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + None, + create_variants_sample_params_without_counts["score_columns_metadata"], + None, + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + async def test_create_variants_for_score_set_raises_when_no_targets_exist( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Remove all TargetGene entries to simulate no targets existing + sample_score_set.target_genes = [] + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Score set has no targets; cannot create variants.") + assert result["status"] == "exception" + assert isinstance(result["exception"], ValueError) + + async def test_create_variants_for_score_set_calls_validate_standardize_dataframe_with_correct_parameters( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ) as mock_validate, + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + mock_validate.assert_called_once_with( + scores_df=sample_score_dataframe, + counts_df=sample_count_dataframe, + score_columns_metadata=create_variants_sample_params["score_columns_metadata"], + count_columns_metadata=create_variants_sample_params["count_columns_metadata"], + targets=sample_score_set.target_genes, + hdp=mock_worker_ctx["hdp"], + ) + + async def test_create_variants_for_score_set_calls_create_variants_data_with_correct_parameters( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ) as mock_create_variants_data, + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + mock_create_variants_data.assert_called_once_with(sample_score_dataframe, sample_count_dataframe, None) + + async def test_create_variants_for_score_set_calls_create_variants_with_correct_parameters( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + mock_variant = MagicMock(spec=Variant) + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[mock_variant], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants", + return_value=None, + ) as mock_create_variants, + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + mock_create_variants.assert_called_once_with(session, sample_score_set, [mock_variant]) + + async def test_create_variants_for_score_set_handles_empty_variant_data( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants_data", return_value=[]), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + # If no exceptions are raised, the test passes for handling empty variant data. + + async def test_create_variants_for_score_set_removes_existing_variants_before_creation( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Add existing variants to the score set to test removal + sample_score_set.num_variants = 1 + variant = Variant(data={}, score_set_id=sample_score_set.id) + session.add(variant) + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + # Verify that existing variants have been removed + remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(remaining_variants) == 0 + session.refresh(sample_score_set) + assert sample_score_set.num_variants == 0 # Updated after creation + + async def test_create_variants_for_score_set_updates_processing_state( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + assert sample_score_set.processing_errors is None + + async def test_create_variants_for_score_set_updates_progress( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting variant creation job."), + call(10, 100, "Validated score set metadata and beginning data validation."), + call(80, 100, "Data validation complete; creating variants in database."), + call(100, 100, "Completed variant creation job."), + ] + ) + + async def test_create_variants_for_score_set_retains_existing_variants_when_exception_occurs( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Add existing variants to the score set to test retention on failure + sample_score_set.num_variants = 1 + variant = Variant(data={}, score_set_id=sample_score_set.id) + session.add(variant) + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Test exception during data validation"), + ), + ): + result = await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + # Verify that existing variants are still present + remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(remaining_variants) == 1 + session.refresh(sample_score_set) + assert sample_score_set.num_variants == 1 # Should remain unchanged + + async def test_create_variants_for_score_set_handles_exception_and_updates_state( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Test exception during data validation"), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await create_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), + ) + + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Test exception during data validation" in sample_score_set.processing_errors["exception"] + mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestCreateVariantsForScoreSetIntegration: + """Integration tests for create_variants_for_score_set job.""" + + ## Common success workflows + + async def test_create_variants_for_score_set_independent_job( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + # Assume the S3 client works as expected. + # + # Moto is omitted here for brevity since this + # function doesn't have S3 side effects. We assume the file is already in S3 for this test, + # and any cases where the file is not present will be handled by the job manager and tested + # in unit tests. + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + # + # A side effect of not mocking S3 more thoroughly + # is that our S3 download has no return value and just side effects data into a file-like object, + # so we mock pd.read_csv directly to avoid it trying to read from an empty file. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + async def test_create_variants_for_score_set_pipeline_job( + self, + session, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_creation_run.id) + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that pipeline job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + # Verify that pipeline status is updated. Pipeline will remain RUNNING + # as our default test pipeline includes the mapping job as well. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING + + ## Common edge cases + + async def test_create_variants_for_score_set_replaces_variants( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # First run to create initial variants + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + initial_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(initial_variants) == sample_score_dataframe.shape[0] + + # Modify dataframes to simulate updated data + updated_score_dataframe = sample_score_dataframe.copy() + updated_score_dataframe["score"] += 10 # Increment scores by 10 + + updated_count_dataframe = sample_count_dataframe.copy() + updated_count_dataframe["c_0"] += 5 # Increment counts by 5 + + # Mock a second run with updated dataframes + sample_independent_variant_creation_run.status = JobStatus.PENDING + session.commit() + + # Second run to replace existing variants + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[updated_score_dataframe, updated_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + replaced_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(replaced_variants) == sample_score_dataframe.shape[0] + + # Verify that the variants have been replaced with updated data + for variant in replaced_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = updated_score_dataframe.loc[ + updated_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = updated_count_dataframe.loc[ + updated_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(replaced_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + async def test_create_variants_for_score_set_handles_missing_counts_file( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + sample_independent_variant_creation_run.job_params["counts_file_key"] = None + sample_independent_variant_creation_run.job_params["count_columns_metadata"] = {} + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return only the score dataframe + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present but... + assert variant.data["count_data"] == {} # ...ensure count_data is empty since no counts file was provided + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + ## Common failure workflows + + async def test_create_variants_for_score_set_validation_error_during_creation( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + sample_score_dataframe.loc[0, "hgvs_nt"] = "c.G>X" # Introduce invalid value to trigger validation error + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + mock_send_slack_error.assert_called_once() + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "encountered 1 invalid variant strings" in sample_score_set.processing_errors["exception"] + assert len(sample_score_set.processing_errors["detail"]) > 0 + + # Verify that no variants were created + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == 0 + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + async def test_create_variants_for_score_set_generic_exception_handling_during_creation( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + mock_send_slack_error.assert_called_once() + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + async def test_create_variants_for_score_set_generic_exception_handling_during_replacement( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # First run to create initial variants + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + initial_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(initial_variants) == sample_score_dataframe.shape[0] + + # Mock a second run to replace existing variants + sample_independent_variant_creation_run.status = JobStatus.PENDING + session.commit() + + # Second run to replace existing variants but trigger a generic exception + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + mock_send_slack_error.assert_called_once() + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that initial variants are still present + remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(remaining_variants) == len(initial_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + ## Pipeline failure workflow + + async def test_create_variants_for_score_set_pipeline_job_generic_exception_handling( + self, + session, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await create_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_creation_run.id) + + mock_send_slack_error.assert_called_once() + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + # Verify that pipeline status is updated. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED + # Verify other pipeline runs are marked as failed + other_runs = ( + session.query(Pipeline) + .filter( + JobRun.pipeline_id == sample_variant_creation_pipeline.id, + Pipeline.id != sample_pipeline_variant_creation_run.id, + ) + .all() + ) + for run in other_runs: + assert run.status == JobStatus.SKIPPED + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestCreateVariantsForScoreSetArqContext: + """Integration tests for create_variants_for_score_set job using ARQ worker context.""" + + async def test_create_variants_for_score_set_with_arq_context_independent_ctx( + self, + session, + arq_redis, + arq_worker, + with_independent_processing_runs, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_independent_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + async def test_create_variants_for_score_set_with_arq_context_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_pipeline_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that pipeline job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + # Verify that pipeline status is updated. Pipeline will remain RUNNING + # as our default test pipeline includes the mapping job as well. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING + + async def test_create_variants_for_score_set_with_arq_context_generic_exception_handling_independent_ctx( + self, + session, + arq_redis, + arq_worker, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_independent_variant_creation_run, + with_populated_domain_data, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_independent_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + async def test_create_variants_for_score_set_with_arq_context_generic_exception_handling_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_pipeline_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + # Verify that pipeline status is updated. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED + + # Verify other pipeline runs are marked as cancelled + other_runs = ( + session.query(Pipeline) + .filter( + JobRun.pipeline_id == sample_variant_creation_pipeline.id, + Pipeline.id != sample_pipeline_variant_creation_run.id, + ) + .all() + ) + for run in other_runs: + assert run.status == JobStatus.SKIPPED diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py new file mode 100644 index 00000000..61357984 --- /dev/null +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -0,0 +1,1867 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import MagicMock, call, patch + +from sqlalchemy.exc import NoResultFound + +from mavedb.lib.exceptions import ( + NoMappedVariantsError, + NonexistentMappingReferenceError, + NonexistentMappingResultsError, + NonexistentMappingScoresError, +) +from mavedb.lib.mapping import EXCLUDED_PREMAPPED_ANNOTATION_KEYS +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.variant import Variant +from mavedb.models.variant_annotation_status import VariantAnnotationStatus +from mavedb.worker.jobs.variant_processing.mapping import map_variants_for_score_set +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import TEST_CODING_LAYER, TEST_GENOMIC_LAYER, TEST_PROTEIN_LAYER +from tests.helpers.util.setup.worker import construct_mock_mapping_output, create_variants_in_score_set + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestMapVariantsForScoreSetUnit: + """Unit tests for map_variants_for_score_set job.""" + + async def dummy_mapping_output(self, output_data={}): + return output_data + + async def test_map_variants_for_score_set_no_mapping_results( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no mapping results are found.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object(_UnixSelectorEventLoop, "run_in_executor", return_value=self.dummy_mapping_output({})), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing results.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingResultsError) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Mapping results were not returned from VRS mapping service" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + + async def test_map_variants_for_score_set_no_mapped_scores( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no scores are mapped.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output( + {"mapped_scores": [], "error_message": "No variants were mapped for this score set"} + ), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed; no variants were mapped.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingScoresError) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "No variants were mapped for this score set" in sample_score_set.mapping_errors["error_message"] + + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + + async def test_map_variants_for_score_set_no_reference_data( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no reference data is available.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output( + {"mapped_scores": [MagicMock()], "error_message": "Reference metadata missing from mapping results"} + ), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing reference metadata.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingReferenceError) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "Reference metadata missing from mapping results" in sample_score_set.mapping_errors["error_message"] + + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + + async def test_map_variants_for_score_set_nonexistent_target_gene( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when the target gene does not exist.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output( + { + "mapped_scores": [MagicMock()], + "reference_sequences": {"some_key": "some_value"}, + } + ), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], ValueError) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + + async def test_map_variants_for_score_set_returns_variants_not_in_score_set( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when variants not in score set are returned.""" + # Add a non-existent variant to the mapped output to ensure at least one invalid mapping + mapping_output = await construct_mock_mapping_output( + session=session, score_set=sample_score_set, with_layers={"g", "c", "p"} + ) + mapping_output["mapped_scores"].append({"variant_id": "not_in_score_set", "some_other_data": "value"}) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output(mapping_output), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NoResultFound) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + + async def test_map_variants_for_score_set_success_missing_gene_info( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with missing gene info.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=False, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + session.add(variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify the gene info is missing from the target gene reference sequence + for target in sample_score_set.target_genes: + assert target.mapped_hgnc_name is None + + # Verify that a mapped variant was created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 1 + + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].annotation_type == "vrs_mapping" + assert annotation_statuses[0].status == "success" + + @pytest.mark.parametrize( + "with_layers", + [ + {"g"}, + {"c"}, + {"p"}, + {"g", "c"}, + {"g", "p"}, + {"c", "p"}, + {"g", "c", "p"}, + ], + ) + async def test_map_variants_for_score_set_success_layer_permutations( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + with_layers, + ): + """Test successful mapping variants with annotation layer permutations.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers=with_layers, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + session.add(variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify the annotation layers presence/absence + for target in sample_score_set.target_genes: + if "g" in with_layers: + assert target.pre_mapped_metadata["genomic"] is not None + assert target.post_mapped_metadata["genomic"] is not None + pre_mapped_comparator = TEST_GENOMIC_LAYER["computed_reference_sequence"].copy() + for key in EXCLUDED_PREMAPPED_ANNOTATION_KEYS: + pre_mapped_comparator.pop(key, None) + + assert target.pre_mapped_metadata["genomic"] == pre_mapped_comparator + assert target.post_mapped_metadata["genomic"] == TEST_GENOMIC_LAYER["mapped_reference_sequence"] + else: + assert target.post_mapped_metadata.get("genomic") is None + + if "c" in with_layers: + assert target.pre_mapped_metadata["cdna"] is not None + assert target.post_mapped_metadata["cdna"] is not None + pre_mapped_comparator = TEST_CODING_LAYER["computed_reference_sequence"].copy() + for key in EXCLUDED_PREMAPPED_ANNOTATION_KEYS: + pre_mapped_comparator.pop(key, None) + + assert target.pre_mapped_metadata["cdna"] == pre_mapped_comparator + assert target.post_mapped_metadata["cdna"] == TEST_CODING_LAYER["mapped_reference_sequence"] + else: + assert target.post_mapped_metadata.get("cdna") is None + + if "p" in with_layers: + assert target.pre_mapped_metadata["protein"] is not None + assert target.post_mapped_metadata["protein"] is not None + pre_mapped_comparator = TEST_PROTEIN_LAYER["computed_reference_sequence"].copy() + for key in EXCLUDED_PREMAPPED_ANNOTATION_KEYS: + pre_mapped_comparator.pop(key, None) + + assert target.pre_mapped_metadata["protein"] == pre_mapped_comparator + assert target.post_mapped_metadata["protein"] == TEST_PROTEIN_LAYER["mapped_reference_sequence"] + else: + assert target.post_mapped_metadata.get("protein") is None + + # Verify that a mapped variant was created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 1 + + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].annotation_type == "vrs_mapping" + assert annotation_statuses[0].status == "success" + + async def test_map_variants_for_score_set_success_no_successful_mapping( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with no successful mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=False, # Missing post-mapped + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + session.add(variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "failed" + assert result["data"] == {} + assert isinstance(result["exception"], NoMappedVariantsError) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors["error_message"] == "All variants failed to map." + + # Verify that one mapped variant was created. Although no successful mapping, an entry is still created. + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 1 + + # Verify that the mapped variant has no post-mapped data + mapped_variant = mapped_variants[0] + assert mapped_variant.post_mapped == {} + + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].annotation_type == "vrs_mapping" + assert annotation_statuses[0].status == "failed" + + async def test_map_variants_for_score_set_incomplete_mapping( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with incomplete mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=False, # Only some variants mapped + ) + + # Create two variants in the score set to be mapped + variant1 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={}, + urn="variant:1", + ) + variant2 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.2G>T", + hgvs_pro="NP_000000.1:p.Val2Leu", + data={}, + urn="variant:2", + ) + session.add_all([variant1, variant2]) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.incomplete + assert sample_score_set.mapping_errors is None + + # Although only one variant was successfully mapped, verify that an entity was created + # for each variant in the score set + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 2 + + # Verify that only one variant has post-mapped data + mapped_variant_with_post_data = ( + session.query(MappedVariant).filter(MappedVariant.post_mapped != {}).one_or_none() + ) + assert mapped_variant_with_post_data is not None + + mapped_variant_without_post_data = ( + session.query(MappedVariant).filter(MappedVariant.post_mapped == {}).one_or_none() + ) + assert mapped_variant_without_post_data is not None + + # Verify that annotation statuses were created and correct + annotation_status_success = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, VariantAnnotationStatus.status == "success") + .all() + ) + assert len(annotation_status_success) == 1 + assert annotation_status_success[0].annotation_type == "vrs_mapping" + annotation_status_failed = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, VariantAnnotationStatus.status == "failed") + .all() + ) + assert len(annotation_status_failed) == 1 + assert annotation_status_failed[0].annotation_type == "vrs_mapping" + + async def test_map_variants_for_score_set_complete_mapping( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with complete mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, # All variants mapped + ) + + # Create two variants in the score set to be mapped + variant1 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={}, + urn="variant:1", + ) + variant2 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.2G>T", + hgvs_pro="NP_000000.1:p.Val2Leu", + data={}, + urn="variant:2", + ) + session.add_all([variant1, variant2]) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 2 + + # Verify that both variants have post-mapped data. I'm comfortable assuming the + # data is correct given our layer permutation tests above. + for urn in ["variant:1", "variant:2"]: + mapped_variant = session.query(MappedVariant).filter(MappedVariant.variant.has(urn=urn)).one_or_none() + assert mapped_variant is not None + assert mapped_variant.post_mapped != {} + + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 2 + for status in annotation_statuses: + assert status.annotation_type == "vrs_mapping" + assert status.status == "success" + + async def test_map_variants_for_score_set_updates_existing_mapped_variants( + self, + with_independent_processing_runs, + session, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants updates existing mapped variants.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant and associated mapped data/annotation status in the score set to be updated + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2023-01-01T00:00:00Z", + mapping_api_version="v1.0.0", + ) + session.add(mapped_variant) + session.commit() + variant_annotation_status = VariantAnnotationStatus( + variant_id=variant.id, current=True, annotation_type="vrs_mapping", status="success" + ) + session.add(variant_annotation_status) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify the existing mapped variant was marked as non-current + non_current_mapped_variant = ( + session.query(MappedVariant) + .filter(MappedVariant.id == mapped_variant.id, MappedVariant.current.is_(False)) + .one_or_none() + ) + assert non_current_mapped_variant is not None + + # Verify a new mapped variant entry was created + new_mapped_variant = ( + session.query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() + ) + assert new_mapped_variant is not None + + # Verify that the new mapped variant has updated mapping data + assert new_mapped_variant.mapped_date != "2023-01-01T00:00:00Z" + assert new_mapped_variant.mapping_api_version != "v1.0.0" + + # Verify the non-current annotation status still exists + old_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter( + VariantAnnotationStatus.variant_id == non_current_mapped_variant.variant_id, + VariantAnnotationStatus.current.is_(False), + ) + .one_or_none() + ) + assert old_annotation_status is not None + + # Verify that a new annotation status was created + new_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == variant.id, VariantAnnotationStatus.current.is_(True)) + .one_or_none() + ) + assert new_annotation_status is not None + + async def test_map_variants_for_score_set_progress_updates( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants reports progress updates.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + session.add(variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify progress updates were reported + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting variant mapping job."), + call(10, 100, "Score set prepared for variant mapping."), + call(30, 100, "Mapping variants using VRS mapping service."), + call(80, 100, "Processing mapped variants."), + call(90, 100, "Saving mapped variants."), + call(100, 100, "Finished processing mapped variants."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestMapVariantsForScoreSetIntegration: + """Integration tests for map_variants_for_score_set job.""" + + async def test_map_variants_for_score_set_independent_job( + self, + session, + with_independent_processing_runs, + mock_s3_client, + mock_worker_ctx, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + """Test mapping variants for an independent processing run.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Mock mapping output + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Now, map variants for the score set + result = await map_variants_for_score_set(mock_worker_ctx, sample_independent_variant_mapping_run.id) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + # Verify that mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that target gene info was updated + for target in sample_score_set.target_genes: + assert target.mapped_hgnc_name is not None + assert target.post_mapped_metadata is not None + + # Verify that each variant has a corresponding mapped variant + variants = ( + session.query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + + # Verify that the job status was updated + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + async def test_map_variants_for_score_set_pipeline_context( + self, + session, + with_variant_creation_pipeline_runs, + with_variant_mapping_pipeline_runs, + mock_s3_client, + mock_worker_ctx, + sample_pipeline_variant_creation_run, + sample_pipeline_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + ): + """Test mapping variants for a pipeline processing run.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_pipeline_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Mock mapping output + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Now, map variants for the score set + result = await map_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_mapping_run.id) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + # Verify that mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that target gene info was updated + for target in sample_score_set.target_genes: + assert target.mapped_hgnc_name is not None + assert target.post_mapped_metadata is not None + + # Verify that each variant has a corresponding mapped variant + variants = ( + session.query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + + # Verify that the job status was updated + processing_run = ( + session.query(sample_pipeline_variant_mapping_run.__class__) + .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status was updated. We expect RUNNING here because + # the mapping job is not the only job in our dummy pipeline. + pipeline_run = ( + session.query(sample_pipeline_variant_mapping_run.pipeline.__class__) + .filter( + sample_pipeline_variant_mapping_run.pipeline.__class__.id + == sample_pipeline_variant_mapping_run.pipeline.id + ) + .one() + ) + assert pipeline_run.status == PipelineStatus.RUNNING + + async def test_map_variants_for_score_set_empty_mapping_results( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants when no mapping results are returned.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return {} + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object(_UnixSelectorEventLoop, "run_in_executor", return_value=dummy_mapping_job()), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], NonexistentMappingResultsError) + assert result["data"] == {} + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Mapping results were not returned from VRS mapping service" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_no_mapped_scores( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants when no variants are mapped.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=False, + with_reference_metadata=True, + with_mapped_scores=False, # No mapped scores + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], NonexistentMappingScoresError) + assert result["data"] == {} + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # Error message originates from our mock mapping construction function + assert "test error: no mapped scores" in sample_score_set.mapping_errors["error_message"] + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_no_reference_data( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants when no reference data is provided.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=False, # No reference metadata + with_mapped_scores=True, + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert isinstance(result["exception"], NonexistentMappingReferenceError) + assert result["data"] == {} + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "Reference metadata missing from mapping results" in sample_score_set.mapping_errors["error_message"] + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_updates_current_mapped_variants( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants updates current mapped variants even if no changes occur.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + # Associate mapped variants with all variants just created in the score set + variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + for variant in variants: + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2023-01-01T00:00:00Z", + mapping_api_version="v1.0.0", + ) + annotation_status = VariantAnnotationStatus( + variant_id=variant.id, current=True, annotation_type="vrs_mapping", status="success" + ) + session.add(annotation_status) + session.add(mapped_variant) + session.commit() + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that mapped variants were marked as non-current and new entries created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == len(variants) * 2 # Each variant has two mapped entries now + for variant in variants: + non_current_mapped_variant = ( + session.query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(False)) + .one_or_none() + ) + assert non_current_mapped_variant is not None + + new_mapped_variant = ( + session.query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() + ) + assert new_mapped_variant is not None + + # Verify that the new mapped variant has updated mapping data + assert new_mapped_variant.mapped_date != "2023-01-01T00:00:00Z" + assert new_mapped_variant.mapping_api_version != "v1.0.0" + + # Verify that annotation statuses where marked as non-current and new entries created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == len(variants) * 2 # Each variant has two annotation statuses now + for variant in variants: + old_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == variant.id, VariantAnnotationStatus.current.is_(False)) + .one_or_none() + ) + assert old_annotation_status is not None + + new_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == variant.id, VariantAnnotationStatus.current.is_(True)) + .one_or_none() + ) + assert new_annotation_status is not None + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + async def test_map_variants_for_score_set_no_variants( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no variants exist in the score set.""" + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingScoresError) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "test error: no mapped scores" in sample_score_set.mapping_errors["error_message"] + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_exception_in_mapping( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when an exception occurs during mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + raise ValueError("test exception during mapping") + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + mock_send_slack_error.assert_called_once() + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], ValueError) + # exception messages are persisted in internal properties + assert "test exception during mapping" in str(result["exception"]) + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # but replaced with generic error message for external visibility + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestMapVariantsForScoreSetArqContext: + """Integration tests for map_variants_for_score_set job using ARQ worker context.""" + + async def test_create_variants_for_score_set_with_arq_context_independent_ctx( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + with_independent_processing_runs, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, + ): + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + standalone_worker_context, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + await arq_redis.enqueue_job("map_variants_for_score_set", sample_independent_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that each variant has a corresponding mapped variant + variants = ( + session.query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + + # Verify that the job status was updated + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + async def test_map_variants_for_score_set_with_arq_context_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + with_variant_creation_pipeline_runs, + with_variant_mapping_pipeline_runs, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_pipeline_variant_creation_run, + sample_pipeline_variant_mapping_run, + ): + """Test mapping variants for a pipeline processing run using ARQ context.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + standalone_worker_context, + sample_pipeline_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=session, + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Mock mapping output + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Now, map variants for the score set + await arq_redis.enqueue_job("map_variants_for_score_set", sample_pipeline_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that each variant has a corresponding mapped variant + variants = ( + session.query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + + # Verify that the job status was updated + processing_run = ( + session.query(sample_pipeline_variant_mapping_run.__class__) + .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status was updated. We expect RUNNING here because + # the mapping job is not the only job in our dummy pipeline. + pipeline_run = ( + session.query(sample_pipeline_variant_mapping_run.pipeline.__class__) + .filter( + sample_pipeline_variant_mapping_run.pipeline.__class__.id + == sample_pipeline_variant_mapping_run.pipeline.id + ) + .one() + ) + assert pipeline_run.status == PipelineStatus.RUNNING + + async def test_map_variants_for_score_set_with_arq_context_generic_exception_handling( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + with_independent_processing_runs, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants with ARQ context when an exception occurs during mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + raise ValueError("test exception during mapping") + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job("map_variants_for_score_set", sample_independent_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # but replaced with generic error message for external visibility + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_with_arq_context_generic_exception_in_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + with_variant_mapping_pipeline_runs, + sample_pipeline_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants with ARQ context in pipeline when an exception occurs during mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + raise ValueError("test exception during mapping") + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await arq_redis.enqueue_job("map_variants_for_score_set", sample_pipeline_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + mock_send_slack_error.assert_called_once() + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # but replaced with generic error message for external visibility + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = session.query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify that the job status was updated. + processing_run = ( + session.query(sample_pipeline_variant_mapping_run.__class__) + .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + # Verify that the pipeline run status was updated to FAILED. + pipeline_run = ( + session.query(sample_pipeline_variant_mapping_run.pipeline.__class__) + .filter( + sample_pipeline_variant_mapping_run.pipeline.__class__.id + == sample_pipeline_variant_mapping_run.pipeline.id + ) + .one() + ) + assert pipeline_run.status == PipelineStatus.FAILED + + # Verify that other jobs in the pipeline were skipped + for job_run in pipeline_run.job_runs: + if job_run.id != sample_pipeline_variant_mapping_run.id: + assert job_run.status == JobStatus.SKIPPED diff --git a/tests/worker/lib/decorators/conftest.py b/tests/worker/lib/decorators/conftest.py new file mode 100644 index 00000000..851d7497 --- /dev/null +++ b/tests/worker/lib/decorators/conftest.py @@ -0,0 +1,10 @@ +import os + +import pytest + + +# Unset test mode flag before each test to ensure decorator logic is executed +# during unit testing of the decorator itself. +@pytest.fixture(autouse=True) +def unset_test_mode_flag(): + os.environ.pop("MAVEDB_TEST_MODE", None) diff --git a/tests/worker/lib/decorators/test_job_guarantee.py b/tests/worker/lib/decorators/test_job_guarantee.py new file mode 100644 index 00000000..23db1d94 --- /dev/null +++ b/tests/worker/lib/decorators/test_job_guarantee.py @@ -0,0 +1,77 @@ +# ruff: noqa: E402 +""" +Unit and integration tests for the with_guaranteed_job_run_record async decorator. +Covers JobRun creation, status transitions, error handling, and DB persistence. +""" + +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +from sqlalchemy import select + +from mavedb import __version__ +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators.job_guarantee import with_guaranteed_job_run_record +from tests.helpers.transaction_spy import TransactionSpy + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@with_guaranteed_job_run_record("test_job") +async def sample_job(ctx: dict, job_id: int): + """Sample job function to test the decorator. + + NOTE: The job_id parameter is injected by the decorator + and is not passed explicitly when calling the function. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + return {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestJobGuaranteeDecoratorUnit: + async def test_decorator_must_receive_ctx_as_first_argument(self, mock_worker_ctx): + with pytest.raises(ValueError) as exc_info: + await sample_job() + + assert "Managed functions must receive context as first argument" in str(exc_info.value) + + async def test_decorator_calls_wrapped_function(self, mock_worker_ctx): + result = await sample_job(mock_worker_ctx) + assert result == {"status": "ok"} + + async def test_decorator_creates_job_run(self, mock_worker_ctx, session): + with ( + TransactionSpy.spy(session, expect_flush=True, expect_commit=True), + ): + await sample_job(mock_worker_ctx) + + job_run = session.execute(select(JobRun)).scalars().first() + assert job_run is not None + assert job_run.status == JobStatus.PENDING + assert job_run.job_type == "test_job" + assert job_run.job_function == "sample_job" + assert job_run.mavedb_version == __version__ + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestJobGuaranteeDecoratorIntegration: + async def test_decorator_persists_job_run_record(self, session, standalone_worker_context): + # Flush called implicitly by commit + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + job_task = await sample_job(standalone_worker_context) + + assert job_task == {"status": "ok"} + + job_run = session.execute(select(JobRun).order_by(JobRun.id.desc())).scalars().first() + assert job_run.status == JobStatus.PENDING + assert job_run.job_type == "test_job" + assert job_run.job_function == "sample_job" + assert job_run.mavedb_version is not None diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py new file mode 100644 index 00000000..c887588f --- /dev/null +++ b/tests/worker/lib/decorators/test_job_management.py @@ -0,0 +1,390 @@ +# ruff : noqa: E402 + +""" +Unit and integration tests for the with_job_management async decorator. +Covers status transitions, error handling, and JobManager interaction. +""" + +import pytest + + +pytest.importorskip("arq") # Skip tests if arq is not installed + +import asyncio +from unittest.mock import patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators.job_management import with_job_management +from mavedb.worker.lib.managers.constants import RETRYABLE_FAILURE_CATEGORIES +from mavedb.worker.lib.managers.exceptions import JobStateError +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.transaction_spy import TransactionSpy + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +@with_job_management +async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + """Sample job function to test the decorator. + + NOTE: The job_manager parameter is injected by the decorator + and is not passed explicitly when calling the function. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + return {"status": "ok"} + + +@with_job_management +async def sample_raise(ctx: dict, job_id: int, job_manager: JobManager): + """Sample job function to test the decorator in cases where the wrapped function raises an exception. + + NOTE: The job_manager parameter is injected by the decorator + and is not passed explicitly when calling the function. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + raise RuntimeError("error in wrapped function") + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestManagedJobDecoratorUnit: + async def test_decorator_must_receive_ctx_as_first_argument(self, mock_job_manager): + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_job_manager.db): + await sample_job() + + assert "Managed functions must receive context as first argument" in str(exc_info.value) + + async def test_decorator_calls_wrapped_function_and_returns_result( + self, session, mock_job_manager, mock_worker_ctx + ): + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None), + patch.object(mock_job_manager, "succeed_job", return_value=None), + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + + result = await sample_job(mock_worker_ctx, 999) + assert result == {"status": "ok"} + + async def test_decorator_calls_start_job_and_succeed_job_when_wrapped_function_succeeds( + self, session, mock_worker_ctx, mock_job_manager + ): + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "succeed_job", return_value=None) as mock_succeed_job, + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_job(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_succeed_job.assert_called_once() + + @pytest.mark.parametrize( + "status", + [ + "failed", + "exception", + ], + ) + async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_returns_failed_status( + self, session, mock_worker_ctx, mock_job_manager, status + ): + @with_job_management + async def sample_fail(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": status, "data": {}, "exception": RuntimeError("simulated failure")} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "fail_job", return_value=None) as mock_fail_job, + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_fail(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_fail_job.assert_called_once() + + async def test_decorator_calls_start_job_and_skip_job_when_wrapped_function_returns_skipped_status( + self, session, mock_worker_ctx, mock_job_manager + ): + @with_job_management + async def sample_skip(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "skipped", "data": {}, "exception": None} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "skip_job", return_value=None) as mock_skip_job, + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_skip(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_skip_job.assert_called_once() + + async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_raises_and_no_retry( + self, session, mock_worker_ctx, mock_job_manager + ): + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "should_retry", return_value=False), + patch.object(mock_job_manager, "fail_job", return_value=None) as mock_fail_job, + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_raise(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_fail_job.assert_called_once() + mock_send_slack_error.assert_called_once() + + async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_raises_and_retry( + self, session, mock_worker_ctx, mock_job_manager + ): + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "should_retry", return_value=True), + patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_raise(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_prepare_retry.assert_called_once_with(reason="error in wrapped function") + mock_send_slack_error.assert_called_once() + + @pytest.mark.parametrize("missing_key", ["redis"]) + async def test_decorator_raises_value_error_if_required_context_missing( + self, mock_job_manager, mock_worker_ctx, missing_key + ): + del mock_worker_ctx[missing_key] + + with ( + pytest.raises(ValueError) as exc_info, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await sample_job(mock_worker_ctx, 999) + + mock_send_slack_error.assert_called_once() + assert missing_key.replace("_", " ") in str(exc_info.value).lower() + assert "not found in job context" in str(exc_info.value).lower() + + async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( + self, session, mock_job_manager, mock_worker_ctx + ): + raised_exc = JobStateError("error in job start") + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + patch.object(mock_job_manager, "start_job", side_effect=raised_exc), + patch.object(mock_job_manager, "should_retry", return_value=False), + patch.object(mock_job_manager, "fail_job", return_value=None), + TransactionSpy.spy(session, expect_rollback=True, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + result = await sample_job(mock_worker_ctx, 999) + + assert result["status"] == "exception" + assert raised_exc == result["exception"] + mock_send_slack_error.assert_called_once() + + async def test_decorator_raises_value_error_if_job_id_missing(self, session, mock_job_manager, mock_worker_ctx): + # Remove job_id from args to simulate missing job_id + with ( + pytest.raises(ValueError) as exc_info, + TransactionSpy.spy(session), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + await sample_job(mock_worker_ctx) + + mock_send_slack_error.assert_called_once() + assert "job id not found in function arguments" in str(exc_info.value).lower() + + async def test_decorator_swallows_exception_from_wrapped_function_inside_except( + self, session, mock_job_manager, mock_worker_ctx + ): + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None), + patch.object(mock_job_manager, "should_retry", return_value=False), + patch.object(mock_job_manager, "fail_job", side_effect=JobStateError("error in job fail")), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): + mock_job_manager_class.return_value = mock_job_manager + result = await sample_raise(mock_worker_ctx, 999) + + # Should notify for internal and job error + assert mock_send_slack_error.call_count == 2 + # Errors within the main try block should take precedence + assert result["status"] == "exception" + assert str(result["exception"]) == "error in wrapped function" + + async def test_decorator_passes_job_manager_to_wrapped(self, session, mock_job_manager, mock_worker_ctx): + @with_job_management + async def assert_manager_passed_job(ctx, job_id: int, job_manager): + assert isinstance(job_manager, JobManager) + return {"status": "ok", "data": {}, "exception": None} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None), + patch.object(mock_job_manager, "succeed_job", return_value=None), + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + assert await assert_manager_passed_job(mock_worker_ctx, 999) + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestManagedJobDecoratorIntegration: + """Integration tests for with_job_management decorator.""" + + async def test_decorator_integrated_job_lifecycle_success( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + return {"status": "ok", "data": {}, "exception": None} + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Now allow the job to complete + event.set() + await job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + async def test_decorator_integrated_job_lifecycle_skipped( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "skipped", "data": {}, "exception": None} + + # Run the job + await sample_job(standalone_worker_context, sample_job_run.id) + + # After completion, status should be SKIPPED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + async def test_decorator_integrated_job_lifecycle_failed( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "failed", "data": {}, "exception": RuntimeError("Simulated job failure")} + + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Run the job + await sample_job(standalone_worker_context, sample_job_run.id) + + mock_send_slack_error.assert_called_once() + # After completion, status should be FAILED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.error_message == "Simulated job failure" + + async def test_decorator_integrated_job_lifecycle_raised_exception( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure") + + # Start the job (it will block at event.wait()) + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not in error + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Now allow the job to complete with failure. This failure + # should be swallowed by the job_task. + event.set() + await job_task + + mock_send_slack_error.assert_called_once() + + # After failure, status should be FAILED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.error_message == "Simulated job failure" + + async def test_decorator_integrated_job_lifecycle_retry( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + sample_job_run.failure_category = RETRYABLE_FAILURE_CATEGORIES[0] # Set a retryable failure category + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure for retry") + + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not in error + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # TODO: We patch `should_retry` to return True to force a retry scenario. After implementing failure + # categorization in the worker, this patch can be removed and we should directly test retry logic based + # on failure categories. + # + # Now allow the job to complete with failure that triggers a retry. This failure + # should be swallowed by the job_task. + with patch.object(JobManager, "should_retry", return_value=True): + event.set() + await job_task + + mock_send_slack_error.assert_called_once() + + # After failure with retry, status should be PENDING + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 # Ensure it attempted once before retrying diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py new file mode 100644 index 00000000..45c7c3d2 --- /dev/null +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -0,0 +1,553 @@ +# ruff : noqa: E402 + +""" +Unit tests for the with_pipeline_management async decorator. +Covers orchestration steps, error handling, and PipelineManager interaction. +""" + +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +import asyncio +from unittest.mock import patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +from tests.helpers.transaction_spy import TransactionSpy + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +async def sample_job(ctx=None, job_id=None): + """Sample job function to test the decorator. When called, it patches + the with_job_management decorator to be a no-op so we can test the + with_pipeline_management decorator in isolation. + + NOTE: The job_manager parameter is normally injected by the with_job_management + decorator. Since we are patching that decorator to be a no-op here, + we do not include it in the function signature. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + # patch the with_job_management decorator to be a no-op + with patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f + ) as mock_job_mgmt: + + @with_pipeline_management + async def patched_sample_job(ctx: dict, job_id: int): + return {"status": "ok"} + + return await patched_sample_job(ctx, job_id) + + # Ensure the mock was called + mock_job_mgmt.assert_called_once() + + +async def sample_raise(ctx: dict, job_id: int): + """Sample job function to test the decorator when a job raises. + When called, it patches the with_job_management decorator to be + a no-op so we can test the with_pipeline_management decorator in isolation. + + NOTE: The job_manager parameter is normally injected by the with_job_management + decorator. Since we are patching that decorator to be a no-op here, + we do not include it in the function signature. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + # patch the with_job_management decorator to be a no-op + with patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f + ) as mock_job_mgmt: + + @with_pipeline_management + async def patched_sample_job(ctx: dict, job_id: int): + raise RuntimeError("error in wrapped function") + + return await patched_sample_job(ctx, job_id) + + # Ensure the mock was called + mock_job_mgmt.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestPipelineManagementDecoratorUnit: + """Unit tests for the with_pipeline_management decorator.""" + + async def test_decorator_must_receive_ctx_as_first_argument(self, mock_pipeline_manager): + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): + await sample_job() + + assert "Managed functions must receive context as first argument" in str(exc_info.value) + + @pytest.mark.parametrize("missing_key", ["redis"]) + async def test_decorator_raises_value_error_if_required_context_missing( + self, mock_pipeline_manager, mock_worker_ctx, missing_key + ): + del mock_worker_ctx[missing_key] + + with ( + pytest.raises(ValueError) as exc_info, + TransactionSpy.spy(mock_pipeline_manager.db), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): + await sample_job(mock_worker_ctx, 999) + + assert missing_key.replace("_", " ") in str(exc_info.value).lower() + assert "not found in pipeline context" in str(exc_info.value).lower() + mock_send_slack_error.assert_called_once() + + async def test_decorator_raises_value_error_if_job_id_missing(self, mock_pipeline_manager, mock_worker_ctx): + # Remove job_id from args to simulate missing job_id + with ( + pytest.raises(ValueError) as exc_info, + TransactionSpy.spy(mock_pipeline_manager.db), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): + await sample_job(mock_worker_ctx) + + assert "job id not found in function arguments" in str(exc_info.value).lower() + mock_send_slack_error.assert_called_once() + + async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id( + self, session, mock_pipeline_manager, mock_worker_ctx + ): + with ( + TransactionSpy.mock_database_execution_failure( + session, + exception=ValueError("job id not found in pipeline context"), + expect_rollback=True, + ), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): + await sample_job(mock_worker_ctx, 999) + mock_send_slack_error.assert_called_once() + + async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manager( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + TransactionSpy.spy(session, expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + result = await sample_job(mock_worker_ctx, sample_job_run.id) + + assert result == {"status": "ok"} + + async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_independent_job_run, with_populated_job_data + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + # We shouldn't expect any commits since no pipeline coordination occurs + TransactionSpy.spy(session), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + result = await sample_job(mock_worker_ctx, sample_independent_job_run.id) + + mock_coordinate_pipeline.assert_not_called() + mock_start_pipeline.assert_not_called() + assert result == {"status": "ok"} + + async def test_decorator_starts_pipeline_when_in_created_state( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(session, expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + result = await sample_job(mock_worker_ctx, sample_job_run.id) + + mock_start_pipeline.assert_called_once() + assert result == {"status": "ok"} + + @pytest.mark.parametrize( + "pipeline_state", + [status for status in PipelineStatus._member_map_.values() if status != PipelineStatus.CREATED], + ) + async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data, pipeline_state + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_state), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(session, expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + result = await sample_job(mock_worker_ctx, sample_job_run.id) + + mock_start_pipeline.assert_not_called() + assert result == {"status": "ok"} + + async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrapped_function( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + TransactionSpy.spy(session, expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + await sample_job(mock_worker_ctx, sample_job_run.id) + + mock_coordinate_pipeline.assert_called_once() + + async def test_decorator_swallows_exception_from_wrapped_function( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + await sample_raise(mock_worker_ctx, sample_job_run.id) + + mock_send_slack_error.assert_called_once() + + async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pipeline( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): + with ( + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_pipeline_manager, + "coordinate_pipeline", + side_effect=RuntimeError("error in coordinate_pipeline"), + ), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + # Exception raised from coordinate_pipeline should trigger rollback, + # and commit will be called when pipeline status is set to running + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + await sample_job(mock_worker_ctx, sample_job_run.id) + + assert mock_send_slack_error.call_count == 2 + + async def test_decorator_swallows_exception_from_job_management_decorator( + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): + def passthrough_decorator(f): + return f + + with ( + # patch the with_job_management decorator to raise an error + patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", + wraps=passthrough_decorator, + side_effect=ValueError("error in job management decorator"), + ) as mock_with_job_mgmt, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + + mock_with_job_mgmt.assert_called_once() + mock_send_slack_error.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestPipelineManagementDecoratorIntegration: + """Integration tests for the with_pipeline_management decorator.""" + + @pytest.mark.parametrize("initial_status", [PipelineStatus.CREATED, PipelineStatus.RUNNING]) + async def test_decorator_integrated_pipeline_lifecycle_success( + self, + session, + arq_redis, + sample_job_run, + sample_dependent_job_run, + standalone_worker_context, + with_populated_job_data, + sample_pipeline, + initial_status, + ): + # Use an event to control when the job completes + event = asyncio.Event() + dep_event = asyncio.Event() + + # Set initial pipeline status to the parameterized value. + # This allows testing both CREATED and RUNNING start states. + sample_pipeline.status = initial_status + session.commit() + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + return {"status": "ok", "data": {}, "exception": None} + + @with_pipeline_management + async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): + await dep_event.wait() # Simulate async work, block until test signals + return {"status": "ok", "data": {}, "exception": None} + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete and flush the Redis queue. Flush the queue first to ensure + # we don't mistakenly flush our queued job. + await arq_redis.flushdb() + event.set() + await job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + + # Pipeline remains RUNNING after job success, another job was queued. + assert pipeline.status == PipelineStatus.RUNNING + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 # Ensure the next job was queued + + # Simulate execution of next job by running the dependent job. + # Start the job (it will block at event.wait()) + dependent_job_task = asyncio.create_task( + sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id) + ) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete and flush the Redis queue. Flush the queue first to ensure + # we don't mistakenly flush our queued job. + await arq_redis.flushdb() + dep_event.set() + await dependent_job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + # Now that all jobs are complete, the pipeline should be SUCCEEDED + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.SUCCEEDED + + # No further jobs should be queued + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + async def test_decorator_integrated_pipeline_lifecycle_retryable_failure( + self, + session, + arq_redis, + sample_job_run, + sample_dependent_job_run, + standalone_worker_context, + with_populated_job_data, + sample_pipeline, + ): + # Use an event to control when the job completes + event = asyncio.Event() + retry_event = asyncio.Event() + dep_event = asyncio.Event() + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure for retry") + + @with_pipeline_management + async def sample_retried_job(ctx: dict, job_id: int, job_manager: JobManager): + await retry_event.wait() # Simulate async work, block until test signals + return {"status": "ok", "data": {}, "exception": None} + + @with_pipeline_management + async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): + await dep_event.wait() # Simulate async work, block until test signals + return {"status": "ok", "data": {}, "exception": None} + + # job management handles slack alerting in this context + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete with failure that triggers a retry. This failure + # should be swallowed by the job_task. + with patch.object(JobManager, "should_retry", return_value=True): + event.set() + await job_task + + mock_send_slack_error.assert_called_once() + + # After failure with retry, status should be QUEUED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + assert job.retry_count == 1 # Ensure it attempted once before retrying + + # Now start the retried job (it will block at retry_event.wait()) + retried_job_task = asyncio.create_task(sample_retried_job(standalone_worker_context, sample_job_run.id)) + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # The pipeline should remain running + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the retried job to complete successfully + await arq_redis.flushdb() + retry_event.set() + await retried_job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 # Ensure the next job was queued + + # Simulate execution of next job by running the dependent job. + # Start the job (it will block at event.wait()) + dependent_job_task = asyncio.create_task( + sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id) + ) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete and flush the Redis queue. Flush the queue first to ensure + # we don't mistakenly flush our queued job. + await arq_redis.flushdb() + dep_event.set() + await dependent_job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + # Now that all jobs are complete, the pipeline should be SUCCEEDED + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.SUCCEEDED + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 # Ensure no further jobs were queued + + async def test_decorator_integrated_pipeline_lifecycle_non_retryable_failure( + self, + session, + arq_redis, + sample_job_run, + sample_dependent_job_run, + standalone_worker_context, + with_populated_job_data, + sample_pipeline, + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure") + + # job management handles slack alerting in this context + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete with failure and flush the Redis queue. This failure + # should be swallowed by the pipeline manager + await arq_redis.flushdb() + event.set() + await job_task + + mock_send_slack_error.assert_called_once() + + # After failure with no retry, status should be FAILED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + + # Pipeline should be marked FAILED after job failure + assert pipeline.status == PipelineStatus.FAILED + + # No further jobs should be queued + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + # Dependent job should transition to skipped since it was never queued + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED diff --git a/tests/worker/lib/managers/test_base_manager.py b/tests/worker/lib/managers/test_base_manager.py new file mode 100644 index 00000000..7f5c3a91 --- /dev/null +++ b/tests/worker/lib/managers/test_base_manager.py @@ -0,0 +1,19 @@ +# ruff: noqa: E402 +import pytest + +pytest.importorskip("arq") + +from mavedb.worker.lib.managers.base_manager import BaseManager + + +@pytest.mark.integration +class TestInitialization: + """Tests for BaseManager initialization.""" + + def test_initialization(self, session, arq_redis): + """Test that BaseManager initializes with db and redis attributes.""" + + manager = BaseManager(db=session, redis=arq_redis) + + assert manager.db == session + assert manager.redis == arq_redis diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py new file mode 100644 index 00000000..ad6b6ef1 --- /dev/null +++ b/tests/worker/lib/managers/test_job_manager.py @@ -0,0 +1,2190 @@ +# ruff: noqa: E402 +""" +Comprehensive test suite for JobManager class. + +Tests cover all aspects of job lifecycle management, pipeline coordination, +error handling, and database interactions. +""" + +import pytest + +pytest.importorskip("arq") + +import re +from unittest.mock import Mock, PropertyMock, patch + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.managers.constants import ( + CANCELLED_JOB_STATUSES, + RETRYABLE_FAILURE_CATEGORIES, + RETRYABLE_JOB_STATUSES, + STARTABLE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + JobStateError, + JobTransitionError, +) +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.transaction_spy import TransactionSpy + +HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION = ( + AttributeError("Mock attribute error"), + KeyError("Mock key error"), + TypeError("Mock type error"), + ValueError("Mock value error"), +) + + +@pytest.mark.integration +class TestJobManagerInitialization: + """Test JobManager initialization and setup.""" + + def test_init_with_valid_job(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful initialization with valid job ID.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + assert manager.db == session + assert manager.job_id == sample_job_run.id + assert manager.pipeline_id == sample_job_run.pipeline_id + + def test_init_with_no_pipeline(self, session, arq_redis, with_populated_job_data, sample_independent_job_run): + """Test initialization with job that has no pipeline.""" + manager = JobManager(session, arq_redis, sample_independent_job_run.id) + + assert manager.job_id == sample_independent_job_run.id + assert manager.pipeline_id is None + + def test_init_with_invalid_job_id(self, session, arq_redis): + """Test initialization failure with non-existent job ID.""" + job_id = 999 # Assuming this ID does not exist + with pytest.raises(DatabaseConnectionError, match=f"Failed to fetch job {job_id}"): + JobManager(session, arq_redis, job_id) + + +@pytest.mark.unit +class TestJobStartUnit: + """Unit tests for job start lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in STARTABLE_JOB_STATUSES], + ) + def test_start_job_raises_job_transition_error_when_managed_job_has_unstartable_status( + self, mock_job_manager, invalid_status, mock_job_run + ): + # Set initial job status to an invalid (unstartable) status. + mock_job_run.status = invalid_status + + # Start job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=f"Cannot start job {mock_job_manager.job_id} from status {invalid_status}", + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.start_job() + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], + ) + def test_start_job_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run, valid_status + ): + """Test job start failure due to exception during job object manipulation.""" + # Set initial job status to a valid status. Job status must be startable for this test. + mock_job_run.status = valid_status + + # Trigger: If any attribute access occurs on job, raise exception. If no access, return QUEUED. + def get_or_error(*args): + if args: + raise exception + return valid_status + + # Start job. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises(JobStateError, match="Failed to update job start state"), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.start_job() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == valid_status + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], + ) + def test_start_job_success(self, mock_job_manager, mock_job_run, valid_status): + """Test successful job start.""" + # Set initial job status to a valid status. Job status must be startable for this test. + mock_job_run.status = valid_status + + # Start job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.start_job() + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.RUNNING + assert mock_job_run.started_at is not None + assert mock_job_run.progress_message == "Job began execution" + + +@pytest.mark.integration +class TestJobStartIntegration: + """Integration tests for job start lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in STARTABLE_JOB_STATUSES], + ) + def test_job_exception_is_raised_when_job_has_invalid_status( + self, session, arq_redis, with_populated_job_data, sample_job_run, invalid_status + ): + """Test job start failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to invalid status and commit changes. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = invalid_status + session.commit() + + # Start job. Verify a JobTransitionError is raised due to the previously set invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Although the job might still set some attributes before the error is raised, the exception + # indicates to the caller that the job was not started successfully and the transaction should be rolled back. + with ( + TransactionSpy.spy(manager.db), + pytest.raises( + JobTransitionError, + match=f"Cannot start job {sample_job_run.id} from status {invalid_status.value}", + ), + ): + manager.start_job() + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], + ) + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run, valid_status): + """Test successful job start.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to invalid status and commit changes. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = valid_status + session.commit() + + # Start job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.start_job() + + # Commit pending changes made by start job. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + assert job.started_at is not None + assert job.progress_message == "Job began execution" + + +@pytest.mark.unit +class TestJobCompletionUnit: + """Unit tests for job completion lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in TERMINAL_JOB_STATUSES], + ) + def test_complete_job_raises_job_transition_error_when_managed_job_has_non_terminal_status( + self, mock_job_manager, mock_job_run, invalid_status + ): + # Set initial job status to an invalid (non-terminal) status. + mock_job_run.status = invalid_status + + # Complete job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=re.escape( + f"Cannot commplete job to status: {invalid_status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" + ), + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.complete_job(status=invalid_status, result={}) + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + def test_complete_job_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, mock_job_run, exception, valid_status + ): + """Test job completion failure due to exception during job object manipulation.""" + # Trigger: If any attribute setting on job status, raise exception. If only accessing, return whatever the mock + # objects original status was (starting job status doesn't matter for this test). + base_status = mock_job_run.status + + def get_or_error(*args): + if args: + raise exception + return base_status + + # Complete job. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises(JobStateError, match="Failed to update job completion state"), + TransactionSpy.spy(mock_job_manager.db), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.complete_job(status=valid_status, result={}) + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == base_status + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + def test_complete_job_sets_default_failure_category_when_job_failed(self, mock_job_manager, mock_job_run): + """Test job completion sets default failure category when job failed without error.""" + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.complete_job( + status=JobStatus.FAILED, result={"status": "failed", "data": {}, "exception": Exception()} + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == { + "result": { + "status": "failed", + "data": {}, + "exception_details": format_raised_exception_info_as_dict(Exception()), + } + } + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category == FailureCategory.UNKNOWN + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + @pytest.mark.parametrize( + "exception", + [ValueError("Test error"), None], + ) + def test_complete_job_success(self, mock_job_manager, valid_status, exception, mock_job_run): + """Test successful job completion.""" + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.complete_job( + status=valid_status, + result={"status": "ok", "data": {"output": "test"}, "exception": exception}, + error=exception, + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == valid_status + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_["result"] == { + "status": "ok", + "data": {"output": "test"}, + "exception_details": format_raised_exception_info_as_dict(exception) if exception else None, + } + + # If an exception was provided, verify error fields are set appropriately. + if exception: + assert mock_job_run.error_message == str(exception) + assert mock_job_run.error_traceback is not None + assert mock_job_run.failure_category == FailureCategory.UNKNOWN + + else: + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + + # Proper handling of failure category only applies to FAILED status. See + # test_complete_job_sets_default_failure_category_when_job_failed for that case. + + +@pytest.mark.integration +class TestJobCompletionIntegration: + """Test job completion lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in TERMINAL_JOB_STATUSES], + ) + def test_job_exception_is_raised_when_job_has_invalid_status( + self, session, arq_redis, with_populated_job_data, sample_job_run, invalid_status + ): + """Test job completion failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Verify a JobTransitionError is raised due to the passed invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Although the job might still set some attributes before the error is raised, the exception + # indicates to the caller that the job was not completed successfully and the transaction should be rolled back. + with ( + TransactionSpy.spy(manager.db), + pytest.raises( + JobTransitionError, + match=re.escape( + f"Cannot commplete job to status: {invalid_status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" + ), + ), + ): + manager.complete_job(status=invalid_status, result={"output": "test"}) + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + def test_job_updated_successfully_without_error( + self, session, arq_redis, with_populated_job_data, sample_job_run, valid_status + ): + """Test successful job completion.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.complete_job( + status=valid_status, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == valid_status + assert job.finished_at is not None + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} + assert job.error_message is None + assert job.error_traceback is None + + # For cases where no error is provided, verify failure category is set appropriately based + # on status. We automatically set UNKNOWN for FAILED status if no error is given. + if valid_status == JobStatus.FAILED: + assert job.failure_category == FailureCategory.UNKNOWN + else: + assert job.failure_category is None + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + def test_job_updated_successfully_with_error( + self, session, arq_redis, with_populated_job_data, sample_job_run, valid_status + ): + """Test successful job completion.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.complete_job( + status=valid_status, + result={ + "status": "ok", + "data": {"output": "test"}, + "exception": ValueError("Test error"), + }, + error=ValueError("Test error"), + ) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == valid_status + assert job.finished_at is not None + assert job.metadata_ == { + "result": { + "status": "ok", + "data": {"output": "test"}, + "exception_details": format_raised_exception_info_as_dict(ValueError("Test error")), + } + } + assert job.error_message == "Test error" + assert job.error_traceback is not None + assert job.failure_category == FailureCategory.UNKNOWN + + +@pytest.mark.unit +class TestJobFailureUnit: + """Unit tests for job failure lifecycle management.""" + + def test_fail_job_success(self, mock_job_manager, mock_job_run): + """Test that fail_job calls complete_job with status=JobStatus.FAILED.""" + + # Fail job with a test exception. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + # This convenience expects an exception to be provided. To fail a job without an exception, callers should use complete_job directly. + test_exception = Exception("Test exception") + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.fail_job( + error=test_exception, + result={"status": "failed", "data": {"output": "test"}, "exception": test_exception}, + ) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with( + status=JobStatus.FAILED, + result={"status": "failed", "data": {"output": "test"}, "exception": test_exception}, + error=test_exception, + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == { + "result": { + "status": "failed", + "data": {"output": "test"}, + "exception_details": format_raised_exception_info_as_dict(test_exception), + } + } + assert mock_job_run.error_message == str(test_exception) + assert mock_job_run.error_traceback is not None + assert mock_job_run.failure_category == FailureCategory.UNKNOWN + + +class TestJobFailureIntegration: + """Test job failure lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job failure.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Fail job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + exc = ValueError("Test error") + with TransactionSpy.spy(manager.db): + manager.fail_job(result={"status": "failed", "data": {}, "exception": exc}, error=exc) + + # Commit pending changes made by fail job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.FAILED + assert job.finished_at is not None + assert job.metadata_ == { + "result": {"status": "failed", "data": {}, "exception_details": format_raised_exception_info_as_dict(exc)} + } + assert job.error_message == "Test error" + assert job.error_traceback is not None + assert job.failure_category == FailureCategory.UNKNOWN + + +@pytest.mark.unit +class TestJobSuccessUnit: + """Unit tests for job success lifecycle management.""" + + def test_succeed_job_success(self, mock_job_manager, mock_job_run): + """Test that succeed_job calls complete_job with status=JobStatus.SUCCEEDED.""" + + # Succeed job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.succeed_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with( + status=JobStatus.SUCCEEDED, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.SUCCEEDED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == { + "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} + } + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + +class TestJobSuccessIntegration: + """Test job success lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job succeeding.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.succeed_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.SUCCEEDED + assert job.finished_at is not None + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + + +@pytest.mark.unit +class TestJobCancellationUnit: + """Unit tests for job cancellation lifecycle management.""" + + def test_cancel_job_success(self, mock_job_manager, mock_job_run): + """Test that cancel_job calls complete_job with status=JobStatus.CANCELLED.""" + + # Cancel job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.cancel_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with( + status=JobStatus.CANCELLED, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.CANCELLED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == { + "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} + } + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + +class TestJobCancellationIntegration: + """Test job cancellation lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job cancellation.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.cancel_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.CANCELLED + assert job.finished_at is not None + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + + +@pytest.mark.unit +class TestJobSkipUnit: + """Unit tests for job skip lifecycle management.""" + + def test_skip_job_success(self, mock_job_manager, mock_job_run): + """Test that skip_job calls complete_job with status=JobStatus.SKIPPED.""" + + # Skip job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.skip_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with( + status=JobStatus.SKIPPED, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.SKIPPED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == { + "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} + } + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + +@pytest.mark.integration +class TestJobSkipIntegration: + """Test job skip lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job skipping.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Skip job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.skip_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.SKIPPED + assert job.finished_at is not None + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + + +@pytest.mark.unit +class TestPrepareRetryUnit: + """Unit tests for job retry lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in RETRYABLE_JOB_STATUSES], + ) + def test_prepare_retry_raises_job_transition_error_when_managed_job_has_unretryable_status( + self, mock_job_manager, invalid_status, mock_job_run + ): + # Set initial job status to an invalid (unretryable) status. + mock_job_run.status = invalid_status + + # Preprare retry job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=re.escape(f"Cannot retry job {mock_job_manager.job_id} due to invalid state ({invalid_status})"), + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_retry() + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.retry_count == 0 + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_prepare_retry_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job prepare retry failure due to exception during job object manipulation.""" + # Set initial job status to FAILED. Job status must be retryable for this test. + initial_status = JobStatus.FAILED + mock_job_run.status = initial_status + + # Trigger: If any attribute access occurs on job, raise exception. If no access, return FAILED. + def get_or_error(*args): + if args: + raise exception + return initial_status + + # Prepare retry. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job retry state", + ), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.prepare_retry() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.retry_count == 0 + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + + def test_prepare_retry_success(self, mock_job_manager, mock_job_run): + """Test successful job prepare retry.""" + # Set initial job status to FAILED. Job status must be retryable for this test. + mock_job_run.status = JobStatus.FAILED + + # Prepare retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Mock the flag_modified function: mock objects don't have _sa_instance_state attribute required by SQLAlchemy + # funcs and it's easier to mock the functions that manipulate the state than to fully mock the state itself. + with ( + patch("mavedb.worker.lib.managers.job_manager.flag_modified") as mock_flag_modified, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_retry() + + # Verify flag_modified was called for metadata_ field. + mock_flag_modified.assert_called_once_with(mock_job_run, "metadata_") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.PENDING + assert mock_job_run.retry_count == 1 + assert mock_job_run.progress_message == "Job retry prepared" + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_["retry_history"] is not None + assert mock_job_run.started_at is None + assert mock_job_run.metadata_.get("result") is None + + +@pytest.mark.integration +class TestPrepareRetryIntegration: + """Test job retry lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status not in RETRYABLE_JOB_STATUSES], + ) + def test_prepare_retry_failed_due_to_invalid_status( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): + """Test job retry failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to non-failed state + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Prepare retry job. Verify a JobTransitionError is raised due to the passed invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(manager.db), + pytest.raises(JobTransitionError, match=f"Cannot retry job {job.id} due to invalid state \({job.status}\)"), + ): + manager.prepare_retry() + + def test_prepare_retry_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job retry.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to FAILED status and commit changes. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + session.commit() + + # Prepare retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.prepare_retry() + + # Commit pending changes made by start job. + session.commit() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 + assert job.progress_message == "Job retry prepared" + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + assert job.finished_at is None + assert job.metadata_["retry_history"] is not None + + +@pytest.mark.unit +class TestPrepareQueueUnit: + """Unit tests for job prepare for queue lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status != JobStatus.PENDING], + ) + def test_prepare_queue_raises_job_transition_error_when_managed_job_has_unretryable_status( + self, mock_job_manager, invalid_status, mock_job_run + ): + """Test job prepare queue failure due to invalid job status.""" + # Set initial job status to an invalid (non-pending) status. + mock_job_run.status = invalid_status + + # Prepare queue job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=re.escape(f"Cannot queue job {mock_job_manager.job_id} from status {invalid_status}"), + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_queue() + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.progress_message is None + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_prepare_queue_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job prepare queue failure due to exception during job object manipulation.""" + # Set initial job status to PENDING. Job status must be valid for this test. + initial_status = JobStatus.PENDING + mock_job_run.status = initial_status + + # Trigger: If any attribute access occurs on job, raise exception. If no access, return FAILED. + def get_or_error(*args): + if args: + raise exception + return initial_status + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job queue state", + ), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.prepare_queue() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == JobStatus.PENDING + assert mock_job_run.progress_message is None + + def test_prepare_queue_success(self, mock_job_manager, mock_job_run): + """Test successful job prepare queue.""" + # Set initial job status to PENDING. Job status must be valid for this test. + mock_job_run.status = JobStatus.PENDING + + # Prepare queue. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Mock the flag_modified function: mock objects don't have _sa_instance_state attribute required by SQLAlchemy + # funcs and it's easier to mock the functions that manipulate the state than to fully mock the state itself. + with ( + patch.object(mock_job_manager, "get_job", return_value=mock_job_run), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_queue() + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.QUEUED + assert mock_job_run.progress_message == "Job queued for execution" + + +@pytest.mark.integration +class TestPrepareQueue: + """Test job prepare for queue lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status != JobStatus.PENDING], + ) + def test_prepare_queue_failed_due_to_invalid_status( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): + """Test job prepare for queue failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to invalid state + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.flush() + + # Prepare queue job. Verify a JobTransitionError is raised due to the passed invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(manager.db), + pytest.raises( + JobTransitionError, + match=f"Cannot queue job {job.id} from status {job.status}", + ), + ): + manager.prepare_queue() + + def test_prepare_queue_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job prepare for queue.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Sample run should be in PENDING state from fixture setup, but verify to be sure. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Sample job run must be in PENDING state for this test." + + # Prepare queue. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + assert job.progress_message == "Job queued for execution" + + +@pytest.mark.unit +class TestResetJobUnit: + """Unit tests for job reset lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_reset_job_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job reset job failure due to exception during job object manipulation.""" + + # Trigger: If any attribute setting occurs on job, raise exception. Otherwise return FAILED. + # Set initial job status to FAILED. Job status is unimportant for this test (all statuses are resettable). + initial_status = JobStatus.FAILED + mock_job_run.status = initial_status + + def get_or_error(*args): + if args: + raise exception + return initial_status + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to reset job state", + ), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.reset_job() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.started_at is None + assert mock_job_run.finished_at is None + assert mock_job_run.progress_current is None + assert mock_job_run.progress_total is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.retry_count == 0 + assert mock_job_run.metadata_ == {} + + def test_reset_job_success(self, mock_job_manager, mock_job_run): + """Test successful job reset.""" + # Set initial job status to provided status. All statuses are resettable, so the actual status is not important. + mock_job_run.status = JobStatus.FAILED + + # Prepare queue. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.reset_job() + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.PENDING + assert mock_job_run.started_at is None + assert mock_job_run.finished_at is None + assert mock_job_run.progress_current is None + assert mock_job_run.progress_total is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.retry_count == 0 + assert mock_job_run.metadata_ == {} + + +@pytest.mark.integration +class TestResetJobIntegration: + """Test job reset lifecycle management.""" + + def test_reset_job_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job reset.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to a non-pending status and set various fields to non-default values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.started_at = "2023-12-31T23:59:59Z" + job.finished_at = "2024-01-01T00:00:00Z" + job.progress_current = 50 + job.progress_total = 100 + job.progress_message = "Halfway done" + job.error_message = "Test error message" + job.error_traceback = "Test error traceback" + job.failure_category = FailureCategory.UNKNOWN + job.retry_count = 2 + job.metadata_ = {"result": {}, "retry_history": [{"attempt": 1}, {"attempt": 2}]} + session.commit() + + # Reset job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.reset_job() + + # Commit pending changes made by reset job. + session.commit() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.progress_current is None + assert job.progress_total is None + assert job.progress_message is None + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + assert job.started_at is None + assert job.finished_at is None + assert job.retry_count == 0 + assert job.metadata_.get("retry_history") is None + + +@pytest.mark.unit +class TestJobProgressUpdateUnit: + """Unit tests for job progress update lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_update_progress_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job progress update failure due to exception during job object manipulation.""" + # Trigger: If any attribute setting occurs on job progress, raise exception. If only access, return initial progress. + initial_progress_current = mock_job_run.progress_current + + def get_or_error(*args): + if args: + raise exception + return initial_progress_current + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job progress", + ), + ): + type(mock_job_run).progress_current = PropertyMock(side_effect=get_or_error) + mock_job_manager.update_progress(50, 100, "Halfway done") + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_current is None + assert mock_job_run.progress_total is None + assert mock_job_run.progress_message is None + + def test_update_progress_success(self, mock_job_manager, mock_job_run): + """Test successful job progress update.""" + + # Update progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.update_progress(50, 100, "Halfway done") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 50 + assert mock_job_run.progress_total == 100 + assert mock_job_run.progress_message == "Halfway done" + + def test_update_progress_does_not_overwrite_old_message_when_no_new_message_is_provided( + self, mock_job_manager, mock_job_run + ): + """Test successful job progress update without message.""" + + # Set initial progress message to verify it is not overwritten. + mock_job_run.progress_message = "Old message" + + # Update progress without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.update_progress(75, 200) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 75 + assert mock_job_run.progress_total == 200 + assert mock_job_run.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.integration +class TestJobProgressUpdateIntegration: + """Test job progress update lifecycle management.""" + + def test_update_progress_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful progress update.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to None to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = None + job.progress_total = None + job.progress_message = None + session.commit() + + # Update progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.update_progress(50, 100, "Halfway done") + + # Commit pending changes made by update progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 50 + assert job.progress_total == 100 + assert job.progress_message == "Halfway done" + + def test_update_progress_success_does_not_overwrite_old_message_when_no_new_message_is_provided( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): + """Test successful progress update without message.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to None to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = None + job.progress_total = None + job.progress_message = "Old message" + session.commit() + + # Update progress without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.update_progress(75, 200) + + # Commit pending changes made by update progress. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 75 + assert job.progress_total == 200 + assert job.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.unit +class TestJobProgressStatusUpdateUnit: + """Unit tests for job progress status update lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_update_status_message_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job status message update failure due to exception during job object manipulation.""" + # Trigger: If any attribute setting occurs on job progress message, raise exception. If only access, return initial message. + initial_progress_message = mock_job_run.progress_message + + def get_or_error(*args): + if args: + raise exception + return initial_progress_message + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job status message", + ), + ): + type(mock_job_run).progress_message = PropertyMock(side_effect=get_or_error) + mock_job_manager.update_status_message("New status message") + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_message == initial_progress_message + + def test_update_status_message_success(self, mock_job_manager, mock_job_run): + """Test successful job status message update.""" + + # Update status message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.update_status_message("New status message") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_message == "New status message" + + +@pytest.mark.integration +class TestJobProgressStatusUpdate: + """Test job progress status update lifecycle management.""" + + def test_update_status_message_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful status message update.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress message to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_message = "Old status message" + session.commit() + + # Update status message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.update_status_message("New status message") + + # Commit pending changes made by update status message. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_message == "New status message" + + +@pytest.mark.unit +class TestJobProgressIncrementationUnit: + """Unit tests for job progress incrementation lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_increment_progress_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job progress incrementation failure due to exception during job object manipulation.""" + # Trigger: If any attribute access occurs on job progress, raise exception. If no access, return initial progress. + initial_progress_current = mock_job_run.progress_current + + def get_or_error(*args): + if args: + raise exception + return initial_progress_current + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to increment job progress", + ), + ): + type(mock_job_run).progress_current = PropertyMock(side_effect=get_or_error) + mock_job_manager.increment_progress(10, "Incrementing progress") + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_current is None + assert mock_job_run.progress_message is None + + def test_increment_progress_success(self, mock_job_manager, mock_job_run): + """Test successful job progress incrementation.""" + + # Increment progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.increment_progress(10, "Incrementing progress") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 10 + assert mock_job_run.progress_message == "Incrementing progress" + + def test_increment_progress_success_old_message_is_not_overwritten_when_none_provided( + self, mock_job_manager, mock_job_run + ): + """Test successful job progress incrementation without message.""" + + # Set initial progress message to verify it is not overwritten. + mock_job_run.progress_message = "Old message" + + # Increment progress without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.increment_progress(15) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 15 + assert mock_job_run.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.integration +class TestJobProgressIncrementationIntegration: + """Test job progress incrementation lifecycle management.""" + + @pytest.mark.parametrize( + "msg", + [None, "Incremented progress successfully"], + ) + def test_increment_progress_success(self, session, arq_redis, with_populated_job_data, sample_job_run, msg): + """Test successful progress incrementation.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to 0 to verify incrementation. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = 0 + job.progress_total = 100 + job.progress_message = "Test incrementation message" + session.commit() + + # Increment progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.increment_progress(10, msg) + + # Commit pending changes made by increment progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 10 + assert job.progress_total == 100 + assert job.progress_message == ( + msg if msg else "Test incrementation message" + ) # Message should remain unchanged if None + + def test_increment_progress_success_multiple_times( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): + """Test successful progress incrementation multiple times.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to 0 to verify incrementation. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = 0 + job.progress_total = 100 + session.commit() + + # Increment progress multiple times. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.increment_progress(20) + manager.increment_progress(30) + + # Commit pending changes made by increment progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 50 + assert job.progress_total == 100 + + def test_increment_progress_success_exceeding_total( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): + """Test successful progress incrementation exceeding total.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to 0 to verify incrementation. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = 0 + job.progress_total = 100 + session.commit() + + # Increment progress exceeding total. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.increment_progress(150) + + # Commit pending changes made by increment progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 150 + assert job.progress_total == 100 + + +class TestJobProgressTotalUpdateUnit: + """Unit tests for job progress total update lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_set_progress_total_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job progress total update failure due to exception during job object manipulation.""" + # Trigger: If any attribute access occurs on job progress total, raise exception. If no access, return initial total. + initial_progress_total = mock_job_run.progress_total + + def get_or_error(*args): + if args: + raise exception + return initial_progress_total + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job progress total state", + ), + ): + type(mock_job_run).progress_total = PropertyMock(side_effect=get_or_error) + mock_job_manager.set_progress_total(200) + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_total == initial_progress_total + + def test_set_progress_total_success(self, mock_job_manager, mock_job_run): + """Test successful job progress total update.""" + + # Set progress total. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.set_progress_total(200) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_total == 200 + + def test_set_progress_total_does_not_overwrite_old_message_when_no_new_message_is_provided( + self, mock_job_manager, mock_job_run + ): + """Test successful job progress total update without message.""" + + # Set initial progress message to verify it is not overwritten. + mock_job_run.progress_message = "Old message" + + # Set progress total without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.set_progress_total(300) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_total == 300 + assert mock_job_run.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.integration +class TestJobProgressTotalUpdateIntegration: + """Test job progress total update lifecycle management.""" + + def test_set_progress_total_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful progress total update.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress total and message to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_total = 100 + job.progress_message = "Ready to start" + session.commit() + + # Set progress total. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.set_progress_total(200, message="Updated total progress") + + # Commit pending changes made by set progress total. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_total == 200 + assert job.progress_message == "Updated total progress" + + +@pytest.mark.unit +class TestJobIsCancelledUnit: + """Unit tests for job is_cancelled lifecycle management.""" + + @pytest.mark.parametrize( + "status,expected_result", + [(status, status in CANCELLED_JOB_STATUSES) for status in JobStatus._member_map_.values()], + ) + def test_is_cancelled_success_not_cancelled(self, mock_job_manager, mock_job_run, status, expected_result): + """Test successful is_cancelled check when not cancelled.""" + # Set initial job status to a non-cancelled status. + mock_job_run.status = status + + # Check is_cancelled. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + result = mock_job_manager.is_cancelled() + + assert result == expected_result + + +@pytest.mark.integration +class TestJobIsCancelledIntegration: + """Test job is_cancelled lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status in CANCELLED_JOB_STATUSES], + ) + def test_is_cancelled_success_cancelled( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): + """Test successful is_cancelled check when cancelled.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Mark the job as cancelled in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Check is_cancelled. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.is_cancelled() + + # Verify the job is marked as cancelled. This method requires no persistance. + assert result is True + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status not in CANCELLED_JOB_STATUSES], + ) + def test_is_cancelled_success_not_cancelled( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): + """Test successful is_cancelled check when not cancelled.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Mark the job as not cancelled in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Check is_cancelled. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.is_cancelled() + + # Verify the job is not marked as cancelled. This method requires no persistance. + assert result is False + + +@pytest.mark.unit +class TestJobShouldRetryUnit: + """Unit tests for job should_retry lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + [ + pytest.param( + exc, + marks=pytest.mark.skip( + reason=( + "AttributeError is not propagated by mock objects: " + "Python's attribute lookup swallows AttributeError and mock returns a new mock instead. " + "See unittest.mock docs for details." + ) + ) + if isinstance(exc, AttributeError) + else (), + # ^ Only mark AttributeError for skip, others run as normal + ) + for exc in HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION + ], + ) + def test_should_retry_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """ + Test should_retry check failure due to exception during job object manipulation. + + AttributeError is skipped in this test because Python's mock machinery swallows + AttributeError raised by property getters and instead returns a new mock, so the + exception is not propagated as expected. See unittest.mock documentation for details. + ^^ or something like that... don't ask me to explain why. + """ + + # Trigger: If any attribute access occurs on job, raise exception. + def get_or_error(*args): + raise exception + + # Remove any instance attribute that could shadow the property + if "status" in mock_job_run.__dict__: + del mock_job_run.__dict__["status"] + + # In cases where we want to raise on attribute access, we need to override the entire property + # or else AttributeError won't be raised due to some internal Mock nuances I don't understand. + type(mock_job_run).status = property(get_or_error) + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to check retry eligibility state", + ), + ): + mock_job_manager.should_retry() + + @pytest.mark.parametrize( + "status,expected_result", + [ + (JobStatus.SUCCEEDED, False), + (JobStatus.CANCELLED, False), + (JobStatus.QUEUED, False), + (JobStatus.RUNNING, False), + (JobStatus.PENDING, False), + ], + ) + def test_should_retry_success_for_non_failed_statuses( + self, mock_job_manager, mock_job_run, status, expected_result + ): + """Test successful should_retry check.""" + # Set initial job status to provided status. + mock_job_run.status = status + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + result = mock_job_manager.should_retry() + + # Verify the result matches expected. + assert result == expected_result + + @pytest.mark.parametrize( + "retry_count,max_retries,failure_category,expected_result", + ( + [(0, 3, cat, True) for cat in RETRYABLE_FAILURE_CATEGORIES] # Initial retry, + + [(2, 3, RETRYABLE_FAILURE_CATEGORIES[0], True)] # Within retry limit (barely) + + [(3, 3, RETRYABLE_FAILURE_CATEGORIES[0], False)] # Exceeded retries + + [ + (1, 3, cat, False) + for cat in FailureCategory._member_map_.values() + if cat not in RETRYABLE_FAILURE_CATEGORIES + ] # Non-retryable failure categories + ), + ) + def test_should_retry_success_for_failed_status( + self, mock_job_manager, mock_job_run, retry_count, max_retries, failure_category, expected_result + ): + """Test successful should_retry check for failed status.""" + # Set initial job status to FAILED with provided parameters. + mock_job_run.status = JobStatus.FAILED + mock_job_run.retry_count = retry_count + mock_job_run.max_retries = max_retries + mock_job_run.failure_category = failure_category + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + result = mock_job_manager.should_retry() + + # Verify the result matches expected. + assert result == expected_result + + +@pytest.mark.integration +class TestJobShouldRetryIntegration: + """Test job should_retry lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status != JobStatus.FAILED], + ) + def test_should_retry_success_non_failed_jobs_should_not_retry( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): + """Test successful should_retry check (only jobs in failed states may retry).""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to non-failed state + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should not retry. This method requires no persistance. + assert result is False + + def test_should_retry_success_exceeded_retry_attempts_should_not_retry( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): + """Test successful should_retry check with no retry attempts left.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to failed state with no retries left + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.max_retries = 3 + job.retry_count = 3 + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should not retry. This method requires no persistance. + assert result is False + + def test_should_retry_success_failure_category_is_not_retryable( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): + """Test successful should_retry check with non-retryable failure category.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to failed state with non-retryable failure category + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.max_retries = 3 + job.retry_count = 1 + job.failure_category = FailureCategory.UNKNOWN + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should not retry. This method requires no persistance. + assert result is False + + def test_should_retry_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful should_retry check with retryable failure category.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to failed state with retryable failure category + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.max_retries = 3 + job.retry_count = 1 + job.failure_category = RETRYABLE_FAILURE_CATEGORIES[0] + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should retry. This method requires no persistance. + assert result is True + + +@pytest.mark.unit +class TestGetJobUnit: + """Unit tests for job retrieval.""" + + def test_get_job_wraps_database_connection_error_when_encounters_sqlalchemy_error(self, mock_job_run): + """Test job retrieval failure during job fetch.""" + + # Prepare mock JobManager with mocked DB session that will raise SQLAlchemyError on query. + # We don't use the default fixture here since it usually wraps this function. + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + manager = object.__new__(JobManager) + manager.db = mock_db + manager.redis = mock_redis + manager.job_id = mock_job_run.id + + with ( + TransactionSpy.mock_database_execution_failure(manager.db), + pytest.raises(DatabaseConnectionError, match=f"Failed to fetch job {mock_job_run.id}"), + ): + manager.get_job() + + +@pytest.mark.integration +class TestGetJobIntegration: + """Test job retrieval.""" + + def test_get_job_success(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test successful job retrieval.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Retrieve job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + job = manager.get_job() + + # Verify the retrieved job matches expected. + assert job.id == sample_job_run.id + assert job.status == JobStatus.PENDING + + def test_get_job_raises_job_not_found_error_when_job_does_not_exist( + self, session, arq_redis, with_populated_job_data + ): + """Test job retrieval failure when job does not exist.""" + with pytest.raises(DatabaseConnectionError, match="Failed to fetch job 9999"), TransactionSpy.spy(session): + JobManager(session, arq_redis, job_id=9999) # Non-existent job ID + + +@pytest.mark.integration +class TestJobManagerJob: + """Test overall job lifecycle management.""" + + def test_full_successful_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test full job lifecycle from start to completion.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING, "Job status should be RUNNING after starting" + assert job.started_at is not None, "Job started_at should be set after starting" + + # Set initial progress + with TransactionSpy.spy(manager.db): + manager.update_progress(0, 100, "Job started") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 0 + assert job.progress_total == 100 + assert job.progress_message == "Job started" + + # Update status message + with TransactionSpy.spy(manager.db): + manager.update_status_message("Began processing data") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_message == "Began processing data" + + # Set progress total + with TransactionSpy.spy(manager.db): + manager.set_progress_total(200, "Set total work units") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_total == 200 + assert job.progress_message == "Set total work units" + + # Increment progress + with TransactionSpy.spy(manager.db): + manager.increment_progress(100, "Halfway done") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 100 + assert job.progress_message == "Halfway done" + + # Increment progress again + with TransactionSpy.spy(manager.db): + manager.increment_progress(100, "All done") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 200 + assert job.progress_message == "All done" + + # Complete job + with TransactionSpy.spy(manager.db): + manager.succeed_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + assert job.finished_at is not None + + # Verify job is not cancelled and should not retry + assert manager.is_cancelled() is False + assert manager.should_retry() is False + + # Verify final job state + final_job = manager.get_job() + assert final_job.status == JobStatus.SUCCEEDED + assert final_job.progress_current == 200 + assert final_job.progress_total == 200 + + def test_full_cancelled_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test full job lifecycle for a cancelled job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Cancel job + with TransactionSpy.spy(manager.db): + manager.cancel_job({"status": "ok", "data": {"reason": "User requested cancellation"}, "exception": None}) + session.flush() + + # Verify job is cancelled + assert manager.is_cancelled() is True + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + assert job.finished_at is not None + + def test_full_skipped_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test full job lifecycle for a skipped job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Skip job + with TransactionSpy.spy(manager.db): + manager.skip_job(result={"status": "ok", "data": {"reason": "Job not needed"}, "exception": None}) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + assert job.finished_at is not None + + def test_full_failed_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test full job lifecycle for a failed job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job + exc = Exception("An error occurred") + with TransactionSpy.spy(manager.db): + manager.fail_job(error=exc, result={"status": "failed", "data": {}, "exception": exc}) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.finished_at is not None + assert job.error_message == "An error occurred" + assert job.error_traceback is not None + + def test_full_retried_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test full job lifecycle for a retried job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job + exc = Exception("Temporary error") + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=exc, + result={"status": "failed", "data": {}, "exception": exc}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # TODO: Use some failure method added later to set failure category to retryable during the + # call to fail_job above. For now, we manually set it here. + job.failure_category = RETRYABLE_FAILURE_CATEGORIES[0] + session.commit() + + # Should retry + assert manager.should_retry() is True + + # Prepare retry + with TransactionSpy.spy(manager.db): + manager.prepare_retry() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 + + def test_full_reset_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): + """Test full job lifecycle for a reset job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job + exc = Exception("Some error") + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=exc, + result={"status": "failed", "data": {}, "exception": exc}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Retry job + with TransactionSpy.spy(manager.db): + manager.prepare_retry() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 + + # Queeue job again + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job again + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job again + exc = Exception("Another error") + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=exc, + result={"status": "failed", "data": {}, "exception": exc}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.retry_count == 1 + + # Reset job + with TransactionSpy.spy(manager.db): + manager.reset_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.progress_current is None + assert job.progress_total is None + assert job.retry_count == 0 diff --git a/tests/worker/lib/managers/test_pipeline_manager.py b/tests/worker/lib/managers/test_pipeline_manager.py new file mode 100644 index 00000000..7cb7931e --- /dev/null +++ b/tests/worker/lib/managers/test_pipeline_manager.py @@ -0,0 +1,3757 @@ +# ruff: noqa: E402 +""" +Comprehensive test suite for PipelineManager class. + +Tests cover all aspects of pipeline coordination, job dependency management, +status updates, error handling, and database interactions including new methods +for pipeline monitoring, job retry management, and restart functionality. +""" + +import pytest + +pytest.importorskip("arq") + +import datetime +from unittest.mock import Mock, PropertyMock, patch + +from arq import ArqRedis +from arq.jobs import Job as ArqJob +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.managers import JobManager +from mavedb.worker.lib.managers.constants import ( + ACTIVE_JOB_STATUSES, + CANCELLED_PIPELINE_STATUSES, + RUNNING_PIPELINE_STATUSES, + TERMINAL_PIPELINE_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + PipelineCoordinationError, + PipelineStateError, + PipelineTransitionError, +) +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +from tests.helpers.transaction_spy import TransactionSpy + +HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION = ( + AttributeError("Mock attribute error"), + KeyError("Mock key error"), + TypeError("Mock type error"), + ValueError("Mock value error"), +) + + +@pytest.mark.integration +class TestPipelineManagerInitialization: + """Test PipelineManager initialization and setup.""" + + def test_init_with_valid_pipeline(self, session, arq_redis, with_populated_job_data, sample_pipeline): + """Test successful initialization with valid pipeline ID.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + assert manager.db == session + assert manager.redis == arq_redis + assert manager.pipeline_id == sample_pipeline.id + + def test_init_with_invalid_pipeline_id(self, session, arq_redis): + """Test initialization failure with non-existent pipeline ID.""" + pipeline_id = 999 # Assuming this ID does not exist + with pytest.raises(DatabaseConnectionError, match=f"Failed to get pipeline {pipeline_id}"): + PipelineManager(session, arq_redis, pipeline_id) + + def test_init_with_database_error(self, session, arq_redis, with_populated_job_data, sample_pipeline): + """Test initialization failure with database connection error.""" + pipeline_id = sample_pipeline.id + + with ( + TransactionSpy.mock_database_execution_failure(session), + pytest.raises(DatabaseConnectionError, match=f"Failed to get pipeline {pipeline_id}"), + ): + PipelineManager(session, arq_redis, pipeline_id) + + +@pytest.mark.unit +class TestStartPipelineUnit: + """Unit tests for starting a pipeline.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "coordinate_after_start", + [True, False], + ) + async def test_start_pipeline_successful(self, mock_pipeline_manager, coordinate_after_start): + """Test successful pipeline start from CREATED state.""" + with ( + patch.object( + mock_pipeline_manager, + "get_pipeline", + return_value=Mock(spec=Pipeline, status=PipelineStatus.CREATED), + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.start_pipeline(coordinate=coordinate_after_start) + + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + if coordinate_after_start: + mock_coordinate.assert_called_once() + else: + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "current_status", + [status for status in PipelineStatus._member_map_.values() if status != PipelineStatus.CREATED], + ) + async def test_start_pipeline_non_created_state(self, mock_pipeline_manager, current_status): + """Test pipeline start failure when not in CREATED state.""" + with ( + patch.object( + mock_pipeline_manager, + "get_pipeline_status", + return_value=current_status, + ), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is in state {current_status} and may not be started", + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.start_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + +@pytest.mark.integration +class TestStartPipelineIntegration: + """Integration tests for starting a pipeline.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "coordinate_after_start", + [True, False], + ) + async def test_start_pipeline_successful( + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, coordinate_after_start + ): + """Test successful pipeline start from CREATED state.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with TransactionSpy.spy(session, expect_flush=True): + await manager.start_pipeline(coordinate=coordinate_after_start) + + # Commit the session to persist changes + session.commit() + + # Verify pipeline status is now RUNNING + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Verify the initial job was queued if we are coordinating after start + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + jobs = await arq_redis.queued_jobs() + + if coordinate_after_start: + assert job.status == JobStatus.QUEUED + assert jobs[0].function == sample_job_run.job_function + else: + assert job.status == JobStatus.PENDING + assert len(jobs) == 0 + + @pytest.mark.asyncio + async def test_start_pipeline_no_jobs(self, session, arq_redis, with_populated_job_data, sample_empty_pipeline): + """Test pipeline start when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with TransactionSpy.spy(session, expect_flush=True): + await manager.start_pipeline(coordinate=True) + + # Commit the session to persist changes + session.commit() + + # Verify pipeline status is now SUCCEEDED since there are no jobs + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_empty_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.SUCCEEDED + + # Verify no jobs were enqueued in Redis + jobs = await arq_redis.queued_jobs() + assert len(jobs) == 0 + + +@pytest.mark.unit +class TestCoordinatePipelineUnit: + """Unit tests for pipeline coordination logic.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "new_status", + CANCELLED_PIPELINE_STATUSES, + ) + async def test_coordinate_pipeline_cancels_remaining_jobs_status_transitions_to_cancellable( + self, + mock_pipeline_manager, + new_status, + ): + """Test that remaining jobs are cancelled if pipeline transitions to a cancelable status.""" + with ( + patch.object( + mock_pipeline_manager, "transition_pipeline_status", return_value=new_status + ) as mock_transition, + patch.object(mock_pipeline_manager, "cancel_remaining_jobs", return_value=None) as mock_cancel, + patch.object(mock_pipeline_manager, "enqueue_ready_jobs", return_value=None) as mock_enqueue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.coordinate_pipeline() + + mock_transition.assert_called_once() + mock_cancel.assert_called_once_with(reason="Pipeline failed or cancelled") + mock_enqueue.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "new_status", + RUNNING_PIPELINE_STATUSES, + ) + async def test_coordinate_pipeline_enqueues_jobs_when_status_transitions_to_running( + self, mock_pipeline_manager, new_status + ): + """Test coordination after successful job completion.""" + with ( + patch.object( + mock_pipeline_manager, "transition_pipeline_status", return_value=new_status + ) as mock_transition, + patch.object(mock_pipeline_manager, "cancel_remaining_jobs", return_value=None) as mock_cancel, + patch.object(mock_pipeline_manager, "enqueue_ready_jobs", return_value=None) as mock_enqueue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.coordinate_pipeline() + + assert mock_transition.call_count == 2 # Called once before and once after enqueuing jobs + mock_cancel.assert_not_called() + mock_enqueue.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "new_status", + [ + status + for status in PipelineStatus._member_map_.values() + if status not in CANCELLED_PIPELINE_STATUSES + RUNNING_PIPELINE_STATUSES + ], + ) + async def test_coordinate_pipeline_noop_for_other_status_transitions(self, mock_pipeline_manager, new_status): + """Test coordination no-op for non-cancelled/running status transitions.""" + with ( + patch.object( + mock_pipeline_manager, "transition_pipeline_status", return_value=new_status + ) as mock_transition, + patch.object(mock_pipeline_manager, "cancel_remaining_jobs", return_value=None) as mock_cancel, + patch.object(mock_pipeline_manager, "enqueue_ready_jobs", return_value=None) as mock_enqueue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.coordinate_pipeline() + + mock_transition.assert_called_once() + mock_cancel.assert_not_called() + mock_enqueue.assert_not_called() + + +@pytest.mark.integration +class TestCoordinatePipelineIntegration: + """Test pipeline coordination after job completion.""" + + @pytest.mark.asyncio + async def test_coordinate_pipeline_transitions_pipeline_to_failed_after_job_failure( + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful pipeline coordination and job enqueuing after job completion.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job in the pipeline to a terminal status + sample_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were enqueued but that jobs were cancelled + mock_cancel.assert_called_once() + mock_enqueue.assert_not_called() + + # Verify that the pipeline status is now FAILED + assert manager.get_pipeline().status == PipelineStatus.FAILED + + # Verify that the failed job remains failed + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Verify that the pending job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + @pytest.mark.asyncio + async def test_coordinate_pipeline_transitions_pipeline_to_cancelled_after_pipeline_is_cancelled( + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful pipeline coordination and job enqueuing after pipeline cancellation .""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to a cancelled status + manager.set_pipeline_status(PipelineStatus.CANCELLED) + session.commit() + + # Set the job in the pipeline to a running status + sample_job_run.status = JobStatus.RUNNING + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were enqueued but that jobs were cancelled + mock_cancel.assert_called_once() + mock_enqueue.assert_not_called() + + # Verify that the pipeline status is now CANCELLED + assert manager.get_pipeline().status == PipelineStatus.CANCELLED + + # Verify that the running job transitions to cancelled + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the pending dependent job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + @pytest.mark.asyncio + async def test_coordinate_running_pipeline_enqueues_ready_jobs( + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful pipeline coordination and job enqueuing when jobs are still pending.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to a running status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were cancelled but that jobs were enqueued + mock_cancel.assert_not_called() + mock_enqueue.assert_called_once() + + # Verify that the non-dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is still pending (since its dependency is not yet complete) + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_status", + [PipelineStatus.CREATED, PipelineStatus.PAUSED, PipelineStatus.SUCCEEDED, PipelineStatus.PARTIAL], + ) + async def test_coordinate_pipeline_noop( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + initial_status, + ): + """Test successful pipeline coordination and job enqueuing when jobs are still pending.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to a cancelled status + manager.set_pipeline_status(initial_status) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were enqueued or cancelled + mock_cancel.assert_not_called() + mock_enqueue.assert_not_called() + + # Verify that the job is still pending + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + # Verify that the dependent job is still pending + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestTransitionPipelineStatusUnit: + """Test pipeline status transition logic.""" + + @pytest.mark.parametrize( + "existing_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_terminal_state_results_in_retention_of_terminal_states( + self, mock_pipeline_manager, existing_status, mock_pipeline + ): + """No jobs in pipeline should result in no status change, so long as the pipeline is in a terminal state.""" + mock_pipeline.status = existing_status + + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result is existing_status + + mock_set_status.assert_not_called() + + def test_paused_state_results_in_retention_of_paused_state(self, mock_pipeline_manager, mock_pipeline): + """No jobs in pipeline should result in no status change when pipeline is paused.""" + mock_pipeline.status = PipelineStatus.PAUSED + + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result is PipelineStatus.PAUSED + + mock_set_status.assert_not_called() + + @pytest.mark.parametrize( + "existing_status", + [ + status + for status in PipelineStatus._member_map_.values() + if status not in TERMINAL_PIPELINE_STATUSES + [PipelineStatus.PAUSED] + ], + ) + def test_no_jobs_results_in_succeeded_state_if_not_terminal( + self, mock_pipeline_manager, existing_status, mock_pipeline + ): + """No jobs in pipeline should result in SUCCEEDED state if not already terminal.""" + mock_pipeline.status = existing_status + mock_pipeline.finished_at = None + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == PipelineStatus.SUCCEEDED + + mock_set_status.assert_called_once_with(PipelineStatus.SUCCEEDED) + + @pytest.mark.parametrize( + "job_counts,expected_status", + [ + # Any failure trumps everything + ({JobStatus.SUCCEEDED: 10, JobStatus.FAILED: 1}, PipelineStatus.FAILED), + # Running or queued jobs without failures keep pipeline running + ({JobStatus.SUCCEEDED: 5, JobStatus.FAILED: 0, JobStatus.RUNNING: 2}, PipelineStatus.RUNNING), + ({JobStatus.SUCCEEDED: 5, JobStatus.FAILED: 0, JobStatus.QUEUED: 3}, PipelineStatus.RUNNING), + # All succeeded + ({JobStatus.SUCCEEDED: 5}, PipelineStatus.SUCCEEDED), + # Mix of terminal states without failures + ({JobStatus.SUCCEEDED: 3, JobStatus.SKIPPED: 2}, PipelineStatus.PARTIAL), + ({JobStatus.SUCCEEDED: 1, JobStatus.CANCELLED: 1}, PipelineStatus.PARTIAL), + # All cancelled + ({JobStatus.CANCELLED: 5}, PipelineStatus.CANCELLED), + # All skipped + ({JobStatus.SKIPPED: 4}, PipelineStatus.CANCELLED), + # Some cancelled and skipped + ({JobStatus.CANCELLED: 2, JobStatus.SKIPPED: 3}, PipelineStatus.CANCELLED), + # Inconsistent state + ({JobStatus.CANCELLED: 2, JobStatus.SKIPPED: 1, JobStatus.SUCCEEDED: 1, None: 3}, PipelineStatus.PARTIAL), + ], + ) + def test_pipeline_status_determination_based_on_job_counts( + self, mock_pipeline_manager, job_counts, expected_status, mock_pipeline + ): + """Test pipeline status determination based on job counts.""" + mock_pipeline.status = PipelineStatus.CREATED + mock_pipeline.finished_at = None + + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value=job_counts), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == expected_status + + mock_set_status.assert_called_once_with(expected_status) + + @pytest.mark.parametrize( + "job_counts,existing_status", + [ + ({JobStatus.PENDING: 5}, PipelineStatus.CREATED), + ({JobStatus.SUCCEEDED: 5, JobStatus.PENDING: 3}, PipelineStatus.RUNNING), + ({JobStatus.PENDING: 2, JobStatus.SKIPPED: 4}, PipelineStatus.RUNNING), + ({JobStatus.PENDING: 1, JobStatus.CANCELLED: 1}, PipelineStatus.RUNNING), + ], + ) + def test_pipeline_status_determination_pending_jobs_do_not_change_status( + self, mock_pipeline_manager, job_counts, existing_status, mock_pipeline + ): + """Test that presence of pending jobs does not change pipeline status.""" + mock_pipeline.status = existing_status + + with ( + patch.object( + mock_pipeline_manager, + "get_job_counts_by_status", + return_value=job_counts, + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == existing_status + + mock_set_status.assert_not_called() + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_pipeline_status_determination_throws_state_error_for_handled_exceptions( + self, mock_pipeline_manager, exception + ): + """Test that handled exceptions during status determination raise PipelineStateError.""" + + # Mocks exception in first try/except + with ( + patch.object( + mock_pipeline_manager, + "get_job_counts_by_status", + return_value=Mock(side_effect=exception), + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + pytest.raises(PipelineStateError), + ): + mock_pipeline_manager.transition_pipeline_status() + mock_set_status.assert_not_called() + + # Mocks exception in second try/except + with ( + patch.object( + mock_pipeline_manager, + "get_job_counts_by_status", + return_value={JobStatus.SUCCEEDED: 5}, + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", side_effect=exception) as mock_set_status, + patch.object( + mock_pipeline_manager, "get_pipeline", return_value=Mock(spec=Pipeline, status=PipelineStatus.CREATED) + ), + pytest.raises(PipelineStateError), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.transition_pipeline_status() + + def test_pipeline_status_determination_no_change(self, mock_pipeline_manager, mock_pipeline): + """Test that no status change occurs if pipeline status remains the same.""" + mock_pipeline.status = PipelineStatus.SUCCEEDED + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={JobStatus.SUCCEEDED: 5}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == PipelineStatus.SUCCEEDED + + mock_set_status.assert_not_called() + + +class TestTransitionPipelineStatusIntegration: + """Integration tests for pipeline status transition logic.""" + + @pytest.mark.parametrize( + "initial_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_pipeline_status_transition_noop_when_status_is_terminal( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + initial_status, + ): + """Test that pipeline status remains unchanged when already in a terminal state.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial pipeline status + manager.set_pipeline_status(initial_status) + session.commit() + + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status remains unchanged + assert new_status == initial_status + assert manager.get_pipeline_status() == initial_status + + def test_pipeline_status_transition_noop_when_status_is_paused( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + ): + """Test that pipeline status remains unchanged when in PAUSED state.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial pipeline status to PAUSED + manager.set_pipeline_status(PipelineStatus.PAUSED) + session.commit() + + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status remains unchanged + assert new_status == PipelineStatus.PAUSED + assert manager.get_pipeline_status() == PipelineStatus.PAUSED + + @pytest.mark.parametrize( + "initial_status,expected_status", + [ + ( + status, + status if status in TERMINAL_PIPELINE_STATUSES + [PipelineStatus.PAUSED] else PipelineStatus.SUCCEEDED, + ) + for status in PipelineStatus._member_map_.values() + ], + ) + def test_pipeline_status_transition_when_no_jobs_in_pipeline( + self, + session, + arq_redis, + with_populated_job_data, + initial_status, + expected_status, + sample_empty_pipeline, + ): + """Test that pipeline status transitions to SUCCEEDED when there are no jobs in a + non-terminal pipeline. If the pipeline is already in a terminal state, it should remain unchanged.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set initial pipeline status + manager.set_pipeline_status(initial_status) + session.commit() + + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is the expected status and that + # the status was persisted to the transaction + assert new_status == expected_status + assert manager.get_pipeline_status() == expected_status + + @pytest.mark.parametrize( + "initial_status,job_updates,expected_status", + [ + # Some failed -> failed + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.FAILED}, PipelineStatus.FAILED), + # Some running -> running + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.RUNNING}, PipelineStatus.RUNNING), + # Some queued -> running + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.QUEUED}, PipelineStatus.RUNNING), + # Some pending => no change (handled separately via a second call to transition after enqueuing jobs) + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.PENDING}, PipelineStatus.CREATED), + (PipelineStatus.RUNNING, {1: JobStatus.SUCCEEDED, 2: JobStatus.PENDING}, PipelineStatus.RUNNING), + # All succeeded -> succeeded + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.SUCCEEDED}, PipelineStatus.SUCCEEDED), + # All cancelled -> cancelled + (PipelineStatus.RUNNING, {1: JobStatus.CANCELLED, 2: JobStatus.CANCELLED}, PipelineStatus.CANCELLED), + # Mix of succeeded and skipped -> partial + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.SKIPPED}, PipelineStatus.PARTIAL), + # Mix of succeeded and cancelled -> partial + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.CANCELLED}, PipelineStatus.PARTIAL), + # Mix of cancelled and skipped -> cancelled + (PipelineStatus.CREATED, {1: JobStatus.CANCELLED, 2: JobStatus.SKIPPED}, PipelineStatus.CANCELLED), + ], + ) + def test_pipeline_status_transitions( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + initial_status, + job_updates, + expected_status, + ): + """Test pipeline status transitions based on job status updates.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial pipeline status + manager.set_pipeline_status(initial_status) + session.commit() + + # Update job statuses as per test case + for job_run in sample_pipeline.job_runs: + if job_run.id in job_updates: + job_run.status = job_updates[job_run.id] + session.commit() + + # Perform status transition and verify return state + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + assert new_status == expected_status + session.commit() + + # Verify expected pipeline status is persisted + pipeline = manager.get_pipeline() + assert pipeline.status == expected_status + + +@pytest.mark.unit +class TestEnqueueReadyJobsUnit: + """Test enqueuing of ready jobs (both independent and dependent).""" + + @pytest.mark.parametrize( + "pipeline_status", + [status for status in PipelineStatus._member_map_.values() if status not in RUNNING_PIPELINE_STATUSES], + ) + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_raises_if_pipeline_not_running(self, mock_pipeline_manager, pipeline_status): + """Test that job enqueuing raises a state error if pipeline is not in RUNNING status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + pytest.raises(PipelineStateError, match="cannot enqueue jobs"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_skips_if_no_jobs(self, mock_pipeline_manager): + """Test that job enqueuing skips if there are no pending jobs.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, + "get_pending_jobs", + return_value=[], + ), + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + # Should complete without error + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "should_skip", + [False, True], + ) + async def test_enqueue_ready_jobs_checks_if_jobs_are_reachable_if_cant_enqueue( + self, mock_pipeline_manager, mock_job_manager, should_skip + ): + """Test that job enqueuing skips jobs which are unreachable if any exist.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, "get_pending_jobs", return_value=[Mock(spec=JobRun, id=1, urn="test:job:1")] + ), + patch.object(mock_pipeline_manager, "can_enqueue_job", return_value=False), + patch.object( + mock_pipeline_manager, "should_skip_job_due_to_dependencies", return_value=(should_skip, "Reason") + ) as mock_should_skip, + patch.object(mock_job_manager, "skip_job", return_value=None) as mock_skip_job, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + mock_should_skip.assert_called_once() + mock_skip_job.assert_called_once() if should_skip else mock_skip_job.assert_not_called() + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_raises_if_arq_enqueue_fails(self, mock_pipeline_manager, mock_job_manager): + """Test that job enqueuing raises an error if ARQ enqueue fails.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, "get_pending_jobs", return_value=[Mock(spec=JobRun, id=1, urn="test:job:1")] + ), + patch.object(mock_pipeline_manager, "can_enqueue_job", return_value=True), + patch.object(mock_job_manager, "prepare_queue", return_value=None) as mock_prepare_queue, + patch.object( + mock_pipeline_manager, "_enqueue_in_arq", side_effect=PipelineCoordinationError("ARQ enqueue failed") + ), + pytest.raises(PipelineCoordinationError, match="ARQ enqueue failed"), + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + mock_prepare_queue.assert_called_once() + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_successful_enqueue(self, mock_pipeline_manager, mock_job_manager): + """Test successful job enqueuing.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, "get_pending_jobs", return_value=[Mock(spec=JobRun, id=1, urn="test:job:1")] + ), + patch.object(mock_pipeline_manager, "can_enqueue_job", return_value=True), + patch.object(mock_pipeline_manager, "_enqueue_in_arq", return_value=None) as mock_enqueue, + patch.object(mock_job_manager, "prepare_queue", return_value=None) as mock_prepare_queue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + mock_prepare_queue.assert_called_once() + mock_enqueue.assert_called_once() + + +@pytest.mark.integration +class TestEnqueueReadyJobsIntegration: + """Integration tests for enqueuing of ready jobs.""" + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful enqueuing of ready jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with TransactionSpy.spy(session, expect_flush=True): + await manager.enqueue_ready_jobs() + + # Verify that the independent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is still pending (since its dependency is not yet complete) + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + # Verify the queued ARQ job exists and is the job we expect + arq_job = await arq_redis.queued_jobs() + assert len(arq_job) == 1 + assert arq_job[0].function == sample_job_run.job_function + + # Verify the pipeline is still in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_integration_with_unreachable_job( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + sample_job_dependency, + ): + """Test enqueuing of ready jobs skips unreachable jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Make the dependent job unreachable by setting the sample_job to cancelled. + sample_job_run.status = JobStatus.CANCELLED + session.commit() + + with TransactionSpy.spy(session, expect_flush=True): + await manager.enqueue_ready_jobs() + + # Verify that the dependent job is marked as skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + # Verify nothing was enqueued for the dependent job + arq_job = await arq_redis.queued_jobs() + assert len(arq_job) == 0 + + # Verify the pipeline is still in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_with_empty_pipeline( + self, session, arq_redis, with_populated_job_data, sample_empty_pipeline + ): + """Test enqueuing of ready jobs in an empty pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with TransactionSpy.spy(session, expect_flush=True): + await manager.enqueue_ready_jobs() + + # Verify nothing was enqueued + arq_job = await arq_redis.queued_jobs() + assert len(arq_job) == 0 + + # Verify the pipeline is still in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_bubbles_pipeline_coordination_error_for_any_exception_during_enqueue( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test that any exception during job enqueuing raises PipelineCoordinationError.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object( + manager.redis, + "enqueue_job", + side_effect=Exception("Unexpected error during enqueue"), + ), + pytest.raises(PipelineCoordinationError, match="Failed to enqueue job in ARQ"), + ): + await manager.enqueue_ready_jobs() + + +@pytest.mark.unit +class TestCancelRemainingJobsUnit: + """Test cancellation of remaining jobs.""" + + def test_cancel_remaining_jobs_no_active_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test job cancellation when there are no active jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "get_active_jobs", + return_value=[], + ), + patch.object(mock_job_manager, "cancel_job", return_value=None) as mock_cancel_job, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.cancel_remaining_jobs() + + mock_cancel_job.assert_not_called() + + @pytest.mark.parametrize( + "job_status, expected_status", + [(JobStatus.QUEUED, JobStatus.CANCELLED), (JobStatus.RUNNING, JobStatus.CANCELLED)], + ) + def test_cancel_remaining_jobs_cancels_queued_and_running_jobs( + self, mock_pipeline_manager, mock_job_manager, mock_job_run, job_status, expected_status + ): + """Test successful cancellation of remaining jobs.""" + mock_job_run.status = job_status + cancellation_result = {"status": expected_status, "reason": "Pipeline cancelled"} + + with ( + patch.object( + mock_pipeline_manager, + "get_active_jobs", + return_value=[mock_job_run], + ), + patch.object(mock_job_manager, "cancel_job", return_value=None) as mock_cancel_job, + patch( + "mavedb.worker.lib.managers.pipeline_manager.construct_bulk_cancellation_result", + return_value=cancellation_result, + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.cancel_remaining_jobs() + + mock_cancel_job.assert_called_once_with(result=cancellation_result) + + @pytest.mark.parametrize( + "job_status, expected_status", + [ + (JobStatus.PENDING, JobStatus.SKIPPED), + ], + ) + def test_cancel_remaining_jobs_skips_pending_jobs( + self, mock_pipeline_manager, mock_job_manager, mock_job_run, job_status, expected_status + ): + """Test successful cancellation of remaining jobs.""" + mock_job_run.status = job_status + cancellation_result = {"status": expected_status, "reason": "Pipeline cancelled"} + + with ( + patch.object( + mock_pipeline_manager, + "get_active_jobs", + return_value=[mock_job_run], + ), + patch.object(mock_job_manager, "skip_job", return_value=None) as mock_skip_job, + patch( + "mavedb.worker.lib.managers.pipeline_manager.construct_bulk_cancellation_result", + return_value=cancellation_result, + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.cancel_remaining_jobs() + + mock_skip_job.assert_called_once_with(result=cancellation_result) + + +@pytest.mark.integration +class TestCancelRemainingJobsIntegration: + """Integration tests for cancellation of remaining jobs.""" + + def test_cancel_remaining_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful cancellation of remaining jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + manager.cancel_remaining_jobs() + + # Commit the transaction + session.commit() + + # Verify that the running job transitions to cancelled + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the pending dependent job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + def test_cancel_remaining_jobs_integration_no_active_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test cancellation of remaining jobs when there are no active jobs.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + manager.cancel_remaining_jobs() + + # Commit the transaction + session.commit() + + # Should complete without error + + +@pytest.mark.unit +class TestCancelPipelineUnit: + """Test cancellation of pipelines.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + async def test_cancel_pipeline_raises_transition_error_if_already_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test that pipeline cancellation raises an error if already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is in terminal state", + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.cancel_pipeline(reason="Testing cancellation") + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + [status for status in PipelineStatus._member_map_.values() if status not in TERMINAL_PIPELINE_STATUSES], + ) + async def test_cancel_pipeline_successful_cancellation_if_not_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test successful pipeline cancellation if not already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.cancel_pipeline(reason="Testing cancellation") + + mock_coordinate.assert_called_once() + mock_set_status.assert_called_once_with(PipelineStatus.CANCELLED) + + +@pytest.mark.integration +class TestCancelPipelineIntegration: + """Integration tests for cancellation of pipelines.""" + + @pytest.mark.asyncio + async def test_cancel_pipeline_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful cancellation of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.cancel_pipeline(reason="Testing cancellation") + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in CANCELLED status + assert manager.get_pipeline_status() == PipelineStatus.CANCELLED + + # Verify that the running job transitions to cancelled + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the pending dependent job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + @pytest.mark.asyncio + async def test_cancel_pipeline_integration_already_terminal( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test that cancelling a pipeline already in terminal status raises an error.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to SUCCEEDED status + manager.set_pipeline_status(PipelineStatus.SUCCEEDED) + session.commit() + + # Set the job status to something that would normally be cancellable + sample_job_run.status = JobStatus.PENDING + session.commit() + + with ( + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {manager.pipeline_id} is in terminal state", + ), + TransactionSpy.spy(session), + ): + await manager.cancel_pipeline(reason="Testing cancellation") + + # Commit the transaction + session.commit() + + # Verify the pipeline status remains SUCCEEDED + assert manager.get_pipeline_status() == PipelineStatus.SUCCEEDED + + # Verify that the job status remains unchanged + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestPausePipelineUnit: + """Test pausing of pipelines.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + async def test_pause_pipeline_raises_transition_error_if_already_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test that pipeline pausing raises an error if already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is in terminal state", + ), + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + ): + await mock_pipeline_manager.pause_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_pause_pipeline_raises_transition_error_if_already_paused(self, mock_pipeline_manager): + """Test that pipeline pausing raises an error if already paused.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.PAUSED), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is already paused", + ), + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + ): + await mock_pipeline_manager.pause_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + [ + status + for status in PipelineStatus._member_map_.values() + if status not in TERMINAL_PIPELINE_STATUSES and status != PipelineStatus.PAUSED + ], + ) + async def test_pause_pipeline_successful_pausing_if_not_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test successful pipeline pausing if not already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.pause_pipeline() + + mock_coordinate.assert_called_once() + mock_set_status.assert_called_once_with(PipelineStatus.PAUSED) + + +@pytest.mark.integration +class TestPausePipelineIntegration: + """Integration tests for pausing of pipelines.""" + + @pytest.mark.asyncio + async def test_pause_pipeline_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + ): + """Test successful pausing of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.pause_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in PAUSED status + assert manager.get_pipeline_status() == PipelineStatus.PAUSED + + # Verify that all jobs remain in their original statuses + # (coordinate_pipeline is called by pause_pipeline but should not change job statuses + # while paused). + for job_run in sample_pipeline.job_runs: + assert job_run.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestUnpausePipelineUnit: + """Test unpausing of pipelines.""" + + @pytest.mark.asyncio + async def test_unpause_pipeline_raises_transition_error_if_not_paused(self, mock_pipeline_manager): + """Test that pipeline unpausing raises an error if not currently paused.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is not paused", + ), + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + ): + await mock_pipeline_manager.unpause_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_unpause_pipeline_successful_unpausing_if_currently_paused(self, mock_pipeline_manager): + """Test successful pipeline unpausing if currently paused.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.PAUSED), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.unpause_pipeline() + + mock_coordinate.assert_called_once() + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + + +@pytest.mark.integration +class TestUnpausePipelineIntegration: + """Integration tests for unpausing of pipelines.""" + + @pytest.mark.asyncio + async def test_unpause_pipeline_integration( + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful unpausing of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to PAUSED status + manager.set_pipeline_status(PipelineStatus.PAUSED) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.unpause_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the non-dependent job was queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + +@pytest.mark.unit +class TestRestartPipelineUnit: + """Test restarting of pipelines.""" + + @pytest.mark.asyncio + async def test_restart_pipeline_skips_if_no_jobs_in_pipeline(self, mock_pipeline_manager): + """Test that pipeline restart skips if there are no jobs in the pipeline.""" + with ( + patch.object( + mock_pipeline_manager, + "get_all_jobs", + return_value=[], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.restart_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_restart_pipeline_successful_restart(self, mock_pipeline_manager, mock_job_manager): + """Test successful pipeline restart.""" + with ( + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + patch.object( + mock_pipeline_manager, + "get_all_jobs", + return_value=[Mock(spec=JobRun, id=1), Mock(spec=JobRun, id=2)], + ), + patch.object( + mock_job_manager, + "reset_job", + return_value=None, + ) as mock_reset_job, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.restart_pipeline() + + assert mock_reset_job.call_count == 2 + mock_set_status.assert_called_once_with(PipelineStatus.CREATED) + mock_start_pipeline.assert_called_once() + + +@pytest.mark.integration +class TestRestartPipelineIntegration: + """Integration tests for restarting of pipelines.""" + + @pytest.mark.asyncio + async def test_restart_pipeline_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful restarting of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job statuses to terminal states + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.restart_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the non-dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is now pending + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + async def test_restart_pipeline_integration_skips_if_no_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test that restarting a pipeline with no jobs skips without error.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to a terminal status + manager.set_pipeline_status(PipelineStatus.SUCCEEDED) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + await manager.restart_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status remains unchanged + assert manager.get_pipeline_status() == PipelineStatus.SUCCEEDED + + +@pytest.mark.unit +class TestCanEnqueueJobUnit: + """Test job dependency checking.""" + + def test_can_enqueue_job_with_no_dependencies(self, mock_pipeline_manager): + """Test that a job with no dependencies can be enqueued.""" + mock_job = Mock(spec=JobRun, id=1) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[], + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.can_enqueue_job(mock_job) + + assert result is True + + def test_cannot_enqueue_job_with_unmet_dependencies(self, mock_pipeline_manager): + """Test that a job with unmet dependencies cannot be enqueued.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.PENDING) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_dependency_is_met", return_value=False + ) as mock_job_dependency_is_met, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.can_enqueue_job(mock_job) + + mock_job_dependency_is_met.assert_called_once_with( + dependency_type=DependencyType.COMPLETION_REQUIRED, dependent_job_status=JobStatus.PENDING + ) + assert result is False + + def test_can_enqueue_job_with_met_dependencies(self, mock_pipeline_manager): + """Test that a job with met dependencies can be enqueued.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_dependency_is_met", return_value=True + ) as mock_job_dependency_is_met, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.can_enqueue_job(mock_job) + + mock_job_dependency_is_met.assert_called_once_with( + dependency_type=DependencyType.COMPLETION_REQUIRED, dependent_job_status=JobStatus.SUCCEEDED + ) + assert result is True + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_can_enqueue_job_raises_pipeline_state_error_on_handled_exceptions(self, mock_pipeline_manager, exception): + """Test that handled exceptions during dependency checking raise PipelineStateError.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch("mavedb.worker.lib.managers.pipeline_manager.job_dependency_is_met", side_effect=exception), + pytest.raises(PipelineStateError, match="Corrupted dependency data"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.can_enqueue_job(mock_job) + + +@pytest.mark.integration +class TestCanEnqueueJobIntegration: + """Integration tests for job dependency checking.""" + + def test_can_enqueue_job_integration_with_no_dependencies( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test that a job with no dependencies can be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + result = manager.can_enqueue_job(sample_job_run) + + assert result is True + + def test_can_enqueue_job_integration_with_unmet_dependencies( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_dependent_job_run, + ): + """Test that a job with unmet dependencies cannot be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + result = manager.can_enqueue_job(sample_dependent_job_run) + + assert result is False + + def test_can_enqueue_job_integration_with_met_dependencies( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test that a job with met dependencies can be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the dependency job to a succeeded status + sample_job_run.status = JobStatus.SUCCEEDED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + result = manager.can_enqueue_job(sample_dependent_job_run) + + assert result is True + + +@pytest.mark.unit +class TestShouldSkipJobDueToDependenciesUnit: + """Test job skipping due to unmet dependencies.""" + + def test_should_not_skip_job_with_no_dependencies(self, mock_pipeline_manager): + """Test that a job with no dependencies should not be skipped.""" + mock_job = Mock(spec=JobRun, id=1) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + return_value=(False, ""), + ) as mock_job_should_be_skipped, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + should_skip, reason = mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + + mock_job_should_be_skipped.assert_not_called() + assert should_skip is False + assert reason == "" + + def test_should_skip_job_with_unreachable_dependency(self, mock_pipeline_manager): + """Test that a job with unreachable dependencies should be skipped.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.FAILED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.SUCCESS_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + return_value=(True, "Unfulfillable dependency detected"), + ) as mock_job_should_be_skipped, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + should_skip, reason = mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + + mock_job_should_be_skipped.assert_called_once_with( + dependency_type=DependencyType.SUCCESS_REQUIRED, dependent_job_status=JobStatus.FAILED + ) + assert should_skip is True + assert reason == "Unfulfillable dependency detected" + + def test_should_not_skip_job_with_reachable(self, mock_pipeline_manager): + """Test that a job with met dependencies can be enqueued.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + return_value=(False, ""), + ) as mock_job_should_be_skipped, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + should_skip, reason = mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + mock_job_should_be_skipped.assert_called_once_with( + dependency_type=DependencyType.COMPLETION_REQUIRED, dependent_job_status=JobStatus.SUCCEEDED + ) + assert should_skip is False + assert reason == "" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_should_skip_job_due_to_dependencies_raises_pipeline_state_error_on_handled_exceptions( + self, mock_pipeline_manager, exception + ): + """Test that handled exceptions during dependency checking raise PipelineStateError.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + side_effect=exception, + ), + pytest.raises(PipelineStateError, match="Corrupted dependency data"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + + +@pytest.mark.integration +class TestShouldSkipJobDueToDependenciesIntegration: + """Integration tests for job skipping due to unmet dependencies.""" + + def test_should_not_skip_job_with_no_dependencies( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test that a job with no dependencies should not be skipped.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + should_skip, reason = manager.should_skip_job_due_to_dependencies(sample_job_run) + + assert should_skip is False + assert reason == "" + + def test_should_skip_job_with_unreachable_dependency( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test that a job with unreachable dependencies should be skipped.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job the dependency depends on to a failed status + sample_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + should_skip, reason = manager.should_skip_job_due_to_dependencies(sample_dependent_job_run) + + assert should_skip is True + assert reason == "Dependency did not succeed (failed)" + + def test_should_not_skip_job_with_reachable_dependency( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test that a job with met dependencies can be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the dependency job to a succeeded status + sample_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + should_skip, reason = manager.should_skip_job_due_to_dependencies(sample_dependent_job_run) + + assert should_skip is False + assert reason == "" + + +@pytest.mark.unit +class TestRetryFailedJobsUnit: + """Test retrying of failed jobs.""" + + @pytest.mark.asyncio + async def test_retry_failed_jobs_no_failed_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test that retrying failed jobs skips if there are no failed jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "get_failed_jobs", + return_value=[], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.retry_failed_jobs() + + mock_prepare_retry.assert_not_called() + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_retry_failed_jobs_successful_retry(self, mock_pipeline_manager, mock_job_manager): + """Test successful retrying of failed jobs.""" + mock_failed_job1 = Mock(spec=JobRun, id=1) + mock_failed_job2 = Mock(spec=JobRun, id=2) + + with ( + patch.object( + mock_pipeline_manager, + "get_failed_jobs", + return_value=[mock_failed_job1, mock_failed_job2], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object( + mock_job_manager, + "prepare_retry", + return_value=None, + ) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.retry_failed_jobs() + + assert mock_prepare_retry.call_count == 2 + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + mock_coordinate.assert_called_once() + + +@pytest.mark.integration +class TestRetryFailedJobsIntegration: + """Integration tests for retrying of failed jobs.""" + + @pytest.mark.asyncio + async def test_retry_failed_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful retrying of failed jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.retry_failed_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the failed job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is still pending + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + async def test_retry_failed_jobs_integration_no_failed_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test that retrying failed jobs skips if there are no failed jobs.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + await manager.retry_failed_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is not changed + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + +@pytest.mark.unit +class TestRetryUnsuccessfulJobsUnit: + """Test retrying of unsuccessful jobs.""" + + @pytest.mark.asyncio + async def test_retry_unsuccessful_jobs_no_unsuccessful_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test that retrying unsuccessful jobs skips if there are no unsuccessful jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "get_unsuccessful_jobs", + return_value=[], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.retry_unsuccessful_jobs() + + mock_prepare_retry.assert_not_called() + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_retry_failed_jobs_successful_retry(self, mock_pipeline_manager, mock_job_manager): + """Test successful retrying of failed jobs.""" + mock_failed_job1 = Mock(spec=JobRun, id=1) + mock_failed_job2 = Mock(spec=JobRun, id=2) + + with ( + patch.object( + mock_pipeline_manager, + "get_unsuccessful_jobs", + return_value=[mock_failed_job1, mock_failed_job2], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object( + mock_job_manager, + "prepare_retry", + return_value=None, + ) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.retry_unsuccessful_jobs() + + assert mock_prepare_retry.call_count == 2 + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + mock_coordinate.assert_called_once() + + +@pytest.mark.integration +class TestRetryUnsuccessfulJobsIntegration: + """Integration tests for retrying of unsuccessful jobs.""" + + @pytest.mark.asyncio + async def test_retry_unsuccessful_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful retrying of unsuccessful jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.CANCELLED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.retry_unsuccessful_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the failed job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the cancelled dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + async def test_retry_unsuccessful_jobs_integration_no_unsuccessful_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test that retrying unsuccessful jobs skips if there are no unsuccessful jobs.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + await manager.retry_unsuccessful_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is not changed + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + +@pytest.mark.unit +class TestRetryPipelineUnit: + """Test retrying of entire pipelines.""" + + @pytest.mark.asyncio + async def test_retry_pipeline_calls_retry_unsuccessful_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test that retrying a pipeline calls retrying unsuccessful jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "retry_unsuccessful_jobs", + return_value=None, + ) as mock_retry_unsuccessful_jobs, + TransactionSpy.spy(mock_pipeline_manager.db), # flush is handled in retry_unsuccessful_jobs, which we mock + ): + await mock_pipeline_manager.retry_pipeline() + + mock_retry_unsuccessful_jobs.assert_called_once() + + +@pytest.mark.integration +class TestRetryPipelineIntegration: + """Integration tests for retrying of entire pipelines.""" + + @pytest.mark.asyncio + async def test_retry_pipeline_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful retrying of an entire pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.CANCELLED + sample_dependent_job_run.status = JobStatus.SKIPPED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.retry_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the failed job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the cancelled dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestGetJobsByStatusUnit: + """Test job retrieval by status with mocked database.""" + + def test_get_jobs_by_status_wraps_sqlalchemy_error_with_database_error(self, mock_pipeline_manager): + """Test database error handling.""" + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get jobs with status"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_jobs_by_status([JobStatus.RUNNING]) + + +@pytest.mark.integration +class TestGetJobsByStatusIntegration: + """Integration tests for job retrieval by status.""" + + @pytest.mark.parametrize( + "status", + JobStatus._member_map_.values(), + ) + def test_get_jobs_by_status_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + status, + ): + """Test retrieval of jobs by status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = status + sample_dependent_job_run.status = [s for s in JobStatus if s != status][0] + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_jobs_by_status([status]) + + assert len(running_jobs) == 1 + assert running_jobs[0].id == sample_job_run.id + + def test_get_jobs_by_status_integration_no_matching_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + ): + """Test retrieval of jobs by status when no jobs match.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + jobs = manager.get_jobs_by_status([JobStatus.SUCCEEDED]) + + assert len(jobs) == 0 + + def test_get_jobs_by_status_integration_multiple_matching_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of jobs by status when multiple jobs match.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set both job statuses to RUNNING + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.RUNNING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_jobs_by_status([JobStatus.RUNNING]) + + assert len(running_jobs) == 2 + job_ids = {job.id for job in running_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_jobs_by_status_integration_no_jobs_in_pipeline( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test retrieval of jobs by status when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + jobs = manager.get_jobs_by_status([JobStatus.RUNNING]) + + assert len(jobs) == 0 + + def test_get_jobs_by_status_multiple_statuses( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of jobs by multiple statuses.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + jobs = manager.get_jobs_by_status([JobStatus.RUNNING, JobStatus.PENDING]) + + assert len(jobs) == 2 + job_ids = {job.id for job in jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + # Assert jobs are ordered by created by timestamp + assert jobs[0].created_at <= jobs[1].created_at + + +@pytest.mark.unit +class TestGetPendingJobsUnit: + """Test retrieval of pending jobs.""" + + def test_get_pending_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of pending jobs.""" + + with ( + patch.object( + mock_pipeline_manager, "get_jobs_by_status", return_value=[Mock(), Mock()] + ) as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + jobs = mock_pipeline_manager.get_pending_jobs() + + assert len(jobs) == 2 + mock_get_jobs_by_status.assert_called_once_with([JobStatus.PENDING]) + + +@pytest.mark.integration +class TestGetPendingJobsIntegration: + """Integration tests for retrieval of pending jobs.""" + + def test_get_pending_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of pending jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.PENDING + sample_dependent_job_run.status = JobStatus.RUNNING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + pending_jobs = manager.get_pending_jobs() + + assert len(pending_jobs) == 1 + assert pending_jobs[0].id == sample_job_run.id + + def test_get_pending_jobs_integration_no_pending_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of pending jobs when there are no pending jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.SUCCEEDED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + pending_jobs = manager.get_pending_jobs() + + assert len(pending_jobs) == 0 + + +@pytest.mark.unit +class TestGetRunningJobsUnit: + """Test retrieval of running jobs.""" + + def test_get_running_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of running jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_running_jobs() + mock_get_jobs_by_status.assert_called_once_with([JobStatus.RUNNING]) + + +@pytest.mark.unit +class TestGetActiveJobsUnit: + """Test retrieval of active jobs.""" + + def test_get_active_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of active jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_active_jobs() + mock_get_jobs_by_status.assert_called_once_with(ACTIVE_JOB_STATUSES) + + +@pytest.mark.integration +class TestGetActiveJobsIntegration: + """Integration tests for retrieval of active jobs.""" + + def test_get_active_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of active jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + active_jobs = manager.get_active_jobs() + + assert len(active_jobs) == 2 + job_ids = {job.id for job in active_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_active_jobs_integration_no_active_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of active jobs when there are no active jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + active_jobs = manager.get_active_jobs() + + assert len(active_jobs) == 0 + + +@pytest.mark.integration +class TestGetRunningJobsIntegration: + """Integration tests for retrieval of running jobs.""" + + def test_get_running_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of running jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_running_jobs() + + assert len(running_jobs) == 1 + assert running_jobs[0].id == sample_job_run.id + + def test_get_running_jobs_integration_no_running_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of running jobs when there are no running jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_running_jobs() + + assert len(running_jobs) == 0 + + +@pytest.mark.unit +class TestGetFailedJobsUnit: + """Test retrieval of failed jobs.""" + + def test_get_failed_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of failed jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_failed_jobs() + + mock_get_jobs_by_status.assert_called_once_with([JobStatus.FAILED]) + + +@pytest.mark.integration +class TestGetFailedJobsIntegration: + """Integration tests for retrieval of failed jobs.""" + + def test_get_failed_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of failed jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + failed_jobs = manager.get_failed_jobs() + + assert len(failed_jobs) == 1 + assert failed_jobs[0].id == sample_job_run.id + + def test_get_failed_jobs_integration_no_failed_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of failed jobs when there are no failed jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + failed_jobs = manager.get_failed_jobs() + + assert len(failed_jobs) == 0 + + +@pytest.mark.unit +class TestGetUnsuccessfulJobsUnit: + """Test retrieval of unsuccessful jobs.""" + + def test_get_unsuccessful_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of unsuccessful jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_unsuccessful_jobs() + mock_get_jobs_by_status.assert_called_once_with([JobStatus.CANCELLED, JobStatus.SKIPPED, JobStatus.FAILED]) + + +@pytest.mark.integration +class TestGetUnsuccessfulJobsIntegration: + """Integration tests for retrieval of unsuccessful jobs.""" + + def test_get_unsuccessful_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of unsuccessful jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.CANCELLED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + unsuccessful_jobs = manager.get_unsuccessful_jobs() + + assert len(unsuccessful_jobs) == 2 + job_ids = {job.id for job in unsuccessful_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_unsuccessful_jobs_integration_no_unsuccessful_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of unsuccessful jobs when there are no unsuccessful jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + unsuccessful_jobs = manager.get_unsuccessful_jobs() + + assert len(unsuccessful_jobs) == 0 + + +@pytest.mark.unit +class TestGetAllJobsUnit: + """Test retrieval of all jobs.""" + + def test_get_all_jobs_wraps_sqlalchemy_errors_with_database_error(self, mock_pipeline_manager): + """Test database error handling during retrieval of all jobs.""" + + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get all jobs"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_all_jobs() + + +@pytest.mark.integration +class TestGetAllJobsIntegration: + """Integration tests for retrieval of all jobs.""" + + def test_get_all_jobs_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of all jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + all_jobs = manager.get_all_jobs() + + assert len(all_jobs) == 2 + job_ids = {job.id for job in all_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_all_jobs_integration_no_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test retrieval of all jobs when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + all_jobs = manager.get_all_jobs() + + assert len(all_jobs) == 0 + + def test_get_all_jobs_integration_multiple_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of all jobs when there are multiple jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Add an additional job to the pipeline + new_job = JobRun( + id=99, + urn="job:additional_job:999", + pipeline_id=sample_pipeline.id, + job_type="Additional Job", + job_function="additional_function", + status=JobStatus.PENDING, + ) + session.add(new_job) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + all_jobs = manager.get_all_jobs() + + assert len(all_jobs) == 3 + job_ids = {job.id for job in all_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + assert new_job.id in job_ids + + # Assert jobs are ordered by created by timestamp + assert all_jobs[0].created_at <= all_jobs[1].created_at <= all_jobs[2].created_at + + +@pytest.mark.unit +class TestGetDependenciesForJobUnit: + """Test retrieval of job dependencies.""" + + def test_get_dependencies_for_job_wraps_sqlalchemy_error_with_database_error(self, mock_pipeline_manager): + """Test database error handling during retrieval of job dependencies.""" + mock_job = Mock(spec=JobRun) + + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get job dependencies for job"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_dependencies_for_job(mock_job) + + +@pytest.mark.integration +class TestGetDependenciesForJobIntegration: + """Integration tests for retrieval of job dependencies.""" + + def test_get_dependencies_for_job_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + sample_job_dependency, + ): + """Test retrieval of job dependencies.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + dependencies = manager.get_dependencies_for_job(sample_dependent_job_run) + + assert len(dependencies) == 1 + dependency, job = dependencies[0] + assert dependency.id == sample_job_dependency.id + assert job.id == sample_job_run.id + + def test_get_dependencies_for_job_integration_no_dependencies( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test retrieval of job dependencies when there are no dependencies.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + dependencies = manager.get_dependencies_for_job(sample_job_run) + + assert len(dependencies) == 0 + + def test_get_dependencies_for_job_integration_multiple_dependencies( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of job dependencies when there are multiple dependencies.""" + # Create additional job and dependency + additional_job = JobRun( + id=99, + urn="job:additional_job:999", + pipeline_id=sample_pipeline.id, + job_type="Additional Job", + job_function="additional_function", + status=JobStatus.PENDING, + ) + session.add(additional_job) + session.commit() + + additional_dependency = JobDependency( + id=sample_dependent_job_run.id, + depends_on_job_id=additional_job.id, + dependency_type=DependencyType.COMPLETION_REQUIRED, + ) + session.add(additional_dependency) + session.commit() + + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + dependencies = manager.get_dependencies_for_job(sample_dependent_job_run) + + assert len(dependencies) == 2 + fetched_dependency_ids = {dep.id for dep, job in dependencies} + implicit_dependency_ids = {dep.id for dep in sample_dependent_job_run.job_dependencies} + assert fetched_dependency_ids == implicit_dependency_ids + + +@pytest.mark.unit +class TestGetPipelineUnit: + """Test retrieval of pipeline.""" + + def test_get_pipeline_wraps_sqlalchemy_errors_with_database_error(self, mock_pipeline): + """Test database error handling during retrieval of pipeline.""" + + # Prepare mock PipelineManager with mocked DB session that will raise SQLAlchemyError on query. + # We don't use the default fixture here since it usually wraps this function. + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + manager = object.__new__(PipelineManager) + manager.db = mock_db + manager.redis = mock_redis + manager.pipeline_id = mock_pipeline.id + + with ( + patch.object(manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get pipeline"), + TransactionSpy.spy(manager.db), + ): + manager.get_pipeline() + + +@pytest.mark.integration +class TestGetPipelineIntegration: + """Integration tests for retrieval of pipeline.""" + + def test_get_pipeline_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + ): + """Test retrieval of pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + pipeline = manager.get_pipeline() + + assert pipeline.id == sample_pipeline.id + assert pipeline.name == sample_pipeline.name + + def test_get_pipeline_integration_nonexistent_pipeline( + self, + session, + arq_redis, + with_populated_job_data, + ): + """Test retrieval of a nonexistent pipeline raises PipelineNotFoundError.""" + with ( + pytest.raises(DatabaseConnectionError, match="Failed to get pipeline 9999"), + TransactionSpy.spy(session), + ): + # get_pipeline is called implicitly during PipelineManager initialization + PipelineManager(session, arq_redis, pipeline_id=9999) + + +@pytest.mark.unit +class TestGetJobCountsByStatusUnit: + """Test retrieval of job counts by status.""" + + def test_get_job_counts_by_status_wraps_sqlalchemy_errors_with_database_error(self, mock_pipeline_manager): + """Test database error handling during retrieval of job counts by status.""" + + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get job counts for pipeline"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_job_counts_by_status() + + +@pytest.mark.integration +class TestGetJobCountsByStatusIntegration: + """Integration tests for retrieval of job counts by status.""" + + def test_get_job_counts_by_status_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of job counts by status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + counts = manager.get_job_counts_by_status() + + assert counts[JobStatus.RUNNING] == 1 + assert counts[JobStatus.PENDING] == 1 + assert counts.get(JobStatus.SUCCEEDED, 0) == 0 + + def test_get_job_counts_by_status_integration_no_jobs( + self, + session, + arq_redis, + with_populated_job_data, + sample_empty_pipeline, + ): + """Test retrieval of job counts by status when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + counts = manager.get_job_counts_by_status() + + assert counts == {} + + +@pytest.mark.unit +class TestGetPipelineProgressUnit: + """Test retrieval of pipeline progress.""" + + pass + + +@pytest.mark.integration +class TestGetPipelineProgressIntegration: + """Integration tests for retrieval of pipeline progress.""" + + pass + + +@pytest.mark.unit +class TestGetPipelineStatusUnit: + """Test retrieval of pipeline status.""" + + def test_get_pipeline_status_success(self, mock_pipeline_manager): + """Test successful retrieval of pipeline status.""" + with ( + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object( + mock_pipeline_manager, + "get_pipeline", + wraps=mock_pipeline_manager.get_pipeline, + ) as mock_get_pipeline, + ): + mock_pipeline_manager.get_pipeline_status() + mock_get_pipeline.assert_called_once() + + +@pytest.mark.integration +class TestGetPipelineStatusIntegration: + """Integration tests for retrieval of pipeline status.""" + + def test_get_pipeline_status_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + ): + """Test retrieval of pipeline status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + status = manager.get_pipeline_status() + + assert status == sample_pipeline.status + + +@pytest.mark.unit +class TestSetPipelineStatusUnit: + """Test setting of pipeline status.""" + + @pytest.mark.parametrize("pipeline_status", [status for status in PipelineStatus._member_map_.values()]) + def test_set_pipeline_status_success(self, mock_pipeline_manager, pipeline_status): + """Test successful setting of pipeline status.""" + mock_pipeline = Mock(spec=Pipeline, status=None) + + with ( + patch.object( + mock_pipeline_manager, + "get_pipeline", + return_value=mock_pipeline, + ) as mock_get_pipeline, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.set_pipeline_status(pipeline_status) + assert mock_pipeline.status == pipeline_status + + mock_get_pipeline.assert_called_once() + + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_set_pipeline_status_sets_finished_at_property_for_terminal_status( + self, mock_pipeline_manager, mock_pipeline, pipeline_status + ): + """Test that setting a terminal status updates the finished_at property.""" + # Set initial finished_at to None + mock_pipeline.finished_at = None + + with TransactionSpy.spy(mock_pipeline_manager.db): + before_update = datetime.datetime.now() + mock_pipeline_manager.set_pipeline_status(pipeline_status) + after_update = datetime.datetime.now() + + assert mock_pipeline.status == pipeline_status + assert mock_pipeline.finished_at is not None + assert before_update <= mock_pipeline.finished_at <= after_update + + def test_set_pipeline_status_clears_started_at_property_for_created_status( + self, mock_pipeline_manager, mock_pipeline + ): + """Test that setting status to CREATED clears the started_at property.""" + + with TransactionSpy.spy(mock_pipeline_manager.db): + mock_pipeline_manager.set_pipeline_status(PipelineStatus.CREATED) + assert mock_pipeline.status == PipelineStatus.CREATED + assert mock_pipeline.started_at is None + + @pytest.mark.parametrize( + "initial_started_at", + [None, datetime.datetime.now() - datetime.timedelta(hours=1)], + ) + def test_set_pipeline_status_sets_started_at_property_for_running_status( + self, mock_pipeline_manager, mock_pipeline, initial_started_at + ): + """Test that setting status to RUNNING sets the started_at property if not already set.""" + mock_pipeline.started_at = initial_started_at + with TransactionSpy.spy(mock_pipeline_manager.db): + before_update = datetime.datetime.now() + mock_pipeline_manager.set_pipeline_status(PipelineStatus.RUNNING) + after_update = datetime.datetime.now() + + assert mock_pipeline.status == PipelineStatus.RUNNING + + if initial_started_at is None: + assert mock_pipeline.started_at is not None + assert before_update <= mock_pipeline.started_at <= after_update + else: + assert mock_pipeline.started_at == initial_started_at + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_set_pipeline_status_handled_exception_raises_pipeline_state_error(self, mock_pipeline_manager, exception): + """Test that handled exceptions during setting of pipeline status raise PipelineStateError.""" + + def get_or_error(*args): + if args: + raise exception + return PipelineStatus.CREATED + + with ( + patch.object(mock_pipeline_manager, "get_pipeline") as mock_pipeline, + pytest.raises(PipelineStateError, match="Failed to set pipeline status"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + # Mock exception when setting pipeline status + mock_pipeline.return_value = Mock(spec=Pipeline) + type(mock_pipeline.return_value).status = PropertyMock(side_effect=get_or_error) + + mock_pipeline_manager.set_pipeline_status(PipelineStatus.RUNNING) + + +@pytest.mark.integration +class TestSetPipelineStatusIntegration: + """Integration tests for setting of pipeline status.""" + + @pytest.mark.parametrize("pipeline_status", [status for status in PipelineStatus._member_map_.values()]) + def test_set_pipeline_status_integration( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + pipeline_status, + ): + """Test setting of pipeline status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + manager.set_pipeline_status(pipeline_status) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is updated + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == pipeline_status + + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_set_pipeline_status_integration_terminal_status_sets_finished_at( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + pipeline_status, + ): + """Test that setting a terminal status updates the finished_at property.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + before_update = datetime.datetime.now(tz=datetime.timezone.utc) + manager.set_pipeline_status(pipeline_status) + after_update = datetime.datetime.now(tz=datetime.timezone.utc) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status and finished_at are updated + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == pipeline_status + assert updated_pipeline.finished_at is not None + assert before_update <= updated_pipeline.finished_at <= after_update + + def test_set_pipeline_status_integration_created_status_clears_started_at( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + ): + """Test that setting status to CREATED clears the started_at property.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with TransactionSpy.spy(session): + manager.set_pipeline_status(PipelineStatus.CREATED) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is updated and started_at is None + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == PipelineStatus.CREATED + assert updated_pipeline.started_at is None + + @pytest.mark.parametrize( + "initial_started_at", + [None, datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(hours=1)], + ) + def test_set_pipeline_status_integration_running_status_sets_started_at( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + initial_started_at, + ): + """Test that setting status to RUNNING sets the started_at property if not already set.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial started_at + sample_pipeline.started_at = initial_started_at + session.commit() + + with TransactionSpy.spy(session): + before_update = datetime.datetime.now(tz=datetime.timezone.utc) + manager.set_pipeline_status(PipelineStatus.RUNNING) + after_update = datetime.datetime.now(tz=datetime.timezone.utc) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status and started_at are updated + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == PipelineStatus.RUNNING + + if initial_started_at is None: + assert before_update <= updated_pipeline.started_at <= after_update + else: + assert updated_pipeline.started_at == initial_started_at + + +@pytest.mark.unit +class TestEnqueueInArqUnit: + """Test enqueuing jobs in ARQ.""" + + @pytest.mark.asyncio + async def test_enqueue_in_arq_without_redis_raises_pipeline_coordination_error(self, mock_pipeline_manager): + """Test that attempting to enqueue a job without a Redis connection raises PipelineCoordinationError.""" + mock_job = Mock(spec=JobRun, job_function="test_func", id=1, urn="urn:example", retry_delay_seconds=10) + mock_pipeline_manager.redis = None + + with ( + pytest.raises( + PipelineCoordinationError, match="Redis client is not configured for job enqueueing; cannot proceed." + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager._enqueue_in_arq(job=mock_job, is_retry=False) + + @pytest.mark.asyncio + @pytest.mark.parametrize("enqueud", [Mock(spec=ArqJob), None]) + @pytest.mark.parametrize("retry", [True, False]) + async def test_enqueue_in_arq_success(self, mock_pipeline_manager, retry, enqueud): + """Test successful enqueuing of a job in ARQ.""" + mock_job = Mock(spec=JobRun, job_function="test_func", id=1, urn="urn:example", retry_delay_seconds=10) + with ( + patch.object(mock_pipeline_manager.redis, "enqueue_job", return_value=enqueud) as mock_enqueue_job, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager._enqueue_in_arq(job=mock_job, is_retry=retry) + + mock_enqueue_job.assert_called_once_with( + mock_job.job_function, + mock_job.id, + _defer_by=datetime.timedelta(seconds=mock_job.retry_delay_seconds if retry else 0), + _job_id=mock_job.urn, + ) + + @pytest.mark.asyncio + async def test_any_enqueue_exception_raises_pipeline_coordination_error(self, mock_pipeline_manager): + """Test that any exception during enqueuing raises PipelineCoordinationError.""" + mock_job = Mock(spec=JobRun, job_function="test_func", id=1, urn="urn:example", retry_delay_seconds=10) + + with ( + patch.object( + mock_pipeline_manager.redis, + "enqueue_job", + side_effect=Exception("Test exception"), + ), + pytest.raises(PipelineCoordinationError, match="Failed to enqueue job in ARQ"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager._enqueue_in_arq(job=mock_job, is_retry=False) + + +@pytest.mark.integration +class TestEnqueueInArqIntegration: + """Integration tests for enqueuing jobs in ARQ.""" + + @pytest.mark.asyncio + async def test_enqueue_in_arq_integration( + self, + session, + arq_redis: ArqRedis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test enqueuing of a job in ARQ.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + await manager._enqueue_in_arq(job=sample_job_run, is_retry=False) + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + +@pytest.mark.integration +class TestPipelineManagerLifecycle: + """Integration tests for PipelineManager lifecycle.""" + + @pytest.mark.asyncio + async def test_full_pipeline_lifecycle( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test full lifecycle of PipelineManager including initialization and job retrieval.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # pipeline is created with pending jobs + pipeline = manager.get_pipeline() + all_jobs = manager.get_all_jobs() + + assert pipeline.id == sample_pipeline.id + assert len(all_jobs) == 2 + assert all_jobs[0].id == sample_job_run.id + assert all_jobs[0].status == JobStatus.PENDING + + # pipeline started + await manager.start_pipeline() + session.commit() + + # verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Simulate pipeline lifecycle for a two job sample pipeline. The workflow here should be as follows: + # - Enter pipeline manager decorator. We don't make any calls when a pipeline begins + # - Enter the job manager decorator. This sets the job to RUNNING. + # - Job runs... + # - Exit the job manager decorator. This sets the job to some terminal state. + # - Exit the pipeline manager decorator. This coordinates the pipeline, either + # enqueuing any newly queueable jobs or terminating it. + + # enter pipeline manager decorator: no work + pass + + # enter job manager decorator: set job to RUNNING + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # job runs... Actual job execution is out of scope for this test. Instead, evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # exit job manager decorator: set job to SUCCEEDED + job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) + session.commit() + + # exit pipeline manager decorator: enqueue newly queueable jobs or terminate pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify pipeline status is still RUNNING (since there is a dependent job) + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify that the completed job is now SUCCEEDED in the database + completed_job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert completed_job.status == JobStatus.SUCCEEDED + + # Verify that the dependent job is now QUEUED in the database and ARQ + dependent_job = session.execute( + select(JobRun).where(JobRun.pipeline_id == sample_pipeline.id).filter(JobRun.id != sample_job_run.id) + ).scalar_one() + assert dependent_job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == dependent_job.job_function + + # Simulate the next iteration of pipeline lifecycle. We've now entered a new context manager with + # steps identical to those described above but executing in the context of a newly enqueued dependent job. + job_manager = JobManager(session, arq_redis, dependent_job.id) + + # enter pipeline manager decorator: no work + pass + + # enter job manager decorator: set dependent job to RUNNING + dependent_job_manager = JobManager(session, arq_redis, dependent_job.id) + dependent_job_manager.start_job() + session.commit() + + # job runs... Actual job execution is out of scope for this test. Instead, evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # exit job manager decorator: set dependent job to SUCCEEDED + job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) + session.commit() + + # exit pipeline manager decorator: enqueue newly queueable jobs or terminate pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify pipeline status is now SUCCEEDED + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify that the dependent job is now SUCCEEDED in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == dependent_job.id)).scalar_one() + assert dependent_job.status == JobStatus.SUCCEEDED + + @pytest.mark.asyncio + async def test_paused_pipeline_lifecycle( + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test lifecycle of a paused pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Simulate job start + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Pause the pipeline. Pausing the pipeline while a job is running DOES NOT affect the job. + await manager.pause_pipeline() + session.commit() + + # Verify that the pipeline is paused + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.PAUSED + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # Simulate job completion + job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify that the pipeline remains paused + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.PAUSED + + # Verify that no jobs were enqueued in ARQ + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + # Verify that the dependent job remains pending in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.PENDING + + # Unpause the pipeline + await manager.unpause_pipeline() + session.commit() + + # Verify that the pipeline is now running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify that the dependent job is is now queued in ARQ + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_dependent_job_run.job_function + + # Simulate dependent job start + dependent_job_manager = JobManager(session, arq_redis, sample_dependent_job_run.id) + dependent_job_manager.start_job() + session.commit() + + # Evict the dependent job from redis to simulate completion. + await arq_redis.flushdb() + + # Simulate dependent job completion + dependent_job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify that the pipeline is now succeeded + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify that the dependent job is now succeeded in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.SUCCEEDED + + @pytest.mark.asyncio + async def test_cancelled_pipeline_lifecycle( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test lifecycle of a cancelled pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Simulate job start + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # Cancel the pipeline. This DOES have an effect on the running job. + await manager.cancel_pipeline() + session.commit() + + # Verify that the pipeline is now cancelled + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.CANCELLED + + # Verify that the job is now cancelled in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the dependent job is now skipped in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.SKIPPED + + # Verify that no jobs were enqueued in ARQ + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + @pytest.mark.asyncio + async def test_restart_pipeline_lifecycle( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test lifecycle of a restarted pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Start the job + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + exc = Exception("Simulated job failure") + job_manager.fail_job(error=exc, result={"status": "error", "data": {}, "exception": exc}) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify the pipeline failed + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.FAILED + + # Verify that the job is now failed in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Restart the pipeline + await manager.restart_pipeline() + session.commit() + + # Verify that the pipeline is now created + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + @pytest.mark.asyncio + async def test_retry_pipeline_lifecycle( + self, + session, + arq_redis, + with_populated_job_data, + sample_pipeline, + sample_job_run, + ): + """Test lifecycle of a restarted pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Add a cancelled job to the pipeline + cancelled_job = JobRun( + id=99, + pipeline_id=sample_pipeline.id, + job_function="cancelled_job_function", + job_type="CANCELLED_JOB", + status=JobStatus.CANCELLED, + urn="urn:cancelled_job", + ) + session.add(cancelled_job) + session.commit() + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Start the job + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + exc = Exception("Simulated job failure") + job_manager.fail_job(error=exc, result={"status": "error", "data": {}, "exception": exc}) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify the pipeline failed + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.FAILED + + # Verify that the job is now failed in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Restart the pipeline + await manager.retry_pipeline() + session.commit() + + # Verify that the pipeline is now created + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status of failed job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify the previously cancelled job is now queued + job = session.execute(select(JobRun).where(JobRun.id == cancelled_job.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 2 diff --git a/tests/worker/lib/managers/test_utils.py b/tests/worker/lib/managers/test_utils.py new file mode 100644 index 00000000..eb5adb81 --- /dev/null +++ b/tests/worker/lib/managers/test_utils.py @@ -0,0 +1,94 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus +from mavedb.worker.lib.managers.constants import COMPLETED_JOB_STATUSES +from mavedb.worker.lib.managers.utils import ( + construct_bulk_cancellation_result, + job_dependency_is_met, + job_should_be_skipped_due_to_unfulfillable_dependency, +) + + +@pytest.mark.unit +class TestConstructBulkCancellationResultUnit: + def test_construct_bulk_cancellation_result(self): + reason = "Test cancellation reason" + result = construct_bulk_cancellation_result(reason) + + assert result["status"] == "cancelled" + assert result["data"]["reason"] == reason + assert "timestamp" in result["data"] + assert result["exception"] is None + + +@pytest.mark.unit +class TestJobDependencyIsMetUnit: + @pytest.mark.parametrize( + "dependency_type, dependent_job_status, expected", + [ + (None, "any_status", True), + # success required dependencies-- should only be met if dependent job succeeded + (DependencyType.SUCCESS_REQUIRED, JobStatus.SUCCEEDED, True), + *[ + (DependencyType.SUCCESS_REQUIRED, dependent_job_status, False) + for dependent_job_status in JobStatus._member_map_.values() + if dependent_job_status != JobStatus.SUCCEEDED + ], + # completion required dependencies-- should be met if dependent job is in any terminal state + *[ + ( + DependencyType.COMPLETION_REQUIRED, + dependent_job_status, + dependent_job_status in COMPLETED_JOB_STATUSES, + ) + for dependent_job_status in JobStatus._member_map_.values() + ], + ], + ) + def test_job_dependency_is_met(self, dependency_type, dependent_job_status, expected): + result = job_dependency_is_met(dependency_type, dependent_job_status) + assert result == expected + + +@pytest.mark.unit +class TestJobShouldBeSkippedDueToUnfulfillableDependencyUnit: + @pytest.mark.parametrize( + "dependency_type, dependent_job_status, expected", + [ + # No dependency-- should not be skipped + (None, "any_status", False), + # success required dependencies-- should be skipped if dependent job in terminal non-success state + (DependencyType.SUCCESS_REQUIRED, JobStatus.SUCCEEDED, False), + *[ + ( + DependencyType.SUCCESS_REQUIRED, + dependent_job_status, + dependent_job_status in (JobStatus.FAILED, JobStatus.SKIPPED, JobStatus.CANCELLED), + ) + for dependent_job_status in JobStatus._member_map_.values() + ], + # completion required dependencies-- should be skipped if dependent job is not in a terminal state + *[ + ( + DependencyType.COMPLETION_REQUIRED, + dependent_job_status, + dependent_job_status in (JobStatus.CANCELLED, JobStatus.SKIPPED), + ) + for dependent_job_status in JobStatus._member_map_.values() + ], + ], + ) + def test_job_should_be_skipped_due_to_unfulfillable_dependency( + self, dependency_type, dependent_job_status, expected + ): + result = job_should_be_skipped_due_to_unfulfillable_dependency(dependency_type, dependent_job_status) + + if expected: + assert result[0] is True + assert isinstance(result[1], str) + else: + assert result == (False, None) diff --git a/tests/worker/test_jobs.py b/tests/worker/test_jobs.py deleted file mode 100644 index e7fd0b39..00000000 --- a/tests/worker/test_jobs.py +++ /dev/null @@ -1,3479 +0,0 @@ -# ruff: noqa: E402 - -import json -from asyncio.unix_events import _UnixSelectorEventLoop -from copy import deepcopy -from datetime import date -from unittest.mock import patch -from uuid import uuid4 - -import jsonschema -import pandas as pd -import pytest -from requests import HTTPError -from sqlalchemy import not_, select - -arq = pytest.importorskip("arq") -cdot = pytest.importorskip("cdot") -fastapi = pytest.importorskip("fastapi") -pyathena = pytest.importorskip("pyathena") - -from mavedb.data_providers.services import VRSMap -from mavedb.lib.clingen.services import ( - ClinGenAlleleRegistryService, - ClinGenLdhService, - clingen_allele_id_from_ldh_variation, -) -from mavedb.lib.mave.constants import HGVS_NT_COLUMN -from mavedb.lib.score_sets import csv_data_to_df -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.lib.validation.exceptions import ValidationError -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.enums.processing_state import ProcessingState -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.models.variant import Variant -from mavedb.view_models.experiment import Experiment, ExperimentCreate -from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate -from mavedb.worker.jobs import ( - BACKOFF_LIMIT, - MAPPING_CURRENT_ID_NAME, - MAPPING_QUEUE_NAME, - create_variants_for_score_set, - link_clingen_variants, - link_gnomad_variants, - map_variants_for_score_set, - poll_uniprot_mapping_jobs_for_score_set, - submit_score_set_mappings_to_car, - submit_score_set_mappings_to_ldh, - submit_uniprot_mapping_jobs_for_score_set, - variant_mapper_manager, -) -from tests.helpers.constants import ( - TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_CLINGEN_ALLELE_OBJECT, - TEST_CLINGEN_LDH_LINKING_RESPONSE, - TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, - TEST_CLINGEN_SUBMISSION_RESPONSE, - TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE, - TEST_GNOMAD_DATA_VERSION, - TEST_MINIMAL_ACC_SCORESET, - TEST_MINIMAL_EXPERIMENT, - TEST_MINIMAL_MULTI_TARGET_SCORESET, - TEST_MINIMAL_SEQ_SCORESET, - TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_NT_CDOT_TRANSCRIPT, - TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, - TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, - TEST_UNIPROT_SWISS_PROT_TYPE, - TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, - TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - VALID_CHR_ACCESSION, - VALID_CLINGEN_CA_ID, - VALID_NT_ACCESSION, - VALID_UNIPROT_ACCESSION, -) -from tests.helpers.util.exceptions import awaitable_exception -from tests.helpers.util.experiment import create_experiment -from tests.helpers.util.score_set import create_acc_score_set, create_multi_target_score_set, create_seq_score_set - - -@pytest.fixture -def populate_worker_db(data_files, client): - # create score set via API. In production, the API would invoke this worker job - experiment = create_experiment(client) - seq_score_set = create_seq_score_set(client, experiment["urn"]) - acc_score_set = create_acc_score_set(client, experiment["urn"]) - multi_target_score_set = create_multi_target_score_set(client, experiment["urn"]) - - return [seq_score_set["urn"], acc_score_set["urn"], multi_target_score_set["urn"]] - - -async def setup_records_and_files(async_client, data_files, input_score_set): - experiment_payload = deepcopy(TEST_MINIMAL_EXPERIMENT) - jsonschema.validate(instance=experiment_payload, schema=ExperimentCreate.model_json_schema()) - experiment_response = await async_client.post("/api/v1/experiments/", json=experiment_payload) - assert experiment_response.status_code == 200 - experiment = experiment_response.json() - jsonschema.validate(instance=experiment, schema=Experiment.model_json_schema()) - - score_set_payload = deepcopy(input_score_set) - score_set_payload["experimentUrn"] = experiment["urn"] - jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.model_json_schema()) - score_set_response = await async_client.post("/api/v1/score-sets/", json=score_set_payload) - assert score_set_response.status_code == 200 - score_set = score_set_response.json() - jsonschema.validate(instance=score_set, schema=ScoreSet.model_json_schema()) - - scores_fp = ( - "scores_multi_target.csv" - if len(score_set["targetGenes"]) > 1 - else ("scores.csv" if "targetSequence" in score_set["targetGenes"][0] else "scores_acc.csv") - ) - counts_fp = ( - "counts_multi_target.csv" - if len(score_set["targetGenes"]) > 1 - else ("counts.csv" if "targetSequence" in score_set["targetGenes"][0] else "counts_acc.csv") - ) - with ( - open(data_files / scores_fp, "rb") as score_file, - open(data_files / counts_fp, "rb") as count_file, - open(data_files / "score_columns_metadata.json", "rb") as score_columns_file, - open(data_files / "count_columns_metadata.json", "rb") as count_columns_file, - ): - scores = csv_data_to_df(score_file) - counts = csv_data_to_df(count_file) - score_columns_metadata = json.load(score_columns_file) - count_columns_metadata = json.load(count_columns_file) - - return score_set["urn"], scores, counts, score_columns_metadata, count_columns_metadata - - -async def setup_records_files_and_variants(session, async_client, data_files, input_score_set, worker_ctx): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # Patch CDOT `_get_transcript`, in the event this function is called on an accesssion based scoreset. - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ): - result = await create_variants_for_score_set( - worker_ctx, uuid4().hex, score_set.id, 1, scores, counts, score_columns_metadata, count_columns_metadata - ) - - score_set_with_variants = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - assert result["success"] - assert score_set.processing_state is ProcessingState.success - assert score_set_with_variants.num_variants == 3 - - return score_set_with_variants - - -async def setup_records_files_and_variants_with_mapping( - session, async_client, data_files, input_score_set, standalone_worker_context -): - score_set = await setup_records_files_and_variants( - session, async_client, data_files, input_score_set, standalone_worker_context - ) - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - return session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - - -async def sanitize_mapping_queue(standalone_worker_context, score_set): - queued_job = await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME) - assert int(queued_job.decode("utf-8")) == score_set.id - - -async def setup_mapping_output( - async_client, session, score_set, score_set_is_seq_based=True, score_set_is_multi_target=False, empty=False -): - score_set_response = await async_client.get(f"/api/v1/score-sets/{score_set.urn}") - - if score_set_is_seq_based: - if score_set_is_multi_target: - # If this is a multi-target sequence based score set, use the scaffold for that. - mapping_output = deepcopy(TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD) - else: - mapping_output = deepcopy(TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD) - else: - # there is not currently a multi-target accession-based score set test - mapping_output = deepcopy(TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD) - mapping_output["metadata"] = score_set_response.json() - - if empty: - return mapping_output - - variants = session.scalars(select(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).all() - for variant in variants: - mapped_score = { - "pre_mapped": TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - "post_mapped": TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, - "mavedb_id": variant.urn, - } - - mapping_output["mapped_scores"].append(mapped_score) - - return mapping_output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set,validation_error", - [ - ( - TEST_MINIMAL_SEQ_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": ["target sequence mismatch for 'c.1T>A' at row 0 for sequence TEST1"], - }, - ), - ( - TEST_MINIMAL_ACC_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": [ - "Failed to parse row 0 with HGVS exception: NM_001637.3:c.1T>A: Variant reference (T) does not agree with reference sequence (G)." - ], - }, - ), - ( - TEST_MINIMAL_MULTI_TARGET_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": ["target sequence mismatch for 'n.1T>A' at row 0 for sequence TEST3"], - }, - ), - ], -) -async def test_create_variants_for_score_set_with_validation_error( - input_score_set, - validation_error, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - if input_score_set == TEST_MINIMAL_SEQ_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "c.1T>A" - elif input_score_set == TEST_MINIMAL_ACC_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = f"{VALID_NT_ACCESSION}:c.1T>A" - elif input_score_set == TEST_MINIMAL_MULTI_TARGET_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "TEST3:n.1T>A" - - with ( - patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == validation_error - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_caught_exception( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some exception will be raised no matter what in the async job. - with ( - patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == {"detail": [], "exception": ""} - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_caught_base_exception( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some base exception will be handled no matter what in the async job. - with ( - patch.object(pd.DataFrame, "isnull", side_effect=BaseException), - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_existing_variants( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - await sanitize_mapping_queue(standalone_worker_context, score_set) - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert score_set.processing_errors is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_existing_exceptions( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some exception will be raised no matter what in the async job. - with ( - patch.object( - pd.DataFrame, - "isnull", - side_effect=ValidationError("Test Exception", triggers=["exc_1", "exc_2"]), - ) as mocked_exc, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == { - "exception": "Test Exception", - "detail": ["exc_1", "exc_2"], - } - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert score_set.processing_errors is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_enqueues_manager_and_successful_mapping( - input_score_set, - setup_worker_db, - session, - async_client, - data_files, - arq_worker, - arq_redis, -): - score_set_is_seq = all(["targetSequence" in target for target in input_score_set["targetGenes"]]) - score_set_is_multi_target = len(input_score_set["targetGenes"]) > 1 - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set, score_set_is_seq, score_set_is_multi_target) - - async def dummy_car_submission_job(): - return TEST_CLINGEN_ALLELE_OBJECT - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # Variants have not yet been created, so infer their URNs. - async def dummy_linking_job(): - return [(f"{score_set_urn}#{i}", TEST_CLINGEN_LDH_LINKING_RESPONSE) for i in range(1, len(scores) + 1)] - - with ( - patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp, - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[ - dummy_mapping_job(), - dummy_car_submission_job(), - dummy_ldh_submission_job(), - dummy_linking_job(), - ], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - await arq_redis.enqueue_job( - "create_variants_for_score_set", - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - await arq_worker.async_run() - await arq_worker.run_check() - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if score_set_is_seq: - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_exception_skips_mapping( - input_score_set, - setup_worker_db, - session, - async_client, - data_files, - arq_worker, - arq_redis, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc: - await arq_redis.enqueue_job( - "create_variants_for_score_set", - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - await arq_worker.async_run() - await arq_worker.run_check() - - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == {"detail": [], "exception": ""} - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.not_attempted - assert score_set.mapping_errors is None - - -# NOTE: These tests operate under the assumption that mapping output is consistent between accession based and sequence based score sets. If -# this assumption changes in the future, tests reflecting this difference in output should be added for accession based score sets. - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset( - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_with_existing_mapped_variants( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - existing_variant = session.scalars(select(Variant)).first() - - if not existing_variant: - raise ValueError - - session.add( - MappedVariant( - pre_mapped={"preexisting": "variant"}, - post_mapped={"preexisting": "variant"}, - variant_id=existing_variant.id, - modification_date=date.today(), - mapped_date=date.today(), - vrs_version="2.0", - mapping_api_version="0.0.0", - current=True, - ) - ) - session.commit() - - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - preexisting_variants = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, not_(MappedVariant.current)) - ).all() - new_variants = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.current) - ).all() - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == score_set.num_variants + 1 - assert len(preexisting_variants) == 1 - assert len(new_variants) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_exception_in_mapping_setup_score_set_selection( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id + 5, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # When we cannot fetch a score set, these fields are unable to be updated. - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_exception_in_mapping_setup_vrs_object( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - with patch.object( - VRSMap, - "__init__", - return_value=Exception(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_mapping_exception( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_mapping_exception_retry_limit_reached( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ): - result = await map_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, BACKOFF_LIMIT + 1 - ) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_mapping_exception_retry_failed( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ), - patch.object(arq.ArqRedis, "lpush", awaitable_exception()), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # Behavior for exception in mapping is retried job - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_parsing_exception_with_retry( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - mapping_test_output_for_score_set = await setup_mapping_output(async_client, session, score_set) - mapping_test_output_for_score_set.pop("computed_genomic_reference_sequence") - return mapping_test_output_for_score_set - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_parsing_exception_retry_failed( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - mapping_test_output_for_score_set = await setup_mapping_output(async_client, session, score_set) - mapping_test_output_for_score_set.pop("computed_genomic_reference_sequence") - return mapping_test_output_for_score_set - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch.object(arq.ArqRedis, "lpush", awaitable_exception()), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # Behavior for exception outside mapping is failed job - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_parsing_exception_retry_limit_reached( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - mapping_test_output_for_score_set = await setup_mapping_output(async_client, session, score_set) - mapping_test_output_for_score_set.pop("computed_genomic_reference_sequence") - return mapping_test_output_for_score_set - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ): - result = await map_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, BACKOFF_LIMIT + 1 - ) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # Behavior for exception outside mapping is failed job - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_no_mapping_output( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # Do not await, we need a co-routine object to be the return value of our `run_in_executor` mock. - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set, empty=True) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - - -@pytest.mark.asyncio -async def test_mapping_manager_empty_queue(setup_worker_db, standalone_worker_context): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. - assert result["enqueued_job"] is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - - -@pytest.mark.asyncio -async def test_mapping_manager_empty_queue_error_during_setup(setup_worker_db, standalone_worker_context): - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.ArqRedis, "rpop", Exception()): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. - assert result["enqueued_job"] is not None - assert ( - await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set.id) - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_not_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Mapping job should be queued if none is currently running, and the queue should now be empty. - assert result["enqueued_job"] is not None - assert ( - await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.queued - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - # We don't actually start processing these score sets. - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_in_progress_error_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with ( - patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_not_in_progress_error_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with ( - patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Enqueue would have failed, the job is unsuccessful, and we remove the queued item. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set_id_1 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_2 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_3 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # All three jobs should complete successfully... - assert result1["success"] - assert result2["success"] - assert result3["success"] - - # ...with a new job enqueued... - assert result1["enqueued_job"] is not None - assert result2["enqueued_job"] is not None - assert result3["enqueued_job"] is not None - - # ...of which all should be deferred jobs of the "variant_mapper_manager" variety... - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - - # ...and the queue state should have three jobs, each of our three created score sets. - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 3 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_1) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) - - score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() - score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() - score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() - # Each score set should remain queued with no mapping errors. - assert score_set1.mapping_state == MappingState.queued - assert score_set2.mapping_state == MappingState.queued - assert score_set3.mapping_state == MappingState.queued - assert score_set1.mapping_errors is None - assert score_set2.mapping_errors is None - assert score_set3.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_not_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set_id_1 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_2 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_3 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): - result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Mock the first job being in-progress - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, str(score_set_id_1)) - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # All three jobs should complete successfully... - assert result1["success"] - assert result2["success"] - assert result3["success"] - - # ...with a new job enqueued... - assert result1["enqueued_job"] is not None - assert result2["enqueued_job"] is not None - assert result3["enqueued_job"] is not None - - # ...of which the first should be a queued job of the "map_variants_for_score_set" variety and the other two should be - # deferred jobs of the "variant_mapper_manager" variety... - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.queued - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "map_variants_for_score_set" - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - - # ...and the queue state should have two jobs, neither of which should be the first score set. - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 2 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) - - score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() - score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() - score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() - # We don't actually process any score sets in the manager job, and each should have no mapping errors. - assert score_set1.mapping_state == MappingState.queued - assert score_set2.mapping_state == MappingState.queued - assert score_set3.mapping_state == MappingState.queued - assert score_set1.mapping_errors is None - assert score_set2.mapping_errors is None - assert score_set3.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", True), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed all jobs exactly once. - assert num_completed_jobs == 8 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_disabled( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager and mapping jobs, but not the submission, linking, or uniprot mapping jobs. - assert num_completed_jobs == 2 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_enabled( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", True), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager, mapping, and uniprot jobs, but not the submission or linking jobs. - assert num_completed_jobs == 4 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_enabled_uniprot_disabled( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job(), dummy_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", return_value=[mocked_gnomad_variant_row]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager, mapping, submission, and linking jobs, but not the uniprot jobs. - assert num_completed_jobs == 6 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_retried_mapping_successful_mapping_on_retry( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_mapping_job(): - return Exception() - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[failed_mapping_job(), dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the mapping manager job twice, the mapping job twice, the two submission jobs, and both linking jobs. - assert num_completed_jobs == 8 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_unsuccessful_mapping( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_mapping_job(): - return Exception() - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[failed_mapping_job()] * 5, - ), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed 6 mapping jobs and 6 management jobs. - assert num_completed_jobs == 12 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -############################################################################################################################################ -# ClinGen CAR Submission -############################################################################################################################################ - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - mapped_variants_with_caid_for_score_set = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) - ).all() - - assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] is not None - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_no_variants_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_hgvs_dict_creation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.get_hgvs_from_post_mapped", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_during_submission( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", side_effect=Exception()), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_allele_association( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.get_allele_registry_associations", side_effect=Exception()), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_during_ldh_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - mapped_variants_with_caid_for_score_set = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) - ).all() - - assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -############################################################################################################################################ -# ClinGen LDH Submission -############################################################################################################################################ - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] is not None - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_auth( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch.object( - ClinGenLdhService, - "_existing_jwt", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_no_variants_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_hgvs_generation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.lib.variants.get_hgvs_from_post_mapped", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_ldh_submission_construction( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.lib.clingen.content_constructors.construct_ldh_submission", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_during_submission( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_submission_job(): - return Exception() - - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=failed_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "error_response", [TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE] -) -async def test_submit_score_set_mappings_to_ldh_submission_failures_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis, error_response -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [None, error_response] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_during_linking_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_linking_not_queued_when_expected( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(arq.ArqRedis, "enqueue_job", return_value=None), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -############################################################################################################################################## -## ClinGen Linkage -############################################################################################################################################## - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.clingen_allele_id == clingen_allele_id_from_ldh_variation(TEST_CLINGEN_LDH_LINKING_RESPONSE) - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.clingen_allele_id is None - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_no_variants_to_link( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_during_linkage( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=Exception(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_while_parsing_linkages( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.clingen_allele_id_from_ldh_variation", - side_effect=Exception(), - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_but_do_not_eclipse_retry_threshold( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 2, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch( - "mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", - 0, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert result["retried"] - assert result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_cant_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_retries_exceeded( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch( - "mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", - 0, - ), - patch( - "mavedb.worker.jobs.BACKOFF_LIMIT", - 1, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 2) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_error_in_gnomad_job_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -################################################################################################################################################## -# UniProt ID mapping -################################################################################################################################################## - -### Test Submission - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] is not None - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_targets( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - score_set.target_genes = [] - session.add(score_set) - session.commit() - - with patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message: - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called_once() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_while_spawning_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", side_effect=HTTPError()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_too_many_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", return_value=["AC1", "AC2"]), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message: - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_error_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.setup_job_state", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_during_submission_generation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_spawned_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=None), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -### Test Polling - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_targets( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - score_set.target_genes = [] - session.add(score_set) - session.commit() - - with patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message: - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_too_many_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", return_value=["AC1", "AC2"]), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", return_value=[]), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_jobs_not_ready( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=False), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # This case does not get sent to slack - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {}, - score_set.id, - uuid4().hex, - ) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_ids_mapped( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value={"failedIDs": [VALID_CHR_ACCESSION]}), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_too_many_mapped_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # Simulate a response with too many mapped IDs - too_many_mapped_ids_response = TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE.copy() - too_many_mapped_ids_response["results"].append( - {"from": "AC3", "to": {"primaryAccession": "AC3", "entryType": TEST_UNIPROT_SWISS_PROT_TYPE}} - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value=too_many_mapped_ids_response), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_error_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.setup_job_state", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_exception_during_polling( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -################################################################################################################################################## -# gnomAD Linking -################################################################################################################################################## - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_success( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. - mapped_variants = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID - session.commit() - - # Patch Athena connection with mock object which returns a mocked gnomAD variant row w/ CAID=VALID_CLINGEN_CA_ID. - with ( - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", return_value=[mocked_gnomad_variant_row]), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_no_variants_to_link( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_while_fetching_variant_data( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ), - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", side_effect=Exception()), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_while_linking_variants( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. - mapped_variants = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID - session.commit() - - with ( - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", return_value=[mocked_gnomad_variant_row]), - patch("mavedb.worker.jobs.link_gnomad_variants_to_mapped_variants", side_effect=Exception()), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants