From a94c2fb8c9f06cc407f74d2d467f486efb6bdc22 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 6 Jan 2026 17:02:45 -0800 Subject: [PATCH 01/70] refactor(worker): restructure monolithic jobs.py into modular architecture Break down 1767-line jobs.py into domain-driven modules, improving maintainability and developer experience. - variant_processing/: Variant creation and VRS mapping - external_services/: ClinGen, UniProt, gnomAD integrations - data_management/: Database and view operations - utils/: Shared utilities (state, retry, constants) - registry.py: Centralized ARQ job configuration - constants.py: Environment configuration - redis.py: Redis connection settings - lifecycle.py: Worker lifecycle hooks - worker.py: Main ArqWorkerSettings class - All job functions maintain identical behavior - Registry provides BACKGROUND_FUNCTIONS/BACKGROUND_CRONJOBS lists for ARQ initialization - Test structure mirrors source organization This refactor ensures ARQ worker initialization is backwards compatible. The modular architecture establishes a more maintainable foundation for MaveDB's automated processing workflows while preserving all existing functionality. --- src/mavedb/worker/jobs.py | 1766 --------- src/mavedb/worker/jobs/__init__.py | 56 + .../worker/jobs/data_management/__init__.py | 16 + .../worker/jobs/data_management/views.py | 34 + .../worker/jobs/external_services/__init__.py | 28 + .../worker/jobs/external_services/clingen.py | 637 +++ .../worker/jobs/external_services/gnomad.py | 140 + .../worker/jobs/external_services/uniprot.py | 230 ++ src/mavedb/worker/jobs/py.typed | 0 src/mavedb/worker/jobs/registry.py | 63 + src/mavedb/worker/jobs/utils/__init__.py | 30 + src/mavedb/worker/jobs/utils/constants.py | 17 + src/mavedb/worker/jobs/utils/job_state.py | 35 + src/mavedb/worker/jobs/utils/retry.py | 61 + .../jobs/variant_processing/__init__.py | 19 + .../jobs/variant_processing/creation.py | 196 + .../worker/jobs/variant_processing/mapping.py | 569 +++ src/mavedb/worker/py.typed | 0 src/mavedb/worker/settings.py | 94 - src/mavedb/worker/settings/__init__.py | 19 + src/mavedb/worker/settings/constants.py | 12 + src/mavedb/worker/settings/lifecycle.py | 35 + src/mavedb/worker/settings/redis.py | 12 + src/mavedb/worker/settings/worker.py | 33 + tests/conftest_optional.py | 11 +- tests/helpers/util/mapping.py | 6 + tests/helpers/util/setup/worker.py | 154 + .../jobs/external_services/test_clingen.py | 879 +++++ .../jobs/external_services/test_gnomad.py | 206 + .../jobs/external_services/test_uniprot.py | 603 +++ .../jobs/variant_processing/test_creation.py | 557 +++ .../jobs/variant_processing/test_mapping.py | 710 ++++ tests/worker/test_jobs.py | 3479 ----------------- 33 files changed, 5362 insertions(+), 5345 deletions(-) delete mode 100644 src/mavedb/worker/jobs.py create mode 100644 src/mavedb/worker/jobs/__init__.py create mode 100644 src/mavedb/worker/jobs/data_management/__init__.py create mode 100644 src/mavedb/worker/jobs/data_management/views.py create mode 100644 src/mavedb/worker/jobs/external_services/__init__.py create mode 100644 src/mavedb/worker/jobs/external_services/clingen.py create mode 100644 src/mavedb/worker/jobs/external_services/gnomad.py create mode 100644 src/mavedb/worker/jobs/external_services/uniprot.py create mode 100644 src/mavedb/worker/jobs/py.typed create mode 100644 src/mavedb/worker/jobs/registry.py create mode 100644 src/mavedb/worker/jobs/utils/__init__.py create mode 100644 src/mavedb/worker/jobs/utils/constants.py create mode 100644 src/mavedb/worker/jobs/utils/job_state.py create mode 100644 src/mavedb/worker/jobs/utils/retry.py create mode 100644 src/mavedb/worker/jobs/variant_processing/__init__.py create mode 100644 src/mavedb/worker/jobs/variant_processing/creation.py create mode 100644 src/mavedb/worker/jobs/variant_processing/mapping.py create mode 100644 src/mavedb/worker/py.typed delete mode 100644 src/mavedb/worker/settings.py create mode 100644 src/mavedb/worker/settings/__init__.py create mode 100644 src/mavedb/worker/settings/constants.py create mode 100644 src/mavedb/worker/settings/lifecycle.py create mode 100644 src/mavedb/worker/settings/redis.py create mode 100644 src/mavedb/worker/settings/worker.py create mode 100644 tests/helpers/util/mapping.py create mode 100644 tests/helpers/util/setup/worker.py create mode 100644 tests/worker/jobs/external_services/test_clingen.py create mode 100644 tests/worker/jobs/external_services/test_gnomad.py create mode 100644 tests/worker/jobs/external_services/test_uniprot.py create mode 100644 tests/worker/jobs/variant_processing/test_creation.py create mode 100644 tests/worker/jobs/variant_processing/test_mapping.py delete mode 100644 tests/worker/test_jobs.py diff --git a/src/mavedb/worker/jobs.py b/src/mavedb/worker/jobs.py deleted file mode 100644 index 3a690d97..00000000 --- a/src/mavedb/worker/jobs.py +++ /dev/null @@ -1,1766 +0,0 @@ -import asyncio -import functools -import logging -from contextlib import asynccontextmanager -from datetime import date, timedelta -from typing import Any, Optional, Sequence - -import pandas as pd -from arq import ArqRedis -from arq.jobs import Job, JobStatus -from cdot.hgvs.dataproviders import RESTDataProvider -from sqlalchemy import cast, delete, null, select -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Session - -from mavedb.data_providers.services import vrs_mapper -from mavedb.db.view import refresh_all_mat_views -from mavedb.lib.clingen.constants import ( - CAR_SUBMISSION_ENDPOINT, - CLIN_GEN_SUBMISSION_ENABLED, - DEFAULT_LDH_SUBMISSION_BATCH_SIZE, - LDH_SUBMISSION_ENDPOINT, - LINKED_DATA_RETRY_THRESHOLD, -) -from mavedb.lib.clingen.content_constructors import construct_ldh_submission -from mavedb.lib.clingen.services import ( - ClinGenAlleleRegistryService, - ClinGenLdhService, - clingen_allele_id_from_ldh_variation, - get_allele_registry_associations, - get_clingen_variation, -) -from mavedb.lib.exceptions import ( - LinkingEnqueueError, - MappingEnqueueError, - NonexistentMappingReferenceError, - NonexistentMappingResultsError, - SubmissionEnqueueError, - UniProtIDMappingEnqueueError, - UniProtPollingEnqueueError, -) -from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants -from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.mapping import ANNOTATION_LAYERS, extract_ids_from_post_mapped_metadata -from mavedb.lib.score_sets import ( - columns_for_dataset, - create_variants, - create_variants_data, -) -from mavedb.lib.slack import log_and_send_slack_message, send_slack_error, send_slack_message -from mavedb.lib.uniprot.constants import UNIPROT_ID_MAPPING_ENABLED -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession -from mavedb.lib.validation.dataframe.dataframe import ( - validate_and_standardize_dataframe_pair, -) -from mavedb.lib.validation.exceptions import ValidationError -from mavedb.lib.variants import get_hgvs_from_post_mapped -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.enums.processing_state import ProcessingState -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.published_variant import PublishedVariantsMV -from mavedb.models.score_set import ScoreSet -from mavedb.models.user import User -from mavedb.models.variant import Variant -from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata - -logger = logging.getLogger(__name__) - -MAPPING_QUEUE_NAME = "vrs_mapping_queue" -MAPPING_CURRENT_ID_NAME = "vrs_mapping_current_job_id" -BACKOFF_LIMIT = 5 -MAPPING_BACKOFF_IN_SECONDS = 15 -LINKING_BACKOFF_IN_SECONDS = 15 * 60 - - -#################################################################################################### -# Job utilities -#################################################################################################### - - -def setup_job_state( - ctx, invoker: Optional[int], resource: Optional[str], correlation_id: Optional[str] -) -> dict[str, Any]: - ctx["state"][ctx["job_id"]] = { - "application": "mavedb-worker", - "user": invoker, - "resource": resource, - "correlation_id": correlation_id, - } - return ctx["state"][ctx["job_id"]] - - -async def enqueue_job_with_backoff( - redis: ArqRedis, job_name: str, attempt: int, backoff: int, *args -) -> tuple[Optional[str], bool, Any]: - new_job_id = None - limit_reached = attempt > BACKOFF_LIMIT - if not limit_reached: - limit_reached = True - backoff = backoff * (2**attempt) - attempt = attempt + 1 - - # NOTE: for jobs supporting backoff, `attempt` should be the final argument. - new_job = await redis.enqueue_job( - job_name, - *args, - attempt, - _defer_by=timedelta(seconds=backoff), - ) - - if new_job: - new_job_id = new_job.job_id - - return (new_job_id, not limit_reached, backoff) - - -#################################################################################################### -# Creating variants -#################################################################################################### - - -async def create_variants_for_score_set( - ctx, - correlation_id: str, - score_set_id: int, - updater_id: int, - scores: pd.DataFrame, - counts: pd.DataFrame, - score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, - count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, -): - """ - Create variants for a score set. Intended to be run within a worker. - On any raised exception, ensure ProcessingState of score set is set to `failed` prior - to exiting. - """ - logging_context = {} - try: - db: Session = ctx["db"] - hdp: RESTDataProvider = ctx["hdp"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) - logger.info(msg="Began processing of score set variants.", extra=logging_context) - - updated_by = db.scalars(select(User).where(User.id == updater_id)).one() - - score_set.modified_by = updated_by - score_set.processing_state = ProcessingState.processing - score_set.mapping_state = MappingState.pending_variant_processing - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - - db.add(score_set) - db.commit() - db.refresh(score_set) - - if not score_set.target_genes: - logger.warning( - msg="No targets are associated with this score set; could not create variants.", - extra=logging_context, - ) - raise ValueError("Can't create variants when score set has no targets.") - - validated_scores, validated_counts, validated_score_columns_metadata, validated_count_columns_metadata = ( - validate_and_standardize_dataframe_pair( - scores_df=scores, - counts_df=counts, - score_columns_metadata=score_columns_metadata, - count_columns_metadata=count_columns_metadata, - targets=score_set.target_genes, - hdp=hdp, - ) - ) - - score_set.dataset_columns = { - "score_columns": columns_for_dataset(validated_scores), - "count_columns": columns_for_dataset(validated_counts), - "score_columns_metadata": validated_score_columns_metadata - if validated_score_columns_metadata is not None - else {}, - "count_columns_metadata": validated_count_columns_metadata - if validated_count_columns_metadata is not None - else {}, - } - - # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. - if score_set.variants: - existing_variants = db.scalars(select(Variant.id).where(Variant.score_set_id == score_set.id)).all() - db.execute(delete(MappedVariant).where(MappedVariant.variant_id.in_(existing_variants))) - db.execute(delete(Variant).where(Variant.id.in_(existing_variants))) - logging_context["deleted_variants"] = score_set.num_variants - score_set.num_variants = 0 - - logger.info(msg="Deleted existing variants from score set.", extra=logging_context) - - db.flush() - db.refresh(score_set) - - variants_data = create_variants_data(validated_scores, validated_counts, None) - create_variants(db, score_set, variants_data) - - # Validation errors arise from problematic user data. These should be inserted into the database so failures can - # be persisted to them. - except ValidationError as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.processing_errors = {"exception": str(e), "detail": e.triggering_exceptions} - score_set.mapping_state = MappingState.not_attempted - - if score_set.num_variants: - score_set.processing_errors["exception"] = ( - f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" - ) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.warning(msg="Encountered a validation error while processing variants.", extra=logging_context) - - return {"success": False} - - # NOTE: Since these are likely to be internal errors, it makes less sense to add them to the DB and surface them to the end user. - # Catch all non-system exiting exceptions. - except Exception as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.processing_errors = {"exception": str(e), "detail": []} - score_set.mapping_state = MappingState.not_attempted - - if score_set.num_variants: - score_set.processing_errors["exception"] = ( - f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" - ) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.warning(msg="Encountered an internal exception while processing variants.", extra=logging_context) - - send_slack_error(err=e) - return {"success": False} - - # Catch all other exceptions. The exceptions caught here were intented to be system exiting. - except BaseException as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.mapping_state = MappingState.not_attempted - db.commit() - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.error( - msg="Encountered an unhandled exception while creating variants for score set.", extra=logging_context - ) - - # Don't raise BaseExceptions so we may emit canonical logs (TODO: Perhaps they are so problematic we want to raise them anyway). - return {"success": False} - - else: - score_set.processing_state = ProcessingState.success - score_set.processing_errors = null() - - logging_context["created_variants"] = score_set.num_variants - logging_context["processing_state"] = score_set.processing_state.name - logger.info(msg="Finished creating variants in score set.", extra=logging_context) - - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - await redis.enqueue_job("variant_mapper_manager", correlation_id, updater_id) - score_set.mapping_state = MappingState.queued - finally: - db.add(score_set) - db.commit() - db.refresh(score_set) - logger.info(msg="Committed new variants to score set.", extra=logging_context) - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return {"success": True} - - -#################################################################################################### -# Mapping variants -#################################################################################################### - - -@asynccontextmanager -async def mapping_in_execution(redis: ArqRedis, job_id: str): - await redis.set(MAPPING_CURRENT_ID_NAME, job_id) - try: - yield - finally: - await redis.set(MAPPING_CURRENT_ID_NAME, "") - - -async def map_variants_for_score_set( - ctx: dict, correlation_id: str, score_set_id: int, updater_id: int, attempt: int = 1 -) -> dict: - async with mapping_in_execution(redis=ctx["redis"], job_id=ctx["job_id"]): - logging_context = {} - score_set = None - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) - logging_context["attempt"] = attempt - logger.info(msg="Started variant mapping", extra=logging_context) - - score_set.mapping_state = MappingState.processing - score_set.mapping_errors = null() - db.add(score_set) - db.commit() - - mapping_urn = score_set.urn - assert mapping_urn, "A valid URN is needed to map this score set." - - logging_context["current_mapping_resource"] = mapping_urn - logging_context["mapping_state"] = score_set.mapping_state - logger.debug(msg="Fetched score set metadata for mapping job.", extra=logging_context) - - # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. - vrs = vrs_mapper() - blocking = functools.partial(vrs.map_score_set, mapping_urn) - loop = asyncio.get_running_loop() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Variant mapper encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - db.rollback() - if score_set: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - - return {"success": False, "retried": False, "enqueued_jobs": []} - - mapping_results = None - try: - mapping_results = await loop.run_in_executor(ctx["pool"], blocking) - logger.debug(msg="Done mapping variants.", extra=logging_context) - - except Exception as e: - db.rollback() - score_set.mapping_errors = { - "error_message": f"Encountered an internal server error during mapping. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." - } - db.add(score_set) - db.commit() - - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.warning( - msg="Variant mapper encountered an unexpected error while mapping variants. This job will be retried.", - extra=logging_context, - ) - - new_job_id = None - max_retries_exceeded = None - try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id - ) - # If we fail to enqueue a mapping manager for this score set, evict it from the queue. - if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as backoff_e: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - send_slack_error(backoff_e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} - logger.critical( - msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", - extra=logging_context, - ) - else: - if new_job_id and not max_retries_exceeded: - score_set.mapping_state = MappingState.queued - db.add(score_set) - db.commit() - logger.info( - msg="After encountering an error while mapping variants, another mapping job was queued.", - extra=logging_context, - ) - elif new_job_id is None and not max_retries_exceeded: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - logger.error( - msg="After encountering an error while mapping variants, another mapping job was unable to be queued. This score set will not be mapped.", - extra=logging_context, - ) - else: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - logger.error( - msg="After encountering an error while mapping variants, the maximum retries for this job were exceeded. This score set will not be mapped.", - extra=logging_context, - ) - finally: - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_jobs": [job for job in [new_job_id] if job], - } - - try: - if mapping_results: - mapped_scores = mapping_results.get("mapped_scores") - if not mapped_scores: - # if there are no mapped scores, the score set failed to map. - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} - else: - reference_metadata = mapping_results.get("reference_sequences") - if not reference_metadata: - raise NonexistentMappingReferenceError() - - for target_gene_identifier in reference_metadata: - target_gene = next( - ( - target_gene - for target_gene in score_set.target_genes - if target_gene.name == target_gene_identifier - ), - None, - ) - if not target_gene: - raise ValueError( - f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." - ) - # allow for multiple annotation layers - pre_mapped_metadata: dict[str, Any] = {} - post_mapped_metadata: dict[str, Any] = {} - excluded_pre_mapped_keys = {"sequence"} - - gene_info = reference_metadata[target_gene_identifier].get("gene_info") - if gene_info: - target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") - post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") - - for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: - layer_premapped = reference_metadata[target_gene_identifier]["layers"][ - annotation_layer - ].get("computed_reference_sequence") - if layer_premapped: - pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { - k: layer_premapped[k] - for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys - } - layer_postmapped = reference_metadata[target_gene_identifier]["layers"][ - annotation_layer - ].get("mapped_reference_sequence") - if layer_postmapped: - post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped - target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) - target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) - - total_variants = 0 - successful_mapped_variants = 0 - for mapped_score in mapped_scores: - total_variants += 1 - variant_urn = mapped_score.get("mavedb_id") - variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() - - # there should only be one current mapped variant per variant id, so update old mapped variant to current = false - existing_mapped_variant = ( - db.query(MappedVariant) - .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) - .one_or_none() - ) - - if existing_mapped_variant: - existing_mapped_variant.current = False - db.add(existing_mapped_variant) - - if mapped_score.get("pre_mapped") and mapped_score.get("post_mapped"): - successful_mapped_variants += 1 - - mapped_variant = MappedVariant( - pre_mapped=mapped_score.get("pre_mapped", null()), - post_mapped=mapped_score.get("post_mapped", null()), - variant_id=variant.id, - modification_date=date.today(), - mapped_date=mapping_results["mapped_date_utc"], - vrs_version=mapped_score.get("vrs_version", null()), - mapping_api_version=mapping_results["dcd_mapping_version"], - error_message=mapped_score.get("error_message", null()), - current=True, - ) - db.add(mapped_variant) - - if successful_mapped_variants == 0: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "All variants failed to map"} - elif successful_mapped_variants < total_variants: - score_set.mapping_state = MappingState.incomplete - else: - score_set.mapping_state = MappingState.complete - - logging_context["mapped_variants_inserted_db"] = len(mapped_scores) - logging_context["variants_successfully_mapped"] = successful_mapped_variants - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["mapping_errors"] = score_set.mapping_errors - logger.info(msg="Inserted mapped variants into db.", extra=logging_context) - - else: - raise NonexistentMappingResultsError() - - db.add(score_set) - db.commit() - - except Exception as e: - db.rollback() - score_set.mapping_errors = { - "error_message": f"Encountered an unexpected error while parsing mapped variants. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." - } - db.add(score_set) - db.commit() - - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.warning( - msg="An unexpected error occurred during variant mapping. This job will be attempted again.", - extra=logging_context, - ) - - new_job_id = None - max_retries_exceeded = None - try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id - ) - # If we fail to enqueue a mapping manager for this score set, evict it from the queue. - if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as backoff_e: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - send_slack_error(backoff_e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} - logger.critical( - msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", - extra=logging_context, - ) - else: - if new_job_id and not max_retries_exceeded: - score_set.mapping_state = MappingState.queued - logger.info( - msg="After encountering an error while parsing mapped variants, another mapping job was queued.", - extra=logging_context, - ) - elif new_job_id is None and not max_retries_exceeded: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - logger.error( - msg="After encountering an error while parsing mapped variants, another mapping job was unable to be queued. This score set will not be mapped.", - extra=logging_context, - ) - else: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - logger.error( - msg="After encountering an error while parsing mapped variants, the maximum retries for this job were exceeded. This score set will not be mapped.", - extra=logging_context, - ) - finally: - db.add(score_set) - db.commit() - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_jobs": [job for job in [new_job_id] if job], - } - - new_uniprot_job_id = None - try: - if UNIPROT_ID_MAPPING_ENABLED: - new_job = await redis.enqueue_job( - "submit_uniprot_mapping_jobs_for_score_set", - score_set.id, - correlation_id, - ) - - if new_job: - new_uniprot_job_id = new_job.job_id - - logging_context["submit_uniprot_mapping_job_id"] = new_uniprot_job_id - logger.info(msg="Queued a new UniProt mapping job.", extra=logging_context) - - else: - raise UniProtIDMappingEnqueueError() - else: - logger.warning( - msg="UniProt ID mapping is disabled, skipped submission of UniProt mapping jobs.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not enqueue UniProt mapping job for score set {score_set.urn}. UniProt mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant UniProt submission encountered an unexpected error while attempting to enqueue a mapping job. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_uniprot_job_id] if job]} - - new_clingen_job_id = None - try: - if CLIN_GEN_SUBMISSION_ENABLED: - new_job = await redis.enqueue_job( - "submit_score_set_mappings_to_car", - correlation_id, - score_set.id, - ) - - if new_job: - new_clingen_job_id = new_job.job_id - - logging_context["submit_clingen_variants_job_id"] = new_clingen_job_id - logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) - - else: - raise SubmissionEnqueueError() - else: - logger.warning( - msg="ClinGen submission is disabled, skipped submission of mapped variants to CAR and LDH.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not submit mappings to CAR and/or LDH mappings for score set {score_set.urn}. Mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", - extra=logging_context, - ) - - return { - "success": False, - "retried": False, - "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], - } - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return { - "success": True, - "retried": False, - "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], - } - - -async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int, attempt: int = 1) -> dict: - logging_context = {} - mapping_job_id = None - mapping_job_status = None - queued_score_set = None - try: - redis: ArqRedis = ctx["redis"] - db: Session = ctx["db"] - - logging_context = setup_job_state(ctx, updater_id, None, correlation_id) - logging_context["attempt"] = attempt - logger.debug(msg="Variant mapping manager began execution", extra=logging_context) - - queue_length = await redis.llen(MAPPING_QUEUE_NAME) # type: ignore - queued_id = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore - logging_context["variant_mapping_queue_length"] = queue_length - - # Setup the job id cache if it does not already exist. - if not await redis.exists(MAPPING_CURRENT_ID_NAME): - await redis.set(MAPPING_CURRENT_ID_NAME, "") - - if not queued_id: - logger.debug(msg="No mapping jobs exist in the queue.", extra=logging_context) - return {"success": True, "enqueued_job": None} - else: - queued_id = queued_id.decode("utf-8") - queued_score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_id)).one() - - logging_context["upcoming_mapping_resource"] = queued_score_set.urn - logger.debug(msg="Found mapping job(s) still in queue.", extra=logging_context) - - mapping_job_id = await redis.get(MAPPING_CURRENT_ID_NAME) - if mapping_job_id: - mapping_job_id = mapping_job_id.decode("utf-8") - mapping_job_status = (await Job(job_id=mapping_job_id, redis=redis).status()).value - - logging_context["existing_mapping_job_status"] = mapping_job_status - logging_context["existing_mapping_job_id"] = mapping_job_id - - except Exception as e: - send_slack_error(e) - - # Attempt to remove this item from the mapping queue. - try: - await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore - logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) - except Exception: - pass - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error(msg="Variant mapper manager encountered an unexpected error during setup.", extra=logging_context) - - return {"success": False, "enqueued_job": None} - - new_job = None - new_job_id = None - try: - if not mapping_job_id or mapping_job_status in (JobStatus.not_found, JobStatus.complete): - logger.debug(msg="No mapping jobs are running, queuing a new one.", extra=logging_context) - - new_job = await redis.enqueue_job( - "map_variants_for_score_set", correlation_id, queued_score_set.id, updater_id, attempt - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["new_mapping_job_id"] = new_job_id - logger.info(msg="Queued a new mapping job.", extra=logging_context) - - return {"success": True, "enqueued_job": new_job_id} - - logger.info( - msg="A mapping job is already running, or a new job was unable to be enqueued. Deferring mapping by 5 minutes.", - extra=logging_context, - ) - - new_job = await redis.enqueue_job( - "variant_mapper_manager", - correlation_id, - updater_id, - attempt, - _defer_by=timedelta(minutes=5), - ) - - if new_job: - # Ensure this score set remains in the front of the queue. - queued_id = await redis.rpush(MAPPING_QUEUE_NAME, queued_score_set.id) # type: ignore - new_job_id = new_job.job_id - - logging_context["new_mapping_manager_job_id"] = new_job_id - logger.info(msg="Deferred a new mapping manager job.", extra=logging_context) - - # Our persistent Redis queue and ARQ's execution rules ensure that even if the worker is stopped and not restarted - # before the deferred time, these deferred jobs will still run once able. - return {"success": True, "enqueued_job": new_job_id} - - raise MappingEnqueueError() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Variant mapper manager encountered an unexpected error while enqueing a mapping job. This job will not be retried.", - extra=logging_context, - ) - - db.rollback() - - # We shouldn't rely on the passed score set id matching the score set we are operating upon. - if not queued_score_set: - return {"success": False, "enqueued_job": new_job_id} - - # Attempt to remove this item from the mapping queue. - try: - await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore - logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) - except Exception: - pass - - score_set_exc = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_score_set.id)).one_or_none() - if score_set_exc: - score_set_exc.mapping_state = MappingState.failed - score_set_exc.mapping_errors = "Unable to queue a new mapping job or defer score set mapping." - db.add(score_set_exc) - db.commit() - - return {"success": False, "enqueued_job": new_job_id} - - -#################################################################################################### -# Materialized Views -#################################################################################################### - - -# TODO#405: Refresh materialized views within an executor. -async def refresh_materialized_views(ctx: dict): - logging_context = setup_job_state(ctx, None, None, None) - logger.debug(msg="Began refresh materialized views.", extra=logging_context) - refresh_all_mat_views(ctx["db"]) - ctx["db"].commit() - logger.debug(msg="Done refreshing materialized views.", extra=logging_context) - return {"success": True} - - -async def refresh_published_variants_view(ctx: dict, correlation_id: str): - logging_context = setup_job_state(ctx, None, None, correlation_id) - logger.debug(msg="Began refresh of published variants materialized view.", extra=logging_context) - PublishedVariantsMV.refresh(ctx["db"]) - ctx["db"].commit() - logger.debug(msg="Done refreshing published variants materialized view.", extra=logging_context) - return {"success": True} - - -#################################################################################################### -# ClinGen resource creation / linkage -#################################################################################################### - - -async def submit_score_set_mappings_to_car(ctx: dict, correlation_id: str, score_set_id: int): - logging_context = {} - score_set = None - text = "Could not submit mappings to ClinGen Allele Registry for score set %s. Mappings for this score set should be submitted manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started CAR mapped resource submission", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to submit CAR objects for this score set." - - logging_context["current_car_submission_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for CAR mapped resource submission.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="CAR mapped resource submission encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_post_mapped_objects = db.execute( - select(MappedVariant.id, MappedVariant.post_mapped) - .join(Variant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() - - if not variant_post_mapped_objects: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", - extra=logging_context, - ) - return {"success": True, "retried": False, "enqueued_job": None} - - variant_post_mapped_hgvs: dict[str, list[int]] = {} - for mapped_variant_id, post_mapped in variant_post_mapped_objects: - hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) - - if not hgvs_for_post_mapped: - logger.warning( - msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant_id}. Skipping submission of this variant.", - extra=logging_context, - ) - continue - - if hgvs_for_post_mapped in variant_post_mapped_hgvs: - variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) - else: - variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to construct post mapped HGVS strings. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - if not CAR_SUBMISSION_ENDPOINT: - logger.warning( - msg="ClinGen Allele Registry submission is disabled (no submission endpoint), skipping submission of mapped variants to CAR.", - extra=logging_context, - ) - return {"success": False, "retried": False, "enqueued_job": None} - - car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) - registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) - for hgvs_string, caid in linked_alleles.items(): - mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] - mapped_variants = db.scalars(select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids))).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = caid - db.add(mapped_variant) - - db.commit() - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - new_job_id = None - try: - new_job = await redis.enqueue_job( - "submit_score_set_mappings_to_ldh", - correlation_id, - score_set.id, - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["submit_clingen_ldh_variants_job_id"] = new_job_id - logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) - - else: - raise SubmissionEnqueueError() - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not submit mappings to LDH for score set {score_set.urn}. Mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": new_job_id} - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return {"success": True, "retried": False, "enqueued_job": new_job_id} - - -async def submit_score_set_mappings_to_ldh(ctx: dict, correlation_id: str, score_set_id: int): - logging_context = {} - score_set = None - text = ( - "Could not submit mappings to LDH for score set %s. Mappings for this score set should be submitted manually." - ) - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started LDH mapped resource submission", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to submit LDH objects for this score set." - - logging_context["current_ldh_submission_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for ldh mapped resource submission.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) - ldh_service.authenticate() - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() - - if not variant_objects: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", - extra=logging_context, - ) - return {"success": True, "retried": False, "enqueued_job": None} - - variant_content = [] - for variant, mapped_variant in variant_objects: - variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) - - if not variation: - logger.warning( - msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant.id}. Skipping submission of this variant.", - extra=logging_context, - ) - continue - - variant_content.append((variation, variant, mapped_variant)) - - submission_content = construct_ldh_submission(variant_content) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to construct submission objects. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - blocking = functools.partial( - ldh_service.dispatch_submissions, submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE - ) - loop = asyncio.get_running_loop() - submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while dispatching submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - assert not submission_failures, f"{len(submission_failures)} submissions failed to be dispatched to the LDH." - logger.info(msg="Dispatched all variant mapping submissions to the LDH.", extra=logging_context) - except AssertionError as e: - send_slack_error(e) - send_slack_message( - text=f"{len(submission_failures)} submissions failed to be dispatched to the LDH for score set {score_set.urn}." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission failed to submit all mapping resources. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - new_job_id = None - try: - new_job = await redis.enqueue_job( - "link_clingen_variants", - correlation_id, - score_set.id, - 1, - _defer_by=timedelta(seconds=LINKING_BACKOFF_IN_SECONDS), - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["link_clingen_variants_job_id"] = new_job_id - logger.info(msg="Queued a new ClinGen linking job.", extra=logging_context) - - else: - raise LinkingEnqueueError() - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to enqueue a linking job. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": new_job_id} - - return {"success": True, "retried": False, "enqueued_job": new_job_id} - - -def do_clingen_fetch(variant_urns): - return [(variant_urn, get_clingen_variation(variant_urn)) for variant_urn in variant_urns] - - -async def link_clingen_variants(ctx: dict, correlation_id: str, score_set_id: int, attempt: int) -> dict: - logging_context = {} - score_set = None - text = "Could not link mappings to LDH for score set %s. Mappings for this score set should be linked manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logging_context["linkage_retry_threshold"] = LINKED_DATA_RETRY_THRESHOLD - logging_context["attempt"] = attempt - logging_context["max_attempts"] = BACKOFF_LIMIT - logger.info(msg="Started LDH mapped resource linkage", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to link LDH objects for this score set." - - logging_context["current_ldh_linking_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for ldh mapped resource linkage.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_urns = db.scalars( - select(Variant.urn) - .join(MappedVariant) - .join(ScoreSet) - .where( - ScoreSet.urn == score_set.urn, MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None) - ) - ).all() - num_variant_urns = len(variant_urns) - - logging_context["variants_to_link_ldh"] = num_variant_urns - - if not variant_urns: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH linkage (nothing to do). A gnomAD linkage job will not be enqueued, as no variants will have a CAID.", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - logger.info( - msg="Found current mapped variants with post mapped metadata for this score set. Attempting to link them to LDH submissions.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - logger.info(msg="Attempting to link mapped variants to LDH submissions.", extra=logging_context) - - # TODO#372: Non-nullable variant urns. - blocking = functools.partial( - do_clingen_fetch, - variant_urns, # type: ignore - ) - loop = asyncio.get_running_loop() - linked_data = await loop.run_in_executor(ctx["pool"], blocking) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - linked_allele_ids = [ - (variant_urn, clingen_allele_id_from_ldh_variation(clingen_variation)) - for variant_urn, clingen_variation in linked_data - ] - - linkage_failures = [] - for variant_urn, ldh_variation in linked_allele_ids: - # XXX: Should we unlink variation if it is not found? Does this constitute a failure? - if not ldh_variation: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No LDH variation found.", - extra=logging_context, - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant = db.scalars( - select(MappedVariant).join(Variant).where(Variant.urn == variant_urn, MappedVariant.current.is_(True)) - ).one_or_none() - - if not mapped_variant: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No mapped variant found.", - extra=logging_context, - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant.clingen_allele_id = ldh_variation - db.add(mapped_variant) - - db.commit() - - except Exception as e: - db.rollback() - - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - num_linkage_failures = len(linkage_failures) - ratio_failed_linking = round(num_linkage_failures / num_variant_urns, 3) - logging_context["linkage_failure_rate"] = ratio_failed_linking - logging_context["linkage_failures"] = num_linkage_failures - logging_context["linkage_successes"] = num_variant_urns - num_linkage_failures - - assert ( - len(linked_allele_ids) == num_variant_urns - ), f"{num_variant_urns - len(linked_allele_ids)} appear to not have been attempted to be linked." - - job_succeeded = False - if not linkage_failures: - logger.info( - msg="Successfully linked all mapped variants to LDH submissions.", - extra=logging_context, - ) - - job_succeeded = True - - elif ratio_failed_linking < LINKED_DATA_RETRY_THRESHOLD: - logger.warning( - msg="Linkage failures exist, but did not exceed the retry threshold.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} mapped variants to LDH submissions for score set {score_set.urn}." - f"The retry threshold was not exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." - ) - - job_succeeded = True - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to finalize linkage. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - if job_succeeded: - gnomad_linking_job_id = None - try: - new_job = await redis.enqueue_job( - "link_gnomad_variants", - correlation_id, - score_set.id, - ) - - if new_job: - gnomad_linking_job_id = new_job.job_id - - logging_context["link_gnomad_variants_job_id"] = gnomad_linking_job_id - logger.info(msg="Queued a new gnomAD linking job.", extra=logging_context) - - else: - raise LinkingEnqueueError() - - except Exception as e: - job_succeeded = False - - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to enqueue a gnomAD linking job. GnomAD variants should be linked manually for this score set. This job will not be retried.", - extra=logging_context, - ) - finally: - return {"success": job_succeeded, "retried": False, "enqueued_job": gnomad_linking_job_id} - - # If we reach this point, we should consider the job failed (there were failures which exceeded our retry threshold). - new_job_id = None - max_retries_exceeded = None - try: - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - ctx["redis"], "variant_mapper_manager", attempt, LINKING_BACKOFF_IN_SECONDS, correlation_id - ) - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.critical( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to retry a failed linkage job. This job will not be retried.", - extra=logging_context, - ) - else: - if new_job_id and not max_retries_exceeded: - logger.info( - msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was queued.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking * 100}% of total mapped variants for {score_set.urn})." - f"This job was successfully retried. This was attempt {attempt}. Retry will occur in {backoff_time} seconds. URNs failed to link: {', '.join(linkage_failures)}." - ) - elif new_job_id is None and not max_retries_exceeded: - logger.error( - msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was unable to be queued.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." - f"This job could not be retried due to an unexpected issue while attempting to enqueue another linkage job. This was attempt {attempt}. URNs failed to link: {', '.join(linkage_failures)}." - ) - else: - logger.error( - msg="After a failure condition while linking mapped variants to LDH submissions, the maximum retries for this job were exceeded. The reamining linkage failures will not be retried.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." - f"The retry threshold was exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." - ) - - finally: - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_job": new_job_id, - } - - -######################################################################################################## -# Mapping between Mapped Metadata and UniProt IDs -######################################################################################################## - - -async def submit_uniprot_mapping_jobs_for_score_set(ctx, score_set_id: int, correlation_id: Optional[str] = None): - logging_context = {} - score_set = None - spawned_mapping_jobs: dict[int, Optional[str]] = {} - text = "Could not submit mapping jobs to UniProt for this score set %s. Mapping jobs for this score set should be submitted manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started UniProt mapping job", extra=logging_context) - - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set_id}. Skipped mapping targets to UniProt." - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) - - return {"success": True, "retried": False, "enqueued_jobs": []} - - except Exception as e: - send_slack_error(e) - if score_set: - msg = text % score_set.urn - else: - msg = text % score_set_id - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - try: - uniprot_api = UniProtIDMappingAPI() - logging_context["total_target_genes_to_map_to_uniprot"] = len(score_set.target_genes) - for target_gene in score_set.target_genes: - spawned_mapping_jobs[target_gene.id] = None # type: ignore - - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - ac_to_map = acs[0] - from_db = infer_db_name_from_sequence_accession(ac_to_map) - - try: - spawned_mapping_jobs[target_gene.id] = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore - except Exception as e: - log_and_send_slack_message( - msg=f"Failed to submit UniProt mapping job for target gene {target_gene.id}: {e}. This target will be skipped.", - ctx=logging_context, - level=logging.WARNING, - ) - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg=f"UniProt mapping job encountered an unexpected error while attempting to submit mapping jobs for score set {score_set.urn}. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, - ) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - new_job_id = None - try: - successfully_spawned_mapping_jobs = sum(1 for job in spawned_mapping_jobs.values() if job is not None) - logging_context["successfully_spawned_mapping_jobs"] = successfully_spawned_mapping_jobs - - if not successfully_spawned_mapping_jobs: - msg = f"No UniProt mapping jobs were successfully spawned for score set {score_set.urn}. Skipped enqueuing polling job." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - return {"success": True, "retried": False, "enqueued_jobs": []} - - new_job = await redis.enqueue_job( - "poll_uniprot_mapping_jobs_for_score_set", - spawned_mapping_jobs, - score_set_id, - correlation_id, - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["poll_uniprot_mapping_job_id"] = new_job_id - logger.info(msg="Enqueued polling jobs for UniProt mapping jobs.", extra=logging_context) - - else: - raise UniProtPollingEnqueueError() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg="UniProt mapping job encountered an unexpected error while attempting to enqueue polling jobs for mapping jobs. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, - ) - - return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} - - return {"success": True, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} - - -async def poll_uniprot_mapping_jobs_for_score_set( - ctx, mapping_jobs: dict[int, Optional[str]], score_set_id: int, correlation_id: Optional[str] = None -): - logging_context = {} - score_set = None - text = "Could not poll mapping jobs from UniProt for this Target %s. Mapping jobs for this score set should be submitted manually." - try: - db: Session = ctx["db"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started UniProt polling job", extra=logging_context) - - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set_id}. Skipped polling targets for UniProt mapping results." - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) - - return {"success": True, "retried": False, "enqueued_jobs": []} - - except Exception as e: - send_slack_error(e) - if score_set: - msg = text % score_set.urn - else: - msg = text % score_set_id - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - try: - uniprot_api = UniProtIDMappingAPI() - for target_gene in score_set.target_genes: - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - mapped_ac = acs[0] - job_id = mapping_jobs.get(target_gene.id) # type: ignore - - if not job_id: - msg = f"No job ID found for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - # This issue has already been sent to Slack in the job submission function, so we just log it here. - logger.debug(msg=msg, extra=logging_context) - continue - - if not uniprot_api.check_id_mapping_results_ready(job_id): - msg = f"Job {job_id} not ready for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target" - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - results = uniprot_api.get_id_mapping_results(job_id) - mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) - - if not mapped_ids: - msg = f"No UniProt ID found for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(mapped_ids) != 1: - msg = f"Found ambiguous Uniprot ID mapping results for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] - target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id - db.add(target_gene) - logger.info( - msg=f"Updated target gene {target_gene.id} with UniProt ID {mapped_uniprot_id}", extra=logging_context - ) - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg="UniProt mapping job encountered an unexpected error while attempting to poll mapping jobs. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, - ) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - db.commit() - return {"success": True, "retried": False, "enqueued_jobs": []} - - -#################################################################################################### -# gnomAD Variant Linkage -#################################################################################################### - - -async def link_gnomad_variants(ctx: dict, correlation_id: str, score_set_id: int) -> dict: - logging_context = {} - score_set = None - text = "Could not link mappings to gnomAD variants for score set %s. Mappings for this score set should be linked manually." - try: - db: Session = ctx["db"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started gnomAD variant linkage", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to link gnomAD objects for this score set." - - logging_context["current_gnomad_linking_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for gnomAD mapped resource linkage.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - # We filter out mapped variants that do not have a CAID, so this query is typed # as a Sequence[str]. Ignore MyPy's type checking here. - variant_caids: Sequence[str] = db.scalars( - select(MappedVariant.clingen_allele_id) - .join(Variant) - .join(ScoreSet) - .where( - ScoreSet.urn == score_set.urn, - MappedVariant.current.is_(True), - MappedVariant.clingen_allele_id.is_not(None), - ) - ).all() # type: ignore - num_variant_caids = len(variant_caids) - - logging_context["num_variants_to_link_gnomad"] = num_variant_caids - - if not variant_caids: - logger.warning( - msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - logger.info( - msg="Found current mapped variants with CAIDs for this score set. Attempting to link them to gnomAD variants.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - gnomad_variant_data = gnomad_variant_data_for_caids(variant_caids) - num_gnomad_variants_with_caid_match = len(gnomad_variant_data) - logging_context["num_gnomad_variants_with_caid_match"] = num_gnomad_variants_with_caid_match - - if not gnomad_variant_data: - logger.warning( - msg="No gnomAD variants with CAID matches were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to fetch gnomAD variant data from S3 via Athena. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=logging_context) - num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(db, gnomad_variant_data) - db.commit() - logging_context["num_mapped_variants_linked_to_gnomad_variants"] = num_linked_gnomad_variants - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=logging_context) - return {"success": True, "retried": False, "enqueued_job": None} diff --git a/src/mavedb/worker/jobs/__init__.py b/src/mavedb/worker/jobs/__init__.py new file mode 100644 index 00000000..15614fd0 --- /dev/null +++ b/src/mavedb/worker/jobs/__init__.py @@ -0,0 +1,56 @@ +"""MaveDB Worker Job Functions. + +This package contains all worker job functions organized by domain: +- variant_processing: Variant creation and VRS mapping jobs +- external_services: Third-party service integration jobs (ClinGen, UniProt, gnomAD) +- data_management: Database and materialized view management jobs +- utils: Shared utilities for job state, retry logic, and constants + +All job functions are exported at the package level for easy import +by the worker settings and other modules. Additionally, a job registry +is provided for ARQ worker configuration. +""" + +from mavedb.worker.jobs.data_management.views import ( + refresh_materialized_views, + refresh_published_variants_view, +) +from mavedb.worker.jobs.external_services.clingen import ( + link_clingen_variants, + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, +) +from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants +from mavedb.worker.jobs.external_services.uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.jobs.registry import ( + BACKGROUND_CRONJOBS, + BACKGROUND_FUNCTIONS, +) +from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set +from mavedb.worker.jobs.variant_processing.mapping import ( + map_variants_for_score_set, + variant_mapper_manager, +) + +__all__ = [ + # Variant processing jobs + "create_variants_for_score_set", + "map_variants_for_score_set", + "variant_mapper_manager", + # External service integration jobs + "link_clingen_variants", + "submit_score_set_mappings_to_car", + "submit_score_set_mappings_to_ldh", + "poll_uniprot_mapping_jobs_for_score_set", + "submit_uniprot_mapping_jobs_for_score_set", + "link_gnomad_variants", + # Data management jobs + "refresh_materialized_views", + "refresh_published_variants_view", + # Job registry and utilities + "BACKGROUND_FUNCTIONS", + "BACKGROUND_CRONJOBS", +] diff --git a/src/mavedb/worker/jobs/data_management/__init__.py b/src/mavedb/worker/jobs/data_management/__init__.py new file mode 100644 index 00000000..63502581 --- /dev/null +++ b/src/mavedb/worker/jobs/data_management/__init__.py @@ -0,0 +1,16 @@ +"""Data management job functions. + +This module exports jobs for database and view management: +- Materialized view refresh for optimized query performance +- Database maintenance and cleanup operations +""" + +from .views import ( + refresh_materialized_views, + refresh_published_variants_view, +) + +__all__ = [ + "refresh_materialized_views", + "refresh_published_variants_view", +] diff --git a/src/mavedb/worker/jobs/data_management/views.py b/src/mavedb/worker/jobs/data_management/views.py new file mode 100644 index 00000000..a6ddb2d6 --- /dev/null +++ b/src/mavedb/worker/jobs/data_management/views.py @@ -0,0 +1,34 @@ +"""Database materialized view refresh jobs. + +This module contains jobs for refreshing materialized views used throughout +the MaveDB application. Materialized views provide optimized, pre-computed +data for complex queries and are refreshed periodically to maintain +data consistency and performance. +""" + +import logging + +from mavedb.db.view import refresh_all_mat_views +from mavedb.models.published_variant import PublishedVariantsMV +from mavedb.worker.jobs.utils.job_state import setup_job_state + +logger = logging.getLogger(__name__) + + +# TODO#405: Refresh materialized views within an executor. +async def refresh_materialized_views(ctx: dict): + logging_context = setup_job_state(ctx, None, None, None) + logger.debug(msg="Began refresh materialized views.", extra=logging_context) + refresh_all_mat_views(ctx["db"]) + ctx["db"].commit() + logger.debug(msg="Done refreshing materialized views.", extra=logging_context) + return {"success": True} + + +async def refresh_published_variants_view(ctx: dict, correlation_id: str): + logging_context = setup_job_state(ctx, None, None, correlation_id) + logger.debug(msg="Began refresh of published variants materialized view.", extra=logging_context) + PublishedVariantsMV.refresh(ctx["db"]) + ctx["db"].commit() + logger.debug(msg="Done refreshing published variants materialized view.", extra=logging_context) + return {"success": True} diff --git a/src/mavedb/worker/jobs/external_services/__init__.py b/src/mavedb/worker/jobs/external_services/__init__.py new file mode 100644 index 00000000..60135efe --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/__init__.py @@ -0,0 +1,28 @@ +"""External service integration job functions. + +This module exports jobs for integrating with third-party services: +- ClinGen (Clinical Genome Resource) for allele registration and data submission +- UniProt for protein sequence annotation and ID mapping +- gnomAD for population frequency and genomic context data +""" + +# External services job functions +from .clingen import ( + link_clingen_variants, + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, +) +from .gnomad import link_gnomad_variants +from .uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) + +__all__ = [ + "link_clingen_variants", + "submit_score_set_mappings_to_car", + "submit_score_set_mappings_to_ldh", + "link_gnomad_variants", + "poll_uniprot_mapping_jobs_for_score_set", + "submit_uniprot_mapping_jobs_for_score_set", +] diff --git a/src/mavedb/worker/jobs/external_services/clingen.py b/src/mavedb/worker/jobs/external_services/clingen.py new file mode 100644 index 00000000..06a7c53d --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/clingen.py @@ -0,0 +1,637 @@ +"""ClinGen integration jobs for variant submission and linking. + +This module contains jobs for submitting mapped variants to ClinGen services: +- ClinGen Allele Registry (CAR) for allele registration +- ClinGen Linked Data Hub (LDH) for data submission +- Variant linking and association management + +These jobs enable integration with the ClinGen ecosystem for clinical +variant interpretation and data sharing. +""" + +import asyncio +import functools +import logging +from datetime import timedelta + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.lib.clingen.constants import ( + CAR_SUBMISSION_ENDPOINT, + DEFAULT_LDH_SUBMISSION_BATCH_SIZE, + LDH_SUBMISSION_ENDPOINT, + LINKED_DATA_RETRY_THRESHOLD, +) +from mavedb.lib.clingen.content_constructors import construct_ldh_submission +from mavedb.lib.clingen.services import ( + ClinGenAlleleRegistryService, + ClinGenLdhService, + clingen_allele_id_from_ldh_variation, + get_allele_registry_associations, + get_clingen_variation, +) +from mavedb.lib.exceptions import LinkingEnqueueError, SubmissionEnqueueError +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.slack import send_slack_error, send_slack_message +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.constants import ENQUEUE_BACKOFF_ATTEMPT_LIMIT, LINKING_BACKOFF_IN_SECONDS +from mavedb.worker.jobs.utils.job_state import setup_job_state +from mavedb.worker.jobs.utils.retry import enqueue_job_with_backoff + +logger = logging.getLogger(__name__) + + +async def submit_score_set_mappings_to_car(ctx: dict, correlation_id: str, score_set_id: int): + logging_context = {} + score_set = None + text = "Could not submit mappings to ClinGen Allele Registry for score set %s. Mappings for this score set should be submitted manually." + try: + db: Session = ctx["db"] + redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) + logger.info(msg="Started CAR mapped resource submission", extra=logging_context) + + submission_urn = score_set.urn + assert submission_urn, "A valid URN is needed to submit CAR objects for this score set." + + logging_context["current_car_submission_resource"] = submission_urn + logger.debug(msg="Fetched score set metadata for CAR mapped resource submission.", extra=logging_context) + + except Exception as e: + send_slack_error(e) + if score_set: + send_slack_message(text=text % score_set.urn) + else: + send_slack_message(text=text % score_set_id) + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="CAR mapped resource submission encountered an unexpected error during setup. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + variant_post_mapped_objects = db.execute( + select(MappedVariant.id, MappedVariant.post_mapped) + .join(Variant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn) + .where(MappedVariant.post_mapped.is_not(None)) + .where(MappedVariant.current.is_(True)) + ).all() + + if not variant_post_mapped_objects: + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", + extra=logging_context, + ) + return {"success": True, "retried": False, "enqueued_job": None} + + variant_post_mapped_hgvs: dict[str, list[int]] = {} + for mapped_variant_id, post_mapped in variant_post_mapped_objects: + hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) + + if not hgvs_for_post_mapped: + logger.warning( + msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant_id}. Skipping submission of this variant.", + extra=logging_context, + ) + continue + + if hgvs_for_post_mapped in variant_post_mapped_hgvs: + variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) + else: + variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while attempting to construct post mapped HGVS strings. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + if not CAR_SUBMISSION_ENDPOINT: + logger.warning( + msg="ClinGen Allele Registry submission is disabled (no submission endpoint), skipping submission of mapped variants to CAR.", + extra=logging_context, + ) + return {"success": False, "retried": False, "enqueued_job": None} + + car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) + registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) + for hgvs_string, caid in linked_alleles.items(): + mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] + mapped_variants = db.scalars(select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids))).all() + + for mapped_variant in mapped_variants: + mapped_variant.clingen_allele_id = caid + db.add(mapped_variant) + + db.commit() + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + new_job_id = None + try: + new_job = await redis.enqueue_job( + "submit_score_set_mappings_to_ldh", + correlation_id, + score_set.id, + ) + + if new_job: + new_job_id = new_job.job_id + + logging_context["submit_clingen_ldh_variants_job_id"] = new_job_id + logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) + + else: + raise SubmissionEnqueueError() + + except Exception as e: + send_slack_error(e) + send_slack_message( + f"Could not submit mappings to LDH for score set {score_set.urn}. Mappings for this score set should be submitted manually." + ) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": new_job_id} + + ctx["state"][ctx["job_id"]] = logging_context.copy() + return {"success": True, "retried": False, "enqueued_job": new_job_id} + + +async def submit_score_set_mappings_to_ldh(ctx: dict, correlation_id: str, score_set_id: int): + logging_context = {} + score_set = None + text = ( + "Could not submit mappings to LDH for score set %s. Mappings for this score set should be submitted manually." + ) + try: + db: Session = ctx["db"] + redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) + logger.info(msg="Started LDH mapped resource submission", extra=logging_context) + + submission_urn = score_set.urn + assert submission_urn, "A valid URN is needed to submit LDH objects for this score set." + + logging_context["current_ldh_submission_resource"] = submission_urn + logger.debug(msg="Fetched score set metadata for ldh mapped resource submission.", extra=logging_context) + + except Exception as e: + send_slack_error(e) + if score_set: + send_slack_message(text=text % score_set.urn) + else: + send_slack_message(text=text % score_set_id) + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error during setup. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) + ldh_service.authenticate() + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + variant_objects = db.execute( + select(Variant, MappedVariant) + .join(MappedVariant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn) + .where(MappedVariant.post_mapped.is_not(None)) + .where(MappedVariant.current.is_(True)) + ).all() + + if not variant_objects: + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", + extra=logging_context, + ) + return {"success": True, "retried": False, "enqueued_job": None} + + variant_content = [] + for variant, mapped_variant in variant_objects: + variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) + + if not variation: + logger.warning( + msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant.id}. Skipping submission of this variant.", + extra=logging_context, + ) + continue + + variant_content.append((variation, variant, mapped_variant)) + + submission_content = construct_ldh_submission(variant_content) + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while attempting to construct submission objects. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + blocking = functools.partial( + ldh_service.dispatch_submissions, submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE + ) + loop = asyncio.get_running_loop() + submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while dispatching submissions. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + assert not submission_failures, f"{len(submission_failures)} submissions failed to be dispatched to the LDH." + logger.info(msg="Dispatched all variant mapping submissions to the LDH.", extra=logging_context) + except AssertionError as e: + send_slack_error(e) + send_slack_message( + text=f"{len(submission_failures)} submissions failed to be dispatched to the LDH for score set {score_set.urn}." + ) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission failed to submit all mapping resources. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + new_job_id = None + try: + new_job = await redis.enqueue_job( + "link_clingen_variants", + correlation_id, + score_set.id, + 1, + _defer_by=timedelta(seconds=LINKING_BACKOFF_IN_SECONDS), + ) + + if new_job: + new_job_id = new_job.job_id + + logging_context["link_clingen_variants_job_id"] = new_job_id + logger.info(msg="Queued a new ClinGen linking job.", extra=logging_context) + + else: + raise LinkingEnqueueError() + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource submission encountered an unexpected error while attempting to enqueue a linking job. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": new_job_id} + + return {"success": True, "retried": False, "enqueued_job": new_job_id} + + +def do_clingen_fetch(variant_urns): + return [(variant_urn, get_clingen_variation(variant_urn)) for variant_urn in variant_urns] + + +async def link_clingen_variants(ctx: dict, correlation_id: str, score_set_id: int, attempt: int) -> dict: + logging_context = {} + score_set = None + text = "Could not link mappings to LDH for score set %s. Mappings for this score set should be linked manually." + try: + db: Session = ctx["db"] + redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) + logging_context["linkage_retry_threshold"] = LINKED_DATA_RETRY_THRESHOLD + logging_context["attempt"] = attempt + logging_context["max_attempts"] = ENQUEUE_BACKOFF_ATTEMPT_LIMIT + logger.info(msg="Started LDH mapped resource linkage", extra=logging_context) + + submission_urn = score_set.urn + assert submission_urn, "A valid URN is needed to link LDH objects for this score set." + + logging_context["current_ldh_linking_resource"] = submission_urn + logger.debug(msg="Fetched score set metadata for ldh mapped resource linkage.", extra=logging_context) + + except Exception as e: + send_slack_error(e) + if score_set: + send_slack_message(text=text % score_set.urn) + else: + send_slack_message(text=text % score_set_id) + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + variant_urns = db.scalars( + select(Variant.urn) + .join(MappedVariant) + .join(ScoreSet) + .where( + ScoreSet.urn == score_set.urn, MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None) + ) + ).all() + num_variant_urns = len(variant_urns) + + logging_context["variants_to_link_ldh"] = num_variant_urns + + if not variant_urns: + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH linkage (nothing to do). A gnomAD linkage job will not be enqueued, as no variants will have a CAID.", + extra=logging_context, + ) + + return {"success": True, "retried": False, "enqueued_job": None} + + logger.info( + msg="Found current mapped variants with post mapped metadata for this score set. Attempting to link them to LDH submissions.", + extra=logging_context, + ) + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + logger.info(msg="Attempting to link mapped variants to LDH submissions.", extra=logging_context) + + # TODO#372: Non-nullable variant urns. + blocking = functools.partial( + do_clingen_fetch, + variant_urns, # type: ignore + ) + loop = asyncio.get_running_loop() + linked_data = await loop.run_in_executor(ctx["pool"], blocking) + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + linked_allele_ids = [ + (variant_urn, clingen_allele_id_from_ldh_variation(clingen_variation)) + for variant_urn, clingen_variation in linked_data + ] + + linkage_failures = [] + for variant_urn, ldh_variation in linked_allele_ids: + # XXX: Should we unlink variation if it is not found? Does this constitute a failure? + if not ldh_variation: + logger.warning( + msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No LDH variation found.", + extra=logging_context, + ) + linkage_failures.append(variant_urn) + continue + + mapped_variant = db.scalars( + select(MappedVariant).join(Variant).where(Variant.urn == variant_urn, MappedVariant.current.is_(True)) + ).one_or_none() + + if not mapped_variant: + logger.warning( + msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No mapped variant found.", + extra=logging_context, + ) + linkage_failures.append(variant_urn) + continue + + mapped_variant.clingen_allele_id = ldh_variation + db.add(mapped_variant) + + db.commit() + + except Exception as e: + db.rollback() + + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + num_linkage_failures = len(linkage_failures) + ratio_failed_linking = round(num_linkage_failures / num_variant_urns, 3) + logging_context["linkage_failure_rate"] = ratio_failed_linking + logging_context["linkage_failures"] = num_linkage_failures + logging_context["linkage_successes"] = num_variant_urns - num_linkage_failures + + assert ( + len(linked_allele_ids) == num_variant_urns + ), f"{num_variant_urns - len(linked_allele_ids)} appear to not have been attempted to be linked." + + job_succeeded = False + if not linkage_failures: + logger.info( + msg="Successfully linked all mapped variants to LDH submissions.", + extra=logging_context, + ) + + job_succeeded = True + + elif ratio_failed_linking < LINKED_DATA_RETRY_THRESHOLD: + logger.warning( + msg="Linkage failures exist, but did not exceed the retry threshold.", + extra=logging_context, + ) + send_slack_message( + text=f"Failed to link {len(linkage_failures)} mapped variants to LDH submissions for score set {score_set.urn}." + f"The retry threshold was not exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." + ) + + job_succeeded = True + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to finalize linkage. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + if job_succeeded: + gnomad_linking_job_id = None + try: + new_job = await redis.enqueue_job( + "link_gnomad_variants", + correlation_id, + score_set.id, + ) + + if new_job: + gnomad_linking_job_id = new_job.job_id + + logging_context["link_gnomad_variants_job_id"] = gnomad_linking_job_id + logger.info(msg="Queued a new gnomAD linking job.", extra=logging_context) + + else: + raise LinkingEnqueueError() + + except Exception as e: + job_succeeded = False + + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to enqueue a gnomAD linking job. GnomAD variants should be linked manually for this score set. This job will not be retried.", + extra=logging_context, + ) + finally: + return {"success": job_succeeded, "retried": False, "enqueued_job": gnomad_linking_job_id} + + # If we reach this point, we should consider the job failed (there were failures which exceeded our retry threshold). + new_job_id = None + max_retries_exceeded = None + try: + new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( + ctx["redis"], "variant_mapper_manager", attempt, LINKING_BACKOFF_IN_SECONDS, correlation_id + ) + + logging_context["backoff_limit_exceeded"] = max_retries_exceeded + logging_context["backoff_deferred_in_seconds"] = backoff_time + logging_context["backoff_job_id"] = new_job_id + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.critical( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to retry a failed linkage job. This job will not be retried.", + extra=logging_context, + ) + else: + if new_job_id and not max_retries_exceeded: + logger.info( + msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was queued.", + extra=logging_context, + ) + send_slack_message( + text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking * 100}% of total mapped variants for {score_set.urn})." + f"This job was successfully retried. This was attempt {attempt}. Retry will occur in {backoff_time} seconds. URNs failed to link: {', '.join(linkage_failures)}." + ) + elif new_job_id is None and not max_retries_exceeded: + logger.error( + msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was unable to be queued.", + extra=logging_context, + ) + send_slack_message( + text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." + f"This job could not be retried due to an unexpected issue while attempting to enqueue another linkage job. This was attempt {attempt}. URNs failed to link: {', '.join(linkage_failures)}." + ) + else: + logger.error( + msg="After a failure condition while linking mapped variants to LDH submissions, the maximum retries for this job were exceeded. The reamining linkage failures will not be retried.", + extra=logging_context, + ) + send_slack_message( + text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." + f"The retry threshold was exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." + ) + + finally: + return { + "success": False, + "retried": (not max_retries_exceeded and new_job_id is not None), + "enqueued_job": new_job_id, + } diff --git a/src/mavedb/worker/jobs/external_services/gnomad.py b/src/mavedb/worker/jobs/external_services/gnomad.py new file mode 100644 index 00000000..66be8fd9 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/gnomad.py @@ -0,0 +1,140 @@ +"""gnomAD variant linking jobs for population frequency annotation. + +This module handles linking of mapped variants to gnomAD (Genome Aggregation Database) +variants to provide population frequency and other genomic context information. +This enrichment helps researchers understand the clinical significance and +rarity of variants in their datasets. +""" + +import logging +from typing import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.slack import send_slack_error, send_slack_message +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.job_state import setup_job_state + +logger = logging.getLogger(__name__) + + +async def link_gnomad_variants(ctx: dict, correlation_id: str, score_set_id: int) -> dict: + logging_context = {} + score_set = None + text = "Could not link mappings to gnomAD variants for score set %s. Mappings for this score set should be linked manually." + try: + db: Session = ctx["db"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) + logger.info(msg="Started gnomAD variant linkage", extra=logging_context) + + submission_urn = score_set.urn + assert submission_urn, "A valid URN is needed to link gnomAD objects for this score set." + + logging_context["current_gnomad_linking_resource"] = submission_urn + logger.debug(msg="Fetched score set metadata for gnomAD mapped resource linkage.", extra=logging_context) + + except Exception as e: + send_slack_error(e) + if score_set: + send_slack_message(text=text % score_set.urn) + else: + send_slack_message(text=text % score_set_id) + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + # We filter out mapped variants that do not have a CAID, so this query is typed # as a Sequence[str]. Ignore MyPy's type checking here. + variant_caids: Sequence[str] = db.scalars( + select(MappedVariant.clingen_allele_id) + .join(Variant) + .join(ScoreSet) + .where( + ScoreSet.urn == score_set.urn, + MappedVariant.current.is_(True), + MappedVariant.clingen_allele_id.is_not(None), + ) + ).all() # type: ignore + num_variant_caids = len(variant_caids) + + logging_context["num_variants_to_link_gnomad"] = num_variant_caids + + if not variant_caids: + logger.warning( + msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", + extra=logging_context, + ) + + return {"success": True, "retried": False, "enqueued_job": None} + + logger.info( + msg="Found current mapped variants with CAIDs for this score set. Attempting to link them to gnomAD variants.", + extra=logging_context, + ) + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + gnomad_variant_data = gnomad_variant_data_for_caids(variant_caids) + num_gnomad_variants_with_caid_match = len(gnomad_variant_data) + logging_context["num_gnomad_variants_with_caid_match"] = num_gnomad_variants_with_caid_match + + if not gnomad_variant_data: + logger.warning( + msg="No gnomAD variants with CAID matches were found for this score set. Skipping gnomAD linkage (nothing to do).", + extra=logging_context, + ) + + return {"success": True, "retried": False, "enqueued_job": None} + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to fetch gnomAD variant data from S3 via Athena. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + try: + logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=logging_context) + num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(db, gnomad_variant_data) + db.commit() + logging_context["num_mapped_variants_linked_to_gnomad_variants"] = num_linked_gnomad_variants + + except Exception as e: + send_slack_error(e) + send_slack_message(text=text % score_set.urn) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_job": None} + + logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=logging_context) + return {"success": True, "retried": False, "enqueued_job": None} diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py new file mode 100644 index 00000000..a72cf9e2 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -0,0 +1,230 @@ +"""UniProt ID mapping jobs for protein sequence annotation. + +This module handles the submission and polling of UniProt ID mapping jobs +to enrich target gene metadata with UniProt identifiers. This enables +linking of genomic variants to protein-level functional information. + +The mapping process is asynchronous, requiring both submission and polling +jobs to handle the UniProt API's batch processing workflow. +""" + +import logging +from typing import Optional + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.lib.exceptions import UniProtPollingEnqueueError +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.mapping import extract_ids_from_post_mapped_metadata +from mavedb.lib.slack import log_and_send_slack_message, send_slack_error +from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI +from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession +from mavedb.models.score_set import ScoreSet +from mavedb.worker.jobs.utils.job_state import setup_job_state + +logger = logging.getLogger(__name__) + + +async def submit_uniprot_mapping_jobs_for_score_set(ctx, score_set_id: int, correlation_id: Optional[str] = None): + logging_context = {} + score_set = None + spawned_mapping_jobs: dict[int, Optional[str]] = {} + text = "Could not submit mapping jobs to UniProt for this score set %s. Mapping jobs for this score set should be submitted manually." + try: + db: Session = ctx["db"] + redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) + logger.info(msg="Started UniProt mapping job", extra=logging_context) + + if not score_set or not score_set.target_genes: + msg = f"No target genes for score set {score_set_id}. Skipped mapping targets to UniProt." + log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) + + return {"success": True, "retried": False, "enqueued_jobs": []} + + except Exception as e: + send_slack_error(e) + if score_set: + msg = text % score_set.urn + else: + msg = text % score_set_id + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) + + return {"success": False, "retried": False, "enqueued_jobs": []} + + try: + uniprot_api = UniProtIDMappingAPI() + logging_context["total_target_genes_to_map_to_uniprot"] = len(score_set.target_genes) + for target_gene in score_set.target_genes: + spawned_mapping_jobs[target_gene.id] = None # type: ignore + + acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore + if not acs: + msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + if len(acs) != 1: + msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + ac_to_map = acs[0] + from_db = infer_db_name_from_sequence_accession(ac_to_map) + + try: + spawned_mapping_jobs[target_gene.id] = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore + except Exception as e: + log_and_send_slack_message( + msg=f"Failed to submit UniProt mapping job for target gene {target_gene.id}: {e}. This target will be skipped.", + ctx=logging_context, + level=logging.WARNING, + ) + + except Exception as e: + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + log_and_send_slack_message( + msg=f"UniProt mapping job encountered an unexpected error while attempting to submit mapping jobs for score set {score_set.urn}. This job will not be retried.", + ctx=logging_context, + level=logging.ERROR, + ) + + return {"success": False, "retried": False, "enqueued_jobs": []} + + new_job_id = None + try: + successfully_spawned_mapping_jobs = sum(1 for job in spawned_mapping_jobs.values() if job is not None) + logging_context["successfully_spawned_mapping_jobs"] = successfully_spawned_mapping_jobs + + if not successfully_spawned_mapping_jobs: + msg = f"No UniProt mapping jobs were successfully spawned for score set {score_set.urn}. Skipped enqueuing polling job." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + return {"success": True, "retried": False, "enqueued_jobs": []} + + new_job = await redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", + spawned_mapping_jobs, + score_set_id, + correlation_id, + ) + + if new_job: + new_job_id = new_job.job_id + + logging_context["poll_uniprot_mapping_job_id"] = new_job_id + logger.info(msg="Enqueued polling jobs for UniProt mapping jobs.", extra=logging_context) + + else: + raise UniProtPollingEnqueueError() + + except Exception as e: + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + log_and_send_slack_message( + msg="UniProt mapping job encountered an unexpected error while attempting to enqueue polling jobs for mapping jobs. This job will not be retried.", + ctx=logging_context, + level=logging.ERROR, + ) + + return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} + + return {"success": True, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} + + +async def poll_uniprot_mapping_jobs_for_score_set( + ctx, mapping_jobs: dict[int, Optional[str]], score_set_id: int, correlation_id: Optional[str] = None +): + logging_context = {} + score_set = None + text = "Could not poll mapping jobs from UniProt for this Target %s. Mapping jobs for this score set should be submitted manually." + try: + db: Session = ctx["db"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) + logger.info(msg="Started UniProt polling job", extra=logging_context) + + if not score_set or not score_set.target_genes: + msg = f"No target genes for score set {score_set_id}. Skipped polling targets for UniProt mapping results." + log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) + + return {"success": True, "retried": False, "enqueued_jobs": []} + + except Exception as e: + send_slack_error(e) + if score_set: + msg = text % score_set.urn + else: + msg = text % score_set_id + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) + + return {"success": False, "retried": False, "enqueued_jobs": []} + + try: + uniprot_api = UniProtIDMappingAPI() + for target_gene in score_set.target_genes: + acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore + if not acs: + msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + if len(acs) != 1: + msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + mapped_ac = acs[0] + job_id = mapping_jobs.get(target_gene.id) # type: ignore + + if not job_id: + msg = f"No job ID found for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." + # This issue has already been sent to Slack in the job submission function, so we just log it here. + logger.debug(msg=msg, extra=logging_context) + continue + + if not uniprot_api.check_id_mapping_results_ready(job_id): + msg = f"Job {job_id} not ready for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target" + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + results = uniprot_api.get_id_mapping_results(job_id) + mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) + + if not mapped_ids: + msg = f"No UniProt ID found for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + if len(mapped_ids) != 1: + msg = f"Found ambiguous Uniprot ID mapping results for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." + log_and_send_slack_message(msg, logging_context, logging.WARNING) + continue + + mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] + target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id + db.add(target_gene) + logger.info( + msg=f"Updated target gene {target_gene.id} with UniProt ID {mapped_uniprot_id}", extra=logging_context + ) + + except Exception as e: + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + log_and_send_slack_message( + msg="UniProt mapping job encountered an unexpected error while attempting to poll mapping jobs. This job will not be retried.", + ctx=logging_context, + level=logging.ERROR, + ) + + return {"success": False, "retried": False, "enqueued_jobs": []} + + db.commit() + return {"success": True, "retried": False, "enqueued_jobs": []} diff --git a/src/mavedb/worker/jobs/py.typed b/src/mavedb/worker/jobs/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py new file mode 100644 index 00000000..a79ed3fa --- /dev/null +++ b/src/mavedb/worker/jobs/registry.py @@ -0,0 +1,63 @@ +"""Job registry for worker configuration. + +This module provides a centralized registry of all available worker jobs +as simple lists for ARQ worker configuration. +""" + +from datetime import timedelta +from typing import Callable, List + +from arq.cron import CronJob, cron + +from mavedb.worker.jobs.data_management import ( + refresh_materialized_views, + refresh_published_variants_view, +) +from mavedb.worker.jobs.external_services import ( + link_clingen_variants, + link_gnomad_variants, + poll_uniprot_mapping_jobs_for_score_set, + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.jobs.variant_processing import ( + create_variants_for_score_set, + map_variants_for_score_set, + variant_mapper_manager, +) + +# All job functions for ARQ worker +BACKGROUND_FUNCTIONS: List[Callable] = [ + # Variant processing jobs + create_variants_for_score_set, + map_variants_for_score_set, + variant_mapper_manager, + # External service jobs + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, + link_clingen_variants, + submit_uniprot_mapping_jobs_for_score_set, + poll_uniprot_mapping_jobs_for_score_set, + link_gnomad_variants, + # Data management jobs + refresh_materialized_views, + refresh_published_variants_view, +] + +# Cron job definitions for ARQ worker +BACKGROUND_CRONJOBS: List[CronJob] = [ + cron( + refresh_materialized_views, + name="refresh_all_materialized_views", + hour=20, + minute=0, + keep_result=timedelta(minutes=2).total_seconds(), + ), +] + + +__all__ = [ + "BACKGROUND_FUNCTIONS", + "BACKGROUND_CRONJOBS", +] diff --git a/src/mavedb/worker/jobs/utils/__init__.py b/src/mavedb/worker/jobs/utils/__init__.py new file mode 100644 index 00000000..a63687b8 --- /dev/null +++ b/src/mavedb/worker/jobs/utils/__init__.py @@ -0,0 +1,30 @@ +"""Worker job utility functions and constants. + +This module provides shared utilities used across worker jobs: +- Job state management and context setup +- Retry logic with exponential backoff +- Configuration constants for queues and timeouts + +These utilities help ensure consistent behavior and error handling +across all worker job implementations. +""" + +from .constants import ( + ENQUEUE_BACKOFF_ATTEMPT_LIMIT, + LINKING_BACKOFF_IN_SECONDS, + MAPPING_BACKOFF_IN_SECONDS, + MAPPING_CURRENT_ID_NAME, + MAPPING_QUEUE_NAME, +) +from .job_state import setup_job_state +from .retry import enqueue_job_with_backoff + +__all__ = [ + "setup_job_state", + "enqueue_job_with_backoff", + "MAPPING_QUEUE_NAME", + "MAPPING_CURRENT_ID_NAME", + "MAPPING_BACKOFF_IN_SECONDS", + "LINKING_BACKOFF_IN_SECONDS", + "ENQUEUE_BACKOFF_ATTEMPT_LIMIT", +] diff --git a/src/mavedb/worker/jobs/utils/constants.py b/src/mavedb/worker/jobs/utils/constants.py new file mode 100644 index 00000000..cca5a02c --- /dev/null +++ b/src/mavedb/worker/jobs/utils/constants.py @@ -0,0 +1,17 @@ +"""Constants used across worker jobs. + +This module centralizes configuration constants used by various worker jobs +including queue names, timeouts, and retry limits. This provides a single +source of truth for job configuration values. +""" + +### Mapping job constants +MAPPING_QUEUE_NAME = "vrs_mapping_queue" +MAPPING_CURRENT_ID_NAME = "vrs_mapping_current_job_id" +MAPPING_BACKOFF_IN_SECONDS = 15 + +### Linking job constants +LINKING_BACKOFF_IN_SECONDS = 15 * 60 + +### Backoff constants +ENQUEUE_BACKOFF_ATTEMPT_LIMIT = 5 diff --git a/src/mavedb/worker/jobs/utils/job_state.py b/src/mavedb/worker/jobs/utils/job_state.py new file mode 100644 index 00000000..33c6887b --- /dev/null +++ b/src/mavedb/worker/jobs/utils/job_state.py @@ -0,0 +1,35 @@ +"""Job state management utilities. + +This module provides utilities for managing job state and context across +the worker job lifecycle. It handles setup of logging context, correlation +IDs, and other state information needed for job traceability and monitoring. +""" + +import logging +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +def setup_job_state( + ctx, invoker: Optional[int], resource: Optional[str], correlation_id: Optional[str] +) -> dict[str, Any]: + """ + Initialize and store job state information in the context dictionary for traceability. + + Args: + ctx: The job context dictionary, must contain 'state' and 'job_id' keys. + invoker: The user ID or identifier who initiated the job (may be None). + resource: The resource string associated with the job (may be None). + correlation_id: Optional correlation ID for tracing requests across services. + + Returns: + dict[str, Any]: The job state dictionary for the current job_id. + """ + ctx["state"][ctx["job_id"]] = { + "application": "mavedb-worker", + "user": invoker, + "resource": resource, + "correlation_id": correlation_id, + } + return ctx["state"][ctx["job_id"]] diff --git a/src/mavedb/worker/jobs/utils/retry.py b/src/mavedb/worker/jobs/utils/retry.py new file mode 100644 index 00000000..5150d95b --- /dev/null +++ b/src/mavedb/worker/jobs/utils/retry.py @@ -0,0 +1,61 @@ +"""Retry and backoff utilities for job error handling. + +This module provides utilities for implementing exponential backoff and +retry logic for failed jobs. It helps ensure reliable job execution +by automatically retrying transient failures with appropriate delays. +""" + +import logging +from datetime import timedelta +from typing import Any, Optional + +from arq import ArqRedis + +from mavedb.worker.jobs.utils.constants import ENQUEUE_BACKOFF_ATTEMPT_LIMIT + +logger = logging.getLogger(__name__) + + +async def enqueue_job_with_backoff( + redis: ArqRedis, job_name: str, attempt: int, backoff: int, *args +) -> tuple[Optional[str], bool, Any]: + """ + Enqueue a job with exponential backoff and attempt tracking, for robust retry logic. + + Args: + redis (ArqRedis): The Redis connection for job queueing. + job_name (str): The name of the job to enqueue. + attempt (int): The current attempt number (used for backoff calculation). + backoff (int): The base backoff time in seconds. + *args: Additional arguments to pass to the job. + + Returns: + tuple[Optional[str], bool, Any]: + - The new job ID if enqueued, else None. + - Boolean indicating if the backoff limit was NOT reached (True if retry scheduled). + - The updated backoff value (seconds). + + Notes: + - If the attempt exceeds ENQUEUE_BACKOFF_ATTEMPT_LIMIT, no job is enqueued and limit is considered reached. + - The attempt value is incremented and passed as the last argument to the job. + - The job is deferred by the calculated backoff time. + """ + new_job_id = None + limit_reached = attempt > ENQUEUE_BACKOFF_ATTEMPT_LIMIT + if not limit_reached: + limit_reached = True + backoff = backoff * (2**attempt) + attempt = attempt + 1 + + # NOTE: for jobs supporting backoff, `attempt` should be the final argument. + new_job = await redis.enqueue_job( + job_name, + *args, + attempt, + _defer_by=timedelta(seconds=backoff), + ) + + if new_job: + new_job_id = new_job.job_id + + return (new_job_id, not limit_reached, backoff) diff --git a/src/mavedb/worker/jobs/variant_processing/__init__.py b/src/mavedb/worker/jobs/variant_processing/__init__.py new file mode 100644 index 00000000..b9085659 --- /dev/null +++ b/src/mavedb/worker/jobs/variant_processing/__init__.py @@ -0,0 +1,19 @@ +"""Variant processing job functions. + +This module exports jobs responsible for variant creation and mapping: +- Variant creation from uploaded score/count data +- VRS mapping to standardized genomic coordinates +- Queue management for mapping workflows +""" + +from .creation import create_variants_for_score_set +from .mapping import ( + map_variants_for_score_set, + variant_mapper_manager, +) + +__all__ = [ + "create_variants_for_score_set", + "map_variants_for_score_set", + "variant_mapper_manager", +] diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py new file mode 100644 index 00000000..3064581b --- /dev/null +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -0,0 +1,196 @@ +"""Variant creation jobs for score sets. + +This module contains jobs responsible for creating and validating variants +from uploaded score and count data. It handles the full variant creation +pipeline including data validation, standardization, and database persistence. +""" + +import logging +from typing import Optional + +import pandas as pd +from arq import ArqRedis +from sqlalchemy import delete, null, select +from sqlalchemy.orm import Session + +from mavedb.data_providers.services import RESTDataProvider +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.score_sets import columns_for_dataset, create_variants, create_variants_data +from mavedb.lib.slack import send_slack_error +from mavedb.lib.validation.dataframe.dataframe import validate_and_standardize_dataframe_pair +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User +from mavedb.models.variant import Variant +from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata +from mavedb.worker.jobs.utils.constants import MAPPING_QUEUE_NAME +from mavedb.worker.jobs.utils.job_state import setup_job_state + +logger = logging.getLogger(__name__) + + +async def create_variants_for_score_set( + ctx, + correlation_id: str, + score_set_id: int, + updater_id: int, + scores: pd.DataFrame, + counts: pd.DataFrame, + score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, + count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, +): + """ + Create variants for a score set. Intended to be run within a worker. + On any raised exception, ensure ProcessingState of score set is set to `failed` prior + to exiting. + """ + logging_context = {} + try: + db: Session = ctx["db"] + hdp: RESTDataProvider = ctx["hdp"] + redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) + logger.info(msg="Began processing of score set variants.", extra=logging_context) + + updated_by = db.scalars(select(User).where(User.id == updater_id)).one() + + score_set.modified_by = updated_by + score_set.processing_state = ProcessingState.processing + score_set.mapping_state = MappingState.pending_variant_processing + logging_context["processing_state"] = score_set.processing_state.name + logging_context["mapping_state"] = score_set.mapping_state.name + + db.add(score_set) + db.commit() + db.refresh(score_set) + + if not score_set.target_genes: + logger.warning( + msg="No targets are associated with this score set; could not create variants.", + extra=logging_context, + ) + raise ValueError("Can't create variants when score set has no targets.") + + validated_scores, validated_counts, validated_score_columns_metadata, validated_count_columns_metadata = ( + validate_and_standardize_dataframe_pair( + scores_df=scores, + counts_df=counts, + score_columns_metadata=score_columns_metadata, + count_columns_metadata=count_columns_metadata, + targets=score_set.target_genes, + hdp=hdp, + ) + ) + + score_set.dataset_columns = { + "score_columns": columns_for_dataset(validated_scores), + "count_columns": columns_for_dataset(validated_counts), + "score_columns_metadata": validated_score_columns_metadata + if validated_score_columns_metadata is not None + else {}, + "count_columns_metadata": validated_count_columns_metadata + if validated_count_columns_metadata is not None + else {}, + } + + # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. + if score_set.variants: + existing_variants = db.scalars(select(Variant.id).where(Variant.score_set_id == score_set.id)).all() + db.execute(delete(MappedVariant).where(MappedVariant.variant_id.in_(existing_variants))) + db.execute(delete(Variant).where(Variant.id.in_(existing_variants))) + logging_context["deleted_variants"] = score_set.num_variants + score_set.num_variants = 0 + + logger.info(msg="Deleted existing variants from score set.", extra=logging_context) + + db.flush() + db.refresh(score_set) + + variants_data = create_variants_data(validated_scores, validated_counts, None) + create_variants(db, score_set, variants_data) + + # Validation errors arise from problematic user data. These should be inserted into the database so failures can + # be persisted to them. + except ValidationError as e: + db.rollback() + score_set.processing_state = ProcessingState.failed + score_set.processing_errors = {"exception": str(e), "detail": e.triggering_exceptions} + score_set.mapping_state = MappingState.not_attempted + + if score_set.num_variants: + score_set.processing_errors["exception"] = ( + f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" + ) + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logging_context["processing_state"] = score_set.processing_state.name + logging_context["mapping_state"] = score_set.mapping_state.name + logging_context["created_variants"] = 0 + logger.warning(msg="Encountered a validation error while processing variants.", extra=logging_context) + + return {"success": False} + + # NOTE: Since these are likely to be internal errors, it makes less sense to add them to the DB and surface them to the end user. + # Catch all non-system exiting exceptions. + except Exception as e: + db.rollback() + score_set.processing_state = ProcessingState.failed + score_set.processing_errors = {"exception": str(e), "detail": []} + score_set.mapping_state = MappingState.not_attempted + + if score_set.num_variants: + score_set.processing_errors["exception"] = ( + f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" + ) + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logging_context["processing_state"] = score_set.processing_state.name + logging_context["mapping_state"] = score_set.mapping_state.name + logging_context["created_variants"] = 0 + logger.warning(msg="Encountered an internal exception while processing variants.", extra=logging_context) + + send_slack_error(err=e) + return {"success": False} + + # Catch all other exceptions. The exceptions caught here were intented to be system exiting. + except BaseException as e: + db.rollback() + score_set.processing_state = ProcessingState.failed + score_set.mapping_state = MappingState.not_attempted + db.commit() + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logging_context["processing_state"] = score_set.processing_state.name + logging_context["mapping_state"] = score_set.mapping_state.name + logging_context["created_variants"] = 0 + logger.error( + msg="Encountered an unhandled exception while creating variants for score set.", extra=logging_context + ) + + # Don't raise BaseExceptions so we may emit canonical logs (TODO: Perhaps they are so problematic we want to raise them anyway). + return {"success": False} + + else: + score_set.processing_state = ProcessingState.success + score_set.processing_errors = null() + + logging_context["created_variants"] = score_set.num_variants + logging_context["processing_state"] = score_set.processing_state.name + logger.info(msg="Finished creating variants in score set.", extra=logging_context) + + await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore + await redis.enqueue_job("variant_mapper_manager", correlation_id, updater_id) + score_set.mapping_state = MappingState.queued + finally: + db.add(score_set) + db.commit() + db.refresh(score_set) + logger.info(msg="Committed new variants to score set.", extra=logging_context) + + ctx["state"][ctx["job_id"]] = logging_context.copy() + return {"success": True} diff --git a/src/mavedb/worker/jobs/variant_processing/mapping.py b/src/mavedb/worker/jobs/variant_processing/mapping.py new file mode 100644 index 00000000..91c6f0fe --- /dev/null +++ b/src/mavedb/worker/jobs/variant_processing/mapping.py @@ -0,0 +1,569 @@ +"""Variant mapping jobs using VRS (Variant Representation Specification). + +This module handles the mapping of variants to standardized genomic coordinates +using the VRS mapping service. It includes queue management, retry logic, +and coordination with downstream services like ClinGen and UniProt. +""" + +import asyncio +import functools +import logging +from contextlib import asynccontextmanager +from datetime import date, timedelta +from typing import Any + +from arq import ArqRedis +from arq.jobs import Job, JobStatus +from sqlalchemy import cast, null, select +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Session + +from mavedb.data_providers.services import vrs_mapper +from mavedb.lib.clingen.constants import CLIN_GEN_SUBMISSION_ENABLED +from mavedb.lib.exceptions import ( + MappingEnqueueError, + NonexistentMappingReferenceError, + NonexistentMappingResultsError, + SubmissionEnqueueError, + UniProtIDMappingEnqueueError, +) +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.lib.mapping import ANNOTATION_LAYERS +from mavedb.lib.slack import send_slack_error, send_slack_message +from mavedb.lib.uniprot.constants import UNIPROT_ID_MAPPING_ENABLED +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.constants import MAPPING_BACKOFF_IN_SECONDS, MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME +from mavedb.worker.jobs.utils.job_state import setup_job_state +from mavedb.worker.jobs.utils.retry import enqueue_job_with_backoff + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def mapping_in_execution(redis: ArqRedis, job_id: str): + await redis.set(MAPPING_CURRENT_ID_NAME, job_id) + try: + yield + finally: + await redis.set(MAPPING_CURRENT_ID_NAME, "") + + +async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int, attempt: int = 1) -> dict: + logging_context = {} + mapping_job_id = None + mapping_job_status = None + queued_score_set = None + try: + redis: ArqRedis = ctx["redis"] + db: Session = ctx["db"] + + logging_context = setup_job_state(ctx, updater_id, None, correlation_id) + logging_context["attempt"] = attempt + logger.debug(msg="Variant mapping manager began execution", extra=logging_context) + + queue_length = await redis.llen(MAPPING_QUEUE_NAME) # type: ignore + queued_id = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore + logging_context["variant_mapping_queue_length"] = queue_length + + # Setup the job id cache if it does not already exist. + if not await redis.exists(MAPPING_CURRENT_ID_NAME): + await redis.set(MAPPING_CURRENT_ID_NAME, "") + + if not queued_id: + logger.debug(msg="No mapping jobs exist in the queue.", extra=logging_context) + return {"success": True, "enqueued_job": None} + else: + queued_id = queued_id.decode("utf-8") + queued_score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_id)).one() + + logging_context["upcoming_mapping_resource"] = queued_score_set.urn + logger.debug(msg="Found mapping job(s) still in queue.", extra=logging_context) + + mapping_job_id = await redis.get(MAPPING_CURRENT_ID_NAME) + if mapping_job_id: + mapping_job_id = mapping_job_id.decode("utf-8") + mapping_job_status = (await Job(job_id=mapping_job_id, redis=redis).status()).value + + logging_context["existing_mapping_job_status"] = mapping_job_status + logging_context["existing_mapping_job_id"] = mapping_job_id + + except Exception as e: + send_slack_error(e) + + # Attempt to remove this item from the mapping queue. + try: + await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore + logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) + except Exception: + pass + + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error(msg="Variant mapper manager encountered an unexpected error during setup.", extra=logging_context) + + return {"success": False, "enqueued_job": None} + + new_job = None + new_job_id = None + try: + if not mapping_job_id or mapping_job_status in (JobStatus.not_found, JobStatus.complete): + logger.debug(msg="No mapping jobs are running, queuing a new one.", extra=logging_context) + + new_job = await redis.enqueue_job( + "map_variants_for_score_set", correlation_id, queued_score_set.id, updater_id, attempt + ) + + if new_job: + new_job_id = new_job.job_id + + logging_context["new_mapping_job_id"] = new_job_id + logger.info(msg="Queued a new mapping job.", extra=logging_context) + + return {"success": True, "enqueued_job": new_job_id} + + logger.info( + msg="A mapping job is already running, or a new job was unable to be enqueued. Deferring mapping by 5 minutes.", + extra=logging_context, + ) + + new_job = await redis.enqueue_job( + "variant_mapper_manager", + correlation_id, + updater_id, + attempt, + _defer_by=timedelta(minutes=5), + ) + + if new_job: + # Ensure this score set remains in the front of the queue. + queued_id = await redis.rpush(MAPPING_QUEUE_NAME, queued_score_set.id) # type: ignore + new_job_id = new_job.job_id + + logging_context["new_mapping_manager_job_id"] = new_job_id + logger.info(msg="Deferred a new mapping manager job.", extra=logging_context) + + # Our persistent Redis queue and ARQ's execution rules ensure that even if the worker is stopped and not restarted + # before the deferred time, these deferred jobs will still run once able. + return {"success": True, "enqueued_job": new_job_id} + + raise MappingEnqueueError() + + except Exception as e: + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="Variant mapper manager encountered an unexpected error while enqueing a mapping job. This job will not be retried.", + extra=logging_context, + ) + + db.rollback() + + # We shouldn't rely on the passed score set id matching the score set we are operating upon. + if not queued_score_set: + return {"success": False, "enqueued_job": new_job_id} + + # Attempt to remove this item from the mapping queue. + try: + await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore + logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) + except Exception: + pass + + score_set_exc = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_score_set.id)).one_or_none() + if score_set_exc: + score_set_exc.mapping_state = MappingState.failed + score_set_exc.mapping_errors = "Unable to queue a new mapping job or defer score set mapping." + db.add(score_set_exc) + db.commit() + + return {"success": False, "enqueued_job": new_job_id} + + +async def map_variants_for_score_set( + ctx: dict, correlation_id: str, score_set_id: int, updater_id: int, attempt: int = 1 +) -> dict: + async with mapping_in_execution(redis=ctx["redis"], job_id=ctx["job_id"]): + logging_context = {} + score_set = None + try: + db: Session = ctx["db"] + redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) + logging_context["attempt"] = attempt + logger.info(msg="Started variant mapping", extra=logging_context) + + score_set.mapping_state = MappingState.processing + score_set.mapping_errors = null() + db.add(score_set) + db.commit() + + mapping_urn = score_set.urn + assert mapping_urn, "A valid URN is needed to map this score set." + + logging_context["current_mapping_resource"] = mapping_urn + logging_context["mapping_state"] = score_set.mapping_state + logger.debug(msg="Fetched score set metadata for mapping job.", extra=logging_context) + + # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. + vrs = vrs_mapper() + blocking = functools.partial(vrs.map_score_set, mapping_urn) + loop = asyncio.get_running_loop() + + except Exception as e: + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="Variant mapper encountered an unexpected error during setup. This job will not be retried.", + extra=logging_context, + ) + + db.rollback() + if score_set: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + db.add(score_set) + db.commit() + + return {"success": False, "retried": False, "enqueued_jobs": []} + + mapping_results = None + try: + mapping_results = await loop.run_in_executor(ctx["pool"], blocking) + logger.debug(msg="Done mapping variants.", extra=logging_context) + + except Exception as e: + db.rollback() + score_set.mapping_errors = { + "error_message": f"Encountered an internal server error during mapping. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." + } + db.add(score_set) + db.commit() + + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.warning( + msg="Variant mapper encountered an unexpected error while mapping variants. This job will be retried.", + extra=logging_context, + ) + + new_job_id = None + max_retries_exceeded = None + try: + await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore + new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( + redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id + ) + # If we fail to enqueue a mapping manager for this score set, evict it from the queue. + if new_job_id is None: + await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore + + logging_context["backoff_limit_exceeded"] = max_retries_exceeded + logging_context["backoff_deferred_in_seconds"] = backoff_time + logging_context["backoff_job_id"] = new_job_id + + except Exception as backoff_e: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + db.add(score_set) + db.commit() + send_slack_error(backoff_e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} + logger.critical( + msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", + extra=logging_context, + ) + else: + if new_job_id and not max_retries_exceeded: + score_set.mapping_state = MappingState.queued + db.add(score_set) + db.commit() + logger.info( + msg="After encountering an error while mapping variants, another mapping job was queued.", + extra=logging_context, + ) + elif new_job_id is None and not max_retries_exceeded: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + db.add(score_set) + db.commit() + logger.error( + msg="After encountering an error while mapping variants, another mapping job was unable to be queued. This score set will not be mapped.", + extra=logging_context, + ) + else: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + db.add(score_set) + db.commit() + logger.error( + msg="After encountering an error while mapping variants, the maximum retries for this job were exceeded. This score set will not be mapped.", + extra=logging_context, + ) + finally: + return { + "success": False, + "retried": (not max_retries_exceeded and new_job_id is not None), + "enqueued_jobs": [job for job in [new_job_id] if job], + } + + try: + if mapping_results: + mapped_scores = mapping_results.get("mapped_scores") + if not mapped_scores: + # if there are no mapped scores, the score set failed to map. + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} + else: + reference_metadata = mapping_results.get("reference_sequences") + if not reference_metadata: + raise NonexistentMappingReferenceError() + + for target_gene_identifier in reference_metadata: + target_gene = next( + ( + target_gene + for target_gene in score_set.target_genes + if target_gene.name == target_gene_identifier + ), + None, + ) + if not target_gene: + raise ValueError( + f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." + ) + # allow for multiple annotation layers + pre_mapped_metadata: dict[str, Any] = {} + post_mapped_metadata: dict[str, Any] = {} + excluded_pre_mapped_keys = {"sequence"} + + gene_info = reference_metadata[target_gene_identifier].get("gene_info") + if gene_info: + target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") + post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") + + for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: + layer_premapped = reference_metadata[target_gene_identifier]["layers"][ + annotation_layer + ].get("computed_reference_sequence") + if layer_premapped: + pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { + k: layer_premapped[k] + for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys + } + layer_postmapped = reference_metadata[target_gene_identifier]["layers"][ + annotation_layer + ].get("mapped_reference_sequence") + if layer_postmapped: + post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped + target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) + target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) + + total_variants = 0 + successful_mapped_variants = 0 + for mapped_score in mapped_scores: + total_variants += 1 + variant_urn = mapped_score.get("mavedb_id") + variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() + + # there should only be one current mapped variant per variant id, so update old mapped variant to current = false + existing_mapped_variant = ( + db.query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() + ) + + if existing_mapped_variant: + existing_mapped_variant.current = False + db.add(existing_mapped_variant) + + if mapped_score.get("pre_mapped") and mapped_score.get("post_mapped"): + successful_mapped_variants += 1 + + mapped_variant = MappedVariant( + pre_mapped=mapped_score.get("pre_mapped", null()), + post_mapped=mapped_score.get("post_mapped", null()), + variant_id=variant.id, + modification_date=date.today(), + mapped_date=mapping_results["mapped_date_utc"], + vrs_version=mapped_score.get("vrs_version", null()), + mapping_api_version=mapping_results["dcd_mapping_version"], + error_message=mapped_score.get("error_message", null()), + current=True, + ) + db.add(mapped_variant) + + if successful_mapped_variants == 0: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "All variants failed to map"} + elif successful_mapped_variants < total_variants: + score_set.mapping_state = MappingState.incomplete + else: + score_set.mapping_state = MappingState.complete + + logging_context["mapped_variants_inserted_db"] = len(mapped_scores) + logging_context["variants_successfully_mapped"] = successful_mapped_variants + logging_context["mapping_state"] = score_set.mapping_state.name + logging_context["mapping_errors"] = score_set.mapping_errors + logger.info(msg="Inserted mapped variants into db.", extra=logging_context) + + else: + raise NonexistentMappingResultsError() + + db.add(score_set) + db.commit() + + except Exception as e: + db.rollback() + score_set.mapping_errors = { + "error_message": f"Encountered an unexpected error while parsing mapped variants. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." + } + db.add(score_set) + db.commit() + + send_slack_error(e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.warning( + msg="An unexpected error occurred during variant mapping. This job will be attempted again.", + extra=logging_context, + ) + + new_job_id = None + max_retries_exceeded = None + try: + await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore + new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( + redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id + ) + # If we fail to enqueue a mapping manager for this score set, evict it from the queue. + if new_job_id is None: + await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore + + logging_context["backoff_limit_exceeded"] = max_retries_exceeded + logging_context["backoff_deferred_in_seconds"] = backoff_time + logging_context["backoff_job_id"] = new_job_id + + except Exception as backoff_e: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + send_slack_error(backoff_e) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} + logger.critical( + msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", + extra=logging_context, + ) + else: + if new_job_id and not max_retries_exceeded: + score_set.mapping_state = MappingState.queued + logger.info( + msg="After encountering an error while parsing mapped variants, another mapping job was queued.", + extra=logging_context, + ) + elif new_job_id is None and not max_retries_exceeded: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + logger.error( + msg="After encountering an error while parsing mapped variants, another mapping job was unable to be queued. This score set will not be mapped.", + extra=logging_context, + ) + else: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} + logger.error( + msg="After encountering an error while parsing mapped variants, the maximum retries for this job were exceeded. This score set will not be mapped.", + extra=logging_context, + ) + finally: + db.add(score_set) + db.commit() + return { + "success": False, + "retried": (not max_retries_exceeded and new_job_id is not None), + "enqueued_jobs": [job for job in [new_job_id] if job], + } + + new_uniprot_job_id = None + try: + if UNIPROT_ID_MAPPING_ENABLED: + new_job = await redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", + score_set.id, + correlation_id, + ) + + if new_job: + new_uniprot_job_id = new_job.job_id + + logging_context["submit_uniprot_mapping_job_id"] = new_uniprot_job_id + logger.info(msg="Queued a new UniProt mapping job.", extra=logging_context) + + else: + raise UniProtIDMappingEnqueueError() + else: + logger.warning( + msg="UniProt ID mapping is disabled, skipped submission of UniProt mapping jobs.", + extra=logging_context, + ) + + except Exception as e: + send_slack_error(e) + send_slack_message( + f"Could not enqueue UniProt mapping job for score set {score_set.urn}. UniProt mappings for this score set should be submitted manually." + ) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="Mapped variant UniProt submission encountered an unexpected error while attempting to enqueue a mapping job. This job will not be retried.", + extra=logging_context, + ) + + return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_uniprot_job_id] if job]} + + new_clingen_job_id = None + try: + if CLIN_GEN_SUBMISSION_ENABLED: + new_job = await redis.enqueue_job( + "submit_score_set_mappings_to_car", + correlation_id, + score_set.id, + ) + + if new_job: + new_clingen_job_id = new_job.job_id + + logging_context["submit_clingen_variants_job_id"] = new_clingen_job_id + logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) + + else: + raise SubmissionEnqueueError() + else: + logger.warning( + msg="ClinGen submission is disabled, skipped submission of mapped variants to CAR and LDH.", + extra=logging_context, + ) + + except Exception as e: + send_slack_error(e) + send_slack_message( + f"Could not submit mappings to CAR and/or LDH mappings for score set {score_set.urn}. Mappings for this score set should be submitted manually." + ) + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error( + msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", + extra=logging_context, + ) + + return { + "success": False, + "retried": False, + "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], + } + + ctx["state"][ctx["job_id"]] = logging_context.copy() + return { + "success": True, + "retried": False, + "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], + } diff --git a/src/mavedb/worker/py.typed b/src/mavedb/worker/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/settings.py b/src/mavedb/worker/settings.py deleted file mode 100644 index 0a9359d5..00000000 --- a/src/mavedb/worker/settings.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -from concurrent import futures -from datetime import timedelta -from typing import Callable - -from arq.connections import RedisSettings -from arq.cron import CronJob, cron - -from mavedb.data_providers.services import cdot_rest -from mavedb.db.session import SessionLocal -from mavedb.lib.logging.canonical import log_job -from mavedb.worker.jobs import ( - create_variants_for_score_set, - map_variants_for_score_set, - variant_mapper_manager, - refresh_materialized_views, - refresh_published_variants_view, - submit_score_set_mappings_to_ldh, - link_clingen_variants, - poll_uniprot_mapping_jobs_for_score_set, - submit_uniprot_mapping_jobs_for_score_set, - link_gnomad_variants, - submit_score_set_mappings_to_car, -) - -# ARQ requires at least one task on startup. -BACKGROUND_FUNCTIONS: list[Callable] = [ - create_variants_for_score_set, - variant_mapper_manager, - map_variants_for_score_set, - refresh_published_variants_view, - submit_score_set_mappings_to_ldh, - link_clingen_variants, - poll_uniprot_mapping_jobs_for_score_set, - submit_uniprot_mapping_jobs_for_score_set, - link_gnomad_variants, - submit_score_set_mappings_to_car, -] -# In UTC time. Depending on daylight savings time, this will bounce around by an hour but should always be very early in the morning -# for all of the USA. -BACKGROUND_CRONJOBS: list[CronJob] = [ - cron( - refresh_materialized_views, - name="refresh_all_materialized_views", - hour=20, - minute=0, - keep_result=timedelta(minutes=2).total_seconds(), - ) -] - -REDIS_IP = os.getenv("REDIS_IP") or "localhost" -REDIS_PORT = int(os.getenv("REDIS_PORT") or 6379) -REDIS_SSL = (os.getenv("REDIS_SSL") or "false").lower() == "true" - - -RedisWorkerSettings = RedisSettings(host=REDIS_IP, port=REDIS_PORT, ssl=REDIS_SSL) - - -async def startup(ctx): - ctx["pool"] = futures.ProcessPoolExecutor() - - -async def shutdown(ctx): - pass - - -async def on_job_start(ctx): - db = SessionLocal() - db.current_user_id = None - ctx["db"] = db - ctx["hdp"] = cdot_rest() - ctx["state"] = {} - - -async def on_job_end(ctx): - db = ctx["db"] - db.close() - - -class ArqWorkerSettings: - """ - Settings for the ARQ worker. - """ - - on_startup = startup - on_shutdown = shutdown - on_job_start = on_job_start - on_job_end = on_job_end - after_job_end = log_job - redis_settings = RedisWorkerSettings - functions: list = BACKGROUND_FUNCTIONS - cron_jobs: list = BACKGROUND_CRONJOBS - - job_timeout = 5 * 60 * 60 # Keep jobs alive for a long while... diff --git a/src/mavedb/worker/settings/__init__.py b/src/mavedb/worker/settings/__init__.py new file mode 100644 index 00000000..af2e6a27 --- /dev/null +++ b/src/mavedb/worker/settings/__init__.py @@ -0,0 +1,19 @@ +"""Worker settings configuration. + +This module provides ARQ worker settings organized by concern: +- constants: Environment variable configuration +- redis: Redis connection settings +- lifecycle: Worker startup/shutdown hooks +- worker: Main ARQ worker configuration class + +The settings are designed to be modular and easily testable, +with clear separation between infrastructure and application concerns. +""" + +from .redis import RedisWorkerSettings +from .worker import ArqWorkerSettings + +__all__ = [ + "ArqWorkerSettings", + "RedisWorkerSettings", +] diff --git a/src/mavedb/worker/settings/constants.py b/src/mavedb/worker/settings/constants.py new file mode 100644 index 00000000..b5e8f23d --- /dev/null +++ b/src/mavedb/worker/settings/constants.py @@ -0,0 +1,12 @@ +"""Environment configuration constants for worker settings. + +This module centralizes all environment variable handling for the worker, +providing sensible defaults and type conversion for configuration values. +All worker-related environment variables should be defined here. +""" + +import os + +REDIS_IP = os.getenv("REDIS_IP") or "localhost" +REDIS_PORT = int(os.getenv("REDIS_PORT") or 6379) +REDIS_SSL = (os.getenv("REDIS_SSL") or "false").lower() == "true" diff --git a/src/mavedb/worker/settings/lifecycle.py b/src/mavedb/worker/settings/lifecycle.py new file mode 100644 index 00000000..7288c691 --- /dev/null +++ b/src/mavedb/worker/settings/lifecycle.py @@ -0,0 +1,35 @@ +"""Worker lifecycle management hooks. + +This module defines the startup, shutdown, and job lifecycle hooks +for the ARQ worker. These hooks manage: +- Process pool for CPU-intensive tasks +- Database session management per job +- HGVS data provider setup +- Job state initialization and cleanup +""" + +from concurrent import futures + +from mavedb.data_providers.services import cdot_rest +from mavedb.db.session import SessionLocal + + +async def startup(ctx): + ctx["pool"] = futures.ProcessPoolExecutor() + + +async def shutdown(ctx): + pass + + +async def on_job_start(ctx): + db = SessionLocal() + db.current_user_id = None + ctx["db"] = db + ctx["hdp"] = cdot_rest() + ctx["state"] = {} + + +async def on_job_end(ctx): + db = ctx["db"] + db.close() diff --git a/src/mavedb/worker/settings/redis.py b/src/mavedb/worker/settings/redis.py new file mode 100644 index 00000000..2773f77f --- /dev/null +++ b/src/mavedb/worker/settings/redis.py @@ -0,0 +1,12 @@ +"""Redis connection settings for ARQ worker. + +This module provides Redis connection configuration using environment +variables with appropriate defaults. The settings are compatible with +ARQ's RedisSettings class and handle SSL connections. +""" + +from arq.connections import RedisSettings + +from mavedb.worker.settings.constants import REDIS_IP, REDIS_PORT, REDIS_SSL + +RedisWorkerSettings = RedisSettings(host=REDIS_IP, port=REDIS_PORT, ssl=REDIS_SSL) diff --git a/src/mavedb/worker/settings/worker.py b/src/mavedb/worker/settings/worker.py new file mode 100644 index 00000000..03bad1f3 --- /dev/null +++ b/src/mavedb/worker/settings/worker.py @@ -0,0 +1,33 @@ +"""Main ARQ worker configuration class. + +This module defines the primary ArqWorkerSettings class that brings together +all worker configuration including: +- Job functions and cron jobs from the jobs registry +- Redis connection settings +- Lifecycle hooks for startup/shutdown and job execution +- Timeout and logging configuration + +This is the main configuration class used to start the ARQ worker. +""" + +from mavedb.lib.logging.canonical import log_job +from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS +from mavedb.worker.settings.lifecycle import on_job_end, on_job_start, shutdown, startup +from mavedb.worker.settings.redis import RedisWorkerSettings + + +class ArqWorkerSettings: + """ + Settings for the ARQ worker. + """ + + on_startup = startup + on_shutdown = shutdown + on_job_start = on_job_start + on_job_end = on_job_end + after_job_end = log_job + redis_settings = RedisWorkerSettings + functions: list = BACKGROUND_FUNCTIONS + cron_jobs: list = BACKGROUND_CRONJOBS + + job_timeout = 5 * 60 * 60 # Keep jobs alive for a long while... diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index 8597c4f9..028a4e05 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -1,9 +1,10 @@ import os +import shutil +import tempfile from concurrent import futures from inspect import getsourcefile from posixpath import abspath -import shutil -import tempfile +from unittest.mock import patch import cdot.hgvs.dataproviders import pytest @@ -12,15 +13,13 @@ from biocommons.seqrepo import SeqRepo from fastapi.testclient import TestClient from httpx import AsyncClient -from unittest.mock import patch +from mavedb.deps import get_db, get_seqrepo, get_worker, hgvs_data_provider from mavedb.lib.authentication import UserData, get_current_user from mavedb.lib.authorization import require_current_user from mavedb.models.user import User from mavedb.server_main import app -from mavedb.deps import get_db, get_worker, hgvs_data_provider, get_seqrepo -from mavedb.worker.settings import BACKGROUND_FUNCTIONS, BACKGROUND_CRONJOBS - +from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER #################################################################################################### diff --git a/tests/helpers/util/mapping.py b/tests/helpers/util/mapping.py new file mode 100644 index 00000000..828e7df8 --- /dev/null +++ b/tests/helpers/util/mapping.py @@ -0,0 +1,6 @@ +from mavedb.worker.jobs.utils.constants import MAPPING_QUEUE_NAME + + +async def sanitize_mapping_queue(standalone_worker_context, score_set): + queued_job = await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME) + assert int(queued_job.decode("utf-8")) == score_set.id diff --git a/tests/helpers/util/setup/worker.py b/tests/helpers/util/setup/worker.py new file mode 100644 index 00000000..50eee000 --- /dev/null +++ b/tests/helpers/util/setup/worker.py @@ -0,0 +1,154 @@ +import json +from asyncio.unix_events import _UnixSelectorEventLoop +from copy import deepcopy +from unittest.mock import patch +from uuid import uuid4 + +import cdot +import jsonschema +from sqlalchemy import select + +from mavedb.lib.score_sets import csv_data_to_df +from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.variant import Variant +from mavedb.view_models.experiment import Experiment, ExperimentCreate +from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate +from mavedb.worker.jobs import ( + create_variants_for_score_set, + map_variants_for_score_set, +) +from tests.helpers.constants import ( + TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD, + TEST_MINIMAL_EXPERIMENT, + TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD, + TEST_NT_CDOT_TRANSCRIPT, + TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD, + TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, + TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, +) +from tests.helpers.util.mapping import sanitize_mapping_queue + + +async def setup_records_and_files(async_client, data_files, input_score_set): + experiment_payload = deepcopy(TEST_MINIMAL_EXPERIMENT) + jsonschema.validate(instance=experiment_payload, schema=ExperimentCreate.model_json_schema()) + experiment_response = await async_client.post("/api/v1/experiments/", json=experiment_payload) + assert experiment_response.status_code == 200 + experiment = experiment_response.json() + jsonschema.validate(instance=experiment, schema=Experiment.model_json_schema()) + + score_set_payload = deepcopy(input_score_set) + score_set_payload["experimentUrn"] = experiment["urn"] + jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.model_json_schema()) + score_set_response = await async_client.post("/api/v1/score-sets/", json=score_set_payload) + assert score_set_response.status_code == 200 + score_set = score_set_response.json() + jsonschema.validate(instance=score_set, schema=ScoreSet.model_json_schema()) + + scores_fp = ( + "scores_multi_target.csv" + if len(score_set["targetGenes"]) > 1 + else ("scores.csv" if "targetSequence" in score_set["targetGenes"][0] else "scores_acc.csv") + ) + counts_fp = ( + "counts_multi_target.csv" + if len(score_set["targetGenes"]) > 1 + else ("counts.csv" if "targetSequence" in score_set["targetGenes"][0] else "counts_acc.csv") + ) + with ( + open(data_files / scores_fp, "rb") as score_file, + open(data_files / counts_fp, "rb") as count_file, + open(data_files / "score_columns_metadata.json", "rb") as score_columns_file, + open(data_files / "count_columns_metadata.json", "rb") as count_columns_file, + ): + scores = csv_data_to_df(score_file) + counts = csv_data_to_df(count_file) + score_columns_metadata = json.load(score_columns_file) + count_columns_metadata = json.load(count_columns_file) + + return score_set["urn"], scores, counts, score_columns_metadata, count_columns_metadata + + +async def setup_records_files_and_variants(session, async_client, data_files, input_score_set, worker_ctx): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + # Patch CDOT `_get_transcript`, in the event this function is called on an accesssion based scoreset. + with patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ): + result = await create_variants_for_score_set( + worker_ctx, uuid4().hex, score_set.id, 1, scores, counts, score_columns_metadata, count_columns_metadata + ) + + score_set_with_variants = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + assert result["success"] + assert score_set.processing_state is ProcessingState.success + assert score_set_with_variants.num_variants == 3 + + return score_set_with_variants + + +async def setup_records_files_and_variants_with_mapping( + session, async_client, data_files, input_score_set, standalone_worker_context +): + score_set = await setup_records_files_and_variants( + session, async_client, data_files, input_score_set, standalone_worker_context + ) + await sanitize_mapping_queue(standalone_worker_context, score_set) + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set) + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), + ): + result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + return session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + + +async def setup_mapping_output( + async_client, session, score_set, score_set_is_seq_based=True, score_set_is_multi_target=False, empty=False +): + score_set_response = await async_client.get(f"/api/v1/score-sets/{score_set.urn}") + + if score_set_is_seq_based: + if score_set_is_multi_target: + # If this is a multi-target sequence based score set, use the scaffold for that. + mapping_output = deepcopy(TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD) + else: + mapping_output = deepcopy(TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD) + else: + # there is not currently a multi-target accession-based score set test + mapping_output = deepcopy(TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD) + mapping_output["metadata"] = score_set_response.json() + + if empty: + return mapping_output + + variants = session.scalars(select(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).all() + for variant in variants: + mapped_score = { + "pre_mapped": TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, + "post_mapped": TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, + "mavedb_id": variant.urn, + } + + mapping_output["mapped_scores"].append(mapped_score) + + return mapping_output diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py new file mode 100644 index 00000000..28432297 --- /dev/null +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -0,0 +1,879 @@ +# ruff: noqa: E402 + +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +arq = pytest.importorskip("arq") + +from mavedb.lib.clingen.services import ( + ClinGenAlleleRegistryService, + ClinGenLdhService, + clingen_allele_id_from_ldh_variation, +) +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.variant import Variant +from mavedb.worker.jobs import ( + link_clingen_variants, + submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, +) +from tests.helpers.constants import ( + TEST_CLINGEN_ALLELE_OBJECT, + TEST_CLINGEN_LDH_LINKING_RESPONSE, + TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, + TEST_CLINGEN_SUBMISSION_RESPONSE, + TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE, + TEST_MINIMAL_SEQ_SCORESET, +) +from tests.helpers.util.exceptions import awaitable_exception +from tests.helpers.util.setup.worker import ( + setup_records_files_and_variants, + setup_records_files_and_variants_with_mapping, +) + +############################################################################################################################################ +# ClinGen CAR Submission +############################################################################################################################################ + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_success( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + ): + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + mapped_variants_with_caid_for_score_set = session.scalars( + select(MappedVariant) + .join(Variant) + .join(ScoreSetDbModel) + .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) + ).all() + + assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants + + assert result["success"] + assert not result["retried"] + assert result["enqueued_job"] is not None + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_exception_in_setup( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.clingen.setup_job_state", + side_effect=Exception(), + ): + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_no_variants_exist( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_exception_in_hgvs_dict_creation( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=Exception(), + ): + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_exception_during_submission( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", side_effect=Exception()), + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + ): + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_exception_in_allele_association( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch("mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", side_effect=Exception()), + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + ): + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_car_exception_during_ldh_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), + ): + result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) + + mapped_variants_with_caid_for_score_set = session.scalars( + select(MappedVariant) + .join(Variant) + .join(ScoreSetDbModel) + .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) + ).all() + + assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +############################################################################################################################################ +# ClinGen LDH Submission +############################################################################################################################################ + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_success( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_job(), + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert result["success"] + assert not result["retried"] + assert result["enqueued_job"] is not None + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_exception_in_setup( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.clingen.setup_job_state", + side_effect=Exception(), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_exception_in_auth( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch.object( + ClinGenLdhService, + "_existing_jwt", + side_effect=Exception(), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_no_variants_exist( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_exception_in_hgvs_generation( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=Exception(), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_exception_in_ldh_submission_construction( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.lib.clingen.content_constructors.construct_ldh_submission", + side_effect=Exception(), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_exception_during_submission( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def failed_submission_job(): + return Exception() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=failed_submission_job(), + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_response", [TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE] +) +async def test_submit_score_set_mappings_to_ldh_submission_failures_exist( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis, error_response +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_submission_job(): + return [None, error_response] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_job(), + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_exception_during_linking_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_job(), + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_submit_score_set_mappings_to_ldh_linking_not_queued_when_expected( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_job(), + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch.object(arq.ArqRedis, "enqueue_job", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +############################################################################################################################################## +## ClinGen Linkage +############################################################################################################################################## + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_success( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert result["success"] + assert not result["retried"] + assert result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert variant.clingen_allele_id == clingen_allele_id_from_ldh_variation(TEST_CLINGEN_LDH_LINKING_RESPONSE) + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_exception_in_setup( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.clingen.setup_job_state", + side_effect=Exception(), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert variant.clingen_allele_id is None + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_no_variants_to_link( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_exception_during_linkage( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception(), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_exception_while_parsing_linkages( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ), + patch( + "mavedb.worker.jobs.external_services.clingen.clingen_allele_id_from_ldh_variation", + side_effect=Exception(), + ), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_failures_exist_but_do_not_eclipse_retry_threshold( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, None) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ), + patch( + "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", + 2, + ), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert result["success"] + assert not result["retried"] + assert result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, None) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ), + patch( + "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", + 1, + ), + patch( + "mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", + 0, + ), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert not result["success"] + assert result["retried"] + assert result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_cant_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, None) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ), + patch( + "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", + 1, + ), + patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_retries_exceeded( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, None) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ), + patch( + "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", + 1, + ), + patch( + "mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", + 0, + ), + patch( + "mavedb.worker.jobs.utils.retry.ENQUEUE_BACKOFF_ATTEMPT_LIMIT", + 1, + ), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 2) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_ldh_objects_error_in_gnomad_job_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_linking_job(): + return [ + (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return + # value of the EventLoop itself, which would have made the request. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_linking_job(), + ), + patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), + ): + result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py new file mode 100644 index 00000000..c407462b --- /dev/null +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -0,0 +1,206 @@ +# ruff: noqa: E402 + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +arq = pytest.importorskip("arq") + +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.variant import Variant +from mavedb.worker.jobs import ( + link_gnomad_variants, +) +from tests.helpers.constants import ( + TEST_GNOMAD_DATA_VERSION, + TEST_MINIMAL_SEQ_SCORESET, + VALID_CLINGEN_CA_ID, +) +from tests.helpers.util.setup.worker import ( + setup_records_files_and_variants, + setup_records_files_and_variants_with_mapping, +) + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_gnomad_variants_success( + setup_worker_db, + standalone_worker_context, + session, + async_client, + data_files, + arq_worker, + arq_redis, + mocked_gnomad_variant_row, +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. + mapped_variants = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + + for mapped_variant in mapped_variants: + mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID + session.commit() + + # Patch Athena connection with mock object which returns a mocked gnomAD variant row w/ CAID=VALID_CLINGEN_CA_ID. + with ( + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[mocked_gnomad_variant_row], + ), + patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), + ): + result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert variant.gnomad_variants + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_gnomad_variants_exception_in_setup( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.gnomad.setup_job_state", + side_effect=Exception(), + ): + result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert not variant.gnomad_variants + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_gnomad_variants_no_variants_to_link( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert not variant.gnomad_variants + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_gnomad_variants_exception_while_fetching_variant_data( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.gnomad.setup_job_state", + side_effect=Exception(), + ), + patch("mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", side_effect=Exception()), + ): + result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert not variant.gnomad_variants + + +@pytest.mark.asyncio +async def test_link_score_set_mappings_to_gnomad_variants_exception_while_linking_variants( + setup_worker_db, + standalone_worker_context, + session, + async_client, + data_files, + arq_worker, + arq_redis, + mocked_gnomad_variant_row, +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. + mapped_variants = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + + for mapped_variant in mapped_variants: + mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[mocked_gnomad_variant_row], + ), + patch( + "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", + side_effect=Exception(), + ), + ): + result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_job"] + + for variant in session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ): + assert not variant.gnomad_variants diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py new file mode 100644 index 00000000..e3833f14 --- /dev/null +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -0,0 +1,603 @@ +# ruff: noqa: E402 + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from requests import HTTPError +from sqlalchemy import select + +arq = pytest.importorskip("arq") + + +from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.worker.jobs import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from tests.helpers.constants import ( + TEST_MINIMAL_SEQ_SCORESET, + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, + TEST_UNIPROT_SWISS_PROT_TYPE, + VALID_CHR_ACCESSION, + VALID_UNIPROT_ACCESSION, +) +from tests.helpers.util.setup.worker import ( + setup_records_files_and_variants, + setup_records_files_and_variants_with_mapping, +) + +### Test Submission + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_success( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + + assert result["success"] + assert not result["retried"] + assert result["enqueued_jobs"] is not None + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_no_targets( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + score_set.target_genes = [] + session.add(score_set) + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message: + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called_once() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_exception_while_spawning_jobs( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "submit_id_mapping", side_effect=HTTPError()), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_too_many_accessions( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", + return_value=["AC1", "AC2"], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_no_accessions( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message: + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_error_in_setup( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch("mavedb.worker.jobs.external_services.uniprot.setup_job_state", side_effect=Exception()), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_exception_during_submission_generation( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", + side_effect=Exception(), + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_no_spawned_jobs( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=None), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_submit_uniprot_id_mapping_exception_during_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), + patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) + mock_slack_message.assert_called() + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +### Test Polling + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_success( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), + patch.object( + UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE + ), + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + for target_gene in score_set.target_genes: + assert target_gene.uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_no_targets( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + score_set.target_genes = [] + session.add(score_set) + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message: + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called_once() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + for target_gene in score_set.target_genes: + assert target_gene.uniprot_id_from_mapped_metadata is None + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_too_many_accessions( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", + return_value=["AC1", "AC2"], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_no_accessions( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch("mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", return_value=[]), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_jobs_not_ready( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=False), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + for target_gene in score_set.target_genes: + assert target_gene.uniprot_id_from_mapped_metadata is None + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_no_jobs( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + # This case does not get sent to slack + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {}, + score_set.id, + uuid4().hex, + ) + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + for target_gene in score_set.target_genes: + assert target_gene.uniprot_id_from_mapped_metadata is None + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_no_ids_mapped( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), + patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value={"failedIDs": [VALID_CHR_ACCESSION]}), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + for target_gene in score_set.target_genes: + assert target_gene.uniprot_id_from_mapped_metadata is None + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_too_many_mapped_accessions( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + # Simulate a response with too many mapped IDs + too_many_mapped_ids_response = TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE.copy() + too_many_mapped_ids_response["results"].append( + {"from": "AC3", "to": {"primaryAccession": "AC3", "entryType": TEST_UNIPROT_SWISS_PROT_TYPE}} + ) + + with ( + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), + patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value=too_many_mapped_ids_response), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called() + + assert result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_error_in_setup( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch("mavedb.worker.jobs.external_services.uniprot.setup_job_state", side_effect=Exception()), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called_once() + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] + + +@pytest.mark.asyncio +async def test_poll_uniprot_id_mapping_exception_during_polling( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with ( + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", side_effect=Exception()), + patch( + "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None + ) as mock_slack_message, + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + standalone_worker_context, + {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, + score_set.id, + uuid4().hex, + ) + mock_slack_message.assert_called_once() + + assert not result["success"] + assert not result["retried"] + assert not result["enqueued_jobs"] diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py new file mode 100644 index 00000000..b5addb76 --- /dev/null +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -0,0 +1,557 @@ +# ruff: noqa: E402 + +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import patch +from uuid import uuid4 + +import pandas as pd +import pytest +from sqlalchemy import select + +arq = pytest.importorskip("arq") +cdot = pytest.importorskip("cdot") + +from mavedb.lib.clingen.services import ( + ClinGenLdhService, +) +from mavedb.lib.mave.constants import HGVS_NT_COLUMN +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.variant import Variant +from mavedb.worker.jobs import ( + create_variants_for_score_set, +) +from mavedb.worker.jobs.utils.constants import MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME +from tests.helpers.constants import ( + TEST_CLINGEN_ALLELE_OBJECT, + TEST_CLINGEN_LDH_LINKING_RESPONSE, + TEST_CLINGEN_SUBMISSION_RESPONSE, + TEST_MINIMAL_ACC_SCORESET, + TEST_MINIMAL_MULTI_TARGET_SCORESET, + TEST_MINIMAL_SEQ_SCORESET, + TEST_NT_CDOT_TRANSCRIPT, + VALID_NT_ACCESSION, +) +from tests.helpers.util.mapping import sanitize_mapping_queue +from tests.helpers.util.setup.worker import setup_mapping_output, setup_records_and_files + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set,validation_error", + [ + ( + TEST_MINIMAL_SEQ_SCORESET, + { + "exception": "encountered 1 invalid variant strings.", + "detail": ["target sequence mismatch for 'c.1T>A' at row 0 for sequence TEST1"], + }, + ), + ( + TEST_MINIMAL_ACC_SCORESET, + { + "exception": "encountered 1 invalid variant strings.", + "detail": [ + "Failed to parse row 0 with HGVS exception: NM_001637.3:c.1T>A: Variant reference (T) does not agree with reference sequence (G)." + ], + }, + ), + ( + TEST_MINIMAL_MULTI_TARGET_SCORESET, + { + "exception": "encountered 1 invalid variant strings.", + "detail": ["target sequence mismatch for 'n.1T>A' at row 0 for sequence TEST3"], + }, + ), + ], +) +async def test_create_variants_for_score_set_with_validation_error( + input_score_set, + validation_error, + setup_worker_db, + async_client, + standalone_worker_context, + session, + data_files, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + if input_score_set == TEST_MINIMAL_SEQ_SCORESET: + scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "c.1T>A" + elif input_score_set == TEST_MINIMAL_ACC_SCORESET: + scores.loc[:, HGVS_NT_COLUMN].iloc[0] = f"{VALID_NT_ACCESSION}:c.1T>A" + elif input_score_set == TEST_MINIMAL_MULTI_TARGET_SCORESET: + scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "TEST3:n.1T>A" + + with ( + patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ) as hdp, + ): + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + + # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. + if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): + hdp.assert_not_called() + else: + hdp.assert_called_once() + + db_variants = session.scalars(select(Variant)).all() + + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + assert score_set.num_variants == 0 + assert len(db_variants) == 0 + assert score_set.processing_state == ProcessingState.failed + assert score_set.processing_errors == validation_error + assert not result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set_with_caught_exception( + input_score_set, + setup_worker_db, + async_client, + standalone_worker_context, + session, + data_files, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee + # some exception will be raised no matter what in the async job. + with ( + patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc, + ): + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + mocked_exc.assert_called() + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 0 + assert len(db_variants) == 0 + assert score_set.processing_state == ProcessingState.failed + assert score_set.processing_errors == {"detail": [], "exception": ""} + assert not result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set_with_caught_base_exception( + input_score_set, + setup_worker_db, + async_client, + standalone_worker_context, + session, + data_files, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + # This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee + # some base exception will be handled no matter what in the async job. + with ( + patch.object(pd.DataFrame, "isnull", side_effect=BaseException), + ): + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 0 + assert len(db_variants) == 0 + assert score_set.processing_state == ProcessingState.failed + assert score_set.processing_errors is None + assert not result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set_with_existing_variants( + input_score_set, + setup_worker_db, + async_client, + standalone_worker_context, + session, + data_files, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + with patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ) as hdp: + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + + # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. + if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): + hdp.assert_not_called() + else: + hdp.assert_called_once() + + await sanitize_mapping_queue(standalone_worker_context, score_set) + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 3 + assert len(db_variants) == 3 + assert score_set.processing_state == ProcessingState.success + + with patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ) as hdp: + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 3 + assert len(db_variants) == 3 + assert score_set.processing_state == ProcessingState.success + assert score_set.processing_errors is None + assert result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set_with_existing_exceptions( + input_score_set, + setup_worker_db, + async_client, + standalone_worker_context, + session, + data_files, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee + # some exception will be raised no matter what in the async job. + with ( + patch.object( + pd.DataFrame, + "isnull", + side_effect=ValidationError("Test Exception", triggers=["exc_1", "exc_2"]), + ) as mocked_exc, + ): + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + mocked_exc.assert_called() + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 0 + assert len(db_variants) == 0 + assert score_set.processing_state == ProcessingState.failed + assert score_set.processing_errors == { + "exception": "Test Exception", + "detail": ["exc_1", "exc_2"], + } + + with patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ) as hdp: + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + + # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. + if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): + hdp.assert_not_called() + else: + hdp.assert_called_once() + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 3 + assert len(db_variants) == 3 + assert score_set.processing_state == ProcessingState.success + assert score_set.processing_errors is None + assert result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set( + input_score_set, + setup_worker_db, + async_client, + standalone_worker_context, + session, + data_files, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + with patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ) as hdp: + result = await create_variants_for_score_set( + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + + # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. + if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): + hdp.assert_not_called() + else: + hdp.assert_called_once() + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + + assert score_set.num_variants == 3 + assert len(db_variants) == 3 + assert score_set.processing_state == ProcessingState.success + assert result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set_enqueues_manager_and_successful_mapping( + input_score_set, + setup_worker_db, + session, + async_client, + data_files, + arq_worker, + arq_redis, +): + score_set_is_seq = all(["targetSequence" in target for target in input_score_set["targetGenes"]]) + score_set_is_multi_target = len(input_score_set["targetGenes"]) > 1 + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set, score_set_is_seq, score_set_is_multi_target) + + async def dummy_car_submission_job(): + return TEST_CLINGEN_ALLELE_OBJECT + + async def dummy_ldh_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + # Variants have not yet been created, so infer their URNs. + async def dummy_linking_job(): + return [(f"{score_set_urn}#{i}", TEST_CLINGEN_LDH_LINKING_RESPONSE) for i in range(1, len(scores) + 1)] + + with ( + patch.object( + cdot.hgvs.dataproviders.RESTDataProvider, + "_get_transcript", + return_value=TEST_NT_CDOT_TRANSCRIPT, + ) as hdp, + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[ + dummy_mapping_job(), + dummy_car_submission_job(), + dummy_ldh_submission_job(), + dummy_linking_job(), + ], + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + await arq_redis.enqueue_job( + "create_variants_for_score_set", + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. + if score_set_is_seq: + hdp.assert_not_called() + else: + hdp.assert_called_once() + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + + assert score_set.num_variants == 3 + assert len(db_variants) == 3 + assert score_set.processing_state == ProcessingState.success + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == score_set.num_variants + assert score_set.mapping_state == MappingState.complete + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) +) +async def test_create_variants_for_score_set_exception_skips_mapping( + input_score_set, + setup_worker_db, + session, + async_client, + data_files, + arq_worker, + arq_redis, +): + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() + + with patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc: + await arq_redis.enqueue_job( + "create_variants_for_score_set", + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + mocked_exc.assert_called() + + db_variants = session.scalars(select(Variant)).all() + score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + + assert score_set.num_variants == 0 + assert len(db_variants) == 0 + assert score_set.processing_state == ProcessingState.failed + assert score_set.processing_errors == {"detail": [], "exception": ""} + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert len(mapped_variants_for_score_set) == 0 + assert score_set.mapping_state == MappingState.not_attempted + assert score_set.mapping_errors is None diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py new file mode 100644 index 00000000..9606e2e0 --- /dev/null +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -0,0 +1,710 @@ +# ruff: noqa: E402 + +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +arq = pytest.importorskip("arq") + +from mavedb.lib.clingen.services import ( + ClinGenAlleleRegistryService, + ClinGenLdhService, +) +from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.variant import Variant +from mavedb.worker.jobs import ( + variant_mapper_manager, +) +from mavedb.worker.jobs.utils.constants import MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME +from tests.helpers.constants import ( + TEST_CLINGEN_ALLELE_OBJECT, + TEST_CLINGEN_LDH_LINKING_RESPONSE, + TEST_CLINGEN_SUBMISSION_RESPONSE, + TEST_GNOMAD_DATA_VERSION, + TEST_MINIMAL_SEQ_SCORESET, + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, +) +from tests.helpers.util.exceptions import awaitable_exception +from tests.helpers.util.setup.worker import setup_mapping_output, setup_records_files_and_variants + + +@pytest.mark.asyncio +async def test_mapping_manager_empty_queue(setup_worker_db, standalone_worker_context): + result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. + assert result["enqueued_job"] is None + assert result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + + +@pytest.mark.asyncio +async def test_mapping_manager_empty_queue_error_during_setup(setup_worker_db, standalone_worker_context): + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") + with patch.object(arq.ArqRedis, "rpop", Exception()): + result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. + assert result["enqueued_job"] is None + assert not result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + + +@pytest.mark.asyncio +async def test_mapping_manager_occupied_queue_mapping_in_progress( + setup_worker_db, standalone_worker_context, session, async_client, data_files +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") + with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): + result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. + assert result["enqueued_job"] is not None + assert ( + await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.deferred + assert result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 + assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set.id) + assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" + assert score_set.mapping_state == MappingState.queued + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_occupied_queue_mapping_not_in_progress( + setup_worker_db, standalone_worker_context, session, async_client, data_files +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") + with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): + result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # Mapping job should be queued if none is currently running, and the queue should now be empty. + assert result["enqueued_job"] is not None + assert ( + await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.queued + assert result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + # We don't actually start processing these score sets. + assert score_set.mapping_state == MappingState.queued + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_occupied_queue_mapping_in_progress_error_during_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") + with ( + patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress), + patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), + ): + result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. + assert result["enqueued_job"] is None + assert not result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" + assert score_set.mapping_state == MappingState.failed + assert score_set.mapping_errors is not None + + +@pytest.mark.asyncio +async def test_mapping_manager_occupied_queue_mapping_not_in_progress_error_during_enqueue( + setup_worker_db, standalone_worker_context, session, async_client, data_files +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") + with ( + patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found), + patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), + ): + result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # Enqueue would have failed, the job is unsuccessful, and we remove the queued item. + assert result["enqueued_job"] is None + assert not result["success"] + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 + assert score_set.mapping_state == MappingState.failed + assert score_set.mapping_errors is not None + + +@pytest.mark.asyncio +async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_in_progress( + setup_worker_db, standalone_worker_context, session, async_client, data_files +): + score_set_id_1 = ( + await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + ).id + score_set_id_2 = ( + await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + ).id + score_set_id_3 = ( + await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + ).id + + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") + with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): + result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # All three jobs should complete successfully... + assert result1["success"] + assert result2["success"] + assert result3["success"] + + # ...with a new job enqueued... + assert result1["enqueued_job"] is not None + assert result2["enqueued_job"] is not None + assert result3["enqueued_job"] is not None + + # ...of which all should be deferred jobs of the "variant_mapper_manager" variety... + assert ( + await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.deferred + assert ( + await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.deferred + assert ( + await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.deferred + + assert ( + await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() + ).function == "variant_mapper_manager" + assert ( + await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() + ).function == "variant_mapper_manager" + assert ( + await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() + ).function == "variant_mapper_manager" + + # ...and the queue state should have three jobs, each of our three created score sets. + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 3 + assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_1) + assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) + assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) + + score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() + score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() + score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() + # Each score set should remain queued with no mapping errors. + assert score_set1.mapping_state == MappingState.queued + assert score_set2.mapping_state == MappingState.queued + assert score_set3.mapping_state == MappingState.queued + assert score_set1.mapping_errors is None + assert score_set2.mapping_errors is None + assert score_set3.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_not_in_progress( + setup_worker_db, standalone_worker_context, session, async_client, data_files +): + score_set_id_1 = ( + await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + ).id + score_set_id_2 = ( + await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + ).id + score_set_id_3 = ( + await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + ).id + + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") + with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): + result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # Mock the first job being in-progress + await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, str(score_set_id_1)) + with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): + result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) + + # All three jobs should complete successfully... + assert result1["success"] + assert result2["success"] + assert result3["success"] + + # ...with a new job enqueued... + assert result1["enqueued_job"] is not None + assert result2["enqueued_job"] is not None + assert result3["enqueued_job"] is not None + + # ...of which the first should be a queued job of the "map_variants_for_score_set" variety and the other two should be + # deferred jobs of the "variant_mapper_manager" variety... + assert ( + await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.queued + assert ( + await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.deferred + assert ( + await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() + ) == arq.jobs.JobStatus.deferred + + assert ( + await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() + ).function == "map_variants_for_score_set" + assert ( + await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() + ).function == "variant_mapper_manager" + assert ( + await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() + ).function == "variant_mapper_manager" + + # ...and the queue state should have two jobs, neither of which should be the first score set. + assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 2 + assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) + assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) + + score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() + score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() + score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() + # We don't actually process any score sets in the manager job, and each should have no mapping errors. + assert score_set1.mapping_state == MappingState.queued + assert score_set2.mapping_state == MappingState.queued + assert score_set3.mapping_state == MappingState.queued + assert score_set1.mapping_errors is None + assert score_set2.mapping_errors is None + assert score_set3.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set) + + async def dummy_ldh_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + async def dummy_linking_job(): + return [ + (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround + # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine + # object that sets up test mapping output. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], + ), + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), + patch.object( + UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE + ), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", True), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), + patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + ): + await arq_worker.async_run() + num_completed_jobs = await arq_worker.run_check() + + # We should have completed all jobs exactly once. + assert num_completed_jobs == 8 + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == score_set.num_variants + assert score_set.mapping_state == MappingState.complete + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_disabled( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set) + + # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround + # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine + # object that sets up test mapping output. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[dummy_mapping_job()], + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", False), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), + ): + await arq_worker.async_run() + num_completed_jobs = await arq_worker.run_check() + + # We should have completed the manager and mapping jobs, but not the submission, linking, or uniprot mapping jobs. + assert num_completed_jobs == 2 + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == score_set.num_variants + assert score_set.mapping_state == MappingState.complete + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_enabled( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set) + + # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround + # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine + # object that sets up test mapping output. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[dummy_mapping_job()], + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), + patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), + patch.object( + UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE + ), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", True), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), + ): + await arq_worker.async_run() + num_completed_jobs = await arq_worker.run_check() + + # We should have completed the manager, mapping, and uniprot jobs, but not the submission or linking jobs. + assert num_completed_jobs == 4 + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == score_set.num_variants + assert score_set.mapping_state == MappingState.complete + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_enabled_uniprot_disabled( + setup_worker_db, + standalone_worker_context, + session, + async_client, + data_files, + arq_worker, + arq_redis, + mocked_gnomad_variant_row, +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set) + + async def dummy_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + async def dummy_linking_job(): + return [ + (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround + # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine + # object that sets up test mapping output. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[dummy_mapping_job(), dummy_submission_job(), dummy_linking_job()], + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", False), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), + patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[mocked_gnomad_variant_row], + ), + ): + await arq_worker.async_run() + num_completed_jobs = await arq_worker.run_check() + + # We should have completed the manager, mapping, submission, and linking jobs, but not the uniprot jobs. + assert num_completed_jobs == 6 + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == score_set.num_variants + assert score_set.mapping_state == MappingState.complete + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_enqueues_mapping_process_with_retried_mapping_successful_mapping_on_retry( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def failed_mapping_job(): + return Exception() + + async def dummy_mapping_job(): + return await setup_mapping_output(async_client, session, score_set) + + async def dummy_ldh_submission_job(): + return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] + + async def dummy_linking_job(): + return [ + (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) + for variant_urn in session.scalars( + select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + ] + + # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround + # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine + # object that sets up test mapping output. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[failed_mapping_job(), dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], + ), + patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), + patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", False), + patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), + patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), + ): + await arq_worker.async_run() + num_completed_jobs = await arq_worker.run_check() + + # We should have completed the mapping manager job twice, the mapping job twice, the two submission jobs, and both linking jobs. + assert num_completed_jobs == 8 + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == score_set.num_variants + assert score_set.mapping_state == MappingState.complete + assert score_set.mapping_errors is None + + +@pytest.mark.asyncio +async def test_mapping_manager_enqueues_mapping_process_with_unsuccessful_mapping( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + score_set = await setup_records_files_and_variants( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + async def failed_mapping_job(): + return Exception() + + # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround + # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine + # object that sets up test mapping output. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[failed_mapping_job()] * 5, + ), + patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), + ): + await arq_worker.async_run() + num_completed_jobs = await arq_worker.run_check() + + # We should have completed 6 mapping jobs and 6 management jobs. + assert num_completed_jobs == 12 + + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() + mapped_variants_for_score_set = session.scalars( + select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) + ).all() + assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 + assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" + assert len(mapped_variants_for_score_set) == 0 + assert score_set.mapping_state == MappingState.failed + assert score_set.mapping_errors is not None diff --git a/tests/worker/test_jobs.py b/tests/worker/test_jobs.py deleted file mode 100644 index e7fd0b39..00000000 --- a/tests/worker/test_jobs.py +++ /dev/null @@ -1,3479 +0,0 @@ -# ruff: noqa: E402 - -import json -from asyncio.unix_events import _UnixSelectorEventLoop -from copy import deepcopy -from datetime import date -from unittest.mock import patch -from uuid import uuid4 - -import jsonschema -import pandas as pd -import pytest -from requests import HTTPError -from sqlalchemy import not_, select - -arq = pytest.importorskip("arq") -cdot = pytest.importorskip("cdot") -fastapi = pytest.importorskip("fastapi") -pyathena = pytest.importorskip("pyathena") - -from mavedb.data_providers.services import VRSMap -from mavedb.lib.clingen.services import ( - ClinGenAlleleRegistryService, - ClinGenLdhService, - clingen_allele_id_from_ldh_variation, -) -from mavedb.lib.mave.constants import HGVS_NT_COLUMN -from mavedb.lib.score_sets import csv_data_to_df -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.lib.validation.exceptions import ValidationError -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.enums.processing_state import ProcessingState -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.models.variant import Variant -from mavedb.view_models.experiment import Experiment, ExperimentCreate -from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate -from mavedb.worker.jobs import ( - BACKOFF_LIMIT, - MAPPING_CURRENT_ID_NAME, - MAPPING_QUEUE_NAME, - create_variants_for_score_set, - link_clingen_variants, - link_gnomad_variants, - map_variants_for_score_set, - poll_uniprot_mapping_jobs_for_score_set, - submit_score_set_mappings_to_car, - submit_score_set_mappings_to_ldh, - submit_uniprot_mapping_jobs_for_score_set, - variant_mapper_manager, -) -from tests.helpers.constants import ( - TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_CLINGEN_ALLELE_OBJECT, - TEST_CLINGEN_LDH_LINKING_RESPONSE, - TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, - TEST_CLINGEN_SUBMISSION_RESPONSE, - TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE, - TEST_GNOMAD_DATA_VERSION, - TEST_MINIMAL_ACC_SCORESET, - TEST_MINIMAL_EXPERIMENT, - TEST_MINIMAL_MULTI_TARGET_SCORESET, - TEST_MINIMAL_SEQ_SCORESET, - TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_NT_CDOT_TRANSCRIPT, - TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, - TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, - TEST_UNIPROT_SWISS_PROT_TYPE, - TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, - TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - VALID_CHR_ACCESSION, - VALID_CLINGEN_CA_ID, - VALID_NT_ACCESSION, - VALID_UNIPROT_ACCESSION, -) -from tests.helpers.util.exceptions import awaitable_exception -from tests.helpers.util.experiment import create_experiment -from tests.helpers.util.score_set import create_acc_score_set, create_multi_target_score_set, create_seq_score_set - - -@pytest.fixture -def populate_worker_db(data_files, client): - # create score set via API. In production, the API would invoke this worker job - experiment = create_experiment(client) - seq_score_set = create_seq_score_set(client, experiment["urn"]) - acc_score_set = create_acc_score_set(client, experiment["urn"]) - multi_target_score_set = create_multi_target_score_set(client, experiment["urn"]) - - return [seq_score_set["urn"], acc_score_set["urn"], multi_target_score_set["urn"]] - - -async def setup_records_and_files(async_client, data_files, input_score_set): - experiment_payload = deepcopy(TEST_MINIMAL_EXPERIMENT) - jsonschema.validate(instance=experiment_payload, schema=ExperimentCreate.model_json_schema()) - experiment_response = await async_client.post("/api/v1/experiments/", json=experiment_payload) - assert experiment_response.status_code == 200 - experiment = experiment_response.json() - jsonschema.validate(instance=experiment, schema=Experiment.model_json_schema()) - - score_set_payload = deepcopy(input_score_set) - score_set_payload["experimentUrn"] = experiment["urn"] - jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.model_json_schema()) - score_set_response = await async_client.post("/api/v1/score-sets/", json=score_set_payload) - assert score_set_response.status_code == 200 - score_set = score_set_response.json() - jsonschema.validate(instance=score_set, schema=ScoreSet.model_json_schema()) - - scores_fp = ( - "scores_multi_target.csv" - if len(score_set["targetGenes"]) > 1 - else ("scores.csv" if "targetSequence" in score_set["targetGenes"][0] else "scores_acc.csv") - ) - counts_fp = ( - "counts_multi_target.csv" - if len(score_set["targetGenes"]) > 1 - else ("counts.csv" if "targetSequence" in score_set["targetGenes"][0] else "counts_acc.csv") - ) - with ( - open(data_files / scores_fp, "rb") as score_file, - open(data_files / counts_fp, "rb") as count_file, - open(data_files / "score_columns_metadata.json", "rb") as score_columns_file, - open(data_files / "count_columns_metadata.json", "rb") as count_columns_file, - ): - scores = csv_data_to_df(score_file) - counts = csv_data_to_df(count_file) - score_columns_metadata = json.load(score_columns_file) - count_columns_metadata = json.load(count_columns_file) - - return score_set["urn"], scores, counts, score_columns_metadata, count_columns_metadata - - -async def setup_records_files_and_variants(session, async_client, data_files, input_score_set, worker_ctx): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # Patch CDOT `_get_transcript`, in the event this function is called on an accesssion based scoreset. - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ): - result = await create_variants_for_score_set( - worker_ctx, uuid4().hex, score_set.id, 1, scores, counts, score_columns_metadata, count_columns_metadata - ) - - score_set_with_variants = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - assert result["success"] - assert score_set.processing_state is ProcessingState.success - assert score_set_with_variants.num_variants == 3 - - return score_set_with_variants - - -async def setup_records_files_and_variants_with_mapping( - session, async_client, data_files, input_score_set, standalone_worker_context -): - score_set = await setup_records_files_and_variants( - session, async_client, data_files, input_score_set, standalone_worker_context - ) - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - return session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - - -async def sanitize_mapping_queue(standalone_worker_context, score_set): - queued_job = await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME) - assert int(queued_job.decode("utf-8")) == score_set.id - - -async def setup_mapping_output( - async_client, session, score_set, score_set_is_seq_based=True, score_set_is_multi_target=False, empty=False -): - score_set_response = await async_client.get(f"/api/v1/score-sets/{score_set.urn}") - - if score_set_is_seq_based: - if score_set_is_multi_target: - # If this is a multi-target sequence based score set, use the scaffold for that. - mapping_output = deepcopy(TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD) - else: - mapping_output = deepcopy(TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD) - else: - # there is not currently a multi-target accession-based score set test - mapping_output = deepcopy(TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD) - mapping_output["metadata"] = score_set_response.json() - - if empty: - return mapping_output - - variants = session.scalars(select(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).all() - for variant in variants: - mapped_score = { - "pre_mapped": TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - "post_mapped": TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, - "mavedb_id": variant.urn, - } - - mapping_output["mapped_scores"].append(mapped_score) - - return mapping_output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set,validation_error", - [ - ( - TEST_MINIMAL_SEQ_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": ["target sequence mismatch for 'c.1T>A' at row 0 for sequence TEST1"], - }, - ), - ( - TEST_MINIMAL_ACC_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": [ - "Failed to parse row 0 with HGVS exception: NM_001637.3:c.1T>A: Variant reference (T) does not agree with reference sequence (G)." - ], - }, - ), - ( - TEST_MINIMAL_MULTI_TARGET_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": ["target sequence mismatch for 'n.1T>A' at row 0 for sequence TEST3"], - }, - ), - ], -) -async def test_create_variants_for_score_set_with_validation_error( - input_score_set, - validation_error, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - if input_score_set == TEST_MINIMAL_SEQ_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "c.1T>A" - elif input_score_set == TEST_MINIMAL_ACC_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = f"{VALID_NT_ACCESSION}:c.1T>A" - elif input_score_set == TEST_MINIMAL_MULTI_TARGET_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "TEST3:n.1T>A" - - with ( - patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == validation_error - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_caught_exception( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some exception will be raised no matter what in the async job. - with ( - patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == {"detail": [], "exception": ""} - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_caught_base_exception( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some base exception will be handled no matter what in the async job. - with ( - patch.object(pd.DataFrame, "isnull", side_effect=BaseException), - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_existing_variants( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - await sanitize_mapping_queue(standalone_worker_context, score_set) - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert score_set.processing_errors is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_existing_exceptions( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some exception will be raised no matter what in the async job. - with ( - patch.object( - pd.DataFrame, - "isnull", - side_effect=ValidationError("Test Exception", triggers=["exc_1", "exc_2"]), - ) as mocked_exc, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == { - "exception": "Test Exception", - "detail": ["exc_1", "exc_2"], - } - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert score_set.processing_errors is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_enqueues_manager_and_successful_mapping( - input_score_set, - setup_worker_db, - session, - async_client, - data_files, - arq_worker, - arq_redis, -): - score_set_is_seq = all(["targetSequence" in target for target in input_score_set["targetGenes"]]) - score_set_is_multi_target = len(input_score_set["targetGenes"]) > 1 - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set, score_set_is_seq, score_set_is_multi_target) - - async def dummy_car_submission_job(): - return TEST_CLINGEN_ALLELE_OBJECT - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # Variants have not yet been created, so infer their URNs. - async def dummy_linking_job(): - return [(f"{score_set_urn}#{i}", TEST_CLINGEN_LDH_LINKING_RESPONSE) for i in range(1, len(scores) + 1)] - - with ( - patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp, - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[ - dummy_mapping_job(), - dummy_car_submission_job(), - dummy_ldh_submission_job(), - dummy_linking_job(), - ], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - await arq_redis.enqueue_job( - "create_variants_for_score_set", - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - await arq_worker.async_run() - await arq_worker.run_check() - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if score_set_is_seq: - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_exception_skips_mapping( - input_score_set, - setup_worker_db, - session, - async_client, - data_files, - arq_worker, - arq_redis, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc: - await arq_redis.enqueue_job( - "create_variants_for_score_set", - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - await arq_worker.async_run() - await arq_worker.run_check() - - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == {"detail": [], "exception": ""} - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.not_attempted - assert score_set.mapping_errors is None - - -# NOTE: These tests operate under the assumption that mapping output is consistent between accession based and sequence based score sets. If -# this assumption changes in the future, tests reflecting this difference in output should be added for accession based score sets. - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset( - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_with_existing_mapped_variants( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - existing_variant = session.scalars(select(Variant)).first() - - if not existing_variant: - raise ValueError - - session.add( - MappedVariant( - pre_mapped={"preexisting": "variant"}, - post_mapped={"preexisting": "variant"}, - variant_id=existing_variant.id, - modification_date=date.today(), - mapped_date=date.today(), - vrs_version="2.0", - mapping_api_version="0.0.0", - current=True, - ) - ) - session.commit() - - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - preexisting_variants = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, not_(MappedVariant.current)) - ).all() - new_variants = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.current) - ).all() - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == score_set.num_variants + 1 - assert len(preexisting_variants) == 1 - assert len(new_variants) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_exception_in_mapping_setup_score_set_selection( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id + 5, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # When we cannot fetch a score set, these fields are unable to be updated. - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_exception_in_mapping_setup_vrs_object( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - with patch.object( - VRSMap, - "__init__", - return_value=Exception(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_mapping_exception( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_mapping_exception_retry_limit_reached( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ): - result = await map_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, BACKOFF_LIMIT + 1 - ) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_mapping_exception_retry_failed( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=awaitable_exception(), - ), - patch.object(arq.ArqRedis, "lpush", awaitable_exception()), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # Behavior for exception in mapping is retried job - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_parsing_exception_with_retry( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - mapping_test_output_for_score_set = await setup_mapping_output(async_client, session, score_set) - mapping_test_output_for_score_set.pop("computed_genomic_reference_sequence") - return mapping_test_output_for_score_set - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_parsing_exception_retry_failed( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - mapping_test_output_for_score_set = await setup_mapping_output(async_client, session, score_set) - mapping_test_output_for_score_set.pop("computed_genomic_reference_sequence") - return mapping_test_output_for_score_set - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch.object(arq.ArqRedis, "lpush", awaitable_exception()), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # Behavior for exception outside mapping is failed job - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_parsing_exception_retry_limit_reached( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - async def dummy_mapping_job(): - mapping_test_output_for_score_set = await setup_mapping_output(async_client, session, score_set) - mapping_test_output_for_score_set.pop("computed_genomic_reference_sequence") - return mapping_test_output_for_score_set - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ): - result = await map_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, BACKOFF_LIMIT + 1 - ) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - # Behavior for exception outside mapping is failed job - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_create_mapped_variants_for_scoreset_no_mapping_output( - setup_worker_db, async_client, standalone_worker_context, session, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - # The call to `create_variants_from_score_set` within the above `setup_records_files_and_variants` will - # add a score set to the queue. Since we are executing the mapping independent of the manager job, we should - # sanitize the queue as if the mananger process had run. - await sanitize_mapping_queue(standalone_worker_context, score_set) - - # Do not await, we need a co-routine object to be the return value of our `run_in_executor` mock. - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set, empty=True) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_mapping_job(), - ), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - - -@pytest.mark.asyncio -async def test_mapping_manager_empty_queue(setup_worker_db, standalone_worker_context): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. - assert result["enqueued_job"] is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - - -@pytest.mark.asyncio -async def test_mapping_manager_empty_queue_error_during_setup(setup_worker_db, standalone_worker_context): - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.ArqRedis, "rpop", Exception()): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. - assert result["enqueued_job"] is not None - assert ( - await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set.id) - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_not_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Mapping job should be queued if none is currently running, and the queue should now be empty. - assert result["enqueued_job"] is not None - assert ( - await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.queued - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - # We don't actually start processing these score sets. - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_in_progress_error_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with ( - patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_not_in_progress_error_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with ( - patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Enqueue would have failed, the job is unsuccessful, and we remove the queued item. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set_id_1 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_2 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_3 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # All three jobs should complete successfully... - assert result1["success"] - assert result2["success"] - assert result3["success"] - - # ...with a new job enqueued... - assert result1["enqueued_job"] is not None - assert result2["enqueued_job"] is not None - assert result3["enqueued_job"] is not None - - # ...of which all should be deferred jobs of the "variant_mapper_manager" variety... - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - - # ...and the queue state should have three jobs, each of our three created score sets. - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 3 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_1) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) - - score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() - score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() - score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() - # Each score set should remain queued with no mapping errors. - assert score_set1.mapping_state == MappingState.queued - assert score_set2.mapping_state == MappingState.queued - assert score_set3.mapping_state == MappingState.queued - assert score_set1.mapping_errors is None - assert score_set2.mapping_errors is None - assert score_set3.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_not_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set_id_1 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_2 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_3 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): - result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Mock the first job being in-progress - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, str(score_set_id_1)) - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # All three jobs should complete successfully... - assert result1["success"] - assert result2["success"] - assert result3["success"] - - # ...with a new job enqueued... - assert result1["enqueued_job"] is not None - assert result2["enqueued_job"] is not None - assert result3["enqueued_job"] is not None - - # ...of which the first should be a queued job of the "map_variants_for_score_set" variety and the other two should be - # deferred jobs of the "variant_mapper_manager" variety... - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.queued - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "map_variants_for_score_set" - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - - # ...and the queue state should have two jobs, neither of which should be the first score set. - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 2 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) - - score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() - score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() - score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() - # We don't actually process any score sets in the manager job, and each should have no mapping errors. - assert score_set1.mapping_state == MappingState.queued - assert score_set2.mapping_state == MappingState.queued - assert score_set3.mapping_state == MappingState.queued - assert score_set1.mapping_errors is None - assert score_set2.mapping_errors is None - assert score_set3.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", True), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed all jobs exactly once. - assert num_completed_jobs == 8 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_disabled( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager and mapping jobs, but not the submission, linking, or uniprot mapping jobs. - assert num_completed_jobs == 2 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_enabled( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", True), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager, mapping, and uniprot jobs, but not the submission or linking jobs. - assert num_completed_jobs == 4 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_enabled_uniprot_disabled( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job(), dummy_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", return_value=[mocked_gnomad_variant_row]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager, mapping, submission, and linking jobs, but not the uniprot jobs. - assert num_completed_jobs == 6 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_retried_mapping_successful_mapping_on_retry( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_mapping_job(): - return Exception() - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[failed_mapping_job(), dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the mapping manager job twice, the mapping job twice, the two submission jobs, and both linking jobs. - assert num_completed_jobs == 8 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_unsuccessful_mapping( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_mapping_job(): - return Exception() - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mappingn output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[failed_mapping_job()] * 5, - ), - patch("mavedb.worker.jobs.MAPPING_BACKOFF_IN_SECONDS", 0), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed 6 mapping jobs and 6 management jobs. - assert num_completed_jobs == 12 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -############################################################################################################################################ -# ClinGen CAR Submission -############################################################################################################################################ - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - mapped_variants_with_caid_for_score_set = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) - ).all() - - assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] is not None - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_no_variants_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_hgvs_dict_creation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.get_hgvs_from_post_mapped", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_during_submission( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", side_effect=Exception()), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_allele_association( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.get_allele_registry_associations", side_effect=Exception()), - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_during_ldh_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.CAR_SUBMISSION_ENDPOINT", "https://reg.test.genome.network/pytest"), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - mapped_variants_with_caid_for_score_set = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) - ).all() - - assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -############################################################################################################################################ -# ClinGen LDH Submission -############################################################################################################################################ - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] is not None - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_auth( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch.object( - ClinGenLdhService, - "_existing_jwt", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_no_variants_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_hgvs_generation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.lib.variants.get_hgvs_from_post_mapped", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_ldh_submission_construction( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.lib.clingen.content_constructors.construct_ldh_submission", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_during_submission( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_submission_job(): - return Exception() - - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=failed_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "error_response", [TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE] -) -async def test_submit_score_set_mappings_to_ldh_submission_failures_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis, error_response -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [None, error_response] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_during_linking_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_linking_not_queued_when_expected( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(arq.ArqRedis, "enqueue_job", return_value=None), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -############################################################################################################################################## -## ClinGen Linkage -############################################################################################################################################## - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.clingen_allele_id == clingen_allele_id_from_ldh_variation(TEST_CLINGEN_LDH_LINKING_RESPONSE) - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.clingen_allele_id is None - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_no_variants_to_link( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_during_linkage( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=Exception(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_while_parsing_linkages( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.clingen_allele_id_from_ldh_variation", - side_effect=Exception(), - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_but_do_not_eclipse_retry_threshold( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 2, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch( - "mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", - 0, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert result["retried"] - assert result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_cant_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_retries_exceeded( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch( - "mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", - 0, - ), - patch( - "mavedb.worker.jobs.BACKOFF_LIMIT", - 1, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 2) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_error_in_gnomad_job_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -################################################################################################################################################## -# UniProt ID mapping -################################################################################################################################################## - -### Test Submission - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] is not None - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_targets( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - score_set.target_genes = [] - session.add(score_set) - session.commit() - - with patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message: - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called_once() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_while_spawning_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", side_effect=HTTPError()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_too_many_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", return_value=["AC1", "AC2"]), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message: - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_error_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.setup_job_state", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_during_submission_generation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_spawned_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=None), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -### Test Polling - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_targets( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - score_set.target_genes = [] - session.add(score_set) - session.commit() - - with patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message: - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_too_many_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", return_value=["AC1", "AC2"]), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.extract_ids_from_post_mapped_metadata", return_value=[]), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_jobs_not_ready( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=False), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # This case does not get sent to slack - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {}, - score_set.id, - uuid4().hex, - ) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_ids_mapped( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value={"failedIDs": [VALID_CHR_ACCESSION]}), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_too_many_mapped_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # Simulate a response with too many mapped IDs - too_many_mapped_ids_response = TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE.copy() - too_many_mapped_ids_response["results"].append( - {"from": "AC3", "to": {"primaryAccession": "AC3", "entryType": TEST_UNIPROT_SWISS_PROT_TYPE}} - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value=too_many_mapped_ids_response), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_error_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.setup_job_state", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_exception_during_polling( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", side_effect=Exception()), - patch("mavedb.worker.jobs.log_and_send_slack_message", return_value=None) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -################################################################################################################################################## -# gnomAD Linking -################################################################################################################################################## - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_success( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. - mapped_variants = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID - session.commit() - - # Patch Athena connection with mock object which returns a mocked gnomAD variant row w/ CAID=VALID_CLINGEN_CA_ID. - with ( - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", return_value=[mocked_gnomad_variant_row]), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_no_variants_to_link( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_while_fetching_variant_data( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.setup_job_state", - side_effect=Exception(), - ), - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", side_effect=Exception()), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_while_linking_variants( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. - mapped_variants = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID - session.commit() - - with ( - patch("mavedb.worker.jobs.gnomad_variant_data_for_caids", return_value=[mocked_gnomad_variant_row]), - patch("mavedb.worker.jobs.link_gnomad_variants_to_mapped_variants", side_effect=Exception()), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants From 1db6b687619d63c89cc2a7fc0a13887d7a453e1f Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 7 Jan 2026 11:20:43 -0800 Subject: [PATCH 02/70] feat: Add comprehensive job traceability system database schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement complete database foundation for pipeline-based job tracking and monitoring: Database Tables: • pipelines - High-level workflow grouping with correlation IDs for end-to-end tracing • job_runs - Individual job execution tracking with full lifecycle management • job_dependencies - Workflow orchestration with success/completion dependency types • job_metrics - Detailed performance metrics (CPU, memory, execution time, business metrics) • variant_annotation_status - Granular variant-level annotation tracking with success data Key Features: • Pipeline workflow management with dependency resolution • Comprehensive job lifecycle tracking (pending → running → completed/failed) • Retry logic with configurable limits and backoff strategies • Resource usage and performance metrics collection • Variant-level annotation status for debugging failures • Correlation ID support for request tracing across system • JSONB metadata fields for flexible job-specific data • Optimized indexes for common query patterns Schema Design: • Foreign key relationships maintain data integrity • Check constraints ensure valid enum values and positive numbers • Strategic indexes optimize dependency resolution and metrics queries • Cascade deletes prevent orphaned records • Version tracking for audit and debugging Models & Enums: • SQLAlchemy models with proper relationships and hybrid properties • Comprehensive enum definitions for job/pipeline status and failure categories --- ...d7_add_pipeline_and_job_tracking_tables.py | 222 ++++++++++++++++++ src/mavedb/models/__init__.py | 4 + src/mavedb/models/enums/__init__.py | 25 ++ src/mavedb/models/enums/annotation_type.py | 12 + src/mavedb/models/enums/job_pipeline.py | 75 ++++++ src/mavedb/models/job_dependency.py | 72 ++++++ src/mavedb/models/job_run.py | 113 +++++++++ src/mavedb/models/pipeline.py | 88 +++++++ .../models/variant_annotation_status.py | 107 +++++++++ tests/worker/conftest.py | 86 ++++++- 10 files changed, 801 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py create mode 100644 src/mavedb/models/enums/annotation_type.py create mode 100644 src/mavedb/models/enums/job_pipeline.py create mode 100644 src/mavedb/models/job_dependency.py create mode 100644 src/mavedb/models/job_run.py create mode 100644 src/mavedb/models/pipeline.py create mode 100644 src/mavedb/models/variant_annotation_status.py diff --git a/alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py b/alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py new file mode 100644 index 00000000..af7eb945 --- /dev/null +++ b/alembic/versions/8de33cc35cd7_add_pipeline_and_job_tracking_tables.py @@ -0,0 +1,222 @@ +"""add pipeline and job tracking tables + +Revision ID: 8de33cc35cd7 +Revises: dcf8572d3a17 +Create Date: 2026-01-28 10:08:36.906494 + +""" + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8de33cc35cd7" +down_revision = "dcf8572d3a17" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "pipelines", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("urn", sa.String(length=255), nullable=True), + sa.Column("name", sa.String(length=500), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("status", sa.String(length=50), nullable=False), + sa.Column("correlation_id", sa.String(length=255), nullable=True), + sa.Column( + "metadata", + postgresql.JSONB(astext_type=sa.Text()), + server_default="{}", + nullable=False, + comment="Flexible metadata storage for pipeline-specific data", + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_by_user_id", sa.Integer(), nullable=True), + sa.Column("mavedb_version", sa.String(length=50), nullable=True), + sa.CheckConstraint( + "status IN ('created', 'running', 'succeeded', 'failed', 'cancelled', 'paused', 'partial')", + name="ck_pipelines_status_valid", + ), + sa.ForeignKeyConstraint(["created_by_user_id"], ["users.id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("urn"), + ) + op.create_index("ix_pipelines_correlation_id", "pipelines", ["correlation_id"], unique=False) + op.create_index("ix_pipelines_created_at", "pipelines", ["created_at"], unique=False) + op.create_index("ix_pipelines_created_by_user_id", "pipelines", ["created_by_user_id"], unique=False) + op.create_index("ix_pipelines_status", "pipelines", ["status"], unique=False) + op.create_table( + "job_runs", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("urn", sa.String(length=255), nullable=True), + sa.Column("job_type", sa.String(length=100), nullable=False), + sa.Column("job_function", sa.String(length=255), nullable=False), + sa.Column("job_params", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("status", sa.String(length=50), nullable=False), + sa.Column("pipeline_id", sa.Integer(), nullable=True), + sa.Column("priority", sa.Integer(), nullable=False), + sa.Column("max_retries", sa.Integer(), nullable=False), + sa.Column("retry_count", sa.Integer(), nullable=False), + sa.Column("retry_delay_seconds", sa.Integer(), nullable=True), + sa.Column("scheduled_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("error_traceback", sa.Text(), nullable=True), + sa.Column("failure_category", sa.String(length=100), nullable=True), + sa.Column("progress_current", sa.Integer(), nullable=True), + sa.Column("progress_total", sa.Integer(), nullable=True), + sa.Column("progress_message", sa.String(length=500), nullable=True), + sa.Column("correlation_id", sa.String(length=255), nullable=True), + sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), server_default="{}", nullable=False), + sa.Column("mavedb_version", sa.String(length=50), nullable=True), + sa.CheckConstraint( + "status IN ('pending', 'queued', 'running', 'succeeded', 'failed', 'cancelled', 'skipped')", + name="ck_job_runs_status_valid", + ), + sa.CheckConstraint("max_retries >= 0", name="ck_job_runs_max_retries_positive"), + sa.CheckConstraint("priority >= 0", name="ck_job_runs_priority_positive"), + sa.CheckConstraint("retry_count >= 0", name="ck_job_runs_retry_count_positive"), + sa.ForeignKeyConstraint(["pipeline_id"], ["pipelines.id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("urn"), + ) + op.create_index("ix_job_runs_correlation_id", "job_runs", ["correlation_id"], unique=False) + op.create_index("ix_job_runs_created_at", "job_runs", ["created_at"], unique=False) + op.create_index("ix_job_runs_job_type", "job_runs", ["job_type"], unique=False) + op.create_index("ix_job_runs_pipeline_id", "job_runs", ["pipeline_id"], unique=False) + op.create_index("ix_job_runs_scheduled_at", "job_runs", ["scheduled_at"], unique=False) + op.create_index("ix_job_runs_status", "job_runs", ["status"], unique=False) + op.create_index("ix_job_runs_status_scheduled", "job_runs", ["status", "scheduled_at"], unique=False) + op.create_table( + "job_dependencies", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("depends_on_job_id", sa.Integer(), nullable=False), + sa.Column("dependency_type", sa.String(length=50), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.CheckConstraint( + "dependency_type IS NULL OR dependency_type IN ('success_required', 'completion_required')", + name="ck_job_dependencies_type_valid", + ), + sa.ForeignKeyConstraint(["depends_on_job_id"], ["job_runs.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["id"], ["job_runs.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id", "depends_on_job_id"), + ) + op.create_index("ix_job_dependencies_created_at", "job_dependencies", ["created_at"], unique=False) + op.create_index("ix_job_dependencies_depends_on_job_id", "job_dependencies", ["depends_on_job_id"], unique=False) + op.create_table( + "variant_annotation_status", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("variant_id", sa.Integer(), nullable=False), + sa.Column( + "annotation_type", + sa.String(length=50), + nullable=False, + comment="Type of annotation: vrs, clinvar, gnomad, etc.", + ), + sa.Column( + "version", + sa.String(length=50), + nullable=True, + comment="Version of the annotation source used (if applicable)", + ), + sa.Column("status", sa.String(length=50), nullable=False, comment="success, failed, skipped, pending"), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("failure_category", sa.String(length=100), nullable=True), + sa.Column( + "success_data", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Annotation results when successful", + ), + sa.Column( + "current", + sa.Boolean(), + server_default="true", + nullable=False, + comment="Whether this is the current status for the variant and annotation type", + ), + sa.Column("job_run_id", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.CheckConstraint( + "annotation_type IN ('vrs_mapping', 'clingen_allele_id', 'mapped_hgvs', 'variant_translation', 'gnomad_allele_frequency', 'clinvar_control', 'vep_functional_consequence', 'ldh_submission')", + name="ck_variant_annotation_type_valid", + ), + sa.CheckConstraint("status IN ('success', 'failed', 'skipped')", name="ck_variant_annotation_status_valid"), + sa.ForeignKeyConstraint(["job_run_id"], ["job_runs.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint(["variant_id"], ["variants.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_variant_annotation_status_annotation_type", "variant_annotation_status", ["annotation_type"], unique=False + ) + op.create_index( + "ix_variant_annotation_status_created_at", "variant_annotation_status", ["created_at"], unique=False + ) + op.create_index("ix_variant_annotation_status_current", "variant_annotation_status", ["current"], unique=False) + op.create_index( + "ix_variant_annotation_status_job_run_id", "variant_annotation_status", ["job_run_id"], unique=False + ) + op.create_index("ix_variant_annotation_status_status", "variant_annotation_status", ["status"], unique=False) + op.create_index( + "ix_variant_annotation_status_variant_id", "variant_annotation_status", ["variant_id"], unique=False + ) + op.create_index( + "ix_variant_annotation_status_variant_type_version_current", + "variant_annotation_status", + ["variant_id", "annotation_type", "version", "current"], + unique=False, + ) + op.create_index("ix_variant_annotation_status_version", "variant_annotation_status", ["version"], unique=False) + op.create_index( + "ix_variant_annotation_type_status", "variant_annotation_status", ["annotation_type", "status"], unique=False + ) + op.create_index( + "ix_variant_annotation_variant_type_status", + "variant_annotation_status", + ["variant_id", "annotation_type", "status"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_variant_annotation_variant_type_status", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_type_status", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_version", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_variant_type_version_current", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_variant_id", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_status", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_job_run_id", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_current", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_created_at", table_name="variant_annotation_status") + op.drop_index("ix_variant_annotation_status_annotation_type", table_name="variant_annotation_status") + op.drop_table("variant_annotation_status") + op.drop_index("ix_job_dependencies_depends_on_job_id", table_name="job_dependencies") + op.drop_index("ix_job_dependencies_created_at", table_name="job_dependencies") + op.drop_table("job_dependencies") + op.drop_index("ix_job_runs_status_scheduled", table_name="job_runs") + op.drop_index("ix_job_runs_status", table_name="job_runs") + op.drop_index("ix_job_runs_scheduled_at", table_name="job_runs") + op.drop_index("ix_job_runs_pipeline_id", table_name="job_runs") + op.drop_index("ix_job_runs_job_type", table_name="job_runs") + op.drop_index("ix_job_runs_created_at", table_name="job_runs") + op.drop_index("ix_job_runs_correlation_id", table_name="job_runs") + op.drop_table("job_runs") + op.drop_index("ix_pipelines_status", table_name="pipelines") + op.drop_index("ix_pipelines_created_by_user_id", table_name="pipelines") + op.drop_index("ix_pipelines_created_at", table_name="pipelines") + op.drop_index("ix_pipelines_correlation_id", table_name="pipelines") + op.drop_table("pipelines") + # ### end Alembic commands ### diff --git a/src/mavedb/models/__init__.py b/src/mavedb/models/__init__.py index 684b3c98..191fdc51 100644 --- a/src/mavedb/models/__init__.py +++ b/src/mavedb/models/__init__.py @@ -10,9 +10,12 @@ "experiment_set", "genome_identifier", "gnomad_variant", + "job_dependency", + "job_run", "legacy_keyword", "license", "mapped_variant", + "pipeline", "publication_identifier", "published_variant", "raw_read_identifier", @@ -27,6 +30,7 @@ "uniprot_identifier", "uniprot_offset", "user", + "variant_annotation_status", "variant", "variant_translation", ] diff --git a/src/mavedb/models/enums/__init__.py b/src/mavedb/models/enums/__init__.py index e69de29b..80c3a7de 100644 --- a/src/mavedb/models/enums/__init__.py +++ b/src/mavedb/models/enums/__init__.py @@ -0,0 +1,25 @@ +""" +Enums used by MaveDB models. +""" + +from .contribution_role import ContributionRole +from .job_pipeline import AnnotationStatus, DependencyType, FailureCategory, JobStatus, PipelineStatus +from .mapping_state import MappingState +from .processing_state import ProcessingState +from .score_calibration_relation import ScoreCalibrationRelation +from .target_category import TargetCategory +from .user_role import UserRole + +__all__ = [ + "ContributionRole", + "JobStatus", + "PipelineStatus", + "DependencyType", + "FailureCategory", + "AnnotationStatus", + "MappingState", + "ProcessingState", + "ScoreCalibrationRelation", + "TargetCategory", + "UserRole", +] diff --git a/src/mavedb/models/enums/annotation_type.py b/src/mavedb/models/enums/annotation_type.py new file mode 100644 index 00000000..773f056e --- /dev/null +++ b/src/mavedb/models/enums/annotation_type.py @@ -0,0 +1,12 @@ +import enum + + +class AnnotationType(enum.Enum): + VRS_MAPPING = "vrs_mapping" + CLINGEN_ALLELE_ID = "clingen_allele_id" + MAPPED_HGVS = "mapped_hgvs" + VARIANT_TRANSLATION = "variant_translation" + GNOMAD_ALLELE_FREQUENCY = "gnomad_allele_frequency" + CLINVAR_CONTROLS = "clinvar_control" + VEP_FUNCTIONAL_CONSEQUENCE = "vep_functional_consequence" + LDH_SUBMISSION = "ldh_submission" diff --git a/src/mavedb/models/enums/job_pipeline.py b/src/mavedb/models/enums/job_pipeline.py new file mode 100644 index 00000000..c8cc78e8 --- /dev/null +++ b/src/mavedb/models/enums/job_pipeline.py @@ -0,0 +1,75 @@ +""" +Job and pipeline related enums. +""" + +from enum import Enum + + +class JobStatus(str, Enum): + """Status of a job execution.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + SKIPPED = "skipped" + + +class PipelineStatus(str, Enum): + """Status of a pipeline execution.""" + + CREATED = "created" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class DependencyType(str, Enum): + """Types of job dependencies.""" + + SUCCESS_REQUIRED = "success_required" # Job only runs if dependency succeeded + COMPLETION_REQUIRED = "completion_required" # Job runs if dependency completed (success OR failure) + + +class FailureCategory(str, Enum): + """Categories of job failures for better classification and handling.""" + + # System-level failures + SYSTEM_ERROR = "system_error" + TIMEOUT = "timeout" + RESOURCE_EXHAUSTION = "resource_exhaustion" + CONFIGURATION_ERROR = "configuration_error" + DEPENDENCY_FAILURE = "dependency_failure" + + # Data and validation failures + VALIDATION_ERROR = "validation_error" + DATA_ERROR = "data_error" + + # External service failures + NETWORK_ERROR = "network_error" + API_RATE_LIMITED = "api_rate_limited" + SERVICE_UNAVAILABLE = "service_unavailable" + AUTHENTICATION_FAILED = "authentication_failed" + + # Permission and access failures + PERMISSION_ERROR = "permission_error" + QUOTA_EXCEEDED = "quota_exceeded" + + # Variant processing specific + INVALID_HGVS = "invalid_hgvs" + REFERENCE_MISMATCH = "reference_mismatch" + VRS_MAPPING_FAILED = "vrs_mapping_failed" + TRANSCRIPT_NOT_FOUND = "transcript_not_found" + + # Catch-all + UNKNOWN = "unknown" + + +class AnnotationStatus(str, Enum): + """Status of individual variant annotations.""" + + SUCCESS = "success" + FAILED = "failed" + SKIPPED = "skipped" diff --git a/src/mavedb/models/job_dependency.py b/src/mavedb/models/job_dependency.py new file mode 100644 index 00000000..414c49c1 --- /dev/null +++ b/src/mavedb/models/job_dependency.py @@ -0,0 +1,72 @@ +""" +SQLAlchemy models for job dependencies. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, String, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.models.enums import DependencyType + +if TYPE_CHECKING: + from mavedb.models.job_run import JobRun + from mavedb.models.pipeline import Pipeline + + +class JobDependency(Base): + """ + Defines dependencies between jobs within a pipeline. + + This table maps jobs to their pipeline and defines execution order. + """ + + __tablename__ = "job_dependencies" + + # The job being defined (references job_runs.id) + id: Mapped[str] = mapped_column(String(255), ForeignKey("job_runs.id", ondelete="CASCADE"), primary_key=True) + + # Pipeline this job belongs to + pipeline_id: Mapped[str] = mapped_column( + String(255), ForeignKey("pipelines.id", ondelete="CASCADE"), nullable=False + ) + + # Job this depends on (nullable for jobs with no dependencies) + depends_on_job_id: Mapped[Optional[str]] = mapped_column( + String(255), ForeignKey("job_runs.id", ondelete="CASCADE"), nullable=True + ) + + # Type of dependency + dependency_type: Mapped[Optional[DependencyType]] = mapped_column(String(50), nullable=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + # Flexible metadata + metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column("metadata", JSONB, nullable=True) + + # Relationships + pipeline: Mapped["Pipeline"] = relationship("Pipeline", back_populates="job_dependencies") + job_run: Mapped["JobRun"] = relationship("JobRun", back_populates="job_dependency", foreign_keys=[id]) + depends_on_job: Mapped[Optional["JobRun"]] = relationship( + "JobRun", foreign_keys=[depends_on_job_id], remote_side="JobRun.id" + ) + + # Indexes + __table_args__ = ( + Index("ix_job_dependencies_pipeline_id", "pipeline_id"), + Index("ix_job_dependencies_depends_on_job_id", "depends_on_job_id"), + Index("ix_job_dependencies_created_at", "created_at"), + CheckConstraint( + "dependency_type IS NULL OR dependency_type IN ('success_required', 'completion_required')", + name="ck_job_dependencies_type_valid", + ), + ) + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/src/mavedb/models/job_run.py b/src/mavedb/models/job_run.py new file mode 100644 index 00000000..5b2c4160 --- /dev/null +++ b/src/mavedb/models/job_run.py @@ -0,0 +1,113 @@ +""" +SQLAlchemy models for job runs. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional + +from sqlalchemy import CheckConstraint, DateTime, Index, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.models.enums import JobStatus + +if TYPE_CHECKING: + from mavedb.models.job_dependency import JobDependency + + +class JobRun(Base): + """ + Represents a single execution of a job. + + Jobs can be retried, so there may be multiple JobRun records for the same logical job. + """ + + __tablename__ = "job_runs" + + # Primary identification + id: Mapped[str] = mapped_column(String(255), primary_key=True) + + # Job definition + job_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + job_function: Mapped[str] = mapped_column(String(255), nullable=False) + job_params: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSONB, nullable=True) + + # Execution tracking + status: Mapped[JobStatus] = mapped_column(String(50), nullable=False, default=JobStatus.PENDING) + + # Priority and scheduling + priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + max_retries: Mapped[int] = mapped_column(Integer, nullable=False, default=3) + retry_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + retry_delay_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + + # Timing + scheduled_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + # Error handling + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + error_traceback: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + failure_category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + # Progress tracking + progress_current: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + progress_total: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + progress_message: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + + # Correlation for tracing + correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True) + + # Flexible metadata + metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column("metadata", JSONB, nullable=True) + + # Version tracking + mavedb_version: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + + # Relationships + job_dependency: Mapped[Optional["JobDependency"]] = relationship( + "JobDependency", back_populates="job_run", uselist=False, foreign_keys="[JobDependency.id]" + ) + + # Indexes + __table_args__ = ( + Index("ix_job_runs_status", "status"), + Index("ix_job_runs_job_type", "job_type"), + Index("ix_job_runs_scheduled_at", "scheduled_at"), + Index("ix_job_runs_created_at", "created_at"), + Index("ix_job_runs_correlation_id", "correlation_id"), + Index("ix_job_runs_status_scheduled", "status", "scheduled_at"), + CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'cancelled', 'retrying')", + name="ck_job_runs_status_valid", + ), + CheckConstraint("priority >= 0", name="ck_job_runs_priority_positive"), + CheckConstraint("max_retries >= 0", name="ck_job_runs_max_retries_positive"), + CheckConstraint("retry_count >= 0", name="ck_job_runs_retry_count_positive"), + ) + + def __repr__(self) -> str: + return f"" + + @hybrid_property + def duration_seconds(self) -> Optional[int]: + """Calculate job duration in seconds.""" + if self.started_at and self.finished_at: + return int((self.finished_at - self.started_at).total_seconds()) + return None + + @hybrid_property + def progress_percentage(self) -> Optional[float]: + """Calculate progress as percentage.""" + if self.progress_total and self.progress_total > 0: + return (self.progress_current or 0) / self.progress_total * 100 + return None + + @property + def can_retry(self) -> bool: + """Check if job can be retried.""" + return self.status == JobStatus.FAILED and self.retry_count < self.max_retries diff --git a/src/mavedb/models/pipeline.py b/src/mavedb/models/pipeline.py new file mode 100644 index 00000000..cb4f5d37 --- /dev/null +++ b/src/mavedb/models/pipeline.py @@ -0,0 +1,88 @@ +""" +SQLAlchemy models for job pipelines. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.models.enums import PipelineStatus + +if TYPE_CHECKING: + from mavedb.models.job_dependency import JobDependency + from mavedb.models.user import User + + +class Pipeline(Base): + """ + Represents a high-level workflow that groups related jobs. + + Examples: + - Processing a score set upload + - Batch re-annotation of variants + - Database migration workflows + """ + + __tablename__ = "pipelines" + + # Primary identification + id: Mapped[str] = mapped_column(String(255), primary_key=True) + name: Mapped[str] = mapped_column(String(500), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Status and lifecycle + status: Mapped[PipelineStatus] = mapped_column(String(50), nullable=False, default=PipelineStatus.CREATED) + + # Correlation for end-to-end tracing + correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True) + + # Flexible metadata storage + metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column( + "metadata", JSONB, nullable=True, comment="Flexible metadata storage for pipeline-specific data" + ) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # User tracking + created_by_user_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + + # Version tracking + mavedb_version: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + + # Relationships + job_dependencies: Mapped[List["JobDependency"]] = relationship( + "JobDependency", back_populates="pipeline", cascade="all, delete-orphan" + ) + created_by_user: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by_user_id]) + + # Indexes + __table_args__ = ( + Index("ix_pipelines_status", "status"), + Index("ix_pipelines_created_at", "created_at"), + Index("ix_pipelines_correlation_id", "correlation_id"), + Index("ix_pipelines_created_by_user_id", "created_by_user_id"), + CheckConstraint( + "status IN ('created', 'running', 'completed', 'failed', 'cancelled')", name="ck_pipelines_status_valid" + ), + ) + + def __repr__(self) -> str: + return f"" + + @hybrid_property + def duration_seconds(self) -> Optional[int]: + """Calculate pipeline duration in seconds.""" + if self.started_at and self.finished_at: + return int((self.finished_at - self.started_at).total_seconds()) + + return None diff --git a/src/mavedb/models/variant_annotation_status.py b/src/mavedb/models/variant_annotation_status.py new file mode 100644 index 00000000..9be7f01e --- /dev/null +++ b/src/mavedb/models/variant_annotation_status.py @@ -0,0 +1,107 @@ +""" +SQLAlchemy models for variant annotation status. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from mavedb.db.base import Base +from mavedb.models.enums.job_pipeline import AnnotationStatus + +if TYPE_CHECKING: + from mavedb.models.job_run import JobRun + from mavedb.models.variant import Variant + + +class VariantAnnotationStatus(Base): + """ + Tracks annotation status for individual variants. + + Allows us to see which variants failed annotation and why. + """ + + __tablename__ = "variant_annotation_status" + + # Primary key + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + # Composite primary key + variant_id: Mapped[int] = mapped_column(Integer, ForeignKey("variants.id", ondelete="CASCADE"), primary_key=True) + annotation_type: Mapped[str] = mapped_column( + String(50), primary_key=True, comment="Type of annotation: vrs, clinvar, gnomad, etc." + ) + + # Source version + version: Mapped[Optional[str]] = mapped_column( + String(50), nullable=True, comment="Version of the annotation source used (if applicable)" + ) + + # Status tracking + status: Mapped[AnnotationStatus] = mapped_column(String(50), nullable=False, comment="success, failed, skipped") + + # Error information + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + failure_category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + # Success data (flexible JSONB for annotation results) + success_data: Mapped[Optional[Dict[str, Any]]] = mapped_column( + JSONB, nullable=True, comment="Annotation results when successful" + ) + + # Current flag + current: Mapped[bool] = mapped_column( + nullable=False, + server_default="true", + comment="Whether this is the current status for the variant and annotation type", + ) + + # Job tracking + job_run_id: Mapped[Optional[str]] = mapped_column( + String(255), ForeignKey("job_runs.id", ondelete="SET NULL"), nullable=True + ) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + # Relationships + variant: Mapped["Variant"] = relationship("Variant") + job_run: Mapped[Optional["JobRun"]] = relationship("JobRun") + + # Indexes + __table_args__ = ( + Index("ix_variant_annotation_status_variant_id", "variant_id"), + Index("ix_variant_annotation_status_annotation_type", "annotation_type"), + Index("ix_variant_annotation_status_status", "status"), + Index("ix_variant_annotation_status_job_run_id", "job_run_id"), + Index("ix_variant_annotation_status_created_at", "created_at"), + # Composite index for common queries + Index("ix_variant_annotation_type_status", "annotation_type", "status"), + Index("ix_variant_annotation_status_current", "current"), + Index("ix_variant_annotation_status_version", "version"), + Index( + "ix_variant_annotation_status_variant_type_version_current", + "variant_id", + "annotation_type", + "version", + "current", + ), + CheckConstraint( + "annotation_type IN ('vrs_mapping', 'clingen_allele_id', 'mapped_hgvs', 'variant_translation', 'gnomad_allele_frequency', 'clinvar_control', 'vep_functional_consequence', 'ldh_submission')", + name="ck_variant_annotation_type_valid", + ), + CheckConstraint( + "status IN ('success', 'failed', 'skipped')", + name="ck_variant_annotation_status_valid", + ), + ## Although un-enforced at the DB level, we should ensure only one 'current' record per (variant_id, annotation_type, version) + ) + + def __repr__(self) -> str: + return f"" diff --git a/tests/worker/conftest.py b/tests/worker/conftest.py index 49dad88f..cf996c1d 100644 --- a/tests/worker/conftest.py +++ b/tests/worker/conftest.py @@ -1,20 +1,23 @@ +from datetime import datetime from pathlib import Path from shutil import copytree from unittest.mock import Mock import pytest +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun from mavedb.models.license import License +from mavedb.models.pipeline import Pipeline from mavedb.models.taxonomy import Taxonomy from mavedb.models.user import User - from tests.helpers.constants import ( EXTRA_USER, - TEST_LICENSE, TEST_INACTIVE_LICENSE, + TEST_LICENSE, + TEST_MAVEDB_ATHENA_ROW, TEST_SAVED_TAXONOMY, TEST_USER, - TEST_MAVEDB_ATHENA_ROW, ) @@ -29,6 +32,83 @@ def setup_worker_db(session): db.commit() +@pytest.fixture +def with_populated_job_data( + session, + sample_job_run, + sample_pipeline, + sample_empty_pipeline, + sample_job_dependency, + sample_dependent_job_run, + sample_independent_job_run, +): + """Set up the database with sample data for worker tests.""" + session.add(sample_pipeline) + session.add(sample_empty_pipeline) + session.add(sample_job_run) + session.add(sample_dependent_job_run) + session.add(sample_independent_job_run) + session.add(sample_job_dependency) + session.commit() + + +@pytest.fixture +def mock_pipeline(): + """Create a mock Pipeline instance. By default, + properties are identical to a default new Pipeline entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=Pipeline, + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + metadata_={}, + created_at=datetime.now(), + started_at=None, + finished_at=None, + created_by_user_id=None, + mavedb_version=None, + ) + + +@pytest.fixture +def mock_job_run(mock_pipeline): + """Create a mock JobRun instance. By default, + properties are identical to a default new JobRun entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=JobRun, + id=123, + urn="test:job:123", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=mock_pipeline.id, + priority=0, + max_retries=3, + retry_count=0, + retry_delay_seconds=None, + scheduled_at=datetime.now(), + started_at=None, + finished_at=None, + created_at=datetime.now(), + error_message=None, + error_traceback=None, + failure_category=None, + progress_current=None, + progress_total=None, + progress_message=None, + correlation_id=None, + metadata_={}, + mavedb_version=None, + ) + + @pytest.fixture def data_files(tmp_path): copytree(Path(__file__).absolute().parent / "data", tmp_path / "data") From 510614cb1530158cc95199b475873291e145139c Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 7 Jan 2026 11:50:51 -0800 Subject: [PATCH 03/70] fix(logging): simplify context saving logic to overwrite existing mappings --- src/mavedb/lib/logging/context.py | 10 +--- src/mavedb/lib/urns.py | 22 +++++++ src/mavedb/models/enums/job_pipeline.py | 16 ++++-- src/mavedb/models/job_dependency.py | 41 ++++++------- src/mavedb/models/job_run.py | 57 +++++++++---------- src/mavedb/models/pipeline.py | 37 ++++++------ .../models/variant_annotation_status.py | 12 +++- 7 files changed, 108 insertions(+), 87 deletions(-) diff --git a/src/mavedb/lib/logging/context.py b/src/mavedb/lib/logging/context.py index 6771f760..075efb58 100644 --- a/src/mavedb/lib/logging/context.py +++ b/src/mavedb/lib/logging/context.py @@ -55,15 +55,7 @@ def save_to_logging_context(ctx: dict) -> dict: return {} for k, v in ctx.items(): - # Don't overwrite existing context mappings but create a list if a duplicated key is added. - if k in context: - existing_ctx = context[k] - if isinstance(existing_ctx, list): - context[k].append(v) - else: - context[k] = [existing_ctx, v] - else: - context[k] = v + context[k] = v return context.data diff --git a/src/mavedb/lib/urns.py b/src/mavedb/lib/urns.py index e3903ac8..55a59e70 100644 --- a/src/mavedb/lib/urns.py +++ b/src/mavedb/lib/urns.py @@ -153,3 +153,25 @@ def generate_calibration_urn(): :return: A new calibration URN """ return f"urn:mavedb:calibration-{uuid4()}" + + +def generate_pipeline_urn(): + """ + Generate a new URN for a pipeline. + + Pipeline URNs include a 16-digit UUID. + + :return: A new pipeline URN + """ + return f"urn:mavedb:pipeline-{uuid4()}" + + +def generate_job_run_urn(): + """ + Generate a new URN for a job run. + + Job run URNs include a 16-digit UUID. + + :return: A new job run URN + """ + return f"urn:mavedb:job-{uuid4()}" diff --git a/src/mavedb/models/enums/job_pipeline.py b/src/mavedb/models/enums/job_pipeline.py index c8cc78e8..0900b580 100644 --- a/src/mavedb/models/enums/job_pipeline.py +++ b/src/mavedb/models/enums/job_pipeline.py @@ -8,10 +8,11 @@ class JobStatus(str, Enum): """Status of a job execution.""" + SUCCEEDED = "succeeded" + FAILED = "failed" PENDING = "pending" + QUEUED = "queued" RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" CANCELLED = "cancelled" SKIPPED = "skipped" @@ -19,11 +20,13 @@ class JobStatus(str, Enum): class PipelineStatus(str, Enum): """Status of a pipeline execution.""" + SUCCEEDED = "succeeded" + FAILED = "failed" CREATED = "created" RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" + PAUSED = "paused" CANCELLED = "cancelled" + PARTIAL = "partial" # Pipeline completed with mixed results (some succeeded, some skipped/cancelled) class DependencyType(str, Enum): @@ -43,6 +46,11 @@ class FailureCategory(str, Enum): CONFIGURATION_ERROR = "configuration_error" DEPENDENCY_FAILURE = "dependency_failure" + # Queue and scheduling failures + ENQUEUE_ERROR = "enqueue_error" + SCHEDULING_ERROR = "scheduling_error" + CANCELLED = "cancelled" + # Data and validation failures VALIDATION_ERROR = "validation_error" DATA_ERROR = "data_error" diff --git a/src/mavedb/models/job_dependency.py b/src/mavedb/models/job_dependency.py index 414c49c1..ac851c7d 100644 --- a/src/mavedb/models/job_dependency.py +++ b/src/mavedb/models/job_dependency.py @@ -5,8 +5,9 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Optional -from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, String, func +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, func from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from mavedb.db.base import Base @@ -14,7 +15,6 @@ if TYPE_CHECKING: from mavedb.models.job_run import JobRun - from mavedb.models.pipeline import Pipeline class JobDependency(Base): @@ -22,42 +22,37 @@ class JobDependency(Base): Defines dependencies between jobs within a pipeline. This table maps jobs to their pipeline and defines execution order. + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. """ __tablename__ = "job_dependencies" - # The job being defined (references job_runs.id) - id: Mapped[str] = mapped_column(String(255), ForeignKey("job_runs.id", ondelete="CASCADE"), primary_key=True) - - # Pipeline this job belongs to - pipeline_id: Mapped[str] = mapped_column( - String(255), ForeignKey("pipelines.id", ondelete="CASCADE"), nullable=False - ) - - # Job this depends on (nullable for jobs with no dependencies) - depends_on_job_id: Mapped[Optional[str]] = mapped_column( - String(255), ForeignKey("job_runs.id", ondelete="CASCADE"), nullable=True + # The job being defined (references job_runs.id). Composite primary key with the dependency we are defining. + id: Mapped[int] = mapped_column(Integer, ForeignKey("job_runs.id", ondelete="CASCADE"), primary_key=True) + depends_on_job_id: Mapped[int] = mapped_column( + Integer, ForeignKey("job_runs.id", ondelete="CASCADE"), nullable=False, primary_key=True ) # Type of dependency - dependency_type: Mapped[Optional[DependencyType]] = mapped_column(String(50), nullable=True) + dependency_type: Mapped[Optional[DependencyType]] = mapped_column(String(50), nullable=False) # Timestamps created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) # Flexible metadata - metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column("metadata", JSONB, nullable=True) + metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column( + "metadata", MutableDict.as_mutable(JSONB), nullable=True + ) # Relationships - pipeline: Mapped["Pipeline"] = relationship("Pipeline", back_populates="job_dependencies") - job_run: Mapped["JobRun"] = relationship("JobRun", back_populates="job_dependency", foreign_keys=[id]) - depends_on_job: Mapped[Optional["JobRun"]] = relationship( - "JobRun", foreign_keys=[depends_on_job_id], remote_side="JobRun.id" - ) + job_run: Mapped["JobRun"] = relationship("JobRun", back_populates="job_dependencies", foreign_keys=[id]) + depends_on_job: Mapped["JobRun"] = relationship("JobRun", foreign_keys=[depends_on_job_id], remote_side="JobRun.id") # Indexes __table_args__ = ( - Index("ix_job_dependencies_pipeline_id", "pipeline_id"), Index("ix_job_dependencies_depends_on_job_id", "depends_on_job_id"), Index("ix_job_dependencies_created_at", "created_at"), CheckConstraint( @@ -67,6 +62,4 @@ class JobDependency(Base): ) def __repr__(self) -> str: - return ( - f"" - ) + return f"" diff --git a/src/mavedb/models/job_run.py b/src/mavedb/models/job_run.py index 5b2c4160..9ec039cd 100644 --- a/src/mavedb/models/job_run.py +++ b/src/mavedb/models/job_run.py @@ -5,16 +5,18 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Optional -from sqlalchemy import CheckConstraint, DateTime, Index, Integer, String, Text, func +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from mavedb.db.base import Base +from mavedb.lib.urns import generate_job_run_urn from mavedb.models.enums import JobStatus if TYPE_CHECKING: from mavedb.models.job_dependency import JobDependency + from mavedb.models.pipeline import Pipeline class JobRun(Base): @@ -22,21 +24,31 @@ class JobRun(Base): Represents a single execution of a job. Jobs can be retried, so there may be multiple JobRun records for the same logical job. + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. """ __tablename__ = "job_runs" # Primary identification - id: Mapped[str] = mapped_column(String(255), primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + urn: Mapped[str] = mapped_column(String(255), nullable=True, unique=True, default=generate_job_run_urn) # Job definition - job_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + job_type: Mapped[str] = mapped_column(String(100), nullable=False) job_function: Mapped[str] = mapped_column(String(255), nullable=False) - job_params: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSONB, nullable=True) + job_params: Mapped[Optional[Dict[str, Any]]] = mapped_column(MutableDict.as_mutable(JSONB), nullable=True) # Execution tracking status: Mapped[JobStatus] = mapped_column(String(50), nullable=False, default=JobStatus.PENDING) + # Pipeline association + pipeline_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("pipelines.id", ondelete="SET NULL"), nullable=True + ) + # Priority and scheduling priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0) max_retries: Mapped[int] = mapped_column(Integer, nullable=False, default=3) @@ -60,29 +72,35 @@ class JobRun(Base): progress_message: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) # Correlation for tracing - correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True) + correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) # Flexible metadata - metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column("metadata", JSONB, nullable=True) + metadata_: Mapped[Dict[str, Any]] = mapped_column( + "metadata", MutableDict.as_mutable(JSONB), nullable=False, server_default="{}" + ) # Version tracking mavedb_version: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) # Relationships - job_dependency: Mapped[Optional["JobDependency"]] = relationship( - "JobDependency", back_populates="job_run", uselist=False, foreign_keys="[JobDependency.id]" + job_dependencies: Mapped[list["JobDependency"]] = relationship( + "JobDependency", back_populates="job_run", uselist=True, foreign_keys="[JobDependency.id]" + ) + pipeline: Mapped[Optional["Pipeline"]] = relationship( + "Pipeline", back_populates="job_runs", foreign_keys="[JobRun.pipeline_id]" ) # Indexes __table_args__ = ( Index("ix_job_runs_status", "status"), Index("ix_job_runs_job_type", "job_type"), + Index("ix_job_runs_pipeline_id", "pipeline_id"), Index("ix_job_runs_scheduled_at", "scheduled_at"), Index("ix_job_runs_created_at", "created_at"), Index("ix_job_runs_correlation_id", "correlation_id"), Index("ix_job_runs_status_scheduled", "status", "scheduled_at"), CheckConstraint( - "status IN ('pending', 'running', 'completed', 'failed', 'cancelled', 'retrying')", + "status IN ('pending', 'queued', 'running', 'succeeded', 'failed', 'cancelled', 'skipped')", name="ck_job_runs_status_valid", ), CheckConstraint("priority >= 0", name="ck_job_runs_priority_positive"), @@ -92,22 +110,3 @@ class JobRun(Base): def __repr__(self) -> str: return f"" - - @hybrid_property - def duration_seconds(self) -> Optional[int]: - """Calculate job duration in seconds.""" - if self.started_at and self.finished_at: - return int((self.finished_at - self.started_at).total_seconds()) - return None - - @hybrid_property - def progress_percentage(self) -> Optional[float]: - """Calculate progress as percentage.""" - if self.progress_total and self.progress_total > 0: - return (self.progress_current or 0) / self.progress_total * 100 - return None - - @property - def can_retry(self) -> bool: - """Check if job can be retried.""" - return self.status == JobStatus.FAILED and self.retry_count < self.max_retries diff --git a/src/mavedb/models/pipeline.py b/src/mavedb/models/pipeline.py index cb4f5d37..717ec24c 100644 --- a/src/mavedb/models/pipeline.py +++ b/src/mavedb/models/pipeline.py @@ -7,14 +7,15 @@ from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from mavedb.db.base import Base +from mavedb.lib.urns import generate_pipeline_urn from mavedb.models.enums import PipelineStatus +from mavedb.models.job_run import JobRun if TYPE_CHECKING: - from mavedb.models.job_dependency import JobDependency from mavedb.models.user import User @@ -26,12 +27,17 @@ class Pipeline(Base): - Processing a score set upload - Batch re-annotation of variants - Database migration workflows + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. """ __tablename__ = "pipelines" # Primary identification - id: Mapped[str] = mapped_column(String(255), primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + urn: Mapped[str] = mapped_column(String(255), nullable=True, unique=True, default=generate_pipeline_urn) name: Mapped[str] = mapped_column(String(500), nullable=False) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) @@ -39,11 +45,15 @@ class Pipeline(Base): status: Mapped[PipelineStatus] = mapped_column(String(50), nullable=False, default=PipelineStatus.CREATED) # Correlation for end-to-end tracing - correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True) + correlation_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) # Flexible metadata storage - metadata_: Mapped[Optional[Dict[str, Any]]] = mapped_column( - "metadata", JSONB, nullable=True, comment="Flexible metadata storage for pipeline-specific data" + metadata_: Mapped[Dict[str, Any]] = mapped_column( + "metadata", + MutableDict.as_mutable(JSONB), + nullable=False, + comment="Flexible metadata storage for pipeline-specific data", + server_default="{}", ) # Timestamps @@ -60,9 +70,7 @@ class Pipeline(Base): mavedb_version: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) # Relationships - job_dependencies: Mapped[List["JobDependency"]] = relationship( - "JobDependency", back_populates="pipeline", cascade="all, delete-orphan" - ) + job_runs: Mapped[List["JobRun"]] = relationship("JobRun", back_populates="pipeline", cascade="all, delete-orphan") created_by_user: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by_user_id]) # Indexes @@ -72,17 +80,10 @@ class Pipeline(Base): Index("ix_pipelines_correlation_id", "correlation_id"), Index("ix_pipelines_created_by_user_id", "created_by_user_id"), CheckConstraint( - "status IN ('created', 'running', 'completed', 'failed', 'cancelled')", name="ck_pipelines_status_valid" + "status IN ('created', 'running', 'succeeded', 'failed', 'cancelled', 'paused', 'partial')", + name="ck_pipelines_status_valid", ), ) def __repr__(self) -> str: return f"" - - @hybrid_property - def duration_seconds(self) -> Optional[int]: - """Calculate pipeline duration in seconds.""" - if self.started_at and self.finished_at: - return int((self.finished_at - self.started_at).total_seconds()) - - return None diff --git a/src/mavedb/models/variant_annotation_status.py b/src/mavedb/models/variant_annotation_status.py index 9be7f01e..3051b4d3 100644 --- a/src/mavedb/models/variant_annotation_status.py +++ b/src/mavedb/models/variant_annotation_status.py @@ -7,6 +7,7 @@ from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Index, Integer, String, Text, func from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from mavedb.db.base import Base @@ -22,6 +23,10 @@ class VariantAnnotationStatus(Base): Tracks annotation status for individual variants. Allows us to see which variants failed annotation and why. + + NOTE: JSONB fields are automatically tracked as mutable objects in this class via MutableDict. + This tracker only works for top-level mutations. If you mutate nested objects, you must call + `flag_modified(instance, "metadata_")` to ensure changes are persisted. """ __tablename__ = "variant_annotation_status" @@ -49,7 +54,7 @@ class VariantAnnotationStatus(Base): # Success data (flexible JSONB for annotation results) success_data: Mapped[Optional[Dict[str, Any]]] = mapped_column( - JSONB, nullable=True, comment="Annotation results when successful" + MutableDict.as_mutable(JSONB), nullable=True, comment="Annotation results when successful" ) # Current flag @@ -60,8 +65,8 @@ class VariantAnnotationStatus(Base): ) # Job tracking - job_run_id: Mapped[Optional[str]] = mapped_column( - String(255), ForeignKey("job_runs.id", ondelete="SET NULL"), nullable=True + job_run_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("job_runs.id", ondelete="SET NULL"), nullable=True ) # Timestamps @@ -82,6 +87,7 @@ class VariantAnnotationStatus(Base): Index("ix_variant_annotation_status_job_run_id", "job_run_id"), Index("ix_variant_annotation_status_created_at", "created_at"), # Composite index for common queries + Index("ix_variant_annotation_variant_type_status", "variant_id", "annotation_type", "status"), Index("ix_variant_annotation_type_status", "annotation_type", "status"), Index("ix_variant_annotation_status_current", "current"), Index("ix_variant_annotation_status_version", "version"), From 83b34d630b615c02deec8f6815552d5098607e08 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Sun, 11 Jan 2026 23:19:57 -0800 Subject: [PATCH 04/70] tests: add TransactionSpy class for mocking database transaction methods and failures --- tests/helpers/transaction_spy.py | 222 +++++++++++++++++++++++++++++++ tests/helpers/util/common.py | 31 +++++ 2 files changed, 253 insertions(+) create mode 100644 tests/helpers/transaction_spy.py diff --git a/tests/helpers/transaction_spy.py b/tests/helpers/transaction_spy.py new file mode 100644 index 00000000..4381aa75 --- /dev/null +++ b/tests/helpers/transaction_spy.py @@ -0,0 +1,222 @@ +from contextlib import contextmanager +from typing import Generator, TypedDict, Union +from unittest.mock import AsyncMock, MagicMock, patch + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from tests.helpers.util.common import create_failing_side_effect + + +class TransactionSpy: + """Factory for creating database transaction spy context managers.""" + + class Spies(TypedDict): + flush: Union[MagicMock, AsyncMock] + rollback: Union[MagicMock, AsyncMock] + commit: Union[MagicMock, AsyncMock] + + class SpiesWithException(Spies): + exception: Exception + + @staticmethod + @contextmanager + def spy( + session: Session, + expect_rollback: bool = False, + expect_flush: bool = False, + expect_commit: bool = False, + ) -> Generator[Spies, None, None]: + """ + Create spies for database transaction methods. + + Args: + session: Database session to spy on + expect_rollback: Whether to assert db.rollback to be called + expect_flush: Whether to assert db.flush to be called + expect_commit: Whether to assert db.commit to be called + + Yields: + dict: Dictionary containing all the spies for granular assertion + + Note: + Use caution when combining expectations. For example, if expect_commit + is True, you may wish to set expect_flush to True as well, since commit + typically implies a flush operation within SQLAlchemy internals. + + Example: + ``` + with TransactionSpy.spy(session, expect_rollback=True) as spies: + # perform operation + ... + + # Make manual granular assertions on spies if desired + spies['rollback'].assert_called_once() + + # if assert_XXX=True is set, automatic assertions will be made at context exit. + # In this example, expect_rollback=True will ensure rollback was called at some point. + ``` + """ + with ( + patch.object(session, "rollback", wraps=session.rollback) as rollback_spy, + patch.object(session, "flush", wraps=session.flush) as flush_spy, + patch.object(session, "commit", wraps=session.commit) as commit_spy, + ): + spies: TransactionSpy.Spies = { + "flush": flush_spy, + "rollback": rollback_spy, + "commit": commit_spy, + } + + yield spies + + # Automatic assertions based on session expectations. + if expect_flush: + flush_spy.assert_called() + else: + flush_spy.assert_not_called() + if expect_rollback: + rollback_spy.assert_called() + else: + rollback_spy.assert_not_called() + if expect_commit: + commit_spy.assert_called() + else: + commit_spy.assert_not_called() + + @staticmethod + @contextmanager + def mock_database_execution_failure( + session: Session, + exception=None, + fail_on_call=1, + expect_rollback: bool = False, + expect_flush: bool = False, + expect_commit: bool = False, + ) -> Generator[SpiesWithException, None, None]: + """ + Create a context that mocks database execution failures with transaction spies. This context + will automatically assert calls to rollback, flush, and commit based on the provided expectations + which all default to False. + + Args: + session: Database session to mock + exception: Exception to raise (defaults to SQLAlchemyError) + fail_on_call: Which call should fail (defaults to first call) + expect_rollback: Whether to assert rollback called (defaults to False) + expect_flush: Whether to assert flush called (defaults to False) + expect_commit: Whether to assert commit called (defaults to False) + Yields: + dict: Dictionary containing spies and the exception that will be raised + """ + exception = exception or SQLAlchemyError("DB Error") + + with ( + patch.object( + session, + "execute", + side_effect=create_failing_side_effect(exception, session.execute, fail_on_call), + ), + TransactionSpy.spy( + session, + expect_rollback=expect_rollback, + expect_flush=expect_flush, + expect_commit=expect_commit, + ) as transaction_spies, + ): + spies: TransactionSpy.SpiesWithException = { + **transaction_spies, + "exception": exception, + } + + yield spies + + @staticmethod + @contextmanager + def mock_database_flush_failure( + session: Session, + exception=None, + fail_on_call=1, + expect_rollback: bool = True, + expect_flush: bool = True, + expect_commit: bool = False, + ) -> Generator[SpiesWithException, None, None]: + """ + Create a context that mocks flush failures specifically. This context will automatically + assert that rollback and flush are called, and that commit is not called. These automatic + assertions can be overridden via the expect_XXX parameters. + + Args: + session: Database session to mock + exception: Exception to raise on flush (defaults to SQLAlchemyError) + fail_on_call: Which flush call should fail (defaults to first call) + expect_rollback: Whether to assert rollback called (defaults to True) + expect_flush: Whether to assert flush called (defaults to True) + expect_commit: Whether to assert commit called (defaults to False) + Yields: + dict: Dictionary containing spies and the exception + """ + exception = exception or SQLAlchemyError("Flush Error") + + with ( + patch.object( + session, "flush", side_effect=create_failing_side_effect(exception, session.flush, fail_on_call) + ), + TransactionSpy.spy( + session, + expect_rollback=expect_rollback, + expect_flush=expect_flush, + expect_commit=expect_commit, + ) as transaction_spies, + ): + spies: TransactionSpy.SpiesWithException = { + **transaction_spies, + "exception": exception, + } + + yield spies + + @staticmethod + @contextmanager + def mock_database_rollback_failure( + session: Session, + exception=None, + fail_on_call=1, + expect_rollback: bool = True, + expect_flush: bool = False, + expect_commit: bool = False, + ) -> Generator[SpiesWithException, None, None]: + """ + Create a context that mocks rollback failures specifically. This context will automatically + assert that rollback is called, flush is not called, and commit is not called. These automatic + assertions can be overridden via the expect_XXX parameters. + + Args: + session: Database session to mock + exception: Exception to raise on rollback (defaults to SQLAlchemyError) + fail_on_call: Which rollback call should fail (defaults to first call) + expect_rollback: Whether to assert rollback called (defaults to True) + expect_flush: Whether to assert flush called (defaults to False) + expect_commit: Whether to assert commit called (defaults to False) + Yields: + dict: Dictionary containing spies and the exception + """ + exception = exception or SQLAlchemyError("Rollback Error") + + with ( + patch.object( + session, "rollback", side_effect=create_failing_side_effect(exception, session.rollback, fail_on_call) + ), + TransactionSpy.spy( + session, + expect_rollback=expect_rollback, + expect_flush=expect_flush, + expect_commit=expect_commit, + ) as transaction_spies, + ): + spies: TransactionSpy.SpiesWithException = { + **transaction_spies, + "exception": exception, + } + + yield spies diff --git a/tests/helpers/util/common.py b/tests/helpers/util/common.py index 407cf101..0acf2c1e 100644 --- a/tests/helpers/util/common.py +++ b/tests/helpers/util/common.py @@ -56,3 +56,34 @@ def deepcamelize(data: Any) -> Any: return [deepcamelize(item) for item in data] else: return data + + +def create_failing_side_effect(exception, original_method, fail_on_call=1): + """ + Create a side effect function that fails on a specific call number, then delegates to original method. + + Args: + exception: The exception to raise on the failing call + original_method: The original method to delegate to after the failure + fail_on_call: Which call number should fail (1-indexed, defaults to first call) + + Returns: + A callable that can be used as a side_effect in mock.patch + + Example: + with patch.object(session, "execute", side_effect=create_failing_side_effect( + SQLAlchemyError("DB Error"), session.execute + )): + # First call will raise SQLAlchemyError, subsequent calls work normally + pass + """ + call_count = 0 + + def side_effect_function(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == fail_on_call: + raise exception + return original_method(*args, **kwargs) + + return side_effect_function From 224bbb3d5d9d4b7245a586bb528313972cc8a05c Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Sun, 11 Jan 2026 23:20:09 -0800 Subject: [PATCH 05/70] feat: add BaseManager class with transaction handling and rollback features --- .../worker/lib/managers/base_manager.py | 41 +++++++++++++++++++ .../worker/lib/managers/test_base_manager.py | 19 +++++++++ 2 files changed, 60 insertions(+) create mode 100644 src/mavedb/worker/lib/managers/base_manager.py create mode 100644 tests/worker/lib/managers/test_base_manager.py diff --git a/src/mavedb/worker/lib/managers/base_manager.py b/src/mavedb/worker/lib/managers/base_manager.py new file mode 100644 index 00000000..08da4670 --- /dev/null +++ b/src/mavedb/worker/lib/managers/base_manager.py @@ -0,0 +1,41 @@ +"""Base manager class providing common database transaction handling. + +This module provides the BaseManager class that encapsulates common database +session management patterns used across all manager classes. +""" + +import logging +from abc import ABC + +from arq import ArqRedis +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class BaseManager(ABC): + """Base class for all manager classes providing common interface. + + Provides standardized pattern for initializing a manager with database + and Redis connections. + + Features: + - Common initialization pattern + + Attributes: + db: SQLAlchemy database session for queries and transactions + redis: ARQ Redis client for job queue operations + """ + + def __init__(self, db: Session, redis: ArqRedis): + """Initialize base manager with database and Redis connections. + + Args: + db: SQLAlchemy database session for job and pipeline queries + redis: ARQ Redis client for job queue operations + + Raises: + DatabaseConnectionError: Cannot connect to database + """ + self.db = db + self.redis = redis diff --git a/tests/worker/lib/managers/test_base_manager.py b/tests/worker/lib/managers/test_base_manager.py new file mode 100644 index 00000000..7f5c3a91 --- /dev/null +++ b/tests/worker/lib/managers/test_base_manager.py @@ -0,0 +1,19 @@ +# ruff: noqa: E402 +import pytest + +pytest.importorskip("arq") + +from mavedb.worker.lib.managers.base_manager import BaseManager + + +@pytest.mark.integration +class TestInitialization: + """Tests for BaseManager initialization.""" + + def test_initialization(self, session, arq_redis): + """Test that BaseManager initializes with db and redis attributes.""" + + manager = BaseManager(db=session, redis=arq_redis) + + assert manager.db == session + assert manager.redis == arq_redis From 05fc52ba0aa62cf994ea5df2c1f2b98d714c0df1 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Mon, 12 Jan 2026 10:17:46 -0800 Subject: [PATCH 06/70] feat: Job manager class, supporting utilities, and unit tests Add comprehensive job lifecycle management with status-based completion: * Implement convenience methods for common job outcomes: - succeed_job() for successful completion - fail_job() for error handling with exception details - cancel_job() for user/system cancellation - skip_job() for conditional job skipping * Enhance progress tracking with increment_progress() and set_progress_total() * Add comprehensive error handling with specific exception types * Improve job state validation and atomic transaction handling * Implement extensive test coverage for all job operations --- src/mavedb/worker/lib/__init__.py | 7 + src/mavedb/worker/lib/managers/__init__.py | 61 + src/mavedb/worker/lib/managers/constants.py | 35 + src/mavedb/worker/lib/managers/exceptions.py | 36 + src/mavedb/worker/lib/managers/job_manager.py | 840 +++++++ src/mavedb/worker/lib/managers/types.py | 14 + src/mavedb/worker/lib/py.typed | 0 tests/worker/lib/conftest.py | 191 ++ tests/worker/lib/managers/test_job_manager.py | 2132 +++++++++++++++++ 9 files changed, 3316 insertions(+) create mode 100644 src/mavedb/worker/lib/__init__.py create mode 100644 src/mavedb/worker/lib/managers/__init__.py create mode 100644 src/mavedb/worker/lib/managers/constants.py create mode 100644 src/mavedb/worker/lib/managers/exceptions.py create mode 100644 src/mavedb/worker/lib/managers/job_manager.py create mode 100644 src/mavedb/worker/lib/managers/types.py create mode 100644 src/mavedb/worker/lib/py.typed create mode 100644 tests/worker/lib/conftest.py create mode 100644 tests/worker/lib/managers/test_job_manager.py diff --git a/src/mavedb/worker/lib/__init__.py b/src/mavedb/worker/lib/__init__.py new file mode 100644 index 00000000..e011ce18 --- /dev/null +++ b/src/mavedb/worker/lib/__init__.py @@ -0,0 +1,7 @@ +""" +Worker library modules for job management and coordination. +""" + +from .managers import JobManager + +__all__ = ["JobManager"] diff --git a/src/mavedb/worker/lib/managers/__init__.py b/src/mavedb/worker/lib/managers/__init__.py new file mode 100644 index 00000000..f5a21c38 --- /dev/null +++ b/src/mavedb/worker/lib/managers/__init__.py @@ -0,0 +1,61 @@ +"""Manager classes and shared utilities for job coordination. + +This package provides managers for job lifecycle,along with shared constants, exceptions, +and types used across the worker system. + +Main Classes: + JobManager: Individual job lifecycle management + +Shared Utilities: + Constants: Job statuses, timeouts, retry limits + Exceptions: Standardized error hierarchy + Types: TypedDict definitions and common type hints + +Example Usage: + >>> from mavedb.worker.lib.managers import JobManager + >>> from mavedb.worker.lib.managers import JobStateError, TERMINAL_JOB_STATUSES + >>> + >>> job_manager = JobManager(db, redis, job_id) + >>> pipeline_manager = PipelineManager(db, redis) + >>> + >>> # Individual job operations + >>> job_manager.start_job() + >>> job_manager.succeed_job({"output": "success"}) + >>> +""" + +# Main manager classes +# Commonly used constants +# Main manager classes +from .base_manager import BaseManager +from .constants import ( + ACTIVE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, +) + +# Exception hierarchy +from .exceptions import ( + DatabaseConnectionError, + JobStateError, + JobTransitionError, +) +from .job_manager import JobManager + +# Type definitions +from .types import JobResultData, RetryHistoryEntry + +__all__ = [ + # Main classes + "BaseManager", + "JobManager", + # Constants + "ACTIVE_JOB_STATUSES", + "TERMINAL_JOB_STATUSES", + # Exceptions + "DatabaseConnectionError", + "JobStateError", + "JobTransitionError", + # Types + "JobResultData", + "RetryHistoryEntry", +] diff --git a/src/mavedb/worker/lib/managers/constants.py b/src/mavedb/worker/lib/managers/constants.py new file mode 100644 index 00000000..acc95236 --- /dev/null +++ b/src/mavedb/worker/lib/managers/constants.py @@ -0,0 +1,35 @@ +"""Constants for job management and pipeline coordination. + +This module defines commonly used job status groupings that are used throughout +the job management system for state validation, dependency checking, and +pipeline coordination. +""" + +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus + +# Job status constants for common groupings +STARTABLE_JOB_STATUSES = [JobStatus.QUEUED, JobStatus.PENDING] +"""Job statuses that can be transitioned to RUNNING state.""" + +COMPLETED_JOB_STATUSES = [JobStatus.SUCCEEDED, JobStatus.FAILED] +"""Job statuses indicating finished execution (completed states).""" + +TERMINAL_JOB_STATUSES = [JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED, JobStatus.SKIPPED] +"""Job statuses indicating finished execution (terminal states).""" + +CANCELLED_JOB_STATUSES = [JobStatus.CANCELLED, JobStatus.SKIPPED, JobStatus.FAILED] +"""Job statuses that should stop execution (termination conditions).""" + +RETRYABLE_JOB_STATUSES = [JobStatus.FAILED, JobStatus.CANCELLED, JobStatus.SKIPPED] +"""Job statuses that can be retried.""" + +ACTIVE_JOB_STATUSES = [JobStatus.PENDING, JobStatus.QUEUED, JobStatus.RUNNING] +"""Job statuses that can be cancelled/skipped when pipeline fails.""" + +RETRYABLE_FAILURE_CATEGORIES = ( + FailureCategory.NETWORK_ERROR, + FailureCategory.TIMEOUT, + FailureCategory.SERVICE_UNAVAILABLE, + # TODO: Add more retryable exception types as needed +) +"""Failure categories that are considered retryable errors.""" diff --git a/src/mavedb/worker/lib/managers/exceptions.py b/src/mavedb/worker/lib/managers/exceptions.py new file mode 100644 index 00000000..7a0ede6b --- /dev/null +++ b/src/mavedb/worker/lib/managers/exceptions.py @@ -0,0 +1,36 @@ +""" +Manager Exceptions for explicit error handling. +""" + + +class ManagerError(Exception): + """Base exception for Manager operations.""" + + pass + + +## Job Manager Exceptions + + +class JobManagerError(ManagerError): + """Job Manager specific errors.""" + + pass + + +class JobStateError(JobManagerError): + """Critical job state operations failed - database issues preventing state persistence.""" + + pass + + +class JobTransitionError(JobManagerError): + """Job is in wrong state for requested operation.""" + + pass + + +class DatabaseConnectionError(JobStateError): + """Database connection issues preventing any operations.""" + + pass diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py new file mode 100644 index 00000000..1da3e581 --- /dev/null +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -0,0 +1,840 @@ +"""Job lifecycle management for individual job state transitions. + +This module provides the JobManager class for managing individual job state transitions +with atomic operations and explicit error handling to ensure data consistency. +Pipeline coordination is handled separately by the PipelineManager. + +Example usage: + >>> from mavedb.worker.lib.job_manager import JobManager + >>> + >>> # Initialize with database and Redis connections + >>> job_manager = JobManager(db_session, redis_client, job_id=123) + >>> + >>> # Start job execution + >>> job_manager.start_job() + >>> + >>> # Update progress during execution + >>> job_manager.update_progress(50, 100, "Processing variants...") + >>> + >>> # Complete job (pipeline coordination handled separately) + >>> job_manager.complete_job( + ... status=JobStatus.SUCCEEDED, + ... result={"variants_processed": 1000} + ... ) + +Error Handling: + The JobManager uses specific exception types to distinguish between different + failure modes, allowing callers to implement appropriate recovery strategies: + + - DatabaseConnectionError: Database connectivity issues + - JobStateError: Critical state persistence failures + - JobTransitionError: Invalid state transitions +""" + +import logging +import traceback +from datetime import datetime +from typing import Optional + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from sqlalchemy.orm.attributes import flag_modified + +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.managers.base_manager import BaseManager +from mavedb.worker.lib.managers.constants import ( + CANCELLED_JOB_STATUSES, + RETRYABLE_FAILURE_CATEGORIES, + RETRYABLE_JOB_STATUSES, + STARTABLE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + JobStateError, + JobTransitionError, +) +from mavedb.worker.lib.managers.types import JobResultData, RetryHistoryEntry + +logger = logging.getLogger(__name__) + + +class JobManager(BaseManager): + """Manages individual job lifecycle with atomic state transitions. + + The JobManager provides a high-level interface for managing individual job execution + while ensuring database consistency. It handles job state transitions, progress updates, + and retry logic. Pipeline coordination is handled separately by the PipelineManager. + + Key Features: + - Atomic state transitions with rollback on failure + - Explicit exception handling for different failure modes + - Progress tracking and retry mechanisms + - Automatic session cleanup on object manipulation failures + - Focus on individual job lifecycle only + + Note: + To avoid persisting inconsistent job state to the database, any failures + during job manipulation (e.g., fetching job, updating fields) will result + in a safe rollback of the current transaction. This ensures that partial + updates do not corrupt job state. This manager DOES NOT COMMIT database + changes, only flushes them. Commit responsibility lies with the caller. + + Usage Patterns: + + Basic job execution: + >>> manager = JobManager(db, redis, job_id=123) + >>> manager.start_job() + >>> manager.update_progress(25, message="Starting validation") + >>> manager.succeed_job(result={"count": 100}) + + Progress tracking convenience: + >>> manager.set_progress_total(1000, "Processing 1000 records") + >>> for record in records: + ... process_record(record) + ... manager.increment_progress() # Increment by 1 + ... if manager.is_cancelled(): + ... break + + Job failure handling: + >>> try: + ... process_data() + ... except ValidationError as e: + ... manager.fail_job(error=e, result={"partial_results": partial_data}) + + Direct completion control: + >>> manager.complete_job(status=JobStatus.SUCCEEDED, result=data) + + Error handling: + >>> try: + ... manager.complete_job(status=JobStatus.SUCCEEDED, result=data) + ... except JobStateError as e: + ... logger.critical(f"Critical state failure: {e}") + ... # Job completion failed - state not saved + + Job retry: + >>> try: + ... manager.retry_job(reason="Transient network error") + ... except JobTransitionError as e: + ... logger.error(f"Cannot retry job in current state: {e}") + + Exception Hierarchy: + - DatabaseConnectionError: Cannot connect to database + - JobStateError: Critical state persistence failures + - JobTransitionError: Invalid state transitions (e.g., start already running job) + + Thread Safety: + JobManager is not thread-safe. Each instance should be used by a single + worker thread and should not be shared across concurrent operations. + """ + + def __init__(self, db: Session, redis: ArqRedis, job_id: int): + """Initialize JobManager for a specific job. + + Args: + db: Active SQLAlchemy session for database operations. Session should + be configured for the appropriate database and have proper + transaction isolation. + redis: ARQ Redis client for job queue operations. Must be connected + and ready for enqueue operations. + job_id: Unique identifier of the job to manage. Must correspond to + an existing JobRun record in the database. + + Raises: + DatabaseConnectionError: If the job cannot be fetched from database, + indicating connectivity issues or invalid job_id. + + Example: + >>> db_session = get_database_session() + >>> redis_client = get_arq_redis_client() + >>> manager = JobManager(db_session, redis_client, 12345) + >>> # Manager is now ready to handle job 12345 + """ + super().__init__(db, redis) + + self.job_id = job_id + job = self.get_job() + self.pipeline_id = job.pipeline_id if job else None + + def start_job(self) -> None: + """Mark job as started and initialize execution tracking. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Transitions job from QUEUED or PENDING to RUNNING state, setting start + timestamp and a default progress message. This method should be called + once at the beginning of job execution. + + State Changes: + - Sets status to JobStatus.RUNNING + - Records started_at timestamp + - Initializes progress to 0/100 + - Sets progress_message to "Job began execution" + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save job start state to database + JobTransitionError: Job not in valid state to start (must be QUEUED or PENDING) + + Example: + >>> manager = JobManager(db, redis, 123) + >>> manager.start_job() # Job 123 now marked as RUNNING + >>> # Proceed with job execution logic... + """ + job_run = self.get_job() + if job_run.status not in STARTABLE_JOB_STATUSES: + raise JobTransitionError(f"Cannot start job {self.job_id} from status {job_run.status}") + + try: + job_run.status = JobStatus.RUNNING + job_run.started_at = datetime.now() + job_run.progress_message = "Job began execution" + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job start state for job {self.job_id}: {e}") + raise JobStateError(f"Failed to update job start state: {e}") + + logger.info(f"Job {self.job_id} marked as started") + + def complete_job(self, status: JobStatus, result: JobResultData, error: Optional[Exception] = None) -> None: + """Mark job as completed with the specified final status. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Transitions job to the passed terminal status (SUCCEEDED, FAILED, CANCELLED, SKIPPED), + recording the finished_at timestamp, result data, and error details if applicable. + + Args: + status: Final job status - must be a terminal status + (SUCCEEDED, FAILED, CANCELLED, SKIPPED) + result: JobResultData to store in metadata. Should be JSON-serializable + dictionary containing any outputs, metrics, or artifacts produced. + error: Exception that caused job failure, if applicable. Error details + will be logged and stored for debugging. + + State Changes: + - Sets status to the specified terminal status + - Sets finished_at timestamp + - Stores result in job metadata + - Records error details if provided and status is FAILED + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + JobTransitionError: Invalid terminal status provided + + Examples: + Successful completion: + >>> result_data = {"records_processed": 1500, "errors": 0} + >>> manager.complete_job( + ... status=JobStatus.SUCCEEDED, + ... result=result_data + ... ) + + Failed completion with error: + >>> try: + ... process_data() + ... except ValidationError as e: + ... manager.complete_job( + ... status=JobStatus.FAILED, + ... result={"partial_results": data}, + ... error=e + ... ) + + Note: + Job completion state is saved independently of any pipeline + coordination. Use PipelineManager for coordinating dependent jobs. + """ + # Validate terminal status + if status not in TERMINAL_JOB_STATUSES: + raise JobTransitionError( + f"Cannot commplete job to status: {status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" + ) + + job_run = self.get_job() + try: + job_run.status = status + job_run.metadata_["result"] = result + job_run.finished_at = datetime.now() + + if status == JobStatus.SUCCEEDED: + job_run.progress_message = "Job completed successfully" + elif status == JobStatus.CANCELLED: + job_run.progress_message = "Job cancelled" + elif status == JobStatus.SKIPPED: + job_run.progress_message = "Job skipped" + elif status == JobStatus.FAILED: + job_run.progress_message = "Job failed" + job_run.failure_category = FailureCategory.UNKNOWN + + if error: + job_run.error_message = str(error) + job_run.error_traceback = traceback.format_exc() + # TODO: Classify failure category based on error type + job_run.failure_category = FailureCategory.UNKNOWN + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job completion state for job {self.job_id}: {e}") + raise JobStateError(f"Failed to update job completion state: {e}") + + logger.info(f"Job {self.job_id} marked as {status.value}") + + def fail_job(self, error: Exception, result: JobResultData) -> None: + """Mark job as failed and record error details. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job execution as failed. This is equivalent + to calling complete_job(status=JobStatus.FAILED, error=error, result=result) but + provides clearer intent and a more focused API for failure scenarios. + + Args: + error: Exception that caused job failure. Error details will be logged + and stored for debugging. Used to populate error message and traceback. + result: Partial results to store in metadata. Should be + JSON-serializable dictionary containing any partial outputs, + metrics, or debugging information produced before failure. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Basic failure with exception: + >>> try: + ... validate_data(input_data) + ... except ValidationError as e: + ... manager.fail_job(error=e) + + Failure with partial results: + >>> try: + ... results = process_batch(records) + ... except ProcessingError as e: + ... partial_results = {"processed": len(results), "failed_at": e.record_id} + ... manager.fail_job(error=e, result=partial_results) + + Note: + This method is equivalent to complete_job(status=JobStatus.FAILED, error=error, result=result). + Use this method when job failure is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.FAILED, result=result, error=error) + + def succeed_job(self, result: JobResultData) -> None: + """Mark job as succeeded and record results. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job execution as successful. This is equivalent + to calling complete_job(status=JobStatus.SUCCEEDED, result=result) but provides clearer + intent and a more focused API for success scenarios. + + Args: + result: Job result data to store in metadata. Should be JSON-serializable + dictionary containing any outputs, metrics, or artifacts produced. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Successful completion: + >>> result_data = {"records_processed": 1500, "errors": 0, "duration": 45.2} + >>> manager.succeed_job(result=result_data) + + Success with metrics: + >>> metrics = { + ... "input_count": 10000, + ... "output_count": 9847, + ... "skipped": 153, + ... "processing_time": 120.5, + ... "memory_peak": "2.1GB" + ... } + >>> manager.succeed_job(result=metrics) + + Note: + This method is equivalent to complete_job(status=JobStatus.SUCCEEDED, result=result). + Use this method when job success is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.SUCCEEDED, result=result) + + def cancel_job(self, result: JobResultData) -> None: + """Mark job as cancelled. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job execution as cancelled. This is equivalent + to calling complete_job(status=JobStatus.CANCELLED, result=result) but provides + clearer intent and a more focused API for cancellation scenarios. + + Args: + reason: Human-readable reason for cancellation (e.g., "user_requested", + "pipeline_cancelled", "timeout"). Used for debugging and audit trails. + result: Partial results to store in metadata. Should be JSON-serializable + dictionary containing any partial outputs or cancellation details. + If None, defaults to cancellation metadata. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Basic cancellation: + >>> manager.cancel_job({"reason": "user_requested"}) + + Note: + This method is equivalent to complete_job(status=JobStatus.CANCELLED, result=result). + Use this method when job cancellation is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.CANCELLED, result=result) + + def skip_job(self, result: JobResultData) -> None: + """Mark job as skipped. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for marking job as skipped (not executed). This is equivalent + to calling complete_job(status=JobStatus.SKIPPED, result=result) but provides + clearer intent and a more focused API for skip scenarios. + + Args: + result: Skip details to store in metadata. Should be JSON-serializable + dictionary containing skip reason and context. + If None, defaults to skip metadata. + + Raises: + DatabaseConnectionError: Cannot fetch job or connect to database + JobStateError: Cannot save job completion state - critical error + + Examples: + Basic skip: + >>> manager.skip_job({"reason": "No work to perform"}) + + Note: + This method is equivalent to complete_job(status=JobStatus.SKIPPED, result=result). + Use this method when job skipping is the primary outcome to make intent clearer. + """ + self.complete_job(status=JobStatus.SKIPPED, result=result) + + def prepare_retry(self, reason: str = "retry_requested") -> None: + """Prepare a failed job for retry by resetting state to PENDING. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Resets a failed job back to PENDING status so it can be re-enqueued + by the pipeline coordination system. This is similar to job completion + but transitions to PENDING instead of a terminal state. + + Args: + reason: Human-readable reason for the retry (e.g., "transient_network_error", + "memory_limit_exceeded"). Used for debugging and audit trails. + + State Changes: + - Increments retry_count + - Resets status from FAILED, SKIPPED, CANCELLED to PENDING + - Clears error_message, error_traceback, failure_category + - Clears finished_at timestamp + - Adds retry attempt to metadata history + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobTransitionError: Job not in FAILED state (cannot retry) + JobStateError: Cannot save retry state changes + + Examples: + Basic retry preparation: + >>> try: + ... manager.prepare_retry("network_timeout") + ... except JobTransitionError: + ... logger.error("Cannot retry job - not in failed state") + + Conditional retry with limits: + >>> job = manager.get_job() + >>> if job and job.retry_count < 3: + ... manager.prepare_retry(f"attempt_{job.retry_count + 1}") + ... # PipelineManager will handle enqueueing + ... else: + ... logger.error("Max retries exceeded") + + Retry History: + Each retry attempt is recorded in job metadata with: + - retry_attempt: Sequential attempt number + - timestamp: When retry was initiated + - result: Previous execution results (for debugging) + - reason: Provided retry reason + + Note: + After calling this method, use PipelineManager.enqueue_ready_jobs() + to actually enqueue the job for execution. + """ + job_run = self.get_job() + if job_run.status not in RETRYABLE_JOB_STATUSES: + raise JobTransitionError(f"Cannot retry job {self.job_id} due to invalid state ({job_run.status})") + + try: + job_run.status = JobStatus.PENDING + current_result: JobResultData = job_run.metadata_.get("result", {}) + job_run.retry_count = (job_run.retry_count or 0) + 1 + job_run.progress_message = "Job retry prepared" + job_run.error_message = None + job_run.error_traceback = None + job_run.failure_category = None + job_run.finished_at = None + job_run.started_at = None + + # Add retry history - metadata manipulation (risky) + retry_history: list[RetryHistoryEntry] = job_run.metadata_.setdefault("retry_history", []) + retry_history.append( + { + "attempt": job_run.retry_count, + "timestamp": datetime.now().isoformat(), + "result": current_result, + "reason": reason, + } + ) + job_run.metadata_.pop("result", None) # Clear previous result + flag_modified(job_run, "metadata_") + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job retry state for job {self.job_id}: {e}") + raise JobStateError(f"Failed to update job retry state: {e}") + + logger.info(f"Job {self.job_id} successfully prepared for retry (attempt {job_run.retry_count})") + + def prepare_queue(self) -> None: + """Prepare job for enqueueing by setting QUEUED status. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Transitions job from PENDING to QUEUED status before ARQ enqueueing. + This ensures proper state tracking and validates the transition. + + Raises: + JobTransitionError: Job not in PENDING state + JobStateError: Cannot save state change + """ + job_run = self.get_job() + if job_run.status != JobStatus.PENDING: + raise JobTransitionError(f"Cannot queue job {self.job_id} from status {job_run.status}") + + try: + job_run.status = JobStatus.QUEUED + job_run.progress_message = "Job queued for execution" + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to prepare job {self.job_id} for queueing: {e}") + raise JobStateError(f"Failed to update job queue state: {e}") + + logger.debug(f"Job {self.job_id} prepared for queueing") + + def reset_job(self) -> None: + """Reset job to initial state for re-execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Resets all job state fields to their initial values, allowing the job + to be re-executed from scratch. This is useful for testing or manual + re-runs of jobs without retaining any prior execution history. + + State Changes: + - Sets status to PENDING + - Clears started_at and finished_at timestamps + - Resets progress to 0/100 with default message + - Clears error details and failure category + - Resets retry_count to 0 + - Clears metadata + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save reset state changes + Examples: + Basic job reset: + >>> manager.reset_job() + >>> # Job is now reset to initial state for re-execution + """ + job_run = self.get_job() + try: + job_run.status = JobStatus.PENDING + job_run.started_at = None + job_run.finished_at = None + job_run.progress_current = None + job_run.progress_total = None + job_run.progress_message = None + job_run.error_message = None + job_run.error_traceback = None + job_run.failure_category = None + job_run.retry_count = 0 + job_run.metadata_ = {} + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job reset state for job {self.job_id}: {e}") + raise JobStateError(f"Failed to reset job state: {e}") + + logger.info(f"Job {self.job_id} successfully reset to initial state") + + def update_progress(self, current: int, total: int = 100, message: Optional[str] = None) -> None: + """Update job progress information during execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Provides real-time progress updates for long-running jobs. Progress updates + are best-effort operations that won't interrupt job execution if they fail. + This allows jobs to continue even if progress tracking has issues. + + Args: + current: Current progress value (e.g., records processed so far) + total: Total expected progress value (default: 100 for percentage) + message: Optional human-readable progress description + + Examples: + Percentage-based progress: + >>> manager.update_progress(25, 100, "Validating input data") + >>> manager.update_progress(50, 100, "Processing records") + >>> manager.update_progress(100, 100, "Finalizing results") + + Count-based progress: + >>> total_records = 50000 + >>> for i, record in enumerate(records): + ... process_record(record) + ... if i % 1000 == 0: # Update every 1000 records + ... manager.update_progress( + ... current=i, + ... total=total_records, + ... message=f"Processed {i}/{total_records} records" + ... ) + + Handling progress failures: + >>> try: + ... manager.update_progress(75, message="Almost done") + ... except DatabaseConnectionError: + ... logger.debug("Progress update failed, continuing job") + ... # Job continues normally + + Note: + Progress updates are non-blocking and failure-tolerant. If a progress + update fails, the job may choose to continue execution normally. Failed + progress updates are logged at debug level. + """ + job_run = self.get_job() + try: + job_run.progress_current = current + job_run.progress_total = total + if message: + job_run.progress_message = message + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job progress for job {self.job_id}: {e}") + raise JobStateError(f"Failed to update job progress state: {e}") + + logger.debug(f"Updated progress for job {self.job_id}: {current}/{total}") + + def update_status_message(self, message: str) -> None: + """Update job status message without changing progress. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for updating the progress message while keeping + current progress values unchanged. Useful for status updates during + long-running operations. + + Args: + message: Human-readable status message describing current activity + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save status message update + + Example: + >>> manager.update_status_message("Connecting to external API...") + >>> # Do API work + >>> manager.update_status_message("Processing API response...") + """ + job_run = self.get_job() + try: + job_run.progress_message = message + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job status message for job {self.job_id}: {e}") + raise JobStateError(f"Failed to update job status message state: {e}") + + logger.debug(f"Updated status message for job {self.job_id}: {message}") + + def increment_progress(self, amount: int = 1, message: Optional[str] = None) -> None: + """Increment job progress by a specified amount. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for incrementing progress without needing to track + the current progress value. Useful for batch processing where you want + to increment by 1 for each item processed. + + Args: + amount: Amount to increment progress by (default: 1) + message: Optional message to update along with progress + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save progress update + + Examples: + >>> # Process items one by one + >>> for item in items: + ... process_item(item) + ... manager.increment_progress() # Increment by 1 + + >>> # Process in batches + >>> for batch in batches: + ... process_batch(batch) + ... manager.increment_progress(len(batch), f"Processed batch {i}") + """ + job_run = self.get_job() + try: + current = job_run.progress_current or 0 + job_run.progress_current = current + amount + if message: + job_run.progress_message = message + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to increment job progress for job {self.job_id}: {e}") + raise JobStateError(f"Failed to increment job progress state: {e}") + + logger.debug(f"Incremented progress for job {self.job_id} by {amount} to {job_run.progress_current}") + + def set_progress_total(self, total: int, message: Optional[str] = None) -> None: + """Update the total progress value, useful when total becomes known during execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for updating progress total when it's discovered during + job execution (e.g., after counting records to process). + + Args: + total: New total progress value + message: Optional message to update along with total + + Raises: + DatabaseConnectionError: Cannot fetch job from database + JobStateError: Cannot save progress total update + + Example: + >>> # Initially unknown total + >>> manager.start_job() + >>> records = load_all_records() # Discovers actual count + >>> manager.set_progress_total(len(records), f"Processing {len(records)} records") + """ + job_run = self.get_job() + try: + job_run.progress_total = total + if message: + job_run.progress_message = message + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to update job progress total for job {self.job_id}: {e}") + raise JobStateError(f"Failed to update job progress total state: {e}") + + logger.debug(f"Updated progress total for job {self.job_id} to {total}") + + def is_cancelled(self) -> bool: + """Check if job has been cancelled or should stop execution. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method for checking if the job should stop execution due to + cancellation, pipeline failure, or other termination conditions. Jobs + can use this for graceful shutdown. + + Returns: + bool: True if job should stop execution, False if it can continue + + Raises: + DatabaseConnectionError: Cannot fetch job status from database + + Example: + >>> for item in large_dataset: + ... if manager.is_cancelled(): + ... logger.info("Job cancelled, stopping gracefully") + ... break + ... process_item(item) + """ + return self.get_job_status() in CANCELLED_JOB_STATUSES + + def should_retry(self) -> bool: + """Check if job should be retried based on error type and retry count. This method does + not flush or commit the database session; the caller is responsible for persisting changes. + + Convenience method that implements common retry logic. Checks current + retry count against maximum and evaluates if the error type is retryable. + + Returns: + bool: True if job should be retried, False otherwise + + Raises: + DatabaseConnectionError: Cannot fetch job info from database + + Examples: + >>> try: + ... result = do_work() + ... except NetworkError as e: + ... manager.fail_job(e, result) + ... if manager.should_retry(): + ... manager.retry_job() + ... else: + ... manager.fail_job(e, result) + """ + job_run = self.get_job() + try: + # Check if job is in FAILED state + if job_run.status != JobStatus.FAILED: + logger.debug(f"Job {self.job_id} not in FAILED state ({job_run.status}), cannot retry") + return False + + # Check retry count + current_retries = job_run.retry_count or 0 + if current_retries >= job_run.max_retries: + logger.debug(f"Job {self.job_id} has reached max retries ({current_retries}/{job_run.max_retries})") + return False + + # Check if failure category is retryable + if job_run.failure_category in RETRYABLE_FAILURE_CATEGORIES: + logger.debug( + f"Job {self.job_id} error {job_run.failure_category} is retryable ({current_retries}/{job_run.max_retries})" + ) + return True + + logger.debug(f"Job {self.job_id} error {job_run.failure_category} is not retryable") + return False + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Failed to check retry eligibility for job {self.job_id}: {e}") + raise JobStateError(f"Failed to check retry eligibility state: {e}") + + def get_job_status(self) -> JobStatus: # pragma: no cover + """Get current job status for monitoring and debugging. + + Provides non-blocking access to job status without affecting job + execution. Used by decorators and monitoring systems to check job state. + + Returns: + JobStatus: Current job status (QUEUED, RUNNING, SUCCEEDED, + FAILED, etc.). + + Raises: + DatabaseConnectionError: Cannot connect to database, SQL query failed, + or job not found (indicates data inconsistency) + + Examples: + >>> status = manager.get_job_status() + >>> if status == JobStatus.RUNNING: + ... logger.info("Job is currently executing") + """ + return self.get_job().status + + def get_job(self) -> JobRun: + """Get complete job information for monitoring and debugging. + + Retrieves full JobRun instance with all fields populated. Used by + decorators and monitoring systems that need access to job metadata, + progress, error details, or other comprehensive job information. + + Returns: + JobRun: Complete job instance with all fields. + + Raises: + DatabaseConnectionError: Cannot connect to database, SQL query failed, + or job not found (indicates data inconsistency) + + Example: + >>> job = manager.get_job() + >>> if job: + ... logger.info(f"Job {job.urn} progress: {job.progress_current}/{job.progress_total}") + ... if job.error_message: + ... logger.error(f"Job error: {job.error_message}") + """ + try: + return self.db.execute(select(JobRun).where(JobRun.id == self.job_id)).scalar_one() + except SQLAlchemyError as e: + logger.debug(f"SQL query failed getting job info for {self.job_id}: {e}") + raise DatabaseConnectionError(f"Failed to fetch job {self.job_id}: {e}") diff --git a/src/mavedb/worker/lib/managers/types.py b/src/mavedb/worker/lib/managers/types.py new file mode 100644 index 00000000..023338b6 --- /dev/null +++ b/src/mavedb/worker/lib/managers/types.py @@ -0,0 +1,14 @@ +from typing import TypedDict + + +class JobResultData(TypedDict): + output: dict + logs: str + metadata: dict + + +class RetryHistoryEntry(TypedDict): + attempt: int + timestamp: str + result: JobResultData + reason: str diff --git a/src/mavedb/worker/lib/py.typed b/src/mavedb/worker/lib/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/worker/lib/conftest.py b/tests/worker/lib/conftest.py new file mode 100644 index 00000000..362642f0 --- /dev/null +++ b/tests/worker/lib/conftest.py @@ -0,0 +1,191 @@ +# ruff: noqa: E402 + +""" +Test configuration and fixtures for worker lib tests. +""" + +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +from datetime import datetime +from unittest.mock import Mock, patch + +from arq import ArqRedis +from sqlalchemy.orm import Session + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.managers.job_manager import JobManager + + +@pytest.fixture +def sample_job_run(): + """Create a sample JobRun instance for testing.""" + return JobRun( + id=1, + urn="test:job:1", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=1, + progress_current=0, + progress_total=100, + progress_message="Ready to start", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_dependent_job_run(): + """Create a sample dependent JobRun instance for testing.""" + return JobRun( + id=2, + urn="test:job:2", + job_type="dependent_job", + job_function="dependent_function", + status=JobStatus.PENDING, + pipeline_id=1, + progress_current=0, + progress_total=100, + progress_message="Waiting for dependency", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_independent_job_run(): + """Create a sample independent JobRun instance for testing.""" + return JobRun( + id=3, + urn="test:job:3", + job_type="independent_job", + job_function="independent_function", + status=JobStatus.PENDING, + pipeline_id=None, + progress_current=0, + progress_total=100, + progress_message="Ready to start", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_pipeline(): + """Create a sample Pipeline instance for testing.""" + return Pipeline( + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_job_dependency(): + """Create a sample JobDependency instance for testing.""" + return JobDependency( + id=2, # dependent job + depends_on_job_id=1, # depends on job 1 + dependency_type=DependencyType.SUCCESS_REQUIRED, + created_at=datetime.now(), + ) + + +@pytest.fixture +def setup_worker_db( + session, + sample_job_run, + sample_pipeline, + sample_job_dependency, + sample_dependent_job_run, + sample_independent_job_run, +): + """Set up the database with sample data for worker tests.""" + session.add(sample_pipeline) + session.add(sample_job_run) + session.add(sample_dependent_job_run) + session.add(sample_independent_job_run) + session.add(sample_job_dependency) + session.commit() + + +@pytest.fixture +def job_manager_with_mocks(session, sample_job_run, sample_pipeline): + """Create a JobManager instance with mocked dependencies.""" + # Add test data to session + session.add(sample_job_run) + session.add(sample_pipeline) + session.commit() + + # Create JobManager instance + manager = JobManager(session, sample_job_run.id) + return manager + + +@pytest.fixture +def async_context(): + """Create a mock async context similar to ARQ worker context.""" + return { + "db": None, # Will be set by specific tests + "redis": None, # Will be set by specific tests + "job_id": 1, + "state": {}, + } + + +@pytest.fixture +def mock_job_run(): + """Create a mock JobRun instance. By default, + properties are identical to a default new JobRun entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=JobRun, + id=123, + urn="test:job:123", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=None, + priority=0, + max_retries=3, + retry_count=0, + retry_delay_seconds=None, + scheduled_at=datetime.now(), + started_at=None, + finished_at=None, + created_at=datetime.now(), + error_message=None, + error_traceback=None, + failure_category=None, + worker_id=None, + worker_host=None, + progress_current=None, + progress_total=None, + progress_message=None, + correlation_id=None, + metadata_={}, + mavedb_version=None, + ) + + +@pytest.fixture +def mock_job_manager(mock_job_run): + """Create a JobManager with mocked database and Redis dependencies.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + # Don't call the real constructor since it tries to load the job from DB + manager = object.__new__(JobManager) + manager.db = mock_db + manager.redis = mock_redis + manager.job_id = mock_job_run.id + + with patch.object(manager, "get_job", return_value=mock_job_run): + yield manager diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py new file mode 100644 index 00000000..5950a10d --- /dev/null +++ b/tests/worker/lib/managers/test_job_manager.py @@ -0,0 +1,2132 @@ +# ruff: noqa: E402 +""" +Comprehensive test suite for JobManager class. + +Tests cover all aspects of job lifecycle management, pipeline coordination, +error handling, and database interactions. +""" + +import pytest +from arq import ArqRedis + +pytest.importorskip("arq") +import re +from unittest.mock import Mock, PropertyMock, patch + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.managers.constants import ( + CANCELLED_JOB_STATUSES, + RETRYABLE_FAILURE_CATEGORIES, + RETRYABLE_JOB_STATUSES, + STARTABLE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + JobStateError, + JobTransitionError, +) +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.transaction_spy import TransactionSpy + +HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION = ( + AttributeError("Mock attribute error"), + KeyError("Mock key error"), + TypeError("Mock type error"), + ValueError("Mock value error"), +) + + +@pytest.mark.integration +class TestJobManagerInitialization: + """Test JobManager initialization and setup.""" + + def test_init_with_valid_job(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful initialization with valid job ID.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + assert manager.db == session + assert manager.job_id == sample_job_run.id + assert manager.pipeline_id == sample_job_run.pipeline_id + + def test_init_with_no_pipeline(self, session, arq_redis, setup_worker_db, sample_independent_job_run): + """Test initialization with job that has no pipeline.""" + manager = JobManager(session, arq_redis, sample_independent_job_run.id) + + assert manager.job_id == sample_independent_job_run.id + assert manager.pipeline_id is None + + def test_init_with_invalid_job_id(self, session, arq_redis): + """Test initialization failure with non-existent job ID.""" + job_id = 999 # Assuming this ID does not exist + with pytest.raises(DatabaseConnectionError, match=f"Failed to fetch job {job_id}"): + JobManager(session, arq_redis, job_id) + + +@pytest.mark.unit +class TestJobStartUnit: + """Unit tests for job start lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in STARTABLE_JOB_STATUSES], + ) + def test_start_job_raises_job_transition_error_when_managed_job_has_unstartable_status( + self, mock_job_manager, invalid_status, mock_job_run + ): + # Set initial job status to an invalid (unstartable) status. + mock_job_run.status = invalid_status + + # Start job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=f"Cannot start job {mock_job_manager.job_id} from status {invalid_status}", + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.start_job() + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], + ) + def test_start_job_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run, valid_status + ): + """Test job start failure due to exception during job object manipulation.""" + # Set initial job status to a valid status. Job status must be startable for this test. + mock_job_run.status = valid_status + + # Trigger: If any attribute access occurs on job, raise exception. If no access, return QUEUED. + def get_or_error(*args): + if args: + raise exception + return valid_status + + # Start job. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises(JobStateError, match="Failed to update job start state"), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.start_job() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == valid_status + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], + ) + def test_start_job_success(self, mock_job_manager, mock_job_run, valid_status): + """Test successful job start.""" + # Set initial job status to a valid status. Job status must be startable for this test. + mock_job_run.status = valid_status + + # Start job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.start_job() + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.RUNNING + assert mock_job_run.started_at is not None + assert mock_job_run.progress_message == "Job began execution" + + +@pytest.mark.integration +class TestJobStartIntegration: + """Integration tests for job start lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in STARTABLE_JOB_STATUSES], + ) + def test_job_exception_is_raised_when_job_has_invalid_status( + self, session, arq_redis, setup_worker_db, sample_job_run, invalid_status + ): + """Test job start failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to invalid status and commit changes. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = invalid_status + session.commit() + + # Start job. Verify a JobTransitionError is raised due to the previously set invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Although the job might still set some attributes before the error is raised, the exception + # indicates to the caller that the job was not started successfully and the transaction should be rolled back. + with ( + TransactionSpy.spy(manager.db), + pytest.raises( + JobTransitionError, + match=f"Cannot start job {sample_job_run.id} from status {invalid_status.value}", + ), + ): + manager.start_job() + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], + ) + def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run, valid_status): + """Test successful job start.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to invalid status and commit changes. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = valid_status + session.commit() + + # Start job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.start_job() + + # Commit pending changes made by start job. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + assert job.started_at is not None + assert job.progress_message == "Job began execution" + + +@pytest.mark.unit +class TestJobCompletionUnit: + """Unit tests for job completion lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in TERMINAL_JOB_STATUSES], + ) + def test_complete_job_raises_job_transition_error_when_managed_job_has_non_terminal_status( + self, mock_job_manager, mock_job_run, invalid_status + ): + # Set initial job status to an invalid (non-terminal) status. + mock_job_run.status = invalid_status + + # Complete job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=re.escape( + f"Cannot commplete job to status: {invalid_status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" + ), + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.complete_job(status=invalid_status, result={}) + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + def test_complete_job_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, mock_job_run, exception, valid_status + ): + """Test job completion failure due to exception during job object manipulation.""" + # Trigger: If any attribute setting on job status, raise exception. If only accessing, return whatever the mock + # objects original status was (starting job status doesn't matter for this test). + base_status = mock_job_run.status + + def get_or_error(*args): + if args: + raise exception + return base_status + + # Complete job. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises(JobStateError, match="Failed to update job completion state"), + TransactionSpy.spy(mock_job_manager.db), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.complete_job(status=valid_status, result={}) + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == base_status + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + def test_complete_job_sets_default_failure_category_when_job_failed(self, mock_job_manager, mock_job_run): + """Test job completion sets default failure category when job failed without error.""" + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.complete_job(status=JobStatus.FAILED, result={}) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == {"result": {}} + assert mock_job_run.progress_message == "Job failed" + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category == FailureCategory.UNKNOWN + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + @pytest.mark.parametrize( + "exception", + [ValueError("Test error"), None], + ) + def test_complete_job_success(self, mock_job_manager, valid_status, exception, mock_job_run): + """Test successful job completion.""" + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.complete_job(status=valid_status, result={"output": "test"}, error=exception) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == valid_status + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_["result"] == {"output": "test"} + assert mock_job_run.progress_message is not None + + # If an exception was provided, verify error fields are set appropriately. + if exception: + assert mock_job_run.error_message == str(exception) + assert mock_job_run.error_traceback is not None + assert mock_job_run.failure_category == FailureCategory.UNKNOWN + + else: + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + + # Proper handling of failure category only applies to FAILED status. See + # test_complete_job_sets_default_failure_category_when_job_failed for that case. + + +@pytest.mark.integration +class TestJobCompletionIntegration: + """Test job completion lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in TERMINAL_JOB_STATUSES], + ) + def test_job_exception_is_raised_when_job_has_invalid_status( + self, session, arq_redis, setup_worker_db, sample_job_run, invalid_status + ): + """Test job completion failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Verify a JobTransitionError is raised due to the passed invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Although the job might still set some attributes before the error is raised, the exception + # indicates to the caller that the job was not completed successfully and the transaction should be rolled back. + with ( + TransactionSpy.spy(manager.db), + pytest.raises( + JobTransitionError, + match=re.escape( + f"Cannot commplete job to status: {invalid_status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" + ), + ), + ): + manager.complete_job(status=invalid_status, result={"output": "test"}) + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + def test_job_updated_successfully_without_error( + self, session, arq_redis, setup_worker_db, sample_job_run, valid_status + ): + """Test successful job completion.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.complete_job(status=valid_status, result={"output": "test"}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == valid_status + assert job.finished_at is not None + assert job.metadata_ == {"result": {"output": "test"}} + assert job.error_message is None + assert job.error_traceback is None + + # For cases where no error is provided, verify failure category is set appropriately based + # on status. We automatically set UNKNOWN for FAILED status if no error is given. + if valid_status == JobStatus.FAILED: + assert job.failure_category == FailureCategory.UNKNOWN + else: + assert job.failure_category is None + + @pytest.mark.parametrize( + "valid_status", + [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], + ) + def test_job_updated_successfully_with_error( + self, session, arq_redis, setup_worker_db, sample_job_run, valid_status + ): + """Test successful job completion.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.complete_job(status=valid_status, result={"output": "test"}, error=ValueError("Test error")) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == valid_status + assert job.finished_at is not None + assert job.metadata_ == {"result": {"output": "test"}} + assert job.error_message == "Test error" + assert job.error_traceback is not None + assert job.failure_category == FailureCategory.UNKNOWN + + +@pytest.mark.unit +class TestJobFailureUnit: + """Unit tests for job failure lifecycle management.""" + + def test_fail_job_success(self, mock_job_manager, mock_job_run): + """Test that fail_job calls complete_job with status=JobStatus.FAILED.""" + + # Fail job with a test exception. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + # This convenience expects an exception to be provided. To fail a job without an exception, callers should use complete_job directly. + test_exception = Exception("Test exception") + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.fail_job(error=test_exception, result={"output": "test"}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with( + status=JobStatus.FAILED, result={"output": "test"}, error=test_exception + ) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == {"result": {"output": "test"}} + assert mock_job_run.progress_message == "Job failed" + assert mock_job_run.error_message == str(test_exception) + assert mock_job_run.error_traceback is not None + assert mock_job_run.failure_category == FailureCategory.UNKNOWN + + +class TestJobFailureIntegration: + """Test job failure lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job failure.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Fail job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.fail_job(result={"output": "test"}, error=ValueError("Test error")) + + # Commit pending changes made by fail job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.FAILED + assert job.finished_at is not None + assert job.metadata_ == {"result": {"output": "test"}} + assert job.progress_message == "Job failed" + assert job.error_message == "Test error" + assert job.error_traceback is not None + assert job.failure_category == FailureCategory.UNKNOWN + + +@pytest.mark.unit +class TestJobSuccessUnit: + """Unit tests for job success lifecycle management.""" + + def test_succeed_job_success(self, mock_job_manager, mock_job_run): + """Test that succeed_job calls complete_job with status=JobStatus.SUCCEEDED.""" + + # Succeed job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.succeed_job(result={"output": "test"}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with(status=JobStatus.SUCCEEDED, result={"output": "test"}) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.SUCCEEDED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == {"result": {"output": "test"}} + assert mock_job_run.progress_message == "Job completed successfully" + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + +class TestJobSuccessIntegration: + """Test job success lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job succeeding.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.succeed_job(result={"output": "test"}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.SUCCEEDED + assert job.finished_at is not None + assert job.progress_message == "Job completed successfully" + assert job.metadata_ == {"result": {"output": "test"}} + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + + +@pytest.mark.unit +class TestJobCancellationUnit: + """Unit tests for job cancellation lifecycle management.""" + + def test_cancel_job_success(self, mock_job_manager, mock_job_run): + """Test that cancel_job calls complete_job with status=JobStatus.CANCELLED.""" + + # Cancel job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.cancel_job(result={"error": "Job was cancelled"}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with(status=JobStatus.CANCELLED, result={"error": "Job was cancelled"}) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.CANCELLED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == {"result": {"error": "Job was cancelled"}} + assert mock_job_run.progress_message == "Job cancelled" + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + +class TestJobCancellationIntegration: + """Test job cancellation lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job cancellation.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.cancel_job(result={"output": "test"}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.CANCELLED + assert job.progress_message == "Job cancelled" + assert job.finished_at is not None + assert job.metadata_ == {"result": {"output": "test"}} + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + + +@pytest.mark.unit +class TestJobSkipUnit: + """Unit tests for job skip lifecycle management.""" + + def test_skip_job_success(self, mock_job_manager, mock_job_run): + """Test that skip_job calls complete_job with status=JobStatus.SKIPPED.""" + + # Skip job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.skip_job(result={"output": "test"}) + + # Verify this function is a thin wrapper around complete_job with expected parameters. + mock_complete_job.assert_called_once_with(status=JobStatus.SKIPPED, result={"output": "test"}) + + # Verify job state was updated on our mock object with expected values. + assert mock_job_run.status == JobStatus.SKIPPED + assert mock_job_run.finished_at is not None + assert mock_job_run.metadata_ == {"result": {"output": "test"}} + assert mock_job_run.progress_message == "Job skipped" + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + + +@pytest.mark.integration +class TestJobSkipIntegration: + """Test job skip lifecycle management.""" + + def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job skipping.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Skip job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.skip_job(result={"output": "test"}) + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + + assert job.status == JobStatus.SKIPPED + assert job.progress_message == "Job skipped" + assert job.finished_at is not None + assert job.metadata_ == {"result": {"output": "test"}} + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + + +@pytest.mark.unit +class TestPrepareRetryUnit: + """Unit tests for job retry lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status not in RETRYABLE_JOB_STATUSES], + ) + def test_prepare_retry_raises_job_transition_error_when_managed_job_has_unretryable_status( + self, mock_job_manager, invalid_status, mock_job_run + ): + # Set initial job status to an invalid (unretryable) status. + mock_job_run.status = invalid_status + + # Preprare retry job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=re.escape(f"Cannot retry job {mock_job_manager.job_id} due to invalid state ({invalid_status})"), + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_retry() + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.retry_count == 0 + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_prepare_retry_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job prepare retry failure due to exception during job object manipulation.""" + # Set initial job status to FAILED. Job status must be retryable for this test. + initial_status = JobStatus.FAILED + mock_job_run.status = initial_status + + # Trigger: If any attribute access occurs on job, raise exception. If no access, return FAILED. + def get_or_error(*args): + if args: + raise exception + return initial_status + + # Prepare retry. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job retry state", + ), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.prepare_retry() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.retry_count == 0 + assert mock_job_run.started_at is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_ == {} + + def test_prepare_retry_success(self, mock_job_manager, mock_job_run): + """Test successful job prepare retry.""" + # Set initial job status to FAILED. Job status must be retryable for this test. + mock_job_run.status = JobStatus.FAILED + + # Prepare retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Mock the flag_modified function: mock objects don't have _sa_instance_state attribute required by SQLAlchemy + # funcs and it's easier to mock the functions that manipulate the state than to fully mock the state itself. + with ( + patch("mavedb.worker.lib.managers.job_manager.flag_modified") as mock_flag_modified, + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_retry() + + # Verify flag_modified was called for metadata_ field. + mock_flag_modified.assert_called_once_with(mock_job_run, "metadata_") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.PENDING + assert mock_job_run.retry_count == 1 + assert mock_job_run.progress_message == "Job retry prepared" + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.finished_at is None + assert mock_job_run.metadata_["retry_history"] is not None + assert mock_job_run.started_at is None + assert mock_job_run.metadata_.get("result") is None + + +@pytest.mark.integration +class TestPrepareRetryIntegration: + """Test job retry lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status not in RETRYABLE_JOB_STATUSES], + ) + def test_prepare_retry_failed_due_to_invalid_status( + self, session, arq_redis, setup_worker_db, sample_job_run, job_status + ): + """Test job retry failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to non-failed state + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Prepare retry job. Verify a JobTransitionError is raised due to the passed invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(manager.db), + pytest.raises(JobTransitionError, match=f"Cannot retry job {job.id} due to invalid state \({job.status}\)"), + ): + manager.prepare_retry() + + def test_prepare_retry_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job retry.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to FAILED status and commit changes. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + session.commit() + + # Prepare retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.prepare_retry() + + # Commit pending changes made by start job. + session.commit() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 + assert job.progress_message == "Job retry prepared" + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + assert job.finished_at is None + assert job.metadata_["retry_history"] is not None + + +@pytest.mark.unit +class TestPrepareQueueUnit: + """Unit tests for job prepare for queue lifecycle management.""" + + @pytest.mark.parametrize( + "invalid_status", + [status for status in JobStatus._member_map_.values() if status != JobStatus.PENDING], + ) + def test_prepare_queue_raises_job_transition_error_when_managed_job_has_unretryable_status( + self, mock_job_manager, invalid_status, mock_job_run + ): + """Test job prepare queue failure due to invalid job status.""" + # Set initial job status to an invalid (non-pending) status. + mock_job_run.status = invalid_status + + # Prepare queue job. Verify a JobTransitionError is raised due to invalid state in the mocked + # job run. Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + pytest.raises( + JobTransitionError, + match=re.escape(f"Cannot queue job {mock_job_manager.job_id} from status {invalid_status}"), + ), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_queue() + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.status == invalid_status + assert mock_job_run.progress_message is None + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_prepare_queue_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job prepare queue failure due to exception during job object manipulation.""" + # Set initial job status to PENDING. Job status must be valid for this test. + initial_status = JobStatus.PENDING + mock_job_run.status = initial_status + + # Trigger: If any attribute access occurs on job, raise exception. If no access, return FAILED. + def get_or_error(*args): + if args: + raise exception + return initial_status + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job queue state", + ), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.prepare_queue() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == JobStatus.PENDING + assert mock_job_run.progress_message is None + + def test_prepare_queue_success(self, mock_job_manager, mock_job_run): + """Test successful job prepare queue.""" + # Set initial job status to PENDING. Job status must be valid for this test. + mock_job_run.status = JobStatus.PENDING + + # Prepare queue. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + # Mock the flag_modified function: mock objects don't have _sa_instance_state attribute required by SQLAlchemy + # funcs and it's easier to mock the functions that manipulate the state than to fully mock the state itself. + with ( + patch.object(mock_job_manager, "get_job", return_value=mock_job_run), + TransactionSpy.spy(mock_job_manager.db), + ): + mock_job_manager.prepare_queue() + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.QUEUED + assert mock_job_run.progress_message == "Job queued for execution" + + +@pytest.mark.integration +class TestPrepareQueue: + """Test job prepare for queue lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status != JobStatus.PENDING], + ) + def test_prepare_queue_failed_due_to_invalid_status( + self, session, arq_redis, setup_worker_db, sample_job_run, job_status + ): + """Test job prepare for queue failure due to invalid job status.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to invalid state + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.flush() + + # Prepare queue job. Verify a JobTransitionError is raised due to the passed invalid state. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(manager.db), + pytest.raises( + JobTransitionError, + match=f"Cannot queue job {job.id} from status {job.status}", + ), + ): + manager.prepare_queue() + + def test_prepare_queue_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job prepare for queue.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Sample run should be in PENDING state from fixture setup, but verify to be sure. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Sample job run must be in PENDING state for this test." + + # Prepare queue. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + + # Commit pending changes made by start job. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + assert job.progress_message == "Job queued for execution" + + +@pytest.mark.unit +class TestResetJobUnit: + """Unit tests for job reset lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_reset_job_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job reset job failure due to exception during job object manipulation.""" + + # Trigger: If any attribute setting occurs on job, raise exception. Otherwise return FAILED. + # Set initial job status to FAILED. Job status is unimportant for this test (all statuses are resettable). + initial_status = JobStatus.FAILED + mock_job_run.status = initial_status + + def get_or_error(*args): + if args: + raise exception + return initial_status + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to reset job state", + ), + ): + type(mock_job_run).status = PropertyMock(side_effect=get_or_error) + mock_job_manager.reset_job() + + # Verify job state on the mocked object remains unchanged. Although it's theoretically + # possible some job state is manipulated prior to an error being raised, our specific + # trigger should prevent any changes from being made. + assert mock_job_run.status == JobStatus.FAILED + assert mock_job_run.started_at is None + assert mock_job_run.finished_at is None + assert mock_job_run.progress_current is None + assert mock_job_run.progress_total is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.retry_count == 0 + assert mock_job_run.metadata_ == {} + + def test_reset_job_success(self, mock_job_manager, mock_job_run): + """Test successful job reset.""" + # Set initial job status to provided status. All statuses are resettable, so the actual status is not important. + mock_job_run.status = JobStatus.FAILED + + # Prepare queue. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.reset_job() + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.status == JobStatus.PENDING + assert mock_job_run.started_at is None + assert mock_job_run.finished_at is None + assert mock_job_run.progress_current is None + assert mock_job_run.progress_total is None + assert mock_job_run.progress_message is None + assert mock_job_run.error_message is None + assert mock_job_run.error_traceback is None + assert mock_job_run.failure_category is None + assert mock_job_run.retry_count == 0 + assert mock_job_run.metadata_ == {} + + +@pytest.mark.integration +class TestResetJobIntegration: + """Test job reset lifecycle management.""" + + def test_reset_job_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job reset.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Manually set job to a non-pending status and set various fields to non-default values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.started_at = "2023-12-31T23:59:59Z" + job.finished_at = "2024-01-01T00:00:00Z" + job.progress_current = 50 + job.progress_total = 100 + job.progress_message = "Halfway done" + job.error_message = "Test error message" + job.error_traceback = "Test error traceback" + job.failure_category = FailureCategory.UNKNOWN + job.retry_count = 2 + job.metadata_ = {"result": {}, "retry_history": [{"attempt": 1}, {"attempt": 2}]} + session.commit() + + # Reset job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.reset_job() + + # Commit pending changes made by reset job. + session.commit() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.progress_current is None + assert job.progress_total is None + assert job.progress_message is None + assert job.error_message is None + assert job.error_traceback is None + assert job.failure_category is None + assert job.started_at is None + assert job.finished_at is None + assert job.retry_count == 0 + assert job.metadata_.get("retry_history") is None + + +@pytest.mark.unit +class TestJobProgressUpdateUnit: + """Unit tests for job progress update lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_update_progress_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job progress update failure due to exception during job object manipulation.""" + # Trigger: If any attribute setting occurs on job progress, raise exception. If only access, return initial progress. + initial_progress_current = mock_job_run.progress_current + + def get_or_error(*args): + if args: + raise exception + return initial_progress_current + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job progress", + ), + ): + type(mock_job_run).progress_current = PropertyMock(side_effect=get_or_error) + mock_job_manager.update_progress(50, 100, "Halfway done") + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_current is None + assert mock_job_run.progress_total is None + assert mock_job_run.progress_message is None + + def test_update_progress_success(self, mock_job_manager, mock_job_run): + """Test successful job progress update.""" + + # Update progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.update_progress(50, 100, "Halfway done") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 50 + assert mock_job_run.progress_total == 100 + assert mock_job_run.progress_message == "Halfway done" + + def test_update_progress_does_not_overwrite_old_message_when_no_new_message_is_provided( + self, mock_job_manager, mock_job_run + ): + """Test successful job progress update without message.""" + + # Set initial progress message to verify it is not overwritten. + mock_job_run.progress_message = "Old message" + + # Update progress without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.update_progress(75, 200) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 75 + assert mock_job_run.progress_total == 200 + assert mock_job_run.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.integration +class TestJobProgressUpdateIntegration: + """Test job progress update lifecycle management.""" + + def test_update_progress_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful progress update.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to None to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = None + job.progress_total = None + job.progress_message = None + session.commit() + + # Update progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.update_progress(50, 100, "Halfway done") + + # Commit pending changes made by update progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 50 + assert job.progress_total == 100 + assert job.progress_message == "Halfway done" + + def test_update_progress_success_does_not_overwrite_old_message_when_no_new_message_is_provided( + self, session, arq_redis, setup_worker_db, sample_job_run + ): + """Test successful progress update without message.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to None to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = None + job.progress_total = None + job.progress_message = "Old message" + session.commit() + + # Update progress without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.update_progress(75, 200) + + # Commit pending changes made by update progress. + session.flush() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 75 + assert job.progress_total == 200 + assert job.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.unit +class TestJobProgressStatusUpdateUnit: + """Unit tests for job progress status update lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_update_status_message_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job status message update failure due to exception during job object manipulation.""" + # Trigger: If any attribute setting occurs on job progress message, raise exception. If only access, return initial message. + initial_progress_message = mock_job_run.progress_message + + def get_or_error(*args): + if args: + raise exception + return initial_progress_message + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job status message", + ), + ): + type(mock_job_run).progress_message = PropertyMock(side_effect=get_or_error) + mock_job_manager.update_status_message("New status message") + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_message == initial_progress_message + + def test_update_status_message_success(self, mock_job_manager, mock_job_run): + """Test successful job status message update.""" + + # Update status message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.update_status_message("New status message") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_message == "New status message" + + +@pytest.mark.integration +class TestJobProgressStatusUpdate: + """Test job progress status update lifecycle management.""" + + def test_update_status_message_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful status message update.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress message to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_message = "Old status message" + session.commit() + + # Update status message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.update_status_message("New status message") + + # Commit pending changes made by update status message. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_message == "New status message" + + +@pytest.mark.unit +class TestJobProgressIncrementationUnit: + """Unit tests for job progress incrementation lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_increment_progress_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job progress incrementation failure due to exception during job object manipulation.""" + # Trigger: If any attribute access occurs on job progress, raise exception. If no access, return initial progress. + initial_progress_current = mock_job_run.progress_current + + def get_or_error(*args): + if args: + raise exception + return initial_progress_current + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to increment job progress", + ), + ): + type(mock_job_run).progress_current = PropertyMock(side_effect=get_or_error) + mock_job_manager.increment_progress(10, "Incrementing progress") + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_current is None + assert mock_job_run.progress_message is None + + def test_increment_progress_success(self, mock_job_manager, mock_job_run): + """Test successful job progress incrementation.""" + + # Increment progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.increment_progress(10, "Incrementing progress") + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 10 + assert mock_job_run.progress_message == "Incrementing progress" + + def test_increment_progress_success_old_message_is_not_overwritten_when_none_provided( + self, mock_job_manager, mock_job_run + ): + """Test successful job progress incrementation without message.""" + + # Set initial progress message to verify it is not overwritten. + mock_job_run.progress_message = "Old message" + + # Increment progress without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.increment_progress(15) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_current == 15 + assert mock_job_run.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.integration +class TestJobProgressIncrementationIntegration: + """Test job progress incrementation lifecycle management.""" + + @pytest.mark.parametrize( + "msg", + [None, "Incremented progress successfully"], + ) + def test_increment_progress_success(self, session, arq_redis, setup_worker_db, sample_job_run, msg): + """Test successful progress incrementation.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to 0 to verify incrementation. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = 0 + job.progress_total = 100 + job.progress_message = "Test incrementation message" + session.commit() + + # Increment progress. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.increment_progress(10, msg) + + # Commit pending changes made by increment progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 10 + assert job.progress_total == 100 + assert job.progress_message == ( + msg if msg else "Test incrementation message" + ) # Message should remain unchanged if None + + def test_increment_progress_success_multiple_times(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful progress incrementation multiple times.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to 0 to verify incrementation. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = 0 + job.progress_total = 100 + session.commit() + + # Increment progress multiple times. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.increment_progress(20) + manager.increment_progress(30) + + # Commit pending changes made by increment progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 50 + assert job.progress_total == 100 + + def test_increment_progress_success_exceeding_total(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful progress incrementation exceeding total.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress to 0 to verify incrementation. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_current = 0 + job.progress_total = 100 + session.commit() + + # Increment progress exceeding total. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.increment_progress(150) + + # Commit pending changes made by increment progress. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 150 + assert job.progress_total == 100 + + +class TestJobProgressTotalUpdateUnit: + """Unit tests for job progress total update lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_set_progress_total_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """Test job progress total update failure due to exception during job object manipulation.""" + # Trigger: If any attribute access occurs on job progress total, raise exception. If no access, return initial total. + initial_progress_total = mock_job_run.progress_total + + def get_or_error(*args): + if args: + raise exception + return initial_progress_total + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to update job progress total state", + ), + ): + type(mock_job_run).progress_total = PropertyMock(side_effect=get_or_error) + mock_job_manager.set_progress_total(200) + + # Verify job state on the mocked object remains unchanged. + assert mock_job_run.progress_total == initial_progress_total + + def test_set_progress_total_success(self, mock_job_manager, mock_job_run): + """Test successful job progress total update.""" + + # Set progress total. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.set_progress_total(200) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_total == 200 + + def test_set_progress_total_does_not_overwrite_old_message_when_no_new_message_is_provided( + self, mock_job_manager, mock_job_run + ): + """Test successful job progress total update without message.""" + + # Set initial progress message to verify it is not overwritten. + mock_job_run.progress_message = "Old message" + + # Set progress total without message. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + mock_job_manager.set_progress_total(300) + + # Verify job state was updated on our mock object with expected values. + # These changes would normally be persisted by the caller after this method returns. + assert mock_job_run.progress_total == 300 + assert mock_job_run.progress_message == "Old message" # Message should remain unchanged from initial set. + + +@pytest.mark.integration +class TestJobProgressTotalUpdateIntegration: + """Test job progress total update lifecycle management.""" + + def test_set_progress_total_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful progress total update.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Set initial progress total and message to verify update. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.progress_total = 100 + job.progress_message = "Ready to start" + session.commit() + + # Set progress total. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + manager.set_progress_total(200, message="Updated total progress") + + # Commit pending changes made by set progress total. + session.commit() + + # Verify job state was updated in transaction with expected values. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_total == 200 + assert job.progress_message == "Updated total progress" + + +@pytest.mark.unit +class TestJobIsCancelledUnit: + """Unit tests for job is_cancelled lifecycle management.""" + + @pytest.mark.parametrize( + "status,expected_result", + [(status, status in CANCELLED_JOB_STATUSES) for status in JobStatus._member_map_.values()], + ) + def test_is_cancelled_success_not_cancelled(self, mock_job_manager, mock_job_run, status, expected_result): + """Test successful is_cancelled check when not cancelled.""" + # Set initial job status to a non-cancelled status. + mock_job_run.status = status + + # Check is_cancelled. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + result = mock_job_manager.is_cancelled() + + assert result == expected_result + + +@pytest.mark.integration +class TestJobIsCancelledIntegration: + """Test job is_cancelled lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status in CANCELLED_JOB_STATUSES], + ) + def test_is_cancelled_success_cancelled(self, session, arq_redis, setup_worker_db, sample_job_run, job_status): + """Test successful is_cancelled check when cancelled.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Mark the job as cancelled in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Check is_cancelled. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.is_cancelled() + + # Verify the job is marked as cancelled. This method requires no persistance. + assert result is True + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status not in CANCELLED_JOB_STATUSES], + ) + def test_is_cancelled_success_not_cancelled(self, session, arq_redis, setup_worker_db, sample_job_run, job_status): + """Test successful is_cancelled check when not cancelled.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Mark the job as not cancelled in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Check is_cancelled. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.is_cancelled() + + # Verify the job is not marked as cancelled. This method requires no persistance. + assert result is False + + +@pytest.mark.unit +class TestJobShouldRetryUnit: + """Unit tests for job should_retry lifecycle management.""" + + @pytest.mark.parametrize( + "exception", + [ + pytest.param( + exc, + marks=pytest.mark.skip( + reason=( + "AttributeError is not propagated by mock objects: " + "Python's attribute lookup swallows AttributeError and mock returns a new mock instead. " + "See unittest.mock docs for details." + ) + ) + if isinstance(exc, AttributeError) + else (), + # ^ Only mark AttributeError for skip, others run as normal + ) + for exc in HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION + ], + ) + def test_should_retry_raises_job_state_error_when_handled_error_is_raised_during_object_manipulation( + self, mock_job_manager, exception, mock_job_run + ): + """ + Test should_retry check failure due to exception during job object manipulation. + + AttributeError is skipped in this test because Python's mock machinery swallows + AttributeError raised by property getters and instead returns a new mock, so the + exception is not propagated as expected. See unittest.mock documentation for details. + ^^ or something like that... don't ask me to explain why. + """ + + # Trigger: If any attribute access occurs on job, raise exception. + def get_or_error(*args): + raise exception + + # Remove any instance attribute that could shadow the property + if "status" in mock_job_run.__dict__: + del mock_job_run.__dict__["status"] + + # In cases where we want to raise on attribute access, we need to override the entire property + # or else AttributeError won't be raised due to some internal Mock nuances I don't understand. + type(mock_job_run).status = property(get_or_error) + + # Prepare queue. Verify a JobStateError is raised by our trigger. + # Spy on the transaction to ensure nothing is flushed/rolled back/committed prematurely. + with ( + TransactionSpy.spy(mock_job_manager.db), + pytest.raises( + JobStateError, + match="Failed to check retry eligibility state", + ), + ): + mock_job_manager.should_retry() + + @pytest.mark.parametrize( + "status,expected_result", + [ + (JobStatus.SUCCEEDED, False), + (JobStatus.CANCELLED, False), + (JobStatus.QUEUED, False), + (JobStatus.RUNNING, False), + (JobStatus.PENDING, False), + ], + ) + def test_should_retry_success_for_non_failed_statuses( + self, mock_job_manager, mock_job_run, status, expected_result + ): + """Test successful should_retry check.""" + # Set initial job status to provided status. + mock_job_run.status = status + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + result = mock_job_manager.should_retry() + + # Verify the result matches expected. + assert result == expected_result + + @pytest.mark.parametrize( + "retry_count,max_retries,failure_category,expected_result", + ( + [(0, 3, cat, True) for cat in RETRYABLE_FAILURE_CATEGORIES] # Initial retry, + + [(2, 3, RETRYABLE_FAILURE_CATEGORIES[0], True)] # Within retry limit (barely) + + [(3, 3, RETRYABLE_FAILURE_CATEGORIES[0], False)] # Exceeded retries + + [ + (1, 3, cat, False) + for cat in FailureCategory._member_map_.values() + if cat not in RETRYABLE_FAILURE_CATEGORIES + ] # Non-retryable failure categories + ), + ) + def test_should_retry_success_for_failed_status( + self, mock_job_manager, mock_job_run, retry_count, max_retries, failure_category, expected_result + ): + """Test successful should_retry check for failed status.""" + # Set initial job status to FAILED with provided parameters. + mock_job_run.status = JobStatus.FAILED + mock_job_run.retry_count = retry_count + mock_job_run.max_retries = max_retries + mock_job_run.failure_category = failure_category + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(mock_job_manager.db): + result = mock_job_manager.should_retry() + + # Verify the result matches expected. + assert result == expected_result + + +@pytest.mark.integration +class TestJobShouldRetryIntegration: + """Test job should_retry lifecycle management.""" + + @pytest.mark.parametrize( + "job_status", + [status for status in JobStatus._member_map_.values() if status != JobStatus.FAILED], + ) + def test_should_retry_success_non_failed_jobs_should_not_retry( + self, session, arq_redis, setup_worker_db, sample_job_run, job_status + ): + """Test successful should_retry check (only jobs in failed states may retry).""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to non-failed state + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = job_status + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should not retry. This method requires no persistance. + assert result is False + + def test_should_retry_success_exceeded_retry_attempts_should_not_retry( + self, session, arq_redis, setup_worker_db, sample_job_run + ): + """Test successful should_retry check with no retry attempts left.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to failed state with no retries left + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.max_retries = 3 + job.retry_count = 3 + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should not retry. This method requires no persistance. + assert result is False + + def test_should_retry_success_failure_category_is_not_retryable( + self, session, arq_redis, setup_worker_db, sample_job_run + ): + """Test successful should_retry check with non-retryable failure category.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to failed state with non-retryable failure category + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.max_retries = 3 + job.retry_count = 1 + job.failure_category = FailureCategory.UNKNOWN + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should not retry. This method requires no persistance. + assert result is False + + def test_should_retry_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful should_retry check with retryable failure category.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Update job to failed state with retryable failure category + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + job.status = JobStatus.FAILED + job.max_retries = 3 + job.retry_count = 1 + job.failure_category = RETRYABLE_FAILURE_CATEGORIES[0] + session.commit() + + # Check should_retry. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + result = manager.should_retry() + + # Verify the job should retry. This method requires no persistance. + assert result is True + + +@pytest.mark.unit +class TestGetJobUnit: + """Unit tests for job retrieval.""" + + def test_get_job_wraps_database_connection_error_when_encounters_sqlalchemy_error(self, mock_job_run): + """Test job retrieval failure during job fetch.""" + + # Prepare mock JobManager with mocked DB session that will raise SQLAlchemyError on query. + # We don't use the default fixture here since it usually wraps this function. + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + manager = object.__new__(JobManager) + manager.db = mock_db + manager.redis = mock_redis + manager.job_id = mock_job_run.id + + with ( + TransactionSpy.mock_database_execution_failure(manager.db), + pytest.raises(DatabaseConnectionError, match=f"Failed to fetch job {mock_job_run.id}"), + ): + manager.get_job() + + +@pytest.mark.integration +class TestGetJobIntegration: + """Test job retrieval.""" + + def test_get_job_success(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test successful job retrieval.""" + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Retrieve job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + with TransactionSpy.spy(manager.db): + job = manager.get_job() + + # Verify the retrieved job matches expected. + assert job.id == sample_job_run.id + assert job.status == JobStatus.PENDING + + def test_get_job_raises_job_not_found_error_when_job_does_not_exist(self, session, arq_redis, setup_worker_db): + """Test job retrieval failure when job does not exist.""" + with pytest.raises(DatabaseConnectionError, match="Failed to fetch job 9999"), TransactionSpy.spy(session): + JobManager(session, arq_redis, job_id=9999) # Non-existent job ID + + +@pytest.mark.integration +class TestJobManagerJob: + """Test overall job lifecycle management.""" + + def test_full_successful_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test full job lifecycle from start to completion.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING, "Job status should be RUNNING after starting" + assert job.started_at is not None, "Job started_at should be set after starting" + + # Set initial progress + with TransactionSpy.spy(manager.db): + manager.update_progress(0, 100, "Job started") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 0 + assert job.progress_total == 100 + assert job.progress_message == "Job started" + + # Update status message + with TransactionSpy.spy(manager.db): + manager.update_status_message("Began processing data") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_message == "Began processing data" + + # Set progress total + with TransactionSpy.spy(manager.db): + manager.set_progress_total(200, "Set total work units") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_total == 200 + assert job.progress_message == "Set total work units" + + # Increment progress + with TransactionSpy.spy(manager.db): + manager.increment_progress(100, "Halfway done") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 100 + assert job.progress_message == "Halfway done" + + # Increment progress again + with TransactionSpy.spy(manager.db): + manager.increment_progress(100, "All done") + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.progress_current == 200 + assert job.progress_message == "All done" + + # Complete job + with TransactionSpy.spy(manager.db): + manager.succeed_job(result={"output": "success"}) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + assert job.finished_at is not None + + # Verify job is not cancelled and should not retry + assert manager.is_cancelled() is False + assert manager.should_retry() is False + + # Verify final job state + final_job = manager.get_job() + assert final_job.status == JobStatus.SUCCEEDED + assert final_job.progress_current == 200 + assert final_job.progress_total == 200 + assert final_job.progress_message == "Job completed successfully" + + def test_full_cancelled_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test full job lifecycle for a cancelled job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Cancel job + with TransactionSpy.spy(manager.db): + manager.cancel_job({"reason": "User requested cancellation"}) + session.flush() + + # Verify job is cancelled + assert manager.is_cancelled() is True + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + assert job.finished_at is not None + assert job.progress_message == "Job cancelled" + + def test_full_skipped_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test full job lifecycle for a skipped job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Skip job + with TransactionSpy.spy(manager.db): + manager.skip_job(result={"reason": "Precondition not met"}) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + assert job.finished_at is not None + assert job.progress_message == "Job skipped" + + def test_full_failed_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test full job lifecycle for a failed job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=Exception("An error occurred"), + result={"details": "Traceback details here"}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.finished_at is not None + assert job.error_message == "An error occurred" + assert job.error_traceback is not None + + def test_full_retried_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test full job lifecycle for a retried job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=Exception("Temporary error"), + result={"details": "Traceback details here"}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # TODO: Use some failure method added later to set failure category to retryable during the + # call to fail_job above. For now, we manually set it here. + job.failure_category = RETRYABLE_FAILURE_CATEGORIES[0] + session.commit() + + # Should retry + assert manager.should_retry() is True + + # Prepare retry + with TransactionSpy.spy(manager.db): + manager.prepare_retry() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 + + def test_full_reset_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + """Test full job lifecycle for a reset job.""" + # Pre-manager: Job is created in DB in Pending state. Verify initial state. + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + + manager = JobManager(session, arq_redis, sample_job_run.id) + + # Prepare job to be enqueued + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=Exception("Some error"), + result={"details": "Traceback details here"}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Retry job + with TransactionSpy.spy(manager.db): + manager.prepare_retry() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 + + # Queeue job again + with TransactionSpy.spy(manager.db): + manager.prepare_queue() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED, "Job status should be QUEUED after preparing queue" + + # Start job again + with TransactionSpy.spy(manager.db): + manager.start_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Fail job again + with TransactionSpy.spy(manager.db): + manager.fail_job( + error=Exception("Another error"), + result={"details": "Traceback details here"}, + ) + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.retry_count == 1 + + # Reset job + with TransactionSpy.spy(manager.db): + manager.reset_job() + session.flush() + + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.progress_current is None + assert job.progress_total is None + assert job.retry_count == 0 From ae18eebbbaa7ae1ae4c43da3bee7c7814c4f5fc6 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 13 Jan 2026 20:27:00 -0800 Subject: [PATCH 07/70] feat: Pipeline manager class, supporting utilities, and unit tests - Created PipelineManager capable of coordinating jobs within a pipeline context - Introduced `construct_bulk_cancellation_result` to standardize cancellation result structures. - Added `job_dependency_is_met` to check job dependencies based on their types and statuses. - Created comprehensive tests for PipelineManager covering initialization, job coordination, status transitions, and error handling. - Implemented mocks for database and Redis dependencies to isolate tests. - Added tests for job enqueuing, cancellation, pausing, unpausing, and retrying functionalities. --- src/mavedb/worker/lib/__init__.py | 6 +- src/mavedb/worker/lib/managers/__init__.py | 14 +- src/mavedb/worker/lib/managers/constants.py | 23 +- src/mavedb/worker/lib/managers/exceptions.py | 27 + .../worker/lib/managers/pipeline_manager.py | 1127 +++++ src/mavedb/worker/lib/managers/types.py | 12 + src/mavedb/worker/lib/managers/utils.py | 69 + tests/worker/lib/conftest.py | 66 +- .../lib/managers/test_pipeline_manager.py | 3731 +++++++++++++++++ 9 files changed, 5065 insertions(+), 10 deletions(-) create mode 100644 src/mavedb/worker/lib/managers/pipeline_manager.py create mode 100644 src/mavedb/worker/lib/managers/utils.py create mode 100644 tests/worker/lib/managers/test_pipeline_manager.py diff --git a/src/mavedb/worker/lib/__init__.py b/src/mavedb/worker/lib/__init__.py index e011ce18..8ab17989 100644 --- a/src/mavedb/worker/lib/__init__.py +++ b/src/mavedb/worker/lib/__init__.py @@ -1,7 +1,7 @@ """ -Worker library modules for job management and coordination. +Worker library modules for job management and pipeline coordination. """ -from .managers import JobManager +from .managers import JobManager, PipelineManager -__all__ = ["JobManager"] +__all__ = ["JobManager", "PipelineManager"] diff --git a/src/mavedb/worker/lib/managers/__init__.py b/src/mavedb/worker/lib/managers/__init__.py index f5a21c38..b75eb40f 100644 --- a/src/mavedb/worker/lib/managers/__init__.py +++ b/src/mavedb/worker/lib/managers/__init__.py @@ -1,10 +1,11 @@ -"""Manager classes and shared utilities for job coordination. +"""Manager classes and shared utilities for job and pipeline coordination. -This package provides managers for job lifecycle,along with shared constants, exceptions, -and types used across the worker system. +This package provides managers for job lifecycle and pipeline coordination, +along with shared constants, exceptions, and types used across the worker system. Main Classes: JobManager: Individual job lifecycle management + PipelineManager: Pipeline coordination and dependency management Shared Utilities: Constants: Job statuses, timeouts, retry limits @@ -12,7 +13,7 @@ Types: TypedDict definitions and common type hints Example Usage: - >>> from mavedb.worker.lib.managers import JobManager + >>> from mavedb.worker.lib.managers import JobManager, PipelineManager >>> from mavedb.worker.lib.managers import JobStateError, TERMINAL_JOB_STATUSES >>> >>> job_manager = JobManager(db, redis, job_id) @@ -22,6 +23,8 @@ >>> job_manager.start_job() >>> job_manager.succeed_job({"output": "success"}) >>> + >>> # Pipeline coordination + >>> await pipeline_manager.coordinate_after_completion(True) """ # Main manager classes @@ -40,6 +43,7 @@ JobTransitionError, ) from .job_manager import JobManager +from .pipeline_manager import PipelineManager # Type definitions from .types import JobResultData, RetryHistoryEntry @@ -48,6 +52,7 @@ # Main classes "BaseManager", "JobManager", + "PipelineManager", # Constants "ACTIVE_JOB_STATUSES", "TERMINAL_JOB_STATUSES", @@ -55,6 +60,7 @@ "DatabaseConnectionError", "JobStateError", "JobTransitionError", + "PipelineCoordinationError", # Types "JobResultData", "RetryHistoryEntry", diff --git a/src/mavedb/worker/lib/managers/constants.py b/src/mavedb/worker/lib/managers/constants.py index acc95236..4eabd684 100644 --- a/src/mavedb/worker/lib/managers/constants.py +++ b/src/mavedb/worker/lib/managers/constants.py @@ -5,7 +5,7 @@ pipeline coordination. """ -from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus +from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus, PipelineStatus # Job status constants for common groupings STARTABLE_JOB_STATUSES = [JobStatus.QUEUED, JobStatus.PENDING] @@ -33,3 +33,24 @@ # TODO: Add more retryable exception types as needed ) """Failure categories that are considered retryable errors.""" + +# Pipeline coordination constants +STARTABLE_PIPELINE_STATUSES = [PipelineStatus.PAUSED, PipelineStatus.CREATED] +"""Pipeline statuses that can be transitioned to RUNNING state.""" + +TERMINAL_PIPELINE_STATUSES = [ + PipelineStatus.SUCCEEDED, + PipelineStatus.FAILED, + PipelineStatus.PARTIAL, + PipelineStatus.CANCELLED, +] +"""Pipeline statuses indicating finished execution (terminal states).""" + +CANCELLED_PIPELINE_STATUSES = [PipelineStatus.CANCELLED, PipelineStatus.FAILED] +"""Pipeline statuses indicating the pipeline has been cancelled or failed.""" + +CANCELLABLE_PIPELINE_STATUSES = [PipelineStatus.CREATED, PipelineStatus.RUNNING, PipelineStatus.PAUSED] +"""Pipeline statuses that can be cancelled/skipped.""" + +RUNNING_PIPELINE_STATUSES = [PipelineStatus.RUNNING] +"""Pipeline statuses indicating active execution.""" diff --git a/src/mavedb/worker/lib/managers/exceptions.py b/src/mavedb/worker/lib/managers/exceptions.py index 7a0ede6b..48fa4b83 100644 --- a/src/mavedb/worker/lib/managers/exceptions.py +++ b/src/mavedb/worker/lib/managers/exceptions.py @@ -9,6 +9,33 @@ class ManagerError(Exception): pass +## Pipeline Manager Exceptions + + +class PipelineManagerError(ManagerError): + """Pipeline Manager specific errors.""" + + pass + + +class PipelineCoordinationError(PipelineManagerError): + """Pipeline coordination failed - may be recoverable.""" + + pass + + +class PipelineTransitionError(PipelineManagerError): + """Pipeline is in wrong state for requested operation.""" + + pass + + +class PipelineStateError(PipelineManagerError): + """Critical pipeline state operations failed - database issues preventing state persistence.""" + + pass + + ## Job Manager Exceptions diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py new file mode 100644 index 00000000..b05f9706 --- /dev/null +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -0,0 +1,1127 @@ +"""Pipeline coordination management for job dependencies and status. + +This module provides the PipelineManager class for coordinating pipeline execution, +managing job dependencies, and updating pipeline status. The PipelineManager is +separated from individual job lifecycle management to provide clean separation of concerns. + +Example usage: + >>> from mavedb.worker.lib.pipeline_manager import PipelineManager + >>> + >>> # Initialize with database and Redis connections + >>> pipeline_manager = PipelineManager(db_session, redis_client, pipeline_id=456) + >>> + >>> # Coordinate after a job completes + >>> await pipeline_manager.coordinate_pipeline() + >>> + >>> # Update pipeline status + >>> new_status = pipeline_manager.transition_pipeline_status() + >>> + >>> # Cancel remaining jobs when pipeline fails + >>> cancelled_count = pipeline_manager.cancel_remaining_jobs( + ... reason="Dependency failed" + ... ) + >>> + >>> # Pause/unpause pipeline + >>> was_paused = pipeline_manager.pause_pipeline("Maintenance") + >>> was_unpaused = await pipeline_manager.unpause_pipeline("Complete") + +Error Handling: + The PipelineManager uses the same exception hierarchy as JobManager for consistency: + + - DatabaseConnectionError: Database connectivity issues + - JobStateError: Critical state persistence failures + - PipelineCoordinationError: Pipeline coordination failures +""" + +import logging +from datetime import datetime, timedelta +from typing import Sequence + +from arq import ArqRedis +from sqlalchemy import and_, func, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.managers import BaseManager, JobManager +from mavedb.worker.lib.managers.constants import ( + ACTIVE_JOB_STATUSES, + CANCELLED_JOB_STATUSES, + CANCELLED_PIPELINE_STATUSES, + RUNNING_PIPELINE_STATUSES, + TERMINAL_PIPELINE_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + PipelineCoordinationError, + PipelineStateError, + PipelineTransitionError, +) +from mavedb.worker.lib.managers.utils import ( + construct_bulk_cancellation_result, + job_dependency_is_met, + job_should_be_skipped_due_to_unfulfillable_dependency, +) + +logger = logging.getLogger(__name__) + + +class PipelineManager(BaseManager): + """Manages pipeline coordination and job dependencies with atomic operations. + + The PipelineManager provides a focused interface for coordinating pipeline execution + without coupling to individual job lifecycle management. It handles dependency + checking, status updates, and pipeline-wide operations like cancellation. + + Key Features: + - Atomic pipeline status transitions with rollback on failure + - Dependency-based job enqueueing with race condition prevention + - Pipeline-wide cancellation with proper error handling + - Separation from individual job lifecycle management + - Consistent exception handling and logging + + Usage Patterns: + + Pipeline coordination after job completion: + >>> manager = PipelineManager(db, redis, pipeline_id=123) + >>> await manager.coordinate_pipeline() + + Manual pipeline operations: + >>> # Update pipeline status based on current job states + >>> new_status = manager.transition_pipeline_status() + >>> + >>> # Cancel remaining jobs + >>> cancelled_count = manager.cancel_remaining_jobs( + ... reason="Manual cancellation" + ... ) + >>> + >>> # Pause pipeline execution + >>> was_paused = manager.pause_pipeline( + ... reason="System maintenance" + ... ) + >>> + >>> # Resume pipeline execution + >>> was_unpaused = await manager.unpause_pipeline( + ... reason="Maintenance complete" + ... ) + + Dependency management: + >>> # Check if a job can be enqueued + >>> can_run = manager.can_enqueue_job(job) + >>> + >>> # Enqueue all ready jobs (independent and dependent) + >>> await manager.enqueue_ready_jobs() + + Pipeline monitoring: + >>> # Get detailed progress statistics + >>> progress = manager.get_pipeline_progress() + >>> print(f"Pipeline {progress['completion_percentage']:.1f}% complete") + >>> + >>> # Get job counts by status + >>> counts = manager.get_job_counts_by_status() + >>> print(f"Failed jobs: {counts.get(JobStatus.FAILED, 0)}") + + Job retry and pipeline restart: + >>> # Retry all failed jobs + >>> retried_count = await manager.retry_failed_jobs() + >>> + >>> # Restart entire pipeline + >>> restarted = await manager.restart_pipeline("Fixed issue") + + Thread Safety: + PipelineManager is not thread-safe. Each instance should be used by a single + worker thread and should not be shared across concurrent operations. + """ + + def __init__(self, db: Session, redis: ArqRedis, pipeline_id: int): + """Initialize pipeline manager with database and Redis connections. + + Args: + db: SQLAlchemy database session for job and pipeline queries + redis: ARQ Redis client for job queue operations + pipeline_id: ID of the pipeline this manager instance will coordinate + + Raises: + DatabaseConnectionError: Cannot connect to database + + Example: + >>> db_session = get_database_session() + >>> redis_client = get_arq_redis_client() + >>> manager = PipelineManager(db_session, redis_client, pipeline_id=456) + """ + super().__init__(db, redis) + self.pipeline_id = pipeline_id + self.get_pipeline() # Validate pipeline exists on init + + async def start_pipeline(self) -> None: + """Start the pipeline + + Entry point to start pipeline execution. Sets pipeline status to RUNNING + and enqueues independent jobs using coordinate pipeline. + + Raises: + DatabaseConnectionError: Cannot query or update pipeline + PipelineStateError: Cannot update pipeline state + PipelineCoordinationError: Failed to enqueue ready jobs + + Example: + >>> # Start a new pipeline + >>> await pipeline_manager.start_pipeline() + """ + status = self.get_pipeline_status() + + if status != PipelineStatus.CREATED: + logger.info( + f"Pipeline {self.pipeline_id} is in a non-created state (current status: {status}) and may not be started" + ) + raise PipelineTransitionError(f"Pipeline {self.pipeline_id} is in state {status} and may not be started") + + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} started successfully") + await self.coordinate_pipeline() + + async def coordinate_pipeline(self) -> None: + """Coordinate pipeline after a job completes. + + This is the main coordination entry point called after jobs complete. + It updates pipeline status and enqueues ready jobs or cancels remaining jobs + based on the completion result. The method operates on the entire pipeline + state rather than tracking individual job completions. + + Raises: + DatabaseConnectionError: Cannot query job or pipeline info + PipelineStateError: Cannot update pipeline state + PipelineCoordinationError: Failed to enqueue jobs or cancel remaining jobs + JobStateError: Critical job state persistence failure + JobTransitionError: Job cannot be transitioned from current state to new state + + + Example: + >>> # Called after successful job completion + >>> await pipeline_manager.coordinate_pipeline() + """ + new_status = self.transition_pipeline_status() + self.db.flush() + + if new_status in CANCELLED_PIPELINE_STATUSES: + self.cancel_remaining_jobs(reason="Pipeline failed or cancelled") + + # Only enqueue new jobs if pipeline is running + if new_status in RUNNING_PIPELINE_STATUSES: + await self.enqueue_ready_jobs() + + # After enqueuing jobs, re-evaluate pipeline status in case it changed. + # We only expect the status to change if jobs with unsatisfiable dependencies were skipped. + self.transition_pipeline_status() + self.db.flush() + + def transition_pipeline_status(self) -> PipelineStatus: + """Update pipeline status based on current job states. + + Analyzes the status distribution of all jobs in the pipeline to determine + the appropriate pipeline status. Updates pipeline status and finished_at + timestamp when the status changes to a terminal state. + + Returns: + PipelineStatus: The current pipeline status after update. If unchanged, the + previous status is returned. + + Raises: + DatabaseConnectionError: Cannot query job statuses or pipeline info + JobStateError: Cannot update pipeline status or corrupted job data + + Status Logic: + - FAILED: Any job has FAILED status + - RUNNING: Any job is RUNNING or QUEUED + - SUCCEEDED: All jobs are SUCCEEDED + - PARTIAL: Mix of SUCCEEDED/SKIPPED/CANCELLED with no FAILED/RUNNING + - CANCELLED: All remaining jobs are CANCELLED + - No Change: If pipeline is PAUSED, CANCELLED, or has no jobs: status remains unchanged + + Example: + >>> new_status = pipeline_manager.transition_pipeline_status() + >>> print(f"Pipeline status is now {new_status}") + """ + pipeline = self.get_pipeline() + status_counts = self.get_job_counts_by_status() + + old_status = pipeline.status + try: + total_jobs = sum(status_counts.values()) + if old_status in TERMINAL_PIPELINE_STATUSES: + logger.debug(f"Pipeline {self.pipeline_id} is in terminal status {old_status}; skipping update") + return old_status # No change from terminal state + + if old_status == PipelineStatus.PAUSED: + logger.debug(f"Pipeline {self.pipeline_id} is paused; skipping status update") + return old_status # No change from paused state + + # The pipeline must not be in a terminal state (from above), but has no jobs. Consider it complete. + if total_jobs == 0: + logger.debug(f"No jobs found in pipeline {self.pipeline_id} - considering pipeline complete") + + self.set_pipeline_status(PipelineStatus.SUCCEEDED) + return PipelineStatus.SUCCEEDED + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Invalid job status data for pipeline {self.pipeline_id}: {e}") + raise PipelineStateError(f"Corrupted job status data for pipeline {self.pipeline_id}: {e}") + + # The pipeline is not in a terminal state and has jobs - determine new status + try: + if status_counts.get(JobStatus.FAILED, 0) > 0: + new_status = PipelineStatus.FAILED + elif status_counts.get(JobStatus.RUNNING, 0) > 0 or status_counts.get(JobStatus.QUEUED, 0) > 0: + new_status = PipelineStatus.RUNNING + + # Pending jobs still exist, don't change the status. + # These might be picked up soon, or they may be proactively + # skipped later if dependencies cannot be met. + # + # Although there is a tension between having only pending + # and succeeded jobs (which would suggest partial/succeeded), + # we leave the status as-is until jobs are actually processed. + # + # *A pipeline with a terminal status must not have pending jobs* + elif status_counts.get(JobStatus.PENDING, 0) > 0: + new_status = old_status + + elif status_counts.get(JobStatus.SUCCEEDED, 0) > 0: + succeeded_jobs = status_counts.get(JobStatus.SUCCEEDED, 0) + skipped_jobs = status_counts.get(JobStatus.SKIPPED, 0) + cancelled_jobs = status_counts.get(JobStatus.CANCELLED, 0) + + if succeeded_jobs == total_jobs: + new_status = PipelineStatus.SUCCEEDED + logger.debug(f"All jobs succeeded in pipeline {self.pipeline_id}") + elif (succeeded_jobs + skipped_jobs + cancelled_jobs) == total_jobs: + new_status = PipelineStatus.PARTIAL + logger.debug(f"Pipeline {self.pipeline_id} completed partially: {status_counts}") + else: + new_status = PipelineStatus.PARTIAL + logger.warning(f"Inconsistent job counts detected for pipeline {self.pipeline_id}: {status_counts}") + # TODO: Notification hooks + else: + new_status = PipelineStatus.CANCELLED + + if pipeline.status != new_status: + self.set_pipeline_status(new_status) + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Object manipulation failed updating pipeline status for {self.pipeline_id}: {e}") + raise PipelineStateError(f"Failed to update pipeline status for {self.pipeline_id}: {e}") + + if new_status != old_status: + logger.info(f"Pipeline {self.pipeline_id} status successfully updated to {new_status} from {old_status}") + else: + logger.debug(f"No status change for pipeline {self.pipeline_id} (remains {old_status})") + + return new_status + + async def enqueue_ready_jobs(self) -> None: + """Find and enqueue all jobs that are ready to run. + + Identifies pending jobs in the pipeline (including retries) whose dependencies + are satisfied, updates their status to QUEUED, and enqueues them in ARQ. + This handles both independent jobs and jobs with dependencies, as well as + jobs that have been prepared for retry. + + Does not enqueue jobs if the pipeline is paused. + + Raises: + DatabaseConnectionError: Cannot query pending jobs or job dependencies + JobStateError: Cannot update job state to QUEUED (critical failure) + PipelineCoordinationError: One or more jobs failed to enqueue in ARQ + + Process: + 1. Ensure pipeline is running (skip enqueues if not) + 2. Query all PENDING jobs in pipeline (includes retries) + 3. Check dependency requirements for each job + 4. For jobs ready to run: flush status change and enqueue in ARQ + + Note: + - This method handles both independent and dependent jobs uniformly - + any job in PENDING status that meets its dependency requirements + (including jobs with no dependencies) will be enqueued, unless the + pipeline is paused. + + Examples: + Basic usage: + >>> # Enqueue all ready jobs in the pipeline + >>> await pipeline_manager.enqueue_ready_jobs() + + Handling coordination errors: + >>> try: + ... await pipeline_manager.enqueue_ready_jobs() + ... except PipelineCoordinationError as e: + ... logger.error(f"Failed to enqueue some jobs: {e}") + ... # Optionally cancel pipeline or take other recovery actions + """ + current_status = self.get_pipeline_status() + if current_status not in RUNNING_PIPELINE_STATUSES: + logger.debug(f"Pipeline {self.pipeline_id} is not running - skipping job enqueue") + raise PipelineStateError( + f"Pipeline {self.pipeline_id} is in status {current_status} and cannot enqueue jobs" + ) + + jobs_to_queue: list[JobRun] = [] + for job in self.get_pending_jobs(): + job_manager = JobManager(self.db, self.redis, job.id) + + # Attempt to enqueue the job if dependencies are met + if self.can_enqueue_job(job): + job_manager.prepare_queue() + jobs_to_queue.append(job) + continue + + should_skip, reason = self.should_skip_job_due_to_dependencies(job) + if should_skip: + job_manager.skip_job( + { + "output": {}, + "logs": "", + "metadata": {"result": reason, "timestamp": datetime.now().isoformat()}, + } + ) + logger.info(f"Skipped job {job.urn} due to unmet dependencies: {reason}") + continue + + # Ensure enqueued jobs can view the status change and pipelines + # can view skipped jobs by flushing transactions. + self.db.flush() + + if not jobs_to_queue: + logger.debug(f"No ready jobs to enqueue in pipeline {self.pipeline_id}") + return + + successfully_enqueued = [] + for job in jobs_to_queue: + await self._enqueue_in_arq(job, is_retry=False) + successfully_enqueued.append(job.urn) + logger.info(f"Successfully enqueued job {job.urn}") + + logger.info(f"Successfully enqueued {len(successfully_enqueued)} jobs: {successfully_enqueued}.") + + def cancel_remaining_jobs(self, reason: str = "Pipeline cancelled") -> None: + """Cancel all remaining jobs in the pipeline when the pipeline fails. + + Finds all active pipeline jobs and marks them as SKIPPED or CANCELLED + to prevent further execution when the pipeline has failed. Records the + cancellation reason and timestamp for audit purposes. + + Args: + reason: Human-readable reason for cancellation + + Raises: + DatabaseConnectionError: Cannot query jobs to cancel + PipelineCoordinationError: Failed to cancel one or more jobs + """ + remaining_jobs = self.get_active_jobs() + if not remaining_jobs: + logger.debug(f"No jobs to cancel in pipeline {self.pipeline_id}") + else: + bulk_cancellation_result = construct_bulk_cancellation_result(reason) + + for job in remaining_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + + # Skip PENDING jobs, cancel RUNNING/QUEUED jobs + if job_manager.get_job_status() == JobStatus.PENDING: + job_manager.skip_job(result=bulk_cancellation_result) + logger.debug(f"Skipped job {job.urn}: {reason}") + else: + job_manager.cancel_job(result=bulk_cancellation_result) + logger.debug(f"Cancelled job {job.urn}: {reason}") + + logger.info(f"Cancelled all remaining jobs in pipeline {self.pipeline_id}") + + async def cancel_pipeline(self, reason: str = "Pipeline cancelled") -> None: + """Cancel the entire pipeline and all remaining jobs. + + Sets the pipeline status to CANCELLED and cancels all PENDING and QUEUED + jobs in the pipeline. Records the cancellation reason for audit purposes. + + Args: + reason: Human-readable reason for pipeline cancellation + + Raises: + DatabaseConnectionError: Cannot query or update pipeline/jobs + PipelineCoordinationError: Failed to cancel pipeline or jobs + + Example: + >>> # Cancel a running pipeline due to external event + >>> await pipeline_manager.cancel_pipeline( + ... reason="User requested cancellation" + ... ) + """ + current_status = self.get_pipeline_status() + + if current_status in TERMINAL_PIPELINE_STATUSES: + logger.info(f"Pipeline {self.pipeline_id} is already in terminal status {current_status}") + raise PipelineTransitionError( + f"Pipeline {self.pipeline_id} is in terminal state {current_status} and may not be cancelled" + ) + + self.set_pipeline_status(PipelineStatus.CANCELLED) + self.db.flush() + logger.info(f"Pipeline {self.pipeline_id} cancelled: {reason}") + + await self.coordinate_pipeline() + + async def pause_pipeline(self, reason: str = "Pipeline paused") -> None: + """Pause the pipeline to stop further job execution. + + Sets the pipeline status to PAUSED, preventing new jobs from being enqueued + while allowing currently running jobs to complete. This provides a way to + temporarily halt pipeline execution without cancelling remaining jobs. + + Args: + reason: Human-readable reason for pausing the pipeline + + Raises: + DatabaseConnectionError: Cannot query or update pipeline + JobStateError: Cannot update pipeline state + PipelineTransitionError: Pipeline cannot be paused due to current state + + Example: + >>> # Pause pipeline for maintenance + >>> was_paused = manager.pause_pipeline( + ... reason="System maintenance" + ... ) + """ + current_status = self.get_pipeline_status() + + if current_status in TERMINAL_PIPELINE_STATUSES: + logger.info(f"Pipeline {self.pipeline_id} cannot be paused (current status: {current_status})") + raise PipelineTransitionError( + f"Pipeline {self.pipeline_id} is in terminal state {current_status} and may not be paused" + ) + + if current_status == PipelineStatus.PAUSED: + logger.info(f"Pipeline {self.pipeline_id} is already paused") + raise PipelineTransitionError(f"Pipeline {self.pipeline_id} is already paused") + + self.set_pipeline_status(PipelineStatus.PAUSED) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} paused (was {current_status}): {reason}") + await self.coordinate_pipeline() + + async def unpause_pipeline(self, reason: str = "Pipeline unpaused") -> None: + """Unpause the pipeline and resume job execution. + + Sets the pipeline status from PAUSED back to RUNNING and enqueues any + jobs that are ready to run. This resumes normal pipeline execution + after a pause. + + Args: + reason: Human-readable reason for unpausing the pipeline + + Raises: + DatabaseConnectionError: Cannot query or update pipeline + PipelineStateError: Cannot update pipeline state + PipelineCoordinationError: Failed to enqueue ready jobs after unpause + + Example: + >>> # Resume pipeline after maintenance + >>> was_unpaused = await manager.unpause_pipeline( + ... reason="Maintenance complete" + ... ) + """ + current_status = self.get_pipeline_status() + + if current_status != PipelineStatus.PAUSED: + logger.info( + f"Pipeline {self.pipeline_id} is not paused (current status: {current_status}) and may not be unpaused" + ) + raise PipelineTransitionError( + f"Pipeline {self.pipeline_id} is not paused (current status: {current_status}) and may not be unpaused" + ) + + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} unpaused (was {current_status}): {reason}") + await self.coordinate_pipeline() + + async def restart_pipeline(self) -> None: + """Restart the entire pipeline from the beginning. + + Resets ALL jobs in the pipeline to PENDING status, resets pipeline state to RUNNING, and re-enqueues + independent jobs. This is useful for recovering from pipeline-wide issues. + + Raises: + PipelineCoordinationError: If restart operations fail + DatabaseConnectionError: If database operations fail + + Example: + >>> success = await manager.restart_pipeline("Fixed configuration issue") + >>> print(f"Pipeline restart: {'successful' if success else 'failed'}") + """ + all_jobs = self.get_all_jobs() + if not all_jobs: + logger.debug(f"No jobs found for pipeline {self.pipeline_id} restart") + return + + # Reset all jobs to PENDING status + for job in all_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + job_manager.reset_job() + + # Reset pipeline status to created + self.set_pipeline_status(PipelineStatus.CREATED) + self.db.flush() + + logger.info(f"Pipeline {self.pipeline_id} reset for restart successfully") + await self.start_pipeline() + + def can_enqueue_job(self, job: JobRun) -> bool: + """Check if a job can be enqueued based on dependency requirements. + + Validates that all job dependencies are satisfied according to their + dependency types before allowing enqueue. Prevents premature execution + of jobs that depend on incomplete predecessors. + + Args: + job: JobRun instance to check dependencies for + + Returns: + bool: True if all dependencies are satisfied and job can be enqueued, + False if dependencies are still pending + + Raises: + DatabaseConnectionError: Cannot query job dependencies + JobStateError: Corrupted dependency data detected + + Dependency Types: + - SUCCESS_REQUIRED: Dependent job must have SUCCEEDED status + - COMPLETION_REQUIRED: Dependent job must be SUCCEEDED or FAILED + """ + for dependency, dependent_job in self.get_dependencies_for_job(job): + try: + if not job_dependency_is_met( + dependency_type=dependency.dependency_type, + dependent_job_status=dependent_job.status, + ): + logger.debug(f"Job {job.urn} cannot be enqueued; dependency on job {dependent_job.urn} not met") + return False + + except (AttributeError, KeyError, TypeError, ValueError) as e: + logger.debug(f"Invalid dependency data detected for job {job.id}: {e}") + raise PipelineStateError(f"Corrupted dependency data during enqueue check for job {job.id}: {e}") + + logger.debug(f"All dependencies satisfied for job {job.urn}; ready to enqueue") + return True + + def should_skip_job_due_to_dependencies(self, job: JobRun) -> tuple[bool, str]: + """Check if a job's dependencies are unsatisfiable and the job should be skipped. + + Validates whether a job's dependencies can still be met based on the + current status of dependent jobs. This helps identify jobs that should + be skipped because their dependencies are in terminal non-success states. + + Args: + job: JobRun instance to check dependencies for + + Returns: + tuple[bool, str]: (True, reason) if dependencies cannot be met and job + should be skipped, (False, "") if dependencies may + still be satisfied + + Raises: + DatabaseConnectionError: Cannot query job dependencies + PipelineStateError: Critical state persistence failure + + Notes: + - A job is considered unreachable if any of its dependencies that + require SUCCESS have FAILED, SKIPPED, or CANCELLED status. + - A job is considered unreachable if any of its dependencies that + require COMPLETION have SKIPPED or CANCELLED status. + + Examples: + Basic usage: + >>> should_skip, reason = manager.should_skip_job_due_to_dependencies(job) + >>> if should_skip: + ... print(f"Job should be skipped: {reason}") + >>> else: + ... print("Job dependencies may still be satisfied") + """ + for dependency, dep_job in self.get_dependencies_for_job(job): + try: + should_skip, reason = job_should_be_skipped_due_to_unfulfillable_dependency( + dependency_type=dependency.dependency_type, + dependent_job_status=dep_job.status, + ) + + if should_skip: + logger.debug(f"Job {job.urn} should be skipped due to dependency on job {dep_job.urn}: {reason}") + # guaranteed to be str if should_skip is True + return True, reason # type: ignore + + except (AttributeError, KeyError, TypeError, ValueError) as e: + logger.debug(f"Invalid dependency data detected for job {job.id}: {e}") + raise PipelineStateError(f"Corrupted dependency data during skip check for job {job.id}: {e}") + + logger.debug(f"Job {job.urn} dependencies may still be satisfied; not skipping") + return False, "" + + async def retry_failed_jobs(self) -> None: + """Retry all failed jobs in the pipeline. + + Resets failed jobs to PENDING status and re-enqueues them for execution. + Only affects jobs with FAILED status; other jobs remain unchanged. + + Raises: + PipelineCoordinationError: If job retry fails + DatabaseConnectionError: If database operations fail + + Example: + >>> await manager.retry_failed_jobs() + >>> print("Successfully retried failed jobs") + """ + failed_jobs = self.get_failed_jobs() + if not failed_jobs: + logger.debug(f"No failed jobs found for pipeline {self.pipeline_id}") + return + + for job in failed_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + job_manager.prepare_retry() + + # Ensure the pipeline status is set to running so jobs are picked up + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + await self.coordinate_pipeline() + + async def retry_unsuccessful_jobs(self) -> None: + """Retry all unsuccessful jobs in the pipeline. + + Resets unsuccessful jobs (CANCELLED, SKIPPED, FAILED) to PENDING status + and re-enqueues them for execution. This is useful for recovering from + partial failures or interruptions. + + Raises: + PipelineCoordinationError: If job retry fails + DatabaseConnectionError: If database operations fail + + Example: + >>> await manager.retry_unsuccessful_jobs() + >>> print("Successfully retried unsuccessful jobs") + """ + unsuccessful_jobs = self.get_unsuccessful_jobs() + if not unsuccessful_jobs: + logger.debug(f"No unsuccessful jobs found for pipeline {self.pipeline_id}") + return + + for job in unsuccessful_jobs: + job_manager = JobManager(self.db, self.redis, job.id) + job_manager.prepare_retry() + + # Ensure the pipeline status is set to running so jobs are picked up + self.set_pipeline_status(PipelineStatus.RUNNING) + self.db.flush() + + await self.coordinate_pipeline() + + async def retry_pipeline(self) -> None: + """Retry all unsuccessful jobs in the pipeline. + + Convenience method to retry all jobs that did not complete successfully, + including CANCELLED, SKIPPED, and FAILED jobs. Resets their status to PENDING + and re-enqueues them for execution. + + This is equivalent to calling `retry_unsuccessful_jobs` but provides a clearer + semantic for pipeline-level retries. + """ + await self.retry_unsuccessful_jobs() + + def get_jobs_by_status(self, status: list[JobStatus]) -> Sequence[JobRun]: + """Get all jobs in the pipeline with a specific status. + + Args: + status: JobStatus to filter jobs by + + Returns: + Sequence[JobRun]: List of jobs with the specified status ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> running_jobs = manager.get_jobs_by_status([JobStatus.RUNNING]) + >>> print(f"Found {len(running_jobs)} running jobs") + """ + try: + return ( + self.db.execute( + select(JobRun) + .where(and_(JobRun.pipeline_id == self.pipeline_id, JobRun.status.in_(status))) + .order_by(JobRun.created_at) + ) + .scalars() + .all() + ) + except SQLAlchemyError as e: + logger.debug( + f"Database query failed getting jobs with status {status} for pipeline {self.pipeline_id}: {e}" + ) + raise DatabaseConnectionError(f"Failed to get jobs with status {status}: {e}") + + def get_pending_jobs(self) -> Sequence[JobRun]: + """Get all PENDING jobs in the pipeline. + + Convenience method for fetching all pending jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.PENDING]) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of pending jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> pending_jobs = manager.get_pending_jobs() + >>> print(f"Found {len(pending_jobs)} pending jobs") + """ + return self.get_jobs_by_status([JobStatus.PENDING]) + + def get_running_jobs(self) -> Sequence[JobRun]: + """Get all RUNNING jobs in the pipeline. + + Convenience method for fetching all running jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.RUNNING]) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of running jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> running_jobs = manager.get_running_jobs() + >>> print(f"Found {len(running_jobs)} running jobs") + """ + return self.get_jobs_by_status([JobStatus.RUNNING]) + + def get_active_jobs(self) -> Sequence[JobRun]: + """Get all active jobs in the pipeline. + + Convenience method for fetching all active jobs. This is equivalent + to calling get_jobs_by_status(ACTIVE_JOB_STATUSES) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of remaining jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> active_jobs = manager.get_active_jobs() + >>> print(f"Found {len(active_jobs)} active jobs") + """ + return self.get_jobs_by_status(ACTIVE_JOB_STATUSES) + + def get_failed_jobs(self) -> Sequence[JobRun]: + """Get all failed jobs in the pipeline. + + Convenience method for fetching all failed jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.FAILED]) but provides + clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of failed jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> failed_jobs = manager.get_failed_jobs() + >>> print(f"Found {len(failed_jobs)} failed jobs for potential retry") + """ + return self.get_jobs_by_status([JobStatus.FAILED]) + + def get_unsuccessful_jobs(self) -> Sequence[JobRun]: + """Get all unsuccessful jobs in the pipeline. + + Convenience method for fetching all unsuccessful (but terminated) jobs. This is equivalent + to calling get_jobs_by_status([JobStatus.FAILED, JobStatus.CANCELLED, JobStatus.SKIPPED]) + but provides clearer intent and a more focused API. + + Returns: + Sequence[JobRun]: List of unsuccessful jobs ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> unsuccessful_jobs = manager.get_unsuccessful_jobs() + >>> print(f"Found {len(unsuccessful_jobs)} unsuccessful jobs") + """ + return self.get_jobs_by_status(CANCELLED_JOB_STATUSES) + + def get_all_jobs(self) -> Sequence[JobRun]: + """Get all jobs in the pipeline regardless of status. + + Returns: + Sequence[JobRun]: List of all jobs in pipeline ordered by creation time + + Raises: + DatabaseConnectionError: Cannot query job information + + Examples: + >>> all_jobs = manager.get_all_jobs() + >>> print(f"Total jobs in pipeline: {len(all_jobs)}") + """ + try: + return ( + self.db.execute( + select(JobRun).where(JobRun.pipeline_id == self.pipeline_id).order_by(JobRun.created_at) + ) + .scalars() + .all() + ) + except SQLAlchemyError as e: + logger.debug(f"Database query failed getting all jobs for pipeline {self.pipeline_id}: {e}") + raise DatabaseConnectionError(f"Failed to get all jobs: {e}") + + def get_dependencies_for_job(self, job: JobRun) -> Sequence[tuple[JobDependency, JobRun]]: + """Get all dependencies for a specific job. + + Args: + job: JobRun instance to fetch dependencies for + + Returns: + Sequence[Row[tuple[JobDependency, JobRun]]]: List of dependencies with associated JobRun instances + + Raises: + DatabaseConnectionError: Cannot query job dependencies + + Examples: + >>> dependencies = manager.get_dependencies_for_job(job) + >>> for dependency, dep_job in dependencies: + ... print(f"Job {job.urn} depends on job {dep_job.urn} with dependency type {dependency.dependency_type}") + """ + try: + # Although the returned type wraps tuples in a row, the contents are still accessible as tuples. + # This allows unpacking as shown in the example, and we can ignore the type checker warning so + # callers can have access to the simpler interface. + return self.db.execute( + select(JobDependency, JobRun) + .join(JobRun, JobDependency.depends_on_job_id == JobRun.id) + .where(JobDependency.id == job.id) + ).all() # type: ignore + except SQLAlchemyError as e: + logger.debug(f"SQL query failed for dependencies of job {job.id}: {e}") + raise DatabaseConnectionError(f"Failed to get job dependencies for job {job.id}: {e}") + + def get_pipeline(self) -> Pipeline: + """Get the Pipeline instance for this manager. + + Returns: + Pipeline: The Pipeline instance associated with this manager + + Raises: + DatabaseConnectionError: Cannot query pipeline information + + Examples: + >>> pipeline = manager.get_pipeline() + >>> print(f"Pipeline ID: {pipeline.id}, Status: {pipeline.status}") + """ + + try: + return self.db.execute(select(Pipeline).where(Pipeline.id == self.pipeline_id)).scalar_one() + except SQLAlchemyError as e: + logger.debug(f"Database query failed getting pipeline {self.pipeline_id}: {e}") + raise DatabaseConnectionError(f"Failed to get pipeline {self.pipeline_id}: {e}") + + def get_job_counts_by_status(self) -> dict[JobStatus, int]: + """Get count of jobs by status for monitoring. + + Returns a simple dictionary mapping job statuses to their counts, + useful for dashboard displays and monitoring systems. + + Returns: + dict[JobStatus, int]: Dictionary mapping JobStatus to count + + Raises: + DatabaseConnectionError: Cannot query job information + + Example: + >>> counts = manager.get_job_counts_by_status() + >>> print(f"Failed jobs: {counts.get(JobStatus.FAILED, 0)}") + """ + try: + job_counts = self.db.execute( + select(JobRun.status, func.count(JobRun.id)) + .where(JobRun.pipeline_id == self.pipeline_id) + .group_by(JobRun.status) + ).all() + except SQLAlchemyError as e: + logger.debug(f"Database query failed getting job counts for pipeline {self.pipeline_id}: {e}") + raise DatabaseConnectionError(f"Failed to get job counts for pipeline {self.pipeline_id}: {e}") + + return {status: count for status, count in job_counts} + + def get_pipeline_progress(self) -> dict: + """Get detailed pipeline progress statistics. + + Provides comprehensive pipeline progress information including job counts, + completion percentage, duration, and estimated completion time. + + Returns: + dict: Pipeline progress statistics with the following keys: + - total_jobs: Total number of jobs in pipeline + - completed_jobs: Number of jobs in terminal states + - successful_jobs: Number of successfully completed jobs + - failed_jobs: Number of failed jobs + - running_jobs: Number of currently running jobs + - pending_jobs: Number of jobs waiting to run + - completion_percentage: Percentage of jobs completed (0-100) + - duration: Time pipeline has been running (in seconds) + - status_counts: Dictionary of job counts by status + + Raises: + DatabaseConnectionError: Cannot query pipeline or job information + + Example: + >>> progress = manager.get_pipeline_progress() + >>> print(f"Pipeline {progress['completion_percentage']:.1f}% complete") + """ + status_counts = self.get_job_counts_by_status() + pipeline = self.get_pipeline() + + try: + total_jobs = sum(status_counts.values()) + + if total_jobs == 0: + return { + "total_jobs": 0, + "completed_jobs": 0, + "successful_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "pending_jobs": 0, + "completion_percentage": 100.0, + "duration": 0, + "status_counts": {}, + } + + # Calculate progress metrics + successful_jobs = status_counts.get(JobStatus.SUCCEEDED, 0) + failed_jobs = status_counts.get(JobStatus.FAILED, 0) + running_jobs = status_counts.get(JobStatus.RUNNING, 0) + status_counts.get(JobStatus.QUEUED, 0) + pending_jobs = status_counts.get(JobStatus.PENDING, 0) + skipped_jobs = status_counts.get(JobStatus.SKIPPED, 0) + cancelled_jobs = status_counts.get(JobStatus.CANCELLED, 0) + + completed_jobs = successful_jobs + failed_jobs + skipped_jobs + cancelled_jobs + completion_percentage = (completed_jobs / total_jobs) * 100 if total_jobs > 0 else 0 + + # Calculate duration + duration = 0 + if pipeline.created_at: + end_time = pipeline.finished_at or datetime.now() + duration = int((end_time - pipeline.created_at).total_seconds()) + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Invalid data detected calculating progress for pipeline {self.pipeline_id}: {e}") + raise PipelineStateError(f"Corrupted data during progress calculation for pipeline {self.pipeline_id}: {e}") + + return { + "total_jobs": total_jobs, + "completed_jobs": completed_jobs, + "successful_jobs": successful_jobs, + "failed_jobs": failed_jobs, + "running_jobs": running_jobs, + "pending_jobs": pending_jobs, + "completion_percentage": completion_percentage, + "duration": duration, + "status_counts": status_counts, + } + + def get_pipeline_status(self) -> PipelineStatus: + """Get the current status of the pipeline. + + Returns: + PipelineStatus: Current status of the pipeline + + Raises: + DatabaseConnectionError: Cannot query pipeline information + + Example: + >>> status = manager.get_pipeline_status() + >>> print(f"Pipeline status: {status}") + """ + return self.get_pipeline().status + + def set_pipeline_status(self, new_status: PipelineStatus) -> None: + """Set the status of the pipeline. + + Args: + new_status: PipelineStatus enum value to set the pipeline to + + Raises: + DatabaseConnectionError: Cannot query or update pipeline information + PipelineStateError: Cannot update pipeline status + + Example: + >>> manager.set_pipeline_status(PipelineStatus.PAUSED) + >>> print("Pipeline paused") + + Note: + This method does not perform any validation on the status transition, + nor does it attempt to coordinate the pipeline after the status change + or flush the change to the database. + """ + pipeline = self.get_pipeline() + try: + pipeline.status = new_status + + # Ensure finished_at is set/cleared appropriately + if new_status in TERMINAL_PIPELINE_STATUSES: + pipeline.finished_at = datetime.now() + else: + pipeline.finished_at = None + + # Ensure started_at is set/cleared appropriately + if new_status == PipelineStatus.CREATED: + pipeline.started_at = None + elif new_status == PipelineStatus.RUNNING and pipeline.started_at is None: + pipeline.started_at = datetime.now() + + except (AttributeError, TypeError, KeyError, ValueError) as e: + logger.debug(f"Object manipulation failed setting status for pipeline {self.pipeline_id}: {e}") + raise PipelineStateError(f"Failed to set pipeline status for {self.pipeline_id}: {e}") + + logger.info(f"Pipeline {self.pipeline_id} status set to {new_status}") + + async def _enqueue_in_arq(self, job: JobRun, is_retry: bool) -> None: + """Enqueue a job in ARQ with proper error handling and retry delay. + + Args: + job: JobRun instance to enqueue + is_retry: Whether this is a retry attempt + + Raises: + PipelineCoordinationError: If ARQ enqueuing fails + """ + try: + defer_by = timedelta(seconds=job.retry_delay_seconds if is_retry and job.retry_delay_seconds else 0) + arq_success = await self.redis.enqueue_job(job.job_function, job.id, _defer_by=defer_by, _job_id=job.urn) + except Exception as e: + logger.debug(f"ARQ enqueue operation failed for job {job.urn}: {e}") + raise PipelineCoordinationError(f"Failed to enqueue job in ARQ: {e}") + + if arq_success: + logger.info(f"{'Retried' if is_retry else 'Enqueued'} job {job.urn} in ARQ") + else: + logger.info(f"Job {job.urn} has already been enqueued in ARQ") diff --git a/src/mavedb/worker/lib/managers/types.py b/src/mavedb/worker/lib/managers/types.py index 023338b6..68a5c217 100644 --- a/src/mavedb/worker/lib/managers/types.py +++ b/src/mavedb/worker/lib/managers/types.py @@ -12,3 +12,15 @@ class RetryHistoryEntry(TypedDict): timestamp: str result: JobResultData reason: str + + +class PipelineProgress(TypedDict): + total_jobs: int + completed_jobs: int + successful_jobs: int + failed_jobs: int + running_jobs: int + pending_jobs: int + completion_percentage: float + duration: int # seconds + status_counts: dict diff --git a/src/mavedb/worker/lib/managers/utils.py b/src/mavedb/worker/lib/managers/utils.py new file mode 100644 index 00000000..b7448e1e --- /dev/null +++ b/src/mavedb/worker/lib/managers/utils.py @@ -0,0 +1,69 @@ +"""Utility functions for job and pipeline management. + +This module provides helper functions for common operations in job and pipeline +management, such as creating standardized result structures, data formatting, and +dependency checking. +""" + +import logging +from datetime import datetime +from typing import Optional + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus +from mavedb.worker.lib.managers.constants import TERMINAL_JOB_STATUSES +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +def construct_bulk_cancellation_result(reason: str) -> JobResultData: + """Construct a standardized JobResultData structure for bulk job cancellations. + + Args: + reason: Human-readable reason for the cancellation + + Returns: + JobResultData: Standardized result data with cancellation metadata + """ + return { + "output": {}, + "logs": "", + "metadata": { + "reason": reason, + "timestamp": datetime.now().isoformat(), + }, + } + + +def job_dependency_is_met(dependency_type: Optional[DependencyType], dependent_job_status: JobStatus) -> bool: + """Check if a job dependency is met based on the dependency type and the status of the dependent job. + + Args: + dependency_type: Type of dependency ('hard' or 'soft') + dependent_job_status: Status of the dependent job + + Returns: + bool: True if the dependency is met, False otherwise + + Notes: + - For 'hard' dependencies, the dependent job must have succeeded. + - For 'soft' dependencies, the dependent job must be in a terminal state. + - If no dependency type is specified, the dependency is considered met. + """ + if not dependency_type: + logger.debug("No dependency type specified; assuming dependency is met.") + return True + + if dependency_type == DependencyType.SUCCESS_REQUIRED: + if dependent_job_status != JobStatus.SUCCEEDED: + logger.debug(f"Dependency not met: dependent job did not succeed ({dependent_job_status}).") + return False + + if dependency_type == DependencyType.COMPLETION_REQUIRED: + if dependent_job_status not in TERMINAL_JOB_STATUSES: + logger.debug( + f"Dependency not met: dependent job has not reached a terminal status ({dependent_job_status})." + ) + return False + + return True diff --git a/tests/worker/lib/conftest.py b/tests/worker/lib/conftest.py index 362642f0..fd707307 100644 --- a/tests/worker/lib/conftest.py +++ b/tests/worker/lib/conftest.py @@ -19,6 +19,7 @@ from mavedb.models.job_run import JobRun from mavedb.models.pipeline import Pipeline from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager @pytest.fixture @@ -86,6 +87,20 @@ def sample_pipeline(): ) +@pytest.fixture +def sample_empty_pipeline(): + """Create a sample Pipeline instance with no jobs for testing.""" + return Pipeline( + id=999, + urn="test:pipeline:999", + name="Empty Pipeline", + description="A pipeline with no jobs", + status=PipelineStatus.CREATED, + correlation_id="empty_correlation_456", + created_at=datetime.now(), + ) + + @pytest.fixture def sample_job_dependency(): """Create a sample JobDependency instance for testing.""" @@ -102,12 +117,14 @@ def setup_worker_db( session, sample_job_run, sample_pipeline, + sample_empty_pipeline, sample_job_dependency, sample_dependent_job_run, sample_independent_job_run, ): """Set up the database with sample data for worker tests.""" session.add(sample_pipeline) + session.add(sample_empty_pipeline) session.add(sample_job_run) session.add(sample_dependent_job_run) session.add(sample_independent_job_run) @@ -140,7 +157,30 @@ def async_context(): @pytest.fixture -def mock_job_run(): +def mock_pipeline(): + """Create a mock Pipeline instance. By default, + properties are identical to a default new Pipeline entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=Pipeline, + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + metadata_={}, + created_at=datetime.now(), + started_at=None, + finished_at=None, + created_by_user_id=None, + mavedb_version=None, + ) + + +@pytest.fixture +def mock_job_run(mock_pipeline): """Create a mock JobRun instance. By default, properties are identical to a default new JobRun entered into the db with sensible defaults for non-nullable but unset fields. @@ -152,7 +192,7 @@ def mock_job_run(): job_type="test_job", job_function="test_function", status=JobStatus.PENDING, - pipeline_id=None, + pipeline_id=mock_pipeline.id, priority=0, max_retries=3, retry_count=0, @@ -188,4 +228,26 @@ def mock_job_manager(mock_job_run): manager.job_id = mock_job_run.id with patch.object(manager, "get_job", return_value=mock_job_run): + manager.job_id = 123 + + return manager + + +@pytest.fixture +def mock_pipeline_manager(mock_job_manager, mock_pipeline): + """Create a PipelineManager with mocked database, Redis dependencies, and job manager.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + # Don't call the real constructor since it tries to validate the pipeline + manager = object.__new__(PipelineManager) + manager.db = mock_db + manager.redis = mock_redis + manager.pipeline_id = 123 + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.JobManager") as mock_job_manager_class, + patch.object(manager, "get_pipeline", return_value=mock_pipeline), + ): + mock_job_manager_class.return_value = mock_job_manager yield manager diff --git a/tests/worker/lib/managers/test_pipeline_manager.py b/tests/worker/lib/managers/test_pipeline_manager.py new file mode 100644 index 00000000..aedeffb3 --- /dev/null +++ b/tests/worker/lib/managers/test_pipeline_manager.py @@ -0,0 +1,3731 @@ +# ruff: noqa: E402 +""" +Comprehensive test suite for PipelineManager class. + +Tests cover all aspects of pipeline coordination, job dependency management, +status updates, error handling, and database interactions including new methods +for pipeline monitoring, job retry management, and restart functionality. +""" + +import pytest + +pytest.importorskip("arq") + +import datetime +from unittest.mock import Mock, PropertyMock, patch + +from arq import ArqRedis +from arq.jobs import Job as ArqJob +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.managers import JobManager +from mavedb.worker.lib.managers.constants import ( + ACTIVE_JOB_STATUSES, + CANCELLED_PIPELINE_STATUSES, + RUNNING_PIPELINE_STATUSES, + TERMINAL_PIPELINE_STATUSES, +) +from mavedb.worker.lib.managers.exceptions import ( + DatabaseConnectionError, + PipelineCoordinationError, + PipelineStateError, + PipelineTransitionError, +) +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +from tests.helpers.transaction_spy import TransactionSpy + +HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION = ( + AttributeError("Mock attribute error"), + KeyError("Mock key error"), + TypeError("Mock type error"), + ValueError("Mock value error"), +) + + +@pytest.mark.integration +class TestPipelineManagerInitialization: + """Test PipelineManager initialization and setup.""" + + def test_init_with_valid_pipeline(self, session, arq_redis, setup_worker_db, sample_pipeline): + """Test successful initialization with valid pipeline ID.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + assert manager.db == session + assert manager.redis == arq_redis + assert manager.pipeline_id == sample_pipeline.id + + def test_init_with_invalid_pipeline_id(self, session, arq_redis): + """Test initialization failure with non-existent pipeline ID.""" + pipeline_id = 999 # Assuming this ID does not exist + with pytest.raises(DatabaseConnectionError, match=f"Failed to get pipeline {pipeline_id}"): + PipelineManager(session, arq_redis, pipeline_id) + + def test_init_with_database_error(self, session, arq_redis, setup_worker_db, sample_pipeline): + """Test initialization failure with database connection error.""" + pipeline_id = sample_pipeline.id + + with ( + TransactionSpy.mock_database_execution_failure(session), + pytest.raises(DatabaseConnectionError, match=f"Failed to get pipeline {pipeline_id}"), + ): + PipelineManager(session, arq_redis, pipeline_id) + + +@pytest.mark.unit +class TestStartPipelineUnit: + """Unit tests for starting a pipeline.""" + + @pytest.mark.asyncio + async def test_start_pipeline_successful(self, mock_pipeline_manager): + """Test successful pipeline start from CREATED state.""" + with ( + patch.object( + mock_pipeline_manager, + "get_pipeline", + return_value=Mock(spec=Pipeline, status=PipelineStatus.CREATED), + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.start_pipeline() + + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + mock_coordinate.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "current_status", + [status for status in PipelineStatus._member_map_.values() if status != PipelineStatus.CREATED], + ) + async def test_start_pipeline_non_created_state(self, mock_pipeline_manager, current_status): + """Test pipeline start failure when not in CREATED state.""" + with ( + patch.object( + mock_pipeline_manager, + "get_pipeline_status", + return_value=current_status, + ), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is in state {current_status} and may not be started", + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.start_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + +@pytest.mark.integration +class TestStartPipelineIntegration: + """Integration tests for starting a pipeline.""" + + @pytest.mark.asyncio + async def test_start_pipeline_successful( + self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run + ): + """Test successful pipeline start from CREATED state.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with TransactionSpy.spy(session, expect_flush=True): + await manager.start_pipeline() + + # Commit the session to persist changes + session.commit() + + # Verify pipeline status is now RUNNING + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Verify the initial job was queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify the job was enqueued in Redis + jobs = await arq_redis.queued_jobs() + assert jobs[0].function == sample_job_run.job_function + + @pytest.mark.asyncio + async def test_start_pipeline_no_jobs(self, session, arq_redis, setup_worker_db, sample_empty_pipeline): + """Test pipeline start when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with TransactionSpy.spy(session, expect_flush=True): + await manager.start_pipeline() + + # Commit the session to persist changes + session.commit() + + # Verify pipeline status is now SUCCEEDED since there are no jobs + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_empty_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.SUCCEEDED + + # Verify no jobs were enqueued in Redis + jobs = await arq_redis.queued_jobs() + assert len(jobs) == 0 + + +@pytest.mark.unit +class TestCoordinatePipelineUnit: + """Unit tests for pipeline coordination logic.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "new_status", + CANCELLED_PIPELINE_STATUSES, + ) + async def test_coordinate_pipeline_cancels_remaining_jobs_status_transitions_to_cancellable( + self, + mock_pipeline_manager, + new_status, + ): + """Test that remaining jobs are cancelled if pipeline transitions to a cancelable status.""" + with ( + patch.object( + mock_pipeline_manager, "transition_pipeline_status", return_value=new_status + ) as mock_transition, + patch.object(mock_pipeline_manager, "cancel_remaining_jobs", return_value=None) as mock_cancel, + patch.object(mock_pipeline_manager, "enqueue_ready_jobs", return_value=None) as mock_enqueue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.coordinate_pipeline() + + mock_transition.assert_called_once() + mock_cancel.assert_called_once_with(reason="Pipeline failed or cancelled") + mock_enqueue.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "new_status", + RUNNING_PIPELINE_STATUSES, + ) + async def test_coordinate_pipeline_enqueues_jobs_when_status_transitions_to_running( + self, mock_pipeline_manager, new_status + ): + """Test coordination after successful job completion.""" + with ( + patch.object( + mock_pipeline_manager, "transition_pipeline_status", return_value=new_status + ) as mock_transition, + patch.object(mock_pipeline_manager, "cancel_remaining_jobs", return_value=None) as mock_cancel, + patch.object(mock_pipeline_manager, "enqueue_ready_jobs", return_value=None) as mock_enqueue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.coordinate_pipeline() + + assert mock_transition.call_count == 2 # Called once before and once after enqueuing jobs + mock_cancel.assert_not_called() + mock_enqueue.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "new_status", + [ + status + for status in PipelineStatus._member_map_.values() + if status not in CANCELLED_PIPELINE_STATUSES + RUNNING_PIPELINE_STATUSES + ], + ) + async def test_coordinate_pipeline_noop_for_other_status_transitions(self, mock_pipeline_manager, new_status): + """Test coordination no-op for non-cancelled/running status transitions.""" + with ( + patch.object( + mock_pipeline_manager, "transition_pipeline_status", return_value=new_status + ) as mock_transition, + patch.object(mock_pipeline_manager, "cancel_remaining_jobs", return_value=None) as mock_cancel, + patch.object(mock_pipeline_manager, "enqueue_ready_jobs", return_value=None) as mock_enqueue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.coordinate_pipeline() + + mock_transition.assert_called_once() + mock_cancel.assert_not_called() + mock_enqueue.assert_not_called() + + +@pytest.mark.integration +class TestCoordinatePipelineIntegration: + """Test pipeline coordination after job completion.""" + + @pytest.mark.asyncio + async def test_coordinate_pipeline_transitions_pipeline_to_failed_after_job_failure( + self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful pipeline coordination and job enqueuing after job completion.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job in the pipeline to a terminal status + sample_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were enqueued but that jobs were cancelled + mock_cancel.assert_called_once() + mock_enqueue.assert_not_called() + + # Verify that the pipeline status is now FAILED + assert manager.get_pipeline().status == PipelineStatus.FAILED + + # Verify that the failed job remains failed + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Verify that the pending job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + @pytest.mark.asyncio + async def test_coordinate_pipeline_transitions_pipeline_to_cancelled_after_pipeline_is_cancelled( + self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful pipeline coordination and job enqueuing after pipeline cancellation .""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to a cancelled status + manager.set_pipeline_status(PipelineStatus.CANCELLED) + session.commit() + + # Set the job in the pipeline to a running status + sample_job_run.status = JobStatus.RUNNING + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were enqueued but that jobs were cancelled + mock_cancel.assert_called_once() + mock_enqueue.assert_not_called() + + # Verify that the pipeline status is now CANCELLED + assert manager.get_pipeline().status == PipelineStatus.CANCELLED + + # Verify that the running job transitions to cancelled + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the pending dependent job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + @pytest.mark.asyncio + async def test_coordinate_running_pipeline_enqueues_ready_jobs( + self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful pipeline coordination and job enqueuing when jobs are still pending.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to a running status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were cancelled but that jobs were enqueued + mock_cancel.assert_not_called() + mock_enqueue.assert_called_once() + + # Verify that the non-dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is still pending (since its dependency is not yet complete) + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initial_status", + [PipelineStatus.CREATED, PipelineStatus.PAUSED, PipelineStatus.SUCCEEDED, PipelineStatus.PARTIAL], + ) + async def test_coordinate_pipeline_noop( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + initial_status, + ): + """Test successful pipeline coordination and job enqueuing when jobs are still pending.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to a cancelled status + manager.set_pipeline_status(initial_status) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object(manager, "cancel_remaining_jobs", wraps=manager.cancel_remaining_jobs) as mock_cancel, + patch.object(manager, "enqueue_ready_jobs", wraps=manager.enqueue_ready_jobs) as mock_enqueue, + ): + await manager.coordinate_pipeline() + + # Ensure no new jobs were enqueued or cancelled + mock_cancel.assert_not_called() + mock_enqueue.assert_not_called() + + # Verify that the job is still pending + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + # Verify that the dependent job is still pending + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestTransitionPipelineStatusUnit: + """Test pipeline status transition logic.""" + + @pytest.mark.parametrize( + "existing_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_terminal_state_results_in_retention_of_terminal_states( + self, mock_pipeline_manager, existing_status, mock_pipeline + ): + """No jobs in pipeline should result in no status change, so long as the pipeline is in a terminal state.""" + mock_pipeline.status = existing_status + + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result is existing_status + + mock_set_status.assert_not_called() + + def test_paused_state_results_in_retention_of_paused_state(self, mock_pipeline_manager, mock_pipeline): + """No jobs in pipeline should result in no status change when pipeline is paused.""" + mock_pipeline.status = PipelineStatus.PAUSED + + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result is PipelineStatus.PAUSED + + mock_set_status.assert_not_called() + + @pytest.mark.parametrize( + "existing_status", + [ + status + for status in PipelineStatus._member_map_.values() + if status not in TERMINAL_PIPELINE_STATUSES + [PipelineStatus.PAUSED] + ], + ) + def test_no_jobs_results_in_succeeded_state_if_not_terminal( + self, mock_pipeline_manager, existing_status, mock_pipeline + ): + """No jobs in pipeline should result in SUCCEEDED state if not already terminal.""" + mock_pipeline.status = existing_status + mock_pipeline.finished_at = None + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == PipelineStatus.SUCCEEDED + + mock_set_status.assert_called_once_with(PipelineStatus.SUCCEEDED) + + @pytest.mark.parametrize( + "job_counts,expected_status", + [ + # Any failure trumps everything + ({JobStatus.SUCCEEDED: 10, JobStatus.FAILED: 1}, PipelineStatus.FAILED), + # Running or queued jobs without failures keep pipeline running + ({JobStatus.SUCCEEDED: 5, JobStatus.FAILED: 0, JobStatus.RUNNING: 2}, PipelineStatus.RUNNING), + ({JobStatus.SUCCEEDED: 5, JobStatus.FAILED: 0, JobStatus.QUEUED: 3}, PipelineStatus.RUNNING), + # All succeeded + ({JobStatus.SUCCEEDED: 5}, PipelineStatus.SUCCEEDED), + # Mix of terminal states without failures + ({JobStatus.SUCCEEDED: 3, JobStatus.SKIPPED: 2}, PipelineStatus.PARTIAL), + ({JobStatus.SUCCEEDED: 1, JobStatus.CANCELLED: 1}, PipelineStatus.PARTIAL), + # All cancelled + ({JobStatus.CANCELLED: 5}, PipelineStatus.CANCELLED), + # All skipped + ({JobStatus.SKIPPED: 4}, PipelineStatus.CANCELLED), + # Some cancelled and skipped + ({JobStatus.CANCELLED: 2, JobStatus.SKIPPED: 3}, PipelineStatus.CANCELLED), + # Inconsistent state + ({JobStatus.CANCELLED: 2, JobStatus.SKIPPED: 1, JobStatus.SUCCEEDED: 1, None: 3}, PipelineStatus.PARTIAL), + ], + ) + def test_pipeline_status_determination_based_on_job_counts( + self, mock_pipeline_manager, job_counts, expected_status, mock_pipeline + ): + """Test pipeline status determination based on job counts.""" + mock_pipeline.status = PipelineStatus.CREATED + mock_pipeline.finished_at = None + + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value=job_counts), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == expected_status + + mock_set_status.assert_called_once_with(expected_status) + + @pytest.mark.parametrize( + "job_counts,existing_status", + [ + ({JobStatus.PENDING: 5}, PipelineStatus.CREATED), + ({JobStatus.SUCCEEDED: 5, JobStatus.PENDING: 3}, PipelineStatus.RUNNING), + ({JobStatus.PENDING: 2, JobStatus.SKIPPED: 4}, PipelineStatus.RUNNING), + ({JobStatus.PENDING: 1, JobStatus.CANCELLED: 1}, PipelineStatus.RUNNING), + ], + ) + def test_pipeline_status_determination_pending_jobs_do_not_change_status( + self, mock_pipeline_manager, job_counts, existing_status, mock_pipeline + ): + """Test that presence of pending jobs does not change pipeline status.""" + mock_pipeline.status = existing_status + + with ( + patch.object( + mock_pipeline_manager, + "get_job_counts_by_status", + return_value=job_counts, + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == existing_status + + mock_set_status.assert_not_called() + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_pipeline_status_determination_throws_state_error_for_handled_exceptions( + self, mock_pipeline_manager, exception + ): + """Test that handled exceptions during status determination raise PipelineStateError.""" + + # Mocks exception in first try/except + with ( + patch.object( + mock_pipeline_manager, + "get_job_counts_by_status", + return_value=Mock(side_effect=exception), + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + pytest.raises(PipelineStateError), + ): + mock_pipeline_manager.transition_pipeline_status() + mock_set_status.assert_not_called() + + # Mocks exception in second try/except + with ( + patch.object( + mock_pipeline_manager, + "get_job_counts_by_status", + return_value={JobStatus.SUCCEEDED: 5}, + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", side_effect=exception) as mock_set_status, + patch.object( + mock_pipeline_manager, "get_pipeline", return_value=Mock(spec=Pipeline, status=PipelineStatus.CREATED) + ), + pytest.raises(PipelineStateError), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.transition_pipeline_status() + + def test_pipeline_status_determination_no_change(self, mock_pipeline_manager, mock_pipeline): + """Test that no status change occurs if pipeline status remains the same.""" + mock_pipeline.status = PipelineStatus.SUCCEEDED + with ( + patch.object(mock_pipeline_manager, "get_job_counts_by_status", return_value={JobStatus.SUCCEEDED: 5}), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.transition_pipeline_status() + assert result == PipelineStatus.SUCCEEDED + + mock_set_status.assert_not_called() + + +class TestTransitionPipelineStatusIntegration: + """Integration tests for pipeline status transition logic.""" + + @pytest.mark.parametrize( + "initial_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_pipeline_status_transition_noop_when_status_is_terminal( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + initial_status, + ): + """Test that pipeline status remains unchanged when already in a terminal state.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial pipeline status + manager.set_pipeline_status(initial_status) + session.commit() + + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status remains unchanged + assert new_status == initial_status + assert manager.get_pipeline_status() == initial_status + + def test_pipeline_status_transition_noop_when_status_is_paused( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + ): + """Test that pipeline status remains unchanged when in PAUSED state.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial pipeline status to PAUSED + manager.set_pipeline_status(PipelineStatus.PAUSED) + session.commit() + + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status remains unchanged + assert new_status == PipelineStatus.PAUSED + assert manager.get_pipeline_status() == PipelineStatus.PAUSED + + @pytest.mark.parametrize( + "initial_status,expected_status", + [ + ( + status, + status if status in TERMINAL_PIPELINE_STATUSES + [PipelineStatus.PAUSED] else PipelineStatus.SUCCEEDED, + ) + for status in PipelineStatus._member_map_.values() + ], + ) + def test_pipeline_status_transition_when_no_jobs_in_pipeline( + self, + session, + arq_redis, + setup_worker_db, + initial_status, + expected_status, + sample_empty_pipeline, + ): + """Test that pipeline status transitions to SUCCEEDED when there are no jobs in a + non-terminal pipeline. If the pipeline is already in a terminal state, it should remain unchanged.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set initial pipeline status + manager.set_pipeline_status(initial_status) + session.commit() + + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is the expected status and that + # the status was persisted to the transaction + assert new_status == expected_status + assert manager.get_pipeline_status() == expected_status + + @pytest.mark.parametrize( + "initial_status,job_updates,expected_status", + [ + # Some failed -> failed + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.FAILED}, PipelineStatus.FAILED), + # Some running -> running + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.RUNNING}, PipelineStatus.RUNNING), + # Some queued -> running + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.QUEUED}, PipelineStatus.RUNNING), + # Some pending => no change (handled separately via a second call to transition after enqueuing jobs) + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.PENDING}, PipelineStatus.CREATED), + (PipelineStatus.RUNNING, {1: JobStatus.SUCCEEDED, 2: JobStatus.PENDING}, PipelineStatus.RUNNING), + # All succeeded -> succeeded + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.SUCCEEDED}, PipelineStatus.SUCCEEDED), + # All cancelled -> cancelled + (PipelineStatus.RUNNING, {1: JobStatus.CANCELLED, 2: JobStatus.CANCELLED}, PipelineStatus.CANCELLED), + # Mix of succeeded and skipped -> partial + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.SKIPPED}, PipelineStatus.PARTIAL), + # Mix of succeeded and cancelled -> partial + (PipelineStatus.CREATED, {1: JobStatus.SUCCEEDED, 2: JobStatus.CANCELLED}, PipelineStatus.PARTIAL), + # Mix of cancelled and skipped -> cancelled + (PipelineStatus.CREATED, {1: JobStatus.CANCELLED, 2: JobStatus.SKIPPED}, PipelineStatus.CANCELLED), + ], + ) + def test_pipeline_status_transitions( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + initial_status, + job_updates, + expected_status, + ): + """Test pipeline status transitions based on job status updates.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial pipeline status + manager.set_pipeline_status(initial_status) + session.commit() + + # Update job statuses as per test case + for job_run in sample_pipeline.job_runs: + if job_run.id in job_updates: + job_run.status = job_updates[job_run.id] + session.commit() + + # Perform status transition and verify return state + with TransactionSpy.spy(session): + new_status = manager.transition_pipeline_status() + assert new_status == expected_status + session.commit() + + # Verify expected pipeline status is persisted + pipeline = manager.get_pipeline() + assert pipeline.status == expected_status + + +@pytest.mark.unit +class TestEnqueueReadyJobsUnit: + """Test enqueuing of ready jobs (both independent and dependent).""" + + @pytest.mark.parametrize( + "pipeline_status", + [status for status in PipelineStatus._member_map_.values() if status not in RUNNING_PIPELINE_STATUSES], + ) + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_raises_if_pipeline_not_running(self, mock_pipeline_manager, pipeline_status): + """Test that job enqueuing raises a state error if pipeline is not in RUNNING status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + pytest.raises(PipelineStateError, match="cannot enqueue jobs"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_skips_if_no_jobs(self, mock_pipeline_manager): + """Test that job enqueuing skips if there are no pending jobs.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, + "get_pending_jobs", + return_value=[], + ), + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + # Should complete without error + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "should_skip", + [False, True], + ) + async def test_enqueue_ready_jobs_checks_if_jobs_are_reachable_if_cant_enqueue( + self, mock_pipeline_manager, mock_job_manager, should_skip + ): + """Test that job enqueuing skips jobs which are unreachable if any exist.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, "get_pending_jobs", return_value=[Mock(spec=JobRun, id=1, urn="test:job:1")] + ), + patch.object(mock_pipeline_manager, "can_enqueue_job", return_value=False), + patch.object( + mock_pipeline_manager, "should_skip_job_due_to_dependencies", return_value=(should_skip, "Reason") + ) as mock_should_skip, + patch.object(mock_job_manager, "skip_job", return_value=None) as mock_skip_job, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + mock_should_skip.assert_called_once() + mock_skip_job.assert_called_once() if should_skip else mock_skip_job.assert_not_called() + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_raises_if_arq_enqueue_fails(self, mock_pipeline_manager, mock_job_manager): + """Test that job enqueuing raises an error if ARQ enqueue fails.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, "get_pending_jobs", return_value=[Mock(spec=JobRun, id=1, urn="test:job:1")] + ), + patch.object(mock_pipeline_manager, "can_enqueue_job", return_value=True), + patch.object(mock_job_manager, "prepare_queue", return_value=None) as mock_prepare_queue, + patch.object( + mock_pipeline_manager, "_enqueue_in_arq", side_effect=PipelineCoordinationError("ARQ enqueue failed") + ), + pytest.raises(PipelineCoordinationError, match="ARQ enqueue failed"), + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + mock_prepare_queue.assert_called_once() + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_successful_enqueue(self, mock_pipeline_manager, mock_job_manager): + """Test successful job enqueuing.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + patch.object( + mock_pipeline_manager, "get_pending_jobs", return_value=[Mock(spec=JobRun, id=1, urn="test:job:1")] + ), + patch.object(mock_pipeline_manager, "can_enqueue_job", return_value=True), + patch.object(mock_pipeline_manager, "_enqueue_in_arq", return_value=None) as mock_enqueue, + patch.object(mock_job_manager, "prepare_queue", return_value=None) as mock_prepare_queue, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.enqueue_ready_jobs() + + mock_prepare_queue.assert_called_once() + mock_enqueue.assert_called_once() + + +@pytest.mark.integration +class TestEnqueueReadyJobsIntegration: + """Integration tests for enqueuing of ready jobs.""" + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful enqueuing of ready jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with TransactionSpy.spy(session, expect_flush=True): + await manager.enqueue_ready_jobs() + + # Verify that the independent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is still pending (since its dependency is not yet complete) + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + # Verify the queued ARQ job exists and is the job we expect + arq_job = await arq_redis.queued_jobs() + assert len(arq_job) == 1 + assert arq_job[0].function == sample_job_run.job_function + + # Verify the pipeline is still in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_integration_with_unreachable_job( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + sample_job_dependency, + ): + """Test enqueuing of ready jobs skips unreachable jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Make the dependent job unreachable by setting the sample_job to cancelled. + sample_job_run.status = JobStatus.CANCELLED + session.commit() + + with TransactionSpy.spy(session, expect_flush=True): + await manager.enqueue_ready_jobs() + + # Verify that the dependent job is marked as skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + # Verify nothing was enqueued for the dependent job + arq_job = await arq_redis.queued_jobs() + assert len(arq_job) == 0 + + # Verify the pipeline is still in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_with_empty_pipeline( + self, session, arq_redis, setup_worker_db, sample_empty_pipeline + ): + """Test enqueuing of ready jobs in an empty pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with TransactionSpy.spy(session, expect_flush=True): + await manager.enqueue_ready_jobs() + + # Verify nothing was enqueued + arq_job = await arq_redis.queued_jobs() + assert len(arq_job) == 0 + + # Verify the pipeline is still in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + @pytest.mark.asyncio + async def test_enqueue_ready_jobs_bubbles_pipeline_coordination_error_for_any_exception_during_enqueue( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test that any exception during job enqueuing raises PipelineCoordinationError.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + patch.object( + manager.redis, + "enqueue_job", + side_effect=Exception("Unexpected error during enqueue"), + ), + pytest.raises(PipelineCoordinationError, match="Failed to enqueue job in ARQ"), + ): + await manager.enqueue_ready_jobs() + + +@pytest.mark.unit +class TestCancelRemainingJobsUnit: + """Test cancellation of remaining jobs.""" + + def test_cancel_remaining_jobs_no_active_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test job cancellation when there are no active jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "get_active_jobs", + return_value=[], + ), + patch.object(mock_job_manager, "cancel_job", return_value=None) as mock_cancel_job, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.cancel_remaining_jobs() + + mock_cancel_job.assert_not_called() + + @pytest.mark.parametrize( + "job_status, expected_status", + [(JobStatus.QUEUED, JobStatus.CANCELLED), (JobStatus.RUNNING, JobStatus.CANCELLED)], + ) + def test_cancel_remaining_jobs_cancels_queued_and_running_jobs( + self, mock_pipeline_manager, mock_job_manager, mock_job_run, job_status, expected_status + ): + """Test successful cancellation of remaining jobs.""" + mock_job_run.status = job_status + cancellation_result = {"status": expected_status, "reason": "Pipeline cancelled"} + + with ( + patch.object( + mock_pipeline_manager, + "get_active_jobs", + return_value=[mock_job_run], + ), + patch.object(mock_job_manager, "cancel_job", return_value=None) as mock_cancel_job, + patch( + "mavedb.worker.lib.managers.pipeline_manager.construct_bulk_cancellation_result", + return_value=cancellation_result, + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.cancel_remaining_jobs() + + mock_cancel_job.assert_called_once_with(result=cancellation_result) + + @pytest.mark.parametrize( + "job_status, expected_status", + [ + (JobStatus.PENDING, JobStatus.SKIPPED), + ], + ) + def test_cancel_remaining_jobs_skips_pending_jobs( + self, mock_pipeline_manager, mock_job_manager, mock_job_run, job_status, expected_status + ): + """Test successful cancellation of remaining jobs.""" + mock_job_run.status = job_status + cancellation_result = {"status": expected_status, "reason": "Pipeline cancelled"} + + with ( + patch.object( + mock_pipeline_manager, + "get_active_jobs", + return_value=[mock_job_run], + ), + patch.object(mock_job_manager, "skip_job", return_value=None) as mock_skip_job, + patch( + "mavedb.worker.lib.managers.pipeline_manager.construct_bulk_cancellation_result", + return_value=cancellation_result, + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.cancel_remaining_jobs() + + mock_skip_job.assert_called_once_with(result=cancellation_result) + + +@pytest.mark.integration +class TestCancelRemainingJobsIntegration: + """Integration tests for cancellation of remaining jobs.""" + + def test_cancel_remaining_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful cancellation of remaining jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + manager.cancel_remaining_jobs() + + # Commit the transaction + session.commit() + + # Verify that the running job transitions to cancelled + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the pending dependent job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + def test_cancel_remaining_jobs_integration_no_active_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test cancellation of remaining jobs when there are no active jobs.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + manager.cancel_remaining_jobs() + + # Commit the transaction + session.commit() + + # Should complete without error + + +@pytest.mark.unit +class TestCancelPipelineUnit: + """Test cancellation of pipelines.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + async def test_cancel_pipeline_raises_transition_error_if_already_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test that pipeline cancellation raises an error if already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is in terminal state", + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.cancel_pipeline(reason="Testing cancellation") + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + [status for status in PipelineStatus._member_map_.values() if status not in TERMINAL_PIPELINE_STATUSES], + ) + async def test_cancel_pipeline_successful_cancellation_if_not_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test successful pipeline cancellation if not already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.cancel_pipeline(reason="Testing cancellation") + + mock_coordinate.assert_called_once() + mock_set_status.assert_called_once_with(PipelineStatus.CANCELLED) + + +@pytest.mark.integration +class TestCancelPipelineIntegration: + """Integration tests for cancellation of pipelines.""" + + @pytest.mark.asyncio + async def test_cancel_pipeline_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful cancellation of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.cancel_pipeline(reason="Testing cancellation") + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in CANCELLED status + assert manager.get_pipeline_status() == PipelineStatus.CANCELLED + + # Verify that the running job transitions to cancelled + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the pending dependent job transitions to skipped + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + @pytest.mark.asyncio + async def test_cancel_pipeline_integration_already_terminal( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test that cancelling a pipeline already in terminal status raises an error.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to SUCCEEDED status + manager.set_pipeline_status(PipelineStatus.SUCCEEDED) + session.commit() + + # Set the job status to something that would normally be cancellable + sample_job_run.status = JobStatus.PENDING + session.commit() + + with ( + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {manager.pipeline_id} is in terminal state", + ), + TransactionSpy.spy(session), + ): + await manager.cancel_pipeline(reason="Testing cancellation") + + # Commit the transaction + session.commit() + + # Verify the pipeline status remains SUCCEEDED + assert manager.get_pipeline_status() == PipelineStatus.SUCCEEDED + + # Verify that the job status remains unchanged + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestPausePipelineUnit: + """Test pausing of pipelines.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + async def test_pause_pipeline_raises_transition_error_if_already_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test that pipeline pausing raises an error if already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is in terminal state", + ), + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + ): + await mock_pipeline_manager.pause_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_pause_pipeline_raises_transition_error_if_already_paused(self, mock_pipeline_manager): + """Test that pipeline pausing raises an error if already paused.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.PAUSED), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is already paused", + ), + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + ): + await mock_pipeline_manager.pause_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_status", + [ + status + for status in PipelineStatus._member_map_.values() + if status not in TERMINAL_PIPELINE_STATUSES and status != PipelineStatus.PAUSED + ], + ) + async def test_pause_pipeline_successful_pausing_if_not_in_terminal_status( + self, mock_pipeline_manager, pipeline_status + ): + """Test successful pipeline pausing if not already in terminal status.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_status), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.pause_pipeline() + + mock_coordinate.assert_called_once() + mock_set_status.assert_called_once_with(PipelineStatus.PAUSED) + + +@pytest.mark.integration +class TestPausePipelineIntegration: + """Integration tests for pausing of pipelines.""" + + @pytest.mark.asyncio + async def test_pause_pipeline_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + ): + """Test successful pausing of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.pause_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in PAUSED status + assert manager.get_pipeline_status() == PipelineStatus.PAUSED + + # Verify that all jobs remain in their original statuses + # (coordinate_pipeline is called by pause_pipeline but should not change job statuses + # while paused). + for job_run in sample_pipeline.job_runs: + assert job_run.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestUnpausePipelineUnit: + """Test unpausing of pipelines.""" + + @pytest.mark.asyncio + async def test_unpause_pipeline_raises_transition_error_if_not_paused(self, mock_pipeline_manager): + """Test that pipeline unpausing raises an error if not currently paused.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.RUNNING), + pytest.raises( + PipelineTransitionError, + match=f"Pipeline {mock_pipeline_manager.pipeline_id} is not paused", + ), + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + ): + await mock_pipeline_manager.unpause_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_unpause_pipeline_successful_unpausing_if_currently_paused(self, mock_pipeline_manager): + """Test successful pipeline unpausing if currently paused.""" + with ( + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.PAUSED), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.unpause_pipeline() + + mock_coordinate.assert_called_once() + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + + +@pytest.mark.integration +class TestUnpausePipelineIntegration: + """Integration tests for unpausing of pipelines.""" + + @pytest.mark.asyncio + async def test_unpause_pipeline_integration( + self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test successful unpausing of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to PAUSED status + manager.set_pipeline_status(PipelineStatus.PAUSED) + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.unpause_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the non-dependent job was queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + +@pytest.mark.unit +class TestRestartPipelineUnit: + """Test restarting of pipelines.""" + + @pytest.mark.asyncio + async def test_restart_pipeline_skips_if_no_jobs_in_pipeline(self, mock_pipeline_manager): + """Test that pipeline restart skips if there are no jobs in the pipeline.""" + with ( + patch.object( + mock_pipeline_manager, + "get_all_jobs", + return_value=[], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.restart_pipeline() + + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_restart_pipeline_successful_restart(self, mock_pipeline_manager, mock_job_manager): + """Test successful pipeline restart.""" + with ( + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + patch.object( + mock_pipeline_manager, + "get_all_jobs", + return_value=[Mock(spec=JobRun, id=1), Mock(spec=JobRun, id=2)], + ), + patch.object( + mock_job_manager, + "reset_job", + return_value=None, + ) as mock_reset_job, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.restart_pipeline() + + assert mock_reset_job.call_count == 2 + mock_set_status.assert_called_once_with(PipelineStatus.CREATED) + mock_start_pipeline.assert_called_once() + + +@pytest.mark.integration +class TestRestartPipelineIntegration: + """Integration tests for restarting of pipelines.""" + + @pytest.mark.asyncio + async def test_restart_pipeline_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful restarting of a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job statuses to terminal states + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.restart_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the non-dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is now pending + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + async def test_restart_pipeline_integration_skips_if_no_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test that restarting a pipeline with no jobs skips without error.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to a terminal status + manager.set_pipeline_status(PipelineStatus.SUCCEEDED) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + await manager.restart_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status remains unchanged + assert manager.get_pipeline_status() == PipelineStatus.SUCCEEDED + + +@pytest.mark.unit +class TestCanEnqueueJobUnit: + """Test job dependency checking.""" + + def test_can_enqueue_job_with_no_dependencies(self, mock_pipeline_manager): + """Test that a job with no dependencies can be enqueued.""" + mock_job = Mock(spec=JobRun, id=1) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[], + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.can_enqueue_job(mock_job) + + assert result is True + + def test_cannot_enqueue_job_with_unmet_dependencies(self, mock_pipeline_manager): + """Test that a job with unmet dependencies cannot be enqueued.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.PENDING) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_dependency_is_met", return_value=False + ) as mock_job_dependency_is_met, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.can_enqueue_job(mock_job) + + mock_job_dependency_is_met.assert_called_once_with( + dependency_type=DependencyType.COMPLETION_REQUIRED, dependent_job_status=JobStatus.PENDING + ) + assert result is False + + def test_can_enqueue_job_with_met_dependencies(self, mock_pipeline_manager): + """Test that a job with met dependencies can be enqueued.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_dependency_is_met", return_value=True + ) as mock_job_dependency_is_met, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + result = mock_pipeline_manager.can_enqueue_job(mock_job) + + mock_job_dependency_is_met.assert_called_once_with( + dependency_type=DependencyType.COMPLETION_REQUIRED, dependent_job_status=JobStatus.SUCCEEDED + ) + assert result is True + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_can_enqueue_job_raises_pipeline_state_error_on_handled_exceptions(self, mock_pipeline_manager, exception): + """Test that handled exceptions during dependency checking raise PipelineStateError.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch("mavedb.worker.lib.managers.pipeline_manager.job_dependency_is_met", side_effect=exception), + pytest.raises(PipelineStateError, match="Corrupted dependency data"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.can_enqueue_job(mock_job) + + +@pytest.mark.integration +class TestCanEnqueueJobIntegration: + """Integration tests for job dependency checking.""" + + def test_can_enqueue_job_integration_with_no_dependencies( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test that a job with no dependencies can be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + result = manager.can_enqueue_job(sample_job_run) + + assert result is True + + def test_can_enqueue_job_integration_with_unmet_dependencies( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_dependent_job_run, + ): + """Test that a job with unmet dependencies cannot be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + result = manager.can_enqueue_job(sample_dependent_job_run) + + assert result is False + + def test_can_enqueue_job_integration_with_met_dependencies( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test that a job with met dependencies can be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the dependency job to a succeeded status + sample_job_run.status = JobStatus.SUCCEEDED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + result = manager.can_enqueue_job(sample_dependent_job_run) + + assert result is True + + +@pytest.mark.unit +class TestShouldSkipJobDueToDependenciesUnit: + """Test job skipping due to unmet dependencies.""" + + def test_should_not_skip_job_with_no_dependencies(self, mock_pipeline_manager): + """Test that a job with no dependencies should not be skipped.""" + mock_job = Mock(spec=JobRun, id=1) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + return_value=(False, ""), + ) as mock_job_should_be_skipped, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + should_skip, reason = mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + + mock_job_should_be_skipped.assert_not_called() + assert should_skip is False + assert reason == "" + + def test_should_skip_job_with_unreachable_dependency(self, mock_pipeline_manager): + """Test that a job with unreachable dependencies should be skipped.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.FAILED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.SUCCESS_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + return_value=(True, "Unfulfillable dependency detected"), + ) as mock_job_should_be_skipped, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + should_skip, reason = mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + + mock_job_should_be_skipped.assert_called_once_with( + dependency_type=DependencyType.SUCCESS_REQUIRED, dependent_job_status=JobStatus.FAILED + ) + assert should_skip is True + assert reason == "Unfulfillable dependency detected" + + def test_should_not_skip_job_with_reachable(self, mock_pipeline_manager): + """Test that a job with met dependencies can be enqueued.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + return_value=(False, ""), + ) as mock_job_should_be_skipped, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + should_skip, reason = mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + mock_job_should_be_skipped.assert_called_once_with( + dependency_type=DependencyType.COMPLETION_REQUIRED, dependent_job_status=JobStatus.SUCCEEDED + ) + assert should_skip is False + assert reason == "" + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_should_skip_job_due_to_dependencies_raises_pipeline_state_error_on_handled_exceptions( + self, mock_pipeline_manager, exception + ): + """Test that handled exceptions during dependency checking raise PipelineStateError.""" + mock_job = Mock(spec=JobRun, id=1, status=JobStatus.SUCCEEDED) + mock_dependency = Mock(spec=JobDependency, dependency_type=DependencyType.COMPLETION_REQUIRED) + + with ( + patch.object( + mock_pipeline_manager, + "get_dependencies_for_job", + return_value=[(mock_dependency, mock_job)], + ), + patch( + "mavedb.worker.lib.managers.pipeline_manager.job_should_be_skipped_due_to_unfulfillable_dependency", + side_effect=exception, + ), + pytest.raises(PipelineStateError, match="Corrupted dependency data"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.should_skip_job_due_to_dependencies(mock_job) + + +@pytest.mark.integration +class TestShouldSkipJobDueToDependenciesIntegration: + """Integration tests for job skipping due to unmet dependencies.""" + + def test_should_not_skip_job_with_no_dependencies( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test that a job with no dependencies should not be skipped.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + should_skip, reason = manager.should_skip_job_due_to_dependencies(sample_job_run) + + assert should_skip is False + assert reason == "" + + def test_should_skip_job_with_unreachable_dependency( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test that a job with unreachable dependencies should be skipped.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the job the dependency depends on to a failed status + sample_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + should_skip, reason = manager.should_skip_job_due_to_dependencies(sample_dependent_job_run) + + assert should_skip is True + assert reason == "Dependency did not succeed (failed)" + + def test_should_not_skip_job_with_reachable_dependency( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test that a job with met dependencies can be enqueued.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the dependency job to a succeeded status + sample_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + should_skip, reason = manager.should_skip_job_due_to_dependencies(sample_dependent_job_run) + + assert should_skip is False + assert reason == "" + + +@pytest.mark.unit +class TestRetryFailedJobsUnit: + """Test retrying of failed jobs.""" + + @pytest.mark.asyncio + async def test_retry_failed_jobs_no_failed_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test that retrying failed jobs skips if there are no failed jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "get_failed_jobs", + return_value=[], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.retry_failed_jobs() + + mock_prepare_retry.assert_not_called() + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_retry_failed_jobs_successful_retry(self, mock_pipeline_manager, mock_job_manager): + """Test successful retrying of failed jobs.""" + mock_failed_job1 = Mock(spec=JobRun, id=1) + mock_failed_job2 = Mock(spec=JobRun, id=2) + + with ( + patch.object( + mock_pipeline_manager, + "get_failed_jobs", + return_value=[mock_failed_job1, mock_failed_job2], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object( + mock_job_manager, + "prepare_retry", + return_value=None, + ) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.retry_failed_jobs() + + assert mock_prepare_retry.call_count == 2 + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + mock_coordinate.assert_called_once() + + +@pytest.mark.integration +class TestRetryFailedJobsIntegration: + """Integration tests for retrying of failed jobs.""" + + @pytest.mark.asyncio + async def test_retry_failed_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful retrying of failed jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.retry_failed_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the failed job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the dependent job is still pending + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + async def test_retry_failed_jobs_integration_no_failed_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test that retrying failed jobs skips if there are no failed jobs.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + await manager.retry_failed_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is not changed + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + +@pytest.mark.unit +class TestRetryUnsuccessfulJobsUnit: + """Test retrying of unsuccessful jobs.""" + + @pytest.mark.asyncio + async def test_retry_unsuccessful_jobs_no_unsuccessful_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test that retrying unsuccessful jobs skips if there are no unsuccessful jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "get_unsuccessful_jobs", + return_value=[], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager.retry_unsuccessful_jobs() + + mock_prepare_retry.assert_not_called() + mock_set_status.assert_not_called() + mock_coordinate.assert_not_called() + + @pytest.mark.asyncio + async def test_retry_failed_jobs_successful_retry(self, mock_pipeline_manager, mock_job_manager): + """Test successful retrying of failed jobs.""" + mock_failed_job1 = Mock(spec=JobRun, id=1) + mock_failed_job2 = Mock(spec=JobRun, id=2) + + with ( + patch.object( + mock_pipeline_manager, + "get_unsuccessful_jobs", + return_value=[mock_failed_job1, mock_failed_job2], + ), + patch.object(mock_pipeline_manager, "set_pipeline_status", return_value=None) as mock_set_status, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, + patch.object( + mock_job_manager, + "prepare_retry", + return_value=None, + ) as mock_prepare_retry, + TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), + ): + await mock_pipeline_manager.retry_unsuccessful_jobs() + + assert mock_prepare_retry.call_count == 2 + mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) + mock_coordinate.assert_called_once() + + +@pytest.mark.integration +class TestRetryUnsuccessfulJobsIntegration: + """Integration tests for retrying of unsuccessful jobs.""" + + @pytest.mark.asyncio + async def test_retry_unsuccessful_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful retrying of unsuccessful jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.CANCELLED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.retry_unsuccessful_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the failed job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the cancelled dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + @pytest.mark.asyncio + async def test_retry_unsuccessful_jobs_integration_no_unsuccessful_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test that retrying unsuccessful jobs skips if there are no unsuccessful jobs.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + await manager.retry_unsuccessful_jobs() + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is not changed + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + +@pytest.mark.unit +class TestRetryPipelineUnit: + """Test retrying of entire pipelines.""" + + @pytest.mark.asyncio + async def test_retry_pipeline_calls_retry_unsuccessful_jobs(self, mock_pipeline_manager, mock_job_manager): + """Test that retrying a pipeline calls retrying unsuccessful jobs.""" + with ( + patch.object( + mock_pipeline_manager, + "retry_unsuccessful_jobs", + return_value=None, + ) as mock_retry_unsuccessful_jobs, + TransactionSpy.spy(mock_pipeline_manager.db), # flush is handled in retry_unsuccessful_jobs, which we mock + ): + await mock_pipeline_manager.retry_pipeline() + + mock_retry_unsuccessful_jobs.assert_called_once() + + +@pytest.mark.integration +class TestRetryPipelineIntegration: + """Integration tests for retrying of entire pipelines.""" + + @pytest.mark.asyncio + async def test_retry_pipeline_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test successful retrying of an entire pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set the pipeline to RUNNING status + manager.set_pipeline_status(PipelineStatus.RUNNING) + session.commit() + + # Set the job statuses + sample_job_run.status = JobStatus.CANCELLED + sample_dependent_job_run.status = JobStatus.SKIPPED + session.commit() + + with ( + TransactionSpy.spy(session, expect_flush=True), + ): + await manager.retry_pipeline() + + # Commit the transaction + session.commit() + + # Verify that the pipeline is now in RUNNING status + assert manager.get_pipeline_status() == PipelineStatus.RUNNING + + # Verify that the failed job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify that the cancelled dependent job is now queued + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + + +@pytest.mark.unit +class TestGetJobsByStatusUnit: + """Test job retrieval by status with mocked database.""" + + def test_get_jobs_by_status_wraps_sqlalchemy_error_with_database_error(self, mock_pipeline_manager): + """Test database error handling.""" + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get jobs with status"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_jobs_by_status([JobStatus.RUNNING]) + + +@pytest.mark.integration +class TestGetJobsByStatusIntegration: + """Integration tests for job retrieval by status.""" + + @pytest.mark.parametrize( + "status", + JobStatus._member_map_.values(), + ) + def test_get_jobs_by_status_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + status, + ): + """Test retrieval of jobs by status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = status + sample_dependent_job_run.status = [s for s in JobStatus if s != status][0] + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_jobs_by_status([status]) + + assert len(running_jobs) == 1 + assert running_jobs[0].id == sample_job_run.id + + def test_get_jobs_by_status_integration_no_matching_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + ): + """Test retrieval of jobs by status when no jobs match.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + jobs = manager.get_jobs_by_status([JobStatus.SUCCEEDED]) + + assert len(jobs) == 0 + + def test_get_jobs_by_status_integration_multiple_matching_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of jobs by status when multiple jobs match.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set both job statuses to RUNNING + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.RUNNING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_jobs_by_status([JobStatus.RUNNING]) + + assert len(running_jobs) == 2 + job_ids = {job.id for job in running_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_jobs_by_status_integration_no_jobs_in_pipeline( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test retrieval of jobs by status when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + jobs = manager.get_jobs_by_status([JobStatus.RUNNING]) + + assert len(jobs) == 0 + + def test_get_jobs_by_status_multiple_statuses( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of jobs by multiple statuses.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + jobs = manager.get_jobs_by_status([JobStatus.RUNNING, JobStatus.PENDING]) + + assert len(jobs) == 2 + job_ids = {job.id for job in jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + # Assert jobs are ordered by created by timestamp + assert jobs[0].created_at <= jobs[1].created_at + + +@pytest.mark.unit +class TestGetPendingJobsUnit: + """Test retrieval of pending jobs.""" + + def test_get_pending_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of pending jobs.""" + + with ( + patch.object( + mock_pipeline_manager, "get_jobs_by_status", return_value=[Mock(), Mock()] + ) as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + jobs = mock_pipeline_manager.get_pending_jobs() + + assert len(jobs) == 2 + mock_get_jobs_by_status.assert_called_once_with([JobStatus.PENDING]) + + +@pytest.mark.integration +class TestGetPendingJobsIntegration: + """Integration tests for retrieval of pending jobs.""" + + def test_get_pending_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of pending jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.PENDING + sample_dependent_job_run.status = JobStatus.RUNNING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + pending_jobs = manager.get_pending_jobs() + + assert len(pending_jobs) == 1 + assert pending_jobs[0].id == sample_job_run.id + + def test_get_pending_jobs_integration_no_pending_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of pending jobs when there are no pending jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.SUCCEEDED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + pending_jobs = manager.get_pending_jobs() + + assert len(pending_jobs) == 0 + + +@pytest.mark.unit +class TestGetRunningJobsUnit: + """Test retrieval of running jobs.""" + + def test_get_running_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of running jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_running_jobs() + mock_get_jobs_by_status.assert_called_once_with([JobStatus.RUNNING]) + + +@pytest.mark.unit +class TestGetActiveJobsUnit: + """Test retrieval of active jobs.""" + + def test_get_active_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of active jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_active_jobs() + mock_get_jobs_by_status.assert_called_once_with(ACTIVE_JOB_STATUSES) + + +@pytest.mark.integration +class TestGetActiveJobsIntegration: + """Integration tests for retrieval of active jobs.""" + + def test_get_active_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of active jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + active_jobs = manager.get_active_jobs() + + assert len(active_jobs) == 2 + job_ids = {job.id for job in active_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_active_jobs_integration_no_active_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of active jobs when there are no active jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.FAILED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + active_jobs = manager.get_active_jobs() + + assert len(active_jobs) == 0 + + +@pytest.mark.integration +class TestGetRunningJobsIntegration: + """Integration tests for retrieval of running jobs.""" + + def test_get_running_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of running jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_running_jobs() + + assert len(running_jobs) == 1 + assert running_jobs[0].id == sample_job_run.id + + def test_get_running_jobs_integration_no_running_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of running jobs when there are no running jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + running_jobs = manager.get_running_jobs() + + assert len(running_jobs) == 0 + + +@pytest.mark.unit +class TestGetFailedJobsUnit: + """Test retrieval of failed jobs.""" + + def test_get_failed_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of failed jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_failed_jobs() + + mock_get_jobs_by_status.assert_called_once_with([JobStatus.FAILED]) + + +@pytest.mark.integration +class TestGetFailedJobsIntegration: + """Integration tests for retrieval of failed jobs.""" + + def test_get_failed_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of failed jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + failed_jobs = manager.get_failed_jobs() + + assert len(failed_jobs) == 1 + assert failed_jobs[0].id == sample_job_run.id + + def test_get_failed_jobs_integration_no_failed_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of failed jobs when there are no failed jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + failed_jobs = manager.get_failed_jobs() + + assert len(failed_jobs) == 0 + + +@pytest.mark.unit +class TestGetUnsuccessfulJobsUnit: + """Test retrieval of unsuccessful jobs.""" + + def test_get_unsuccessful_jobs_success(self, mock_pipeline_manager): + """Test successful retrieval of unsuccessful jobs.""" + + with ( + patch.object(mock_pipeline_manager, "get_jobs_by_status") as mock_get_jobs_by_status, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_unsuccessful_jobs() + mock_get_jobs_by_status.assert_called_once_with([JobStatus.CANCELLED, JobStatus.SKIPPED, JobStatus.FAILED]) + + +@pytest.mark.integration +class TestGetUnsuccessfulJobsIntegration: + """Integration tests for retrieval of unsuccessful jobs.""" + + def test_get_unsuccessful_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of unsuccessful jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.FAILED + sample_dependent_job_run.status = JobStatus.CANCELLED + session.commit() + + with ( + TransactionSpy.spy(session), + ): + unsuccessful_jobs = manager.get_unsuccessful_jobs() + + assert len(unsuccessful_jobs) == 2 + job_ids = {job.id for job in unsuccessful_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_unsuccessful_jobs_integration_no_unsuccessful_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of unsuccessful jobs when there are no unsuccessful jobs.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.SUCCEEDED + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + unsuccessful_jobs = manager.get_unsuccessful_jobs() + + assert len(unsuccessful_jobs) == 0 + + +@pytest.mark.unit +class TestGetAllJobsUnit: + """Test retrieval of all jobs.""" + + def test_get_all_jobs_wraps_sqlalchemy_errors_with_database_error(self, mock_pipeline_manager): + """Test database error handling during retrieval of all jobs.""" + + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get all jobs"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_all_jobs() + + +@pytest.mark.integration +class TestGetAllJobsIntegration: + """Integration tests for retrieval of all jobs.""" + + def test_get_all_jobs_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of all jobs in a pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + all_jobs = manager.get_all_jobs() + + assert len(all_jobs) == 2 + job_ids = {job.id for job in all_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + + def test_get_all_jobs_integration_no_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test retrieval of all jobs when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + all_jobs = manager.get_all_jobs() + + assert len(all_jobs) == 0 + + def test_get_all_jobs_integration_multiple_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of all jobs when there are multiple jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Add an additional job to the pipeline + new_job = JobRun( + id=99, + urn="job:additional_job:999", + pipeline_id=sample_pipeline.id, + job_type="Additional Job", + job_function="additional_function", + status=JobStatus.PENDING, + ) + session.add(new_job) + session.commit() + + with ( + TransactionSpy.spy(session), + ): + all_jobs = manager.get_all_jobs() + + assert len(all_jobs) == 3 + job_ids = {job.id for job in all_jobs} + assert sample_job_run.id in job_ids + assert sample_dependent_job_run.id in job_ids + assert new_job.id in job_ids + + # Assert jobs are ordered by created by timestamp + assert all_jobs[0].created_at <= all_jobs[1].created_at <= all_jobs[2].created_at + + +@pytest.mark.unit +class TestGetDependenciesForJobUnit: + """Test retrieval of job dependencies.""" + + def test_get_dependencies_for_job_wraps_sqlalchemy_error_with_database_error(self, mock_pipeline_manager): + """Test database error handling during retrieval of job dependencies.""" + mock_job = Mock(spec=JobRun) + + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get job dependencies for job"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_dependencies_for_job(mock_job) + + +@pytest.mark.integration +class TestGetDependenciesForJobIntegration: + """Integration tests for retrieval of job dependencies.""" + + def test_get_dependencies_for_job_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + sample_job_dependency, + ): + """Test retrieval of job dependencies.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + dependencies = manager.get_dependencies_for_job(sample_dependent_job_run) + + assert len(dependencies) == 1 + dependency, job = dependencies[0] + assert dependency.id == sample_job_dependency.id + assert job.id == sample_job_run.id + + def test_get_dependencies_for_job_integration_no_dependencies( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test retrieval of job dependencies when there are no dependencies.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + dependencies = manager.get_dependencies_for_job(sample_job_run) + + assert len(dependencies) == 0 + + def test_get_dependencies_for_job_integration_multiple_dependencies( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of job dependencies when there are multiple dependencies.""" + # Create additional job and dependency + additional_job = JobRun( + id=99, + urn="job:additional_job:999", + pipeline_id=sample_pipeline.id, + job_type="Additional Job", + job_function="additional_function", + status=JobStatus.PENDING, + ) + session.add(additional_job) + session.commit() + + additional_dependency = JobDependency( + id=sample_dependent_job_run.id, + depends_on_job_id=additional_job.id, + dependency_type=DependencyType.COMPLETION_REQUIRED, + ) + session.add(additional_dependency) + session.commit() + + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + dependencies = manager.get_dependencies_for_job(sample_dependent_job_run) + + assert len(dependencies) == 2 + fetched_dependency_ids = {dep.id for dep, job in dependencies} + implicit_dependency_ids = {dep.id for dep in sample_dependent_job_run.job_dependencies} + assert fetched_dependency_ids == implicit_dependency_ids + + +@pytest.mark.unit +class TestGetPipelineUnit: + """Test retrieval of pipeline.""" + + def test_get_pipeline_wraps_sqlalchemy_errors_with_database_error(self, mock_pipeline): + """Test database error handling during retrieval of pipeline.""" + + # Prepare mock PipelineManager with mocked DB session that will raise SQLAlchemyError on query. + # We don't use the default fixture here since it usually wraps this function. + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + manager = object.__new__(PipelineManager) + manager.db = mock_db + manager.redis = mock_redis + manager.pipeline_id = mock_pipeline.id + + with ( + patch.object(manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get pipeline"), + TransactionSpy.spy(manager.db), + ): + manager.get_pipeline() + + +@pytest.mark.integration +class TestGetPipelineIntegration: + """Integration tests for retrieval of pipeline.""" + + def test_get_pipeline_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + ): + """Test retrieval of pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + pipeline = manager.get_pipeline() + + assert pipeline.id == sample_pipeline.id + assert pipeline.name == sample_pipeline.name + + def test_get_pipeline_integration_nonexistent_pipeline( + self, + session, + arq_redis, + setup_worker_db, + ): + """Test retrieval of a nonexistent pipeline raises PipelineNotFoundError.""" + with ( + pytest.raises(DatabaseConnectionError, match="Failed to get pipeline 9999"), + TransactionSpy.spy(session), + ): + # get_pipeline is called implicitly during PipelineManager initialization + PipelineManager(session, arq_redis, pipeline_id=9999) + + +@pytest.mark.unit +class TestGetJobCountsByStatusUnit: + """Test retrieval of job counts by status.""" + + def test_get_job_counts_by_status_wraps_sqlalchemy_errors_with_database_error(self, mock_pipeline_manager): + """Test database error handling during retrieval of job counts by status.""" + + with ( + patch.object(mock_pipeline_manager.db, "execute", side_effect=SQLAlchemyError("DB error")), + pytest.raises(DatabaseConnectionError, match="Failed to get job counts for pipeline"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.get_job_counts_by_status() + + +@pytest.mark.integration +class TestGetJobCountsByStatusIntegration: + """Integration tests for retrieval of job counts by status.""" + + def test_get_job_counts_by_status_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test retrieval of job counts by status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set job statuses + sample_job_run.status = JobStatus.RUNNING + sample_dependent_job_run.status = JobStatus.PENDING + session.commit() + + with ( + TransactionSpy.spy(session), + ): + counts = manager.get_job_counts_by_status() + + assert counts[JobStatus.RUNNING] == 1 + assert counts[JobStatus.PENDING] == 1 + assert counts.get(JobStatus.SUCCEEDED, 0) == 0 + + def test_get_job_counts_by_status_integration_no_jobs( + self, + session, + arq_redis, + setup_worker_db, + sample_empty_pipeline, + ): + """Test retrieval of job counts by status when there are no jobs in the pipeline.""" + manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + counts = manager.get_job_counts_by_status() + + assert counts == {} + + +@pytest.mark.unit +class TestGetPipelineProgressUnit: + """Test retrieval of pipeline progress.""" + + pass + + +@pytest.mark.integration +class TestGetPipelineProgressIntegration: + """Integration tests for retrieval of pipeline progress.""" + + pass + + +@pytest.mark.unit +class TestGetPipelineStatusUnit: + """Test retrieval of pipeline status.""" + + def test_get_pipeline_status_success(self, mock_pipeline_manager): + """Test successful retrieval of pipeline status.""" + with ( + TransactionSpy.spy(mock_pipeline_manager.db), + patch.object( + mock_pipeline_manager, + "get_pipeline", + wraps=mock_pipeline_manager.get_pipeline, + ) as mock_get_pipeline, + ): + mock_pipeline_manager.get_pipeline_status() + mock_get_pipeline.assert_called_once() + + +@pytest.mark.integration +class TestGetPipelineStatusIntegration: + """Integration tests for retrieval of pipeline status.""" + + def test_get_pipeline_status_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + ): + """Test retrieval of pipeline status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + status = manager.get_pipeline_status() + + assert status == sample_pipeline.status + + +@pytest.mark.unit +class TestSetPipelineStatusUnit: + """Test setting of pipeline status.""" + + @pytest.mark.parametrize("pipeline_status", [status for status in PipelineStatus._member_map_.values()]) + def test_set_pipeline_status_success(self, mock_pipeline_manager, pipeline_status): + """Test successful setting of pipeline status.""" + mock_pipeline = Mock(spec=Pipeline, status=None) + + with ( + patch.object( + mock_pipeline_manager, + "get_pipeline", + return_value=mock_pipeline, + ) as mock_get_pipeline, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + mock_pipeline_manager.set_pipeline_status(pipeline_status) + assert mock_pipeline.status == pipeline_status + + mock_get_pipeline.assert_called_once() + + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_set_pipeline_status_sets_finished_at_property_for_terminal_status( + self, mock_pipeline_manager, mock_pipeline, pipeline_status + ): + """Test that setting a terminal status updates the finished_at property.""" + # Set initial finished_at to None + mock_pipeline.finished_at = None + + with TransactionSpy.spy(mock_pipeline_manager.db): + before_update = datetime.datetime.now() + mock_pipeline_manager.set_pipeline_status(pipeline_status) + after_update = datetime.datetime.now() + + assert mock_pipeline.status == pipeline_status + assert mock_pipeline.finished_at is not None + assert before_update <= mock_pipeline.finished_at <= after_update + + def test_set_pipeline_status_clears_started_at_property_for_created_status( + self, mock_pipeline_manager, mock_pipeline + ): + """Test that setting status to CREATED clears the started_at property.""" + + with TransactionSpy.spy(mock_pipeline_manager.db): + mock_pipeline_manager.set_pipeline_status(PipelineStatus.CREATED) + assert mock_pipeline.status == PipelineStatus.CREATED + assert mock_pipeline.started_at is None + + @pytest.mark.parametrize( + "initial_started_at", + [None, datetime.datetime.now() - datetime.timedelta(hours=1)], + ) + def test_set_pipeline_status_sets_started_at_property_for_running_status( + self, mock_pipeline_manager, mock_pipeline, initial_started_at + ): + """Test that setting status to RUNNING sets the started_at property if not already set.""" + mock_pipeline.started_at = initial_started_at + with TransactionSpy.spy(mock_pipeline_manager.db): + before_update = datetime.datetime.now() + mock_pipeline_manager.set_pipeline_status(PipelineStatus.RUNNING) + after_update = datetime.datetime.now() + + assert mock_pipeline.status == PipelineStatus.RUNNING + + if initial_started_at is None: + assert mock_pipeline.started_at is not None + assert before_update <= mock_pipeline.started_at <= after_update + else: + assert mock_pipeline.started_at == initial_started_at + + @pytest.mark.parametrize( + "exception", + HANDLED_EXCEPTIONS_DURING_OBJECT_MANIPULATION, + ) + def test_set_pipeline_status_handled_exception_raises_pipeline_state_error(self, mock_pipeline_manager, exception): + """Test that handled exceptions during setting of pipeline status raise PipelineStateError.""" + + def get_or_error(*args): + if args: + raise exception + return PipelineStatus.CREATED + + with ( + patch.object(mock_pipeline_manager, "get_pipeline") as mock_pipeline, + pytest.raises(PipelineStateError, match="Failed to set pipeline status"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + # Mock exception when setting pipeline status + mock_pipeline.return_value = Mock(spec=Pipeline) + type(mock_pipeline.return_value).status = PropertyMock(side_effect=get_or_error) + + mock_pipeline_manager.set_pipeline_status(PipelineStatus.RUNNING) + + +@pytest.mark.integration +class TestSetPipelineStatusIntegration: + """Integration tests for setting of pipeline status.""" + + @pytest.mark.parametrize("pipeline_status", [status for status in PipelineStatus._member_map_.values()]) + def test_set_pipeline_status_integration( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + pipeline_status, + ): + """Test setting of pipeline status.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + manager.set_pipeline_status(pipeline_status) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is updated + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == pipeline_status + + @pytest.mark.parametrize( + "pipeline_status", + TERMINAL_PIPELINE_STATUSES, + ) + def test_set_pipeline_status_integration_terminal_status_sets_finished_at( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + pipeline_status, + ): + """Test that setting a terminal status updates the finished_at property.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + before_update = datetime.datetime.now(tz=datetime.timezone.utc) + manager.set_pipeline_status(pipeline_status) + after_update = datetime.datetime.now(tz=datetime.timezone.utc) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status and finished_at are updated + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == pipeline_status + assert updated_pipeline.finished_at is not None + assert before_update <= updated_pipeline.finished_at <= after_update + + def test_set_pipeline_status_integration_created_status_clears_started_at( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + ): + """Test that setting status to CREATED clears the started_at property.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with TransactionSpy.spy(session): + manager.set_pipeline_status(PipelineStatus.CREATED) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status is updated and started_at is None + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == PipelineStatus.CREATED + assert updated_pipeline.started_at is None + + @pytest.mark.parametrize( + "initial_started_at", + [None, datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(hours=1)], + ) + def test_set_pipeline_status_integration_running_status_sets_started_at( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + initial_started_at, + ): + """Test that setting status to RUNNING sets the started_at property if not already set.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Set initial started_at + sample_pipeline.started_at = initial_started_at + session.commit() + + with TransactionSpy.spy(session): + before_update = datetime.datetime.now(tz=datetime.timezone.utc) + manager.set_pipeline_status(PipelineStatus.RUNNING) + after_update = datetime.datetime.now(tz=datetime.timezone.utc) + + # Commit the transaction + session.commit() + + # Verify that the pipeline status and started_at are updated + updated_pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert updated_pipeline.status == PipelineStatus.RUNNING + + if initial_started_at is None: + assert before_update <= updated_pipeline.started_at <= after_update + else: + assert updated_pipeline.started_at == initial_started_at + + +@pytest.mark.unit +class TestEnqueueInArqUnit: + """Test enqueuing jobs in ARQ.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("enqueud", [Mock(spec=ArqJob), None]) + @pytest.mark.parametrize("retry", [True, False]) + async def test_enqueue_in_arq_success(self, mock_pipeline_manager, retry, enqueud): + """Test successful enqueuing of a job in ARQ.""" + mock_job = Mock(spec=JobRun, job_function="test_func", id=1, urn="urn:example", retry_delay_seconds=10) + with ( + patch.object(mock_pipeline_manager.redis, "enqueue_job", return_value=enqueud) as mock_enqueue_job, + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager._enqueue_in_arq(job=mock_job, is_retry=retry) + + mock_enqueue_job.assert_called_once_with( + mock_job.job_function, + mock_job.id, + _defer_by=datetime.timedelta(seconds=mock_job.retry_delay_seconds if retry else 0), + _job_id=mock_job.urn, + ) + + @pytest.mark.asyncio + async def test_any_enqueue_exception_raises_pipeline_coordination_error(self, mock_pipeline_manager): + """Test that any exception during enqueuing raises PipelineCoordinationError.""" + mock_job = Mock(spec=JobRun, job_function="test_func", id=1, urn="urn:example", retry_delay_seconds=10) + + with ( + patch.object( + mock_pipeline_manager.redis, + "enqueue_job", + side_effect=Exception("Test exception"), + ), + pytest.raises(PipelineCoordinationError, match="Failed to enqueue job in ARQ"), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager._enqueue_in_arq(job=mock_job, is_retry=False) + + +@pytest.mark.integration +class TestEnqueueInArqIntegration: + """Integration tests for enqueuing jobs in ARQ.""" + + @pytest.mark.asyncio + async def test_enqueue_in_arq_integration( + self, + session, + arq_redis: ArqRedis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test enqueuing of a job in ARQ.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + with ( + TransactionSpy.spy(session), + ): + await manager._enqueue_in_arq(job=sample_job_run, is_retry=False) + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + +@pytest.mark.integration +class TestPipelineManagerLifecycle: + """Integration tests for PipelineManager lifecycle.""" + + @pytest.mark.asyncio + async def test_full_pipeline_lifecycle( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test full lifecycle of PipelineManager including initialization and job retrieval.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # pipeline is created with pending jobs + pipeline = manager.get_pipeline() + all_jobs = manager.get_all_jobs() + + assert pipeline.id == sample_pipeline.id + assert len(all_jobs) == 2 + assert all_jobs[0].id == sample_job_run.id + assert all_jobs[0].status == JobStatus.PENDING + + # pipeline started + await manager.start_pipeline() + session.commit() + + # verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Simulate pipeline lifecycle for a two job sample pipeline. The workflow here should be as follows: + # - Enter pipeline manager decorator. We don't make any calls when a pipeline begins + # - Enter the job manager decorator. This sets the job to RUNNING. + # - Job runs... + # - Exit the job manager decorator. This sets the job to some terminal state. + # - Exit the pipeline manager decorator. This coordinates the pipeline, either + # enqueuing any newly queueable jobs or terminating it. + + # enter pipeline manager decorator: no work + pass + + # enter job manager decorator: set job to RUNNING + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # job runs... Actual job execution is out of scope for this test. Instead, evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # exit job manager decorator: set job to SUCCEEDED + job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + session.commit() + + # exit pipeline manager decorator: enqueue newly queueable jobs or terminate pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify pipeline status is still RUNNING (since there is a dependent job) + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify that the completed job is now SUCCEEDED in the database + completed_job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert completed_job.status == JobStatus.SUCCEEDED + + # Verify that the dependent job is now QUEUED in the database and ARQ + dependent_job = session.execute( + select(JobRun).where(JobRun.pipeline_id == sample_pipeline.id).filter(JobRun.id != sample_job_run.id) + ).scalar_one() + assert dependent_job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == dependent_job.job_function + + # Simulate the next iteration of pipeline lifecycle. We've now entered a new context manager with + # steps identical to those described above but executing in the context of a newly enqueued dependent job. + job_manager = JobManager(session, arq_redis, dependent_job.id) + + # enter pipeline manager decorator: no work + pass + + # enter job manager decorator: set dependent job to RUNNING + dependent_job_manager = JobManager(session, arq_redis, dependent_job.id) + dependent_job_manager.start_job() + session.commit() + + # job runs... Actual job execution is out of scope for this test. Instead, evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # exit job manager decorator: set dependent job to SUCCEEDED + job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + session.commit() + + # exit pipeline manager decorator: enqueue newly queueable jobs or terminate pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify pipeline status is now SUCCEEDED + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify that the dependent job is now SUCCEEDED in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == dependent_job.id)).scalar_one() + assert dependent_job.status == JobStatus.SUCCEEDED + + @pytest.mark.asyncio + async def test_paused_pipeline_lifecycle( + self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + ): + """Test lifecycle of a paused pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Simulate job start + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Pause the pipeline. Pausing the pipeline while a job is running DOES NOT affect the job. + await manager.pause_pipeline() + session.commit() + + # Verify that the pipeline is paused + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.PAUSED + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # Simulate job completion + job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify that the pipeline remains paused + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.PAUSED + + # Verify that no jobs were enqueued in ARQ + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + # Verify that the dependent job remains pending in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.PENDING + + # Unpause the pipeline + await manager.unpause_pipeline() + session.commit() + + # Verify that the pipeline is now running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify that the dependent job is is now queued in ARQ + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_dependent_job_run.job_function + + # Simulate dependent job start + dependent_job_manager = JobManager(session, arq_redis, sample_dependent_job_run.id) + dependent_job_manager.start_job() + session.commit() + + # Evict the dependent job from redis to simulate completion. + await arq_redis.flushdb() + + # Simulate dependent job completion + dependent_job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify that the pipeline is now succeeded + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify that the dependent job is now succeeded in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.SUCCEEDED + + @pytest.mark.asyncio + async def test_cancelled_pipeline_lifecycle( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + sample_dependent_job_run, + ): + """Test lifecycle of a cancelled pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Simulate job start + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + # Cancel the pipeline. This DOES have an effect on the running job. + await manager.cancel_pipeline() + session.commit() + + # Verify that the pipeline is now cancelled + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.CANCELLED + + # Verify that the job is now cancelled in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.CANCELLED + + # Verify that the dependent job is now skipped in the database + dependent_job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert dependent_job.status == JobStatus.SKIPPED + + # Verify that no jobs were enqueued in ARQ + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + @pytest.mark.asyncio + async def test_restart_pipeline_lifecycle( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test lifecycle of a restarted pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Start the job + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + job_manager.fail_job( + error=Exception("Simulated job failure"), result={"output": None, "logs": "some logs", "metadata": {}} + ) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify the pipeline failed + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.FAILED + + # Verify that the job is now failed in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Restart the pipeline + await manager.restart_pipeline() + session.commit() + + # Verify that the pipeline is now created + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + @pytest.mark.asyncio + async def test_retry_pipeline_lifecycle( + self, + session, + arq_redis, + setup_worker_db, + sample_pipeline, + sample_job_run, + ): + """Test lifecycle of a restarted pipeline.""" + manager = PipelineManager(session, arq_redis, sample_pipeline.id) + + # Add a cancelled job to the pipeline + cancelled_job = JobRun( + id=99, + pipeline_id=sample_pipeline.id, + job_function="cancelled_job_function", + job_type="CANCELLED_JOB", + status=JobStatus.CANCELLED, + urn="urn:cancelled_job", + ) + session.add(cancelled_job) + session.commit() + + # Start the pipeline + await manager.start_pipeline() + session.commit() + + # Verify pipeline status is running + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status and enqueued in ARQ + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 + assert queued_jobs[0].function == sample_job_run.job_function + + # Start the job + job_manager = JobManager(session, arq_redis, sample_job_run.id) + job_manager.start_job() + session.commit() + + # Evict the job from redis to simulate completion. + await arq_redis.flushdb() + + job_manager.fail_job( + error=Exception("Simulated job failure"), result={"output": None, "logs": "some logs", "metadata": {}} + ) + session.commit() + + # Coordinate the pipeline + await manager.coordinate_pipeline() + session.commit() + + # Verify the pipeline failed + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.FAILED + + # Verify that the job is now failed in the database + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + # Restart the pipeline + await manager.retry_pipeline() + session.commit() + + # Verify that the pipeline is now created + updated_pipeline = manager.get_pipeline() + assert updated_pipeline.status == PipelineStatus.RUNNING + + # Verify job status of failed job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + + # Verify the previously cancelled job is now queued + job = session.execute(select(JobRun).where(JobRun.id == cancelled_job.id)).scalar_one() + assert job.status == JobStatus.QUEUED + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 2 From 899ca84743f4be88bc8575cb19f4e78182d22374 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 16 Jan 2026 10:35:49 -0800 Subject: [PATCH 08/70] feat: add function to check if job dependencies are reachable --- src/mavedb/worker/lib/managers/utils.py | 46 ++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/mavedb/worker/lib/managers/utils.py b/src/mavedb/worker/lib/managers/utils.py index b7448e1e..c607185c 100644 --- a/src/mavedb/worker/lib/managers/utils.py +++ b/src/mavedb/worker/lib/managers/utils.py @@ -7,10 +7,10 @@ import logging from datetime import datetime -from typing import Optional +from typing import Literal, Optional, Union from mavedb.models.enums.job_pipeline import DependencyType, JobStatus -from mavedb.worker.lib.managers.constants import TERMINAL_JOB_STATUSES +from mavedb.worker.lib.managers.constants import COMPLETED_JOB_STATUSES from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) @@ -60,10 +60,48 @@ def job_dependency_is_met(dependency_type: Optional[DependencyType], dependent_j return False if dependency_type == DependencyType.COMPLETION_REQUIRED: - if dependent_job_status not in TERMINAL_JOB_STATUSES: + if dependent_job_status not in COMPLETED_JOB_STATUSES: logger.debug( - f"Dependency not met: dependent job has not reached a terminal status ({dependent_job_status})." + f"Dependency not met: dependent job has not reached a completed status ({dependent_job_status})." ) return False return True + + +def job_should_be_skipped_due_to_unfulfillable_dependency( + dependency_type: Optional[DependencyType], dependent_job_status: JobStatus +) -> Union[tuple[Literal[False], None], tuple[Literal[True], str]]: + """Determine if a job should be skipped due to an unfulfillable dependency. + + Args: + dependency_type: Type of dependency ('hard' or 'soft') + dependent_job_status: Status of the dependent job + + Returns: + Union[tuple[Literal[False], None], tuple[Literal[True], str]]: Tuple indicating + if the job should be skipped and the reason + + Notes: + - A job should be skipped if it has a 'hard' dependency and the dependent job did not succeed. + """ + + # If dependency must have SUCCEEDED but is in a terminal non-success state, skip. + if dependency_type == DependencyType.SUCCESS_REQUIRED: + if dependent_job_status in (JobStatus.FAILED, JobStatus.SKIPPED, JobStatus.CANCELLED): + logger.debug( + f"Job should be skipped due to unfulfillable 'success_required' dependency " + f"({dependent_job_status})." + ) + return True, f"Dependency did not succeed ({dependent_job_status})" + + # If dependency requires 'completion' and you want CANCELLED to NOT qualify, skip here too. + if dependency_type == DependencyType.COMPLETION_REQUIRED: + if dependent_job_status in (JobStatus.CANCELLED, JobStatus.SKIPPED): + logger.debug( + f"Job should be skipped due to unfulfillable 'completion_required' dependency " + f"({dependent_job_status})." + ) + return True, f"Dependency was not completed successfully ({dependent_job_status})" + + return False, None From f34939cf135af6d1bb4edeb552855d01268d1b4f Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 16 Jan 2026 10:35:59 -0800 Subject: [PATCH 09/70] feat: add markers for test categorization in pytest --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ca00ecf0..f9538bff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,11 @@ asyncio_mode = 'strict' testpaths = "tests/" pythonpath = "." norecursedirs = "tests/helpers/" +markers = """ + integration: mark a test as an integration test. + unit: mark a test as a unit test. + slow: mark a test as slow-running. +""" # Uncomment the following lines to include application log output in Pytest logs. # log_cli = true # log_cli_level = "DEBUG" From 3ad046de7f3305f56a185706c343f7170a3cef41 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 16 Jan 2026 16:24:15 -0800 Subject: [PATCH 10/70] fix: mock job manager returning in fixture rather than yielding --- tests/worker/lib/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/worker/lib/conftest.py b/tests/worker/lib/conftest.py index fd707307..ddcd25bc 100644 --- a/tests/worker/lib/conftest.py +++ b/tests/worker/lib/conftest.py @@ -228,9 +228,7 @@ def mock_job_manager(mock_job_run): manager.job_id = mock_job_run.id with patch.object(manager, "get_job", return_value=mock_job_run): - manager.job_id = 123 - - return manager + yield manager @pytest.fixture From 1e447a7164113434c218c1c1438c6b4f083940a4 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 16 Jan 2026 16:35:27 -0800 Subject: [PATCH 11/70] fix: enhance error logging for job and pipeline state transitions --- src/mavedb/worker/lib/managers/job_manager.py | 4 ++++ src/mavedb/worker/lib/managers/pipeline_manager.py | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py index 1da3e581..a3e8a430 100644 --- a/src/mavedb/worker/lib/managers/job_manager.py +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -185,6 +185,7 @@ def start_job(self) -> None: """ job_run = self.get_job() if job_run.status not in STARTABLE_JOB_STATUSES: + logger.error(f"Invalid job start attempt for job {self.job_id} in status {job_run.status}") raise JobTransitionError(f"Cannot start job {self.job_id} from status {job_run.status}") try: @@ -247,6 +248,7 @@ def complete_job(self, status: JobStatus, result: JobResultData, error: Optional """ # Validate terminal status if status not in TERMINAL_JOB_STATUSES: + logger.error(f"Invalid job completion status {status} for job {self.job_id}") raise JobTransitionError( f"Cannot commplete job to status: {status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" ) @@ -463,6 +465,7 @@ def prepare_retry(self, reason: str = "retry_requested") -> None: """ job_run = self.get_job() if job_run.status not in RETRYABLE_JOB_STATUSES: + logger.error(f"Invalid job retry attempt for job {self.job_id} in status {job_run.status}") raise JobTransitionError(f"Cannot retry job {self.job_id} due to invalid state ({job_run.status})") try: @@ -508,6 +511,7 @@ def prepare_queue(self) -> None: """ job_run = self.get_job() if job_run.status != JobStatus.PENDING: + logger.error(f"Invalid job queue attempt for job {self.job_id} in status {job_run.status}") raise JobTransitionError(f"Cannot queue job {self.job_id} from status {job_run.status}") try: diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py index b05f9706..a81a2738 100644 --- a/src/mavedb/worker/lib/managers/pipeline_manager.py +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -174,7 +174,7 @@ async def start_pipeline(self) -> None: status = self.get_pipeline_status() if status != PipelineStatus.CREATED: - logger.info( + logger.error( f"Pipeline {self.pipeline_id} is in a non-created state (current status: {status}) and may not be started" ) raise PipelineTransitionError(f"Pipeline {self.pipeline_id} is in state {status} and may not be started") @@ -364,7 +364,7 @@ async def enqueue_ready_jobs(self) -> None: """ current_status = self.get_pipeline_status() if current_status not in RUNNING_PIPELINE_STATUSES: - logger.debug(f"Pipeline {self.pipeline_id} is not running - skipping job enqueue") + logger.error(f"Pipeline {self.pipeline_id} is not running - skipping job enqueue") raise PipelineStateError( f"Pipeline {self.pipeline_id} is in status {current_status} and cannot enqueue jobs" ) @@ -388,7 +388,7 @@ async def enqueue_ready_jobs(self) -> None: "metadata": {"result": reason, "timestamp": datetime.now().isoformat()}, } ) - logger.info(f"Skipped job {job.urn} due to unmet dependencies: {reason}") + logger.info(f"Skipped job {job.urn} due to unreachable dependencies: {reason}") continue # Ensure enqueued jobs can view the status change and pipelines @@ -462,7 +462,7 @@ async def cancel_pipeline(self, reason: str = "Pipeline cancelled") -> None: current_status = self.get_pipeline_status() if current_status in TERMINAL_PIPELINE_STATUSES: - logger.info(f"Pipeline {self.pipeline_id} is already in terminal status {current_status}") + logger.error(f"Pipeline {self.pipeline_id} is already in terminal status {current_status}") raise PipelineTransitionError( f"Pipeline {self.pipeline_id} is in terminal state {current_status} and may not be cancelled" ) @@ -497,13 +497,13 @@ async def pause_pipeline(self, reason: str = "Pipeline paused") -> None: current_status = self.get_pipeline_status() if current_status in TERMINAL_PIPELINE_STATUSES: - logger.info(f"Pipeline {self.pipeline_id} cannot be paused (current status: {current_status})") + logger.error(f"Pipeline {self.pipeline_id} cannot be paused (current status: {current_status})") raise PipelineTransitionError( f"Pipeline {self.pipeline_id} is in terminal state {current_status} and may not be paused" ) if current_status == PipelineStatus.PAUSED: - logger.info(f"Pipeline {self.pipeline_id} is already paused") + logger.error(f"Pipeline {self.pipeline_id} is already paused") raise PipelineTransitionError(f"Pipeline {self.pipeline_id} is already paused") self.set_pipeline_status(PipelineStatus.PAUSED) @@ -536,7 +536,7 @@ async def unpause_pipeline(self, reason: str = "Pipeline unpaused") -> None: current_status = self.get_pipeline_status() if current_status != PipelineStatus.PAUSED: - logger.info( + logger.error( f"Pipeline {self.pipeline_id} is not paused (current status: {current_status}) and may not be unpaused" ) raise PipelineTransitionError( From a551f5d9b81a480019738402b04e9d7469763d40 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 16 Jan 2026 16:36:45 -0800 Subject: [PATCH 12/70] fix: re-order imports in job manager test file --- tests/worker/lib/managers/test_job_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py index 5950a10d..ca54c18e 100644 --- a/tests/worker/lib/managers/test_job_manager.py +++ b/tests/worker/lib/managers/test_job_manager.py @@ -7,12 +7,13 @@ """ import pytest -from arq import ArqRedis pytest.importorskip("arq") + import re from unittest.mock import Mock, PropertyMock, patch +from arq import ArqRedis from sqlalchemy import select from sqlalchemy.orm import Session From 8ff985c6ff717a1a00e0e568e9d70be59513bc79 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 16 Jan 2026 16:39:44 -0800 Subject: [PATCH 13/70] fix: use conftest_optional import structure in worker test module --- tests/worker/lib/conftest.py | 54 +++++---------------------- tests/worker/lib/conftest_optional.py | 44 ++++++++++++++++++++++ 2 files changed, 54 insertions(+), 44 deletions(-) create mode 100644 tests/worker/lib/conftest_optional.py diff --git a/tests/worker/lib/conftest.py b/tests/worker/lib/conftest.py index ddcd25bc..39d30f13 100644 --- a/tests/worker/lib/conftest.py +++ b/tests/worker/lib/conftest.py @@ -4,22 +4,24 @@ Test configuration and fixtures for worker lib tests. """ -import pytest - -pytest.importorskip("arq") # Skip tests if arq is not installed - from datetime import datetime -from unittest.mock import Mock, patch +from unittest.mock import Mock -from arq import ArqRedis -from sqlalchemy.orm import Session +import pytest from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun from mavedb.models.pipeline import Pipeline from mavedb.worker.lib.managers.job_manager import JobManager -from mavedb.worker.lib.managers.pipeline_manager import PipelineManager + +# Attempt to import optional top level fixtures. If the modules they depend on are not installed, +# we won't have access to our full fixture suite and only a limited subset of tests can be run. +try: + from .conftest_optional import * # noqa: F401, F403 + +except ModuleNotFoundError: + pass @pytest.fixture @@ -213,39 +215,3 @@ def mock_job_run(mock_pipeline): metadata_={}, mavedb_version=None, ) - - -@pytest.fixture -def mock_job_manager(mock_job_run): - """Create a JobManager with mocked database and Redis dependencies.""" - mock_db = Mock(spec=Session) - mock_redis = Mock(spec=ArqRedis) - - # Don't call the real constructor since it tries to load the job from DB - manager = object.__new__(JobManager) - manager.db = mock_db - manager.redis = mock_redis - manager.job_id = mock_job_run.id - - with patch.object(manager, "get_job", return_value=mock_job_run): - yield manager - - -@pytest.fixture -def mock_pipeline_manager(mock_job_manager, mock_pipeline): - """Create a PipelineManager with mocked database, Redis dependencies, and job manager.""" - mock_db = Mock(spec=Session) - mock_redis = Mock(spec=ArqRedis) - - # Don't call the real constructor since it tries to validate the pipeline - manager = object.__new__(PipelineManager) - manager.db = mock_db - manager.redis = mock_redis - manager.pipeline_id = 123 - - with ( - patch("mavedb.worker.lib.managers.pipeline_manager.JobManager") as mock_job_manager_class, - patch.object(manager, "get_pipeline", return_value=mock_pipeline), - ): - mock_job_manager_class.return_value = mock_job_manager - yield manager diff --git a/tests/worker/lib/conftest_optional.py b/tests/worker/lib/conftest_optional.py new file mode 100644 index 00000000..3a9bb268 --- /dev/null +++ b/tests/worker/lib/conftest_optional.py @@ -0,0 +1,44 @@ +from unittest.mock import Mock, patch + +import pytest +from arq import ArqRedis +from sqlalchemy.orm import Session + +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager + + +@pytest.fixture +def mock_job_manager(mock_job_run): + """Create a JobManager with mocked database and Redis dependencies.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + # Don't call the real constructor since it tries to load the job from DB + manager = object.__new__(JobManager) + manager.db = mock_db + manager.redis = mock_redis + manager.job_id = mock_job_run.id + + with patch.object(manager, "get_job", return_value=mock_job_run): + yield manager + + +@pytest.fixture +def mock_pipeline_manager(mock_job_manager, mock_pipeline): + """Create a PipelineManager with mocked database, Redis dependencies, and job manager.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + # Don't call the real constructor since it tries to validate the pipeline + manager = object.__new__(PipelineManager) + manager.db = mock_db + manager.redis = mock_redis + manager.pipeline_id = 123 + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.JobManager") as mock_job_manager_class, + patch.object(manager, "get_pipeline", return_value=mock_pipeline), + ): + mock_job_manager_class.return_value = mock_job_manager + yield manager From c2100a204f8c5a1f59931b3150ff03e5a46b0f66 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 20 Jan 2026 13:18:32 -0800 Subject: [PATCH 14/70] feat: Add decorators for job and pipeline management Adds decorators for managed jobs and pipelines. These can be applied to async ARQ functions to automatically persist their state as they execute --- src/mavedb/worker/lib/decorators/__init__.py | 27 + .../worker/lib/decorators/job_management.py | 180 ++++++ .../lib/decorators/pipeline_management.py | 188 +++++++ src/mavedb/worker/lib/managers/types.py | 14 +- src/mavedb/worker/lib/managers/utils.py | 6 +- tests/worker/lib/conftest.py | 25 - tests/worker/lib/conftest_optional.py | 13 + .../lib/decorators/test_job_management.py | 293 ++++++++++ .../decorators/test_pipeline_management.py | 526 ++++++++++++++++++ 9 files changed, 1240 insertions(+), 32 deletions(-) create mode 100644 src/mavedb/worker/lib/decorators/__init__.py create mode 100644 src/mavedb/worker/lib/decorators/job_management.py create mode 100644 src/mavedb/worker/lib/decorators/pipeline_management.py create mode 100644 tests/worker/lib/decorators/test_job_management.py create mode 100644 tests/worker/lib/decorators/test_pipeline_management.py diff --git a/src/mavedb/worker/lib/decorators/__init__.py b/src/mavedb/worker/lib/decorators/__init__.py new file mode 100644 index 00000000..1f9ad803 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/__init__.py @@ -0,0 +1,27 @@ +""" +Decorator utilities for job and pipeline management. + +This module exposes decorators for managing job and pipeline lifecycle hooks, error handling, +and logging in worker functions. Use these decorators to ensure consistent state management +and observability for background jobs and pipelines. + +Available decorators: +- with_job_management: Handles job context and state transitions +- with_pipeline_management: Handles pipeline context and coordination in addition to job management + +Example usage:: + from mavedb.worker.lib.decorators import managed_workflow + + @with_pipeline_management + async def my_worker_function_in_a_pipeline(...): + ... + + @with_job_management + async def my_standalone_job_function(...): + ... +""" + +from .job_management import with_job_management +from .pipeline_management import with_pipeline_management + +__all__ = ["with_job_management", "with_pipeline_management"] diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py new file mode 100644 index 00000000..0da0e7fd --- /dev/null +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -0,0 +1,180 @@ +""" +Managed Job Decorator - Unified decorator for complete job lifecycle management. + +Provides automatic job lifecycle tracking with support for both sync and async functions. +Includes JobManager injection for advanced operations and robust error handling. +""" + +import functools +import inspect +import logging +from typing import Any, Awaitable, Callable, TypeVar, cast + +from arq import ArqRedis +from sqlalchemy.orm import Session + +from mavedb.worker.lib.managers import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def with_job_management(func: F) -> F: + """ + Decorator that adds automatic job lifecycle management to ARQ worker functions. + + Features: + - Job start/completion tracking with error handling + - JobManager injection for advanced operations + - Robust error handling with guaranteed state persistence + + The decorator injects a 'job_manager' parameter into the function that provides + access to progress updates and the underlying JobManager. + + Example: + ``` + @with_job_management + async def my_job_function(ctx, param1, param2, job_manager: JobManager): + job_manager.update_progress(10, message="Starting work") + + # Access JobManager for advanced operations + job_info = job_manager.get_job_info() + + # Do work... + job_manager.update_progress(50, message="Halfway done") + + # More work... + job_manager.update_progress(100, message="Complete") + + return {"result": "success"} + ``` + + Args: + func: The async function to decorate + + Returns: + Decorated async function with lifecycle management + """ + if not inspect.iscoroutinefunction(func): # pragma: no cover + raise ValueError("with_job_management decorator can only be applied to async functions") + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + return await _execute_managed_job(func, args, kwargs) + + return cast(F, async_wrapper) + + +async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> Any: + """ + Execute a managed ARQ job with full lifecycle tracking. + + This function handles the complete job lifecycle including: + - JobManager initialization from context + - Job start tracking + - ProgressTracker injection + - Async function execution + - Job completion tracking + - Error handling and cleanup + + Args: + func: Async function to execute + args: Function arguments + kwargs: Function keyword arguments + + Returns: + Function result + + Raises: + Exception: Re-raises any exception after proper job failure tracking + """ + # Extract context (implicit first argument by ARQ convention) + if not args: + raise ValueError("Managed job functions must receive context as first argument") + ctx = args[0] + + # Get database session and job ID from context + if "db" not in ctx: + raise ValueError("DB session not found in job context") + if "redis" not in ctx: + raise ValueError("Redis connection not found in job context") + + # Extract job_id (second argument by MaveDB convention) + if not args or len(args) < 2 or not isinstance(args[1], int): + raise ValueError("Job ID not found in pipeline context") + job_id = args[1] + + db_session: Session = ctx["db"] + redis_pool: ArqRedis = ctx["redis"] + + try: + # Initialize JobManager + job_manager = JobManager(db_session, redis_pool, job_id) + + # Inject the job manager into kwargs for access within the function + kwargs["job_manager"] = job_manager + + # Mark job as started and persist state + job_manager.start_job() + db_session.commit() + + # Execute the async function + result = await func(*args, **kwargs) + + # Mark job as succeeded and persist state + job_manager.succeed_job(result=result) + db_session.commit() + + return result + + except Exception as e: + # Prioritize salvaging lifecycle state + try: + db_session.rollback() + + # Build failure result data + result = { + "status": "failed", + "data": {}, + "exception_details": { + "type": type(e).__name__, + "message": str(e), + "traceback": None, # Could be populated with actual traceback if needed + }, + } + + # Mark job as failed + job_manager.fail_job(result=result, error=e) + db_session.commit() + + # TODO: Decide on retry logic based on exception type and result. + if job_manager.should_retry(): + # Prepare job for retry and persist state + job_manager.prepare_retry(reason=str(e)) + db_session.commit() + + result["status"] = "retried" + + # short circuit raising the exception. We indicate to the caller + # we did encounter a terminal failure and coordination should proceed. + return result + + except Exception as inner_e: + logger.error(f"Failed to mark job {job_id} as failed: {inner_e}") + + # TODO: Notification hooks + + # Re-raise the outer exception immediately to prevent duplicate notifications + raise e + + logger.error(f"Job {job_id} failed: {e}") + + # TODO: Notification hooks + + raise # Re-raise the exception + + +# Export decorator at module level for easy import +__all__ = ["with_job_management"] diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py new file mode 100644 index 00000000..09bca4c6 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -0,0 +1,188 @@ +""" +Managed Job Decorator - Unified decorator for complete job lifecycle management. + +Provides automatic job lifecycle tracking with support for both sync and async functions. +Includes JobManager injection for advanced operations and robust error handling. +""" + +import functools +import inspect +import logging +from typing import Any, Awaitable, Callable, TypeVar, cast + +from arq import ArqRedis +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators import with_job_management +from mavedb.worker.lib.managers import PipelineManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def with_pipeline_management(func: F) -> F: + """ + Decorator that adds automatic pipeline lifecycle management to ARQ worker functions. Practically, + this means calling `PipelineManager.coordinate_pipeline()` after the decorated function completes. + + This decorator performs no pipeline coordination prior to function execution; it only + coordinates the pipeline after the function has run (whether successfully or with failure). + As a result, this decorator is best suited for jobs that represent discrete steps within a pipeline. + Pipelines are expected to be pre-defined and associated with jobs prior to execution and should be transitioned + to a running state by other means (e.g. a dedicated pipeline starter job). Attempting to start pipelines + within this decorator is not supported, and doing so may lead to unexpected behavior. + + Because pipeline management depends on job management, this decorator is built on top of the + `with_job_management` decorator. + + This decorator may be added to jobs which may or may not belong to a pipeline. If the job does not + belong to a pipeline, the decorator will simply skip pipeline coordination steps. Although pipeline + membership is optional, the decorator still will always enforce job lifecycle management via + `with_job_management`. + + Features: + - Pipeline lifecycle tracking + - Job lifecycle tracking via with_job_management + - Robust error handling, logging, and TODO(alerting) on failures + + Example: + @with_pipeline_management + async def my_job_function(ctx, param1, param2): + ... job logic ... + + On decorator exit, pipeline coordination is attempted. + + Args: + func: The async function to decorate + + Returns: + Decorated async function with lifecycle management + """ + if not inspect.iscoroutinefunction(func): # pragma: no cover + raise ValueError("with_pipeline_management decorator can only be applied to async functions") + + # Wrap the function with job management. It isn't as simple as stacking decorators + # as we can only call job management after setting up pipeline management. + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + return await _execute_managed_pipeline(func, args, kwargs) + + return cast(F, async_wrapper) + + +async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> Any: + """ + Execute the managed pipeline function with lifecycle management. + + Args: + func: The async function to execute. + args: Positional arguments for the function. + kwargs: Keyword arguments for the function. + + Returns: + Any: The result of the function execution. + + Raises: + Exception: Propagates any exception raised during function execution. + """ + # Extract context (first argument by ARQ convention) + if not args or len(args) < 1 or not isinstance(args[0], dict): + raise ValueError("Managed pipeline functions must receive context as first argument") + ctx = args[0] + + # Get database session and pipeline ID from context + if "db" not in ctx: + raise ValueError("DB session not found in pipeline context") + if "redis" not in ctx: + raise ValueError("Redis connection not found in pipeline context") + + db_session: Session = ctx["db"] + redis_pool: ArqRedis = ctx["redis"] + + # Extract job_id (second argument by MaveDB convention) + if not args or len(args) < 2 or not isinstance(args[1], int): + raise ValueError("Job ID not found in pipeline context") + job_id = args[1] + + pipeline_manager = None + pipeline_id = None + try: + # Attempt to load the pipeline ID from the job. + # - If pipeline_id is not None, initialize PipelineManager + # - If None, skip pipeline coordination. We do not enforce every job to belong to a pipeline. + # - If error occurs, handle below + pipeline_id = db_session.execute(select(JobRun.pipeline_id).where(JobRun.id == job_id)).scalar_one() + if pipeline_id: + pipeline_manager = PipelineManager(db=db_session, redis=redis_pool, pipeline_id=pipeline_id) + + logger.info(f"Pipeline ID for job {job_id} is {pipeline_id}. Coordinating pipeline after job execution.") + + # Wrap the function with job management, then execute. This ensures both: + # - Job lifecycle management is nested within pipeline management + # - Exceptions from the job management layer are caught here for pipeline coordination + job_managed_func = with_job_management(func) + result = await job_managed_func(*args, **kwargs) + + # Attempt to coordinate pipeline next steps after successful job execution + if pipeline_manager: + await pipeline_manager.coordinate_pipeline() + + # Commit any changes made during pipeline coordination + db_session.commit() + + logger.info(f"Pipeline {pipeline_id} associated with job {job_id} coordinated successfully") + else: + logger.info(f"No pipeline associated with job {job_id}; skipping coordination") + + return result + + except Exception as e: + try: + # Rollback any uncommitted changes + db_session.rollback() + + # Attempt one final coordination to clean up any stubborn pipeline state + if pipeline_manager: + await pipeline_manager.coordinate_pipeline() + + # Commit any changes made during final coordination + db_session.commit() + + except Exception as inner_e: + logger.error( + f"Unable to perform cleanup coordination on pipeline {pipeline_id} associated with job {job_id} after error: {inner_e}" + ) + + # No further work here. We can rely on the notification hooks below to alert on the original failure + # and should allow result generation to proceed as normal so the job can be logged. + + logger.error(f"Pipeline {pipeline_id} associated with job {job_id} failed to coordinate: {e}") + + # Build job result data for failure + result = { + "status": "failed", + "data": {}, + "exception_details": { + "type": type(e).__name__, + "message": str(e), + "traceback": None, # Could be populated with actual traceback if needed + }, + } + + # TODO: Notification hooks + + # Pipeline coordination represents the outermost operation. Swallow the exception after alerting + # so ARQ can finish the job cleanly and log results. We don't mind that we lose ARQs built in + # job marking, since we perform our own job lifecycle management via with_job_management. + return result + + # Note: No finally block needed - PipelineManager handles cleanup automatically + + +# Export decorator at module level for easy import +__all__ = ["with_pipeline_management"] diff --git a/src/mavedb/worker/lib/managers/types.py b/src/mavedb/worker/lib/managers/types.py index 68a5c217..e93b2ac2 100644 --- a/src/mavedb/worker/lib/managers/types.py +++ b/src/mavedb/worker/lib/managers/types.py @@ -1,10 +1,16 @@ -from typing import TypedDict +from typing import Optional, TypedDict + + +class ExceptionDetails(TypedDict): + type: str + message: str + traceback: Optional[str] class JobResultData(TypedDict): - output: dict - logs: str - metadata: dict + status: str + data: dict + exception_details: Optional[ExceptionDetails] class RetryHistoryEntry(TypedDict): diff --git a/src/mavedb/worker/lib/managers/utils.py b/src/mavedb/worker/lib/managers/utils.py index c607185c..91395d4a 100644 --- a/src/mavedb/worker/lib/managers/utils.py +++ b/src/mavedb/worker/lib/managers/utils.py @@ -26,12 +26,12 @@ def construct_bulk_cancellation_result(reason: str) -> JobResultData: JobResultData: Standardized result data with cancellation metadata """ return { - "output": {}, - "logs": "", - "metadata": { + "status": "cancelled", + "data": { "reason": reason, "timestamp": datetime.now().isoformat(), }, + "exception_details": None, } diff --git a/tests/worker/lib/conftest.py b/tests/worker/lib/conftest.py index 39d30f13..faf63e0e 100644 --- a/tests/worker/lib/conftest.py +++ b/tests/worker/lib/conftest.py @@ -13,7 +13,6 @@ from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun from mavedb.models.pipeline import Pipeline -from mavedb.worker.lib.managers.job_manager import JobManager # Attempt to import optional top level fixtures. If the modules they depend on are not installed, # we won't have access to our full fixture suite and only a limited subset of tests can be run. @@ -134,30 +133,6 @@ def setup_worker_db( session.commit() -@pytest.fixture -def job_manager_with_mocks(session, sample_job_run, sample_pipeline): - """Create a JobManager instance with mocked dependencies.""" - # Add test data to session - session.add(sample_job_run) - session.add(sample_pipeline) - session.commit() - - # Create JobManager instance - manager = JobManager(session, sample_job_run.id) - return manager - - -@pytest.fixture -def async_context(): - """Create a mock async context similar to ARQ worker context.""" - return { - "db": None, # Will be set by specific tests - "redis": None, # Will be set by specific tests - "job_id": 1, - "state": {}, - } - - @pytest.fixture def mock_pipeline(): """Create a mock Pipeline instance. By default, diff --git a/tests/worker/lib/conftest_optional.py b/tests/worker/lib/conftest_optional.py index 3a9bb268..badebab2 100644 --- a/tests/worker/lib/conftest_optional.py +++ b/tests/worker/lib/conftest_optional.py @@ -42,3 +42,16 @@ def mock_pipeline_manager(mock_job_manager, mock_pipeline): ): mock_job_manager_class.return_value = mock_job_manager yield manager + + +@pytest.fixture +def mock_worker_ctx(): + """Create a mock worker context dictionary for testing.""" + mock_db = Mock(spec=Session) + mock_redis = Mock(spec=ArqRedis) + + return { + "db": mock_db, + "redis": mock_redis, + "hdp": Mock(), # Mock HDP data provider + } diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py new file mode 100644 index 00000000..2f689cbe --- /dev/null +++ b/tests/worker/lib/decorators/test_job_management.py @@ -0,0 +1,293 @@ +# ruff : noqa: E402 + +""" +Unit and integration tests for the with_job_management async decorator. +Covers status transitions, error handling, and JobManager interaction. +""" + +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +import asyncio +from unittest.mock import patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators.job_management import with_job_management +from mavedb.worker.lib.managers.constants import RETRYABLE_FAILURE_CATEGORIES +from mavedb.worker.lib.managers.exceptions import JobStateError +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.transaction_spy import TransactionSpy + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestManagedJobDecoratorUnit: + async def test_decorator_must_receive_ctx_as_first_argument(self, mock_job_manager): + @with_job_management + async def sample_job(not_ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "ok"} + + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_job_manager.db): + await sample_job() + + assert "Managed job functions must receive context as first argument" in str(exc_info.value) + + async def test_decorator_calls_wrapped_function_and_returns_result(self, mock_job_manager, mock_worker_ctx): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "ok"} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None), + patch.object(mock_job_manager, "succeed_job", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + + result = await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + assert result == {"status": "ok"} + + async def test_decorator_calls_start_job_and_succeed_job_when_wrapped_function_succeeds( + self, mock_worker_ctx, mock_job_manager + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "ok"} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "succeed_job", return_value=None) as mock_succeed_job, + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_start_job.assert_called_once() + mock_succeed_job.assert_called_once() + + async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_raises_and_no_retry( + self, mock_worker_ctx, mock_job_manager + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + raise RuntimeError("error in wrapped function") + + with ( + pytest.raises(RuntimeError), + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "should_retry", return_value=False), + patch.object(mock_job_manager, "fail_job", return_value=None) as mock_fail_job, + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_start_job.assert_called_once() + mock_fail_job.assert_called_once() + + async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_raises_and_retry( + self, mock_worker_ctx, mock_job_manager + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + raise RuntimeError("error in wrapped function") + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "should_retry", return_value=True), + patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_start_job.assert_called_once() + mock_prepare_retry.assert_called_once_with(reason="error in wrapped function") + + @pytest.mark.parametrize("missing_key", ["db", "redis"]) + async def test_decorator_raises_value_error_if_required_context_missing( + self, mock_job_manager, mock_worker_ctx, missing_key + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "ok"} + + del mock_worker_ctx[missing_key] + + with pytest.raises(ValueError) as exc_info: + await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + assert missing_key.replace("_", " ") in str(exc_info.value).lower() + assert "not found in job context" in str(exc_info.value).lower() + + async def test_decorator_propagates_exception_from_lifecycle_state_outside_except( + self, mock_job_manager, mock_worker_ctx + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "ok"} + + with ( + pytest.raises(JobStateError) as exc_info, + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", side_effect=JobStateError("error in job start")), + patch.object(mock_job_manager, "should_retry", return_value=False), + patch.object(mock_job_manager, "fail_job", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_rollback=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + assert "error in job start" in str(exc_info.value) + + async def test_decorator_raises_value_error_if_job_id_missing(self, mock_job_manager, mock_worker_ctx): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "ok"} + + # Remove job_id from args to simulate missing job_id + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_worker_ctx["db"]): + await sample_job(mock_worker_ctx) + + assert "job id not found in pipeline context" in str(exc_info.value).lower() + + async def test_decorator_propagates_exception_from_wrapped_function_inside_except( + self, mock_job_manager, mock_worker_ctx + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + raise RuntimeError("error in wrapped function") + + with ( + pytest.raises(RuntimeError) as exc_info, + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None), + patch.object(mock_job_manager, "should_retry", return_value=False), + patch.object(mock_job_manager, "fail_job", side_effect=JobStateError("error in job fail")), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=False, expect_rollback=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + # Errors within the main try block should take precedence + assert "error in wrapped function" in str(exc_info.value) + + async def test_decorator_passes_job_manager_to_wrapped(self, mock_job_manager, mock_worker_ctx): + @with_job_management + async def sample_job(ctx, job_id: int, job_manager): + assert isinstance(job_manager, JobManager) + return True + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None), + patch.object(mock_job_manager, "succeed_job", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + assert await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestManagedJobDecoratorIntegration: + """Integration tests for with_job_management decorator.""" + + async def test_decorator_integrated_job_lifecycle_success( + self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + return {"status": "ok"} + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Now allow the job to complete + event.set() + await job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + async def test_decorator_integrated_job_lifecycle_failure( + self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure") + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + + # At this point, the job should be started but not in error + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # Now allow the job to complete with failure. This failure + # should be propagated out of the job_task. + with pytest.raises(RuntimeError): + event.set() + await job_task + + # After failure, status should be FAILED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + async def test_decorator_integrated_job_lifecycle_retry( + self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db + ): + # Use an event to control when the job completes + event = asyncio.Event() + + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + sample_job_run.failure_category = RETRYABLE_FAILURE_CATEGORIES[0] # Set a retryable failure category + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure for retry") + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + + # At this point, the job should be started but not in error + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # TODO: We patch `should_retry` to return True to force a retry scenario. After implementing failure + # categorization in the worker, this patch can be removed and we should directly test retry logic based + # on failure categories. + # + # Now allow the job to complete with failure that triggers a retry. This failure + # should be swallowed by the job_task. + with patch.object(JobManager, "should_retry", return_value=True): + event.set() + await job_task + + # After failure with retry, status should be PENDING + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.PENDING + assert job.retry_count == 1 # Ensure it attempted once before retrying diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py new file mode 100644 index 00000000..eb843aac --- /dev/null +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -0,0 +1,526 @@ +# ruff : noqa: E402 + +""" +Unit tests for the with_pipeline_management async decorator. +Covers orchestration steps, error handling, and PipelineManager interaction. +""" + +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +import asyncio +from unittest.mock import MagicMock, patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +from tests.helpers.transaction_spy import TransactionSpy + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestPipelineManagementDecoratorUnit: + """Unit tests for the with_pipeline_management decorator.""" + + async def test_decorator_must_receive_ctx_as_first_argument(self, mock_pipeline_manager): + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): + await sample_job() + + assert "Managed pipeline functions must receive context as first argument" in str(exc_info.value) + + @pytest.mark.parametrize("missing_key", ["db", "redis"]) + async def test_decorator_raises_value_error_if_required_context_missing( + self, mock_pipeline_manager, mock_worker_ctx, missing_key + ): + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + del mock_worker_ctx[missing_key] + + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): + await sample_job(mock_worker_ctx, 999, mock_pipeline_manager) + + assert missing_key.replace("_", " ") in str(exc_info.value).lower() + assert "not found in pipeline context" in str(exc_info.value).lower() + + async def test_decorator_raises_value_error_if_job_id_missing(self, mock_pipeline_manager, mock_worker_ctx): + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + # Remove job_id from args to simulate missing job_id + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): + await sample_job(mock_worker_ctx, mock_pipeline_manager) + + assert "job id not found in pipeline context" in str(exc_info.value).lower() + + async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id(self, mock_pipeline_manager, mock_worker_ctx): + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + with ( + TransactionSpy.mock_database_execution_failure( + mock_worker_ctx["db"], + exception=ValueError("job id not found in pipeline context"), + expect_rollback=True, + ), + ): + await sample_job(mock_worker_ctx, 999) + + async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manager( + self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + ): + with ( + # patch the with_job_management decorator to be a no-op + patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) + ) as mock_execute, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + # Sample jobs should be defined within the with scope to mock the job management decorator + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_execute.assert_called_once() + assert result == {"status": "ok"} + + async def test_decorator_skips_coordination_when_no_pipeline_exists(self, mock_pipeline_manager, mock_worker_ctx): + with ( + # patch the with_job_management decorator to be a no-op + patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=None)) + ) as mock_execute, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + # We shouldn't expect any commits since no pipeline coordination occurs + TransactionSpy.spy(mock_worker_ctx["db"]), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_execute.assert_called_once() + mock_coordinate_pipeline.assert_not_called() + assert result == {"status": "ok"} + + async def test_decorator_calls_wrapped_function_and_returns_result( + self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + ): + with ( + # patch the with_job_management decorator to be a no-op + patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f + ) as mock_with_job_mgmt, + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) + ), + patch.object(mock_pipeline_manager, "get_pipeline", return_value=mock_pipeline), + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_with_job_mgmt.assert_called_once() + assert result == {"status": "ok"} + + async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrapped_function( + self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + ): + with ( + # patch the with_job_management decorator to be a no-op + patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", + wraps=lambda f: f, + ), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) + ), + patch.object(mock_pipeline_manager, "get_pipeline", return_value=mock_pipeline), + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_coordinate_pipeline.assert_called_once() + + async def test_decorator_swallows_exception_from_wrapped_function(self, mock_pipeline_manager, mock_worker_ctx): + with ( + # patch the with_job_management decorator to be a no-op + patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", + wraps=lambda f: f, + ), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + raise RuntimeError("error in wrapped function") + + await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + # TODO: Assert calls for notification hooks and job result data + + async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pipeline( + self, mock_pipeline_manager, mock_worker_ctx + ): + with ( + # patch the with_job_management decorator to be a no-op + patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", + wraps=lambda f: f, + ), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_pipeline_manager, + "coordinate_pipeline", + side_effect=RuntimeError("error in coordinate_pipeline"), + ), + # Exception raised from coordinate_pipeline should trigger rollback but prevent commit + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=False, expect_rollback=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + # TODO: Assert calls for notification hooks and job result data + + async def test_decorator_swallows_exception_from_job_management_decorator( + self, mock_pipeline_manager, mock_worker_ctx + ): + def passthrough_decorator(f): + return f + + with ( + # patch the with_job_management decorator to raise an error + patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", + wraps=passthrough_decorator, + side_effect=ValueError("error in job management decorator"), + ) as mock_with_job_mgmt, + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=False, expect_rollback=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_with_job_mgmt.assert_called_once() + # TODO: Assert calls for notification hooks and job result data + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestPipelineManagementDecoratorIntegration: + """Integration tests for the with_pipeline_management decorator.""" + + async def test_decorator_integrated_pipeline_lifecycle_success( + self, + session, + arq_redis, + sample_job_run, + sample_dependent_job_run, + standalone_worker_context, + setup_worker_db, + sample_pipeline, + ): + # Use an event to control when the job completes + event = asyncio.Event() + dep_event = asyncio.Event() + + # Transition pipeline to RUNNING to allow job execution. This step of pipeline management + # is intentionally not handled by the decorator. + sample_pipeline.status = PipelineStatus.RUNNING + session.commit() + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + return {"status": "ok"} + + @with_pipeline_management + async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): + await dep_event.wait() # Simulate async work, block until test signals + return {"status": "ok"} + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete and flush the Redis queue. Flush the queue first to ensure + # we don't mistakenly flush our queued job. + await arq_redis.flushdb() + event.set() + await job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + + # Pipeline remains RUNNING after job success, another job was queued. + assert pipeline.status == PipelineStatus.RUNNING + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 # Ensure the next job was queued + + # Simulate execution of next job by running the dependent job. + # Start the job (it will block at event.wait()) + dependent_job_task = asyncio.create_task( + sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id, job_manager=None) + ) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete and flush the Redis queue. Flush the queue first to ensure + # we don't mistakenly flush our queued job. + await arq_redis.flushdb() + dep_event.set() + await dependent_job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + # Now that all jobs are complete, the pipeline should be SUCCEEDED + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.SUCCEEDED + + # No further jobs should be queued + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + async def test_decorator_integrated_pipeline_lifecycle_retryable_failure( + self, + session, + arq_redis, + sample_job_run, + sample_dependent_job_run, + standalone_worker_context, + setup_worker_db, + sample_pipeline, + ): + # Use an event to control when the job completes + event = asyncio.Event() + retry_event = asyncio.Event() + dep_event = asyncio.Event() + + # Transition pipeline to RUNNING to allow job execution. This step of pipeline management + # is intentionally not handled by the decorator. + sample_pipeline.status = PipelineStatus.RUNNING + session.commit() + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure for retry") + + @with_pipeline_management + async def sample_retried_job(ctx: dict, job_id: int, job_manager: JobManager): + await retry_event.wait() # Simulate async work, block until test signals + return {"status": "ok"} + + @with_pipeline_management + async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): + await dep_event.wait() # Simulate async work, block until test signals + return {"status": "ok"} + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete with failure that triggers a retry. This failure + # should be swallowed by the job_task. + with patch.object(JobManager, "should_retry", return_value=True): + event.set() + await job_task + + # After failure with retry, status should be QUEUED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.QUEUED + assert job.retry_count == 1 # Ensure it attempted once before retrying + + # Now start the retried job (it will block at retry_event.wait()) + retried_job_task = asyncio.create_task( + sample_retried_job(standalone_worker_context, sample_job_run.id, job_manager=None) + ) + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # The pipeline should remain running + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the retried job to complete successfully + await arq_redis.flushdb() + retry_event.set() + await retried_job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 1 # Ensure the next job was queued + + # Simulate execution of next job by running the dependent job. + # Start the job (it will block at event.wait()) + dependent_job_task = asyncio.create_task( + sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id, job_manager=None) + ) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete and flush the Redis queue. Flush the queue first to ensure + # we don't mistakenly flush our queued job. + await arq_redis.flushdb() + dep_event.set() + await dependent_job_task + + # After completion, status should be SUCCEEDED + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SUCCEEDED + + # Now that all jobs are complete, the pipeline should be SUCCEEDED + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.SUCCEEDED + + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 # Ensure no further jobs were queued + + async def test_decorator_integrated_pipeline_lifecycle_non_retryable_failure( + self, + session, + arq_redis, + sample_job_run, + sample_dependent_job_run, + standalone_worker_context, + setup_worker_db, + sample_pipeline, + ): + # Use an event to control when the job completes + event = asyncio.Event() + + # Transition pipeline to RUNNING to allow job execution. This step of pipeline management + # is intentionally not handled by the decorator. + sample_pipeline.status = PipelineStatus.RUNNING + session.commit() + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + await event.wait() # Simulate async work, block until test signals + raise RuntimeError("Simulated job failure") + + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING + + # Now allow the job to complete with failure. This failure + # should be swallowed by the pipeline manager + event.set() + await job_task + + # After failure with no retry, status should be FAILED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + + # Pipeline should be marked FAILED after job failure + assert pipeline.status == PipelineStatus.FAILED + + # No further jobs should be queued + queued_jobs = await arq_redis.queued_jobs() + assert len(queued_jobs) == 0 + + # Dependent job should transition to skipped since it was never queued + job = session.execute(select(JobRun).where(JobRun.id == sample_dependent_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED From 3799d847acd21da29373bbda6c76e6a7fa4a55c4 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 20 Jan 2026 13:46:04 -0800 Subject: [PATCH 15/70] feat: use context for logging in job manager --- src/mavedb/worker/lib/managers/job_manager.py | 137 +++++++++++++----- 1 file changed, 102 insertions(+), 35 deletions(-) diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py index a3e8a430..f89aecbb 100644 --- a/src/mavedb/worker/lib/managers/job_manager.py +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -34,7 +34,7 @@ import logging import traceback from datetime import datetime -from typing import Optional +from typing import Any, Optional from arq import ArqRedis from sqlalchemy import select @@ -42,6 +42,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import flag_modified +from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus from mavedb.models.job_run import JobRun from mavedb.worker.lib.managers.base_manager import BaseManager @@ -131,6 +132,8 @@ class JobManager(BaseManager): worker thread and should not be shared across concurrent operations. """ + context: dict[str, Any] = {} + def __init__(self, db: Session, redis: ArqRedis, job_id: int): """Initialize JobManager for a specific job. @@ -159,6 +162,19 @@ def __init__(self, db: Session, redis: ArqRedis, job_id: int): job = self.get_job() self.pipeline_id = job.pipeline_id if job else None + self.save_to_context( + {"job_id": str(self.job_id), "pipeline_id": str(self.pipeline_id) if self.pipeline_id else None} + ) + + def save_to_context(self, ctx: dict) -> dict[str, Any]: + for k, v in ctx.items(): + self.context[k] = v + + return self.context + + def logging_context(self) -> dict[str, Any]: + return self.context + def start_job(self) -> None: """Mark job as started and initialize execution tracking. This method does not flush or commit the database session; the caller is responsible for persisting changes. @@ -185,7 +201,10 @@ def start_job(self) -> None: """ job_run = self.get_job() if job_run.status not in STARTABLE_JOB_STATUSES: - logger.error(f"Invalid job start attempt for job {self.job_id} in status {job_run.status}") + self.save_to_context({"job_status": str(job_run.status)}) + logger.error( + "Invalid job start attempt: status not in STARTABLE_JOB_STATUSES", extra=self.logging_context() + ) raise JobTransitionError(f"Cannot start job {self.job_id} from status {job_run.status}") try: @@ -193,10 +212,12 @@ def start_job(self) -> None: job_run.started_at = datetime.now() job_run.progress_message = "Job began execution" except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job start state for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job start state", extra=self.logging_context()) raise JobStateError(f"Failed to update job start state: {e}") - logger.info(f"Job {self.job_id} marked as started") + self.save_to_context({"job_status": str(job_run.status)}) + logger.info("Job marked as started", extra=self.logging_context()) def complete_job(self, status: JobStatus, result: JobResultData, error: Optional[Exception] = None) -> None: """Mark job as completed with the specified final status. This method does @@ -248,7 +269,8 @@ def complete_job(self, status: JobStatus, result: JobResultData, error: Optional """ # Validate terminal status if status not in TERMINAL_JOB_STATUSES: - logger.error(f"Invalid job completion status {status} for job {self.job_id}") + self.save_to_context({"job_status": str(status)}) + logger.error("Invalid job completion status: not in TERMINAL_JOB_STATUSES", extra=self.logging_context()) raise JobTransitionError( f"Cannot commplete job to status: {status}. Must complete to a terminal status: {TERMINAL_JOB_STATUSES}" ) @@ -275,11 +297,17 @@ def complete_job(self, status: JobStatus, result: JobResultData, error: Optional # TODO: Classify failure category based on error type job_run.failure_category = FailureCategory.UNKNOWN + self.save_to_context({"failure_category": str(job_run.failure_category)}) + except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job completion state for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while updating job completion state", extra=self.logging_context() + ) raise JobStateError(f"Failed to update job completion state: {e}") - logger.info(f"Job {self.job_id} marked as {status.value}") + self.save_to_context({"job_status": str(job_run.status)}) + logger.info("Job marked as completed", extra=self.logging_context()) def fail_job(self, error: Exception, result: JobResultData) -> None: """Mark job as failed and record error details. This method does @@ -305,7 +333,7 @@ def fail_job(self, error: Exception, result: JobResultData) -> None: >>> try: ... validate_data(input_data) ... except ValidationError as e: - ... manager.fail_job(error=e) + ... manager.fail_job(error=e, result={}) Failure with partial results: >>> try: @@ -465,7 +493,8 @@ def prepare_retry(self, reason: str = "retry_requested") -> None: """ job_run = self.get_job() if job_run.status not in RETRYABLE_JOB_STATUSES: - logger.error(f"Invalid job retry attempt for job {self.job_id} in status {job_run.status}") + self.save_to_context({"job_status": str(job_run.status)}) + logger.error("Invalid job retry status: status not in RETRYABLE_JOB_STATUSES", extra=self.logging_context()) raise JobTransitionError(f"Cannot retry job {self.job_id} due to invalid state ({job_run.status})") try: @@ -493,10 +522,12 @@ def prepare_retry(self, reason: str = "retry_requested") -> None: flag_modified(job_run, "metadata_") except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job retry state for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job retry state", extra=self.logging_context()) raise JobStateError(f"Failed to update job retry state: {e}") - logger.info(f"Job {self.job_id} successfully prepared for retry (attempt {job_run.retry_count})") + self.save_to_context({"job_status": str(job_run.status), "retry_attempt": job_run.retry_count}) + logger.info("Job successfully prepared for retry", extra=self.logging_context()) def prepare_queue(self) -> None: """Prepare job for enqueueing by setting QUEUED status. This method does @@ -511,17 +542,20 @@ def prepare_queue(self) -> None: """ job_run = self.get_job() if job_run.status != JobStatus.PENDING: - logger.error(f"Invalid job queue attempt for job {self.job_id} in status {job_run.status}") + self.save_to_context({"job_status": str(job_run.status)}) + logger.error("Invalid job queue attempt: status not PENDING", extra=self.logging_context()) raise JobTransitionError(f"Cannot queue job {self.job_id} from status {job_run.status}") try: job_run.status = JobStatus.QUEUED job_run.progress_message = "Job queued for execution" except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to prepare job {self.job_id} for queueing: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job queue state", extra=self.logging_context()) raise JobStateError(f"Failed to update job queue state: {e}") - logger.debug(f"Job {self.job_id} prepared for queueing") + self.save_to_context({"job_status": str(job_run.status)}) + logger.debug("Job successfully prepared for queueing", extra=self.logging_context()) def reset_job(self) -> None: """Reset job to initial state for re-execution. This method does @@ -562,10 +596,12 @@ def reset_job(self) -> None: job_run.metadata_ = {} except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job reset state for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while resetting job state", extra=self.logging_context()) raise JobStateError(f"Failed to reset job state: {e}") - logger.info(f"Job {self.job_id} successfully reset to initial state") + self.save_to_context({"job_status": str(job_run.status), "retry_attempt": job_run.retry_count}) + logger.info("Job successfully reset to initial state", extra=self.logging_context()) def update_progress(self, current: int, total: int = 100, message: Optional[str] = None) -> None: """Update job progress information during execution. This method does @@ -617,10 +653,14 @@ def update_progress(self, current: int, total: int = 100, message: Optional[str] job_run.progress_message = message except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job progress for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Encountered an unexpected error while updating job progress", extra=self.logging_context()) raise JobStateError(f"Failed to update job progress state: {e}") - logger.debug(f"Updated progress for job {self.job_id}: {current}/{total}") + self.save_to_context( + {"job_progress_current": current, "job_progress_total": total, "job_progress_message": message} + ) + logger.debug("Updated progress successfully for job", extra=self.logging_context()) def update_status_message(self, message: str) -> None: """Update job status message without changing progress. This method does @@ -646,10 +686,14 @@ def update_status_message(self, message: str) -> None: try: job_run.progress_message = message except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job status message for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while updating job status message", extra=self.logging_context() + ) raise JobStateError(f"Failed to update job status message state: {e}") - logger.debug(f"Updated status message for job {self.job_id}: {message}") + self.save_to_context({"job_progress_message": message}) + logger.debug("Updated status message successfully for job", extra=self.logging_context()) def increment_progress(self, amount: int = 1, message: Optional[str] = None) -> None: """Increment job progress by a specified amount. This method does @@ -685,10 +729,20 @@ def increment_progress(self, amount: int = 1, message: Optional[str] = None) -> if message: job_run.progress_message = message except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to increment job progress for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while incrementing job progress", extra=self.logging_context() + ) raise JobStateError(f"Failed to increment job progress state: {e}") - logger.debug(f"Incremented progress for job {self.job_id} by {amount} to {job_run.progress_current}") + self.save_to_context( + { + "job_progress_current": current, + "job_progress_total": job_run.progress_total, + "job_progress_message": message or "", + } + ) + logger.debug("Incremented progress successfully for job", extra=self.logging_context()) def set_progress_total(self, total: int, message: Optional[str] = None) -> None: """Update the total progress value, useful when total becomes known during execution. This method does @@ -717,10 +771,14 @@ def set_progress_total(self, total: int, message: Optional[str] = None) -> None: if message: job_run.progress_message = message except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to update job progress total for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug( + "Encountered an unexpected error while updating job progress total", extra=self.logging_context() + ) raise JobStateError(f"Failed to update job progress total state: {e}") - logger.debug(f"Updated progress total for job {self.job_id} to {total}") + self.save_to_context({"job_progress_total": total, "job_progress_message": message}) + logger.debug("Updated progress total successfully for job", extra=self.logging_context()) def is_cancelled(self) -> bool: """Check if job has been cancelled or should stop execution. This method does @@ -770,29 +828,37 @@ def should_retry(self) -> bool: """ job_run = self.get_job() try: + self.save_to_context( + { + "job_retry_count": job_run.retry_count, + "job_max_retries": job_run.max_retries, + "job_failure_category": str(job_run.failure_category) if job_run.failure_category else None, + "job_status": str(job_run.status), + } + ) + # Check if job is in FAILED state if job_run.status != JobStatus.FAILED: - logger.debug(f"Job {self.job_id} not in FAILED state ({job_run.status}), cannot retry") + logger.debug("Job cannot be retried: not in FAILED state", extra=self.logging_context()) return False # Check retry count current_retries = job_run.retry_count or 0 if current_retries >= job_run.max_retries: - logger.debug(f"Job {self.job_id} has reached max retries ({current_retries}/{job_run.max_retries})") + logger.debug("Job cannot be retried: max retries reached", extra=self.logging_context()) return False # Check if failure category is retryable - if job_run.failure_category in RETRYABLE_FAILURE_CATEGORIES: - logger.debug( - f"Job {self.job_id} error {job_run.failure_category} is retryable ({current_retries}/{job_run.max_retries})" - ) - return True + if job_run.failure_category not in RETRYABLE_FAILURE_CATEGORIES: + logger.debug("Job cannot be retried: failure category not retryable", extra=self.logging_context()) + return False - logger.debug(f"Job {self.job_id} error {job_run.failure_category} is not retryable") - return False + logger.debug("Job is retryable", extra=self.logging_context()) + return True except (AttributeError, TypeError, KeyError, ValueError) as e: - logger.debug(f"Failed to check retry eligibility for job {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Unexpected error checking retry eligibility", extra=self.logging_context()) raise JobStateError(f"Failed to check retry eligibility state: {e}") def get_job_status(self) -> JobStatus: # pragma: no cover @@ -840,5 +906,6 @@ def get_job(self) -> JobRun: try: return self.db.execute(select(JobRun).where(JobRun.id == self.job_id)).scalar_one() except SQLAlchemyError as e: - logger.debug(f"SQL query failed getting job info for {self.job_id}: {e}") + self.save_to_context(format_raised_exception_info_as_dict(e)) + logger.debug("Unexpected error fetching job info", extra=self.logging_context()) raise DatabaseConnectionError(f"Failed to fetch job {self.job_id}: {e}") From 155e5491b9d9006823fd53005e8c6a1e3c0a5f70 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 21 Jan 2026 13:06:20 -0800 Subject: [PATCH 16/70] feat: decorator for job run record guarantees In certain instances (cron jobs in particular), worker processes are invoked from contexts where we have not yet added a job run record to the database. In such cases, it becomes useful to first guarantee a minimal record is added to the database such that the job run can be tracked via existing managed job decorators. This feature adds such a decorator and associated tests.` --- src/mavedb/worker/lib/decorators/__init__.py | 3 +- .../worker/lib/decorators/job_guarantee.py | 97 +++++++++++++++++++ src/mavedb/worker/lib/decorators/py.typed | 0 .../lib/decorators/test_job_guarantee.py | 96 ++++++++++++++++++ 4 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 src/mavedb/worker/lib/decorators/job_guarantee.py create mode 100644 src/mavedb/worker/lib/decorators/py.typed create mode 100644 tests/worker/lib/decorators/test_job_guarantee.py diff --git a/src/mavedb/worker/lib/decorators/__init__.py b/src/mavedb/worker/lib/decorators/__init__.py index 1f9ad803..4bef68d5 100644 --- a/src/mavedb/worker/lib/decorators/__init__.py +++ b/src/mavedb/worker/lib/decorators/__init__.py @@ -21,7 +21,8 @@ async def my_standalone_job_function(...): ... """ +from .job_guarantee import with_guaranteed_job_run_record from .job_management import with_job_management from .pipeline_management import with_pipeline_management -__all__ = ["with_job_management", "with_pipeline_management"] +__all__ = ["with_job_management", "with_pipeline_management", "with_guaranteed_job_run_record"] diff --git a/src/mavedb/worker/lib/decorators/job_guarantee.py b/src/mavedb/worker/lib/decorators/job_guarantee.py new file mode 100644 index 00000000..fb118b3a --- /dev/null +++ b/src/mavedb/worker/lib/decorators/job_guarantee.py @@ -0,0 +1,97 @@ +""" +Job Guarantee Decorator - Ensures a JobRun record is persisted before job execution. + +This decorator guarantees that a corresponding JobRun record is created and tracked for the decorated +function in the database before execution begins. It is designed to be stacked before managed job +decorators (such as with_job_management) to provide a consistent audit trail and robust error handling +for all job entrypoints, including cron-triggered jobs. + +Features: +- Persists JobRun with job_type, function name, and parameters +- Integrates cleanly with managed job and pipeline decorators + +Example: + @with_guaranteed_job_run_record("cron_job") + @with_job_management + async def my_cron_job(ctx, ...): + ... +""" + +import functools +from typing import Any, Awaitable, Callable, TypeVar + +from sqlalchemy.orm import Session + +from mavedb import __version__ +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.managers.types import JobResultData + +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +def with_guaranteed_job_run_record(job_type: str) -> Callable[[F], F]: + """ + Async decorator to ensure a JobRun record is created and persisted before executing the job function. + Should be applied before the managed job decorator. + + Args: + job_type (str): The type/category of the job (e.g., "cron_job", "data_processing"). + + Returns: + Decorated async function with job run persistence guarantee. + + Example: + ``` + @with_guaranteed_job_run_record("cron_job") + @with_job_management + async def my_cron_job(ctx, ...): + ... + ``` + """ + + def decorator(func: F) -> F: + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) + + # The job id must be passed as the second argument to the wrapped function. + job = _create_job_run(job_type, func, args, kwargs) + args = list(args) + args.insert(1, job.id) + args = tuple(args) + + return await func(*args, **kwargs) + + return async_wrapper # type: ignore + + return decorator + + +def _create_job_run(job_type: str, func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> None: + """ + Creates and persists a JobRun record for a function before job execution. + """ + # Extract context (implicit first argument by ARQ convention) + if not args: + raise ValueError("Managed job functions must receive context as first argument") + ctx = args[0] + + # Get database session from context + if "db" not in ctx: + raise ValueError("DB session not found in job context") + + db: Session = ctx["db"] + + job_run = JobRun( + job_type=job_type, + job_function=func.__name__, + status=JobStatus.PENDING, + mavedb_version=__version__, + ) + db.add(job_run) + db.commit() + + return job_run diff --git a/src/mavedb/worker/lib/decorators/py.typed b/src/mavedb/worker/lib/decorators/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/worker/lib/decorators/test_job_guarantee.py b/tests/worker/lib/decorators/test_job_guarantee.py new file mode 100644 index 00000000..3da60c87 --- /dev/null +++ b/tests/worker/lib/decorators/test_job_guarantee.py @@ -0,0 +1,96 @@ +# ruff: noqa: E402 +""" +Unit and integration tests for the with_guaranteed_job_run_record async decorator. +Covers JobRun creation, status transitions, error handling, and DB persistence. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import select + +from mavedb import __version__ +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators.job_guarantee import with_guaranteed_job_run_record +from tests.helpers.transaction_spy import TransactionSpy + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestJobGuaranteeDecoratorUnit: + async def test_decorator_must_receive_ctx_as_first_argument(self, mock_worker_ctx): + @with_guaranteed_job_run_record("test_job") + async def sample_job(not_ctx: dict): + return {"status": "ok"} + + with pytest.raises(ValueError) as exc_info: + await sample_job() + + assert "Managed job functions must receive context as first argument" in str(exc_info.value) + + async def test_decorator_must_receive_db_in_ctx(self, mock_worker_ctx): + del mock_worker_ctx["db"] + + @with_guaranteed_job_run_record("test_job") + async def sample_job(not_ctx: dict): + return {"status": "ok"} + + with pytest.raises(ValueError) as exc_info: + await sample_job(mock_worker_ctx) + + assert "DB session not found in job context" in str(exc_info.value) + + async def test_decorator_calls_wrapped_function(self, mock_worker_ctx): + @with_guaranteed_job_run_record("test_job") + async def sample_job(ctx: dict): + return {"status": "ok"} + + with patch("mavedb.worker.lib.decorators.job_guarantee.JobRun") as MockJobRunClass: + MockJobRunClass.return_value = MagicMock(spec=JobRun) + + result = await sample_job(mock_worker_ctx) + + assert result == {"status": "ok"} + + async def test_decorator_creates_job_run(self, mock_worker_ctx, mock_job_run): + @with_guaranteed_job_run_record("test_job") + async def sample_job(ctx: dict): + return {"status": "ok"} + + with ( + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + patch("mavedb.worker.lib.decorators.job_guarantee.JobRun") as mock_job_run_class, + ): + mock_job_run_class.return_value = MagicMock(spec=JobRun) + + await sample_job(mock_worker_ctx) + + mock_job_run_class.assert_called_with( + job_type="test_job", + job_function="sample_job", + status=JobStatus.PENDING, + mavedb_version=__version__, + ) + mock_worker_ctx["db"].add.assert_called() + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestJobGuaranteeDecoratorIntegration: + async def test_decorator_persists_job_run_record(self, session, standalone_worker_context): + @with_guaranteed_job_run_record("integration_job") + async def sample_job(ctx: dict): + return {"status": "ok"} + + # Flush called implicitly by commit + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + job_task = await sample_job(standalone_worker_context) + + assert job_task == {"status": "ok"} + + job_run = session.execute(select(JobRun).order_by(JobRun.id.desc())).scalars().first() + assert job_run.status == JobStatus.PENDING + assert job_run.job_type == "integration_job" + assert job_run.job_function == "sample_job" + assert job_run.mavedb_version is not None From 4a4055d62e2a80b944dff39fd1c10b666abd1c94 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 21 Jan 2026 14:30:05 -0800 Subject: [PATCH 17/70] feat: add test mode support to job and pipeline decorators Since decorators are applied at import time, this test mode path is a pragmatic solution to run decorators without side effects during unit tests. It's more straightforward and maintainable than other solutions, and still lets us import job definitions up front to register with ARQ. --- .../worker/lib/decorators/job_guarantee.py | 1 + .../worker/lib/decorators/job_management.py | 5 +++++ .../lib/decorators/pipeline_management.py | 5 +++++ src/mavedb/worker/lib/decorators/utils.py | 20 +++++++++++++++++++ tests/conftest.py | 11 ++++++++++ 5 files changed, 42 insertions(+) create mode 100644 src/mavedb/worker/lib/decorators/utils.py diff --git a/src/mavedb/worker/lib/decorators/job_guarantee.py b/src/mavedb/worker/lib/decorators/job_guarantee.py index fb118b3a..2f464e47 100644 --- a/src/mavedb/worker/lib/decorators/job_guarantee.py +++ b/src/mavedb/worker/lib/decorators/job_guarantee.py @@ -25,6 +25,7 @@ async def my_cron_job(ctx, ...): from mavedb import __version__ from mavedb.models.enums.job_pipeline import JobStatus from mavedb.models.job_run import JobRun +from mavedb.worker.lib.decorators.utils import is_test_mode from mavedb.worker.lib.managers.types import JobResultData F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 0da0e7fd..86068a40 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -13,6 +13,7 @@ from arq import ArqRedis from sqlalchemy.orm import Session +from mavedb.worker.lib.decorators.utils import is_test_mode from mavedb.worker.lib.managers import JobManager from mavedb.worker.lib.managers.types import JobResultData @@ -62,6 +63,10 @@ async def my_job_function(ctx, param1, param2, job_manager: JobManager): @functools.wraps(func) async def async_wrapper(*args, **kwargs): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) + return await _execute_managed_job(func, args, kwargs) return cast(F, async_wrapper) diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index 09bca4c6..0e8944bc 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -16,6 +16,7 @@ from mavedb.models.job_run import JobRun from mavedb.worker.lib.decorators import with_job_management +from mavedb.worker.lib.decorators.utils import is_test_mode from mavedb.worker.lib.managers import PipelineManager from mavedb.worker.lib.managers.types import JobResultData @@ -70,6 +71,10 @@ async def my_job_function(ctx, param1, param2): @functools.wraps(func) async def async_wrapper(*args, **kwargs): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) + return await _execute_managed_pipeline(func, args, kwargs) return cast(F, async_wrapper) diff --git a/src/mavedb/worker/lib/decorators/utils.py b/src/mavedb/worker/lib/decorators/utils.py new file mode 100644 index 00000000..373d72b3 --- /dev/null +++ b/src/mavedb/worker/lib/decorators/utils.py @@ -0,0 +1,20 @@ +import os + + +def is_test_mode() -> bool: + """Check if the application is running in test mode based on the MAVEDB_TEST_MODE environment variable. + + Returns: + bool: True if in test mode, False otherwise. + """ + # Although not ideal, we use an environment variable to detect whether + # the application is in test mode. In the context of decorators, test + # mode makes them no-ops to facilitate unit testing without side effects. + # + # This is necessary because decorators are applied at import time, making + # it difficult to mock their behavior in tests when they must be imported + # up front and provided to the ARQ worker. + # + # This pattern allows us to control decorator behavior in tests without + # altering production code paths. + return os.getenv("MAVEDB_TEST_MODE") == "1" diff --git a/tests/conftest.py b/tests/conftest.py index b11f728c..f93804b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import logging # noqa: F401 +import os import sys from datetime import datetime from unittest import mock @@ -336,3 +337,13 @@ def test_needing_publication_identifier_mock(mock_publication_fetch, ...): mocked_publications.append(publication_to_mock) # Return a single dict (original behavior) if only one was provided; otherwise the list. return mocked_publications[0] if len(mocked_publications) == 1 else mocked_publications + + +# Automatically set MAVEDB_TEST_MODE=1 for unit tests, unset for integration tests. +@pytest.fixture(autouse=True) +def set_mavedb_test_mode_flag(request): + # If 'unit' marker is present, set the flag; otherwise, unset it. + if request.node.get_closest_marker("unit"): + os.environ["MAVEDB_TEST_MODE"] = "1" + else: + os.environ.pop("MAVEDB_TEST_MODE", None) From 3c4e6b982d15daa9340ee6baadcf021075de24a0 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 21 Jan 2026 15:23:52 -0800 Subject: [PATCH 18/70] fix: simplify exc handling in job management decorator Additionally contains some small updates to how decorator unit tests handle the new test mode flag. --- .../worker/lib/decorators/job_guarantee.py | 6 +++ .../worker/lib/decorators/job_management.py | 14 ++--- .../lib/decorators/pipeline_management.py | 44 ++++++++-------- .../lib/decorators/test_job_guarantee.py | 51 ++++++++++--------- .../lib/decorators/test_job_management.py | 35 +++++++------ .../decorators/test_pipeline_management.py | 8 +++ 6 files changed, 90 insertions(+), 68 deletions(-) diff --git a/src/mavedb/worker/lib/decorators/job_guarantee.py b/src/mavedb/worker/lib/decorators/job_guarantee.py index 2f464e47..5dabf8ff 100644 --- a/src/mavedb/worker/lib/decorators/job_guarantee.py +++ b/src/mavedb/worker/lib/decorators/job_guarantee.py @@ -6,6 +6,12 @@ decorators (such as with_job_management) to provide a consistent audit trail and robust error handling for all job entrypoints, including cron-triggered jobs. +NOTE +- This decorator must be applied before any job management decorators. +- This decorator is not supported as part of pipeline management; stacking it + with pipeline management decorators is not allowed and it should only be used with + standalone jobs. + Features: - Persists JobRun with job_type, function name, and parameters - Integrates cleanly with managed job and pipeline decorators diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 86068a40..37120929 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -167,18 +167,20 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar return result except Exception as inner_e: - logger.error(f"Failed to mark job {job_id} as failed: {inner_e}") + logger.critical(f"Failed to mark job {job_id} as failed: {inner_e}") # TODO: Notification hooks # Re-raise the outer exception immediately to prevent duplicate notifications - raise e + finally: + logger.error(f"Job {job_id} failed: {e}") - logger.error(f"Job {job_id} failed: {e}") - - # TODO: Notification hooks + # TODO: Notification hooks - raise # Re-raise the exception + # Swallow the exception after alerting so ARQ can finish the job cleanly and log results. + # We don't mind that we lose ARQs built in job marking, since we perform our own job + # lifecycle management via with_job_management. + return result # Export decorator at module level for easy import diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index 0e8944bc..a254e043 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -159,34 +159,32 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] db_session.commit() except Exception as inner_e: - logger.error( + logger.critical( f"Unable to perform cleanup coordination on pipeline {pipeline_id} associated with job {job_id} after error: {inner_e}" ) # No further work here. We can rely on the notification hooks below to alert on the original failure # and should allow result generation to proceed as normal so the job can be logged. - - logger.error(f"Pipeline {pipeline_id} associated with job {job_id} failed to coordinate: {e}") - - # Build job result data for failure - result = { - "status": "failed", - "data": {}, - "exception_details": { - "type": type(e).__name__, - "message": str(e), - "traceback": None, # Could be populated with actual traceback if needed - }, - } - - # TODO: Notification hooks - - # Pipeline coordination represents the outermost operation. Swallow the exception after alerting - # so ARQ can finish the job cleanly and log results. We don't mind that we lose ARQs built in - # job marking, since we perform our own job lifecycle management via with_job_management. - return result - - # Note: No finally block needed - PipelineManager handles cleanup automatically + finally: + logger.error(f"Pipeline {pipeline_id} associated with job {job_id} failed to coordinate: {e}") + + # Build job result data for failure + result = { + "status": "failed", + "data": {}, + "exception_details": { + "type": type(e).__name__, + "message": str(e), + "traceback": None, # Could be populated with actual traceback if needed + }, + } + + # TODO: Notification hooks + + # Swallow the exception after alerting so ARQ can finish the job cleanly and log results. + # We don't mind that we lose ARQs built in job marking, since we perform our own job + # lifecycle management via with_job_management. + return result # Export decorator at module level for easy import diff --git a/tests/worker/lib/decorators/test_job_guarantee.py b/tests/worker/lib/decorators/test_job_guarantee.py index 3da60c87..cfdc40a1 100644 --- a/tests/worker/lib/decorators/test_job_guarantee.py +++ b/tests/worker/lib/decorators/test_job_guarantee.py @@ -4,9 +4,13 @@ Covers JobRun creation, status transitions, error handling, and DB persistence. """ +import pytest + +pytest.importorskip("arq") # Skip tests if arq is not installed + +import os from unittest.mock import MagicMock, patch -import pytest from sqlalchemy import select from mavedb import __version__ @@ -16,14 +20,31 @@ from tests.helpers.transaction_spy import TransactionSpy +# Unset test mode flag before each test to ensure decorator logic is executed +# during unit testing of the decorator itself. +@pytest.fixture(autouse=True) +def unset_test_mode_flag(): + os.environ.pop("MAVEDB_TEST_MODE", None) + + +@with_guaranteed_job_run_record("test_job") +async def sample_job(ctx: dict, job_id: int): + """Sample job function to test the decorator. + + NOTE: The job_id parameter is injected by the decorator + and is not passed explicitly when calling the function. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + return {"status": "ok"} + + @pytest.mark.asyncio @pytest.mark.unit class TestJobGuaranteeDecoratorUnit: async def test_decorator_must_receive_ctx_as_first_argument(self, mock_worker_ctx): - @with_guaranteed_job_run_record("test_job") - async def sample_job(not_ctx: dict): - return {"status": "ok"} - with pytest.raises(ValueError) as exc_info: await sample_job() @@ -32,38 +53,24 @@ async def sample_job(not_ctx: dict): async def test_decorator_must_receive_db_in_ctx(self, mock_worker_ctx): del mock_worker_ctx["db"] - @with_guaranteed_job_run_record("test_job") - async def sample_job(not_ctx: dict): - return {"status": "ok"} - with pytest.raises(ValueError) as exc_info: await sample_job(mock_worker_ctx) assert "DB session not found in job context" in str(exc_info.value) async def test_decorator_calls_wrapped_function(self, mock_worker_ctx): - @with_guaranteed_job_run_record("test_job") - async def sample_job(ctx: dict): - return {"status": "ok"} - with patch("mavedb.worker.lib.decorators.job_guarantee.JobRun") as MockJobRunClass: MockJobRunClass.return_value = MagicMock(spec=JobRun) - result = await sample_job(mock_worker_ctx) assert result == {"status": "ok"} async def test_decorator_creates_job_run(self, mock_worker_ctx, mock_job_run): - @with_guaranteed_job_run_record("test_job") - async def sample_job(ctx: dict): - return {"status": "ok"} - with ( TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), patch("mavedb.worker.lib.decorators.job_guarantee.JobRun") as mock_job_run_class, ): mock_job_run_class.return_value = MagicMock(spec=JobRun) - await sample_job(mock_worker_ctx) mock_job_run_class.assert_called_with( @@ -79,10 +86,6 @@ async def sample_job(ctx: dict): @pytest.mark.integration class TestJobGuaranteeDecoratorIntegration: async def test_decorator_persists_job_run_record(self, session, standalone_worker_context): - @with_guaranteed_job_run_record("integration_job") - async def sample_job(ctx: dict): - return {"status": "ok"} - # Flush called implicitly by commit with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): job_task = await sample_job(standalone_worker_context) @@ -91,6 +94,6 @@ async def sample_job(ctx: dict): job_run = session.execute(select(JobRun).order_by(JobRun.id.desc())).scalars().first() assert job_run.status == JobStatus.PENDING - assert job_run.job_type == "integration_job" + assert job_run.job_type == "test_job" assert job_run.job_function == "sample_job" assert job_run.mavedb_version is not None diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index 2f689cbe..6a60199b 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -10,6 +10,7 @@ pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio +import os from unittest.mock import patch from sqlalchemy import select @@ -23,6 +24,13 @@ from tests.helpers.transaction_spy import TransactionSpy +# Unset test mode flag before each test to ensure decorator logic is executed +# during unit testing of the decorator itself. +@pytest.fixture(autouse=True) +def unset_test_mode_flag(): + os.environ.pop("MAVEDB_TEST_MODE", None) + + @pytest.mark.asyncio @pytest.mark.unit class TestManagedJobDecoratorUnit: @@ -79,7 +87,6 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): raise RuntimeError("error in wrapped function") with ( - pytest.raises(RuntimeError), patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, patch.object(mock_job_manager, "should_retry", return_value=False), @@ -128,7 +135,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): assert missing_key.replace("_", " ") in str(exc_info.value).lower() assert "not found in job context" in str(exc_info.value).lower() - async def test_decorator_propagates_exception_from_lifecycle_state_outside_except( + async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( self, mock_job_manager, mock_worker_ctx ): @with_job_management @@ -136,17 +143,16 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): return {"status": "ok"} with ( - pytest.raises(JobStateError) as exc_info, patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", side_effect=JobStateError("error in job start")), patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_rollback=True), + TransactionSpy.spy(mock_worker_ctx["db"], expect_rollback=True, expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager - await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + result = await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) - assert "error in job start" in str(exc_info.value) + assert "error in job start" in result["exception_details"]["message"] async def test_decorator_raises_value_error_if_job_id_missing(self, mock_job_manager, mock_worker_ctx): @with_job_management @@ -159,7 +165,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): assert "job id not found in pipeline context" in str(exc_info.value).lower() - async def test_decorator_propagates_exception_from_wrapped_function_inside_except( + async def test_decorator_swallows_exception_from_wrapped_function_inside_except( self, mock_job_manager, mock_worker_ctx ): @with_job_management @@ -167,18 +173,17 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): raise RuntimeError("error in wrapped function") with ( - pytest.raises(RuntimeError) as exc_info, patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None), patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", side_effect=JobStateError("error in job fail")), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=False, expect_rollback=True), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager - await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + result = await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) # Errors within the main try block should take precedence - assert "error in wrapped function" in str(exc_info.value) + assert "error in wrapped function" in result["exception_details"]["message"] async def test_decorator_passes_job_manager_to_wrapped(self, mock_job_manager, mock_worker_ctx): @with_job_management @@ -248,14 +253,14 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): assert job.status == JobStatus.RUNNING # Now allow the job to complete with failure. This failure - # should be propagated out of the job_task. - with pytest.raises(RuntimeError): - event.set() - await job_task + # should be swallowed by the job_task. + event.set() + await job_task # After failure, status should be FAILED job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.FAILED + assert job.error_message == "Simulated job failure" async def test_decorator_integrated_job_lifecycle_retry( self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index eb843aac..738d2ca3 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -10,6 +10,7 @@ pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio +import os from unittest.mock import MagicMock, patch from sqlalchemy import select @@ -23,6 +24,13 @@ from tests.helpers.transaction_spy import TransactionSpy +# Unset test mode flag before each test to ensure decorator logic is executed +# during unit testing of the decorator itself. +@pytest.fixture(autouse=True) +def unset_test_mode_flag(): + os.environ.pop("MAVEDB_TEST_MODE", None) + + @pytest.mark.asyncio @pytest.mark.unit class TestPipelineManagementDecoratorUnit: From 9dd71ff63b0dc28fa1441e7be1559fe521a1c37a Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 21 Jan 2026 21:44:46 -0800 Subject: [PATCH 19/70] feat: allow pipelines to be started by decorated jobs --- .../lib/decorators/pipeline_management.py | 10 +- .../decorators/test_pipeline_management.py | 105 ++++++++++++++---- 2 files changed, 94 insertions(+), 21 deletions(-) diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index a254e043..3bede53f 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -14,6 +14,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.models.enums.job_pipeline import PipelineStatus from mavedb.models.job_run import JobRun from mavedb.worker.lib.decorators import with_job_management from mavedb.worker.lib.decorators.utils import is_test_mode @@ -125,7 +126,14 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] if pipeline_id: pipeline_manager = PipelineManager(db=db_session, redis=redis_pool, pipeline_id=pipeline_id) - logger.info(f"Pipeline ID for job {job_id} is {pipeline_id}. Coordinating pipeline after job execution.") + logger.info(f"Pipeline ID for job {job_id} is {pipeline_id}. Coordinating pipeline.") + + # If the pipeline is still in the created state, start it now + if pipeline_manager and pipeline_manager.get_pipeline_status() == PipelineStatus.CREATED: + await pipeline_manager.start_pipeline() + db_session.commit() + + logger.info(f"Pipeline {pipeline_id} associated with job {job_id} started successfully") # Wrap the function with job management, then execute. This ensures both: # - Job lifecycle management is nested within pipeline management diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index 738d2ca3..33e33713 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -98,6 +98,7 @@ async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manage mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) ) as mock_execute, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager @@ -112,7 +113,9 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): mock_execute.assert_called_once() assert result == {"status": "ok"} - async def test_decorator_skips_coordination_when_no_pipeline_exists(self, mock_pipeline_manager, mock_worker_ctx): + async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( + self, mock_pipeline_manager, mock_worker_ctx + ): with ( # patch the with_job_management decorator to be a no-op patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), @@ -121,6 +124,7 @@ async def test_decorator_skips_coordination_when_no_pipeline_exists(self, mock_p mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=None)) ) as mock_execute, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, # We shouldn't expect any commits since no pipeline coordination occurs TransactionSpy.spy(mock_worker_ctx["db"]), ): @@ -134,6 +138,65 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): mock_execute.assert_called_once() mock_coordinate_pipeline.assert_not_called() + mock_start_pipeline.assert_not_called() + assert result == {"status": "ok"} + + async def test_decorator_starts_pipeline_when_in_created_state( + self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + ): + with ( + # patch the with_job_management decorator to be a no-op + patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) + ) as mock_execute, + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_execute.assert_called_once() + mock_start_pipeline.assert_called_once() + assert result == {"status": "ok"} + + @pytest.mark.parametrize( + "pipeline_state", + [status for status in PipelineStatus._member_map_.values() if status != PipelineStatus.CREATED], + ) + async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( + self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline, pipeline_state + ): + with ( + # patch the with_job_management decorator to be a no-op + patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), + patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, + patch.object( + mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) + ) as mock_execute, + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_state), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + @with_pipeline_management + async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): + return {"status": "ok"} + + result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + + mock_execute.assert_called_once() + mock_start_pipeline.assert_not_called() assert result == {"status": "ok"} async def test_decorator_calls_wrapped_function_and_returns_result( @@ -148,7 +211,8 @@ async def test_decorator_calls_wrapped_function_and_returns_result( patch.object( mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) ), - patch.object(mock_pipeline_manager, "get_pipeline", return_value=mock_pipeline), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): @@ -176,8 +240,9 @@ async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrappe patch.object( mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) ), - patch.object(mock_pipeline_manager, "get_pipeline", return_value=mock_pipeline), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager @@ -199,6 +264,8 @@ async def test_decorator_swallows_exception_from_wrapped_function(self, mock_pip ), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager @@ -226,8 +293,11 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip "coordinate_pipeline", side_effect=RuntimeError("error in coordinate_pipeline"), ), - # Exception raised from coordinate_pipeline should trigger rollback but prevent commit - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=False, expect_rollback=True), + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + # Exception raised from coordinate_pipeline should trigger rollback, + # and commit will be called when pipeline status is set to running + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager @@ -252,8 +322,10 @@ def passthrough_decorator(f): wraps=passthrough_decorator, side_effect=ValueError("error in job management decorator"), ) as mock_with_job_mgmt, + patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), + patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=False, expect_rollback=True), + TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager @@ -272,6 +344,7 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): class TestPipelineManagementDecoratorIntegration: """Integration tests for the with_pipeline_management decorator.""" + @pytest.mark.parametrize("initial_status", [PipelineStatus.CREATED, PipelineStatus.RUNNING]) async def test_decorator_integrated_pipeline_lifecycle_success( self, session, @@ -281,14 +354,15 @@ async def test_decorator_integrated_pipeline_lifecycle_success( standalone_worker_context, setup_worker_db, sample_pipeline, + initial_status, ): # Use an event to control when the job completes event = asyncio.Event() dep_event = asyncio.Event() - # Transition pipeline to RUNNING to allow job execution. This step of pipeline management - # is intentionally not handled by the decorator. - sample_pipeline.status = PipelineStatus.RUNNING + # Set initial pipeline status to the parameterized value. + # This allows testing both CREATED and RUNNING start states. + sample_pipeline.status = initial_status session.commit() @with_pipeline_management @@ -377,11 +451,6 @@ async def test_decorator_integrated_pipeline_lifecycle_retryable_failure( retry_event = asyncio.Event() dep_event = asyncio.Event() - # Transition pipeline to RUNNING to allow job execution. This step of pipeline management - # is intentionally not handled by the decorator. - sample_pipeline.status = PipelineStatus.RUNNING - session.commit() - @with_pipeline_management async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals @@ -490,11 +559,6 @@ async def test_decorator_integrated_pipeline_lifecycle_non_retryable_failure( # Use an event to control when the job completes event = asyncio.Event() - # Transition pipeline to RUNNING to allow job execution. This step of pipeline management - # is intentionally not handled by the decorator. - sample_pipeline.status = PipelineStatus.RUNNING - session.commit() - @with_pipeline_management async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals @@ -511,8 +575,9 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() assert pipeline.status == PipelineStatus.RUNNING - # Now allow the job to complete with failure. This failure + # Now allow the job to complete with failure and flush the Redis queue. This failure # should be swallowed by the pipeline manager + await arq_redis.flushdb() event.set() await job_task From 1fe076aba2e8e3bab4dbc49469afddc0321f034c Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 21 Jan 2026 22:54:29 -0800 Subject: [PATCH 20/70] tests: unit tests for worker manager utilities --- tests/worker/lib/managers/test_utils.py | 90 +++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/worker/lib/managers/test_utils.py diff --git a/tests/worker/lib/managers/test_utils.py b/tests/worker/lib/managers/test_utils.py new file mode 100644 index 00000000..a33285b4 --- /dev/null +++ b/tests/worker/lib/managers/test_utils.py @@ -0,0 +1,90 @@ +import pytest + +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus +from mavedb.worker.lib.managers.constants import COMPLETED_JOB_STATUSES +from mavedb.worker.lib.managers.utils import ( + construct_bulk_cancellation_result, + job_dependency_is_met, + job_should_be_skipped_due_to_unfulfillable_dependency, +) + + +@pytest.mark.unit +class TestConstructBulkCancellationResultUnit: + def test_construct_bulk_cancellation_result(self): + reason = "Test cancellation reason" + result = construct_bulk_cancellation_result(reason) + + assert result["status"] == "cancelled" + assert result["data"]["reason"] == reason + assert "timestamp" in result["data"] + assert result["exception_details"] is None + + +@pytest.mark.unit +class TestJobDependencyIsMetUnit: + @pytest.mark.parametrize( + "dependency_type, dependent_job_status, expected", + [ + (None, "any_status", True), + # success required dependencies-- should only be met if dependent job succeeded + (DependencyType.SUCCESS_REQUIRED, JobStatus.SUCCEEDED, True), + *[ + (DependencyType.SUCCESS_REQUIRED, dependent_job_status, False) + for dependent_job_status in JobStatus._member_map_.values() + if dependent_job_status != JobStatus.SUCCEEDED + ], + # completion required dependencies-- should be met if dependent job is in any terminal state + *[ + ( + DependencyType.COMPLETION_REQUIRED, + dependent_job_status, + dependent_job_status in COMPLETED_JOB_STATUSES, + ) + for dependent_job_status in JobStatus._member_map_.values() + ], + ], + ) + def test_job_dependency_is_met(self, dependency_type, dependent_job_status, expected): + result = job_dependency_is_met(dependency_type, dependent_job_status) + assert result == expected + + +@pytest.mark.unit +class TestJobShouldBeSkippedDueToUnfulfillableDependencyUnit: + @pytest.mark.parametrize( + "dependency_type, dependent_job_status, expected", + [ + # No dependency-- should not be skipped + (None, "any_status", False), + # success required dependencies-- should be skipped if dependent job in terminal non-success state + (DependencyType.SUCCESS_REQUIRED, JobStatus.SUCCEEDED, False), + *[ + ( + DependencyType.SUCCESS_REQUIRED, + dependent_job_status, + dependent_job_status in (JobStatus.FAILED, JobStatus.SKIPPED, JobStatus.CANCELLED), + ) + for dependent_job_status in JobStatus._member_map_.values() + ], + # completion required dependencies-- should be skipped if dependent job is not in a terminal state + *[ + ( + DependencyType.COMPLETION_REQUIRED, + dependent_job_status, + dependent_job_status in (JobStatus.CANCELLED, JobStatus.SKIPPED), + ) + for dependent_job_status in JobStatus._member_map_.values() + ], + ], + ) + def test_job_should_be_skipped_due_to_unfulfillable_dependency( + self, dependency_type, dependent_job_status, expected + ): + result = job_should_be_skipped_due_to_unfulfillable_dependency(dependency_type, dependent_job_status) + + if expected: + assert result[0] is True + assert isinstance(result[1], str) + else: + assert result == (False, None) From 16a5a509cd6cce35f5677fdd7adf81f9a9a1142e Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 22 Jan 2026 10:28:32 -0800 Subject: [PATCH 21/70] feat: add network test marker and control socket access in pytest --- pyproject.toml | 3 ++- tests/conftest.py | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9538bff..f1217384 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ plugins = [ mypy_path = "mypy_stubs" [tool.pytest.ini_options] -addopts = "-v -rP --import-mode=importlib --disable-socket --allow-unix-socket --allow-hosts localhost,::1,127.0.0.1" +addopts = "-v -rP --import-mode=importlib" asyncio_mode = 'strict' testpaths = "tests/" pythonpath = "." @@ -108,6 +108,7 @@ norecursedirs = "tests/helpers/" markers = """ integration: mark a test as an integration test. unit: mark a test as a unit test. + network: mark a test that requires network access. slow: mark a test as slow-running. """ # Uncomment the following lines to include application log output in Pytest logs. diff --git a/tests/conftest.py b/tests/conftest.py index f93804b2..0cb869fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,8 @@ import email_validator import pytest import pytest_postgresql -from sqlalchemy import create_engine +import pytest_socket +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool @@ -58,6 +59,21 @@ email_validator.TEST_ENVIRONMENT = True +def pytest_runtest_setup(item): + # Only block sockets for tests not marked with 'network' + if "network" not in item.keywords: + try: + pytest_socket.socket_allow_hosts(["localhost", "127.0.0.1", "::1"], allow_unix_socket=True) + except ImportError: + pass + + else: + try: + pytest_socket.enable_socket() + except ImportError: + pass + + @pytest.fixture() def session(postgresql): # Un-comment this line to log all database queries: @@ -73,6 +89,15 @@ def session(postgresql): Base.metadata.create_all(bind=engine) + # Create a unique index for the published_variants_materialized_view to + # enforce uniqueness on (variant_id, mapped_variant_id, score_set_id). This + # allows us to test mat view refreshes that require this constraint. + session.execute( + text("""CREATE UNIQUE INDEX IF NOT EXISTS published_variants_mv_unique_idx + ON published_variants_materialized_view (variant_id, mapped_variant_id, score_set_id)"""), + ) + session.commit() + try: yield session finally: From a701d53b899edfdb90ab50e32bc33971d4887fc6 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 22 Jan 2026 13:40:33 -0800 Subject: [PATCH 22/70] Refactor test setup by replacing `setup_worker_db` with `with_populated_job_data` - Updated test files to use `with_populated_job_data` fixture for populating the database with sample job and pipeline data. - Removed the `setup_worker_db` fixture from various test cases in job and pipeline management tests. - Added new sample job and pipeline fixtures in `conftest.py` to streamline test data creation. - Improved clarity and maintainability of tests by consolidating data setup logic. --- tests/worker/conftest.py | 173 +++++++++++++++++- .../lib/decorators/test_job_management.py | 6 +- .../decorators/test_pipeline_management.py | 6 +- tests/worker/lib/managers/test_job_manager.py | 84 +++++---- .../lib/managers/test_pipeline_manager.py | 142 +++++++------- 5 files changed, 289 insertions(+), 122 deletions(-) diff --git a/tests/worker/conftest.py b/tests/worker/conftest.py index cf996c1d..eef66d03 100644 --- a/tests/worker/conftest.py +++ b/tests/worker/conftest.py @@ -1,3 +1,7 @@ +""" +Test configuration and fixtures for worker lib tests. +""" + from datetime import datetime from pathlib import Path from shutil import copytree @@ -5,7 +9,8 @@ import pytest -from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus +from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun from mavedb.models.license import License from mavedb.models.pipeline import Pipeline @@ -15,14 +20,111 @@ EXTRA_USER, TEST_INACTIVE_LICENSE, TEST_LICENSE, - TEST_MAVEDB_ATHENA_ROW, TEST_SAVED_TAXONOMY, TEST_USER, ) +# Attempt to import optional top level fixtures. If the modules they depend on are not installed, +# we won't have access to our full fixture suite and only a limited subset of tests can be run. +try: + from .conftest_optional import * # noqa: F401, F403 + +except ModuleNotFoundError: + pass + + +@pytest.fixture +def sample_job_run(): + """Create a sample JobRun instance for testing.""" + return JobRun( + id=1, + urn="test:job:1", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=1, + progress_current=0, + progress_total=100, + progress_message="Ready to start", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_dependent_job_run(): + """Create a sample dependent JobRun instance for testing.""" + return JobRun( + id=2, + urn="test:job:2", + job_type="dependent_job", + job_function="dependent_function", + status=JobStatus.PENDING, + pipeline_id=1, + progress_current=0, + progress_total=100, + progress_message="Waiting for dependency", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_independent_job_run(): + """Create a sample independent JobRun instance for testing.""" + return JobRun( + id=3, + urn="test:job:3", + job_type="independent_job", + job_function="independent_function", + status=JobStatus.PENDING, + pipeline_id=None, + progress_current=0, + progress_total=100, + progress_message="Ready to start", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_pipeline(): + """Create a sample Pipeline instance for testing.""" + return Pipeline( + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_empty_pipeline(): + """Create a sample Pipeline instance with no jobs for testing.""" + return Pipeline( + id=999, + urn="test:pipeline:999", + name="Empty Pipeline", + description="A pipeline with no jobs", + status=PipelineStatus.CREATED, + correlation_id="empty_correlation_456", + created_at=datetime.now(), + ) + + +@pytest.fixture +def sample_job_dependency(): + """Create a sample JobDependency instance for testing.""" + return JobDependency( + id=2, # dependent job + depends_on_job_id=1, # depends on job 1 + dependency_type=DependencyType.SUCCESS_REQUIRED, + created_at=datetime.now(), + ) + @pytest.fixture -def setup_worker_db(session): +def with_populated_domain_data(session): db = session db.add(User(**TEST_USER)) db.add(User(**EXTRA_USER)) @@ -116,10 +218,65 @@ def data_files(tmp_path): @pytest.fixture -def mocked_gnomad_variant_row(): - gnomad_variant = Mock() +def mock_pipeline(): + """Create a mock Pipeline instance. By default, + properties are identical to a default new Pipeline entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=Pipeline, + id=1, + urn="test:pipeline:1", + name="Test Pipeline", + description="A test pipeline", + status=PipelineStatus.CREATED, + correlation_id="test_correlation_123", + metadata_={}, + created_at=datetime.now(), + started_at=None, + finished_at=None, + created_by_user_id=None, + mavedb_version=None, + ) + + +@pytest.fixture +def mock_job_run(mock_pipeline): + """Create a mock JobRun instance. By default, + properties are identical to a default new JobRun entered into the db + with sensible defaults for non-nullable but unset fields. + """ + return Mock( + spec=JobRun, + id=123, + urn="test:job:123", + job_type="test_job", + job_function="test_function", + status=JobStatus.PENDING, + pipeline_id=mock_pipeline.id, + priority=0, + max_retries=3, + retry_count=0, + retry_delay_seconds=None, + scheduled_at=datetime.now(), + started_at=None, + finished_at=None, + created_at=datetime.now(), + error_message=None, + error_traceback=None, + failure_category=None, + worker_id=None, + worker_host=None, + progress_current=None, + progress_total=None, + progress_message=None, + correlation_id=None, + metadata_={}, + mavedb_version=None, + ) - for key, value in TEST_MAVEDB_ATHENA_ROW.items(): - setattr(gnomad_variant, key, value) - return gnomad_variant +@pytest.fixture +def data_files(tmp_path): + copytree(Path(__file__).absolute().parent / "data", tmp_path / "data") + return tmp_path / "data" diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index 6a60199b..d22a37ee 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -207,7 +207,7 @@ class TestManagedJobDecoratorIntegration: """Integration tests for with_job_management decorator.""" async def test_decorator_integrated_job_lifecycle_success( - self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data ): # Use an event to control when the job completes event = asyncio.Event() @@ -234,7 +234,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): assert job.status == JobStatus.SUCCEEDED async def test_decorator_integrated_job_lifecycle_failure( - self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data ): # Use an event to control when the job completes event = asyncio.Event() @@ -263,7 +263,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): assert job.error_message == "Simulated job failure" async def test_decorator_integrated_job_lifecycle_retry( - self, session, arq_redis, sample_job_run, standalone_worker_context, setup_worker_db + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data ): # Use an event to control when the job completes event = asyncio.Event() diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index 33e33713..f7b2bc1e 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -352,7 +352,7 @@ async def test_decorator_integrated_pipeline_lifecycle_success( sample_job_run, sample_dependent_job_run, standalone_worker_context, - setup_worker_db, + with_populated_job_data, sample_pipeline, initial_status, ): @@ -443,7 +443,7 @@ async def test_decorator_integrated_pipeline_lifecycle_retryable_failure( sample_job_run, sample_dependent_job_run, standalone_worker_context, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): # Use an event to control when the job completes @@ -553,7 +553,7 @@ async def test_decorator_integrated_pipeline_lifecycle_non_retryable_failure( sample_job_run, sample_dependent_job_run, standalone_worker_context, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): # Use an event to control when the job completes diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py index ca54c18e..3806ac68 100644 --- a/tests/worker/lib/managers/test_job_manager.py +++ b/tests/worker/lib/managers/test_job_manager.py @@ -46,7 +46,7 @@ class TestJobManagerInitialization: """Test JobManager initialization and setup.""" - def test_init_with_valid_job(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_init_with_valid_job(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful initialization with valid job ID.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -54,7 +54,7 @@ def test_init_with_valid_job(self, session, arq_redis, setup_worker_db, sample_j assert manager.job_id == sample_job_run.id assert manager.pipeline_id == sample_job_run.pipeline_id - def test_init_with_no_pipeline(self, session, arq_redis, setup_worker_db, sample_independent_job_run): + def test_init_with_no_pipeline(self, session, arq_redis, with_populated_job_data, sample_independent_job_run): """Test initialization with job that has no pipeline.""" manager = JobManager(session, arq_redis, sample_independent_job_run.id) @@ -164,7 +164,7 @@ class TestJobStartIntegration: [status for status in JobStatus._member_map_.values() if status not in STARTABLE_JOB_STATUSES], ) def test_job_exception_is_raised_when_job_has_invalid_status( - self, session, arq_redis, setup_worker_db, sample_job_run, invalid_status + self, session, arq_redis, with_populated_job_data, sample_job_run, invalid_status ): """Test job start failure due to invalid job status.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -191,7 +191,7 @@ def test_job_exception_is_raised_when_job_has_invalid_status( "valid_status", [status for status in JobStatus._member_map_.values() if status in STARTABLE_JOB_STATUSES], ) - def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run, valid_status): + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run, valid_status): """Test successful job start.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -351,7 +351,7 @@ class TestJobCompletionIntegration: [status for status in JobStatus._member_map_.values() if status not in TERMINAL_JOB_STATUSES], ) def test_job_exception_is_raised_when_job_has_invalid_status( - self, session, arq_redis, setup_worker_db, sample_job_run, invalid_status + self, session, arq_redis, with_populated_job_data, sample_job_run, invalid_status ): """Test job completion failure due to invalid job status.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -376,7 +376,7 @@ def test_job_exception_is_raised_when_job_has_invalid_status( [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], ) def test_job_updated_successfully_without_error( - self, session, arq_redis, setup_worker_db, sample_job_run, valid_status + self, session, arq_redis, with_populated_job_data, sample_job_run, valid_status ): """Test successful job completion.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -409,7 +409,7 @@ def test_job_updated_successfully_without_error( [status for status in JobStatus._member_map_.values() if status in TERMINAL_JOB_STATUSES], ) def test_job_updated_successfully_with_error( - self, session, arq_redis, setup_worker_db, sample_job_run, valid_status + self, session, arq_redis, with_populated_job_data, sample_job_run, valid_status ): """Test successful job completion.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -466,7 +466,7 @@ def test_fail_job_success(self, mock_job_manager, mock_job_run): class TestJobFailureIntegration: """Test job failure lifecycle management.""" - def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job failure.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -519,7 +519,7 @@ def test_succeed_job_success(self, mock_job_manager, mock_job_run): class TestJobSuccessIntegration: """Test job success lifecycle management.""" - def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job succeeding.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -572,7 +572,7 @@ def test_cancel_job_success(self, mock_job_manager, mock_job_run): class TestJobCancellationIntegration: """Test job cancellation lifecycle management.""" - def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job cancellation.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -626,7 +626,7 @@ def test_skip_job_success(self, mock_job_manager, mock_job_run): class TestJobSkipIntegration: """Test job skip lifecycle management.""" - def test_job_updated_successfully(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_job_updated_successfully(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job skipping.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -768,7 +768,7 @@ class TestPrepareRetryIntegration: [status for status in JobStatus._member_map_.values() if status not in RETRYABLE_JOB_STATUSES], ) def test_prepare_retry_failed_due_to_invalid_status( - self, session, arq_redis, setup_worker_db, sample_job_run, job_status + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status ): """Test job retry failure due to invalid job status.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -786,7 +786,7 @@ def test_prepare_retry_failed_due_to_invalid_status( ): manager.prepare_retry() - def test_prepare_retry_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_prepare_retry_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job retry.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -908,7 +908,7 @@ class TestPrepareQueue: [status for status in JobStatus._member_map_.values() if status != JobStatus.PENDING], ) def test_prepare_queue_failed_due_to_invalid_status( - self, session, arq_redis, setup_worker_db, sample_job_run, job_status + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status ): """Test job prepare for queue failure due to invalid job status.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -929,7 +929,7 @@ def test_prepare_queue_failed_due_to_invalid_status( ): manager.prepare_queue() - def test_prepare_queue_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_prepare_queue_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job prepare for queue.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1028,7 +1028,7 @@ def test_reset_job_success(self, mock_job_manager, mock_job_run): class TestResetJobIntegration: """Test job reset lifecycle management.""" - def test_reset_job_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_reset_job_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job reset.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1141,7 +1141,7 @@ def test_update_progress_does_not_overwrite_old_message_when_no_new_message_is_p class TestJobProgressUpdateIntegration: """Test job progress update lifecycle management.""" - def test_update_progress_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_update_progress_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful progress update.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1166,7 +1166,7 @@ def test_update_progress_success(self, session, arq_redis, setup_worker_db, samp assert job.progress_message == "Halfway done" def test_update_progress_success_does_not_overwrite_old_message_when_no_new_message_is_provided( - self, session, arq_redis, setup_worker_db, sample_job_run + self, session, arq_redis, with_populated_job_data, sample_job_run ): """Test successful progress update without message.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1243,7 +1243,7 @@ def test_update_status_message_success(self, mock_job_manager, mock_job_run): class TestJobProgressStatusUpdate: """Test job progress status update lifecycle management.""" - def test_update_status_message_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_update_status_message_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful status message update.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1338,7 +1338,7 @@ class TestJobProgressIncrementationIntegration: "msg", [None, "Incremented progress successfully"], ) - def test_increment_progress_success(self, session, arq_redis, setup_worker_db, sample_job_run, msg): + def test_increment_progress_success(self, session, arq_redis, with_populated_job_data, sample_job_run, msg): """Test successful progress incrementation.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1364,7 +1364,9 @@ def test_increment_progress_success(self, session, arq_redis, setup_worker_db, s msg if msg else "Test incrementation message" ) # Message should remain unchanged if None - def test_increment_progress_success_multiple_times(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_increment_progress_success_multiple_times( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): """Test successful progress incrementation multiple times.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1387,7 +1389,9 @@ def test_increment_progress_success_multiple_times(self, session, arq_redis, set assert job.progress_current == 50 assert job.progress_total == 100 - def test_increment_progress_success_exceeding_total(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_increment_progress_success_exceeding_total( + self, session, arq_redis, with_populated_job_data, sample_job_run + ): """Test successful progress incrementation exceeding total.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1477,7 +1481,7 @@ def test_set_progress_total_does_not_overwrite_old_message_when_no_new_message_i class TestJobProgressTotalUpdateIntegration: """Test job progress total update lifecycle management.""" - def test_set_progress_total_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_set_progress_total_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful progress total update.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1528,7 +1532,9 @@ class TestJobIsCancelledIntegration: "job_status", [status for status in JobStatus._member_map_.values() if status in CANCELLED_JOB_STATUSES], ) - def test_is_cancelled_success_cancelled(self, session, arq_redis, setup_worker_db, sample_job_run, job_status): + def test_is_cancelled_success_cancelled( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): """Test successful is_cancelled check when cancelled.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1548,7 +1554,9 @@ def test_is_cancelled_success_cancelled(self, session, arq_redis, setup_worker_d "job_status", [status for status in JobStatus._member_map_.values() if status not in CANCELLED_JOB_STATUSES], ) - def test_is_cancelled_success_not_cancelled(self, session, arq_redis, setup_worker_db, sample_job_run, job_status): + def test_is_cancelled_success_not_cancelled( + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status + ): """Test successful is_cancelled check when not cancelled.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1687,7 +1695,7 @@ class TestJobShouldRetryIntegration: [status for status in JobStatus._member_map_.values() if status != JobStatus.FAILED], ) def test_should_retry_success_non_failed_jobs_should_not_retry( - self, session, arq_redis, setup_worker_db, sample_job_run, job_status + self, session, arq_redis, with_populated_job_data, sample_job_run, job_status ): """Test successful should_retry check (only jobs in failed states may retry).""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1705,7 +1713,7 @@ def test_should_retry_success_non_failed_jobs_should_not_retry( assert result is False def test_should_retry_success_exceeded_retry_attempts_should_not_retry( - self, session, arq_redis, setup_worker_db, sample_job_run + self, session, arq_redis, with_populated_job_data, sample_job_run ): """Test successful should_retry check with no retry attempts left.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1725,7 +1733,7 @@ def test_should_retry_success_exceeded_retry_attempts_should_not_retry( assert result is False def test_should_retry_success_failure_category_is_not_retryable( - self, session, arq_redis, setup_worker_db, sample_job_run + self, session, arq_redis, with_populated_job_data, sample_job_run ): """Test successful should_retry check with non-retryable failure category.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1745,7 +1753,7 @@ def test_should_retry_success_failure_category_is_not_retryable( # Verify the job should not retry. This method requires no persistance. assert result is False - def test_should_retry_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_should_retry_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful should_retry check with retryable failure category.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1792,7 +1800,7 @@ def test_get_job_wraps_database_connection_error_when_encounters_sqlalchemy_erro class TestGetJobIntegration: """Test job retrieval.""" - def test_get_job_success(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_get_job_success(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test successful job retrieval.""" manager = JobManager(session, arq_redis, sample_job_run.id) @@ -1804,7 +1812,9 @@ def test_get_job_success(self, session, arq_redis, setup_worker_db, sample_job_r assert job.id == sample_job_run.id assert job.status == JobStatus.PENDING - def test_get_job_raises_job_not_found_error_when_job_does_not_exist(self, session, arq_redis, setup_worker_db): + def test_get_job_raises_job_not_found_error_when_job_does_not_exist( + self, session, arq_redis, with_populated_job_data + ): """Test job retrieval failure when job does not exist.""" with pytest.raises(DatabaseConnectionError, match="Failed to fetch job 9999"), TransactionSpy.spy(session): JobManager(session, arq_redis, job_id=9999) # Non-existent job ID @@ -1814,7 +1824,7 @@ def test_get_job_raises_job_not_found_error_when_job_does_not_exist(self, sessio class TestJobManagerJob: """Test overall job lifecycle management.""" - def test_full_successful_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_full_successful_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle from start to completion.""" # Pre-manager: Job is created in DB in Pending state. Verify initial state. job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -1904,7 +1914,7 @@ def test_full_successful_job_lifecycle(self, session, arq_redis, setup_worker_db assert final_job.progress_total == 200 assert final_job.progress_message == "Job completed successfully" - def test_full_cancelled_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_full_cancelled_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a cancelled job.""" # Pre-manager: Job is created in DB in Pending state. Verify initial state. job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -1941,7 +1951,7 @@ def test_full_cancelled_job_lifecycle(self, session, arq_redis, setup_worker_db, assert job.finished_at is not None assert job.progress_message == "Job cancelled" - def test_full_skipped_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_full_skipped_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a skipped job.""" # Pre-manager: Job is created in DB in Pending state. Verify initial state. job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -1959,7 +1969,7 @@ def test_full_skipped_job_lifecycle(self, session, arq_redis, setup_worker_db, s assert job.finished_at is not None assert job.progress_message == "Job skipped" - def test_full_failed_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_full_failed_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a failed job.""" # Pre-manager: Job is created in DB in Pending state. Verify initial state. job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -1997,7 +2007,7 @@ def test_full_failed_job_lifecycle(self, session, arq_redis, setup_worker_db, sa assert job.error_message == "An error occurred" assert job.error_traceback is not None - def test_full_retried_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_full_retried_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a retried job.""" # Pre-manager: Job is created in DB in Pending state. Verify initial state. job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -2049,7 +2059,7 @@ def test_full_retried_job_lifecycle(self, session, arq_redis, setup_worker_db, s assert job.status == JobStatus.PENDING assert job.retry_count == 1 - def test_full_reset_job_lifecycle(self, session, arq_redis, setup_worker_db, sample_job_run): + def test_full_reset_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a reset job.""" # Pre-manager: Job is created in DB in Pending state. Verify initial state. job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() diff --git a/tests/worker/lib/managers/test_pipeline_manager.py b/tests/worker/lib/managers/test_pipeline_manager.py index aedeffb3..5c57ba3f 100644 --- a/tests/worker/lib/managers/test_pipeline_manager.py +++ b/tests/worker/lib/managers/test_pipeline_manager.py @@ -52,7 +52,7 @@ class TestPipelineManagerInitialization: """Test PipelineManager initialization and setup.""" - def test_init_with_valid_pipeline(self, session, arq_redis, setup_worker_db, sample_pipeline): + def test_init_with_valid_pipeline(self, session, arq_redis, with_populated_job_data, sample_pipeline): """Test successful initialization with valid pipeline ID.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -66,7 +66,7 @@ def test_init_with_invalid_pipeline_id(self, session, arq_redis): with pytest.raises(DatabaseConnectionError, match=f"Failed to get pipeline {pipeline_id}"): PipelineManager(session, arq_redis, pipeline_id) - def test_init_with_database_error(self, session, arq_redis, setup_worker_db, sample_pipeline): + def test_init_with_database_error(self, session, arq_redis, with_populated_job_data, sample_pipeline): """Test initialization failure with database connection error.""" pipeline_id = sample_pipeline.id @@ -132,7 +132,7 @@ class TestStartPipelineIntegration: @pytest.mark.asyncio async def test_start_pipeline_successful( - self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run ): """Test successful pipeline start from CREATED state.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -156,7 +156,7 @@ async def test_start_pipeline_successful( assert jobs[0].function == sample_job_run.job_function @pytest.mark.asyncio - async def test_start_pipeline_no_jobs(self, session, arq_redis, setup_worker_db, sample_empty_pipeline): + async def test_start_pipeline_no_jobs(self, session, arq_redis, with_populated_job_data, sample_empty_pipeline): """Test pipeline start when there are no jobs in the pipeline.""" manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) @@ -259,7 +259,7 @@ class TestCoordinatePipelineIntegration: @pytest.mark.asyncio async def test_coordinate_pipeline_transitions_pipeline_to_failed_after_job_failure( - self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run ): """Test successful pipeline coordination and job enqueuing after job completion.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -292,7 +292,7 @@ async def test_coordinate_pipeline_transitions_pipeline_to_failed_after_job_fail @pytest.mark.asyncio async def test_coordinate_pipeline_transitions_pipeline_to_cancelled_after_pipeline_is_cancelled( - self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run ): """Test successful pipeline coordination and job enqueuing after pipeline cancellation .""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -329,7 +329,7 @@ async def test_coordinate_pipeline_transitions_pipeline_to_cancelled_after_pipel @pytest.mark.asyncio async def test_coordinate_running_pipeline_enqueues_ready_jobs( - self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run ): """Test successful pipeline coordination and job enqueuing when jobs are still pending.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -366,7 +366,7 @@ async def test_coordinate_pipeline_noop( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -594,7 +594,7 @@ def test_pipeline_status_transition_noop_when_status_is_terminal( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, initial_status, ): @@ -619,7 +619,7 @@ def test_pipeline_status_transition_noop_when_status_is_paused( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): """Test that pipeline status remains unchanged when in PAUSED state.""" @@ -653,7 +653,7 @@ def test_pipeline_status_transition_when_no_jobs_in_pipeline( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, initial_status, expected_status, sample_empty_pipeline, @@ -705,7 +705,7 @@ def test_pipeline_status_transitions( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, initial_status, job_updates, @@ -842,7 +842,7 @@ async def test_enqueue_ready_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -878,7 +878,7 @@ async def test_enqueue_ready_jobs_integration_with_unreachable_job( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -911,7 +911,7 @@ async def test_enqueue_ready_jobs_integration_with_unreachable_job( @pytest.mark.asyncio async def test_enqueue_ready_jobs_with_empty_pipeline( - self, session, arq_redis, setup_worker_db, sample_empty_pipeline + self, session, arq_redis, with_populated_job_data, sample_empty_pipeline ): """Test enqueuing of ready jobs in an empty pipeline.""" manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) @@ -935,7 +935,7 @@ async def test_enqueue_ready_jobs_bubbles_pipeline_coordination_error_for_any_ex self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -1044,7 +1044,7 @@ def test_cancel_remaining_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1077,7 +1077,7 @@ def test_cancel_remaining_jobs_integration_no_active_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test cancellation of remaining jobs when there are no active jobs.""" @@ -1152,7 +1152,7 @@ async def test_cancel_pipeline_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1193,7 +1193,7 @@ async def test_cancel_pipeline_integration_already_terminal( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -1308,7 +1308,7 @@ async def test_pause_pipeline_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): """Test successful pausing of a pipeline.""" @@ -1379,7 +1379,7 @@ class TestUnpausePipelineIntegration: @pytest.mark.asyncio async def test_unpause_pipeline_integration( - self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run ): """Test successful unpausing of a pipeline.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -1460,7 +1460,7 @@ async def test_restart_pipeline_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1497,7 +1497,7 @@ async def test_restart_pipeline_integration_skips_if_no_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test that restarting a pipeline with no jobs skips without error.""" @@ -1615,7 +1615,7 @@ def test_can_enqueue_job_integration_with_no_dependencies( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -1633,7 +1633,7 @@ def test_can_enqueue_job_integration_with_unmet_dependencies( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_dependent_job_run, ): @@ -1651,7 +1651,7 @@ def test_can_enqueue_job_integration_with_met_dependencies( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1781,7 +1781,7 @@ def test_should_not_skip_job_with_no_dependencies( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -1800,7 +1800,7 @@ def test_should_skip_job_with_unreachable_dependency( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1824,7 +1824,7 @@ def test_should_not_skip_job_with_reachable_dependency( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1906,7 +1906,7 @@ async def test_retry_failed_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -1947,7 +1947,7 @@ async def test_retry_failed_jobs_integration_no_failed_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test that retrying failed jobs skips if there are no failed jobs.""" @@ -2030,7 +2030,7 @@ async def test_retry_unsuccessful_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2071,7 +2071,7 @@ async def test_retry_unsuccessful_jobs_integration_no_unsuccessful_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test that retrying unsuccessful jobs skips if there are no unsuccessful jobs.""" @@ -2122,7 +2122,7 @@ async def test_retry_pipeline_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2185,7 +2185,7 @@ def test_get_jobs_by_status_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2211,7 +2211,7 @@ def test_get_jobs_by_status_integration_no_matching_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): """Test retrieval of jobs by status when no jobs match.""" @@ -2228,7 +2228,7 @@ def test_get_jobs_by_status_integration_multiple_matching_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2255,7 +2255,7 @@ def test_get_jobs_by_status_integration_no_jobs_in_pipeline( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test retrieval of jobs by status when there are no jobs in the pipeline.""" @@ -2272,7 +2272,7 @@ def test_get_jobs_by_status_multiple_statuses( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2326,7 +2326,7 @@ def test_get_pending_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2351,7 +2351,7 @@ def test_get_pending_jobs_integration_no_pending_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2410,7 +2410,7 @@ def test_get_active_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2437,7 +2437,7 @@ def test_get_active_jobs_integration_no_active_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2466,7 +2466,7 @@ def test_get_running_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2491,7 +2491,7 @@ def test_get_running_jobs_integration_no_running_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2536,7 +2536,7 @@ def test_get_failed_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2561,7 +2561,7 @@ def test_get_failed_jobs_integration_no_failed_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2605,7 +2605,7 @@ def test_get_unsuccessful_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2632,7 +2632,7 @@ def test_get_unsuccessful_jobs_integration_no_unsuccessful_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2676,7 +2676,7 @@ def test_get_all_jobs_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2698,7 +2698,7 @@ def test_get_all_jobs_integration_no_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test retrieval of all jobs when there are no jobs in the pipeline.""" @@ -2715,7 +2715,7 @@ def test_get_all_jobs_integration_multiple_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2774,7 +2774,7 @@ def test_get_dependencies_for_job_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2797,7 +2797,7 @@ def test_get_dependencies_for_job_integration_no_dependencies( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -2815,7 +2815,7 @@ def test_get_dependencies_for_job_integration_multiple_dependencies( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2886,7 +2886,7 @@ def test_get_pipeline_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): """Test retrieval of pipeline.""" @@ -2904,7 +2904,7 @@ def test_get_pipeline_integration_nonexistent_pipeline( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, ): """Test retrieval of a nonexistent pipeline raises PipelineNotFoundError.""" with ( @@ -2938,7 +2938,7 @@ def test_get_job_counts_by_status_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -2964,7 +2964,7 @@ def test_get_job_counts_by_status_integration_no_jobs( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_empty_pipeline, ): """Test retrieval of job counts by status when there are no jobs in the pipeline.""" @@ -3018,7 +3018,7 @@ def test_get_pipeline_status_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): """Test retrieval of pipeline status.""" @@ -3139,7 +3139,7 @@ def test_set_pipeline_status_integration( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, pipeline_status, ): @@ -3166,7 +3166,7 @@ def test_set_pipeline_status_integration_terminal_status_sets_finished_at( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, pipeline_status, ): @@ -3193,7 +3193,7 @@ def test_set_pipeline_status_integration_created_status_clears_started_at( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, ): """Test that setting status to CREATED clears the started_at property.""" @@ -3218,7 +3218,7 @@ def test_set_pipeline_status_integration_running_status_sets_started_at( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, initial_started_at, ): @@ -3296,7 +3296,7 @@ async def test_enqueue_in_arq_integration( self, session, arq_redis: ArqRedis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -3322,7 +3322,7 @@ async def test_full_pipeline_lifecycle( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -3430,7 +3430,7 @@ async def test_full_pipeline_lifecycle( @pytest.mark.asyncio async def test_paused_pipeline_lifecycle( - self, session, arq_redis, setup_worker_db, sample_pipeline, sample_job_run, sample_dependent_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run ): """Test lifecycle of a paused pipeline.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) @@ -3530,7 +3530,7 @@ async def test_cancelled_pipeline_lifecycle( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, sample_dependent_job_run, @@ -3586,7 +3586,7 @@ async def test_restart_pipeline_lifecycle( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): @@ -3653,7 +3653,7 @@ async def test_retry_pipeline_lifecycle( self, session, arq_redis, - setup_worker_db, + with_populated_job_data, sample_pipeline, sample_job_run, ): From ce893a443449ddd0a80729a10f247e6d02a4da3f Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 22 Jan 2026 13:42:25 -0800 Subject: [PATCH 23/70] wip: refactor jobs to use job management system feat(wip): upload files to S3 prior to job invocation, localstack emulation in dev environment --- bin/localstack-init.sh | 4 + docker-compose-dev.yml | 13 + poetry.lock | 835 ++++++----- pyproject.toml | 2 +- settings/.env.template | 9 + src/mavedb/data_providers/services.py | 19 +- src/mavedb/lib/clingen/constants.py | 2 - src/mavedb/lib/exceptions.py | 6 + src/mavedb/routers/score_sets.py | 37 +- src/mavedb/worker/jobs/__init__.py | 2 - .../worker/jobs/data_management/py.typed | 0 .../worker/jobs/data_management/views.py | 114 +- .../worker/jobs/external_services/clingen.py | 858 ++++------- .../worker/jobs/external_services/gnomad.py | 198 ++- .../worker/jobs/external_services/py.typed | 0 .../worker/jobs/external_services/uniprot.py | 412 +++--- src/mavedb/worker/jobs/registry.py | 2 - src/mavedb/worker/jobs/utils/__init__.py | 6 +- src/mavedb/worker/jobs/utils/job_state.py | 35 - src/mavedb/worker/jobs/utils/py.typed | 0 src/mavedb/worker/jobs/utils/retry.py | 61 - src/mavedb/worker/jobs/utils/setup.py | 24 + .../jobs/variant_processing/__init__.py | 2 - .../jobs/variant_processing/creation.py | 225 +-- .../worker/jobs/variant_processing/mapping.py | 738 ++++------ .../worker/jobs/variant_processing/py.typed | 0 src/mavedb/worker/lib/managers/py.typed | 0 tests/network/worker/test_clingen.py | 0 tests/network/worker/test_gnomad.py | 0 tests/network/worker/test_uniprot.py | 0 tests/worker/{lib => }/conftest_optional.py | 0 .../worker/jobs/data_management/test_views.py | 288 ++++ .../jobs/external_services/test_clingen.py | 1289 ++++++----------- .../jobs/external_services/test_gnomad.py | 206 --- .../jobs/external_services/test_uniprot.py | 603 -------- tests/worker/jobs/utils/test_setup.py | 30 + .../jobs/variant_processing/test_creation.py | 557 ------- .../jobs/variant_processing/test_mapping.py | 710 --------- tests/worker/lib/conftest.py | 192 --- 39 files changed, 2415 insertions(+), 5064 deletions(-) create mode 100755 bin/localstack-init.sh create mode 100644 src/mavedb/worker/jobs/data_management/py.typed create mode 100644 src/mavedb/worker/jobs/external_services/py.typed delete mode 100644 src/mavedb/worker/jobs/utils/job_state.py create mode 100644 src/mavedb/worker/jobs/utils/py.typed delete mode 100644 src/mavedb/worker/jobs/utils/retry.py create mode 100644 src/mavedb/worker/jobs/utils/setup.py create mode 100644 src/mavedb/worker/jobs/variant_processing/py.typed create mode 100644 src/mavedb/worker/lib/managers/py.typed create mode 100644 tests/network/worker/test_clingen.py create mode 100644 tests/network/worker/test_gnomad.py create mode 100644 tests/network/worker/test_uniprot.py rename tests/worker/{lib => }/conftest_optional.py (100%) create mode 100644 tests/worker/jobs/data_management/test_views.py create mode 100644 tests/worker/jobs/utils/test_setup.py delete mode 100644 tests/worker/lib/conftest.py diff --git a/bin/localstack-init.sh b/bin/localstack-init.sh new file mode 100755 index 00000000..1a00cfcb --- /dev/null +++ b/bin/localstack-init.sh @@ -0,0 +1,4 @@ +#!/bin/sh +echo "Initializing local S3 bucket..." +awslocal s3 mb s3://score-set-csv-uploads-dev +echo "S3 bucket 'score-set-csv-uploads-dev' created." \ No newline at end of file diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index d9d430af..972eb410 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -95,6 +95,18 @@ services: volumes: - mavedb-redis-dev:/data + localstack: + image: localstack/localstack:latest + ports: + - "4566:4566" + env_file: + - settings/.env.dev + environment: + - SERVICES=s3:4566 # We only need S3 for MaveDB + volumes: + - mavedb-localstack-dev:/var/lib/localstack + - "./bin/localstack-init.sh:/etc/localstack/init/ready.d/localstack-init.sh" + seqrepo: image: biocommons/seqrepo:2024-12-20 volumes: @@ -104,3 +116,4 @@ volumes: mavedb-data-dev: mavedb-redis-dev: mavedb-seqrepo-dev: + mavedb-localstack-dev: diff --git a/poetry.lock b/poetry.lock index 18ecdd5e..35c2477c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -301,411 +301,441 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "boto3-stubs" -version = "1.34.162" -description = "Type annotations for boto3 1.34.162 generated with mypy-boto3-builder 7.26.0" +version = "1.42.33" +description = "Type annotations for boto3 1.42.33 generated with mypy-boto3-builder 8.12.0" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev"] files = [ - {file = "boto3_stubs-1.34.162-py3-none-any.whl", hash = "sha256:47c651272782a2e894082087eeaeb87a7e809e7e282748560cf39c155031abef"}, - {file = "boto3_stubs-1.34.162.tar.gz", hash = "sha256:6d60b7b9652e1c99f3caba00779e1b94ba7062b0431147a00543af8b1f5252f4"}, + {file = "boto3_stubs-1.42.33-py3-none-any.whl", hash = "sha256:ea2887aaab8b29db446a8260a19069ad8ad614d7a9ffe34ae87b9a2396c7a57e"}, + {file = "boto3_stubs-1.42.33.tar.gz", hash = "sha256:c6b508b3541d48d63892a3eb2a7b36ec4d24435e8cf8233b6ae3f8f2122f0b61"}, ] [package.dependencies] botocore-stubs = "*" +mypy-boto3-s3 = {version = ">=1.42.0,<1.43.0", optional = true, markers = "extra == \"s3\""} types-s3transfer = "*" typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} [package.extras] -accessanalyzer = ["mypy-boto3-accessanalyzer (>=1.34.0,<1.35.0)"] -account = ["mypy-boto3-account (>=1.34.0,<1.35.0)"] -acm = ["mypy-boto3-acm (>=1.34.0,<1.35.0)"] -acm-pca = ["mypy-boto3-acm-pca (>=1.34.0,<1.35.0)"] -all = ["mypy-boto3-accessanalyzer (>=1.34.0,<1.35.0)", "mypy-boto3-account (>=1.34.0,<1.35.0)", "mypy-boto3-acm (>=1.34.0,<1.35.0)", "mypy-boto3-acm-pca (>=1.34.0,<1.35.0)", "mypy-boto3-amp (>=1.34.0,<1.35.0)", "mypy-boto3-amplify (>=1.34.0,<1.35.0)", "mypy-boto3-amplifybackend (>=1.34.0,<1.35.0)", "mypy-boto3-amplifyuibuilder (>=1.34.0,<1.35.0)", "mypy-boto3-apigateway (>=1.34.0,<1.35.0)", "mypy-boto3-apigatewaymanagementapi (>=1.34.0,<1.35.0)", "mypy-boto3-apigatewayv2 (>=1.34.0,<1.35.0)", "mypy-boto3-appconfig (>=1.34.0,<1.35.0)", "mypy-boto3-appconfigdata (>=1.34.0,<1.35.0)", "mypy-boto3-appfabric (>=1.34.0,<1.35.0)", "mypy-boto3-appflow (>=1.34.0,<1.35.0)", "mypy-boto3-appintegrations (>=1.34.0,<1.35.0)", "mypy-boto3-application-autoscaling (>=1.34.0,<1.35.0)", "mypy-boto3-application-insights (>=1.34.0,<1.35.0)", "mypy-boto3-application-signals (>=1.34.0,<1.35.0)", "mypy-boto3-applicationcostprofiler (>=1.34.0,<1.35.0)", "mypy-boto3-appmesh (>=1.34.0,<1.35.0)", "mypy-boto3-apprunner (>=1.34.0,<1.35.0)", "mypy-boto3-appstream (>=1.34.0,<1.35.0)", "mypy-boto3-appsync (>=1.34.0,<1.35.0)", "mypy-boto3-apptest (>=1.34.0,<1.35.0)", "mypy-boto3-arc-zonal-shift (>=1.34.0,<1.35.0)", "mypy-boto3-artifact (>=1.34.0,<1.35.0)", "mypy-boto3-athena (>=1.34.0,<1.35.0)", "mypy-boto3-auditmanager (>=1.34.0,<1.35.0)", "mypy-boto3-autoscaling (>=1.34.0,<1.35.0)", "mypy-boto3-autoscaling-plans (>=1.34.0,<1.35.0)", "mypy-boto3-b2bi (>=1.34.0,<1.35.0)", "mypy-boto3-backup (>=1.34.0,<1.35.0)", "mypy-boto3-backup-gateway (>=1.34.0,<1.35.0)", "mypy-boto3-batch (>=1.34.0,<1.35.0)", "mypy-boto3-bcm-data-exports (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock-agent (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock-agent-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-bedrock-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-billingconductor (>=1.34.0,<1.35.0)", "mypy-boto3-braket (>=1.34.0,<1.35.0)", "mypy-boto3-budgets (>=1.34.0,<1.35.0)", "mypy-boto3-ce (>=1.34.0,<1.35.0)", "mypy-boto3-chatbot (>=1.34.0,<1.35.0)", "mypy-boto3-chime (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-identity (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-media-pipelines (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-meetings (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-messaging (>=1.34.0,<1.35.0)", "mypy-boto3-chime-sdk-voice (>=1.34.0,<1.35.0)", "mypy-boto3-cleanrooms (>=1.34.0,<1.35.0)", "mypy-boto3-cleanroomsml (>=1.34.0,<1.35.0)", "mypy-boto3-cloud9 (>=1.34.0,<1.35.0)", "mypy-boto3-cloudcontrol (>=1.34.0,<1.35.0)", "mypy-boto3-clouddirectory (>=1.34.0,<1.35.0)", "mypy-boto3-cloudformation (>=1.34.0,<1.35.0)", "mypy-boto3-cloudfront (>=1.34.0,<1.35.0)", "mypy-boto3-cloudfront-keyvaluestore (>=1.34.0,<1.35.0)", "mypy-boto3-cloudhsm (>=1.34.0,<1.35.0)", "mypy-boto3-cloudhsmv2 (>=1.34.0,<1.35.0)", "mypy-boto3-cloudsearch (>=1.34.0,<1.35.0)", "mypy-boto3-cloudsearchdomain (>=1.34.0,<1.35.0)", "mypy-boto3-cloudtrail (>=1.34.0,<1.35.0)", "mypy-boto3-cloudtrail-data (>=1.34.0,<1.35.0)", "mypy-boto3-cloudwatch (>=1.34.0,<1.35.0)", "mypy-boto3-codeartifact (>=1.34.0,<1.35.0)", "mypy-boto3-codebuild (>=1.34.0,<1.35.0)", "mypy-boto3-codecatalyst (>=1.34.0,<1.35.0)", "mypy-boto3-codecommit (>=1.34.0,<1.35.0)", "mypy-boto3-codeconnections (>=1.34.0,<1.35.0)", "mypy-boto3-codedeploy (>=1.34.0,<1.35.0)", "mypy-boto3-codeguru-reviewer (>=1.34.0,<1.35.0)", "mypy-boto3-codeguru-security (>=1.34.0,<1.35.0)", "mypy-boto3-codeguruprofiler (>=1.34.0,<1.35.0)", "mypy-boto3-codepipeline (>=1.34.0,<1.35.0)", "mypy-boto3-codestar (>=1.34.0,<1.35.0)", "mypy-boto3-codestar-connections (>=1.34.0,<1.35.0)", "mypy-boto3-codestar-notifications (>=1.34.0,<1.35.0)", "mypy-boto3-cognito-identity (>=1.34.0,<1.35.0)", "mypy-boto3-cognito-idp (>=1.34.0,<1.35.0)", "mypy-boto3-cognito-sync (>=1.34.0,<1.35.0)", "mypy-boto3-comprehend (>=1.34.0,<1.35.0)", "mypy-boto3-comprehendmedical (>=1.34.0,<1.35.0)", "mypy-boto3-compute-optimizer (>=1.34.0,<1.35.0)", "mypy-boto3-config (>=1.34.0,<1.35.0)", "mypy-boto3-connect (>=1.34.0,<1.35.0)", "mypy-boto3-connect-contact-lens (>=1.34.0,<1.35.0)", "mypy-boto3-connectcampaigns (>=1.34.0,<1.35.0)", "mypy-boto3-connectcases (>=1.34.0,<1.35.0)", "mypy-boto3-connectparticipant (>=1.34.0,<1.35.0)", "mypy-boto3-controlcatalog (>=1.34.0,<1.35.0)", "mypy-boto3-controltower (>=1.34.0,<1.35.0)", "mypy-boto3-cost-optimization-hub (>=1.34.0,<1.35.0)", "mypy-boto3-cur (>=1.34.0,<1.35.0)", "mypy-boto3-customer-profiles (>=1.34.0,<1.35.0)", "mypy-boto3-databrew (>=1.34.0,<1.35.0)", "mypy-boto3-dataexchange (>=1.34.0,<1.35.0)", "mypy-boto3-datapipeline (>=1.34.0,<1.35.0)", "mypy-boto3-datasync (>=1.34.0,<1.35.0)", "mypy-boto3-datazone (>=1.34.0,<1.35.0)", "mypy-boto3-dax (>=1.34.0,<1.35.0)", "mypy-boto3-deadline (>=1.34.0,<1.35.0)", "mypy-boto3-detective (>=1.34.0,<1.35.0)", "mypy-boto3-devicefarm (>=1.34.0,<1.35.0)", "mypy-boto3-devops-guru (>=1.34.0,<1.35.0)", "mypy-boto3-directconnect (>=1.34.0,<1.35.0)", "mypy-boto3-discovery (>=1.34.0,<1.35.0)", "mypy-boto3-dlm (>=1.34.0,<1.35.0)", "mypy-boto3-dms (>=1.34.0,<1.35.0)", "mypy-boto3-docdb (>=1.34.0,<1.35.0)", "mypy-boto3-docdb-elastic (>=1.34.0,<1.35.0)", "mypy-boto3-drs (>=1.34.0,<1.35.0)", "mypy-boto3-ds (>=1.34.0,<1.35.0)", "mypy-boto3-dynamodb (>=1.34.0,<1.35.0)", "mypy-boto3-dynamodbstreams (>=1.34.0,<1.35.0)", "mypy-boto3-ebs (>=1.34.0,<1.35.0)", "mypy-boto3-ec2 (>=1.34.0,<1.35.0)", "mypy-boto3-ec2-instance-connect (>=1.34.0,<1.35.0)", "mypy-boto3-ecr (>=1.34.0,<1.35.0)", "mypy-boto3-ecr-public (>=1.34.0,<1.35.0)", "mypy-boto3-ecs (>=1.34.0,<1.35.0)", "mypy-boto3-efs (>=1.34.0,<1.35.0)", "mypy-boto3-eks (>=1.34.0,<1.35.0)", "mypy-boto3-eks-auth (>=1.34.0,<1.35.0)", "mypy-boto3-elastic-inference (>=1.34.0,<1.35.0)", "mypy-boto3-elasticache (>=1.34.0,<1.35.0)", "mypy-boto3-elasticbeanstalk (>=1.34.0,<1.35.0)", "mypy-boto3-elastictranscoder (>=1.34.0,<1.35.0)", "mypy-boto3-elb (>=1.34.0,<1.35.0)", "mypy-boto3-elbv2 (>=1.34.0,<1.35.0)", "mypy-boto3-emr (>=1.34.0,<1.35.0)", "mypy-boto3-emr-containers (>=1.34.0,<1.35.0)", "mypy-boto3-emr-serverless (>=1.34.0,<1.35.0)", "mypy-boto3-entityresolution (>=1.34.0,<1.35.0)", "mypy-boto3-es (>=1.34.0,<1.35.0)", "mypy-boto3-events (>=1.34.0,<1.35.0)", "mypy-boto3-evidently (>=1.34.0,<1.35.0)", "mypy-boto3-finspace (>=1.34.0,<1.35.0)", "mypy-boto3-finspace-data (>=1.34.0,<1.35.0)", "mypy-boto3-firehose (>=1.34.0,<1.35.0)", "mypy-boto3-fis (>=1.34.0,<1.35.0)", "mypy-boto3-fms (>=1.34.0,<1.35.0)", "mypy-boto3-forecast (>=1.34.0,<1.35.0)", "mypy-boto3-forecastquery (>=1.34.0,<1.35.0)", "mypy-boto3-frauddetector (>=1.34.0,<1.35.0)", "mypy-boto3-freetier (>=1.34.0,<1.35.0)", "mypy-boto3-fsx (>=1.34.0,<1.35.0)", "mypy-boto3-gamelift (>=1.34.0,<1.35.0)", "mypy-boto3-glacier (>=1.34.0,<1.35.0)", "mypy-boto3-globalaccelerator (>=1.34.0,<1.35.0)", "mypy-boto3-glue (>=1.34.0,<1.35.0)", "mypy-boto3-grafana (>=1.34.0,<1.35.0)", "mypy-boto3-greengrass (>=1.34.0,<1.35.0)", "mypy-boto3-greengrassv2 (>=1.34.0,<1.35.0)", "mypy-boto3-groundstation (>=1.34.0,<1.35.0)", "mypy-boto3-guardduty (>=1.34.0,<1.35.0)", "mypy-boto3-health (>=1.34.0,<1.35.0)", "mypy-boto3-healthlake (>=1.34.0,<1.35.0)", "mypy-boto3-iam (>=1.34.0,<1.35.0)", "mypy-boto3-identitystore (>=1.34.0,<1.35.0)", "mypy-boto3-imagebuilder (>=1.34.0,<1.35.0)", "mypy-boto3-importexport (>=1.34.0,<1.35.0)", "mypy-boto3-inspector (>=1.34.0,<1.35.0)", "mypy-boto3-inspector-scan (>=1.34.0,<1.35.0)", "mypy-boto3-inspector2 (>=1.34.0,<1.35.0)", "mypy-boto3-internetmonitor (>=1.34.0,<1.35.0)", "mypy-boto3-iot (>=1.34.0,<1.35.0)", "mypy-boto3-iot-data (>=1.34.0,<1.35.0)", "mypy-boto3-iot-jobs-data (>=1.34.0,<1.35.0)", "mypy-boto3-iot1click-devices (>=1.34.0,<1.35.0)", "mypy-boto3-iot1click-projects (>=1.34.0,<1.35.0)", "mypy-boto3-iotanalytics (>=1.34.0,<1.35.0)", "mypy-boto3-iotdeviceadvisor (>=1.34.0,<1.35.0)", "mypy-boto3-iotevents (>=1.34.0,<1.35.0)", "mypy-boto3-iotevents-data (>=1.34.0,<1.35.0)", "mypy-boto3-iotfleethub (>=1.34.0,<1.35.0)", "mypy-boto3-iotfleetwise (>=1.34.0,<1.35.0)", "mypy-boto3-iotsecuretunneling (>=1.34.0,<1.35.0)", "mypy-boto3-iotsitewise (>=1.34.0,<1.35.0)", "mypy-boto3-iotthingsgraph (>=1.34.0,<1.35.0)", "mypy-boto3-iottwinmaker (>=1.34.0,<1.35.0)", "mypy-boto3-iotwireless (>=1.34.0,<1.35.0)", "mypy-boto3-ivs (>=1.34.0,<1.35.0)", "mypy-boto3-ivs-realtime (>=1.34.0,<1.35.0)", "mypy-boto3-ivschat (>=1.34.0,<1.35.0)", "mypy-boto3-kafka (>=1.34.0,<1.35.0)", "mypy-boto3-kafkaconnect (>=1.34.0,<1.35.0)", "mypy-boto3-kendra (>=1.34.0,<1.35.0)", "mypy-boto3-kendra-ranking (>=1.34.0,<1.35.0)", "mypy-boto3-keyspaces (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-archived-media (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-media (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-signaling (>=1.34.0,<1.35.0)", "mypy-boto3-kinesis-video-webrtc-storage (>=1.34.0,<1.35.0)", "mypy-boto3-kinesisanalytics (>=1.34.0,<1.35.0)", "mypy-boto3-kinesisanalyticsv2 (>=1.34.0,<1.35.0)", "mypy-boto3-kinesisvideo (>=1.34.0,<1.35.0)", "mypy-boto3-kms (>=1.34.0,<1.35.0)", "mypy-boto3-lakeformation (>=1.34.0,<1.35.0)", "mypy-boto3-lambda (>=1.34.0,<1.35.0)", "mypy-boto3-launch-wizard (>=1.34.0,<1.35.0)", "mypy-boto3-lex-models (>=1.34.0,<1.35.0)", "mypy-boto3-lex-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-lexv2-models (>=1.34.0,<1.35.0)", "mypy-boto3-lexv2-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-license-manager (>=1.34.0,<1.35.0)", "mypy-boto3-license-manager-linux-subscriptions (>=1.34.0,<1.35.0)", "mypy-boto3-license-manager-user-subscriptions (>=1.34.0,<1.35.0)", "mypy-boto3-lightsail (>=1.34.0,<1.35.0)", "mypy-boto3-location (>=1.34.0,<1.35.0)", "mypy-boto3-logs (>=1.34.0,<1.35.0)", "mypy-boto3-lookoutequipment (>=1.34.0,<1.35.0)", "mypy-boto3-lookoutmetrics (>=1.34.0,<1.35.0)", "mypy-boto3-lookoutvision (>=1.34.0,<1.35.0)", "mypy-boto3-m2 (>=1.34.0,<1.35.0)", "mypy-boto3-machinelearning (>=1.34.0,<1.35.0)", "mypy-boto3-macie2 (>=1.34.0,<1.35.0)", "mypy-boto3-mailmanager (>=1.34.0,<1.35.0)", "mypy-boto3-managedblockchain (>=1.34.0,<1.35.0)", "mypy-boto3-managedblockchain-query (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-agreement (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-catalog (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-deployment (>=1.34.0,<1.35.0)", "mypy-boto3-marketplace-entitlement (>=1.34.0,<1.35.0)", "mypy-boto3-marketplacecommerceanalytics (>=1.34.0,<1.35.0)", "mypy-boto3-mediaconnect (>=1.34.0,<1.35.0)", "mypy-boto3-mediaconvert (>=1.34.0,<1.35.0)", "mypy-boto3-medialive (>=1.34.0,<1.35.0)", "mypy-boto3-mediapackage (>=1.34.0,<1.35.0)", "mypy-boto3-mediapackage-vod (>=1.34.0,<1.35.0)", "mypy-boto3-mediapackagev2 (>=1.34.0,<1.35.0)", "mypy-boto3-mediastore (>=1.34.0,<1.35.0)", "mypy-boto3-mediastore-data (>=1.34.0,<1.35.0)", "mypy-boto3-mediatailor (>=1.34.0,<1.35.0)", "mypy-boto3-medical-imaging (>=1.34.0,<1.35.0)", "mypy-boto3-memorydb (>=1.34.0,<1.35.0)", "mypy-boto3-meteringmarketplace (>=1.34.0,<1.35.0)", "mypy-boto3-mgh (>=1.34.0,<1.35.0)", "mypy-boto3-mgn (>=1.34.0,<1.35.0)", "mypy-boto3-migration-hub-refactor-spaces (>=1.34.0,<1.35.0)", "mypy-boto3-migrationhub-config (>=1.34.0,<1.35.0)", "mypy-boto3-migrationhuborchestrator (>=1.34.0,<1.35.0)", "mypy-boto3-migrationhubstrategy (>=1.34.0,<1.35.0)", "mypy-boto3-mq (>=1.34.0,<1.35.0)", "mypy-boto3-mturk (>=1.34.0,<1.35.0)", "mypy-boto3-mwaa (>=1.34.0,<1.35.0)", "mypy-boto3-neptune (>=1.34.0,<1.35.0)", "mypy-boto3-neptune-graph (>=1.34.0,<1.35.0)", "mypy-boto3-neptunedata (>=1.34.0,<1.35.0)", "mypy-boto3-network-firewall (>=1.34.0,<1.35.0)", "mypy-boto3-networkmanager (>=1.34.0,<1.35.0)", "mypy-boto3-networkmonitor (>=1.34.0,<1.35.0)", "mypy-boto3-nimble (>=1.34.0,<1.35.0)", "mypy-boto3-oam (>=1.34.0,<1.35.0)", "mypy-boto3-omics (>=1.34.0,<1.35.0)", "mypy-boto3-opensearch (>=1.34.0,<1.35.0)", "mypy-boto3-opensearchserverless (>=1.34.0,<1.35.0)", "mypy-boto3-opsworks (>=1.34.0,<1.35.0)", "mypy-boto3-opsworkscm (>=1.34.0,<1.35.0)", "mypy-boto3-organizations (>=1.34.0,<1.35.0)", "mypy-boto3-osis (>=1.34.0,<1.35.0)", "mypy-boto3-outposts (>=1.34.0,<1.35.0)", "mypy-boto3-panorama (>=1.34.0,<1.35.0)", "mypy-boto3-payment-cryptography (>=1.34.0,<1.35.0)", "mypy-boto3-payment-cryptography-data (>=1.34.0,<1.35.0)", "mypy-boto3-pca-connector-ad (>=1.34.0,<1.35.0)", "mypy-boto3-pca-connector-scep (>=1.34.0,<1.35.0)", "mypy-boto3-personalize (>=1.34.0,<1.35.0)", "mypy-boto3-personalize-events (>=1.34.0,<1.35.0)", "mypy-boto3-personalize-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-pi (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint-email (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint-sms-voice (>=1.34.0,<1.35.0)", "mypy-boto3-pinpoint-sms-voice-v2 (>=1.34.0,<1.35.0)", "mypy-boto3-pipes (>=1.34.0,<1.35.0)", "mypy-boto3-polly (>=1.34.0,<1.35.0)", "mypy-boto3-pricing (>=1.34.0,<1.35.0)", "mypy-boto3-privatenetworks (>=1.34.0,<1.35.0)", "mypy-boto3-proton (>=1.34.0,<1.35.0)", "mypy-boto3-qapps (>=1.34.0,<1.35.0)", "mypy-boto3-qbusiness (>=1.34.0,<1.35.0)", "mypy-boto3-qconnect (>=1.34.0,<1.35.0)", "mypy-boto3-qldb (>=1.34.0,<1.35.0)", "mypy-boto3-qldb-session (>=1.34.0,<1.35.0)", "mypy-boto3-quicksight (>=1.34.0,<1.35.0)", "mypy-boto3-ram (>=1.34.0,<1.35.0)", "mypy-boto3-rbin (>=1.34.0,<1.35.0)", "mypy-boto3-rds (>=1.34.0,<1.35.0)", "mypy-boto3-rds-data (>=1.34.0,<1.35.0)", "mypy-boto3-redshift (>=1.34.0,<1.35.0)", "mypy-boto3-redshift-data (>=1.34.0,<1.35.0)", "mypy-boto3-redshift-serverless (>=1.34.0,<1.35.0)", "mypy-boto3-rekognition (>=1.34.0,<1.35.0)", "mypy-boto3-repostspace (>=1.34.0,<1.35.0)", "mypy-boto3-resiliencehub (>=1.34.0,<1.35.0)", "mypy-boto3-resource-explorer-2 (>=1.34.0,<1.35.0)", "mypy-boto3-resource-groups (>=1.34.0,<1.35.0)", "mypy-boto3-resourcegroupstaggingapi (>=1.34.0,<1.35.0)", "mypy-boto3-robomaker (>=1.34.0,<1.35.0)", "mypy-boto3-rolesanywhere (>=1.34.0,<1.35.0)", "mypy-boto3-route53 (>=1.34.0,<1.35.0)", "mypy-boto3-route53-recovery-cluster (>=1.34.0,<1.35.0)", "mypy-boto3-route53-recovery-control-config (>=1.34.0,<1.35.0)", "mypy-boto3-route53-recovery-readiness (>=1.34.0,<1.35.0)", "mypy-boto3-route53domains (>=1.34.0,<1.35.0)", "mypy-boto3-route53profiles (>=1.34.0,<1.35.0)", "mypy-boto3-route53resolver (>=1.34.0,<1.35.0)", "mypy-boto3-rum (>=1.34.0,<1.35.0)", "mypy-boto3-s3 (>=1.34.0,<1.35.0)", "mypy-boto3-s3control (>=1.34.0,<1.35.0)", "mypy-boto3-s3outposts (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-a2i-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-edge (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-featurestore-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-geospatial (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-metrics (>=1.34.0,<1.35.0)", "mypy-boto3-sagemaker-runtime (>=1.34.0,<1.35.0)", "mypy-boto3-savingsplans (>=1.34.0,<1.35.0)", "mypy-boto3-scheduler (>=1.34.0,<1.35.0)", "mypy-boto3-schemas (>=1.34.0,<1.35.0)", "mypy-boto3-sdb (>=1.34.0,<1.35.0)", "mypy-boto3-secretsmanager (>=1.34.0,<1.35.0)", "mypy-boto3-securityhub (>=1.34.0,<1.35.0)", "mypy-boto3-securitylake (>=1.34.0,<1.35.0)", "mypy-boto3-serverlessrepo (>=1.34.0,<1.35.0)", "mypy-boto3-service-quotas (>=1.34.0,<1.35.0)", "mypy-boto3-servicecatalog (>=1.34.0,<1.35.0)", "mypy-boto3-servicecatalog-appregistry (>=1.34.0,<1.35.0)", "mypy-boto3-servicediscovery (>=1.34.0,<1.35.0)", "mypy-boto3-ses (>=1.34.0,<1.35.0)", "mypy-boto3-sesv2 (>=1.34.0,<1.35.0)", "mypy-boto3-shield (>=1.34.0,<1.35.0)", "mypy-boto3-signer (>=1.34.0,<1.35.0)", "mypy-boto3-simspaceweaver (>=1.34.0,<1.35.0)", "mypy-boto3-sms (>=1.34.0,<1.35.0)", "mypy-boto3-sms-voice (>=1.34.0,<1.35.0)", "mypy-boto3-snow-device-management (>=1.34.0,<1.35.0)", "mypy-boto3-snowball (>=1.34.0,<1.35.0)", "mypy-boto3-sns (>=1.34.0,<1.35.0)", "mypy-boto3-sqs (>=1.34.0,<1.35.0)", "mypy-boto3-ssm (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-contacts (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-incidents (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-quicksetup (>=1.34.0,<1.35.0)", "mypy-boto3-ssm-sap (>=1.34.0,<1.35.0)", "mypy-boto3-sso (>=1.34.0,<1.35.0)", "mypy-boto3-sso-admin (>=1.34.0,<1.35.0)", "mypy-boto3-sso-oidc (>=1.34.0,<1.35.0)", "mypy-boto3-stepfunctions (>=1.34.0,<1.35.0)", "mypy-boto3-storagegateway (>=1.34.0,<1.35.0)", "mypy-boto3-sts (>=1.34.0,<1.35.0)", "mypy-boto3-supplychain (>=1.34.0,<1.35.0)", "mypy-boto3-support (>=1.34.0,<1.35.0)", "mypy-boto3-support-app (>=1.34.0,<1.35.0)", "mypy-boto3-swf (>=1.34.0,<1.35.0)", "mypy-boto3-synthetics (>=1.34.0,<1.35.0)", "mypy-boto3-taxsettings (>=1.34.0,<1.35.0)", "mypy-boto3-textract (>=1.34.0,<1.35.0)", "mypy-boto3-timestream-influxdb (>=1.34.0,<1.35.0)", "mypy-boto3-timestream-query (>=1.34.0,<1.35.0)", "mypy-boto3-timestream-write (>=1.34.0,<1.35.0)", "mypy-boto3-tnb (>=1.34.0,<1.35.0)", "mypy-boto3-transcribe (>=1.34.0,<1.35.0)", "mypy-boto3-transfer (>=1.34.0,<1.35.0)", "mypy-boto3-translate (>=1.34.0,<1.35.0)", "mypy-boto3-trustedadvisor (>=1.34.0,<1.35.0)", "mypy-boto3-verifiedpermissions (>=1.34.0,<1.35.0)", "mypy-boto3-voice-id (>=1.34.0,<1.35.0)", "mypy-boto3-vpc-lattice (>=1.34.0,<1.35.0)", "mypy-boto3-waf (>=1.34.0,<1.35.0)", "mypy-boto3-waf-regional (>=1.34.0,<1.35.0)", "mypy-boto3-wafv2 (>=1.34.0,<1.35.0)", "mypy-boto3-wellarchitected (>=1.34.0,<1.35.0)", "mypy-boto3-wisdom (>=1.34.0,<1.35.0)", "mypy-boto3-workdocs (>=1.34.0,<1.35.0)", "mypy-boto3-worklink (>=1.34.0,<1.35.0)", "mypy-boto3-workmail (>=1.34.0,<1.35.0)", "mypy-boto3-workmailmessageflow (>=1.34.0,<1.35.0)", "mypy-boto3-workspaces (>=1.34.0,<1.35.0)", "mypy-boto3-workspaces-thin-client (>=1.34.0,<1.35.0)", "mypy-boto3-workspaces-web (>=1.34.0,<1.35.0)", "mypy-boto3-xray (>=1.34.0,<1.35.0)"] -amp = ["mypy-boto3-amp (>=1.34.0,<1.35.0)"] -amplify = ["mypy-boto3-amplify (>=1.34.0,<1.35.0)"] -amplifybackend = ["mypy-boto3-amplifybackend (>=1.34.0,<1.35.0)"] -amplifyuibuilder = ["mypy-boto3-amplifyuibuilder (>=1.34.0,<1.35.0)"] -apigateway = ["mypy-boto3-apigateway (>=1.34.0,<1.35.0)"] -apigatewaymanagementapi = ["mypy-boto3-apigatewaymanagementapi (>=1.34.0,<1.35.0)"] -apigatewayv2 = ["mypy-boto3-apigatewayv2 (>=1.34.0,<1.35.0)"] -appconfig = ["mypy-boto3-appconfig (>=1.34.0,<1.35.0)"] -appconfigdata = ["mypy-boto3-appconfigdata (>=1.34.0,<1.35.0)"] -appfabric = ["mypy-boto3-appfabric (>=1.34.0,<1.35.0)"] -appflow = ["mypy-boto3-appflow (>=1.34.0,<1.35.0)"] -appintegrations = ["mypy-boto3-appintegrations (>=1.34.0,<1.35.0)"] -application-autoscaling = ["mypy-boto3-application-autoscaling (>=1.34.0,<1.35.0)"] -application-insights = ["mypy-boto3-application-insights (>=1.34.0,<1.35.0)"] -application-signals = ["mypy-boto3-application-signals (>=1.34.0,<1.35.0)"] -applicationcostprofiler = ["mypy-boto3-applicationcostprofiler (>=1.34.0,<1.35.0)"] -appmesh = ["mypy-boto3-appmesh (>=1.34.0,<1.35.0)"] -apprunner = ["mypy-boto3-apprunner (>=1.34.0,<1.35.0)"] -appstream = ["mypy-boto3-appstream (>=1.34.0,<1.35.0)"] -appsync = ["mypy-boto3-appsync (>=1.34.0,<1.35.0)"] -apptest = ["mypy-boto3-apptest (>=1.34.0,<1.35.0)"] -arc-zonal-shift = ["mypy-boto3-arc-zonal-shift (>=1.34.0,<1.35.0)"] -artifact = ["mypy-boto3-artifact (>=1.34.0,<1.35.0)"] -athena = ["mypy-boto3-athena (>=1.34.0,<1.35.0)"] -auditmanager = ["mypy-boto3-auditmanager (>=1.34.0,<1.35.0)"] -autoscaling = ["mypy-boto3-autoscaling (>=1.34.0,<1.35.0)"] -autoscaling-plans = ["mypy-boto3-autoscaling-plans (>=1.34.0,<1.35.0)"] -b2bi = ["mypy-boto3-b2bi (>=1.34.0,<1.35.0)"] -backup = ["mypy-boto3-backup (>=1.34.0,<1.35.0)"] -backup-gateway = ["mypy-boto3-backup-gateway (>=1.34.0,<1.35.0)"] -batch = ["mypy-boto3-batch (>=1.34.0,<1.35.0)"] -bcm-data-exports = ["mypy-boto3-bcm-data-exports (>=1.34.0,<1.35.0)"] -bedrock = ["mypy-boto3-bedrock (>=1.34.0,<1.35.0)"] -bedrock-agent = ["mypy-boto3-bedrock-agent (>=1.34.0,<1.35.0)"] -bedrock-agent-runtime = ["mypy-boto3-bedrock-agent-runtime (>=1.34.0,<1.35.0)"] -bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.34.0,<1.35.0)"] -billingconductor = ["mypy-boto3-billingconductor (>=1.34.0,<1.35.0)"] -boto3 = ["boto3 (==1.34.162)", "botocore (==1.34.162)"] -braket = ["mypy-boto3-braket (>=1.34.0,<1.35.0)"] -budgets = ["mypy-boto3-budgets (>=1.34.0,<1.35.0)"] -ce = ["mypy-boto3-ce (>=1.34.0,<1.35.0)"] -chatbot = ["mypy-boto3-chatbot (>=1.34.0,<1.35.0)"] -chime = ["mypy-boto3-chime (>=1.34.0,<1.35.0)"] -chime-sdk-identity = ["mypy-boto3-chime-sdk-identity (>=1.34.0,<1.35.0)"] -chime-sdk-media-pipelines = ["mypy-boto3-chime-sdk-media-pipelines (>=1.34.0,<1.35.0)"] -chime-sdk-meetings = ["mypy-boto3-chime-sdk-meetings (>=1.34.0,<1.35.0)"] -chime-sdk-messaging = ["mypy-boto3-chime-sdk-messaging (>=1.34.0,<1.35.0)"] -chime-sdk-voice = ["mypy-boto3-chime-sdk-voice (>=1.34.0,<1.35.0)"] -cleanrooms = ["mypy-boto3-cleanrooms (>=1.34.0,<1.35.0)"] -cleanroomsml = ["mypy-boto3-cleanroomsml (>=1.34.0,<1.35.0)"] -cloud9 = ["mypy-boto3-cloud9 (>=1.34.0,<1.35.0)"] -cloudcontrol = ["mypy-boto3-cloudcontrol (>=1.34.0,<1.35.0)"] -clouddirectory = ["mypy-boto3-clouddirectory (>=1.34.0,<1.35.0)"] -cloudformation = ["mypy-boto3-cloudformation (>=1.34.0,<1.35.0)"] -cloudfront = ["mypy-boto3-cloudfront (>=1.34.0,<1.35.0)"] -cloudfront-keyvaluestore = ["mypy-boto3-cloudfront-keyvaluestore (>=1.34.0,<1.35.0)"] -cloudhsm = ["mypy-boto3-cloudhsm (>=1.34.0,<1.35.0)"] -cloudhsmv2 = ["mypy-boto3-cloudhsmv2 (>=1.34.0,<1.35.0)"] -cloudsearch = ["mypy-boto3-cloudsearch (>=1.34.0,<1.35.0)"] -cloudsearchdomain = ["mypy-boto3-cloudsearchdomain (>=1.34.0,<1.35.0)"] -cloudtrail = ["mypy-boto3-cloudtrail (>=1.34.0,<1.35.0)"] -cloudtrail-data = ["mypy-boto3-cloudtrail-data (>=1.34.0,<1.35.0)"] -cloudwatch = ["mypy-boto3-cloudwatch (>=1.34.0,<1.35.0)"] -codeartifact = ["mypy-boto3-codeartifact (>=1.34.0,<1.35.0)"] -codebuild = ["mypy-boto3-codebuild (>=1.34.0,<1.35.0)"] -codecatalyst = ["mypy-boto3-codecatalyst (>=1.34.0,<1.35.0)"] -codecommit = ["mypy-boto3-codecommit (>=1.34.0,<1.35.0)"] -codeconnections = ["mypy-boto3-codeconnections (>=1.34.0,<1.35.0)"] -codedeploy = ["mypy-boto3-codedeploy (>=1.34.0,<1.35.0)"] -codeguru-reviewer = ["mypy-boto3-codeguru-reviewer (>=1.34.0,<1.35.0)"] -codeguru-security = ["mypy-boto3-codeguru-security (>=1.34.0,<1.35.0)"] -codeguruprofiler = ["mypy-boto3-codeguruprofiler (>=1.34.0,<1.35.0)"] -codepipeline = ["mypy-boto3-codepipeline (>=1.34.0,<1.35.0)"] -codestar = ["mypy-boto3-codestar (>=1.34.0,<1.35.0)"] -codestar-connections = ["mypy-boto3-codestar-connections (>=1.34.0,<1.35.0)"] -codestar-notifications = ["mypy-boto3-codestar-notifications (>=1.34.0,<1.35.0)"] -cognito-identity = ["mypy-boto3-cognito-identity (>=1.34.0,<1.35.0)"] -cognito-idp = ["mypy-boto3-cognito-idp (>=1.34.0,<1.35.0)"] -cognito-sync = ["mypy-boto3-cognito-sync (>=1.34.0,<1.35.0)"] -comprehend = ["mypy-boto3-comprehend (>=1.34.0,<1.35.0)"] -comprehendmedical = ["mypy-boto3-comprehendmedical (>=1.34.0,<1.35.0)"] -compute-optimizer = ["mypy-boto3-compute-optimizer (>=1.34.0,<1.35.0)"] -config = ["mypy-boto3-config (>=1.34.0,<1.35.0)"] -connect = ["mypy-boto3-connect (>=1.34.0,<1.35.0)"] -connect-contact-lens = ["mypy-boto3-connect-contact-lens (>=1.34.0,<1.35.0)"] -connectcampaigns = ["mypy-boto3-connectcampaigns (>=1.34.0,<1.35.0)"] -connectcases = ["mypy-boto3-connectcases (>=1.34.0,<1.35.0)"] -connectparticipant = ["mypy-boto3-connectparticipant (>=1.34.0,<1.35.0)"] -controlcatalog = ["mypy-boto3-controlcatalog (>=1.34.0,<1.35.0)"] -controltower = ["mypy-boto3-controltower (>=1.34.0,<1.35.0)"] -cost-optimization-hub = ["mypy-boto3-cost-optimization-hub (>=1.34.0,<1.35.0)"] -cur = ["mypy-boto3-cur (>=1.34.0,<1.35.0)"] -customer-profiles = ["mypy-boto3-customer-profiles (>=1.34.0,<1.35.0)"] -databrew = ["mypy-boto3-databrew (>=1.34.0,<1.35.0)"] -dataexchange = ["mypy-boto3-dataexchange (>=1.34.0,<1.35.0)"] -datapipeline = ["mypy-boto3-datapipeline (>=1.34.0,<1.35.0)"] -datasync = ["mypy-boto3-datasync (>=1.34.0,<1.35.0)"] -datazone = ["mypy-boto3-datazone (>=1.34.0,<1.35.0)"] -dax = ["mypy-boto3-dax (>=1.34.0,<1.35.0)"] -deadline = ["mypy-boto3-deadline (>=1.34.0,<1.35.0)"] -detective = ["mypy-boto3-detective (>=1.34.0,<1.35.0)"] -devicefarm = ["mypy-boto3-devicefarm (>=1.34.0,<1.35.0)"] -devops-guru = ["mypy-boto3-devops-guru (>=1.34.0,<1.35.0)"] -directconnect = ["mypy-boto3-directconnect (>=1.34.0,<1.35.0)"] -discovery = ["mypy-boto3-discovery (>=1.34.0,<1.35.0)"] -dlm = ["mypy-boto3-dlm (>=1.34.0,<1.35.0)"] -dms = ["mypy-boto3-dms (>=1.34.0,<1.35.0)"] -docdb = ["mypy-boto3-docdb (>=1.34.0,<1.35.0)"] -docdb-elastic = ["mypy-boto3-docdb-elastic (>=1.34.0,<1.35.0)"] -drs = ["mypy-boto3-drs (>=1.34.0,<1.35.0)"] -ds = ["mypy-boto3-ds (>=1.34.0,<1.35.0)"] -dynamodb = ["mypy-boto3-dynamodb (>=1.34.0,<1.35.0)"] -dynamodbstreams = ["mypy-boto3-dynamodbstreams (>=1.34.0,<1.35.0)"] -ebs = ["mypy-boto3-ebs (>=1.34.0,<1.35.0)"] -ec2 = ["mypy-boto3-ec2 (>=1.34.0,<1.35.0)"] -ec2-instance-connect = ["mypy-boto3-ec2-instance-connect (>=1.34.0,<1.35.0)"] -ecr = ["mypy-boto3-ecr (>=1.34.0,<1.35.0)"] -ecr-public = ["mypy-boto3-ecr-public (>=1.34.0,<1.35.0)"] -ecs = ["mypy-boto3-ecs (>=1.34.0,<1.35.0)"] -efs = ["mypy-boto3-efs (>=1.34.0,<1.35.0)"] -eks = ["mypy-boto3-eks (>=1.34.0,<1.35.0)"] -eks-auth = ["mypy-boto3-eks-auth (>=1.34.0,<1.35.0)"] -elastic-inference = ["mypy-boto3-elastic-inference (>=1.34.0,<1.35.0)"] -elasticache = ["mypy-boto3-elasticache (>=1.34.0,<1.35.0)"] -elasticbeanstalk = ["mypy-boto3-elasticbeanstalk (>=1.34.0,<1.35.0)"] -elastictranscoder = ["mypy-boto3-elastictranscoder (>=1.34.0,<1.35.0)"] -elb = ["mypy-boto3-elb (>=1.34.0,<1.35.0)"] -elbv2 = ["mypy-boto3-elbv2 (>=1.34.0,<1.35.0)"] -emr = ["mypy-boto3-emr (>=1.34.0,<1.35.0)"] -emr-containers = ["mypy-boto3-emr-containers (>=1.34.0,<1.35.0)"] -emr-serverless = ["mypy-boto3-emr-serverless (>=1.34.0,<1.35.0)"] -entityresolution = ["mypy-boto3-entityresolution (>=1.34.0,<1.35.0)"] -es = ["mypy-boto3-es (>=1.34.0,<1.35.0)"] -essential = ["mypy-boto3-cloudformation (>=1.34.0,<1.35.0)", "mypy-boto3-dynamodb (>=1.34.0,<1.35.0)", "mypy-boto3-ec2 (>=1.34.0,<1.35.0)", "mypy-boto3-lambda (>=1.34.0,<1.35.0)", "mypy-boto3-rds (>=1.34.0,<1.35.0)", "mypy-boto3-s3 (>=1.34.0,<1.35.0)", "mypy-boto3-sqs (>=1.34.0,<1.35.0)"] -events = ["mypy-boto3-events (>=1.34.0,<1.35.0)"] -evidently = ["mypy-boto3-evidently (>=1.34.0,<1.35.0)"] -finspace = ["mypy-boto3-finspace (>=1.34.0,<1.35.0)"] -finspace-data = ["mypy-boto3-finspace-data (>=1.34.0,<1.35.0)"] -firehose = ["mypy-boto3-firehose (>=1.34.0,<1.35.0)"] -fis = ["mypy-boto3-fis (>=1.34.0,<1.35.0)"] -fms = ["mypy-boto3-fms (>=1.34.0,<1.35.0)"] -forecast = ["mypy-boto3-forecast (>=1.34.0,<1.35.0)"] -forecastquery = ["mypy-boto3-forecastquery (>=1.34.0,<1.35.0)"] -frauddetector = ["mypy-boto3-frauddetector (>=1.34.0,<1.35.0)"] -freetier = ["mypy-boto3-freetier (>=1.34.0,<1.35.0)"] -fsx = ["mypy-boto3-fsx (>=1.34.0,<1.35.0)"] -gamelift = ["mypy-boto3-gamelift (>=1.34.0,<1.35.0)"] -glacier = ["mypy-boto3-glacier (>=1.34.0,<1.35.0)"] -globalaccelerator = ["mypy-boto3-globalaccelerator (>=1.34.0,<1.35.0)"] -glue = ["mypy-boto3-glue (>=1.34.0,<1.35.0)"] -grafana = ["mypy-boto3-grafana (>=1.34.0,<1.35.0)"] -greengrass = ["mypy-boto3-greengrass (>=1.34.0,<1.35.0)"] -greengrassv2 = ["mypy-boto3-greengrassv2 (>=1.34.0,<1.35.0)"] -groundstation = ["mypy-boto3-groundstation (>=1.34.0,<1.35.0)"] -guardduty = ["mypy-boto3-guardduty (>=1.34.0,<1.35.0)"] -health = ["mypy-boto3-health (>=1.34.0,<1.35.0)"] -healthlake = ["mypy-boto3-healthlake (>=1.34.0,<1.35.0)"] -iam = ["mypy-boto3-iam (>=1.34.0,<1.35.0)"] -identitystore = ["mypy-boto3-identitystore (>=1.34.0,<1.35.0)"] -imagebuilder = ["mypy-boto3-imagebuilder (>=1.34.0,<1.35.0)"] -importexport = ["mypy-boto3-importexport (>=1.34.0,<1.35.0)"] -inspector = ["mypy-boto3-inspector (>=1.34.0,<1.35.0)"] -inspector-scan = ["mypy-boto3-inspector-scan (>=1.34.0,<1.35.0)"] -inspector2 = ["mypy-boto3-inspector2 (>=1.34.0,<1.35.0)"] -internetmonitor = ["mypy-boto3-internetmonitor (>=1.34.0,<1.35.0)"] -iot = ["mypy-boto3-iot (>=1.34.0,<1.35.0)"] -iot-data = ["mypy-boto3-iot-data (>=1.34.0,<1.35.0)"] -iot-jobs-data = ["mypy-boto3-iot-jobs-data (>=1.34.0,<1.35.0)"] -iot1click-devices = ["mypy-boto3-iot1click-devices (>=1.34.0,<1.35.0)"] -iot1click-projects = ["mypy-boto3-iot1click-projects (>=1.34.0,<1.35.0)"] -iotanalytics = ["mypy-boto3-iotanalytics (>=1.34.0,<1.35.0)"] -iotdeviceadvisor = ["mypy-boto3-iotdeviceadvisor (>=1.34.0,<1.35.0)"] -iotevents = ["mypy-boto3-iotevents (>=1.34.0,<1.35.0)"] -iotevents-data = ["mypy-boto3-iotevents-data (>=1.34.0,<1.35.0)"] -iotfleethub = ["mypy-boto3-iotfleethub (>=1.34.0,<1.35.0)"] -iotfleetwise = ["mypy-boto3-iotfleetwise (>=1.34.0,<1.35.0)"] -iotsecuretunneling = ["mypy-boto3-iotsecuretunneling (>=1.34.0,<1.35.0)"] -iotsitewise = ["mypy-boto3-iotsitewise (>=1.34.0,<1.35.0)"] -iotthingsgraph = ["mypy-boto3-iotthingsgraph (>=1.34.0,<1.35.0)"] -iottwinmaker = ["mypy-boto3-iottwinmaker (>=1.34.0,<1.35.0)"] -iotwireless = ["mypy-boto3-iotwireless (>=1.34.0,<1.35.0)"] -ivs = ["mypy-boto3-ivs (>=1.34.0,<1.35.0)"] -ivs-realtime = ["mypy-boto3-ivs-realtime (>=1.34.0,<1.35.0)"] -ivschat = ["mypy-boto3-ivschat (>=1.34.0,<1.35.0)"] -kafka = ["mypy-boto3-kafka (>=1.34.0,<1.35.0)"] -kafkaconnect = ["mypy-boto3-kafkaconnect (>=1.34.0,<1.35.0)"] -kendra = ["mypy-boto3-kendra (>=1.34.0,<1.35.0)"] -kendra-ranking = ["mypy-boto3-kendra-ranking (>=1.34.0,<1.35.0)"] -keyspaces = ["mypy-boto3-keyspaces (>=1.34.0,<1.35.0)"] -kinesis = ["mypy-boto3-kinesis (>=1.34.0,<1.35.0)"] -kinesis-video-archived-media = ["mypy-boto3-kinesis-video-archived-media (>=1.34.0,<1.35.0)"] -kinesis-video-media = ["mypy-boto3-kinesis-video-media (>=1.34.0,<1.35.0)"] -kinesis-video-signaling = ["mypy-boto3-kinesis-video-signaling (>=1.34.0,<1.35.0)"] -kinesis-video-webrtc-storage = ["mypy-boto3-kinesis-video-webrtc-storage (>=1.34.0,<1.35.0)"] -kinesisanalytics = ["mypy-boto3-kinesisanalytics (>=1.34.0,<1.35.0)"] -kinesisanalyticsv2 = ["mypy-boto3-kinesisanalyticsv2 (>=1.34.0,<1.35.0)"] -kinesisvideo = ["mypy-boto3-kinesisvideo (>=1.34.0,<1.35.0)"] -kms = ["mypy-boto3-kms (>=1.34.0,<1.35.0)"] -lakeformation = ["mypy-boto3-lakeformation (>=1.34.0,<1.35.0)"] -lambda = ["mypy-boto3-lambda (>=1.34.0,<1.35.0)"] -launch-wizard = ["mypy-boto3-launch-wizard (>=1.34.0,<1.35.0)"] -lex-models = ["mypy-boto3-lex-models (>=1.34.0,<1.35.0)"] -lex-runtime = ["mypy-boto3-lex-runtime (>=1.34.0,<1.35.0)"] -lexv2-models = ["mypy-boto3-lexv2-models (>=1.34.0,<1.35.0)"] -lexv2-runtime = ["mypy-boto3-lexv2-runtime (>=1.34.0,<1.35.0)"] -license-manager = ["mypy-boto3-license-manager (>=1.34.0,<1.35.0)"] -license-manager-linux-subscriptions = ["mypy-boto3-license-manager-linux-subscriptions (>=1.34.0,<1.35.0)"] -license-manager-user-subscriptions = ["mypy-boto3-license-manager-user-subscriptions (>=1.34.0,<1.35.0)"] -lightsail = ["mypy-boto3-lightsail (>=1.34.0,<1.35.0)"] -location = ["mypy-boto3-location (>=1.34.0,<1.35.0)"] -logs = ["mypy-boto3-logs (>=1.34.0,<1.35.0)"] -lookoutequipment = ["mypy-boto3-lookoutequipment (>=1.34.0,<1.35.0)"] -lookoutmetrics = ["mypy-boto3-lookoutmetrics (>=1.34.0,<1.35.0)"] -lookoutvision = ["mypy-boto3-lookoutvision (>=1.34.0,<1.35.0)"] -m2 = ["mypy-boto3-m2 (>=1.34.0,<1.35.0)"] -machinelearning = ["mypy-boto3-machinelearning (>=1.34.0,<1.35.0)"] -macie2 = ["mypy-boto3-macie2 (>=1.34.0,<1.35.0)"] -mailmanager = ["mypy-boto3-mailmanager (>=1.34.0,<1.35.0)"] -managedblockchain = ["mypy-boto3-managedblockchain (>=1.34.0,<1.35.0)"] -managedblockchain-query = ["mypy-boto3-managedblockchain-query (>=1.34.0,<1.35.0)"] -marketplace-agreement = ["mypy-boto3-marketplace-agreement (>=1.34.0,<1.35.0)"] -marketplace-catalog = ["mypy-boto3-marketplace-catalog (>=1.34.0,<1.35.0)"] -marketplace-deployment = ["mypy-boto3-marketplace-deployment (>=1.34.0,<1.35.0)"] -marketplace-entitlement = ["mypy-boto3-marketplace-entitlement (>=1.34.0,<1.35.0)"] -marketplacecommerceanalytics = ["mypy-boto3-marketplacecommerceanalytics (>=1.34.0,<1.35.0)"] -mediaconnect = ["mypy-boto3-mediaconnect (>=1.34.0,<1.35.0)"] -mediaconvert = ["mypy-boto3-mediaconvert (>=1.34.0,<1.35.0)"] -medialive = ["mypy-boto3-medialive (>=1.34.0,<1.35.0)"] -mediapackage = ["mypy-boto3-mediapackage (>=1.34.0,<1.35.0)"] -mediapackage-vod = ["mypy-boto3-mediapackage-vod (>=1.34.0,<1.35.0)"] -mediapackagev2 = ["mypy-boto3-mediapackagev2 (>=1.34.0,<1.35.0)"] -mediastore = ["mypy-boto3-mediastore (>=1.34.0,<1.35.0)"] -mediastore-data = ["mypy-boto3-mediastore-data (>=1.34.0,<1.35.0)"] -mediatailor = ["mypy-boto3-mediatailor (>=1.34.0,<1.35.0)"] -medical-imaging = ["mypy-boto3-medical-imaging (>=1.34.0,<1.35.0)"] -memorydb = ["mypy-boto3-memorydb (>=1.34.0,<1.35.0)"] -meteringmarketplace = ["mypy-boto3-meteringmarketplace (>=1.34.0,<1.35.0)"] -mgh = ["mypy-boto3-mgh (>=1.34.0,<1.35.0)"] -mgn = ["mypy-boto3-mgn (>=1.34.0,<1.35.0)"] -migration-hub-refactor-spaces = ["mypy-boto3-migration-hub-refactor-spaces (>=1.34.0,<1.35.0)"] -migrationhub-config = ["mypy-boto3-migrationhub-config (>=1.34.0,<1.35.0)"] -migrationhuborchestrator = ["mypy-boto3-migrationhuborchestrator (>=1.34.0,<1.35.0)"] -migrationhubstrategy = ["mypy-boto3-migrationhubstrategy (>=1.34.0,<1.35.0)"] -mq = ["mypy-boto3-mq (>=1.34.0,<1.35.0)"] -mturk = ["mypy-boto3-mturk (>=1.34.0,<1.35.0)"] -mwaa = ["mypy-boto3-mwaa (>=1.34.0,<1.35.0)"] -neptune = ["mypy-boto3-neptune (>=1.34.0,<1.35.0)"] -neptune-graph = ["mypy-boto3-neptune-graph (>=1.34.0,<1.35.0)"] -neptunedata = ["mypy-boto3-neptunedata (>=1.34.0,<1.35.0)"] -network-firewall = ["mypy-boto3-network-firewall (>=1.34.0,<1.35.0)"] -networkmanager = ["mypy-boto3-networkmanager (>=1.34.0,<1.35.0)"] -networkmonitor = ["mypy-boto3-networkmonitor (>=1.34.0,<1.35.0)"] -nimble = ["mypy-boto3-nimble (>=1.34.0,<1.35.0)"] -oam = ["mypy-boto3-oam (>=1.34.0,<1.35.0)"] -omics = ["mypy-boto3-omics (>=1.34.0,<1.35.0)"] -opensearch = ["mypy-boto3-opensearch (>=1.34.0,<1.35.0)"] -opensearchserverless = ["mypy-boto3-opensearchserverless (>=1.34.0,<1.35.0)"] -opsworks = ["mypy-boto3-opsworks (>=1.34.0,<1.35.0)"] -opsworkscm = ["mypy-boto3-opsworkscm (>=1.34.0,<1.35.0)"] -organizations = ["mypy-boto3-organizations (>=1.34.0,<1.35.0)"] -osis = ["mypy-boto3-osis (>=1.34.0,<1.35.0)"] -outposts = ["mypy-boto3-outposts (>=1.34.0,<1.35.0)"] -panorama = ["mypy-boto3-panorama (>=1.34.0,<1.35.0)"] -payment-cryptography = ["mypy-boto3-payment-cryptography (>=1.34.0,<1.35.0)"] -payment-cryptography-data = ["mypy-boto3-payment-cryptography-data (>=1.34.0,<1.35.0)"] -pca-connector-ad = ["mypy-boto3-pca-connector-ad (>=1.34.0,<1.35.0)"] -pca-connector-scep = ["mypy-boto3-pca-connector-scep (>=1.34.0,<1.35.0)"] -personalize = ["mypy-boto3-personalize (>=1.34.0,<1.35.0)"] -personalize-events = ["mypy-boto3-personalize-events (>=1.34.0,<1.35.0)"] -personalize-runtime = ["mypy-boto3-personalize-runtime (>=1.34.0,<1.35.0)"] -pi = ["mypy-boto3-pi (>=1.34.0,<1.35.0)"] -pinpoint = ["mypy-boto3-pinpoint (>=1.34.0,<1.35.0)"] -pinpoint-email = ["mypy-boto3-pinpoint-email (>=1.34.0,<1.35.0)"] -pinpoint-sms-voice = ["mypy-boto3-pinpoint-sms-voice (>=1.34.0,<1.35.0)"] -pinpoint-sms-voice-v2 = ["mypy-boto3-pinpoint-sms-voice-v2 (>=1.34.0,<1.35.0)"] -pipes = ["mypy-boto3-pipes (>=1.34.0,<1.35.0)"] -polly = ["mypy-boto3-polly (>=1.34.0,<1.35.0)"] -pricing = ["mypy-boto3-pricing (>=1.34.0,<1.35.0)"] -privatenetworks = ["mypy-boto3-privatenetworks (>=1.34.0,<1.35.0)"] -proton = ["mypy-boto3-proton (>=1.34.0,<1.35.0)"] -qapps = ["mypy-boto3-qapps (>=1.34.0,<1.35.0)"] -qbusiness = ["mypy-boto3-qbusiness (>=1.34.0,<1.35.0)"] -qconnect = ["mypy-boto3-qconnect (>=1.34.0,<1.35.0)"] -qldb = ["mypy-boto3-qldb (>=1.34.0,<1.35.0)"] -qldb-session = ["mypy-boto3-qldb-session (>=1.34.0,<1.35.0)"] -quicksight = ["mypy-boto3-quicksight (>=1.34.0,<1.35.0)"] -ram = ["mypy-boto3-ram (>=1.34.0,<1.35.0)"] -rbin = ["mypy-boto3-rbin (>=1.34.0,<1.35.0)"] -rds = ["mypy-boto3-rds (>=1.34.0,<1.35.0)"] -rds-data = ["mypy-boto3-rds-data (>=1.34.0,<1.35.0)"] -redshift = ["mypy-boto3-redshift (>=1.34.0,<1.35.0)"] -redshift-data = ["mypy-boto3-redshift-data (>=1.34.0,<1.35.0)"] -redshift-serverless = ["mypy-boto3-redshift-serverless (>=1.34.0,<1.35.0)"] -rekognition = ["mypy-boto3-rekognition (>=1.34.0,<1.35.0)"] -repostspace = ["mypy-boto3-repostspace (>=1.34.0,<1.35.0)"] -resiliencehub = ["mypy-boto3-resiliencehub (>=1.34.0,<1.35.0)"] -resource-explorer-2 = ["mypy-boto3-resource-explorer-2 (>=1.34.0,<1.35.0)"] -resource-groups = ["mypy-boto3-resource-groups (>=1.34.0,<1.35.0)"] -resourcegroupstaggingapi = ["mypy-boto3-resourcegroupstaggingapi (>=1.34.0,<1.35.0)"] -robomaker = ["mypy-boto3-robomaker (>=1.34.0,<1.35.0)"] -rolesanywhere = ["mypy-boto3-rolesanywhere (>=1.34.0,<1.35.0)"] -route53 = ["mypy-boto3-route53 (>=1.34.0,<1.35.0)"] -route53-recovery-cluster = ["mypy-boto3-route53-recovery-cluster (>=1.34.0,<1.35.0)"] -route53-recovery-control-config = ["mypy-boto3-route53-recovery-control-config (>=1.34.0,<1.35.0)"] -route53-recovery-readiness = ["mypy-boto3-route53-recovery-readiness (>=1.34.0,<1.35.0)"] -route53domains = ["mypy-boto3-route53domains (>=1.34.0,<1.35.0)"] -route53profiles = ["mypy-boto3-route53profiles (>=1.34.0,<1.35.0)"] -route53resolver = ["mypy-boto3-route53resolver (>=1.34.0,<1.35.0)"] -rum = ["mypy-boto3-rum (>=1.34.0,<1.35.0)"] -s3 = ["mypy-boto3-s3 (>=1.34.0,<1.35.0)"] -s3control = ["mypy-boto3-s3control (>=1.34.0,<1.35.0)"] -s3outposts = ["mypy-boto3-s3outposts (>=1.34.0,<1.35.0)"] -sagemaker = ["mypy-boto3-sagemaker (>=1.34.0,<1.35.0)"] -sagemaker-a2i-runtime = ["mypy-boto3-sagemaker-a2i-runtime (>=1.34.0,<1.35.0)"] -sagemaker-edge = ["mypy-boto3-sagemaker-edge (>=1.34.0,<1.35.0)"] -sagemaker-featurestore-runtime = ["mypy-boto3-sagemaker-featurestore-runtime (>=1.34.0,<1.35.0)"] -sagemaker-geospatial = ["mypy-boto3-sagemaker-geospatial (>=1.34.0,<1.35.0)"] -sagemaker-metrics = ["mypy-boto3-sagemaker-metrics (>=1.34.0,<1.35.0)"] -sagemaker-runtime = ["mypy-boto3-sagemaker-runtime (>=1.34.0,<1.35.0)"] -savingsplans = ["mypy-boto3-savingsplans (>=1.34.0,<1.35.0)"] -scheduler = ["mypy-boto3-scheduler (>=1.34.0,<1.35.0)"] -schemas = ["mypy-boto3-schemas (>=1.34.0,<1.35.0)"] -sdb = ["mypy-boto3-sdb (>=1.34.0,<1.35.0)"] -secretsmanager = ["mypy-boto3-secretsmanager (>=1.34.0,<1.35.0)"] -securityhub = ["mypy-boto3-securityhub (>=1.34.0,<1.35.0)"] -securitylake = ["mypy-boto3-securitylake (>=1.34.0,<1.35.0)"] -serverlessrepo = ["mypy-boto3-serverlessrepo (>=1.34.0,<1.35.0)"] -service-quotas = ["mypy-boto3-service-quotas (>=1.34.0,<1.35.0)"] -servicecatalog = ["mypy-boto3-servicecatalog (>=1.34.0,<1.35.0)"] -servicecatalog-appregistry = ["mypy-boto3-servicecatalog-appregistry (>=1.34.0,<1.35.0)"] -servicediscovery = ["mypy-boto3-servicediscovery (>=1.34.0,<1.35.0)"] -ses = ["mypy-boto3-ses (>=1.34.0,<1.35.0)"] -sesv2 = ["mypy-boto3-sesv2 (>=1.34.0,<1.35.0)"] -shield = ["mypy-boto3-shield (>=1.34.0,<1.35.0)"] -signer = ["mypy-boto3-signer (>=1.34.0,<1.35.0)"] -simspaceweaver = ["mypy-boto3-simspaceweaver (>=1.34.0,<1.35.0)"] -sms = ["mypy-boto3-sms (>=1.34.0,<1.35.0)"] -sms-voice = ["mypy-boto3-sms-voice (>=1.34.0,<1.35.0)"] -snow-device-management = ["mypy-boto3-snow-device-management (>=1.34.0,<1.35.0)"] -snowball = ["mypy-boto3-snowball (>=1.34.0,<1.35.0)"] -sns = ["mypy-boto3-sns (>=1.34.0,<1.35.0)"] -sqs = ["mypy-boto3-sqs (>=1.34.0,<1.35.0)"] -ssm = ["mypy-boto3-ssm (>=1.34.0,<1.35.0)"] -ssm-contacts = ["mypy-boto3-ssm-contacts (>=1.34.0,<1.35.0)"] -ssm-incidents = ["mypy-boto3-ssm-incidents (>=1.34.0,<1.35.0)"] -ssm-quicksetup = ["mypy-boto3-ssm-quicksetup (>=1.34.0,<1.35.0)"] -ssm-sap = ["mypy-boto3-ssm-sap (>=1.34.0,<1.35.0)"] -sso = ["mypy-boto3-sso (>=1.34.0,<1.35.0)"] -sso-admin = ["mypy-boto3-sso-admin (>=1.34.0,<1.35.0)"] -sso-oidc = ["mypy-boto3-sso-oidc (>=1.34.0,<1.35.0)"] -stepfunctions = ["mypy-boto3-stepfunctions (>=1.34.0,<1.35.0)"] -storagegateway = ["mypy-boto3-storagegateway (>=1.34.0,<1.35.0)"] -sts = ["mypy-boto3-sts (>=1.34.0,<1.35.0)"] -supplychain = ["mypy-boto3-supplychain (>=1.34.0,<1.35.0)"] -support = ["mypy-boto3-support (>=1.34.0,<1.35.0)"] -support-app = ["mypy-boto3-support-app (>=1.34.0,<1.35.0)"] -swf = ["mypy-boto3-swf (>=1.34.0,<1.35.0)"] -synthetics = ["mypy-boto3-synthetics (>=1.34.0,<1.35.0)"] -taxsettings = ["mypy-boto3-taxsettings (>=1.34.0,<1.35.0)"] -textract = ["mypy-boto3-textract (>=1.34.0,<1.35.0)"] -timestream-influxdb = ["mypy-boto3-timestream-influxdb (>=1.34.0,<1.35.0)"] -timestream-query = ["mypy-boto3-timestream-query (>=1.34.0,<1.35.0)"] -timestream-write = ["mypy-boto3-timestream-write (>=1.34.0,<1.35.0)"] -tnb = ["mypy-boto3-tnb (>=1.34.0,<1.35.0)"] -transcribe = ["mypy-boto3-transcribe (>=1.34.0,<1.35.0)"] -transfer = ["mypy-boto3-transfer (>=1.34.0,<1.35.0)"] -translate = ["mypy-boto3-translate (>=1.34.0,<1.35.0)"] -trustedadvisor = ["mypy-boto3-trustedadvisor (>=1.34.0,<1.35.0)"] -verifiedpermissions = ["mypy-boto3-verifiedpermissions (>=1.34.0,<1.35.0)"] -voice-id = ["mypy-boto3-voice-id (>=1.34.0,<1.35.0)"] -vpc-lattice = ["mypy-boto3-vpc-lattice (>=1.34.0,<1.35.0)"] -waf = ["mypy-boto3-waf (>=1.34.0,<1.35.0)"] -waf-regional = ["mypy-boto3-waf-regional (>=1.34.0,<1.35.0)"] -wafv2 = ["mypy-boto3-wafv2 (>=1.34.0,<1.35.0)"] -wellarchitected = ["mypy-boto3-wellarchitected (>=1.34.0,<1.35.0)"] -wisdom = ["mypy-boto3-wisdom (>=1.34.0,<1.35.0)"] -workdocs = ["mypy-boto3-workdocs (>=1.34.0,<1.35.0)"] -worklink = ["mypy-boto3-worklink (>=1.34.0,<1.35.0)"] -workmail = ["mypy-boto3-workmail (>=1.34.0,<1.35.0)"] -workmailmessageflow = ["mypy-boto3-workmailmessageflow (>=1.34.0,<1.35.0)"] -workspaces = ["mypy-boto3-workspaces (>=1.34.0,<1.35.0)"] -workspaces-thin-client = ["mypy-boto3-workspaces-thin-client (>=1.34.0,<1.35.0)"] -workspaces-web = ["mypy-boto3-workspaces-web (>=1.34.0,<1.35.0)"] -xray = ["mypy-boto3-xray (>=1.34.0,<1.35.0)"] +accessanalyzer = ["mypy-boto3-accessanalyzer (>=1.42.0,<1.43.0)"] +account = ["mypy-boto3-account (>=1.42.0,<1.43.0)"] +acm = ["mypy-boto3-acm (>=1.42.0,<1.43.0)"] +acm-pca = ["mypy-boto3-acm-pca (>=1.42.0,<1.43.0)"] +aiops = ["mypy-boto3-aiops (>=1.42.0,<1.43.0)"] +all = ["mypy-boto3-accessanalyzer (>=1.42.0,<1.43.0)", "mypy-boto3-account (>=1.42.0,<1.43.0)", "mypy-boto3-acm (>=1.42.0,<1.43.0)", "mypy-boto3-acm-pca (>=1.42.0,<1.43.0)", "mypy-boto3-aiops (>=1.42.0,<1.43.0)", "mypy-boto3-amp (>=1.42.0,<1.43.0)", "mypy-boto3-amplify (>=1.42.0,<1.43.0)", "mypy-boto3-amplifybackend (>=1.42.0,<1.43.0)", "mypy-boto3-amplifyuibuilder (>=1.42.0,<1.43.0)", "mypy-boto3-apigateway (>=1.42.0,<1.43.0)", "mypy-boto3-apigatewaymanagementapi (>=1.42.0,<1.43.0)", "mypy-boto3-apigatewayv2 (>=1.42.0,<1.43.0)", "mypy-boto3-appconfig (>=1.42.0,<1.43.0)", "mypy-boto3-appconfigdata (>=1.42.0,<1.43.0)", "mypy-boto3-appfabric (>=1.42.0,<1.43.0)", "mypy-boto3-appflow (>=1.42.0,<1.43.0)", "mypy-boto3-appintegrations (>=1.42.0,<1.43.0)", "mypy-boto3-application-autoscaling (>=1.42.0,<1.43.0)", "mypy-boto3-application-insights (>=1.42.0,<1.43.0)", "mypy-boto3-application-signals (>=1.42.0,<1.43.0)", "mypy-boto3-applicationcostprofiler (>=1.42.0,<1.43.0)", "mypy-boto3-appmesh (>=1.42.0,<1.43.0)", "mypy-boto3-apprunner (>=1.42.0,<1.43.0)", "mypy-boto3-appstream (>=1.42.0,<1.43.0)", "mypy-boto3-appsync (>=1.42.0,<1.43.0)", "mypy-boto3-arc-region-switch (>=1.42.0,<1.43.0)", "mypy-boto3-arc-zonal-shift (>=1.42.0,<1.43.0)", "mypy-boto3-artifact (>=1.42.0,<1.43.0)", "mypy-boto3-athena (>=1.42.0,<1.43.0)", "mypy-boto3-auditmanager (>=1.42.0,<1.43.0)", "mypy-boto3-autoscaling (>=1.42.0,<1.43.0)", "mypy-boto3-autoscaling-plans (>=1.42.0,<1.43.0)", "mypy-boto3-b2bi (>=1.42.0,<1.43.0)", "mypy-boto3-backup (>=1.42.0,<1.43.0)", "mypy-boto3-backup-gateway (>=1.42.0,<1.43.0)", "mypy-boto3-backupsearch (>=1.42.0,<1.43.0)", "mypy-boto3-batch (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-dashboards (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-data-exports (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-pricing-calculator (>=1.42.0,<1.43.0)", "mypy-boto3-bcm-recommended-actions (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agent (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agent-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agentcore (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-agentcore-control (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-data-automation (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-data-automation-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-bedrock-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-billing (>=1.42.0,<1.43.0)", "mypy-boto3-billingconductor (>=1.42.0,<1.43.0)", "mypy-boto3-braket (>=1.42.0,<1.43.0)", "mypy-boto3-budgets (>=1.42.0,<1.43.0)", "mypy-boto3-ce (>=1.42.0,<1.43.0)", "mypy-boto3-chatbot (>=1.42.0,<1.43.0)", "mypy-boto3-chime (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-identity (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-media-pipelines (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-meetings (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-messaging (>=1.42.0,<1.43.0)", "mypy-boto3-chime-sdk-voice (>=1.42.0,<1.43.0)", "mypy-boto3-cleanrooms (>=1.42.0,<1.43.0)", "mypy-boto3-cleanroomsml (>=1.42.0,<1.43.0)", "mypy-boto3-cloud9 (>=1.42.0,<1.43.0)", "mypy-boto3-cloudcontrol (>=1.42.0,<1.43.0)", "mypy-boto3-clouddirectory (>=1.42.0,<1.43.0)", "mypy-boto3-cloudformation (>=1.42.0,<1.43.0)", "mypy-boto3-cloudfront (>=1.42.0,<1.43.0)", "mypy-boto3-cloudfront-keyvaluestore (>=1.42.0,<1.43.0)", "mypy-boto3-cloudhsm (>=1.42.0,<1.43.0)", "mypy-boto3-cloudhsmv2 (>=1.42.0,<1.43.0)", "mypy-boto3-cloudsearch (>=1.42.0,<1.43.0)", "mypy-boto3-cloudsearchdomain (>=1.42.0,<1.43.0)", "mypy-boto3-cloudtrail (>=1.42.0,<1.43.0)", "mypy-boto3-cloudtrail-data (>=1.42.0,<1.43.0)", "mypy-boto3-cloudwatch (>=1.42.0,<1.43.0)", "mypy-boto3-codeartifact (>=1.42.0,<1.43.0)", "mypy-boto3-codebuild (>=1.42.0,<1.43.0)", "mypy-boto3-codecatalyst (>=1.42.0,<1.43.0)", "mypy-boto3-codecommit (>=1.42.0,<1.43.0)", "mypy-boto3-codeconnections (>=1.42.0,<1.43.0)", "mypy-boto3-codedeploy (>=1.42.0,<1.43.0)", "mypy-boto3-codeguru-reviewer (>=1.42.0,<1.43.0)", "mypy-boto3-codeguru-security (>=1.42.0,<1.43.0)", "mypy-boto3-codeguruprofiler (>=1.42.0,<1.43.0)", "mypy-boto3-codepipeline (>=1.42.0,<1.43.0)", "mypy-boto3-codestar-connections (>=1.42.0,<1.43.0)", "mypy-boto3-codestar-notifications (>=1.42.0,<1.43.0)", "mypy-boto3-cognito-identity (>=1.42.0,<1.43.0)", "mypy-boto3-cognito-idp (>=1.42.0,<1.43.0)", "mypy-boto3-cognito-sync (>=1.42.0,<1.43.0)", "mypy-boto3-comprehend (>=1.42.0,<1.43.0)", "mypy-boto3-comprehendmedical (>=1.42.0,<1.43.0)", "mypy-boto3-compute-optimizer (>=1.42.0,<1.43.0)", "mypy-boto3-compute-optimizer-automation (>=1.42.0,<1.43.0)", "mypy-boto3-config (>=1.42.0,<1.43.0)", "mypy-boto3-connect (>=1.42.0,<1.43.0)", "mypy-boto3-connect-contact-lens (>=1.42.0,<1.43.0)", "mypy-boto3-connectcampaigns (>=1.42.0,<1.43.0)", "mypy-boto3-connectcampaignsv2 (>=1.42.0,<1.43.0)", "mypy-boto3-connectcases (>=1.42.0,<1.43.0)", "mypy-boto3-connectparticipant (>=1.42.0,<1.43.0)", "mypy-boto3-controlcatalog (>=1.42.0,<1.43.0)", "mypy-boto3-controltower (>=1.42.0,<1.43.0)", "mypy-boto3-cost-optimization-hub (>=1.42.0,<1.43.0)", "mypy-boto3-cur (>=1.42.0,<1.43.0)", "mypy-boto3-customer-profiles (>=1.42.0,<1.43.0)", "mypy-boto3-databrew (>=1.42.0,<1.43.0)", "mypy-boto3-dataexchange (>=1.42.0,<1.43.0)", "mypy-boto3-datapipeline (>=1.42.0,<1.43.0)", "mypy-boto3-datasync (>=1.42.0,<1.43.0)", "mypy-boto3-datazone (>=1.42.0,<1.43.0)", "mypy-boto3-dax (>=1.42.0,<1.43.0)", "mypy-boto3-deadline (>=1.42.0,<1.43.0)", "mypy-boto3-detective (>=1.42.0,<1.43.0)", "mypy-boto3-devicefarm (>=1.42.0,<1.43.0)", "mypy-boto3-devops-guru (>=1.42.0,<1.43.0)", "mypy-boto3-directconnect (>=1.42.0,<1.43.0)", "mypy-boto3-discovery (>=1.42.0,<1.43.0)", "mypy-boto3-dlm (>=1.42.0,<1.43.0)", "mypy-boto3-dms (>=1.42.0,<1.43.0)", "mypy-boto3-docdb (>=1.42.0,<1.43.0)", "mypy-boto3-docdb-elastic (>=1.42.0,<1.43.0)", "mypy-boto3-drs (>=1.42.0,<1.43.0)", "mypy-boto3-ds (>=1.42.0,<1.43.0)", "mypy-boto3-ds-data (>=1.42.0,<1.43.0)", "mypy-boto3-dsql (>=1.42.0,<1.43.0)", "mypy-boto3-dynamodb (>=1.42.0,<1.43.0)", "mypy-boto3-dynamodbstreams (>=1.42.0,<1.43.0)", "mypy-boto3-ebs (>=1.42.0,<1.43.0)", "mypy-boto3-ec2 (>=1.42.0,<1.43.0)", "mypy-boto3-ec2-instance-connect (>=1.42.0,<1.43.0)", "mypy-boto3-ecr (>=1.42.0,<1.43.0)", "mypy-boto3-ecr-public (>=1.42.0,<1.43.0)", "mypy-boto3-ecs (>=1.42.0,<1.43.0)", "mypy-boto3-efs (>=1.42.0,<1.43.0)", "mypy-boto3-eks (>=1.42.0,<1.43.0)", "mypy-boto3-eks-auth (>=1.42.0,<1.43.0)", "mypy-boto3-elasticache (>=1.42.0,<1.43.0)", "mypy-boto3-elasticbeanstalk (>=1.42.0,<1.43.0)", "mypy-boto3-elb (>=1.42.0,<1.43.0)", "mypy-boto3-elbv2 (>=1.42.0,<1.43.0)", "mypy-boto3-emr (>=1.42.0,<1.43.0)", "mypy-boto3-emr-containers (>=1.42.0,<1.43.0)", "mypy-boto3-emr-serverless (>=1.42.0,<1.43.0)", "mypy-boto3-entityresolution (>=1.42.0,<1.43.0)", "mypy-boto3-es (>=1.42.0,<1.43.0)", "mypy-boto3-events (>=1.42.0,<1.43.0)", "mypy-boto3-evidently (>=1.42.0,<1.43.0)", "mypy-boto3-evs (>=1.42.0,<1.43.0)", "mypy-boto3-finspace (>=1.42.0,<1.43.0)", "mypy-boto3-finspace-data (>=1.42.0,<1.43.0)", "mypy-boto3-firehose (>=1.42.0,<1.43.0)", "mypy-boto3-fis (>=1.42.0,<1.43.0)", "mypy-boto3-fms (>=1.42.0,<1.43.0)", "mypy-boto3-forecast (>=1.42.0,<1.43.0)", "mypy-boto3-forecastquery (>=1.42.0,<1.43.0)", "mypy-boto3-frauddetector (>=1.42.0,<1.43.0)", "mypy-boto3-freetier (>=1.42.0,<1.43.0)", "mypy-boto3-fsx (>=1.42.0,<1.43.0)", "mypy-boto3-gamelift (>=1.42.0,<1.43.0)", "mypy-boto3-gameliftstreams (>=1.42.0,<1.43.0)", "mypy-boto3-geo-maps (>=1.42.0,<1.43.0)", "mypy-boto3-geo-places (>=1.42.0,<1.43.0)", "mypy-boto3-geo-routes (>=1.42.0,<1.43.0)", "mypy-boto3-glacier (>=1.42.0,<1.43.0)", "mypy-boto3-globalaccelerator (>=1.42.0,<1.43.0)", "mypy-boto3-glue (>=1.42.0,<1.43.0)", "mypy-boto3-grafana (>=1.42.0,<1.43.0)", "mypy-boto3-greengrass (>=1.42.0,<1.43.0)", "mypy-boto3-greengrassv2 (>=1.42.0,<1.43.0)", "mypy-boto3-groundstation (>=1.42.0,<1.43.0)", "mypy-boto3-guardduty (>=1.42.0,<1.43.0)", "mypy-boto3-health (>=1.42.0,<1.43.0)", "mypy-boto3-healthlake (>=1.42.0,<1.43.0)", "mypy-boto3-iam (>=1.42.0,<1.43.0)", "mypy-boto3-identitystore (>=1.42.0,<1.43.0)", "mypy-boto3-imagebuilder (>=1.42.0,<1.43.0)", "mypy-boto3-importexport (>=1.42.0,<1.43.0)", "mypy-boto3-inspector (>=1.42.0,<1.43.0)", "mypy-boto3-inspector-scan (>=1.42.0,<1.43.0)", "mypy-boto3-inspector2 (>=1.42.0,<1.43.0)", "mypy-boto3-internetmonitor (>=1.42.0,<1.43.0)", "mypy-boto3-invoicing (>=1.42.0,<1.43.0)", "mypy-boto3-iot (>=1.42.0,<1.43.0)", "mypy-boto3-iot-data (>=1.42.0,<1.43.0)", "mypy-boto3-iot-jobs-data (>=1.42.0,<1.43.0)", "mypy-boto3-iot-managed-integrations (>=1.42.0,<1.43.0)", "mypy-boto3-iotanalytics (>=1.42.0,<1.43.0)", "mypy-boto3-iotdeviceadvisor (>=1.42.0,<1.43.0)", "mypy-boto3-iotevents (>=1.42.0,<1.43.0)", "mypy-boto3-iotevents-data (>=1.42.0,<1.43.0)", "mypy-boto3-iotfleetwise (>=1.42.0,<1.43.0)", "mypy-boto3-iotsecuretunneling (>=1.42.0,<1.43.0)", "mypy-boto3-iotsitewise (>=1.42.0,<1.43.0)", "mypy-boto3-iotthingsgraph (>=1.42.0,<1.43.0)", "mypy-boto3-iottwinmaker (>=1.42.0,<1.43.0)", "mypy-boto3-iotwireless (>=1.42.0,<1.43.0)", "mypy-boto3-ivs (>=1.42.0,<1.43.0)", "mypy-boto3-ivs-realtime (>=1.42.0,<1.43.0)", "mypy-boto3-ivschat (>=1.42.0,<1.43.0)", "mypy-boto3-kafka (>=1.42.0,<1.43.0)", "mypy-boto3-kafkaconnect (>=1.42.0,<1.43.0)", "mypy-boto3-kendra (>=1.42.0,<1.43.0)", "mypy-boto3-kendra-ranking (>=1.42.0,<1.43.0)", "mypy-boto3-keyspaces (>=1.42.0,<1.43.0)", "mypy-boto3-keyspacesstreams (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-archived-media (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-media (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-signaling (>=1.42.0,<1.43.0)", "mypy-boto3-kinesis-video-webrtc-storage (>=1.42.0,<1.43.0)", "mypy-boto3-kinesisanalytics (>=1.42.0,<1.43.0)", "mypy-boto3-kinesisanalyticsv2 (>=1.42.0,<1.43.0)", "mypy-boto3-kinesisvideo (>=1.42.0,<1.43.0)", "mypy-boto3-kms (>=1.42.0,<1.43.0)", "mypy-boto3-lakeformation (>=1.42.0,<1.43.0)", "mypy-boto3-lambda (>=1.42.0,<1.43.0)", "mypy-boto3-launch-wizard (>=1.42.0,<1.43.0)", "mypy-boto3-lex-models (>=1.42.0,<1.43.0)", "mypy-boto3-lex-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-lexv2-models (>=1.42.0,<1.43.0)", "mypy-boto3-lexv2-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-license-manager (>=1.42.0,<1.43.0)", "mypy-boto3-license-manager-linux-subscriptions (>=1.42.0,<1.43.0)", "mypy-boto3-license-manager-user-subscriptions (>=1.42.0,<1.43.0)", "mypy-boto3-lightsail (>=1.42.0,<1.43.0)", "mypy-boto3-location (>=1.42.0,<1.43.0)", "mypy-boto3-logs (>=1.42.0,<1.43.0)", "mypy-boto3-lookoutequipment (>=1.42.0,<1.43.0)", "mypy-boto3-m2 (>=1.42.0,<1.43.0)", "mypy-boto3-machinelearning (>=1.42.0,<1.43.0)", "mypy-boto3-macie2 (>=1.42.0,<1.43.0)", "mypy-boto3-mailmanager (>=1.42.0,<1.43.0)", "mypy-boto3-managedblockchain (>=1.42.0,<1.43.0)", "mypy-boto3-managedblockchain-query (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-agreement (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-catalog (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-deployment (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-entitlement (>=1.42.0,<1.43.0)", "mypy-boto3-marketplace-reporting (>=1.42.0,<1.43.0)", "mypy-boto3-marketplacecommerceanalytics (>=1.42.0,<1.43.0)", "mypy-boto3-mediaconnect (>=1.42.0,<1.43.0)", "mypy-boto3-mediaconvert (>=1.42.0,<1.43.0)", "mypy-boto3-medialive (>=1.42.0,<1.43.0)", "mypy-boto3-mediapackage (>=1.42.0,<1.43.0)", "mypy-boto3-mediapackage-vod (>=1.42.0,<1.43.0)", "mypy-boto3-mediapackagev2 (>=1.42.0,<1.43.0)", "mypy-boto3-mediastore (>=1.42.0,<1.43.0)", "mypy-boto3-mediastore-data (>=1.42.0,<1.43.0)", "mypy-boto3-mediatailor (>=1.42.0,<1.43.0)", "mypy-boto3-medical-imaging (>=1.42.0,<1.43.0)", "mypy-boto3-memorydb (>=1.42.0,<1.43.0)", "mypy-boto3-meteringmarketplace (>=1.42.0,<1.43.0)", "mypy-boto3-mgh (>=1.42.0,<1.43.0)", "mypy-boto3-mgn (>=1.42.0,<1.43.0)", "mypy-boto3-migration-hub-refactor-spaces (>=1.42.0,<1.43.0)", "mypy-boto3-migrationhub-config (>=1.42.0,<1.43.0)", "mypy-boto3-migrationhuborchestrator (>=1.42.0,<1.43.0)", "mypy-boto3-migrationhubstrategy (>=1.42.0,<1.43.0)", "mypy-boto3-mpa (>=1.42.0,<1.43.0)", "mypy-boto3-mq (>=1.42.0,<1.43.0)", "mypy-boto3-mturk (>=1.42.0,<1.43.0)", "mypy-boto3-mwaa (>=1.42.0,<1.43.0)", "mypy-boto3-mwaa-serverless (>=1.42.0,<1.43.0)", "mypy-boto3-neptune (>=1.42.0,<1.43.0)", "mypy-boto3-neptune-graph (>=1.42.0,<1.43.0)", "mypy-boto3-neptunedata (>=1.42.0,<1.43.0)", "mypy-boto3-network-firewall (>=1.42.0,<1.43.0)", "mypy-boto3-networkflowmonitor (>=1.42.0,<1.43.0)", "mypy-boto3-networkmanager (>=1.42.0,<1.43.0)", "mypy-boto3-networkmonitor (>=1.42.0,<1.43.0)", "mypy-boto3-notifications (>=1.42.0,<1.43.0)", "mypy-boto3-notificationscontacts (>=1.42.0,<1.43.0)", "mypy-boto3-nova-act (>=1.42.0,<1.43.0)", "mypy-boto3-oam (>=1.42.0,<1.43.0)", "mypy-boto3-observabilityadmin (>=1.42.0,<1.43.0)", "mypy-boto3-odb (>=1.42.0,<1.43.0)", "mypy-boto3-omics (>=1.42.0,<1.43.0)", "mypy-boto3-opensearch (>=1.42.0,<1.43.0)", "mypy-boto3-opensearchserverless (>=1.42.0,<1.43.0)", "mypy-boto3-organizations (>=1.42.0,<1.43.0)", "mypy-boto3-osis (>=1.42.0,<1.43.0)", "mypy-boto3-outposts (>=1.42.0,<1.43.0)", "mypy-boto3-panorama (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-account (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-benefits (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-channel (>=1.42.0,<1.43.0)", "mypy-boto3-partnercentral-selling (>=1.42.0,<1.43.0)", "mypy-boto3-payment-cryptography (>=1.42.0,<1.43.0)", "mypy-boto3-payment-cryptography-data (>=1.42.0,<1.43.0)", "mypy-boto3-pca-connector-ad (>=1.42.0,<1.43.0)", "mypy-boto3-pca-connector-scep (>=1.42.0,<1.43.0)", "mypy-boto3-pcs (>=1.42.0,<1.43.0)", "mypy-boto3-personalize (>=1.42.0,<1.43.0)", "mypy-boto3-personalize-events (>=1.42.0,<1.43.0)", "mypy-boto3-personalize-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-pi (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint-email (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint-sms-voice (>=1.42.0,<1.43.0)", "mypy-boto3-pinpoint-sms-voice-v2 (>=1.42.0,<1.43.0)", "mypy-boto3-pipes (>=1.42.0,<1.43.0)", "mypy-boto3-polly (>=1.42.0,<1.43.0)", "mypy-boto3-pricing (>=1.42.0,<1.43.0)", "mypy-boto3-proton (>=1.42.0,<1.43.0)", "mypy-boto3-qapps (>=1.42.0,<1.43.0)", "mypy-boto3-qbusiness (>=1.42.0,<1.43.0)", "mypy-boto3-qconnect (>=1.42.0,<1.43.0)", "mypy-boto3-quicksight (>=1.42.0,<1.43.0)", "mypy-boto3-ram (>=1.42.0,<1.43.0)", "mypy-boto3-rbin (>=1.42.0,<1.43.0)", "mypy-boto3-rds (>=1.42.0,<1.43.0)", "mypy-boto3-rds-data (>=1.42.0,<1.43.0)", "mypy-boto3-redshift (>=1.42.0,<1.43.0)", "mypy-boto3-redshift-data (>=1.42.0,<1.43.0)", "mypy-boto3-redshift-serverless (>=1.42.0,<1.43.0)", "mypy-boto3-rekognition (>=1.42.0,<1.43.0)", "mypy-boto3-repostspace (>=1.42.0,<1.43.0)", "mypy-boto3-resiliencehub (>=1.42.0,<1.43.0)", "mypy-boto3-resource-explorer-2 (>=1.42.0,<1.43.0)", "mypy-boto3-resource-groups (>=1.42.0,<1.43.0)", "mypy-boto3-resourcegroupstaggingapi (>=1.42.0,<1.43.0)", "mypy-boto3-rolesanywhere (>=1.42.0,<1.43.0)", "mypy-boto3-route53 (>=1.42.0,<1.43.0)", "mypy-boto3-route53-recovery-cluster (>=1.42.0,<1.43.0)", "mypy-boto3-route53-recovery-control-config (>=1.42.0,<1.43.0)", "mypy-boto3-route53-recovery-readiness (>=1.42.0,<1.43.0)", "mypy-boto3-route53domains (>=1.42.0,<1.43.0)", "mypy-boto3-route53globalresolver (>=1.42.0,<1.43.0)", "mypy-boto3-route53profiles (>=1.42.0,<1.43.0)", "mypy-boto3-route53resolver (>=1.42.0,<1.43.0)", "mypy-boto3-rtbfabric (>=1.42.0,<1.43.0)", "mypy-boto3-rum (>=1.42.0,<1.43.0)", "mypy-boto3-s3 (>=1.42.0,<1.43.0)", "mypy-boto3-s3control (>=1.42.0,<1.43.0)", "mypy-boto3-s3outposts (>=1.42.0,<1.43.0)", "mypy-boto3-s3tables (>=1.42.0,<1.43.0)", "mypy-boto3-s3vectors (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-a2i-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-edge (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-featurestore-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-geospatial (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-metrics (>=1.42.0,<1.43.0)", "mypy-boto3-sagemaker-runtime (>=1.42.0,<1.43.0)", "mypy-boto3-savingsplans (>=1.42.0,<1.43.0)", "mypy-boto3-scheduler (>=1.42.0,<1.43.0)", "mypy-boto3-schemas (>=1.42.0,<1.43.0)", "mypy-boto3-sdb (>=1.42.0,<1.43.0)", "mypy-boto3-secretsmanager (>=1.42.0,<1.43.0)", "mypy-boto3-security-ir (>=1.42.0,<1.43.0)", "mypy-boto3-securityhub (>=1.42.0,<1.43.0)", "mypy-boto3-securitylake (>=1.42.0,<1.43.0)", "mypy-boto3-serverlessrepo (>=1.42.0,<1.43.0)", "mypy-boto3-service-quotas (>=1.42.0,<1.43.0)", "mypy-boto3-servicecatalog (>=1.42.0,<1.43.0)", "mypy-boto3-servicecatalog-appregistry (>=1.42.0,<1.43.0)", "mypy-boto3-servicediscovery (>=1.42.0,<1.43.0)", "mypy-boto3-ses (>=1.42.0,<1.43.0)", "mypy-boto3-sesv2 (>=1.42.0,<1.43.0)", "mypy-boto3-shield (>=1.42.0,<1.43.0)", "mypy-boto3-signer (>=1.42.0,<1.43.0)", "mypy-boto3-signin (>=1.42.0,<1.43.0)", "mypy-boto3-simspaceweaver (>=1.42.0,<1.43.0)", "mypy-boto3-snow-device-management (>=1.42.0,<1.43.0)", "mypy-boto3-snowball (>=1.42.0,<1.43.0)", "mypy-boto3-sns (>=1.42.0,<1.43.0)", "mypy-boto3-socialmessaging (>=1.42.0,<1.43.0)", "mypy-boto3-sqs (>=1.42.0,<1.43.0)", "mypy-boto3-ssm (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-contacts (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-guiconnect (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-incidents (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-quicksetup (>=1.42.0,<1.43.0)", "mypy-boto3-ssm-sap (>=1.42.0,<1.43.0)", "mypy-boto3-sso (>=1.42.0,<1.43.0)", "mypy-boto3-sso-admin (>=1.42.0,<1.43.0)", "mypy-boto3-sso-oidc (>=1.42.0,<1.43.0)", "mypy-boto3-stepfunctions (>=1.42.0,<1.43.0)", "mypy-boto3-storagegateway (>=1.42.0,<1.43.0)", "mypy-boto3-sts (>=1.42.0,<1.43.0)", "mypy-boto3-supplychain (>=1.42.0,<1.43.0)", "mypy-boto3-support (>=1.42.0,<1.43.0)", "mypy-boto3-support-app (>=1.42.0,<1.43.0)", "mypy-boto3-swf (>=1.42.0,<1.43.0)", "mypy-boto3-synthetics (>=1.42.0,<1.43.0)", "mypy-boto3-taxsettings (>=1.42.0,<1.43.0)", "mypy-boto3-textract (>=1.42.0,<1.43.0)", "mypy-boto3-timestream-influxdb (>=1.42.0,<1.43.0)", "mypy-boto3-timestream-query (>=1.42.0,<1.43.0)", "mypy-boto3-timestream-write (>=1.42.0,<1.43.0)", "mypy-boto3-tnb (>=1.42.0,<1.43.0)", "mypy-boto3-transcribe (>=1.42.0,<1.43.0)", "mypy-boto3-transfer (>=1.42.0,<1.43.0)", "mypy-boto3-translate (>=1.42.0,<1.43.0)", "mypy-boto3-trustedadvisor (>=1.42.0,<1.43.0)", "mypy-boto3-verifiedpermissions (>=1.42.0,<1.43.0)", "mypy-boto3-voice-id (>=1.42.0,<1.43.0)", "mypy-boto3-vpc-lattice (>=1.42.0,<1.43.0)", "mypy-boto3-waf (>=1.42.0,<1.43.0)", "mypy-boto3-waf-regional (>=1.42.0,<1.43.0)", "mypy-boto3-wafv2 (>=1.42.0,<1.43.0)", "mypy-boto3-wellarchitected (>=1.42.0,<1.43.0)", "mypy-boto3-wickr (>=1.42.0,<1.43.0)", "mypy-boto3-wisdom (>=1.42.0,<1.43.0)", "mypy-boto3-workdocs (>=1.42.0,<1.43.0)", "mypy-boto3-workmail (>=1.42.0,<1.43.0)", "mypy-boto3-workmailmessageflow (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces-instances (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces-thin-client (>=1.42.0,<1.43.0)", "mypy-boto3-workspaces-web (>=1.42.0,<1.43.0)", "mypy-boto3-xray (>=1.42.0,<1.43.0)"] +amp = ["mypy-boto3-amp (>=1.42.0,<1.43.0)"] +amplify = ["mypy-boto3-amplify (>=1.42.0,<1.43.0)"] +amplifybackend = ["mypy-boto3-amplifybackend (>=1.42.0,<1.43.0)"] +amplifyuibuilder = ["mypy-boto3-amplifyuibuilder (>=1.42.0,<1.43.0)"] +apigateway = ["mypy-boto3-apigateway (>=1.42.0,<1.43.0)"] +apigatewaymanagementapi = ["mypy-boto3-apigatewaymanagementapi (>=1.42.0,<1.43.0)"] +apigatewayv2 = ["mypy-boto3-apigatewayv2 (>=1.42.0,<1.43.0)"] +appconfig = ["mypy-boto3-appconfig (>=1.42.0,<1.43.0)"] +appconfigdata = ["mypy-boto3-appconfigdata (>=1.42.0,<1.43.0)"] +appfabric = ["mypy-boto3-appfabric (>=1.42.0,<1.43.0)"] +appflow = ["mypy-boto3-appflow (>=1.42.0,<1.43.0)"] +appintegrations = ["mypy-boto3-appintegrations (>=1.42.0,<1.43.0)"] +application-autoscaling = ["mypy-boto3-application-autoscaling (>=1.42.0,<1.43.0)"] +application-insights = ["mypy-boto3-application-insights (>=1.42.0,<1.43.0)"] +application-signals = ["mypy-boto3-application-signals (>=1.42.0,<1.43.0)"] +applicationcostprofiler = ["mypy-boto3-applicationcostprofiler (>=1.42.0,<1.43.0)"] +appmesh = ["mypy-boto3-appmesh (>=1.42.0,<1.43.0)"] +apprunner = ["mypy-boto3-apprunner (>=1.42.0,<1.43.0)"] +appstream = ["mypy-boto3-appstream (>=1.42.0,<1.43.0)"] +appsync = ["mypy-boto3-appsync (>=1.42.0,<1.43.0)"] +arc-region-switch = ["mypy-boto3-arc-region-switch (>=1.42.0,<1.43.0)"] +arc-zonal-shift = ["mypy-boto3-arc-zonal-shift (>=1.42.0,<1.43.0)"] +artifact = ["mypy-boto3-artifact (>=1.42.0,<1.43.0)"] +athena = ["mypy-boto3-athena (>=1.42.0,<1.43.0)"] +auditmanager = ["mypy-boto3-auditmanager (>=1.42.0,<1.43.0)"] +autoscaling = ["mypy-boto3-autoscaling (>=1.42.0,<1.43.0)"] +autoscaling-plans = ["mypy-boto3-autoscaling-plans (>=1.42.0,<1.43.0)"] +b2bi = ["mypy-boto3-b2bi (>=1.42.0,<1.43.0)"] +backup = ["mypy-boto3-backup (>=1.42.0,<1.43.0)"] +backup-gateway = ["mypy-boto3-backup-gateway (>=1.42.0,<1.43.0)"] +backupsearch = ["mypy-boto3-backupsearch (>=1.42.0,<1.43.0)"] +batch = ["mypy-boto3-batch (>=1.42.0,<1.43.0)"] +bcm-dashboards = ["mypy-boto3-bcm-dashboards (>=1.42.0,<1.43.0)"] +bcm-data-exports = ["mypy-boto3-bcm-data-exports (>=1.42.0,<1.43.0)"] +bcm-pricing-calculator = ["mypy-boto3-bcm-pricing-calculator (>=1.42.0,<1.43.0)"] +bcm-recommended-actions = ["mypy-boto3-bcm-recommended-actions (>=1.42.0,<1.43.0)"] +bedrock = ["mypy-boto3-bedrock (>=1.42.0,<1.43.0)"] +bedrock-agent = ["mypy-boto3-bedrock-agent (>=1.42.0,<1.43.0)"] +bedrock-agent-runtime = ["mypy-boto3-bedrock-agent-runtime (>=1.42.0,<1.43.0)"] +bedrock-agentcore = ["mypy-boto3-bedrock-agentcore (>=1.42.0,<1.43.0)"] +bedrock-agentcore-control = ["mypy-boto3-bedrock-agentcore-control (>=1.42.0,<1.43.0)"] +bedrock-data-automation = ["mypy-boto3-bedrock-data-automation (>=1.42.0,<1.43.0)"] +bedrock-data-automation-runtime = ["mypy-boto3-bedrock-data-automation-runtime (>=1.42.0,<1.43.0)"] +bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.42.0,<1.43.0)"] +billing = ["mypy-boto3-billing (>=1.42.0,<1.43.0)"] +billingconductor = ["mypy-boto3-billingconductor (>=1.42.0,<1.43.0)"] +boto3 = ["boto3 (==1.42.33)"] +braket = ["mypy-boto3-braket (>=1.42.0,<1.43.0)"] +budgets = ["mypy-boto3-budgets (>=1.42.0,<1.43.0)"] +ce = ["mypy-boto3-ce (>=1.42.0,<1.43.0)"] +chatbot = ["mypy-boto3-chatbot (>=1.42.0,<1.43.0)"] +chime = ["mypy-boto3-chime (>=1.42.0,<1.43.0)"] +chime-sdk-identity = ["mypy-boto3-chime-sdk-identity (>=1.42.0,<1.43.0)"] +chime-sdk-media-pipelines = ["mypy-boto3-chime-sdk-media-pipelines (>=1.42.0,<1.43.0)"] +chime-sdk-meetings = ["mypy-boto3-chime-sdk-meetings (>=1.42.0,<1.43.0)"] +chime-sdk-messaging = ["mypy-boto3-chime-sdk-messaging (>=1.42.0,<1.43.0)"] +chime-sdk-voice = ["mypy-boto3-chime-sdk-voice (>=1.42.0,<1.43.0)"] +cleanrooms = ["mypy-boto3-cleanrooms (>=1.42.0,<1.43.0)"] +cleanroomsml = ["mypy-boto3-cleanroomsml (>=1.42.0,<1.43.0)"] +cloud9 = ["mypy-boto3-cloud9 (>=1.42.0,<1.43.0)"] +cloudcontrol = ["mypy-boto3-cloudcontrol (>=1.42.0,<1.43.0)"] +clouddirectory = ["mypy-boto3-clouddirectory (>=1.42.0,<1.43.0)"] +cloudformation = ["mypy-boto3-cloudformation (>=1.42.0,<1.43.0)"] +cloudfront = ["mypy-boto3-cloudfront (>=1.42.0,<1.43.0)"] +cloudfront-keyvaluestore = ["mypy-boto3-cloudfront-keyvaluestore (>=1.42.0,<1.43.0)"] +cloudhsm = ["mypy-boto3-cloudhsm (>=1.42.0,<1.43.0)"] +cloudhsmv2 = ["mypy-boto3-cloudhsmv2 (>=1.42.0,<1.43.0)"] +cloudsearch = ["mypy-boto3-cloudsearch (>=1.42.0,<1.43.0)"] +cloudsearchdomain = ["mypy-boto3-cloudsearchdomain (>=1.42.0,<1.43.0)"] +cloudtrail = ["mypy-boto3-cloudtrail (>=1.42.0,<1.43.0)"] +cloudtrail-data = ["mypy-boto3-cloudtrail-data (>=1.42.0,<1.43.0)"] +cloudwatch = ["mypy-boto3-cloudwatch (>=1.42.0,<1.43.0)"] +codeartifact = ["mypy-boto3-codeartifact (>=1.42.0,<1.43.0)"] +codebuild = ["mypy-boto3-codebuild (>=1.42.0,<1.43.0)"] +codecatalyst = ["mypy-boto3-codecatalyst (>=1.42.0,<1.43.0)"] +codecommit = ["mypy-boto3-codecommit (>=1.42.0,<1.43.0)"] +codeconnections = ["mypy-boto3-codeconnections (>=1.42.0,<1.43.0)"] +codedeploy = ["mypy-boto3-codedeploy (>=1.42.0,<1.43.0)"] +codeguru-reviewer = ["mypy-boto3-codeguru-reviewer (>=1.42.0,<1.43.0)"] +codeguru-security = ["mypy-boto3-codeguru-security (>=1.42.0,<1.43.0)"] +codeguruprofiler = ["mypy-boto3-codeguruprofiler (>=1.42.0,<1.43.0)"] +codepipeline = ["mypy-boto3-codepipeline (>=1.42.0,<1.43.0)"] +codestar-connections = ["mypy-boto3-codestar-connections (>=1.42.0,<1.43.0)"] +codestar-notifications = ["mypy-boto3-codestar-notifications (>=1.42.0,<1.43.0)"] +cognito-identity = ["mypy-boto3-cognito-identity (>=1.42.0,<1.43.0)"] +cognito-idp = ["mypy-boto3-cognito-idp (>=1.42.0,<1.43.0)"] +cognito-sync = ["mypy-boto3-cognito-sync (>=1.42.0,<1.43.0)"] +comprehend = ["mypy-boto3-comprehend (>=1.42.0,<1.43.0)"] +comprehendmedical = ["mypy-boto3-comprehendmedical (>=1.42.0,<1.43.0)"] +compute-optimizer = ["mypy-boto3-compute-optimizer (>=1.42.0,<1.43.0)"] +compute-optimizer-automation = ["mypy-boto3-compute-optimizer-automation (>=1.42.0,<1.43.0)"] +config = ["mypy-boto3-config (>=1.42.0,<1.43.0)"] +connect = ["mypy-boto3-connect (>=1.42.0,<1.43.0)"] +connect-contact-lens = ["mypy-boto3-connect-contact-lens (>=1.42.0,<1.43.0)"] +connectcampaigns = ["mypy-boto3-connectcampaigns (>=1.42.0,<1.43.0)"] +connectcampaignsv2 = ["mypy-boto3-connectcampaignsv2 (>=1.42.0,<1.43.0)"] +connectcases = ["mypy-boto3-connectcases (>=1.42.0,<1.43.0)"] +connectparticipant = ["mypy-boto3-connectparticipant (>=1.42.0,<1.43.0)"] +controlcatalog = ["mypy-boto3-controlcatalog (>=1.42.0,<1.43.0)"] +controltower = ["mypy-boto3-controltower (>=1.42.0,<1.43.0)"] +cost-optimization-hub = ["mypy-boto3-cost-optimization-hub (>=1.42.0,<1.43.0)"] +cur = ["mypy-boto3-cur (>=1.42.0,<1.43.0)"] +customer-profiles = ["mypy-boto3-customer-profiles (>=1.42.0,<1.43.0)"] +databrew = ["mypy-boto3-databrew (>=1.42.0,<1.43.0)"] +dataexchange = ["mypy-boto3-dataexchange (>=1.42.0,<1.43.0)"] +datapipeline = ["mypy-boto3-datapipeline (>=1.42.0,<1.43.0)"] +datasync = ["mypy-boto3-datasync (>=1.42.0,<1.43.0)"] +datazone = ["mypy-boto3-datazone (>=1.42.0,<1.43.0)"] +dax = ["mypy-boto3-dax (>=1.42.0,<1.43.0)"] +deadline = ["mypy-boto3-deadline (>=1.42.0,<1.43.0)"] +detective = ["mypy-boto3-detective (>=1.42.0,<1.43.0)"] +devicefarm = ["mypy-boto3-devicefarm (>=1.42.0,<1.43.0)"] +devops-guru = ["mypy-boto3-devops-guru (>=1.42.0,<1.43.0)"] +directconnect = ["mypy-boto3-directconnect (>=1.42.0,<1.43.0)"] +discovery = ["mypy-boto3-discovery (>=1.42.0,<1.43.0)"] +dlm = ["mypy-boto3-dlm (>=1.42.0,<1.43.0)"] +dms = ["mypy-boto3-dms (>=1.42.0,<1.43.0)"] +docdb = ["mypy-boto3-docdb (>=1.42.0,<1.43.0)"] +docdb-elastic = ["mypy-boto3-docdb-elastic (>=1.42.0,<1.43.0)"] +drs = ["mypy-boto3-drs (>=1.42.0,<1.43.0)"] +ds = ["mypy-boto3-ds (>=1.42.0,<1.43.0)"] +ds-data = ["mypy-boto3-ds-data (>=1.42.0,<1.43.0)"] +dsql = ["mypy-boto3-dsql (>=1.42.0,<1.43.0)"] +dynamodb = ["mypy-boto3-dynamodb (>=1.42.0,<1.43.0)"] +dynamodbstreams = ["mypy-boto3-dynamodbstreams (>=1.42.0,<1.43.0)"] +ebs = ["mypy-boto3-ebs (>=1.42.0,<1.43.0)"] +ec2 = ["mypy-boto3-ec2 (>=1.42.0,<1.43.0)"] +ec2-instance-connect = ["mypy-boto3-ec2-instance-connect (>=1.42.0,<1.43.0)"] +ecr = ["mypy-boto3-ecr (>=1.42.0,<1.43.0)"] +ecr-public = ["mypy-boto3-ecr-public (>=1.42.0,<1.43.0)"] +ecs = ["mypy-boto3-ecs (>=1.42.0,<1.43.0)"] +efs = ["mypy-boto3-efs (>=1.42.0,<1.43.0)"] +eks = ["mypy-boto3-eks (>=1.42.0,<1.43.0)"] +eks-auth = ["mypy-boto3-eks-auth (>=1.42.0,<1.43.0)"] +elasticache = ["mypy-boto3-elasticache (>=1.42.0,<1.43.0)"] +elasticbeanstalk = ["mypy-boto3-elasticbeanstalk (>=1.42.0,<1.43.0)"] +elb = ["mypy-boto3-elb (>=1.42.0,<1.43.0)"] +elbv2 = ["mypy-boto3-elbv2 (>=1.42.0,<1.43.0)"] +emr = ["mypy-boto3-emr (>=1.42.0,<1.43.0)"] +emr-containers = ["mypy-boto3-emr-containers (>=1.42.0,<1.43.0)"] +emr-serverless = ["mypy-boto3-emr-serverless (>=1.42.0,<1.43.0)"] +entityresolution = ["mypy-boto3-entityresolution (>=1.42.0,<1.43.0)"] +es = ["mypy-boto3-es (>=1.42.0,<1.43.0)"] +essential = ["mypy-boto3-cloudformation (>=1.42.0,<1.43.0)", "mypy-boto3-dynamodb (>=1.42.0,<1.43.0)", "mypy-boto3-ec2 (>=1.42.0,<1.43.0)", "mypy-boto3-lambda (>=1.42.0,<1.43.0)", "mypy-boto3-rds (>=1.42.0,<1.43.0)", "mypy-boto3-s3 (>=1.42.0,<1.43.0)", "mypy-boto3-sqs (>=1.42.0,<1.43.0)"] +events = ["mypy-boto3-events (>=1.42.0,<1.43.0)"] +evidently = ["mypy-boto3-evidently (>=1.42.0,<1.43.0)"] +evs = ["mypy-boto3-evs (>=1.42.0,<1.43.0)"] +finspace = ["mypy-boto3-finspace (>=1.42.0,<1.43.0)"] +finspace-data = ["mypy-boto3-finspace-data (>=1.42.0,<1.43.0)"] +firehose = ["mypy-boto3-firehose (>=1.42.0,<1.43.0)"] +fis = ["mypy-boto3-fis (>=1.42.0,<1.43.0)"] +fms = ["mypy-boto3-fms (>=1.42.0,<1.43.0)"] +forecast = ["mypy-boto3-forecast (>=1.42.0,<1.43.0)"] +forecastquery = ["mypy-boto3-forecastquery (>=1.42.0,<1.43.0)"] +frauddetector = ["mypy-boto3-frauddetector (>=1.42.0,<1.43.0)"] +freetier = ["mypy-boto3-freetier (>=1.42.0,<1.43.0)"] +fsx = ["mypy-boto3-fsx (>=1.42.0,<1.43.0)"] +full = ["boto3-stubs-full (>=1.42.0,<1.43.0)"] +gamelift = ["mypy-boto3-gamelift (>=1.42.0,<1.43.0)"] +gameliftstreams = ["mypy-boto3-gameliftstreams (>=1.42.0,<1.43.0)"] +geo-maps = ["mypy-boto3-geo-maps (>=1.42.0,<1.43.0)"] +geo-places = ["mypy-boto3-geo-places (>=1.42.0,<1.43.0)"] +geo-routes = ["mypy-boto3-geo-routes (>=1.42.0,<1.43.0)"] +glacier = ["mypy-boto3-glacier (>=1.42.0,<1.43.0)"] +globalaccelerator = ["mypy-boto3-globalaccelerator (>=1.42.0,<1.43.0)"] +glue = ["mypy-boto3-glue (>=1.42.0,<1.43.0)"] +grafana = ["mypy-boto3-grafana (>=1.42.0,<1.43.0)"] +greengrass = ["mypy-boto3-greengrass (>=1.42.0,<1.43.0)"] +greengrassv2 = ["mypy-boto3-greengrassv2 (>=1.42.0,<1.43.0)"] +groundstation = ["mypy-boto3-groundstation (>=1.42.0,<1.43.0)"] +guardduty = ["mypy-boto3-guardduty (>=1.42.0,<1.43.0)"] +health = ["mypy-boto3-health (>=1.42.0,<1.43.0)"] +healthlake = ["mypy-boto3-healthlake (>=1.42.0,<1.43.0)"] +iam = ["mypy-boto3-iam (>=1.42.0,<1.43.0)"] +identitystore = ["mypy-boto3-identitystore (>=1.42.0,<1.43.0)"] +imagebuilder = ["mypy-boto3-imagebuilder (>=1.42.0,<1.43.0)"] +importexport = ["mypy-boto3-importexport (>=1.42.0,<1.43.0)"] +inspector = ["mypy-boto3-inspector (>=1.42.0,<1.43.0)"] +inspector-scan = ["mypy-boto3-inspector-scan (>=1.42.0,<1.43.0)"] +inspector2 = ["mypy-boto3-inspector2 (>=1.42.0,<1.43.0)"] +internetmonitor = ["mypy-boto3-internetmonitor (>=1.42.0,<1.43.0)"] +invoicing = ["mypy-boto3-invoicing (>=1.42.0,<1.43.0)"] +iot = ["mypy-boto3-iot (>=1.42.0,<1.43.0)"] +iot-data = ["mypy-boto3-iot-data (>=1.42.0,<1.43.0)"] +iot-jobs-data = ["mypy-boto3-iot-jobs-data (>=1.42.0,<1.43.0)"] +iot-managed-integrations = ["mypy-boto3-iot-managed-integrations (>=1.42.0,<1.43.0)"] +iotanalytics = ["mypy-boto3-iotanalytics (>=1.42.0,<1.43.0)"] +iotdeviceadvisor = ["mypy-boto3-iotdeviceadvisor (>=1.42.0,<1.43.0)"] +iotevents = ["mypy-boto3-iotevents (>=1.42.0,<1.43.0)"] +iotevents-data = ["mypy-boto3-iotevents-data (>=1.42.0,<1.43.0)"] +iotfleetwise = ["mypy-boto3-iotfleetwise (>=1.42.0,<1.43.0)"] +iotsecuretunneling = ["mypy-boto3-iotsecuretunneling (>=1.42.0,<1.43.0)"] +iotsitewise = ["mypy-boto3-iotsitewise (>=1.42.0,<1.43.0)"] +iotthingsgraph = ["mypy-boto3-iotthingsgraph (>=1.42.0,<1.43.0)"] +iottwinmaker = ["mypy-boto3-iottwinmaker (>=1.42.0,<1.43.0)"] +iotwireless = ["mypy-boto3-iotwireless (>=1.42.0,<1.43.0)"] +ivs = ["mypy-boto3-ivs (>=1.42.0,<1.43.0)"] +ivs-realtime = ["mypy-boto3-ivs-realtime (>=1.42.0,<1.43.0)"] +ivschat = ["mypy-boto3-ivschat (>=1.42.0,<1.43.0)"] +kafka = ["mypy-boto3-kafka (>=1.42.0,<1.43.0)"] +kafkaconnect = ["mypy-boto3-kafkaconnect (>=1.42.0,<1.43.0)"] +kendra = ["mypy-boto3-kendra (>=1.42.0,<1.43.0)"] +kendra-ranking = ["mypy-boto3-kendra-ranking (>=1.42.0,<1.43.0)"] +keyspaces = ["mypy-boto3-keyspaces (>=1.42.0,<1.43.0)"] +keyspacesstreams = ["mypy-boto3-keyspacesstreams (>=1.42.0,<1.43.0)"] +kinesis = ["mypy-boto3-kinesis (>=1.42.0,<1.43.0)"] +kinesis-video-archived-media = ["mypy-boto3-kinesis-video-archived-media (>=1.42.0,<1.43.0)"] +kinesis-video-media = ["mypy-boto3-kinesis-video-media (>=1.42.0,<1.43.0)"] +kinesis-video-signaling = ["mypy-boto3-kinesis-video-signaling (>=1.42.0,<1.43.0)"] +kinesis-video-webrtc-storage = ["mypy-boto3-kinesis-video-webrtc-storage (>=1.42.0,<1.43.0)"] +kinesisanalytics = ["mypy-boto3-kinesisanalytics (>=1.42.0,<1.43.0)"] +kinesisanalyticsv2 = ["mypy-boto3-kinesisanalyticsv2 (>=1.42.0,<1.43.0)"] +kinesisvideo = ["mypy-boto3-kinesisvideo (>=1.42.0,<1.43.0)"] +kms = ["mypy-boto3-kms (>=1.42.0,<1.43.0)"] +lakeformation = ["mypy-boto3-lakeformation (>=1.42.0,<1.43.0)"] +lambda = ["mypy-boto3-lambda (>=1.42.0,<1.43.0)"] +launch-wizard = ["mypy-boto3-launch-wizard (>=1.42.0,<1.43.0)"] +lex-models = ["mypy-boto3-lex-models (>=1.42.0,<1.43.0)"] +lex-runtime = ["mypy-boto3-lex-runtime (>=1.42.0,<1.43.0)"] +lexv2-models = ["mypy-boto3-lexv2-models (>=1.42.0,<1.43.0)"] +lexv2-runtime = ["mypy-boto3-lexv2-runtime (>=1.42.0,<1.43.0)"] +license-manager = ["mypy-boto3-license-manager (>=1.42.0,<1.43.0)"] +license-manager-linux-subscriptions = ["mypy-boto3-license-manager-linux-subscriptions (>=1.42.0,<1.43.0)"] +license-manager-user-subscriptions = ["mypy-boto3-license-manager-user-subscriptions (>=1.42.0,<1.43.0)"] +lightsail = ["mypy-boto3-lightsail (>=1.42.0,<1.43.0)"] +location = ["mypy-boto3-location (>=1.42.0,<1.43.0)"] +logs = ["mypy-boto3-logs (>=1.42.0,<1.43.0)"] +lookoutequipment = ["mypy-boto3-lookoutequipment (>=1.42.0,<1.43.0)"] +m2 = ["mypy-boto3-m2 (>=1.42.0,<1.43.0)"] +machinelearning = ["mypy-boto3-machinelearning (>=1.42.0,<1.43.0)"] +macie2 = ["mypy-boto3-macie2 (>=1.42.0,<1.43.0)"] +mailmanager = ["mypy-boto3-mailmanager (>=1.42.0,<1.43.0)"] +managedblockchain = ["mypy-boto3-managedblockchain (>=1.42.0,<1.43.0)"] +managedblockchain-query = ["mypy-boto3-managedblockchain-query (>=1.42.0,<1.43.0)"] +marketplace-agreement = ["mypy-boto3-marketplace-agreement (>=1.42.0,<1.43.0)"] +marketplace-catalog = ["mypy-boto3-marketplace-catalog (>=1.42.0,<1.43.0)"] +marketplace-deployment = ["mypy-boto3-marketplace-deployment (>=1.42.0,<1.43.0)"] +marketplace-entitlement = ["mypy-boto3-marketplace-entitlement (>=1.42.0,<1.43.0)"] +marketplace-reporting = ["mypy-boto3-marketplace-reporting (>=1.42.0,<1.43.0)"] +marketplacecommerceanalytics = ["mypy-boto3-marketplacecommerceanalytics (>=1.42.0,<1.43.0)"] +mediaconnect = ["mypy-boto3-mediaconnect (>=1.42.0,<1.43.0)"] +mediaconvert = ["mypy-boto3-mediaconvert (>=1.42.0,<1.43.0)"] +medialive = ["mypy-boto3-medialive (>=1.42.0,<1.43.0)"] +mediapackage = ["mypy-boto3-mediapackage (>=1.42.0,<1.43.0)"] +mediapackage-vod = ["mypy-boto3-mediapackage-vod (>=1.42.0,<1.43.0)"] +mediapackagev2 = ["mypy-boto3-mediapackagev2 (>=1.42.0,<1.43.0)"] +mediastore = ["mypy-boto3-mediastore (>=1.42.0,<1.43.0)"] +mediastore-data = ["mypy-boto3-mediastore-data (>=1.42.0,<1.43.0)"] +mediatailor = ["mypy-boto3-mediatailor (>=1.42.0,<1.43.0)"] +medical-imaging = ["mypy-boto3-medical-imaging (>=1.42.0,<1.43.0)"] +memorydb = ["mypy-boto3-memorydb (>=1.42.0,<1.43.0)"] +meteringmarketplace = ["mypy-boto3-meteringmarketplace (>=1.42.0,<1.43.0)"] +mgh = ["mypy-boto3-mgh (>=1.42.0,<1.43.0)"] +mgn = ["mypy-boto3-mgn (>=1.42.0,<1.43.0)"] +migration-hub-refactor-spaces = ["mypy-boto3-migration-hub-refactor-spaces (>=1.42.0,<1.43.0)"] +migrationhub-config = ["mypy-boto3-migrationhub-config (>=1.42.0,<1.43.0)"] +migrationhuborchestrator = ["mypy-boto3-migrationhuborchestrator (>=1.42.0,<1.43.0)"] +migrationhubstrategy = ["mypy-boto3-migrationhubstrategy (>=1.42.0,<1.43.0)"] +mpa = ["mypy-boto3-mpa (>=1.42.0,<1.43.0)"] +mq = ["mypy-boto3-mq (>=1.42.0,<1.43.0)"] +mturk = ["mypy-boto3-mturk (>=1.42.0,<1.43.0)"] +mwaa = ["mypy-boto3-mwaa (>=1.42.0,<1.43.0)"] +mwaa-serverless = ["mypy-boto3-mwaa-serverless (>=1.42.0,<1.43.0)"] +neptune = ["mypy-boto3-neptune (>=1.42.0,<1.43.0)"] +neptune-graph = ["mypy-boto3-neptune-graph (>=1.42.0,<1.43.0)"] +neptunedata = ["mypy-boto3-neptunedata (>=1.42.0,<1.43.0)"] +network-firewall = ["mypy-boto3-network-firewall (>=1.42.0,<1.43.0)"] +networkflowmonitor = ["mypy-boto3-networkflowmonitor (>=1.42.0,<1.43.0)"] +networkmanager = ["mypy-boto3-networkmanager (>=1.42.0,<1.43.0)"] +networkmonitor = ["mypy-boto3-networkmonitor (>=1.42.0,<1.43.0)"] +notifications = ["mypy-boto3-notifications (>=1.42.0,<1.43.0)"] +notificationscontacts = ["mypy-boto3-notificationscontacts (>=1.42.0,<1.43.0)"] +nova-act = ["mypy-boto3-nova-act (>=1.42.0,<1.43.0)"] +oam = ["mypy-boto3-oam (>=1.42.0,<1.43.0)"] +observabilityadmin = ["mypy-boto3-observabilityadmin (>=1.42.0,<1.43.0)"] +odb = ["mypy-boto3-odb (>=1.42.0,<1.43.0)"] +omics = ["mypy-boto3-omics (>=1.42.0,<1.43.0)"] +opensearch = ["mypy-boto3-opensearch (>=1.42.0,<1.43.0)"] +opensearchserverless = ["mypy-boto3-opensearchserverless (>=1.42.0,<1.43.0)"] +organizations = ["mypy-boto3-organizations (>=1.42.0,<1.43.0)"] +osis = ["mypy-boto3-osis (>=1.42.0,<1.43.0)"] +outposts = ["mypy-boto3-outposts (>=1.42.0,<1.43.0)"] +panorama = ["mypy-boto3-panorama (>=1.42.0,<1.43.0)"] +partnercentral-account = ["mypy-boto3-partnercentral-account (>=1.42.0,<1.43.0)"] +partnercentral-benefits = ["mypy-boto3-partnercentral-benefits (>=1.42.0,<1.43.0)"] +partnercentral-channel = ["mypy-boto3-partnercentral-channel (>=1.42.0,<1.43.0)"] +partnercentral-selling = ["mypy-boto3-partnercentral-selling (>=1.42.0,<1.43.0)"] +payment-cryptography = ["mypy-boto3-payment-cryptography (>=1.42.0,<1.43.0)"] +payment-cryptography-data = ["mypy-boto3-payment-cryptography-data (>=1.42.0,<1.43.0)"] +pca-connector-ad = ["mypy-boto3-pca-connector-ad (>=1.42.0,<1.43.0)"] +pca-connector-scep = ["mypy-boto3-pca-connector-scep (>=1.42.0,<1.43.0)"] +pcs = ["mypy-boto3-pcs (>=1.42.0,<1.43.0)"] +personalize = ["mypy-boto3-personalize (>=1.42.0,<1.43.0)"] +personalize-events = ["mypy-boto3-personalize-events (>=1.42.0,<1.43.0)"] +personalize-runtime = ["mypy-boto3-personalize-runtime (>=1.42.0,<1.43.0)"] +pi = ["mypy-boto3-pi (>=1.42.0,<1.43.0)"] +pinpoint = ["mypy-boto3-pinpoint (>=1.42.0,<1.43.0)"] +pinpoint-email = ["mypy-boto3-pinpoint-email (>=1.42.0,<1.43.0)"] +pinpoint-sms-voice = ["mypy-boto3-pinpoint-sms-voice (>=1.42.0,<1.43.0)"] +pinpoint-sms-voice-v2 = ["mypy-boto3-pinpoint-sms-voice-v2 (>=1.42.0,<1.43.0)"] +pipes = ["mypy-boto3-pipes (>=1.42.0,<1.43.0)"] +polly = ["mypy-boto3-polly (>=1.42.0,<1.43.0)"] +pricing = ["mypy-boto3-pricing (>=1.42.0,<1.43.0)"] +proton = ["mypy-boto3-proton (>=1.42.0,<1.43.0)"] +qapps = ["mypy-boto3-qapps (>=1.42.0,<1.43.0)"] +qbusiness = ["mypy-boto3-qbusiness (>=1.42.0,<1.43.0)"] +qconnect = ["mypy-boto3-qconnect (>=1.42.0,<1.43.0)"] +quicksight = ["mypy-boto3-quicksight (>=1.42.0,<1.43.0)"] +ram = ["mypy-boto3-ram (>=1.42.0,<1.43.0)"] +rbin = ["mypy-boto3-rbin (>=1.42.0,<1.43.0)"] +rds = ["mypy-boto3-rds (>=1.42.0,<1.43.0)"] +rds-data = ["mypy-boto3-rds-data (>=1.42.0,<1.43.0)"] +redshift = ["mypy-boto3-redshift (>=1.42.0,<1.43.0)"] +redshift-data = ["mypy-boto3-redshift-data (>=1.42.0,<1.43.0)"] +redshift-serverless = ["mypy-boto3-redshift-serverless (>=1.42.0,<1.43.0)"] +rekognition = ["mypy-boto3-rekognition (>=1.42.0,<1.43.0)"] +repostspace = ["mypy-boto3-repostspace (>=1.42.0,<1.43.0)"] +resiliencehub = ["mypy-boto3-resiliencehub (>=1.42.0,<1.43.0)"] +resource-explorer-2 = ["mypy-boto3-resource-explorer-2 (>=1.42.0,<1.43.0)"] +resource-groups = ["mypy-boto3-resource-groups (>=1.42.0,<1.43.0)"] +resourcegroupstaggingapi = ["mypy-boto3-resourcegroupstaggingapi (>=1.42.0,<1.43.0)"] +rolesanywhere = ["mypy-boto3-rolesanywhere (>=1.42.0,<1.43.0)"] +route53 = ["mypy-boto3-route53 (>=1.42.0,<1.43.0)"] +route53-recovery-cluster = ["mypy-boto3-route53-recovery-cluster (>=1.42.0,<1.43.0)"] +route53-recovery-control-config = ["mypy-boto3-route53-recovery-control-config (>=1.42.0,<1.43.0)"] +route53-recovery-readiness = ["mypy-boto3-route53-recovery-readiness (>=1.42.0,<1.43.0)"] +route53domains = ["mypy-boto3-route53domains (>=1.42.0,<1.43.0)"] +route53globalresolver = ["mypy-boto3-route53globalresolver (>=1.42.0,<1.43.0)"] +route53profiles = ["mypy-boto3-route53profiles (>=1.42.0,<1.43.0)"] +route53resolver = ["mypy-boto3-route53resolver (>=1.42.0,<1.43.0)"] +rtbfabric = ["mypy-boto3-rtbfabric (>=1.42.0,<1.43.0)"] +rum = ["mypy-boto3-rum (>=1.42.0,<1.43.0)"] +s3 = ["mypy-boto3-s3 (>=1.42.0,<1.43.0)"] +s3control = ["mypy-boto3-s3control (>=1.42.0,<1.43.0)"] +s3outposts = ["mypy-boto3-s3outposts (>=1.42.0,<1.43.0)"] +s3tables = ["mypy-boto3-s3tables (>=1.42.0,<1.43.0)"] +s3vectors = ["mypy-boto3-s3vectors (>=1.42.0,<1.43.0)"] +sagemaker = ["mypy-boto3-sagemaker (>=1.42.0,<1.43.0)"] +sagemaker-a2i-runtime = ["mypy-boto3-sagemaker-a2i-runtime (>=1.42.0,<1.43.0)"] +sagemaker-edge = ["mypy-boto3-sagemaker-edge (>=1.42.0,<1.43.0)"] +sagemaker-featurestore-runtime = ["mypy-boto3-sagemaker-featurestore-runtime (>=1.42.0,<1.43.0)"] +sagemaker-geospatial = ["mypy-boto3-sagemaker-geospatial (>=1.42.0,<1.43.0)"] +sagemaker-metrics = ["mypy-boto3-sagemaker-metrics (>=1.42.0,<1.43.0)"] +sagemaker-runtime = ["mypy-boto3-sagemaker-runtime (>=1.42.0,<1.43.0)"] +savingsplans = ["mypy-boto3-savingsplans (>=1.42.0,<1.43.0)"] +scheduler = ["mypy-boto3-scheduler (>=1.42.0,<1.43.0)"] +schemas = ["mypy-boto3-schemas (>=1.42.0,<1.43.0)"] +sdb = ["mypy-boto3-sdb (>=1.42.0,<1.43.0)"] +secretsmanager = ["mypy-boto3-secretsmanager (>=1.42.0,<1.43.0)"] +security-ir = ["mypy-boto3-security-ir (>=1.42.0,<1.43.0)"] +securityhub = ["mypy-boto3-securityhub (>=1.42.0,<1.43.0)"] +securitylake = ["mypy-boto3-securitylake (>=1.42.0,<1.43.0)"] +serverlessrepo = ["mypy-boto3-serverlessrepo (>=1.42.0,<1.43.0)"] +service-quotas = ["mypy-boto3-service-quotas (>=1.42.0,<1.43.0)"] +servicecatalog = ["mypy-boto3-servicecatalog (>=1.42.0,<1.43.0)"] +servicecatalog-appregistry = ["mypy-boto3-servicecatalog-appregistry (>=1.42.0,<1.43.0)"] +servicediscovery = ["mypy-boto3-servicediscovery (>=1.42.0,<1.43.0)"] +ses = ["mypy-boto3-ses (>=1.42.0,<1.43.0)"] +sesv2 = ["mypy-boto3-sesv2 (>=1.42.0,<1.43.0)"] +shield = ["mypy-boto3-shield (>=1.42.0,<1.43.0)"] +signer = ["mypy-boto3-signer (>=1.42.0,<1.43.0)"] +signin = ["mypy-boto3-signin (>=1.42.0,<1.43.0)"] +simspaceweaver = ["mypy-boto3-simspaceweaver (>=1.42.0,<1.43.0)"] +snow-device-management = ["mypy-boto3-snow-device-management (>=1.42.0,<1.43.0)"] +snowball = ["mypy-boto3-snowball (>=1.42.0,<1.43.0)"] +sns = ["mypy-boto3-sns (>=1.42.0,<1.43.0)"] +socialmessaging = ["mypy-boto3-socialmessaging (>=1.42.0,<1.43.0)"] +sqs = ["mypy-boto3-sqs (>=1.42.0,<1.43.0)"] +ssm = ["mypy-boto3-ssm (>=1.42.0,<1.43.0)"] +ssm-contacts = ["mypy-boto3-ssm-contacts (>=1.42.0,<1.43.0)"] +ssm-guiconnect = ["mypy-boto3-ssm-guiconnect (>=1.42.0,<1.43.0)"] +ssm-incidents = ["mypy-boto3-ssm-incidents (>=1.42.0,<1.43.0)"] +ssm-quicksetup = ["mypy-boto3-ssm-quicksetup (>=1.42.0,<1.43.0)"] +ssm-sap = ["mypy-boto3-ssm-sap (>=1.42.0,<1.43.0)"] +sso = ["mypy-boto3-sso (>=1.42.0,<1.43.0)"] +sso-admin = ["mypy-boto3-sso-admin (>=1.42.0,<1.43.0)"] +sso-oidc = ["mypy-boto3-sso-oidc (>=1.42.0,<1.43.0)"] +stepfunctions = ["mypy-boto3-stepfunctions (>=1.42.0,<1.43.0)"] +storagegateway = ["mypy-boto3-storagegateway (>=1.42.0,<1.43.0)"] +sts = ["mypy-boto3-sts (>=1.42.0,<1.43.0)"] +supplychain = ["mypy-boto3-supplychain (>=1.42.0,<1.43.0)"] +support = ["mypy-boto3-support (>=1.42.0,<1.43.0)"] +support-app = ["mypy-boto3-support-app (>=1.42.0,<1.43.0)"] +swf = ["mypy-boto3-swf (>=1.42.0,<1.43.0)"] +synthetics = ["mypy-boto3-synthetics (>=1.42.0,<1.43.0)"] +taxsettings = ["mypy-boto3-taxsettings (>=1.42.0,<1.43.0)"] +textract = ["mypy-boto3-textract (>=1.42.0,<1.43.0)"] +timestream-influxdb = ["mypy-boto3-timestream-influxdb (>=1.42.0,<1.43.0)"] +timestream-query = ["mypy-boto3-timestream-query (>=1.42.0,<1.43.0)"] +timestream-write = ["mypy-boto3-timestream-write (>=1.42.0,<1.43.0)"] +tnb = ["mypy-boto3-tnb (>=1.42.0,<1.43.0)"] +transcribe = ["mypy-boto3-transcribe (>=1.42.0,<1.43.0)"] +transfer = ["mypy-boto3-transfer (>=1.42.0,<1.43.0)"] +translate = ["mypy-boto3-translate (>=1.42.0,<1.43.0)"] +trustedadvisor = ["mypy-boto3-trustedadvisor (>=1.42.0,<1.43.0)"] +verifiedpermissions = ["mypy-boto3-verifiedpermissions (>=1.42.0,<1.43.0)"] +voice-id = ["mypy-boto3-voice-id (>=1.42.0,<1.43.0)"] +vpc-lattice = ["mypy-boto3-vpc-lattice (>=1.42.0,<1.43.0)"] +waf = ["mypy-boto3-waf (>=1.42.0,<1.43.0)"] +waf-regional = ["mypy-boto3-waf-regional (>=1.42.0,<1.43.0)"] +wafv2 = ["mypy-boto3-wafv2 (>=1.42.0,<1.43.0)"] +wellarchitected = ["mypy-boto3-wellarchitected (>=1.42.0,<1.43.0)"] +wickr = ["mypy-boto3-wickr (>=1.42.0,<1.43.0)"] +wisdom = ["mypy-boto3-wisdom (>=1.42.0,<1.43.0)"] +workdocs = ["mypy-boto3-workdocs (>=1.42.0,<1.43.0)"] +workmail = ["mypy-boto3-workmail (>=1.42.0,<1.43.0)"] +workmailmessageflow = ["mypy-boto3-workmailmessageflow (>=1.42.0,<1.43.0)"] +workspaces = ["mypy-boto3-workspaces (>=1.42.0,<1.43.0)"] +workspaces-instances = ["mypy-boto3-workspaces-instances (>=1.42.0,<1.43.0)"] +workspaces-thin-client = ["mypy-boto3-workspaces-thin-client (>=1.42.0,<1.43.0)"] +workspaces-web = ["mypy-boto3-workspaces-web (>=1.42.0,<1.43.0)"] +xray = ["mypy-boto3-xray (>=1.42.0,<1.43.0)"] [[package]] name = "botocore" @@ -2496,6 +2526,21 @@ install-types = ["pip"] mypyc = ["setuptools (>=50)"] reports = ["lxml"] +[[package]] +name = "mypy-boto3-s3" +version = "1.42.21" +description = "Type annotations for boto3 S3 1.42.21 service generated with mypy-boto3-builder 8.12.0" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mypy_boto3_s3-1.42.21-py3-none-any.whl", hash = "sha256:f5b7d1ed718ba5b00f67e95a9a38c6a021159d3071ea235e6cf496e584115ded"}, + {file = "mypy_boto3_s3-1.42.21.tar.gz", hash = "sha256:cab71c918aac7d98c4d742544c722e37d8e7178acb8bc88a0aead7b1035026d2"}, +] + +[package.dependencies] +typing-extensions = {version = "*", markers = "python_version < \"3.12\""} + [[package]] name = "mypy-extensions" version = "1.1.0" @@ -4799,4 +4844,4 @@ server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", " [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "cb94d5f7faedc07aa0e3457fdb0735b6526b2f40f02c6d438cab46b733123fd6" +content-hash = "a92cfae921a52b547c08ab74fd06a60427d5ac28601c68f4ca6d740e2059dfb2" diff --git a/pyproject.toml b/pyproject.toml index f1217384..bb55a121 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ watchtower = { version = "~3.2.0", optional = true } optional = true [tool.poetry.group.dev.dependencies] -boto3-stubs = "~1.34.97" +boto3-stubs = { extras = ["s3"], version = "~1.42.33" } mypy = "~1.10.0" pre-commit = "*" jsonschema = "*" diff --git a/settings/.env.template b/settings/.env.template index fbb5b861..a11bbbbb 100644 --- a/settings/.env.template +++ b/settings/.env.template @@ -98,3 +98,12 @@ AWS_REGION_NAME=us-west-2 ATHENA_SCHEMA_NAME=default ATHENA_S3_STAGING_DIR=s3://your-bucket/path/to/staging/ GNOMAD_DATA_VERSION=v4.1 + +#################################################################################################### +# Environment variables for S3 connection +#################################################################################################### + +AWS_ACCESS_KEY_ID=test +AWS_SECRET_ACCESS_KEY=test +S3_ENDPOINT_URL=http://localstack:4566 +UPLOAD_S3_BUCKET_NAME=score-set-csv-uploads-dev \ No newline at end of file diff --git a/src/mavedb/data_providers/services.py b/src/mavedb/data_providers/services.py index eed9b01d..a94c16d6 100644 --- a/src/mavedb/data_providers/services.py +++ b/src/mavedb/data_providers/services.py @@ -1,10 +1,14 @@ import os -from typing import Optional +from typing import TYPE_CHECKING, Optional -from cdot.hgvs.dataproviders import SeqFetcher, ChainedSeqFetcher, FastaSeqFetcher, RESTDataProvider +import boto3 +from cdot.hgvs.dataproviders import ChainedSeqFetcher, FastaSeqFetcher, RESTDataProvider, SeqFetcher from mavedb.lib.mapping import VRSMap +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + GENOMIC_FASTA_FILES = [ "/data/GCF_000001405.39_GRCh38.p13_genomic.fna.gz", "/data/GCF_000001405.25_GRCh37.p13_genomic.fna.gz", @@ -12,6 +16,7 @@ DCD_MAP_URL = os.environ.get("DCD_MAPPING_URL", "http://dcd-mapping:8000") CDOT_URL = os.environ.get("CDOT_URL", "http://cdot-rest:8000") +CSV_UPLOAD_S3_BUCKET_NAME = os.getenv("UPLOAD_S3_BUCKET_NAME", "score-set-csv-uploads-dev") def seqfetcher() -> ChainedSeqFetcher: @@ -24,3 +29,13 @@ def cdot_rest() -> RESTDataProvider: def vrs_mapper(url: Optional[str] = None) -> VRSMap: return VRSMap(DCD_MAP_URL) if not url else VRSMap(url) + + +def s3_client() -> "S3Client": + return boto3.client( + "s3", + endpoint_url=os.getenv("S3_ENDPOINT_URL"), + aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_REGION_NAME", "us-west-2"), + ) diff --git a/src/mavedb/lib/clingen/constants.py b/src/mavedb/lib/clingen/constants.py index 2bc6979b..77a33a53 100644 --- a/src/mavedb/lib/clingen/constants.py +++ b/src/mavedb/lib/clingen/constants.py @@ -17,5 +17,3 @@ LDH_SUBMISSION_ENDPOINT = f"https://genboree.org/mq/brdg/pulsar/{CLIN_GEN_TENANT}/ldh/submissions/{LDH_ENTITY_ENDPOINT}" LDH_ACCESS_ENDPOINT = os.getenv("LDH_ACCESS_ENDPOINT", "https://genboree.org/ldh") LDH_MAVE_ACCESS_ENDPOINT = f"{LDH_ACCESS_ENDPOINT}/{LDH_ENTITY_NAME}/id" - -LINKED_DATA_RETRY_THRESHOLD = 0.95 diff --git a/src/mavedb/lib/exceptions.py b/src/mavedb/lib/exceptions.py index 8734becb..aae550d4 100644 --- a/src/mavedb/lib/exceptions.py +++ b/src/mavedb/lib/exceptions.py @@ -168,6 +168,12 @@ class NonexistentMappingResultsError(ValueError): pass +class NonexistentMappingScoresError(ValueError): + """Raised when score set mapping results do not contain mapping scores""" + + pass + + class NonexistentMappingReferenceError(ValueError): """Raised when score set mapping results do not contain a valid reference sequence""" diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 959f9133..a20f5829 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -1,3 +1,4 @@ +import io import json import logging import time @@ -20,6 +21,7 @@ from sqlalchemy.orm import Session, contains_eager from mavedb import deps +from mavedb.data_providers.services import CSV_UPLOAD_S3_BUCKET_NAME, s3_client from mavedb.lib.annotation.annotate import ( variant_functional_impact_statement, variant_pathogenicity_evidence, @@ -136,6 +138,37 @@ async def enqueue_variant_creation( variants_to_csv_rows(item.variants, columns=count_columns, namespaced=False) ).replace("NA", np.NaN) + scores_file_to_upload = existing_scores_df if new_scores_df is None else new_scores_df + counts_file_to_upload = existing_counts_df if new_counts_df is None else new_counts_df + + scores_file_key = None + counts_file_key = None + if scores_file_to_upload is not None or counts_file_to_upload is not None: + timestamp = date.today().isoformat() + unique_id = str(int(time.time() * 1000)) + user_id = user_data.user.id + score_set_id = item.id + + s3 = s3_client() + + if scores_file_to_upload is not None: + save_to_logging_context({"num_scores": len(scores_file_to_upload)}) + scores_file_key = f"{score_set_id}/{user_id}/{timestamp}-{unique_id}-scores.csv" + s3.upload_fileobj( + Fileobj=io.BytesIO(scores_file_to_upload.to_csv(index=False).encode("utf-8")), + Bucket=CSV_UPLOAD_S3_BUCKET_NAME, + Key=scores_file_key, + ) + + if counts_file_to_upload is not None: + save_to_logging_context({"num_counts": len(counts_file_to_upload)}) + counts_file_key = f"{score_set_id}/{user_id}/{timestamp}-{unique_id}-counts.csv" + s3.upload_fileobj( + Fileobj=io.BytesIO(counts_file_to_upload.to_csv(index=False).encode("utf-8")), + Bucket=CSV_UPLOAD_S3_BUCKET_NAME, + Key=counts_file_key, + ) + # Await the insertion of this job into the worker queue, not the job itself. # Uses provided score and counts dataframes and metadata files, or falls back to existing data on the score set if not provided. job = await worker.enqueue_job( @@ -143,8 +176,8 @@ async def enqueue_variant_creation( correlation_id_for_context(), item.id, user_data.user.id, - existing_scores_df if new_scores_df is None else new_scores_df, - existing_counts_df if new_counts_df is None else new_counts_df, + scores_file_to_upload, + counts_file_to_upload, item.dataset_columns.get("score_columns_metadata") if new_score_columns_metadata is None else new_score_columns_metadata, diff --git a/src/mavedb/worker/jobs/__init__.py b/src/mavedb/worker/jobs/__init__.py index 15614fd0..a7a86a58 100644 --- a/src/mavedb/worker/jobs/__init__.py +++ b/src/mavedb/worker/jobs/__init__.py @@ -32,14 +32,12 @@ from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set from mavedb.worker.jobs.variant_processing.mapping import ( map_variants_for_score_set, - variant_mapper_manager, ) __all__ = [ # Variant processing jobs "create_variants_for_score_set", "map_variants_for_score_set", - "variant_mapper_manager", # External service integration jobs "link_clingen_variants", "submit_score_set_mappings_to_car", diff --git a/src/mavedb/worker/jobs/data_management/py.typed b/src/mavedb/worker/jobs/data_management/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/data_management/views.py b/src/mavedb/worker/jobs/data_management/views.py index a6ddb2d6..24e5fac8 100644 --- a/src/mavedb/worker/jobs/data_management/views.py +++ b/src/mavedb/worker/jobs/data_management/views.py @@ -10,25 +10,105 @@ from mavedb.db.view import refresh_all_mat_views from mavedb.models.published_variant import PublishedVariantsMV -from mavedb.worker.jobs.utils.job_state import setup_job_state +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.job_guarantee import with_guaranteed_job_run_record +from mavedb.worker.lib.decorators.job_management import with_job_management +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) # TODO#405: Refresh materialized views within an executor. -async def refresh_materialized_views(ctx: dict): - logging_context = setup_job_state(ctx, None, None, None) - logger.debug(msg="Began refresh materialized views.", extra=logging_context) - refresh_all_mat_views(ctx["db"]) - ctx["db"].commit() - logger.debug(msg="Done refreshing materialized views.", extra=logging_context) - return {"success": True} - - -async def refresh_published_variants_view(ctx: dict, correlation_id: str): - logging_context = setup_job_state(ctx, None, None, correlation_id) - logger.debug(msg="Began refresh of published variants materialized view.", extra=logging_context) - PublishedVariantsMV.refresh(ctx["db"]) - ctx["db"].commit() - logger.debug(msg="Done refreshing published variants materialized view.", extra=logging_context) - return {"success": True} +@with_guaranteed_job_run_record("cron_job") +@with_job_management +async def refresh_materialized_views(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Refresh all materialized views in the database. + + This job refreshes all materialized views to ensure that they are up-to-date + with the latest data. It is typically run as a scheduled cron job and meant + to be invoked indirectly via a job queue system. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job run. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Refreshes all materialized views in the database. + + Returns: + dict: Result indicating success and any exception details + """ + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "refresh_materialized_views", + "resource": "all_materialized_views", + "correlation_id": None, + } + ) + job_manager.update_progress(0, 100, "Starting refresh of all materialized views.") + logger.debug(msg="Began refresh of all materialized views.", extra=job_manager.logging_context()) + + # Do refresh + refresh_all_mat_views(job_manager.db) + job_manager.db.commit() + + # Finalize job state + job_manager.update_progress(100, 100, "Completed refresh of all materialized views.") + logger.debug(msg="Done refreshing materialized views.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception_details": None} + + +@with_pipeline_management +async def refresh_published_variants_view(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Refresh the published variants materialized view. + + This job refreshes the PublishedVariantsMV materialized view to ensure that it + is up-to-date with the latest data. It is meant to be invoked as part of a job queue system. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job run. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Refreshes the PublishedVariantsMV materialized view in the database. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "refresh_published_variants_view", + "resource": "published_variants_materialized_view", + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting refresh of published variants materialized view.") + logger.info(msg="Started refresh of published variants materialized view", extra=job_manager.logging_context()) + + # Do refresh + PublishedVariantsMV.refresh(job_manager.db) + job_manager.db.commit() + + # Finalize job state + job_manager.update_progress(100, 100, "Completed refresh of published variants materialized view.") + logger.debug(msg="Done refreshing published variants materialized view.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/external_services/clingen.py b/src/mavedb/worker/jobs/external_services/clingen.py index 06a7c53d..56b7a5f9 100644 --- a/src/mavedb/worker/jobs/external_services/clingen.py +++ b/src/mavedb/worker/jobs/external_services/clingen.py @@ -12,17 +12,13 @@ import asyncio import functools import logging -from datetime import timedelta -from arq import ArqRedis from sqlalchemy import select -from sqlalchemy.orm import Session from mavedb.lib.clingen.constants import ( CAR_SUBMISSION_ENDPOINT, DEFAULT_LDH_SUBMISSION_BATCH_SIZE, LDH_SUBMISSION_ENDPOINT, - LINKED_DATA_RETRY_THRESHOLD, ) from mavedb.lib.clingen.content_constructors import construct_ldh_submission from mavedb.lib.clingen.services import ( @@ -32,606 +28,388 @@ get_allele_registry_associations, get_clingen_variation, ) -from mavedb.lib.exceptions import LinkingEnqueueError, SubmissionEnqueueError -from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.slack import send_slack_error, send_slack_message from mavedb.lib.variants import get_hgvs_from_post_mapped from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant -from mavedb.worker.jobs.utils.constants import ENQUEUE_BACKOFF_ATTEMPT_LIMIT, LINKING_BACKOFF_IN_SECONDS -from mavedb.worker.jobs.utils.job_state import setup_job_state -from mavedb.worker.jobs.utils.retry import enqueue_job_with_backoff +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) -async def submit_score_set_mappings_to_car(ctx: dict, correlation_id: str, score_set_id: int): - logging_context = {} - score_set = None - text = "Could not submit mappings to ClinGen Allele Registry for score set %s. Mappings for this score set should be submitted manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() +@with_pipeline_management +async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Submit mapped variants for a score set to the ClinGen Allele Registry (CAR). - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started CAR mapped resource submission", extra=logging_context) + This job registers mapped variants with CAR, assigns ClinGen Allele IDs (CAIDs), + and updates the database with the results. Progress is tracked throughout the submission. - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to submit CAR objects for this score set." + Required job_params in the JobRun: + - score_set_id (int): ID of the ScoreSet to process + - correlation_id (str): Correlation ID for tracking - logging_context["current_car_submission_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for CAR mapped resource submission.", extra=logging_context) + Args: + ctx (dict): Worker context containing DB and Redis connections + job_manager (JobManager): Manager for job lifecycle and DB operations - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) + Side Effects: + - Updates MappedVariant records with ClinGen Allele IDs + - Submits data to ClinGen Allele Registry - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="CAR mapped resource submission encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() - return {"success": False, "retried": False, "enqueued_job": None} + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) - try: - variant_post_mapped_objects = db.execute( - select(MappedVariant.id, MappedVariant.post_mapped) - .join(Variant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore - if not variant_post_mapped_objects: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", - extra=logging_context, - ) - return {"success": True, "retried": False, "enqueued_job": None} - - variant_post_mapped_hgvs: dict[str, list[int]] = {} - for mapped_variant_id, post_mapped in variant_post_mapped_objects: - hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) - - if not hgvs_for_post_mapped: - logger.warning( - msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant_id}. Skipping submission of this variant.", - extra=logging_context, - ) - continue - - if hgvs_for_post_mapped in variant_post_mapped_hgvs: - variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) - else: - variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to construct post mapped HGVS strings. This job will not be retried.", - extra=logging_context, + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "submit_score_set_mappings_to_car", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting CAR mapped resource submission.") + logger.info(msg="Started CAR mapped resource submission", extra=job_manager.logging_context()) + + # Fetch mapped variants with post-mapped data for the score set + variant_post_mapped_objects = job_manager.db.execute( + select(MappedVariant.id, MappedVariant.post_mapped) + .join(Variant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn) + .where(MappedVariant.post_mapped.is_not(None)) + .where(MappedVariant.current.is_(True)) + ).all() + + # Track total variants to submit + job_manager.save_to_context({"total_variants_to_submit_car": len(variant_post_mapped_objects)}) + if not variant_post_mapped_objects: + job_manager.update_progress(100, 100, "No mapped variants to submit to CAR. Skipped submission.") + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", + extra=job_manager.logging_context(), ) + return {"status": "ok", "data": {}, "exception_details": None} + job_manager.update_progress( + 10, 100, f"Preparing {len(variant_post_mapped_objects)} mapped variants for CAR submission." + ) - return {"success": False, "retried": False, "enqueued_job": None} + # Build HGVS strings for submission + variant_post_mapped_hgvs: dict[str, list[int]] = {} + for mapped_variant_id, post_mapped in variant_post_mapped_objects: + hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) - try: - if not CAR_SUBMISSION_ENDPOINT: + if not hgvs_for_post_mapped: logger.warning( - msg="ClinGen Allele Registry submission is disabled (no submission endpoint), skipping submission of mapped variants to CAR.", - extra=logging_context, + msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant_id}. Skipping submission of this variant.", + extra=job_manager.logging_context(), ) - return {"success": False, "retried": False, "enqueued_job": None} - - car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) - registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) - for hgvs_string, caid in linked_alleles.items(): - mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] - mapped_variants = db.scalars(select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids))).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = caid - db.add(mapped_variant) - - db.commit() - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - new_job_id = None - try: - new_job = await redis.enqueue_job( - "submit_score_set_mappings_to_ldh", - correlation_id, - score_set.id, - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["submit_clingen_ldh_variants_job_id"] = new_job_id - logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) + continue + if hgvs_for_post_mapped in variant_post_mapped_hgvs: + variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) else: - raise SubmissionEnqueueError() - - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not submit mappings to LDH for score set {score_set.urn}. Mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", - extra=logging_context, + variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] + job_manager.save_to_context({"unique_variants_to_submit_car": len(variant_post_mapped_hgvs)}) + job_manager.update_progress(15, 100, "Submitting mapped variants to CAR.") + + # Check for CAR submission endpoint + if not CAR_SUBMISSION_ENDPOINT: + job_manager.update_progress(100, 100, "CAR submission endpoint not configured. Skipping submission.") + logger.warning( + msg="ClinGen Allele Registry submission is disabled (no submission endpoint), skipping submission of mapped variants to CAR.", + extra=job_manager.logging_context(), ) + raise ValueError("ClinGen Allele Registry submission endpoint is not configured.") + + # Do submission + car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) + registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) + job_manager.update_progress(50, 100, "Processing registered alleles from CAR.") + + # Process registered alleles and update mapped variants + linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) + processed = 0 + total = len(linked_alleles) + for hgvs_string, caid in linked_alleles.items(): + mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] + mapped_variants = job_manager.db.scalars( + select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids)) + ).all() - return {"success": False, "retried": False, "enqueued_job": new_job_id} - - ctx["state"][ctx["job_id"]] = logging_context.copy() - return {"success": True, "retried": False, "enqueued_job": new_job_id} - - -async def submit_score_set_mappings_to_ldh(ctx: dict, correlation_id: str, score_set_id: int): - logging_context = {} - score_set = None - text = ( - "Could not submit mappings to LDH for score set %s. Mappings for this score set should be submitted manually." + # TODO: Track annotation progress. + for mapped_variant in mapped_variants: + mapped_variant.clingen_allele_id = caid + job_manager.db.add(mapped_variant) + processed += 1 + + # Calculate progress: 50% + (processed/total_mapped)*50, rounded to nearest 5% + if total % 20 == 0 or processed == total: + progress = 50 + round((processed / total) * 50 / 5) * 5 + job_manager.update_progress(progress, 100, f"Processed {processed} of {total} registered alleles.") + + # Finalize progress + job_manager.update_progress(100, 100, "Completed CAR mapped resource submission.") + job_manager.db.commit() + logger.info(msg="Completed CAR mapped resource submission", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception_details": None} + + +@with_pipeline_management +async def submit_score_set_mappings_to_ldh(ctx: dict, job_manager: JobManager) -> JobResultData: + """ + Submit mapped variants for a score set to the ClinGen Linked Data Hub (LDH). + + This job submits mapped variant data to LDH for a given score set, handling authentication, + submission batching, and error reporting. Progress and errors are logged and reported to Slack. + + Required job_params in the JobRun: + - score_set_id (int): ID of the ScoreSet to process + - correlation_id (str): Correlation ID for tracking + + Args: + ctx (dict): Worker context containing DB and Redis connections + job_manager (JobManager): Manager for job lifecycle and DB operations + + Side Effects: + - Submits data to ClinGen Linked Data Hub + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "submit_score_set_mappings_to_ldh", + "resource": score_set.urn, + "correlation_id": correlation_id, + } ) - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started LDH mapped resource submission", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to submit LDH objects for this score set." - - logging_context["current_ldh_submission_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for ldh mapped resource submission.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, + job_manager.update_progress(0, 100, "Starting LDH mapped resource submission.") + logger.info(msg="Started LDH mapped resource submission", extra=job_manager.logging_context()) + + # Connect to LDH service + ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) + ldh_service.authenticate() + + # Fetch mapped variants with post-mapped data for the score set + variant_objects = job_manager.db.execute( + select(Variant, MappedVariant) + .join(MappedVariant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn) + .where(MappedVariant.post_mapped.is_not(None)) + .where(MappedVariant.current.is_(True)) + ).all() + + # Track total variants to submit + job_manager.save_to_context({"total_variants_to_submit_ldh": len(variant_objects)}) + if not variant_objects: + job_manager.update_progress(100, 100, "No mapped variants to submit to LDH. Skipping submission.") + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", + extra=job_manager.logging_context(), ) + return {"status": "ok", "data": {}, "exception_details": None} + job_manager.update_progress(10, 100, f"Submitting {len(variant_objects)} mapped variants to LDH.") - return {"success": False, "retried": False, "enqueued_job": None} - - try: - ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) - ldh_service.authenticate() - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to authenticate to the LDH. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() + # Build submission content + variant_content = [] + for variant, mapped_variant in variant_objects: + variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) - if not variant_objects: + if not variation: logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", - extra=logging_context, + msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant.id}. Skipping submission of this variant.", + extra=job_manager.logging_context(), ) - return {"success": True, "retried": False, "enqueued_job": None} + continue - variant_content = [] - for variant, mapped_variant in variant_objects: - variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) + variant_content.append((variation, variant, mapped_variant)) - if not variation: - logger.warning( - msg=f"Could not construct a valid HGVS string for mapped variant {mapped_variant.id}. Skipping submission of this variant.", - extra=logging_context, - ) - continue + job_manager.save_to_context({"unique_variants_to_submit_ldh": len(variant_content)}) + job_manager.update_progress(30, 100, f"Dispatching submissions for {len(variant_content)} unique variants to LDH.") + submission_content = construct_ldh_submission(variant_content) - variant_content.append((variation, variant, mapped_variant)) - - submission_content = construct_ldh_submission(variant_content) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to construct submission objects. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - blocking = functools.partial( - ldh_service.dispatch_submissions, submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE - ) - loop = asyncio.get_running_loop() - submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission encountered an unexpected error while dispatching submissions. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - assert not submission_failures, f"{len(submission_failures)} submissions failed to be dispatched to the LDH." - logger.info(msg="Dispatched all variant mapping submissions to the LDH.", extra=logging_context) - except AssertionError as e: - send_slack_error(e) - send_slack_message( - text=f"{len(submission_failures)} submissions failed to be dispatched to the LDH for score set {score_set.urn}." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource submission failed to submit all mapping resources. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - new_job_id = None - try: - new_job = await redis.enqueue_job( - "link_clingen_variants", - correlation_id, - score_set.id, - 1, - _defer_by=timedelta(seconds=LINKING_BACKOFF_IN_SECONDS), - ) - - if new_job: - new_job_id = new_job.job_id - - logging_context["link_clingen_variants_job_id"] = new_job_id - logger.info(msg="Queued a new ClinGen linking job.", extra=logging_context) - - else: - raise LinkingEnqueueError() + blocking = functools.partial( + ldh_service.dispatch_submissions, submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE + ) + loop = asyncio.get_running_loop() + submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) + job_manager.update_progress(90, 100, "Finalizing LDH mapped resource submission.") - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + # TODO: Track submission successes and failures, add as annotation features. + if submission_failures: + job_manager.save_to_context({"ldh_submission_failures": len(submission_failures)}) logger.error( - msg="LDH mapped resource submission encountered an unexpected error while attempting to enqueue a linking job. This job will not be retried.", - extra=logging_context, + msg=f"LDH mapped resource submission encountered {len(submission_failures)} failures.", + extra=job_manager.logging_context(), ) - return {"success": False, "retried": False, "enqueued_job": new_job_id} - - return {"success": True, "retried": False, "enqueued_job": new_job_id} + # Finalize progress + job_manager.update_progress(100, 100, "Finalized LDH mapped resource submission.") + job_manager.db.commit() + return {"status": "ok", "data": {}, "exception_details": None} def do_clingen_fetch(variant_urns): return [(variant_urn, get_clingen_variation(variant_urn)) for variant_urn in variant_urns] -async def link_clingen_variants(ctx: dict, correlation_id: str, score_set_id: int, attempt: int) -> dict: - logging_context = {} - score_set = None - text = "Could not link mappings to LDH for score set %s. Mappings for this score set should be linked manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() +@with_pipeline_management +async def link_clingen_variants(ctx: dict, job_manager: JobManager) -> JobResultData: + """ + Link mapped variants to ClinGen Linked Data Hub (LDH) submissions. - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logging_context["linkage_retry_threshold"] = LINKED_DATA_RETRY_THRESHOLD - logging_context["attempt"] = attempt - logging_context["max_attempts"] = ENQUEUE_BACKOFF_ATTEMPT_LIMIT - logger.info(msg="Started LDH mapped resource linkage", extra=logging_context) + This job links mapped variant data to existing LDH data for a given score set. It fetches + LDH variations for each mapped variant and updates the database accordingly. Progress + and errors are logged throughout the process. - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to link LDH objects for this score set." + Required job_params in the JobRun: + - score_set_id (int): ID of the ScoreSet to process + - correlation_id (str): Correlation ID for tracking - logging_context["current_ldh_linking_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for ldh mapped resource linkage.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - try: - variant_urns = db.scalars( - select(Variant.urn) - .join(MappedVariant) - .join(ScoreSet) - .where( - ScoreSet.urn == score_set.urn, MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None) - ) - ).all() - num_variant_urns = len(variant_urns) - - logging_context["variants_to_link_ldh"] = num_variant_urns - - if not variant_urns: - logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH linkage (nothing to do). A gnomAD linkage job will not be enqueued, as no variants will have a CAID.", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - logger.info( - msg="Found current mapped variants with post mapped metadata for this score set. Attempting to link them to LDH submissions.", - extra=logging_context, - ) - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", - extra=logging_context, - ) + Args: + ctx (dict): Worker context containing DB and Redis connections + job_manager (JobManager): Manager for job lifecycle and DB operations - return {"success": False, "retried": False, "enqueued_job": None} + Side Effects: + - Updates MappedVariant records with ClinGen Allele IDs from LDH objects - try: - logger.info(msg="Attempting to link mapped variants to LDH submissions.", extra=logging_context) + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() - # TODO#372: Non-nullable variant urns. - blocking = functools.partial( - do_clingen_fetch, - variant_urns, # type: ignore - ) - loop = asyncio.get_running_loop() - linked_data = await loop.run_in_executor(ctx["pool"], blocking) + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(_job_required_params, job) - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore - return {"success": False, "retried": False, "enqueued_job": None} - - try: - linked_allele_ids = [ - (variant_urn, clingen_allele_id_from_ldh_variation(clingen_variation)) - for variant_urn, clingen_variation in linked_data - ] - - linkage_failures = [] - for variant_urn, ldh_variation in linked_allele_ids: - # XXX: Should we unlink variation if it is not found? Does this constitute a failure? - if not ldh_variation: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No LDH variation found.", - extra=logging_context, - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant = db.scalars( - select(MappedVariant).join(Variant).where(Variant.urn == variant_urn, MappedVariant.current.is_(True)) - ).one_or_none() - - if not mapped_variant: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No mapped variant found.", - extra=logging_context, - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant.clingen_allele_id = ldh_variation - db.add(mapped_variant) - - db.commit() - - except Exception as e: - db.rollback() - - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "link_clingen_variants", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting LDH mapped resource linkage.") + logger.info(msg="Started LDH mapped resource linkage", extra=job_manager.logging_context()) + + # Fetch mapped variants with post-mapped data for the score set + variant_urns = job_manager.db.scalars( + select(Variant.urn) + .join(MappedVariant) + .join(ScoreSet) + .where(ScoreSet.urn == score_set.urn, MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None)) + ).all() + num_variant_urns = len(variant_urns) + + job_manager.save_to_context({"total_variants_to_link_ldh": num_variant_urns}) + job_manager.update_progress(10, 100, f"Found {num_variant_urns} mapped variants to link to LDH submissions.") + + if not variant_urns: + job_manager.update_progress(100, 100, "No mapped variants to link to LDH submissions. Skipping linkage.") + logger.warning( + msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH linkage (nothing to do). A gnomAD linkage job will not be enqueued, as no variants will have a CAID.", + extra=job_manager.logging_context(), ) + return {"status": "ok", "data": {}, "exception_details": None} - return {"success": False, "retried": False, "enqueued_job": None} - - try: - num_linkage_failures = len(linkage_failures) - ratio_failed_linking = round(num_linkage_failures / num_variant_urns, 3) - logging_context["linkage_failure_rate"] = ratio_failed_linking - logging_context["linkage_failures"] = num_linkage_failures - logging_context["linkage_successes"] = num_variant_urns - num_linkage_failures - - assert ( - len(linked_allele_ids) == num_variant_urns - ), f"{num_variant_urns - len(linked_allele_ids)} appear to not have been attempted to be linked." + logger.info(msg="Attempting to link mapped variants to LDH submissions.", extra=job_manager.logging_context()) - job_succeeded = False - if not linkage_failures: - logger.info( - msg="Successfully linked all mapped variants to LDH submissions.", - extra=logging_context, - ) - - job_succeeded = True - - elif ratio_failed_linking < LINKED_DATA_RETRY_THRESHOLD: + # TODO#372: Non-nullable variant urns. + # Fetch linked data from LDH for each variant URN + blocking = functools.partial( + do_clingen_fetch, + variant_urns, # type: ignore + ) + loop = asyncio.get_running_loop() + linked_data = await loop.run_in_executor(ctx["pool"], blocking) + + linked_allele_ids = [ + (variant_urn, clingen_allele_id_from_ldh_variation(clingen_variation)) + for variant_urn, clingen_variation in linked_data + ] + job_manager.save_to_context({"ldh_variants_fetched": len(linked_allele_ids)}) + job_manager.update_progress(70, 100, "Fetched existing LDH variant data.") + logger.info(msg="Fetched existing LDH variant data.", extra=job_manager.logging_context()) + + # Link mapped variants to fetched LDH data + linkage_failures = [] + for variant_urn, ldh_variation in linked_allele_ids: + # XXX: Should we unlink variation if it is not found? Does this constitute a failure? + if not ldh_variation: logger.warning( - msg="Linkage failures exist, but did not exceed the retry threshold.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} mapped variants to LDH submissions for score set {score_set.urn}." - f"The retry threshold was not exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." + msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No LDH variation found.", + extra=job_manager.logging_context(), ) + linkage_failures.append(variant_urn) + continue - job_succeeded = True + mapped_variant = job_manager.db.scalars( + select(MappedVariant).join(Variant).where(Variant.urn == variant_urn, MappedVariant.current.is_(True)) + ).one_or_none() - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to finalize linkage. This job will not be retried.", - extra=logging_context, - ) - - return {"success": False, "retried": False, "enqueued_job": None} - - if job_succeeded: - gnomad_linking_job_id = None - try: - new_job = await redis.enqueue_job( - "link_gnomad_variants", - correlation_id, - score_set.id, + if not mapped_variant: + logger.warning( + msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No mapped variant found.", + extra=job_manager.logging_context(), ) + linkage_failures.append(variant_urn) + continue - if new_job: - gnomad_linking_job_id = new_job.job_id - - logging_context["link_gnomad_variants_job_id"] = gnomad_linking_job_id - logger.info(msg="Queued a new gnomAD linking job.", extra=logging_context) + mapped_variant.clingen_allele_id = ldh_variation + job_manager.db.add(mapped_variant) - else: - raise LinkingEnqueueError() + # TODO: Track annotation progress. Given the new progress model, we can better understand what linked and what didn't and + # can move away from the retry threshold model. - except Exception as e: - job_succeeded = False - - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to enqueue a gnomAD linking job. GnomAD variants should be linked manually for this score set. This job will not be retried.", - extra=logging_context, + # Calculate progress: 70% + (linked/total_variants)*30, rounded to nearest 5% + if len(linked_allele_ids) % 20 == 0 or len(linked_allele_ids) == num_variant_urns: + progress = 70 + round((len(linked_allele_ids) / num_variant_urns) * 30 / 5) * 5 + job_manager.update_progress( + progress, 100, f"Linked {len(linked_allele_ids)} of {num_variant_urns} variants." ) - finally: - return {"success": job_succeeded, "retried": False, "enqueued_job": gnomad_linking_job_id} - - # If we reach this point, we should consider the job failed (there were failures which exceeded our retry threshold). - new_job_id = None - max_retries_exceeded = None - try: - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - ctx["redis"], "variant_mapper_manager", attempt, LINKING_BACKOFF_IN_SECONDS, correlation_id - ) - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.critical( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to retry a failed linkage job. This job will not be retried.", - extra=logging_context, + job_manager.save_to_context({"ldh_linkage_failures": len(linkage_failures)}) + if linkage_failures: + logger.warning( + msg=f"LDH mapped resource linkage encountered {len(linkage_failures)} failures.", + extra=job_manager.logging_context(), ) - else: - if new_job_id and not max_retries_exceeded: - logger.info( - msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was queued.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking * 100}% of total mapped variants for {score_set.urn})." - f"This job was successfully retried. This was attempt {attempt}. Retry will occur in {backoff_time} seconds. URNs failed to link: {', '.join(linkage_failures)}." - ) - elif new_job_id is None and not max_retries_exceeded: - logger.error( - msg="After a failure condition while linking mapped variants to LDH submissions, another linkage job was unable to be queued.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." - f"This job could not be retried due to an unexpected issue while attempting to enqueue another linkage job. This was attempt {attempt}. URNs failed to link: {', '.join(linkage_failures)}." - ) - else: - logger.error( - msg="After a failure condition while linking mapped variants to LDH submissions, the maximum retries for this job were exceeded. The reamining linkage failures will not be retried.", - extra=logging_context, - ) - send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking} of total mapped variants for {score_set.urn})." - f"The retry threshold was exceeded and this job will not be retried. URNs failed to link: {', '.join(linkage_failures)}." - ) - finally: - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_job": new_job_id, - } + # Finalize progress + job_manager.update_progress(100, 100, "Finalized LDH mapped resource linkage.") + job_manager.db.commit() + return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/external_services/gnomad.py b/src/mavedb/worker/jobs/external_services/gnomad.py index 66be8fd9..e045d247 100644 --- a/src/mavedb/worker/jobs/external_services/gnomad.py +++ b/src/mavedb/worker/jobs/external_services/gnomad.py @@ -10,131 +10,115 @@ from typing import Sequence from sqlalchemy import select -from sqlalchemy.orm import Session from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants -from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.slack import send_slack_error, send_slack_message from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant -from mavedb.worker.jobs.utils.job_state import setup_job_state +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) -async def link_gnomad_variants(ctx: dict, correlation_id: str, score_set_id: int) -> dict: - logging_context = {} - score_set = None - text = "Could not link mappings to gnomAD variants for score set %s. Mappings for this score set should be linked manually." - try: - db: Session = ctx["db"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started gnomAD variant linkage", extra=logging_context) - - submission_urn = score_set.urn - assert submission_urn, "A valid URN is needed to link gnomAD objects for this score set." - - logging_context["current_gnomad_linking_resource"] = submission_urn - logger.debug(msg="Fetched score set metadata for gnomAD mapped resource linkage.", extra=logging_context) - - except Exception as e: - send_slack_error(e) - if score_set: - send_slack_message(text=text % score_set.urn) - else: - send_slack_message(text=text % score_set_id) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, +@with_pipeline_management +async def link_gnomad_variants(ctx: dict, job_manager: JobManager) -> JobResultData: + """ + Link mapped variants to gnomAD variants based on ClinGen Allele IDs (CAIDs). + This job fetches mapped variants associated with a given score set that have CAIDs, + retrieves corresponding gnomAD variant data, and establishes links between them + in the database. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet containing mapped variants to process. + - correlation_id (str): Correlation ID for tracing requests across services. + + Args: + ctx (dict): The job context dictionary. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Updates MappedVariant records to link to gnomAD variants. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(job_manager, _job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "link_gnomad_variants", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting gnomAD mapped resource linkage.") + logger.info(msg="Started gnomAD mapped resource linkage", extra=job_manager.logging_context()) + + # We filter out mapped variants that do not have a CAID, so this query is typed # as a Sequence[str]. Ignore MyPy's type checking here. + variant_caids: Sequence[str] = job_manager.db.scalars( + select(MappedVariant.clingen_allele_id) + .join(Variant) + .join(ScoreSet) + .where( + ScoreSet.urn == score_set.urn, + MappedVariant.current.is_(True), + MappedVariant.clingen_allele_id.is_not(None), ) + ).all() # type: ignore - return {"success": False, "retried": False, "enqueued_job": None} - - try: - # We filter out mapped variants that do not have a CAID, so this query is typed # as a Sequence[str]. Ignore MyPy's type checking here. - variant_caids: Sequence[str] = db.scalars( - select(MappedVariant.clingen_allele_id) - .join(Variant) - .join(ScoreSet) - .where( - ScoreSet.urn == score_set.urn, - MappedVariant.current.is_(True), - MappedVariant.clingen_allele_id.is_not(None), - ) - ).all() # type: ignore - num_variant_caids = len(variant_caids) - - logging_context["num_variants_to_link_gnomad"] = num_variant_caids - - if not variant_caids: - logger.warning( - msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=logging_context, - ) - - return {"success": True, "retried": False, "enqueued_job": None} - - logger.info( - msg="Found current mapped variants with CAIDs for this score set. Attempting to link them to gnomAD variants.", - extra=logging_context, - ) + num_variant_caids = len(variant_caids) + job_manager.save_to_context({"num_variants_to_link_gnomad": num_variant_caids}) - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to build linkage urn list. This job will not be retried.", - extra=logging_context, + if not variant_caids: + job_manager.update_progress(100, 100, "No variants with CAIDs found to link to gnomAD variants. Nothing to do.") + logger.warning( + msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", + extra=job_manager.logging_context(), ) + return {"status": "ok", "data": {}, "exception_details": None} - return {"success": False, "retried": False, "enqueued_job": None} + job_manager.update_progress(10, 100, f"Found {num_variant_caids} variants with CAIDs to link to gnomAD variants.") + logger.info( + msg="Found current mapped variants with CAIDs for this score set. Attempting to link them to gnomAD variants.", + extra=job_manager.logging_context(), + ) - try: - gnomad_variant_data = gnomad_variant_data_for_caids(variant_caids) - num_gnomad_variants_with_caid_match = len(gnomad_variant_data) - logging_context["num_gnomad_variants_with_caid_match"] = num_gnomad_variants_with_caid_match + # Fetch gnomAD variant data for the CAIDs + gnomad_variant_data = gnomad_variant_data_for_caids(variant_caids) + num_gnomad_variants_with_caid_match = len(gnomad_variant_data) - if not gnomad_variant_data: - logger.warning( - msg="No gnomAD variants with CAID matches were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=logging_context, - ) + job_manager.save_to_context({"num_gnomad_variants_with_caid_match": num_gnomad_variants_with_caid_match}) - return {"success": True, "retried": False, "enqueued_job": None} - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="gnomAD mapped resource linkage encountered an unexpected error while attempting to fetch gnomAD variant data from S3 via Athena. This job will not be retried.", - extra=logging_context, + if not gnomad_variant_data: + job_manager.update_progress(100, 100, "No gnomAD variants with CAID matches found. Nothing to link.") + logger.warning( + msg="No gnomAD variants with CAID matches were found for this score set. Skipping gnomAD linkage (nothing to do).", + extra=job_manager.logging_context(), ) - return {"success": False, "retried": False, "enqueued_job": None} - - try: - logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=logging_context) - num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(db, gnomad_variant_data) - db.commit() - logging_context["num_mapped_variants_linked_to_gnomad_variants"] = num_linked_gnomad_variants - - except Exception as e: - send_slack_error(e) - send_slack_message(text=text % score_set.urn) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="LDH mapped resource linkage encountered an unexpected error while attempting to link LDH submissions. This job will not be retried.", - extra=logging_context, - ) + return {"status": "ok", "data": {}, "exception_details": None} + job_manager.update_progress(75, 100, f"Found {num_gnomad_variants_with_caid_match} gnomAD variants matching CAIDs.") - return {"success": False, "retried": False, "enqueued_job": None} + # Link mapped variants to gnomAD variants + logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=job_manager.logging_context()) + num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(job_manager.db, gnomad_variant_data) + job_manager.db.commit() - logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=logging_context) - return {"success": True, "retried": False, "enqueued_job": None} + # Save final context and progress + job_manager.save_to_context({"num_mapped_variants_linked_to_gnomad_variants": num_linked_gnomad_variants}) + job_manager.update_progress(100, 100, f"Linked {num_linked_gnomad_variants} mapped variants to gnomAD variants.") + logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/external_services/py.typed b/src/mavedb/worker/jobs/external_services/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py index a72cf9e2..713cd60f 100644 --- a/src/mavedb/worker/jobs/external_services/uniprot.py +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -9,222 +9,236 @@ """ import logging -from typing import Optional -from arq import ArqRedis from sqlalchemy import select -from sqlalchemy.orm import Session from mavedb.lib.exceptions import UniProtPollingEnqueueError -from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.lib.mapping import extract_ids_from_post_mapped_metadata -from mavedb.lib.slack import log_and_send_slack_message, send_slack_error +from mavedb.lib.slack import log_and_send_slack_message from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession +from mavedb.models.job_dependency import JobDependency from mavedb.models.score_set import ScoreSet -from mavedb.worker.jobs.utils.job_state import setup_job_state +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) -async def submit_uniprot_mapping_jobs_for_score_set(ctx, score_set_id: int, correlation_id: Optional[str] = None): - logging_context = {} - score_set = None - spawned_mapping_jobs: dict[int, Optional[str]] = {} - text = "Could not submit mapping jobs to UniProt for this score set %s. Mapping jobs for this score set should be submitted manually." - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started UniProt mapping job", extra=logging_context) - - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set_id}. Skipped mapping targets to UniProt." - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) - - return {"success": True, "retried": False, "enqueued_jobs": []} - - except Exception as e: - send_slack_error(e) - if score_set: - msg = text % score_set.urn - else: - msg = text % score_set_id - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - try: - uniprot_api = UniProtIDMappingAPI() - logging_context["total_target_genes_to_map_to_uniprot"] = len(score_set.target_genes) - for target_gene in score_set.target_genes: - spawned_mapping_jobs[target_gene.id] = None # type: ignore - - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - ac_to_map = acs[0] - from_db = infer_db_name_from_sequence_accession(ac_to_map) - - try: - spawned_mapping_jobs[target_gene.id] = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore - except Exception as e: - log_and_send_slack_message( - msg=f"Failed to submit UniProt mapping job for target gene {target_gene.id}: {e}. This target will be skipped.", - ctx=logging_context, - level=logging.WARNING, - ) - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg=f"UniProt mapping job encountered an unexpected error while attempting to submit mapping jobs for score set {score_set.urn}. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, +@with_pipeline_management +async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobManager) -> JobResultData: + """Submit UniProt ID mapping jobs for all target genes in a given ScoreSet. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet containing target genes to map. + - correlation_id (str): Correlation ID for tracing requests across services. + + Args: + ctx (dict): The job context dictionary. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Submits UniProt ID mapping jobs for each target gene in the ScoreSet. + - Fetches the dependent job for this function, which is the polling job for UniProt results. + Sets the parameter `mapping_jobs` on the polling job with a dictionary of target gene IDs to UniProt job IDs. + TODO#XXX: Split mapping jobs into one per target gene so that polling can be more granular. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id"] + validate_job_params(job_manager, _job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "submit_uniprot_mapping_jobs_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting UniProt mapping job submission.") + logger.info(msg="Started UniProt mapping job submission", extra=job_manager.logging_context()) + + if not score_set or not score_set.target_genes: + job_manager.update_progress(100, 100, "No target genes found. Skipped UniProt mapping job submission.") + msg = f"No target genes for score set {score_set.id}. Skipped mapping targets to UniProt." + log_and_send_slack_message(msg=msg, ctx=job_manager.logging_context(), level=logging.WARNING) + return {"status": "ok", "data": {}, "exception_details": None} + + uniprot_api = UniProtIDMappingAPI() + job_manager.save_to_context({"total_target_genes_to_map_to_uniprot": len(score_set.target_genes)}) + + mapping_jobs = {} + for idx, target_gene in enumerate(score_set.target_genes): + acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore + if not acs: + msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + if len(acs) != 1: + msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + ac_to_map = acs[0] + from_db = infer_db_name_from_sequence_accession(ac_to_map) + spawned_job = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore + mapping_jobs[target_gene.id] = {"job_id": spawned_job, "accession_mapped": ac_to_map} + + job_manager.save_to_context( + { + "submitted_uniprot_mapping_jobs": { + **job_manager.logging_context().get("submitted_uniprot_mapping_jobs", {}), + target_gene.id: mapping_jobs[target_gene.id], + } + } ) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - new_job_id = None - try: - successfully_spawned_mapping_jobs = sum(1 for job in spawned_mapping_jobs.values() if job is not None) - logging_context["successfully_spawned_mapping_jobs"] = successfully_spawned_mapping_jobs - - if not successfully_spawned_mapping_jobs: - msg = f"No UniProt mapping jobs were successfully spawned for score set {score_set.urn}. Skipped enqueuing polling job." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - return {"success": True, "retried": False, "enqueued_jobs": []} - - new_job = await redis.enqueue_job( - "poll_uniprot_mapping_jobs_for_score_set", - spawned_mapping_jobs, - score_set_id, - correlation_id, + logger.info( + msg=f"Submitted UniProt ID mapping job for target gene {target_gene.id}.", + extra=job_manager.logging_context(), + ) + job_manager.update_progress( + int((idx + 1 / len(score_set.target_genes)) * 100), + 100, + f"Submitted UniProt mapping job for target gene {target_gene.name}.", ) - if new_job: - new_job_id = new_job.job_id - - logging_context["poll_uniprot_mapping_job_id"] = new_job_id - logger.info(msg="Enqueued polling jobs for UniProt mapping jobs.", extra=logging_context) - - else: - raise UniProtPollingEnqueueError() + # Set mapping jobs on dependent polling job. Only one polling job per score set should be created. + dependent_polling_job = job_manager.db.scalars( + select(JobDependency).where(JobDependency.depends_on_job_id == job.id) + ).all() - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg="UniProt mapping job encountered an unexpected error while attempting to enqueue polling jobs for mapping jobs. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, + if not dependent_polling_job or len(dependent_polling_job) != 1: + raise UniProtPollingEnqueueError( + f"Could not find unique dependent polling job for UniProt mapping job {job.id}." ) - return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} - - return {"success": True, "retried": False, "enqueued_jobs": [job for job in [new_job_id] if job]} - - -async def poll_uniprot_mapping_jobs_for_score_set( - ctx, mapping_jobs: dict[int, Optional[str]], score_set_id: int, correlation_id: Optional[str] = None -): - logging_context = {} - score_set = None - text = "Could not poll mapping jobs from UniProt for this Target %s. Mapping jobs for this score set should be submitted manually." - try: - db: Session = ctx["db"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, None, score_set.urn, correlation_id) - logger.info(msg="Started UniProt polling job", extra=logging_context) - - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set_id}. Skipped polling targets for UniProt mapping results." - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.WARNING) - - return {"success": True, "retried": False, "enqueued_jobs": []} - - except Exception as e: - send_slack_error(e) - if score_set: - msg = text % score_set.urn - else: - msg = text % score_set_id - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message(msg=msg, ctx=logging_context, level=logging.ERROR) - - return {"success": False, "retried": False, "enqueued_jobs": []} - - try: - uniprot_api = UniProtIDMappingAPI() - for target_gene in score_set.target_genes: - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - mapped_ac = acs[0] - job_id = mapping_jobs.get(target_gene.id) # type: ignore - - if not job_id: - msg = f"No job ID found for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - # This issue has already been sent to Slack in the job submission function, so we just log it here. - logger.debug(msg=msg, extra=logging_context) - continue - - if not uniprot_api.check_id_mapping_results_ready(job_id): - msg = f"Job {job_id} not ready for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target" - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - results = uniprot_api.get_id_mapping_results(job_id) - mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) - - if not mapped_ids: - msg = f"No UniProt ID found for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - if len(mapped_ids) != 1: - msg = f"Found ambiguous Uniprot ID mapping results for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, logging_context, logging.WARNING) - continue - - mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] - target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id - db.add(target_gene) - logger.info( - msg=f"Updated target gene {target_gene.id} with UniProt ID {mapped_uniprot_id}", extra=logging_context - ) - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - log_and_send_slack_message( - msg="UniProt mapping job encountered an unexpected error while attempting to poll mapping jobs. This job will not be retried.", - ctx=logging_context, - level=logging.ERROR, + polling_job = dependent_polling_job[0].job_run + polling_job.job_params = { + **(polling_job.job_params or {}), + "mapping_jobs": { + target_gene_id: mapping_info["job_id"] for target_gene_id, mapping_info in mapping_jobs.items() + }, + } + job_manager.db.add(polling_job) + job_manager.update_progress(100, 100, "Completed submission of UniProt mapping jobs.") + job_manager.db.commit() + return {"status": "ok", "data": {}, "exception_details": None} + + +@with_pipeline_management +async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobManager) -> JobResultData: + """Submit UniProt ID mapping jobs for all target genes in a given ScoreSet. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet containing target genes to map. + - correlation_id (str): Correlation ID for tracing requests across services. + - mapping_jobs (dict): Dictionary of target gene IDs to UniProt job IDs. + + Args: + ctx (dict): The job context dictionary. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + TODO#XXX: Split mapping jobs into one per target gene so that polling can be more granular. + + Returns: + dict: Result indicating success and any exception details + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id", "mapping_jobs"] + validate_job_params(job_manager, _job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + mapping_jobs = job.job_params.get("mapping_jobs", {}) # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "poll_uniprot_mapping_jobs_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting UniProt mapping job polling.") + logger.info(msg="Started UniProt mapping job polling", extra=job_manager.logging_context()) + + if not score_set or not score_set.target_genes: + msg = f"No target genes for score set {score_set.id}. Skipped polling targets for UniProt mapping results." + log_and_send_slack_message(msg=msg, ctx=job_manager.logging_context(), level=logging.WARNING) + + return {"status": "ok", "data": {}, "exception_details": None} + + # Poll each mapping job and update target genes with UniProt IDs + uniprot_api = UniProtIDMappingAPI() + for target_gene in score_set.target_genes: + acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore + if not acs: + msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + if len(acs) != 1: + msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + mapped_ac = acs[0] + job_id = mapping_jobs.get(target_gene.id) # type: ignore + + if not job_id: + msg = f"No job ID found for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." + # This issue has already been sent to Slack in the job submission function, so we just log it here. + logger.debug(msg=msg, extra=job_manager.logging_context()) + continue + + if not uniprot_api.check_id_mapping_results_ready(job_id): + msg = f"Job {job_id} not ready for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target" + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + results = uniprot_api.get_id_mapping_results(job_id) + mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) + + if not mapped_ids: + msg = f"No UniProt ID found for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + if len(mapped_ids) != 1: + msg = f"Found ambiguous Uniprot ID mapping results for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." + log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + continue + + mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] + target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id + job_manager.db.add(target_gene) + logger.info( + msg=f"Updated target gene {target_gene.id} with UniProt ID {mapped_uniprot_id}", + extra=job_manager.logging_context(), + ) + job_manager.update_progress( + int((list(score_set.target_genes).index(target_gene) + 1 / len(score_set.target_genes)) * 100), + 100, + f"Polled UniProt mapping job for target gene {target_gene.name}.", ) - return {"success": False, "retried": False, "enqueued_jobs": []} - - db.commit() - return {"success": True, "retried": False, "enqueued_jobs": []} + job_manager.update_progress(100, 100, "Completed polling of UniProt mapping jobs.") + job_manager.db.commit() + return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py index a79ed3fa..06ae2b29 100644 --- a/src/mavedb/worker/jobs/registry.py +++ b/src/mavedb/worker/jobs/registry.py @@ -24,7 +24,6 @@ from mavedb.worker.jobs.variant_processing import ( create_variants_for_score_set, map_variants_for_score_set, - variant_mapper_manager, ) # All job functions for ARQ worker @@ -32,7 +31,6 @@ # Variant processing jobs create_variants_for_score_set, map_variants_for_score_set, - variant_mapper_manager, # External service jobs submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, diff --git a/src/mavedb/worker/jobs/utils/__init__.py b/src/mavedb/worker/jobs/utils/__init__.py index a63687b8..4bdb3409 100644 --- a/src/mavedb/worker/jobs/utils/__init__.py +++ b/src/mavedb/worker/jobs/utils/__init__.py @@ -16,12 +16,10 @@ MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME, ) -from .job_state import setup_job_state -from .retry import enqueue_job_with_backoff +from .setup import validate_job_params __all__ = [ - "setup_job_state", - "enqueue_job_with_backoff", + "validate_job_params", "MAPPING_QUEUE_NAME", "MAPPING_CURRENT_ID_NAME", "MAPPING_BACKOFF_IN_SECONDS", diff --git a/src/mavedb/worker/jobs/utils/job_state.py b/src/mavedb/worker/jobs/utils/job_state.py deleted file mode 100644 index 33c6887b..00000000 --- a/src/mavedb/worker/jobs/utils/job_state.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Job state management utilities. - -This module provides utilities for managing job state and context across -the worker job lifecycle. It handles setup of logging context, correlation -IDs, and other state information needed for job traceability and monitoring. -""" - -import logging -from typing import Any, Optional - -logger = logging.getLogger(__name__) - - -def setup_job_state( - ctx, invoker: Optional[int], resource: Optional[str], correlation_id: Optional[str] -) -> dict[str, Any]: - """ - Initialize and store job state information in the context dictionary for traceability. - - Args: - ctx: The job context dictionary, must contain 'state' and 'job_id' keys. - invoker: The user ID or identifier who initiated the job (may be None). - resource: The resource string associated with the job (may be None). - correlation_id: Optional correlation ID for tracing requests across services. - - Returns: - dict[str, Any]: The job state dictionary for the current job_id. - """ - ctx["state"][ctx["job_id"]] = { - "application": "mavedb-worker", - "user": invoker, - "resource": resource, - "correlation_id": correlation_id, - } - return ctx["state"][ctx["job_id"]] diff --git a/src/mavedb/worker/jobs/utils/py.typed b/src/mavedb/worker/jobs/utils/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/jobs/utils/retry.py b/src/mavedb/worker/jobs/utils/retry.py deleted file mode 100644 index 5150d95b..00000000 --- a/src/mavedb/worker/jobs/utils/retry.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Retry and backoff utilities for job error handling. - -This module provides utilities for implementing exponential backoff and -retry logic for failed jobs. It helps ensure reliable job execution -by automatically retrying transient failures with appropriate delays. -""" - -import logging -from datetime import timedelta -from typing import Any, Optional - -from arq import ArqRedis - -from mavedb.worker.jobs.utils.constants import ENQUEUE_BACKOFF_ATTEMPT_LIMIT - -logger = logging.getLogger(__name__) - - -async def enqueue_job_with_backoff( - redis: ArqRedis, job_name: str, attempt: int, backoff: int, *args -) -> tuple[Optional[str], bool, Any]: - """ - Enqueue a job with exponential backoff and attempt tracking, for robust retry logic. - - Args: - redis (ArqRedis): The Redis connection for job queueing. - job_name (str): The name of the job to enqueue. - attempt (int): The current attempt number (used for backoff calculation). - backoff (int): The base backoff time in seconds. - *args: Additional arguments to pass to the job. - - Returns: - tuple[Optional[str], bool, Any]: - - The new job ID if enqueued, else None. - - Boolean indicating if the backoff limit was NOT reached (True if retry scheduled). - - The updated backoff value (seconds). - - Notes: - - If the attempt exceeds ENQUEUE_BACKOFF_ATTEMPT_LIMIT, no job is enqueued and limit is considered reached. - - The attempt value is incremented and passed as the last argument to the job. - - The job is deferred by the calculated backoff time. - """ - new_job_id = None - limit_reached = attempt > ENQUEUE_BACKOFF_ATTEMPT_LIMIT - if not limit_reached: - limit_reached = True - backoff = backoff * (2**attempt) - attempt = attempt + 1 - - # NOTE: for jobs supporting backoff, `attempt` should be the final argument. - new_job = await redis.enqueue_job( - job_name, - *args, - attempt, - _defer_by=timedelta(seconds=backoff), - ) - - if new_job: - new_job_id = new_job.job_id - - return (new_job_id, not limit_reached, backoff) diff --git a/src/mavedb/worker/jobs/utils/setup.py b/src/mavedb/worker/jobs/utils/setup.py new file mode 100644 index 00000000..b569bb0e --- /dev/null +++ b/src/mavedb/worker/jobs/utils/setup.py @@ -0,0 +1,24 @@ +"""Job state management utilities. + +This module provides utilities for managing job state and context across +the worker job lifecycle. It handles setup of logging context, correlation +IDs, and other state information needed for job traceability and monitoring. +""" + +import logging + +from mavedb.models.job_run import JobRun + +logger = logging.getLogger(__name__) + + +def validate_job_params(required_params: list[str], job: JobRun) -> None: + """ + Validate that the given job has all required parameters present in its job_params. + """ + if not job.job_params: + raise ValueError("Job has no job_params defined.") + + for param in required_params: + if param not in job.job_params: + raise ValueError(f"Missing required job param: {param}") diff --git a/src/mavedb/worker/jobs/variant_processing/__init__.py b/src/mavedb/worker/jobs/variant_processing/__init__.py index b9085659..a6df0975 100644 --- a/src/mavedb/worker/jobs/variant_processing/__init__.py +++ b/src/mavedb/worker/jobs/variant_processing/__init__.py @@ -9,11 +9,9 @@ from .creation import create_variants_for_score_set from .mapping import ( map_variants_for_score_set, - variant_mapper_manager, ) __all__ = [ "create_variants_for_score_set", "map_variants_for_score_set", - "variant_mapper_manager", ] diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py index 3064581b..f71c5ed8 100644 --- a/src/mavedb/worker/jobs/variant_processing/creation.py +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -6,73 +6,113 @@ """ import logging -from typing import Optional -import pandas as pd -from arq import ArqRedis from sqlalchemy import delete, null, select -from sqlalchemy.orm import Session from mavedb.data_providers.services import RESTDataProvider from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.lib.score_sets import columns_for_dataset, create_variants, create_variants_data -from mavedb.lib.slack import send_slack_error from mavedb.lib.validation.dataframe.dataframe import validate_and_standardize_dataframe_pair -from mavedb.lib.validation.exceptions import ValidationError from mavedb.models.enums.mapping_state import MappingState from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.user import User from mavedb.models.variant import Variant -from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata -from mavedb.worker.jobs.utils.constants import MAPPING_QUEUE_NAME -from mavedb.worker.jobs.utils.job_state import setup_job_state +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) -async def create_variants_for_score_set( - ctx, - correlation_id: str, - score_set_id: int, - updater_id: int, - scores: pd.DataFrame, - counts: pd.DataFrame, - score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, - count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, -): +@with_pipeline_management +async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResultData: """ - Create variants for a score set. Intended to be run within a worker. - On any raised exception, ensure ProcessingState of score set is set to `failed` prior - to exiting. + Create variants for a given ScoreSet based on uploaded score and count data. + + Args: + ctx: The job context dictionary. + job_manager: Manager for job lifecycle and DB operations. + + Job Parameters: + - score_set_id (int): The ID of the ScoreSet to create variants for. + - correlation_id (str): Correlation ID for tracing requests across services. + - updater_id (int): The ID of the user performing the update. + - scores (pd.DataFrame): DataFrame containing score data. + - counts (pd.DataFrame): DataFrame containing count data. + - score_columns_metadata (dict): Metadata for score columns. + - count_columns_metadata (dict): Metadata for count columns. + + Side Effects: + - Creates Variant and MappedVariant records in the database. + + Returns: + dict: Result indicating success and any exception details """ - logging_context = {} - try: - db: Session = ctx["db"] - hdp: RESTDataProvider = ctx["hdp"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) - logger.info(msg="Began processing of score set variants.", extra=logging_context) + hdp: RESTDataProvider = ctx["hdp"] + + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = [ + "score_set_id", + "correlation_id", + "updater_id", + "scores", + "counts", + "score_columns_metadata", + "count_columns_metadata", + ] + validate_job_params(job_manager, _job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + updater_id = job.job_params["updater_id"] # type: ignore + scores = job.job_params["scores"] # type: ignore + counts = job.job_params["counts"] # type: ignore + score_columns_metadata = job.job_params["score_columns_metadata"] # type: ignore + count_columns_metadata = job.job_params["count_columns_metadata"] # type: ignore + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "create_variants_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting variant creation job.") + logger.info(msg="Started variant creation job", extra=job_manager.logging_context()) - updated_by = db.scalars(select(User).where(User.id == updater_id)).one() + updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() + # Main processing block. Handled in a try/except to ensure we can set score set state appropriately, + # which is handled independently of the job state. + # TODO:XXX In a future iteration, we may want to move this logic into the job manager itself for better cohesion. + try: score_set.modified_by = updated_by score_set.processing_state = ProcessingState.processing score_set.mapping_state = MappingState.pending_variant_processing - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - db.add(score_set) - db.commit() - db.refresh(score_set) + job_manager.save_to_context( + {"processing_state": score_set.processing_state.name, "mapping_state": score_set.mapping_state.name} + ) + + job_manager.db.add(score_set) + job_manager.db.commit() + job_manager.db.refresh(score_set) + + job_manager.update_progress(10, 100, "Validated score set metadata and beginning data validation.") if not score_set.target_genes: + job_manager.update_progress(100, 100, "Score set has no targets; cannot create variants.") logger.warning( msg="No targets are associated with this score set; could not create variants.", - extra=logging_context, + extra=job_manager.logging_context(), ) raise ValueError("Can't create variants when score set has no targets.") @@ -87,6 +127,8 @@ async def create_variants_for_score_set( ) ) + job_manager.update_progress(80, 100, "Data validation complete; creating variants in database.") + score_set.dataset_columns = { "score_columns": columns_for_dataset(validated_scores), "count_columns": columns_for_dataset(validated_counts), @@ -98,47 +140,31 @@ async def create_variants_for_score_set( else {}, } + job_manager.update_progress(90, 100, "Creating variants in database.") + # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. if score_set.variants: - existing_variants = db.scalars(select(Variant.id).where(Variant.score_set_id == score_set.id)).all() - db.execute(delete(MappedVariant).where(MappedVariant.variant_id.in_(existing_variants))) - db.execute(delete(Variant).where(Variant.id.in_(existing_variants))) - logging_context["deleted_variants"] = score_set.num_variants + existing_variants = job_manager.db.scalars( + select(Variant.id).where(Variant.score_set_id == score_set.id) + ).all() + job_manager.db.execute(delete(MappedVariant).where(MappedVariant.variant_id.in_(existing_variants))) + job_manager.db.execute(delete(Variant).where(Variant.id.in_(existing_variants))) + + job_manager.save_to_context({"deleted_variants": len(existing_variants)}) score_set.num_variants = 0 - logger.info(msg="Deleted existing variants from score set.", extra=logging_context) + logger.info(msg="Deleted existing variants from score set.", extra=job_manager.logging_context()) - db.flush() - db.refresh(score_set) + job_manager.db.flush() + job_manager.db.refresh(score_set) variants_data = create_variants_data(validated_scores, validated_counts, None) - create_variants(db, score_set, variants_data) - - # Validation errors arise from problematic user data. These should be inserted into the database so failures can - # be persisted to them. - except ValidationError as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.processing_errors = {"exception": str(e), "detail": e.triggering_exceptions} - score_set.mapping_state = MappingState.not_attempted - - if score_set.num_variants: - score_set.processing_errors["exception"] = ( - f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" - ) - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.warning(msg="Encountered a validation error while processing variants.", extra=logging_context) - - return {"success": False} + create_variants(job_manager.db, score_set, variants_data) # NOTE: Since these are likely to be internal errors, it makes less sense to add them to the DB and surface them to the end user. - # Catch all non-system exiting exceptions. + # Catch all exceptions so we can log them and set score set state appropriately. except Exception as e: - db.rollback() + job_manager.db.rollback() score_set.processing_state = ProcessingState.failed score_set.processing_errors = {"exception": str(e), "detail": []} score_set.mapping_state = MappingState.not_attempted @@ -148,49 +174,40 @@ async def create_variants_for_score_set( f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 - logger.warning(msg="Encountered an internal exception while processing variants.", extra=logging_context) - - send_slack_error(err=e) - return {"success": False} - - # Catch all other exceptions. The exceptions caught here were intented to be system exiting. - except BaseException as e: - db.rollback() - score_set.processing_state = ProcessingState.failed - score_set.mapping_state = MappingState.not_attempted - db.commit() - - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logging_context["processing_state"] = score_set.processing_state.name - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["created_variants"] = 0 + job_manager.save_to_context( + { + "processing_state": score_set.processing_state.name, + "mapping_state": score_set.mapping_state.name, + **format_raised_exception_info_as_dict(e), + "created_variants": 0, + } + ) + job_manager.update_progress(100, 100, "Variant creation job failed due to an internal error.") logger.error( - msg="Encountered an unhandled exception while creating variants for score set.", extra=logging_context + msg="Encountered an internal exception while processing variants.", extra=job_manager.logging_context() ) - # Don't raise BaseExceptions so we may emit canonical logs (TODO: Perhaps they are so problematic we want to raise them anyway). - return {"success": False} + raise e else: score_set.processing_state = ProcessingState.success + score_set.mapping_state = MappingState.queued score_set.processing_errors = null() - logging_context["created_variants"] = score_set.num_variants - logging_context["processing_state"] = score_set.processing_state.name - logger.info(msg="Finished creating variants in score set.", extra=logging_context) + job_manager.save_to_context( + { + "processing_state": score_set.processing_state.name, + "mapping_state": score_set.mapping_state.name, + "created_variants": score_set.num_variants, + } + ) - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - await redis.enqueue_job("variant_mapper_manager", correlation_id, updater_id) - score_set.mapping_state = MappingState.queued finally: - db.add(score_set) - db.commit() - db.refresh(score_set) - logger.info(msg="Committed new variants to score set.", extra=logging_context) + job_manager.db.add(score_set) + job_manager.db.commit() + job_manager.db.refresh(score_set) + + job_manager.update_progress(100, 100, "Completed variant creation job.") + logger.info(msg="Committed new variants to score set.", extra=job_manager.logging_context()) - ctx["state"][ctx["job_id"]] = logging_context.copy() - return {"success": True} + return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/variant_processing/mapping.py b/src/mavedb/worker/jobs/variant_processing/mapping.py index 91c6f0fe..848c7b06 100644 --- a/src/mavedb/worker/jobs/variant_processing/mapping.py +++ b/src/mavedb/worker/jobs/variant_processing/mapping.py @@ -8,562 +8,308 @@ import asyncio import functools import logging -from contextlib import asynccontextmanager -from datetime import date, timedelta +from datetime import date from typing import Any -from arq import ArqRedis -from arq.jobs import Job, JobStatus from sqlalchemy import cast, null, select from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Session from mavedb.data_providers.services import vrs_mapper -from mavedb.lib.clingen.constants import CLIN_GEN_SUBMISSION_ENABLED from mavedb.lib.exceptions import ( - MappingEnqueueError, NonexistentMappingReferenceError, NonexistentMappingResultsError, - SubmissionEnqueueError, - UniProtIDMappingEnqueueError, + NonexistentMappingScoresError, ) from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.lib.mapping import ANNOTATION_LAYERS -from mavedb.lib.slack import send_slack_error, send_slack_message -from mavedb.lib.uniprot.constants import UNIPROT_ID_MAPPING_ENABLED +from mavedb.lib.slack import send_slack_error from mavedb.models.enums.mapping_state import MappingState from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User from mavedb.models.variant import Variant -from mavedb.worker.jobs.utils.constants import MAPPING_BACKOFF_IN_SECONDS, MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME -from mavedb.worker.jobs.utils.job_state import setup_job_state -from mavedb.worker.jobs.utils.retry import enqueue_job_with_backoff +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData logger = logging.getLogger(__name__) -@asynccontextmanager -async def mapping_in_execution(redis: ArqRedis, job_id: str): - await redis.set(MAPPING_CURRENT_ID_NAME, job_id) - try: - yield - finally: - await redis.set(MAPPING_CURRENT_ID_NAME, "") +@with_pipeline_management +async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobResultData: + """Map variants for a given score set using VRS.""" + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = [ + "score_set_id", + "correlation_id", + "updater_id", + ] + validate_job_params(job_manager, _job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + updater_id = job.job_params["updater_id"] # type: ignore + updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "map_variants_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting variant mapping job.") + logger.info(msg="Started variant mapping job", extra=job_manager.logging_context()) + # TODO#372: non-nullable URNs + if not score_set.urn: + raise ValueError("Score set URN is required for variant mapping.") -async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int, attempt: int = 1) -> dict: - logging_context = {} - mapping_job_id = None - mapping_job_status = None - queued_score_set = None + # Handle everything within try/except to persist appropriate mapping state try: - redis: ArqRedis = ctx["redis"] - db: Session = ctx["db"] - - logging_context = setup_job_state(ctx, updater_id, None, correlation_id) - logging_context["attempt"] = attempt - logger.debug(msg="Variant mapping manager began execution", extra=logging_context) - - queue_length = await redis.llen(MAPPING_QUEUE_NAME) # type: ignore - queued_id = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore - logging_context["variant_mapping_queue_length"] = queue_length - - # Setup the job id cache if it does not already exist. - if not await redis.exists(MAPPING_CURRENT_ID_NAME): - await redis.set(MAPPING_CURRENT_ID_NAME, "") - - if not queued_id: - logger.debug(msg="No mapping jobs exist in the queue.", extra=logging_context) - return {"success": True, "enqueued_job": None} - else: - queued_id = queued_id.decode("utf-8") - queued_score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_id)).one() + # Setup score set state for mapping + score_set.mapping_state = MappingState.processing + score_set.mapping_errors = null() + score_set.modified_by = updated_by + score_set.modification_date = date.today() - logging_context["upcoming_mapping_resource"] = queued_score_set.urn - logger.debug(msg="Found mapping job(s) still in queue.", extra=logging_context) + job_manager.db.add(score_set) + job_manager.db.commit() - mapping_job_id = await redis.get(MAPPING_CURRENT_ID_NAME) - if mapping_job_id: - mapping_job_id = mapping_job_id.decode("utf-8") - mapping_job_status = (await Job(job_id=mapping_job_id, redis=redis).status()).value + job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) + job_manager.update_progress(10, 100, "Score set prepared for variant mapping.") + logger.debug(msg="Score set prepared for variant mapping.", extra=job_manager.logging_context()) - logging_context["existing_mapping_job_status"] = mapping_job_status - logging_context["existing_mapping_job_id"] = mapping_job_id + # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. + vrs = vrs_mapper() + blocking = functools.partial(vrs.map_score_set, score_set.urn) + loop = asyncio.get_running_loop() - except Exception as e: - send_slack_error(e) + mapping_results = None - # Attempt to remove this item from the mapping queue. - try: - await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore - logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) - except Exception: - pass + logger.debug(msg="Mapping variants using VRS mapping service.", extra=job_manager.logging_context()) + job_manager.update_progress(30, 100, "Mapping variants using VRS mapping service.") + mapping_results = await loop.run_in_executor(ctx["pool"], blocking) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error(msg="Variant mapper manager encountered an unexpected error during setup.", extra=logging_context) + logger.debug(msg="Done mapping variants.", extra=job_manager.logging_context()) + job_manager.update_progress(80, 100, "Processing mapped variants and updating database.") - return {"success": False, "enqueued_job": None} + ## Check our assumptions about mapping results and handle errors appropriately. Don't raise exceptions directly, + ## the try/except handling is intended for unexpected errors only. - new_job = None - new_job_id = None - try: - if not mapping_job_id or mapping_job_status in (JobStatus.not_found, JobStatus.complete): - logger.debug(msg="No mapping jobs are running, queuing a new one.", extra=logging_context) + # Ensure we have mapping results + if not mapping_results: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Mapping results were not returned from VRS mapping service."} + job_manager.db.add(score_set) + job_manager.db.commit() - new_job = await redis.enqueue_job( - "map_variants_for_score_set", correlation_id, queued_score_set.id, updater_id, attempt + job_manager.update_progress(100, 100, "Variant mapping failed due to missing results.") + job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) + logger.error( + msg="Mapping results were not returned from VRS mapping service.", extra=job_manager.logging_context() ) + return { + "status": "error", + "data": {}, + "exception_details": { + "message": "Mapping results were not returned from VRS mapping service.", + "type": NonexistentMappingResultsError.__name__, + "traceback": None, + }, + } - if new_job: - new_job_id = new_job.job_id - - logging_context["new_mapping_job_id"] = new_job_id - logger.info(msg="Queued a new mapping job.", extra=logging_context) - - return {"success": True, "enqueued_job": new_job_id} - - logger.info( - msg="A mapping job is already running, or a new job was unable to be enqueued. Deferring mapping by 5 minutes.", - extra=logging_context, - ) - - new_job = await redis.enqueue_job( - "variant_mapper_manager", - correlation_id, - updater_id, - attempt, - _defer_by=timedelta(minutes=5), - ) - - if new_job: - # Ensure this score set remains in the front of the queue. - queued_id = await redis.rpush(MAPPING_QUEUE_NAME, queued_score_set.id) # type: ignore - new_job_id = new_job.job_id - - logging_context["new_mapping_manager_job_id"] = new_job_id - logger.info(msg="Deferred a new mapping manager job.", extra=logging_context) - - # Our persistent Redis queue and ARQ's execution rules ensure that even if the worker is stopped and not restarted - # before the deferred time, these deferred jobs will still run once able. - return {"success": True, "enqueued_job": new_job_id} - - raise MappingEnqueueError() + # Ensure we have mapped scores + mapped_scores = mapping_results.get("mapped_scores") + if not mapped_scores: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} + job_manager.db.add(score_set) + job_manager.db.commit() + + job_manager.update_progress(100, 100, "Variant mapping failed; no variants were mapped.") + job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) + logger.error(msg="No variants were mapped for this score set.", extra=job_manager.logging_context()) + return { + "status": "error", + "data": {}, + "exception_details": { + "message": "No variants were mapped for this score set.", + "type": NonexistentMappingScoresError.__name__, + "traceback": None, + }, + } - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Variant mapper manager encountered an unexpected error while enqueing a mapping job. This job will not be retried.", - extra=logging_context, - ) + # Ensure we have reference metadata + reference_metadata = mapping_results.get("reference_sequences") + if not reference_metadata: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "Reference metadata missing from mapping results."} + job_manager.db.add(score_set) + job_manager.db.commit() + + job_manager.update_progress(100, 100, "Variant mapping failed due to missing reference metadata.") + job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) + logger.error(msg="Reference metadata missing from mapping results.", extra=job_manager.logging_context()) + return { + "status": "error", + "data": {}, + "exception_details": { + "message": "Reference metadata missing from mapping results.", + "type": NonexistentMappingReferenceError.__name__, + "traceback": None, + }, + } - db.rollback() - - # We shouldn't rely on the passed score set id matching the score set we are operating upon. - if not queued_score_set: - return {"success": False, "enqueued_job": new_job_id} - - # Attempt to remove this item from the mapping queue. - try: - await redis.lrem(MAPPING_QUEUE_NAME, 1, queued_id) # type: ignore - logger.warning(msg="Removed un-queueable score set from the queue.", extra=logging_context) - except Exception: - pass - - score_set_exc = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_score_set.id)).one_or_none() - if score_set_exc: - score_set_exc.mapping_state = MappingState.failed - score_set_exc.mapping_errors = "Unable to queue a new mapping job or defer score set mapping." - db.add(score_set_exc) - db.commit() - - return {"success": False, "enqueued_job": new_job_id} - - -async def map_variants_for_score_set( - ctx: dict, correlation_id: str, score_set_id: int, updater_id: int, attempt: int = 1 -) -> dict: - async with mapping_in_execution(redis=ctx["redis"], job_id=ctx["job_id"]): - logging_context = {} - score_set = None - try: - db: Session = ctx["db"] - redis: ArqRedis = ctx["redis"] - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - - logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) - logging_context["attempt"] = attempt - logger.info(msg="Started variant mapping", extra=logging_context) - - score_set.mapping_state = MappingState.processing - score_set.mapping_errors = null() - db.add(score_set) - db.commit() - - mapping_urn = score_set.urn - assert mapping_urn, "A valid URN is needed to map this score set." - - logging_context["current_mapping_resource"] = mapping_urn - logging_context["mapping_state"] = score_set.mapping_state - logger.debug(msg="Fetched score set metadata for mapping job.", extra=logging_context) - - # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. - vrs = vrs_mapper() - blocking = functools.partial(vrs.map_score_set, mapping_urn) - loop = asyncio.get_running_loop() - - except Exception as e: - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Variant mapper encountered an unexpected error during setup. This job will not be retried.", - extra=logging_context, + # Process and store mapped variants + for target_gene_identifier in reference_metadata: + target_gene = next( + (target_gene for target_gene in score_set.target_genes if target_gene.name == target_gene_identifier), + None, ) - db.rollback() - if score_set: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() + if not target_gene: + raise ValueError( + f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." + ) - return {"success": False, "retried": False, "enqueued_jobs": []} + job_manager.save_to_context({"processing_target_gene": target_gene.id}) + logger.debug(f"Processing target gene {target_gene.name}.", extra=job_manager.logging_context()) - mapping_results = None - try: - mapping_results = await loop.run_in_executor(ctx["pool"], blocking) - logger.debug(msg="Done mapping variants.", extra=logging_context) + # allow for multiple annotation layers + pre_mapped_metadata: dict[str, Any] = {} + post_mapped_metadata: dict[str, Any] = {} + excluded_pre_mapped_keys = {"sequence"} - except Exception as e: - db.rollback() - score_set.mapping_errors = { - "error_message": f"Encountered an internal server error during mapping. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." - } - db.add(score_set) - db.commit() - - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.warning( - msg="Variant mapper encountered an unexpected error while mapping variants. This job will be retried.", - extra=logging_context, - ) + # add gene-level info + gene_info = reference_metadata[target_gene_identifier].get("gene_info") + if gene_info: + target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") + post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") + + job_manager.save_to_context({"mapped_hgnc_name": target_gene.mapped_hgnc_name}) + logger.debug("Added mapped HGNC name to target gene.", extra=job_manager.logging_context()) - new_job_id = None - max_retries_exceeded = None - try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id + # add annotation layer info + for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: + layer_premapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( + "computed_reference_sequence" ) - # If we fail to enqueue a mapping manager for this score set, evict it from the queue. - if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as backoff_e: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - send_slack_error(backoff_e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} - logger.critical( - msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", - extra=logging_context, + if layer_premapped: + pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { + k: layer_premapped[k] for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys + } + job_manager.save_to_context({"pre_mapped_layer_exists": True}) + + layer_postmapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( + "mapped_reference_sequence" ) - else: - if new_job_id and not max_retries_exceeded: - score_set.mapping_state = MappingState.queued - db.add(score_set) - db.commit() - logger.info( - msg="After encountering an error while mapping variants, another mapping job was queued.", - extra=logging_context, - ) - elif new_job_id is None and not max_retries_exceeded: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - logger.error( - msg="After encountering an error while mapping variants, another mapping job was unable to be queued. This score set will not be mapped.", - extra=logging_context, - ) - else: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - db.add(score_set) - db.commit() - logger.error( - msg="After encountering an error while mapping variants, the maximum retries for this job were exceeded. This score set will not be mapped.", - extra=logging_context, - ) - finally: - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_jobs": [job for job in [new_job_id] if job], - } - - try: - if mapping_results: - mapped_scores = mapping_results.get("mapped_scores") - if not mapped_scores: - # if there are no mapped scores, the score set failed to map. - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} - else: - reference_metadata = mapping_results.get("reference_sequences") - if not reference_metadata: - raise NonexistentMappingReferenceError() - - for target_gene_identifier in reference_metadata: - target_gene = next( - ( - target_gene - for target_gene in score_set.target_genes - if target_gene.name == target_gene_identifier - ), - None, - ) - if not target_gene: - raise ValueError( - f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." - ) - # allow for multiple annotation layers - pre_mapped_metadata: dict[str, Any] = {} - post_mapped_metadata: dict[str, Any] = {} - excluded_pre_mapped_keys = {"sequence"} - - gene_info = reference_metadata[target_gene_identifier].get("gene_info") - if gene_info: - target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") - post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") - - for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: - layer_premapped = reference_metadata[target_gene_identifier]["layers"][ - annotation_layer - ].get("computed_reference_sequence") - if layer_premapped: - pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { - k: layer_premapped[k] - for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys - } - layer_postmapped = reference_metadata[target_gene_identifier]["layers"][ - annotation_layer - ].get("mapped_reference_sequence") - if layer_postmapped: - post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped - target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) - target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) - - total_variants = 0 - successful_mapped_variants = 0 - for mapped_score in mapped_scores: - total_variants += 1 - variant_urn = mapped_score.get("mavedb_id") - variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() - - # there should only be one current mapped variant per variant id, so update old mapped variant to current = false - existing_mapped_variant = ( - db.query(MappedVariant) - .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) - .one_or_none() - ) - - if existing_mapped_variant: - existing_mapped_variant.current = False - db.add(existing_mapped_variant) - - if mapped_score.get("pre_mapped") and mapped_score.get("post_mapped"): - successful_mapped_variants += 1 - - mapped_variant = MappedVariant( - pre_mapped=mapped_score.get("pre_mapped", null()), - post_mapped=mapped_score.get("post_mapped", null()), - variant_id=variant.id, - modification_date=date.today(), - mapped_date=mapping_results["mapped_date_utc"], - vrs_version=mapped_score.get("vrs_version", null()), - mapping_api_version=mapping_results["dcd_mapping_version"], - error_message=mapped_score.get("error_message", null()), - current=True, - ) - db.add(mapped_variant) - - if successful_mapped_variants == 0: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "All variants failed to map"} - elif successful_mapped_variants < total_variants: - score_set.mapping_state = MappingState.incomplete - else: - score_set.mapping_state = MappingState.complete - - logging_context["mapped_variants_inserted_db"] = len(mapped_scores) - logging_context["variants_successfully_mapped"] = successful_mapped_variants - logging_context["mapping_state"] = score_set.mapping_state.name - logging_context["mapping_errors"] = score_set.mapping_errors - logger.info(msg="Inserted mapped variants into db.", extra=logging_context) - - else: - raise NonexistentMappingResultsError() - - db.add(score_set) - db.commit() - - except Exception as e: - db.rollback() - score_set.mapping_errors = { - "error_message": f"Encountered an unexpected error while parsing mapped variants. Mapping will be automatically retried up to 5 times for this score set (attempt {attempt}/5)." - } - db.add(score_set) - db.commit() - - send_slack_error(e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.warning( - msg="An unexpected error occurred during variant mapping. This job will be attempted again.", - extra=logging_context, - ) + if layer_postmapped: + post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped + job_manager.save_to_context({"post_mapped_layer_exists": True}) - new_job_id = None - max_retries_exceeded = None - try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, MAPPING_BACKOFF_IN_SECONDS, correlation_id, updater_id - ) - # If we fail to enqueue a mapping manager for this score set, evict it from the queue. - if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore - - logging_context["backoff_limit_exceeded"] = max_retries_exceeded - logging_context["backoff_deferred_in_seconds"] = backoff_time - logging_context["backoff_job_id"] = new_job_id - - except Exception as backoff_e: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - send_slack_error(backoff_e) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(backoff_e)} - logger.critical( - msg="While attempting to re-enqueue a mapping job that exited in error, another exception was encountered. This score set will not be mapped.", - extra=logging_context, + logger.debug( + f"Added annotation layer mapping metadata for {annotation_layer}.", + extra=job_manager.logging_context(), ) - else: - if new_job_id and not max_retries_exceeded: - score_set.mapping_state = MappingState.queued - logger.info( - msg="After encountering an error while parsing mapped variants, another mapping job was queued.", - extra=logging_context, - ) - elif new_job_id is None and not max_retries_exceeded: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - logger.error( - msg="After encountering an error while parsing mapped variants, another mapping job was unable to be queued. This score set will not be mapped.", - extra=logging_context, - ) - else: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "Encountered an internal server error during mapping"} - logger.error( - msg="After encountering an error while parsing mapped variants, the maximum retries for this job were exceeded. This score set will not be mapped.", - extra=logging_context, - ) - finally: - db.add(score_set) - db.commit() - return { - "success": False, - "retried": (not max_retries_exceeded and new_job_id is not None), - "enqueued_jobs": [job for job in [new_job_id] if job], - } - - new_uniprot_job_id = None - try: - if UNIPROT_ID_MAPPING_ENABLED: - new_job = await redis.enqueue_job( - "submit_uniprot_mapping_jobs_for_score_set", - score_set.id, - correlation_id, - ) - if new_job: - new_uniprot_job_id = new_job.job_id + target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) + target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) + job_manager.db.add(target_gene) + logger.debug("Added mapping metadata to target gene.", extra=job_manager.logging_context()) - logging_context["submit_uniprot_mapping_job_id"] = new_uniprot_job_id - logger.info(msg="Queued a new UniProt mapping job.", extra=logging_context) + total_variants = len(mapped_scores) + job_manager.save_to_context({"total_variants_to_process": total_variants}) + job_manager.update_progress(90, 100, "Storing mapped variants in database.") - else: - raise UniProtIDMappingEnqueueError() - else: - logger.warning( - msg="UniProt ID mapping is disabled, skipped submission of UniProt mapping jobs.", - extra=logging_context, - ) + successful_mapped_variants = 0 + for mapped_score in mapped_scores: + variant_urn = mapped_score.get("mavedb_id") + variant = job_manager.db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() - except Exception as e: - send_slack_error(e) - send_slack_message( - f"Could not enqueue UniProt mapping job for score set {score_set.urn}. UniProt mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant UniProt submission encountered an unexpected error while attempting to enqueue a mapping job. This job will not be retried.", - extra=logging_context, - ) + job_manager.save_to_context({"processing_variant": variant.id}) + logger.debug(f"Processing variant {variant.id}.", extra=job_manager.logging_context()) - return {"success": False, "retried": False, "enqueued_jobs": [job for job in [new_uniprot_job_id] if job]} - - new_clingen_job_id = None - try: - if CLIN_GEN_SUBMISSION_ENABLED: - new_job = await redis.enqueue_job( - "submit_score_set_mappings_to_car", - correlation_id, - score_set.id, + # there should only be one current mapped variant per variant id, so update old mapped variant to current = false + existing_mapped_variant = ( + job_manager.db.query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() ) - if new_job: - new_clingen_job_id = new_job.job_id + if existing_mapped_variant: + job_manager.save_to_context({"existing_mapped_variant": existing_mapped_variant.id}) + existing_mapped_variant.current = False + job_manager.db.add(existing_mapped_variant) + logger.debug(msg="Set existing mapped variant to current = false.", extra=job_manager.logging_context()) + + if mapped_score.get("pre_mapped") and mapped_score.get("post_mapped"): + successful_mapped_variants += 1 + job_manager.save_to_context({"successful_mapped_variants": successful_mapped_variants}) + + mapped_variant = MappedVariant( + pre_mapped=mapped_score.get("pre_mapped", null()), + post_mapped=mapped_score.get("post_mapped", null()), + variant_id=variant.id, + modification_date=date.today(), + mapped_date=mapping_results["mapped_date_utc"], + vrs_version=mapped_score.get("vrs_version", null()), + mapping_api_version=mapping_results["dcd_mapping_version"], + error_message=mapped_score.get("error_message", null()), + current=True, + ) - logging_context["submit_clingen_variants_job_id"] = new_clingen_job_id - logger.info(msg="Queued a new ClinGen submission job.", extra=logging_context) + job_manager.db.add(mapped_variant) + logger.debug(msg="Added new mapped variant to session.", extra=job_manager.logging_context()) - else: - raise SubmissionEnqueueError() + if successful_mapped_variants == 0: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "All variants failed to map"} + elif successful_mapped_variants < total_variants: + score_set.mapping_state = MappingState.incomplete else: - logger.warning( - msg="ClinGen submission is disabled, skipped submission of mapped variants to CAR and LDH.", - extra=logging_context, - ) + score_set.mapping_state = MappingState.complete + + job_manager.save_to_context( + { + "successful_mapped_variants": successful_mapped_variants, + "mapping_state": score_set.mapping_state.name, + "mapping_errors": score_set.mapping_errors, + "inserted_mapped_variants": len(mapped_scores), + } + ) + + job_manager.update_progress(100, 100, "Completed processing of mapped variants.") + logger.info(msg="Inserted mapped variants into db.", extra=job_manager.logging_context()) except Exception as e: send_slack_error(e) - send_slack_message( - f"Could not submit mappings to CAR and/or LDH mappings for score set {score_set.urn}. Mappings for this score set should be submitted manually." - ) - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error( - msg="Mapped variant ClinGen submission encountered an unexpected error while attempting to enqueue a submission job. This job will not be retried.", - extra=logging_context, - ) + logging_context = {**job_manager.logging_context(), **format_raised_exception_info_as_dict(e)} + logger.error(msg="Encountered an unexpected error while parsing mapped variants.", extra=logging_context) + + job_manager.db.rollback() + + score_set.mapping_state = MappingState.failed + if not score_set.mapping_errors: + score_set.mapping_errors = { + "error_message": f"Encountered an unexpected error while parsing mapped variants. This job will be retried up to {job.max_retries} times (this was attempt {job.retry_count})." + } + job_manager.update_progress(100, 100, "Variant mapping failed due to an unexpected error.") return { - "success": False, - "retried": False, - "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], + "status": "error", + "data": {}, + "exception_details": {"message": str(e), "type": type(e).__name__, "traceback": None}, } - ctx["state"][ctx["job_id"]] = logging_context.copy() - return { - "success": True, - "retried": False, - "enqueued_jobs": [job for job in [new_uniprot_job_id, new_clingen_job_id] if job], - } + finally: + job_manager.db.add(score_set) + job_manager.db.commit() + + return {"status": "ok" if successful_mapped_variants > 0 else "error", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/variant_processing/py.typed b/src/mavedb/worker/jobs/variant_processing/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/worker/lib/managers/py.typed b/src/mavedb/worker/lib/managers/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/network/worker/test_clingen.py b/tests/network/worker/test_clingen.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/network/worker/test_gnomad.py b/tests/network/worker/test_gnomad.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/network/worker/test_uniprot.py b/tests/network/worker/test_uniprot.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/worker/lib/conftest_optional.py b/tests/worker/conftest_optional.py similarity index 100% rename from tests/worker/lib/conftest_optional.py rename to tests/worker/conftest_optional.py diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py new file mode 100644 index 00000000..b9962163 --- /dev/null +++ b/tests/worker/jobs/data_management/test_views.py @@ -0,0 +1,288 @@ +# ruff: noqa: E402 + +import pytest + +from mavedb.models.pipeline import Pipeline +from mavedb.models.published_variant import PublishedVariantsMV + +pytest.importorskip("arq") # Skip tests if arq is not installed + +from unittest.mock import call, patch + +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.jobs.data_management.views import refresh_materialized_views, refresh_published_variants_view +from tests.helpers.transaction_spy import TransactionSpy + +############################################################################################################################################ +# refresh_materialized_views +############################################################################################################################################ + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestRefreshMaterializedViewsUnit: + """Unit tests for the refresh_materialized_views function.""" + + async def test_refresh_materialized_views_calls_refresh_function(self, mock_worker_ctx, mock_job_manager): + """Test that refresh_materialized_views calls the refresh function.""" + with ( + patch("mavedb.worker.jobs.data_management.views.refresh_all_mat_views") as mock_refresh, + TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + ): + result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_refresh.assert_called_once_with(mock_job_manager.db) + assert result == {"status": "ok", "data": {}, "exception_details": None} + + async def test_refresh_materialized_views_updates_progress(self, mock_worker_ctx, mock_job_manager): + """Test that refresh_materialized_views updates progress correctly.""" + with ( + patch("mavedb.worker.jobs.data_management.views.refresh_all_mat_views"), + patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + ): + result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) + + expected_calls = [ + call(0, 100, "Starting refresh of all materialized views."), + call(100, 100, "Completed refresh of all materialized views."), + ] + mock_update_progress.assert_has_calls(expected_calls) + assert result == {"status": "ok", "data": {}, "exception_details": None} + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshMaterializedViewsIntegration: + """Integration tests for the refresh_materialized_views function and decorator logic.""" + + async def test_refresh_materialized_views_integration(self, standalone_worker_context, session): + """Integration test that runs refresh_materialized_views end-to-end.""" + + # Flush will be called implicitly when the transaction is committed + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + result = await refresh_materialized_views(standalone_worker_context) + + job = session.execute( + select(JobRun).where(JobRun.job_function == "refresh_materialized_views") + ).scalar_one_or_none() + assert job is not None + assert job.status == JobStatus.SUCCEEDED + assert job.job_type == "cron_job" + + assert result == {"status": "ok", "data": {}, "exception_details": None} + + async def test_refresh_materialized_views_handles_exceptions(self, standalone_worker_context, session): + """Integration test that ensures exceptions during refresh are handled properly.""" + + with ( + patch( + "mavedb.worker.jobs.data_management.views.refresh_all_mat_views", + side_effect=Exception("Test exception during refresh"), + ), + TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + ): + result = await refresh_materialized_views(standalone_worker_context) + + job = session.execute( + select(JobRun).where(JobRun.job_function == "refresh_materialized_views") + ).scalar_one_or_none() + + assert job is not None + assert job.status == JobStatus.FAILED + assert job.job_type == "cron_job" + assert job.error_message == "Test exception during refresh" + assert result["exception_details"]["message"] == "Test exception during refresh" + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshMaterializedViewsArqContext: + """Integration tests for refresh_materialized_views within an ARQ worker context.""" + + async def test_refresh_materialized_views_arq_integration( + self, arq_redis, arq_worker, standalone_worker_context, session + ): + """Integration test that runs refresh_materialized_views end-to-end using ARQ context.""" + await arq_redis.enqueue_job("refresh_materialized_views") + await arq_worker.async_run() + await arq_worker.run_check() + + job = session.execute( + select(JobRun).where(JobRun.job_function == "refresh_materialized_views") + ).scalar_one_or_none() + assert job is not None + assert job.status == JobStatus.SUCCEEDED + assert job.job_type == "cron_job" + + +############################################################################################################################################ +# refresh_published_variants_view +############################################################################################################################################ + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestRefreshPublishedVariantsViewUnit: + """Unit tests for the refresh_published_variants_view function.""" + + async def test_refresh_published_variants_view_calls_refresh_function( + self, mock_worker_ctx, mock_job_manager, mock_job_run + ): + """Test that refresh_published_variants_view calls the refresh function.""" + mock_job_run.job_params = {"correlation_id": "test-corr-id"} + + with ( + patch.object(PublishedVariantsMV, "refresh") as mock_refresh, + patch("mavedb.worker.jobs.data_management.views.validate_job_params"), + TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + ): + result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) + + mock_refresh.assert_called_once_with(mock_job_manager.db) + assert result == {"status": "ok", "data": {}, "exception_details": None} + + async def test_refresh_published_variants_view_updates_progress( + self, mock_worker_ctx, mock_job_manager, mock_job_run + ): + """Test that refresh_published_variants_view updates progress correctly.""" + mock_job_run.job_params = {"correlation_id": "test-corr-id"} + + with ( + patch.object(PublishedVariantsMV, "refresh"), + patch("mavedb.worker.jobs.data_management.views.validate_job_params"), + patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + ): + result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) + + expected_calls = [ + call(0, 100, "Starting refresh of published variants materialized view."), + call(100, 100, "Completed refresh of published variants materialized view."), + ] + mock_update_progress.assert_has_calls(expected_calls) + assert result == {"status": "ok", "data": {}, "exception_details": None} + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshPublishedVariantsViewIntegration: + """Integration tests for the refresh_published_variants_view function and decorator logic.""" + + @pytest.fixture() + def setup_refresh_job_run(self, session): + """Add a refresh_published_variants_view job run to the DB before each test.""" + job_run = JobRun( + job_type="data_management", + job_function="refresh_published_variants_view", + status=JobStatus.PENDING, + job_params={"correlation_id": "test-corr-id"}, + ) + session.add(job_run) + session.commit() + return job_run + + async def test_refresh_published_variants_view_integration_standalone( + self, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that runs refresh_published_variants_view end-to-end.""" + # Flush will be called implicitly when the transaction is committed + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.SUCCEEDED + assert result == {"status": "ok", "data": {}, "exception_details": None} + + async def test_refresh_published_variants_view_integration_pipeline( + self, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that runs refresh_published_variants_view end-to-end.""" + # Create a pipeline for the job run and associate it + pipeline = Pipeline( + name="Test Pipeline for Published Variants View Refresh", + ) + session.add(pipeline) + session.commit() + session.refresh(pipeline) + setup_refresh_job_run.pipeline_id = pipeline.id + session.add(setup_refresh_job_run) + session.commit() + + # Flush will be called implicitly when the transaction is committed + with TransactionSpy.spy(session, expect_flush=True, expect_commit=True): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.SUCCEEDED + assert result == {"status": "ok", "data": {}, "exception_details": None} + session.refresh(pipeline) + assert pipeline.status == PipelineStatus.SUCCEEDED + + async def test_refresh_published_variants_view_handles_exceptions( + self, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that ensures exceptions during refresh are handled properly.""" + with ( + patch.object( + PublishedVariantsMV, + "refresh", + side_effect=Exception("Test exception during published variants view refresh"), + ), + TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + ): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.FAILED + assert setup_refresh_job_run.error_message == "Test exception during published variants view refresh" + assert result["exception_details"]["message"] == "Test exception during published variants view refresh" + + async def test_refresh_published_variants_view_requires_params( + self, setup_refresh_job_run, standalone_worker_context, session + ): + """Integration test that ensures required job params are validated.""" + setup_refresh_job_run.job_params = {} # Clear required params + session.add(setup_refresh_job_run) + session.commit() + + with TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True): + result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.FAILED + assert "Job has no job_params defined" in setup_refresh_job_run.error_message + assert "Job has no job_params defined" in result["exception_details"]["message"] + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshPublishedVariantsViewArqContext: + """Integration tests for refresh_published_variants_view within an ARQ worker context.""" + + @pytest.fixture() + def setup_refresh_job_run(self, session): + """Add a refresh_published_variants_view job run to the DB before each test.""" + job_run = JobRun( + job_type="data_management", + job_function="refresh_published_variants_view", + status=JobStatus.PENDING, + job_params={"correlation_id": "test-corr-id"}, + ) + session.add(job_run) + session.commit() + return job_run + + async def test_refresh_published_variants_view_arq_integration( + self, arq_redis, arq_worker, standalone_worker_context, session, setup_refresh_job_run + ): + """Integration test that runs refresh_published_variants_view end-to-end using ARQ context.""" + await arq_redis.enqueue_job("refresh_published_variants_view", setup_refresh_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + session.refresh(setup_refresh_job_run) + assert setup_refresh_job_run.status == JobStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index 28432297..add6d0b1 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -1,38 +1,31 @@ # ruff: noqa: E402 -from asyncio.unix_events import _UnixSelectorEventLoop -from unittest.mock import patch +from unittest.mock import MagicMock, call, patch from uuid import uuid4 import pytest -from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.lib.managers.job_manager import JobManager arq = pytest.importorskip("arq") +from sqlalchemy.exc import NoResultFound + from mavedb.lib.clingen.services import ( ClinGenAlleleRegistryService, - ClinGenLdhService, - clingen_allele_id_from_ldh_variation, ) from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.models.variant import Variant from mavedb.worker.jobs import ( - link_clingen_variants, submit_score_set_mappings_to_car, - submit_score_set_mappings_to_ldh, ) from tests.helpers.constants import ( TEST_CLINGEN_ALLELE_OBJECT, - TEST_CLINGEN_LDH_LINKING_RESPONSE, - TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, - TEST_CLINGEN_SUBMISSION_RESPONSE, - TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE, TEST_MINIMAL_SEQ_SCORESET, ) -from tests.helpers.util.exceptions import awaitable_exception from tests.helpers.util.setup.worker import ( - setup_records_files_and_variants, setup_records_files_and_variants_with_mapping, ) @@ -42,838 +35,484 @@ @pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - mapped_variants_with_caid_for_score_set = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) - ).all() - - assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] is not None - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.clingen.setup_job_state", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_no_variants_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_hgvs_dict_creation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_during_submission( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", side_effect=Exception()), - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_in_allele_association( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", side_effect=Exception()), - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_car_exception_during_ldh_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - ): - result = await submit_score_set_mappings_to_car(standalone_worker_context, uuid4().hex, score_set.id) - - mapped_variants_with_caid_for_score_set = session.scalars( - select(MappedVariant) - .join(Variant) - .join(ScoreSetDbModel) - .filter(ScoreSetDbModel.urn == score_set.urn, MappedVariant.clingen_allele_id.is_not(None)) - ).all() - - assert len(mapped_variants_with_caid_for_score_set) == score_set.num_variants - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -############################################################################################################################################ -# ClinGen LDH Submission -############################################################################################################################################ - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] is not None - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.clingen.setup_job_state", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_auth( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch.object( - ClinGenLdhService, - "_existing_jwt", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_no_variants_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_hgvs_generation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=Exception(), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_in_ldh_submission_construction( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.lib.clingen.content_constructors.construct_ldh_submission", - side_effect=Exception(), +@pytest.mark.unit +class TestSubmitScoreSetMappingsToCARUnit: + """Tests for the submit_score_set_mappings_to_car function.""" + + @pytest.mark.parametrize("missing_param", ["score_set_id", "correlation_id"]) + async def test_submit_score_set_mappings_to_car_required_params( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, + missing_param, ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) + """Test that submitting a non-existent score set raises an exception.""" - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] + mock_job_run.job_params = {"score_set_id": 99, "correlation_id": uuid4().hex} + del mock_job_run.job_params[missing_param] -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_during_submission( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_submission_job(): - return Exception() - - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=failed_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - + with pytest.raises(ValueError): + await submit_score_set_mappings_to_car(mock_worker_ctx, 99, job_manager=mock_job_manager) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "error_response", [TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE] -) -async def test_submit_score_set_mappings_to_ldh_submission_failures_exist( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis, error_response -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [None, error_response] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), + async def test_submit_score_set_mappings_to_car_raises_when_no_score_set( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_exception_during_linking_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), + """Test that submitting a non-existent score set raises an exception.""" + + mock_job_run.job_params = {"score_set_id": 99, "correlation_id": uuid4().hex} + + with ( + pytest.raises(NoResultFound), + patch.object(mock_job_manager.db, "scalars", side_effect=NoResultFound()), + patch.object(mock_job_manager, "update_progress", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + ): + await submit_score_set_mappings_to_car(mock_worker_ctx, 99, job_manager=mock_job_manager) + + async def test_submit_score_set_mappings_to_car_no_mapped_variants( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_submit_score_set_mappings_to_ldh_linking_not_queued_when_expected( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_submission_job(), - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(arq.ArqRedis, "enqueue_job", return_value=None), + """Test that submitting a score set with no mapped variants completes successfully.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + with ( + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock(one=MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=0)), + ), + patch.object( + mock_job_manager.db, + "execute", + return_value=MagicMock(all=lambda: []), + ), + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + patch.object(mock_job_manager, "update_progress", return_value=None), + ): + result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_car_no_variants_updates_progress( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await submit_score_set_mappings_to_ldh(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -############################################################################################################################################## -## ClinGen Linkage -############################################################################################################################################## - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() + """Test that submitting a score set with no variants updates progress to 100%.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + with ( + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock(one=MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=0)), + ), + patch.object( + mock_job_manager.db, + "execute", + return_value=MagicMock(all=lambda: []), + ), + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + ): + await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + expected_calls = [ + call(0, 100, "Starting CAR mapped resource submission."), + call(100, 100, "No mapped variants to submit to CAR. Skipped submission."), ] + mock_update_progress.assert_has_calls(expected_calls) - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + async def test_submit_score_set_mappings_to_car_no_submission_endpoint( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - assert variant.clingen_allele_id == clingen_allele_id_from_ldh_variation(TEST_CLINGEN_LDH_LINKING_RESPONSE) - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.clingen.setup_job_state", - side_effect=Exception(), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.clingen_allele_id is None - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_no_variants_to_link( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_during_linkage( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=Exception(), + """Test that submitting a score set with no CAR submission endpoint configured raises an exception.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + with ( + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock(one=MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=1)), + ), + patch.object( + mock_job_manager.db, + "execute", + return_value=MagicMock(all=lambda: [(999, {}), (1000, {})]), + ), + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + patch.object(mock_job_manager, "update_progress", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", None), + pytest.raises(ValueError), + ): + await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + async def test_submit_score_set_mappings_to_car_no_variants_associated( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_exception_while_parsing_linkages( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.external_services.clingen.clingen_allele_id_from_ldh_variation", - side_effect=Exception(), - ), + """Test that submitting a score set with no variants associated completes successfully.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) + mocked_mapped_variant_with_hgvs = MagicMock(spec=MappedVariant, id=1000, clingen_allele_id=None) + + with ( + # db.scalars is called twice in this function: once to get the score set (one), once to get the mapped variants (all) + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock( + one=mocked_score_set, + all=lambda: [mocked_mapped_variant_with_hgvs], + ), + ), + # db.execute is called to get the mapped variant IDs and post mapped data + patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), + # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both + # calls return valid HGVS strings. + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=["c.122G>C", "c.123A>T"], + ), + # validate_job_params is called to validate job parameters + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + # update_progress is called multiple times to update job progress + patch.object(mock_job_manager, "update_progress", return_value=None), + # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant + patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[]), + # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID + patch( + "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", + return_value={}, + ), + patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + ): + result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + # Assert no CAID was not added to the variant + mock_db_add.assert_not_called() + assert mocked_mapped_variant_with_hgvs.clingen_allele_id is None + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_car_no_variants_found_in_db( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_but_do_not_eclipse_retry_threshold( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", - 2, - ), + """Test that submitting a score set with no mapped variants found in the db completes successfully.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) + mocked_mapped_variant_with_hgvs = MagicMock(spec=MappedVariant, id=1000, clingen_allele_id=None) + + with ( + # db.scalars is called twice in this function: once to get the score set (one), twice to get the mapped variants (all) + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock( + one=mocked_score_set, + all=lambda: [], + ), + ), + # db.execute is called to get the mapped variant IDs and post mapped data + patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), + # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both + # calls return valid HGVS strings. + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=["c.122G>C", "c.123A>T"], + ), + # validate_job_params is called to validate job parameters + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + # update_progress is called multiple times to update job progress + patch.object(mock_job_manager, "update_progress", return_value=None), + # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant + patch.object( + ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT] + ), + # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID + patch( + "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", + return_value={"c.122G>C": "CAID:0000000", "c.123A>T": "CAID:0000001"}, + ), + patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + ): + result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + # Assert no CAID was not added to the variant + mock_db_add.assert_not_called() + assert mocked_mapped_variant_with_hgvs.clingen_allele_id is None + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_car_skips_submission_for_variants_without_hgvs_string( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch( - "mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", - 0, - ), + """Test that submitting a score set with mapped variants completes successfully but skips variants without an HGVS string.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) + mocked_mapped_variant_with_hgvs = MagicMock(spec=MappedVariant, id=1000) + + with ( + # db.scalars is called twice in this function: once to get the score set (one), once to get the mapped variants (all) + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock( + one=mocked_score_set, + all=lambda: [mocked_mapped_variant_with_hgvs], + ), + ), + # db.execute is called to get the mapped variant IDs and post mapped data + patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), + # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that the first + # call returns None (no HGVS), the second returns a valid HGVS string. + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=[None, "c.123A>T"], + ), + # validate_job_params is called to validate job parameters + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + # update_progress is called multiple times to update job progress + patch.object(mock_job_manager, "update_progress", return_value=None), + # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant + patch.object( + ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT] + ), + # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID + patch( + "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", + return_value={"c.123A>T": "CAID:0000001"}, + ), + patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + ): + result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + # Assert the variant without an HGVS string was skipped, and the other variant was updated with the CAID + mock_db_add.assert_has_calls([call(mocked_mapped_variant_with_hgvs)]) + assert mocked_mapped_variant_with_hgvs.clingen_allele_id == "CAID:0000001" + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_car_success( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert result["retried"] - assert result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_cant_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), + """Test that submitting a score set with mapped variants completes successfully.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) + mocked_mapped_variant_with_hgvs_999 = MagicMock(spec=MappedVariant, id=999) + mocked_mapped_variant_with_hgvs_1000 = MagicMock(spec=MappedVariant, id=1000) + + with ( + # db.scalars is called three times in this function: once to get the score set (one), twice to get the mapped variants (all) + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock( + one=mocked_score_set, + all=MagicMock( + side_effect=[[mocked_mapped_variant_with_hgvs_999], [mocked_mapped_variant_with_hgvs_1000]] + ), + ), + ), + # db.execute is called to get the mapped variant IDs and post mapped data + patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), + # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both + # calls return valid HGVS strings. + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=["c.122G>C", "c.123A>T"], + ), + # validate_job_params is called to validate job parameters + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + # update_progress is called multiple times to update job progress + patch.object(mock_job_manager, "update_progress", return_value=None), + # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant + patch.object( + ClinGenAlleleRegistryService, + "dispatch_submissions", + return_value=[TEST_CLINGEN_ALLELE_OBJECT, TEST_CLINGEN_ALLELE_OBJECT], + ), + # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID + patch( + "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", + return_value={"c.122G>C": "CAID:0000000", "c.123A>T": "CAID:0000001"}, + ), + patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + ): + result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + # Assert the variant without an HGVS string was skipped, and the other variant was updated with the CAID + mock_db_add.assert_has_calls( + [call(mocked_mapped_variant_with_hgvs_999), call(mocked_mapped_variant_with_hgvs_1000)] + ) + assert mocked_mapped_variant_with_hgvs_999.clingen_allele_id == "CAID:0000000" + assert mocked_mapped_variant_with_hgvs_1000.clingen_allele_id == "CAID:0000001" + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_car_updates_progress( + self, + mock_job_manager, + mock_job_run, + mock_worker_ctx, ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] + """Test that submitting a score set with mapped variants updates progress correctly.""" + + mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + + mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) + mocked_mapped_variant_with_hgvs_999 = MagicMock(spec=MappedVariant, id=999) + mocked_mapped_variant_with_hgvs_1000 = MagicMock(spec=MappedVariant, id=1000) + + with ( + # db.scalars is called three times in this function: once to get the score set (one), twice to get the mapped variants (all) + patch.object( + mock_job_manager.db, + "scalars", + return_value=MagicMock( + one=mocked_score_set, + all=MagicMock( + side_effect=[[mocked_mapped_variant_with_hgvs_999], [mocked_mapped_variant_with_hgvs_1000]] + ), + ), + ), + # db.execute is called to get the mapped variant IDs and post mapped data + patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), + # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both + # calls return valid HGVS strings. + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + side_effect=["c.122G>C", "c.123A>T"], + ), + # validate_job_params is called to validate job parameters + patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + # update_progress is called multiple times to update job progress + patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network/pytest", + ), + # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant + patch.object( + ClinGenAlleleRegistryService, + "dispatch_submissions", + return_value=[TEST_CLINGEN_ALLELE_OBJECT], + ), + ): + result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + + # Assert the variant without an HGVS string was skipped, and the other variant was updated with the CAID + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting CAR mapped resource submission."), + call(10, 100, "Preparing 2 mapped variants for CAR submission."), + call(15, 100, "Submitting mapped variants to CAR."), + call(50, 100, "Processing registered alleles from CAR."), + call(100, 100, "Completed CAR mapped resource submission."), + ] + ) + assert result["status"] == "ok" @pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_failures_exist_and_eclipse_retry_threshold_retries_exceeded( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, +@pytest.mark.integration +class TestSubmitScoreSetMappingsToCARIntegration: + """Integration tests for the submit_score_set_mappings_to_car function.""" + + @pytest.fixture() + def setup_car_submission_job_run(self, session): + """Add a submit_score_set_mappings_to_car job run to the DB before each test.""" + job_run = JobRun( + job_type="external_service", + job_function="submit_score_set_mappings_to_car", + status=JobStatus.PENDING, + job_params={"correlation_id": "test-corr-id"}, + ) + session.add(job_run) + session.commit() + return job_run + + async def test_submit_score_set_mappings_to_car_no_submission_endpoint( + self, standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, None) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch( - "mavedb.worker.jobs.external_services.clingen.LINKED_DATA_RETRY_THRESHOLD", - 1, - ), - patch( - "mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", - 0, - ), - patch( - "mavedb.worker.jobs.utils.retry.ENQUEUE_BACKOFF_ATTEMPT_LIMIT", - 1, - ), - ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 2) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_ldh_objects_error_in_gnomad_job_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( session, + with_populated_test_data, + setup_car_submission_job_run, async_client, data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We are unable to mock requests via requests_mock that occur inside another event loop. Instead, patch the return - # value of the EventLoop itself, which would have made the request. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - return_value=dummy_linking_job(), - ), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), + arq_redis, ): - result = await link_clingen_variants(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] + """Test that submitting a score set with no CAR submission endpoint configured raises an exception.""" + score_set = await setup_records_files_and_variants_with_mapping( + session, + async_client, + data_files, + TEST_MINIMAL_SEQ_SCORESET, + standalone_worker_context, + ) + + with patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + None, + ): + with pytest.raises(ValueError): + await submit_score_set_mappings_to_car( + standalone_worker_context, + score_set.id, + JobManager( + session, + arq_redis, + setup_car_submission_job_run.id, + ), + ) diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index c407462b..e69de29b 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -1,206 +0,0 @@ -# ruff: noqa: E402 - -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from sqlalchemy import select - -arq = pytest.importorskip("arq") - -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.models.variant import Variant -from mavedb.worker.jobs import ( - link_gnomad_variants, -) -from tests.helpers.constants import ( - TEST_GNOMAD_DATA_VERSION, - TEST_MINIMAL_SEQ_SCORESET, - VALID_CLINGEN_CA_ID, -) -from tests.helpers.util.setup.worker import ( - setup_records_files_and_variants, - setup_records_files_and_variants_with_mapping, -) - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_success( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. - mapped_variants = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID - session.commit() - - # Patch Athena connection with mock object which returns a mocked gnomAD variant row w/ CAID=VALID_CLINGEN_CA_ID. - with ( - patch( - "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", - return_value=[mocked_gnomad_variant_row], - ), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.gnomad.setup_job_state", - side_effect=Exception(), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_no_variants_to_link( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_while_fetching_variant_data( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.external_services.gnomad.setup_job_state", - side_effect=Exception(), - ), - patch("mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", side_effect=Exception()), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants - - -@pytest.mark.asyncio -async def test_link_score_set_mappings_to_gnomad_variants_exception_while_linking_variants( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # We need to set the ClinGen Allele ID for the Mapped Variants, so that the gnomAD job can link them. - mapped_variants = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - - for mapped_variant in mapped_variants: - mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID - session.commit() - - with ( - patch( - "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", - return_value=[mocked_gnomad_variant_row], - ), - patch( - "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", - side_effect=Exception(), - ), - ): - result = await link_gnomad_variants(standalone_worker_context, uuid4().hex, score_set.id) - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_job"] - - for variant in session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ): - assert not variant.gnomad_variants diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index e3833f14..e69de29b 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -1,603 +0,0 @@ -# ruff: noqa: E402 - -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from requests import HTTPError -from sqlalchemy import select - -arq = pytest.importorskip("arq") - - -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.worker.jobs import ( - poll_uniprot_mapping_jobs_for_score_set, - submit_uniprot_mapping_jobs_for_score_set, -) -from tests.helpers.constants import ( - TEST_MINIMAL_SEQ_SCORESET, - TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, - TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, - TEST_UNIPROT_SWISS_PROT_TYPE, - VALID_CHR_ACCESSION, - VALID_UNIPROT_ACCESSION, -) -from tests.helpers.util.setup.worker import ( - setup_records_files_and_variants, - setup_records_files_and_variants_with_mapping, -) - -### Test Submission - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - - assert result["success"] - assert not result["retried"] - assert result["enqueued_jobs"] is not None - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_targets( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - score_set.target_genes = [] - session.add(score_set) - session.commit() - - with patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message: - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called_once() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_while_spawning_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", side_effect=HTTPError()), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_too_many_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", - return_value=["AC1", "AC2"], - ), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message: - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_error_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.external_services.uniprot.setup_job_state", side_effect=Exception()), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_during_submission_generation( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", - side_effect=Exception(), - ), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_no_spawned_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=None), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_submit_uniprot_id_mapping_exception_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(arq.ArqRedis, "enqueue_job", side_effect=Exception()), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await submit_uniprot_mapping_jobs_for_score_set(standalone_worker_context, score_set.id, uuid4().hex) - mock_slack_message.assert_called() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -### Test Polling - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_success( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_targets( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - score_set.target_genes = [] - session.add(score_set) - session.commit() - - with patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message: - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_too_many_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch( - "mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", - return_value=["AC1", "AC2"], - ), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.external_services.uniprot.extract_ids_from_post_mapped_metadata", return_value=[]), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_jobs_not_ready( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=False), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_jobs( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # This case does not get sent to slack - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {}, - score_set.id, - uuid4().hex, - ) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_no_ids_mapped( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value={"failedIDs": [VALID_CHR_ACCESSION]}), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - for target_gene in score_set.target_genes: - assert target_gene.uniprot_id_from_mapped_metadata is None - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_too_many_mapped_accessions( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - # Simulate a response with too many mapped IDs - too_many_mapped_ids_response = TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE.copy() - too_many_mapped_ids_response["results"].append( - {"from": "AC3", "to": {"primaryAccession": "AC3", "entryType": TEST_UNIPROT_SWISS_PROT_TYPE}} - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object(UniProtIDMappingAPI, "get_id_mapping_results", return_value=too_many_mapped_ids_response), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called() - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_error_in_setup( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch("mavedb.worker.jobs.external_services.uniprot.setup_job_state", side_effect=Exception()), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - - -@pytest.mark.asyncio -async def test_poll_uniprot_id_mapping_exception_during_polling( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants_with_mapping( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - with ( - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", side_effect=Exception()), - patch( - "mavedb.worker.jobs.external_services.uniprot.log_and_send_slack_message", return_value=None - ) as mock_slack_message, - ): - result = await poll_uniprot_mapping_jobs_for_score_set( - standalone_worker_context, - {tg.id: f"job_{idx}" for idx, tg in enumerate(score_set.target_genes)}, - score_set.id, - uuid4().hex, - ) - mock_slack_message.assert_called_once() - - assert not result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] diff --git a/tests/worker/jobs/utils/test_setup.py b/tests/worker/jobs/utils/test_setup.py new file mode 100644 index 00000000..096abd2d --- /dev/null +++ b/tests/worker/jobs/utils/test_setup.py @@ -0,0 +1,30 @@ +from unittest.mock import Mock + +import pytest + +from mavedb.models.job_run import JobRun +from mavedb.worker.jobs.utils.setup import validate_job_params + + +@pytest.mark.unit +def test_validate_job_params_success(): + job = Mock(spec=JobRun, job_params={"foo": 1, "bar": 2}) + + # Should not raise + validate_job_params(["foo", "bar"], job) + + +@pytest.mark.unit +def test_validate_job_params_missing_param(): + job = Mock(spec=JobRun, job_params={"foo": 1}) + + with pytest.raises(ValueError, match="Missing required job param: bar"): + validate_job_params(["foo", "bar"], job) + + +@pytest.mark.unit +def test_validate_job_params_no_params(): + job = Mock(spec=JobRun, job_params=None) + + with pytest.raises(ValueError, match="Job has no job_params defined."): + validate_job_params(["foo"], job) diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index b5addb76..e69de29b 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -1,557 +0,0 @@ -# ruff: noqa: E402 - -from asyncio.unix_events import _UnixSelectorEventLoop -from unittest.mock import patch -from uuid import uuid4 - -import pandas as pd -import pytest -from sqlalchemy import select - -arq = pytest.importorskip("arq") -cdot = pytest.importorskip("cdot") - -from mavedb.lib.clingen.services import ( - ClinGenLdhService, -) -from mavedb.lib.mave.constants import HGVS_NT_COLUMN -from mavedb.lib.validation.exceptions import ValidationError -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.enums.processing_state import ProcessingState -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.models.variant import Variant -from mavedb.worker.jobs import ( - create_variants_for_score_set, -) -from mavedb.worker.jobs.utils.constants import MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME -from tests.helpers.constants import ( - TEST_CLINGEN_ALLELE_OBJECT, - TEST_CLINGEN_LDH_LINKING_RESPONSE, - TEST_CLINGEN_SUBMISSION_RESPONSE, - TEST_MINIMAL_ACC_SCORESET, - TEST_MINIMAL_MULTI_TARGET_SCORESET, - TEST_MINIMAL_SEQ_SCORESET, - TEST_NT_CDOT_TRANSCRIPT, - VALID_NT_ACCESSION, -) -from tests.helpers.util.mapping import sanitize_mapping_queue -from tests.helpers.util.setup.worker import setup_mapping_output, setup_records_and_files - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set,validation_error", - [ - ( - TEST_MINIMAL_SEQ_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": ["target sequence mismatch for 'c.1T>A' at row 0 for sequence TEST1"], - }, - ), - ( - TEST_MINIMAL_ACC_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": [ - "Failed to parse row 0 with HGVS exception: NM_001637.3:c.1T>A: Variant reference (T) does not agree with reference sequence (G)." - ], - }, - ), - ( - TEST_MINIMAL_MULTI_TARGET_SCORESET, - { - "exception": "encountered 1 invalid variant strings.", - "detail": ["target sequence mismatch for 'n.1T>A' at row 0 for sequence TEST3"], - }, - ), - ], -) -async def test_create_variants_for_score_set_with_validation_error( - input_score_set, - validation_error, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - if input_score_set == TEST_MINIMAL_SEQ_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "c.1T>A" - elif input_score_set == TEST_MINIMAL_ACC_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = f"{VALID_NT_ACCESSION}:c.1T>A" - elif input_score_set == TEST_MINIMAL_MULTI_TARGET_SCORESET: - scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "TEST3:n.1T>A" - - with ( - patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == validation_error - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_caught_exception( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some exception will be raised no matter what in the async job. - with ( - patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == {"detail": [], "exception": ""} - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_caught_base_exception( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some base exception will be handled no matter what in the async job. - with ( - patch.object(pd.DataFrame, "isnull", side_effect=BaseException), - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_existing_variants( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - await sanitize_mapping_queue(standalone_worker_context, score_set) - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert score_set.processing_errors is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_with_existing_exceptions( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee - # some exception will be raised no matter what in the async job. - with ( - patch.object( - pd.DataFrame, - "isnull", - side_effect=ValidationError("Test Exception", triggers=["exc_1", "exc_2"]), - ) as mocked_exc, - ): - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == { - "exception": "Test Exception", - "detail": ["exc_1", "exc_2"], - } - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert score_set.processing_errors is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set( - input_score_set, - setup_worker_db, - async_client, - standalone_worker_context, - session, - data_files, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp: - result = await create_variants_for_score_set( - standalone_worker_context, - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if all(["targetSequence" in target for target in input_score_set["targetGenes"]]): - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_enqueues_manager_and_successful_mapping( - input_score_set, - setup_worker_db, - session, - async_client, - data_files, - arq_worker, - arq_redis, -): - score_set_is_seq = all(["targetSequence" in target for target in input_score_set["targetGenes"]]) - score_set_is_multi_target = len(input_score_set["targetGenes"]) > 1 - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set, score_set_is_seq, score_set_is_multi_target) - - async def dummy_car_submission_job(): - return TEST_CLINGEN_ALLELE_OBJECT - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - # Variants have not yet been created, so infer their URNs. - async def dummy_linking_job(): - return [(f"{score_set_urn}#{i}", TEST_CLINGEN_LDH_LINKING_RESPONSE) for i in range(1, len(scores) + 1)] - - with ( - patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, - ) as hdp, - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[ - dummy_mapping_job(), - dummy_car_submission_job(), - dummy_ldh_submission_job(), - dummy_linking_job(), - ], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), - ): - await arq_redis.enqueue_job( - "create_variants_for_score_set", - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - await arq_worker.async_run() - await arq_worker.run_check() - - # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. - if score_set_is_seq: - hdp.assert_not_called() - else: - hdp.assert_called_once() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert score_set.num_variants == 3 - assert len(db_variants) == 3 - assert score_set.processing_state == ProcessingState.success - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input_score_set", (TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET) -) -async def test_create_variants_for_score_set_exception_skips_mapping( - input_score_set, - setup_worker_db, - session, - async_client, - data_files, - arq_worker, - arq_redis, -): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - with patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc: - await arq_redis.enqueue_job( - "create_variants_for_score_set", - uuid4().hex, - score_set.id, - 1, - scores, - counts, - score_columns_metadata, - count_columns_metadata, - ) - await arq_worker.async_run() - await arq_worker.run_check() - - mocked_exc.assert_called() - - db_variants = session.scalars(select(Variant)).all() - score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set_urn).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - - assert score_set.num_variants == 0 - assert len(db_variants) == 0 - assert score_set.processing_state == ProcessingState.failed - assert score_set.processing_errors == {"detail": [], "exception": ""} - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.not_attempted - assert score_set.mapping_errors is None diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index 9606e2e0..e69de29b 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -1,710 +0,0 @@ -# ruff: noqa: E402 - -from asyncio.unix_events import _UnixSelectorEventLoop -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from sqlalchemy import select - -arq = pytest.importorskip("arq") - -from mavedb.lib.clingen.services import ( - ClinGenAlleleRegistryService, - ClinGenLdhService, -) -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.models.variant import Variant -from mavedb.worker.jobs import ( - variant_mapper_manager, -) -from mavedb.worker.jobs.utils.constants import MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME -from tests.helpers.constants import ( - TEST_CLINGEN_ALLELE_OBJECT, - TEST_CLINGEN_LDH_LINKING_RESPONSE, - TEST_CLINGEN_SUBMISSION_RESPONSE, - TEST_GNOMAD_DATA_VERSION, - TEST_MINIMAL_SEQ_SCORESET, - TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, - TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, -) -from tests.helpers.util.exceptions import awaitable_exception -from tests.helpers.util.setup.worker import setup_mapping_output, setup_records_files_and_variants - - -@pytest.mark.asyncio -async def test_mapping_manager_empty_queue(setup_worker_db, standalone_worker_context): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. - assert result["enqueued_job"] is None - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - - -@pytest.mark.asyncio -async def test_mapping_manager_empty_queue_error_during_setup(setup_worker_db, standalone_worker_context): - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.ArqRedis, "rpop", Exception()): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # No new jobs should have been created if nothing is in the queue, and the queue should remain empty. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. - assert result["enqueued_job"] is not None - assert ( - await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 1 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set.id) - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_not_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Mapping job should be queued if none is currently running, and the queue should now be empty. - assert result["enqueued_job"] is not None - assert ( - await arq.jobs.Job(result["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.queued - assert result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - # We don't actually start processing these score sets. - assert score_set.mapping_state == MappingState.queued - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_in_progress_error_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with ( - patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Execution should be deferred if a job is in progress, and the queue should contain one entry which is the deferred ID. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert (await standalone_worker_context["redis"].get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "5" - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_mapping_manager_occupied_queue_mapping_not_in_progress_error_during_enqueue( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with ( - patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found), - patch.object(arq.ArqRedis, "enqueue_job", return_value=awaitable_exception()), - ): - result = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Enqueue would have failed, the job is unsuccessful, and we remove the queued item. - assert result["enqueued_job"] is None - assert not result["success"] - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None - - -@pytest.mark.asyncio -async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set_id_1 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_2 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_3 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "5") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # All three jobs should complete successfully... - assert result1["success"] - assert result2["success"] - assert result3["success"] - - # ...with a new job enqueued... - assert result1["enqueued_job"] is not None - assert result2["enqueued_job"] is not None - assert result3["enqueued_job"] is not None - - # ...of which all should be deferred jobs of the "variant_mapper_manager" variety... - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - - # ...and the queue state should have three jobs, each of our three created score sets. - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 3 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_1) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) - - score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() - score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() - score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() - # Each score set should remain queued with no mapping errors. - assert score_set1.mapping_state == MappingState.queued - assert score_set2.mapping_state == MappingState.queued - assert score_set3.mapping_state == MappingState.queued - assert score_set1.mapping_errors is None - assert score_set2.mapping_errors is None - assert score_set3.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_multiple_score_sets_occupy_queue_mapping_not_in_progress( - setup_worker_db, standalone_worker_context, session, async_client, data_files -): - score_set_id_1 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_2 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - score_set_id_3 = ( - await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - ).id - - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, "") - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.not_found): - result1 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # Mock the first job being in-progress - await standalone_worker_context["redis"].set(MAPPING_CURRENT_ID_NAME, str(score_set_id_1)) - with patch.object(arq.jobs.Job, "status", return_value=arq.jobs.JobStatus.in_progress): - result2 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - result3 = await variant_mapper_manager(standalone_worker_context, uuid4().hex, 1) - - # All three jobs should complete successfully... - assert result1["success"] - assert result2["success"] - assert result3["success"] - - # ...with a new job enqueued... - assert result1["enqueued_job"] is not None - assert result2["enqueued_job"] is not None - assert result3["enqueued_job"] is not None - - # ...of which the first should be a queued job of the "map_variants_for_score_set" variety and the other two should be - # deferred jobs of the "variant_mapper_manager" variety... - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.queued - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).status() - ) == arq.jobs.JobStatus.deferred - - assert ( - await arq.jobs.Job(result1["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "map_variants_for_score_set" - assert ( - await arq.jobs.Job(result2["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - assert ( - await arq.jobs.Job(result3["enqueued_job"], standalone_worker_context["redis"]).info() - ).function == "variant_mapper_manager" - - # ...and the queue state should have two jobs, neither of which should be the first score set. - assert (await standalone_worker_context["redis"].llen(MAPPING_QUEUE_NAME)) == 2 - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_2) - assert (await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME)).decode("utf-8") == str(score_set_id_3) - - score_set1 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_1)).one() - score_set2 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_2)).one() - score_set3 = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.id == score_set_id_3)).one() - # We don't actually process any score sets in the manager job, and each should have no mapping errors. - assert score_set1.mapping_state == MappingState.queued - assert score_set2.mapping_state == MappingState.queued - assert score_set3.mapping_state == MappingState.queued - assert score_set1.mapping_errors is None - assert score_set2.mapping_errors is None - assert score_set3.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mapping output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", True), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed all jobs exactly once. - assert num_completed_jobs == 8 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_disabled( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mapping output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager and mapping jobs, but not the submission, linking, or uniprot mapping jobs. - assert num_completed_jobs == 2 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_disabled_uniprot_enabled( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mapping output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(UniProtIDMappingAPI, "submit_id_mapping", return_value=TEST_UNIPROT_JOB_SUBMISSION_RESPONSE), - patch.object(UniProtIDMappingAPI, "check_id_mapping_results_ready", return_value=True), - patch.object( - UniProtIDMappingAPI, "get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE - ), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", True), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager, mapping, and uniprot jobs, but not the submission or linking jobs. - assert num_completed_jobs == 4 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_successful_mapping_linking_enabled_uniprot_disabled( - setup_worker_db, - standalone_worker_context, - session, - async_client, - data_files, - arq_worker, - arq_redis, - mocked_gnomad_variant_row, -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mapping output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[dummy_mapping_job(), dummy_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch( - "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", - return_value=[mocked_gnomad_variant_row], - ), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the manager, mapping, submission, and linking jobs, but not the uniprot jobs. - assert num_completed_jobs == 6 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_retried_mapping_successful_mapping_on_retry( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_mapping_job(): - return Exception() - - async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) - - async def dummy_ldh_submission_job(): - return [TEST_CLINGEN_SUBMISSION_RESPONSE, None] - - async def dummy_linking_job(): - return [ - (variant_urn, TEST_CLINGEN_LDH_LINKING_RESPONSE) - for variant_urn in session.scalars( - select(Variant.urn).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) - ).all() - ] - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mapping output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[failed_mapping_job(), dummy_mapping_job(), dummy_ldh_submission_job(), dummy_linking_job()], - ), - patch.object(ClinGenLdhService, "_existing_jwt", return_value="test_jwt"), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.external_services.clingen.LINKING_BACKOFF_IN_SECONDS", 0), - patch("mavedb.worker.jobs.variant_processing.mapping.UNIPROT_ID_MAPPING_ENABLED", False), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", True), - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), - patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), - patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION), - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT]), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed the mapping manager job twice, the mapping job twice, the two submission jobs, and both linking jobs. - assert num_completed_jobs == 8 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == score_set.num_variants - assert score_set.mapping_state == MappingState.complete - assert score_set.mapping_errors is None - - -@pytest.mark.asyncio -async def test_mapping_manager_enqueues_mapping_process_with_unsuccessful_mapping( - setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis -): - score_set = await setup_records_files_and_variants( - session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, - standalone_worker_context, - ) - - async def failed_mapping_job(): - return Exception() - - # We seem unable to mock requests via requests_mock that occur inside another event loop. Workaround - # this limitation by instead patching the _UnixSelectorEventLoop 's executor function, with a coroutine - # object that sets up test mapping output. - with ( - patch.object( - _UnixSelectorEventLoop, - "run_in_executor", - side_effect=[failed_mapping_job()] * 5, - ), - patch("mavedb.worker.jobs.variant_processing.mapping.MAPPING_BACKOFF_IN_SECONDS", 0), - ): - await arq_worker.async_run() - num_completed_jobs = await arq_worker.run_check() - - # We should have completed 6 mapping jobs and 6 management jobs. - assert num_completed_jobs == 12 - - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - mapped_variants_for_score_set = session.scalars( - select(MappedVariant).join(Variant).join(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set.urn) - ).all() - assert (await arq_redis.llen(MAPPING_QUEUE_NAME)) == 0 - assert (await arq_redis.get(MAPPING_CURRENT_ID_NAME)).decode("utf-8") == "" - assert len(mapped_variants_for_score_set) == 0 - assert score_set.mapping_state == MappingState.failed - assert score_set.mapping_errors is not None diff --git a/tests/worker/lib/conftest.py b/tests/worker/lib/conftest.py deleted file mode 100644 index faf63e0e..00000000 --- a/tests/worker/lib/conftest.py +++ /dev/null @@ -1,192 +0,0 @@ -# ruff: noqa: E402 - -""" -Test configuration and fixtures for worker lib tests. -""" - -from datetime import datetime -from unittest.mock import Mock - -import pytest - -from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus -from mavedb.models.job_dependency import JobDependency -from mavedb.models.job_run import JobRun -from mavedb.models.pipeline import Pipeline - -# Attempt to import optional top level fixtures. If the modules they depend on are not installed, -# we won't have access to our full fixture suite and only a limited subset of tests can be run. -try: - from .conftest_optional import * # noqa: F401, F403 - -except ModuleNotFoundError: - pass - - -@pytest.fixture -def sample_job_run(): - """Create a sample JobRun instance for testing.""" - return JobRun( - id=1, - urn="test:job:1", - job_type="test_job", - job_function="test_function", - status=JobStatus.PENDING, - pipeline_id=1, - progress_current=0, - progress_total=100, - progress_message="Ready to start", - created_at=datetime.now(), - ) - - -@pytest.fixture -def sample_dependent_job_run(): - """Create a sample dependent JobRun instance for testing.""" - return JobRun( - id=2, - urn="test:job:2", - job_type="dependent_job", - job_function="dependent_function", - status=JobStatus.PENDING, - pipeline_id=1, - progress_current=0, - progress_total=100, - progress_message="Waiting for dependency", - created_at=datetime.now(), - ) - - -@pytest.fixture -def sample_independent_job_run(): - """Create a sample independent JobRun instance for testing.""" - return JobRun( - id=3, - urn="test:job:3", - job_type="independent_job", - job_function="independent_function", - status=JobStatus.PENDING, - pipeline_id=None, - progress_current=0, - progress_total=100, - progress_message="Ready to start", - created_at=datetime.now(), - ) - - -@pytest.fixture -def sample_pipeline(): - """Create a sample Pipeline instance for testing.""" - return Pipeline( - id=1, - urn="test:pipeline:1", - name="Test Pipeline", - description="A test pipeline", - status=PipelineStatus.CREATED, - correlation_id="test_correlation_123", - created_at=datetime.now(), - ) - - -@pytest.fixture -def sample_empty_pipeline(): - """Create a sample Pipeline instance with no jobs for testing.""" - return Pipeline( - id=999, - urn="test:pipeline:999", - name="Empty Pipeline", - description="A pipeline with no jobs", - status=PipelineStatus.CREATED, - correlation_id="empty_correlation_456", - created_at=datetime.now(), - ) - - -@pytest.fixture -def sample_job_dependency(): - """Create a sample JobDependency instance for testing.""" - return JobDependency( - id=2, # dependent job - depends_on_job_id=1, # depends on job 1 - dependency_type=DependencyType.SUCCESS_REQUIRED, - created_at=datetime.now(), - ) - - -@pytest.fixture -def setup_worker_db( - session, - sample_job_run, - sample_pipeline, - sample_empty_pipeline, - sample_job_dependency, - sample_dependent_job_run, - sample_independent_job_run, -): - """Set up the database with sample data for worker tests.""" - session.add(sample_pipeline) - session.add(sample_empty_pipeline) - session.add(sample_job_run) - session.add(sample_dependent_job_run) - session.add(sample_independent_job_run) - session.add(sample_job_dependency) - session.commit() - - -@pytest.fixture -def mock_pipeline(): - """Create a mock Pipeline instance. By default, - properties are identical to a default new Pipeline entered into the db - with sensible defaults for non-nullable but unset fields. - """ - return Mock( - spec=Pipeline, - id=1, - urn="test:pipeline:1", - name="Test Pipeline", - description="A test pipeline", - status=PipelineStatus.CREATED, - correlation_id="test_correlation_123", - metadata_={}, - created_at=datetime.now(), - started_at=None, - finished_at=None, - created_by_user_id=None, - mavedb_version=None, - ) - - -@pytest.fixture -def mock_job_run(mock_pipeline): - """Create a mock JobRun instance. By default, - properties are identical to a default new JobRun entered into the db - with sensible defaults for non-nullable but unset fields. - """ - return Mock( - spec=JobRun, - id=123, - urn="test:job:123", - job_type="test_job", - job_function="test_function", - status=JobStatus.PENDING, - pipeline_id=mock_pipeline.id, - priority=0, - max_retries=3, - retry_count=0, - retry_delay_seconds=None, - scheduled_at=datetime.now(), - started_at=None, - finished_at=None, - created_at=datetime.now(), - error_message=None, - error_traceback=None, - failure_category=None, - worker_id=None, - worker_host=None, - progress_current=None, - progress_total=None, - progress_message=None, - correlation_id=None, - metadata_={}, - mavedb_version=None, - ) From 8c5e225e81b659d257cc19fcbd7400b72fb7a3da Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 23 Jan 2026 11:46:21 -0800 Subject: [PATCH 24/70] refactor: reduce mocking of database across worker tests --- tests/worker/conftest_optional.py | 12 ++-- .../lib/decorators/test_job_guarantee.py | 25 +++----- .../decorators/test_pipeline_management.py | 64 +++++++------------ 3 files changed, 40 insertions(+), 61 deletions(-) diff --git a/tests/worker/conftest_optional.py b/tests/worker/conftest_optional.py index badebab2..a3a00f54 100644 --- a/tests/worker/conftest_optional.py +++ b/tests/worker/conftest_optional.py @@ -2,6 +2,7 @@ import pytest from arq import ArqRedis +from cdot.hgvs.dataproviders import RESTDataProvider from sqlalchemy.orm import Session from mavedb.worker.lib.managers.job_manager import JobManager @@ -45,13 +46,16 @@ def mock_pipeline_manager(mock_job_manager, mock_pipeline): @pytest.fixture -def mock_worker_ctx(): +def mock_worker_ctx(session): """Create a mock worker context dictionary for testing.""" - mock_db = Mock(spec=Session) mock_redis = Mock(spec=ArqRedis) + mock_hdp = Mock(spec=RESTDataProvider) + # Don't mock the session itself to allow real DB interactions in tests + # It's generally more pain than it's worth to mock out SQLAlchemy sessions, + # although it can sometimes be useful when raising specific exceptions. return { - "db": mock_db, + "db": session, "redis": mock_redis, - "hdp": Mock(), # Mock HDP data provider + "hdp": mock_hdp, } diff --git a/tests/worker/lib/decorators/test_job_guarantee.py b/tests/worker/lib/decorators/test_job_guarantee.py index cfdc40a1..2e1faf70 100644 --- a/tests/worker/lib/decorators/test_job_guarantee.py +++ b/tests/worker/lib/decorators/test_job_guarantee.py @@ -9,7 +9,6 @@ pytest.importorskip("arq") # Skip tests if arq is not installed import os -from unittest.mock import MagicMock, patch from sqlalchemy import select @@ -59,27 +58,21 @@ async def test_decorator_must_receive_db_in_ctx(self, mock_worker_ctx): assert "DB session not found in job context" in str(exc_info.value) async def test_decorator_calls_wrapped_function(self, mock_worker_ctx): - with patch("mavedb.worker.lib.decorators.job_guarantee.JobRun") as MockJobRunClass: - MockJobRunClass.return_value = MagicMock(spec=JobRun) - result = await sample_job(mock_worker_ctx) - + result = await sample_job(mock_worker_ctx) assert result == {"status": "ok"} - async def test_decorator_creates_job_run(self, mock_worker_ctx, mock_job_run): + async def test_decorator_creates_job_run(self, mock_worker_ctx): with ( - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), - patch("mavedb.worker.lib.decorators.job_guarantee.JobRun") as mock_job_run_class, + TransactionSpy.spy(mock_worker_ctx["db"], expect_flush=True, expect_commit=True), ): - mock_job_run_class.return_value = MagicMock(spec=JobRun) await sample_job(mock_worker_ctx) - mock_job_run_class.assert_called_with( - job_type="test_job", - job_function="sample_job", - status=JobStatus.PENDING, - mavedb_version=__version__, - ) - mock_worker_ctx["db"].add.assert_called() + job_run = mock_worker_ctx["db"].execute(select(JobRun)).scalars().first() + assert job_run is not None + assert job_run.status == JobStatus.PENDING + assert job_run.job_type == "test_job" + assert job_run.job_function == "sample_job" + assert job_run.mavedb_version == __version__ @pytest.mark.asyncio diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index f7b2bc1e..ec947080 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -11,7 +11,7 @@ import asyncio import os -from unittest.mock import MagicMock, patch +from unittest.mock import patch from sqlalchemy import select @@ -88,15 +88,12 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): await sample_job(mock_worker_ctx, 999) async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manager( - self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( # patch the with_job_management decorator to be a no-op patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object( - mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) - ) as mock_execute, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), @@ -108,21 +105,17 @@ async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manage async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) - mock_execute.assert_called_once() assert result == {"status": "ok"} async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( - self, mock_pipeline_manager, mock_worker_ctx + self, mock_pipeline_manager, mock_worker_ctx, sample_independent_job_run, with_populated_job_data ): with ( # patch the with_job_management decorator to be a no-op patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object( - mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=None)) - ) as mock_execute, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, # We shouldn't expect any commits since no pipeline coordination occurs @@ -134,23 +127,21 @@ async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + result = await sample_job( + mock_worker_ctx, sample_independent_job_run.id, pipeline_manager=mock_pipeline_manager + ) - mock_execute.assert_called_once() mock_coordinate_pipeline.assert_not_called() mock_start_pipeline.assert_not_called() assert result == {"status": "ok"} async def test_decorator_starts_pipeline_when_in_created_state( - self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( # patch the with_job_management decorator to be a no-op patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object( - mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) - ) as mock_execute, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), @@ -162,9 +153,8 @@ async def test_decorator_starts_pipeline_when_in_created_state( async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) - mock_execute.assert_called_once() mock_start_pipeline.assert_called_once() assert result == {"status": "ok"} @@ -173,15 +163,12 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): [status for status in PipelineStatus._member_map_.values() if status != PipelineStatus.CREATED], ) async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( - self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline, pipeline_state + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data, pipeline_state ): with ( # patch the with_job_management decorator to be a no-op patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object( - mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) - ) as mock_execute, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_state), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), @@ -193,14 +180,13 @@ async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) - mock_execute.assert_called_once() mock_start_pipeline.assert_not_called() assert result == {"status": "ok"} async def test_decorator_calls_wrapped_function_and_returns_result( - self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( # patch the with_job_management decorator to be a no-op @@ -208,9 +194,6 @@ async def test_decorator_calls_wrapped_function_and_returns_result( "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f ) as mock_with_job_mgmt, patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object( - mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) - ), patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), @@ -222,13 +205,13 @@ async def test_decorator_calls_wrapped_function_and_returns_result( async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - result = await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) mock_with_job_mgmt.assert_called_once() assert result == {"status": "ok"} async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrapped_function( - self, mock_pipeline_manager, mock_worker_ctx, mock_pipeline + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( # patch the with_job_management decorator to be a no-op @@ -237,9 +220,6 @@ async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrappe wraps=lambda f: f, ), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object( - mock_worker_ctx["db"], "execute", return_value=MagicMock(scalar_one=MagicMock(return_value=123)) - ), patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), @@ -251,11 +231,13 @@ async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrappe async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) mock_coordinate_pipeline.assert_called_once() - async def test_decorator_swallows_exception_from_wrapped_function(self, mock_pipeline_manager, mock_worker_ctx): + async def test_decorator_swallows_exception_from_wrapped_function( + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + ): with ( # patch the with_job_management decorator to be a no-op patch( @@ -274,12 +256,12 @@ async def test_decorator_swallows_exception_from_wrapped_function(self, mock_pip async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): raise RuntimeError("error in wrapped function") - await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) # TODO: Assert calls for notification hooks and job result data async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pipeline( - self, mock_pipeline_manager, mock_worker_ctx + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( # patch the with_job_management decorator to be a no-op @@ -305,12 +287,12 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) # TODO: Assert calls for notification hooks and job result data async def test_decorator_swallows_exception_from_job_management_decorator( - self, mock_pipeline_manager, mock_worker_ctx + self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): def passthrough_decorator(f): return f @@ -333,7 +315,7 @@ def passthrough_decorator(f): async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): return {"status": "ok"} - await sample_job(mock_worker_ctx, 999, pipeline_manager=mock_pipeline_manager) + await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) mock_with_job_mgmt.assert_called_once() # TODO: Assert calls for notification hooks and job result data From b0397b485e5990097af143117710a772f2840454 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 23 Jan 2026 12:19:48 -0800 Subject: [PATCH 25/70] refactor: simplify job definition in job management tests --- .../lib/decorators/test_job_management.py | 88 +++++++++---------- 1 file changed, 40 insertions(+), 48 deletions(-) diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index d22a37ee..ba8320f7 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -31,24 +31,44 @@ def unset_test_mode_flag(): os.environ.pop("MAVEDB_TEST_MODE", None) +@with_job_management +async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + """Sample job function to test the decorator. + + NOTE: The job_manager parameter is injected by the decorator + and is not passed explicitly when calling the function. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + return {"status": "ok"} + + +@with_job_management +async def sample_raise(ctx: dict, job_id: int, job_manager: JobManager): + """Sample job function to test the decorator in cases where the wrapped function raises an exception. + + NOTE: The job_manager parameter is injected by the decorator + and is not passed explicitly when calling the function. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + raise RuntimeError("error in wrapped function") + + @pytest.mark.asyncio @pytest.mark.unit class TestManagedJobDecoratorUnit: async def test_decorator_must_receive_ctx_as_first_argument(self, mock_job_manager): - @with_job_management - async def sample_job(not_ctx: dict, job_id: int, job_manager: JobManager): - return {"status": "ok"} - with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_job_manager.db): await sample_job() assert "Managed job functions must receive context as first argument" in str(exc_info.value) async def test_decorator_calls_wrapped_function_and_returns_result(self, mock_job_manager, mock_worker_ctx): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - return {"status": "ok"} - with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None), @@ -57,16 +77,12 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): ): mock_job_manager_class.return_value = mock_job_manager - result = await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + result = await sample_job(mock_worker_ctx, 999) assert result == {"status": "ok"} async def test_decorator_calls_start_job_and_succeed_job_when_wrapped_function_succeeds( self, mock_worker_ctx, mock_job_manager ): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - return {"status": "ok"} - with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, @@ -74,7 +90,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager - await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + await sample_job(mock_worker_ctx, 999) mock_start_job.assert_called_once() mock_succeed_job.assert_called_once() @@ -82,10 +98,6 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_raises_and_no_retry( self, mock_worker_ctx, mock_job_manager ): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - raise RuntimeError("error in wrapped function") - with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, @@ -94,7 +106,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager - await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + await sample_raise(mock_worker_ctx, 999) mock_start_job.assert_called_once() mock_fail_job.assert_called_once() @@ -102,10 +114,6 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_raises_and_retry( self, mock_worker_ctx, mock_job_manager ): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - raise RuntimeError("error in wrapped function") - with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, @@ -114,7 +122,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager - await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + await sample_raise(mock_worker_ctx, 999) mock_start_job.assert_called_once() mock_prepare_retry.assert_called_once_with(reason="error in wrapped function") @@ -123,14 +131,10 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): async def test_decorator_raises_value_error_if_required_context_missing( self, mock_job_manager, mock_worker_ctx, missing_key ): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - return {"status": "ok"} - del mock_worker_ctx[missing_key] with pytest.raises(ValueError) as exc_info: - await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + await sample_job(mock_worker_ctx, 999) assert missing_key.replace("_", " ") in str(exc_info.value).lower() assert "not found in job context" in str(exc_info.value).lower() @@ -138,10 +142,6 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( self, mock_job_manager, mock_worker_ctx ): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - return {"status": "ok"} - with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", side_effect=JobStateError("error in job start")), @@ -150,15 +150,11 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): TransactionSpy.spy(mock_worker_ctx["db"], expect_rollback=True, expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager - result = await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + result = await sample_job(mock_worker_ctx, 999) assert "error in job start" in result["exception_details"]["message"] async def test_decorator_raises_value_error_if_job_id_missing(self, mock_job_manager, mock_worker_ctx): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - return {"status": "ok"} - # Remove job_id from args to simulate missing job_id with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_worker_ctx["db"]): await sample_job(mock_worker_ctx) @@ -168,10 +164,6 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): async def test_decorator_swallows_exception_from_wrapped_function_inside_except( self, mock_job_manager, mock_worker_ctx ): - @with_job_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): - raise RuntimeError("error in wrapped function") - with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None), @@ -180,14 +172,14 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager - result = await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + result = await sample_raise(mock_worker_ctx, 999) # Errors within the main try block should take precedence assert "error in wrapped function" in result["exception_details"]["message"] async def test_decorator_passes_job_manager_to_wrapped(self, mock_job_manager, mock_worker_ctx): @with_job_management - async def sample_job(ctx, job_id: int, job_manager): + async def assert_manager_passed_job(ctx, job_id: int, job_manager): assert isinstance(job_manager, JobManager) return True @@ -198,7 +190,7 @@ async def sample_job(ctx, job_id: int, job_manager): TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager - assert await sample_job(mock_worker_ctx, 999, job_manager=mock_job_manager) + assert await assert_manager_passed_job(mock_worker_ctx, 999) @pytest.mark.asyncio @@ -218,7 +210,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): return {"status": "ok"} # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) # At this point, the job should be started but not completed await asyncio.sleep(0.1) # Give the event loop a moment to start the job @@ -245,7 +237,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): raise RuntimeError("Simulated job failure") # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) # At this point, the job should be started but not in error await asyncio.sleep(0.1) # Give the event loop a moment to start the job @@ -275,7 +267,7 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): raise RuntimeError("Simulated job failure for retry") # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) # At this point, the job should be started but not in error await asyncio.sleep(0.1) # Give the event loop a moment to start the job From a716cc99087f4b9c2276e812a43bdc40bda8d6d8 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 23 Jan 2026 12:34:31 -0800 Subject: [PATCH 26/70] refactor: simplify job definition in job management tests --- .../decorators/test_pipeline_management.py | 205 +++++++----------- 1 file changed, 79 insertions(+), 126 deletions(-) diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index ec947080..1b8ae22f 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -7,6 +7,8 @@ import pytest +from mavedb.worker.lib.managers.job_manager import JobManager + pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio @@ -19,7 +21,6 @@ from mavedb.models.job_run import JobRun from mavedb.models.pipeline import Pipeline from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management -from mavedb.worker.lib.managers.job_manager import JobManager from mavedb.worker.lib.managers.pipeline_manager import PipelineManager from tests.helpers.transaction_spy import TransactionSpy @@ -31,16 +32,68 @@ def unset_test_mode_flag(): os.environ.pop("MAVEDB_TEST_MODE", None) +async def sample_job(ctx=None, job_id=None): + """Sample job function to test the decorator. When called, it patches + the with_job_management decorator to be a no-op so we can test the + with_pipeline_management decorator in isolation. + + NOTE: The job_manager parameter is normally injected by the with_job_management + decorator. Since we are patching that decorator to be a no-op here, + we do not include it in the function signature. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + # patch the with_job_management decorator to be a no-op + with patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f + ) as mock_job_mgmt: + + @with_pipeline_management + async def patched_sample_job(ctx: dict, job_id: int): + return {"status": "ok"} + + return await patched_sample_job(ctx, job_id) + + # Ensure the mock was called + mock_job_mgmt.assert_called_once() + + +async def sample_raise(ctx: dict, job_id: int): + """Sample job function to test the decorator when a job raises. + When called, it patches the with_job_management decorator to be + a no-op so we can test the with_pipeline_management decorator in isolation. + + NOTE: The job_manager parameter is normally injected by the with_job_management + decorator. Since we are patching that decorator to be a no-op here, + we do not include it in the function signature. + + Args: + ctx (dict): Worker context dictionary. + job_id (int): ID of the JobRun record created by the decorator. + """ + # patch the with_job_management decorator to be a no-op + with patch( + "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f + ) as mock_job_mgmt: + + @with_pipeline_management + async def patched_sample_job(ctx: dict, job_id: int): + raise RuntimeError("error in wrapped function") + + return await patched_sample_job(ctx, job_id) + + # Ensure the mock was called + mock_job_mgmt.assert_called_once() + + @pytest.mark.asyncio @pytest.mark.unit class TestPipelineManagementDecoratorUnit: """Unit tests for the with_pipeline_management decorator.""" async def test_decorator_must_receive_ctx_as_first_argument(self, mock_pipeline_manager): - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): await sample_job() @@ -50,34 +103,22 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): async def test_decorator_raises_value_error_if_required_context_missing( self, mock_pipeline_manager, mock_worker_ctx, missing_key ): - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - del mock_worker_ctx[missing_key] with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): - await sample_job(mock_worker_ctx, 999, mock_pipeline_manager) + await sample_job(mock_worker_ctx, 999) assert missing_key.replace("_", " ") in str(exc_info.value).lower() assert "not found in pipeline context" in str(exc_info.value).lower() async def test_decorator_raises_value_error_if_job_id_missing(self, mock_pipeline_manager, mock_worker_ctx): - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - # Remove job_id from args to simulate missing job_id with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): - await sample_job(mock_worker_ctx, mock_pipeline_manager) + await sample_job(mock_worker_ctx) assert "job id not found in pipeline context" in str(exc_info.value).lower() async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id(self, mock_pipeline_manager, mock_worker_ctx): - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - with ( TransactionSpy.mock_database_execution_failure( mock_worker_ctx["db"], @@ -91,21 +132,13 @@ async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manage self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( - # patch the with_job_management decorator to be a no-op - patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - # Sample jobs should be defined within the with scope to mock the job management decorator - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id) assert result == {"status": "ok"} @@ -113,8 +146,6 @@ async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( self, mock_pipeline_manager, mock_worker_ctx, sample_independent_job_run, with_populated_job_data ): with ( - # patch the with_job_management decorator to be a no-op - patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, @@ -122,14 +153,7 @@ async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( TransactionSpy.spy(mock_worker_ctx["db"]), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - result = await sample_job( - mock_worker_ctx, sample_independent_job_run.id, pipeline_manager=mock_pipeline_manager - ) + result = await sample_job(mock_worker_ctx, sample_independent_job_run.id) mock_coordinate_pipeline.assert_not_called() mock_start_pipeline.assert_not_called() @@ -139,8 +163,6 @@ async def test_decorator_starts_pipeline_when_in_created_state( self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( - # patch the with_job_management decorator to be a no-op - patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, @@ -148,12 +170,7 @@ async def test_decorator_starts_pipeline_when_in_created_state( TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id) mock_start_pipeline.assert_called_once() assert result == {"status": "ok"} @@ -166,8 +183,6 @@ async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data, pipeline_state ): with ( - # patch the with_job_management decorator to be a no-op - patch("mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_state), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, @@ -175,50 +190,15 @@ async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + result = await sample_job(mock_worker_ctx, sample_job_run.id) mock_start_pipeline.assert_not_called() assert result == {"status": "ok"} - async def test_decorator_calls_wrapped_function_and_returns_result( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data - ): - with ( - # patch the with_job_management decorator to be a no-op - patch( - "mavedb.worker.lib.decorators.pipeline_management.with_job_management", wraps=lambda f: f - ) as mock_with_job_mgmt, - patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), - patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), - patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), - ): - mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - result = await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) - - mock_with_job_mgmt.assert_called_once() - assert result == {"status": "ok"} - async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrapped_function( self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( - # patch the with_job_management decorator to be a no-op - patch( - "mavedb.worker.lib.decorators.pipeline_management.with_job_management", - wraps=lambda f: f, - ), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, @@ -226,12 +206,7 @@ async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrappe TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + await sample_job(mock_worker_ctx, sample_job_run.id) mock_coordinate_pipeline.assert_called_once() @@ -239,11 +214,6 @@ async def test_decorator_swallows_exception_from_wrapped_function( self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( - # patch the with_job_management decorator to be a no-op - patch( - "mavedb.worker.lib.decorators.pipeline_management.with_job_management", - wraps=lambda f: f, - ), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), @@ -251,12 +221,7 @@ async def test_decorator_swallows_exception_from_wrapped_function( TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - raise RuntimeError("error in wrapped function") - - await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + await sample_raise(mock_worker_ctx, sample_job_run.id) # TODO: Assert calls for notification hooks and job result data @@ -264,11 +229,6 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( - # patch the with_job_management decorator to be a no-op - patch( - "mavedb.worker.lib.decorators.pipeline_management.with_job_management", - wraps=lambda f: f, - ), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object( mock_pipeline_manager, @@ -282,12 +242,7 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager - - @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): - return {"status": "ok"} - - await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) + await sample_job(mock_worker_ctx, sample_job_run.id) # TODO: Assert calls for notification hooks and job result data @@ -348,17 +303,17 @@ async def test_decorator_integrated_pipeline_lifecycle_success( session.commit() @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + async def sample_job(ctx: dict, job_id: int): await event.wait() # Simulate async work, block until test signals return {"status": "ok"} @with_pipeline_management - async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): + async def sample_dependent_job(ctx: dict, job_id: int): await dep_event.wait() # Simulate async work, block until test signals return {"status": "ok"} # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) # At this point, the job should be started but not completed await asyncio.sleep(0.1) # Give the event loop a moment to start the job @@ -389,7 +344,7 @@ async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): # Simulate execution of next job by running the dependent job. # Start the job (it will block at event.wait()) dependent_job_task = asyncio.create_task( - sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id, job_manager=None) + sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id) ) # At this point, the job should be started but not completed @@ -434,22 +389,22 @@ async def test_decorator_integrated_pipeline_lifecycle_retryable_failure( dep_event = asyncio.Event() @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + async def sample_job(ctx: dict, job_id: int): await event.wait() # Simulate async work, block until test signals raise RuntimeError("Simulated job failure for retry") @with_pipeline_management - async def sample_retried_job(ctx: dict, job_id: int, job_manager: JobManager): + async def sample_retried_job(ctx: dict, job_id: int): await retry_event.wait() # Simulate async work, block until test signals return {"status": "ok"} @with_pipeline_management - async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): + async def sample_dependent_job(ctx: dict, job_id: int): await dep_event.wait() # Simulate async work, block until test signals return {"status": "ok"} # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) # At this point, the job should be started but not completed await asyncio.sleep(0.1) # Give the event loop a moment to start the job @@ -471,9 +426,7 @@ async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): assert job.retry_count == 1 # Ensure it attempted once before retrying # Now start the retried job (it will block at retry_event.wait()) - retried_job_task = asyncio.create_task( - sample_retried_job(standalone_worker_context, sample_job_run.id, job_manager=None) - ) + retried_job_task = asyncio.create_task(sample_retried_job(standalone_worker_context, sample_job_run.id)) await asyncio.sleep(0.1) # Give the event loop a moment to start the job job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.RUNNING @@ -500,7 +453,7 @@ async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): # Simulate execution of next job by running the dependent job. # Start the job (it will block at event.wait()) dependent_job_task = asyncio.create_task( - sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id, job_manager=None) + sample_dependent_job(standalone_worker_context, sample_dependent_job_run.id) ) # At this point, the job should be started but not completed @@ -542,12 +495,12 @@ async def test_decorator_integrated_pipeline_lifecycle_non_retryable_failure( event = asyncio.Event() @with_pipeline_management - async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + async def sample_job(ctx: dict, job_id: int): await event.wait() # Simulate async work, block until test signals raise RuntimeError("Simulated job failure") # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id, job_manager=None)) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) # At this point, the job should be started but not completed await asyncio.sleep(0.1) # Give the event loop a moment to start the job From 8a2230662836d1f81091a7c04a1a34e25caa65da Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 23 Jan 2026 12:40:21 -0800 Subject: [PATCH 27/70] refactor: centralize decorator test mode flag fixture --- tests/worker/lib/decorators/conftest.py | 10 ++++++++++ .../lib/decorators/test_job_guarantee.py | 9 --------- .../lib/decorators/test_job_management.py | 8 -------- .../decorators/test_pipeline_management.py | 20 ++++++------------- 4 files changed, 16 insertions(+), 31 deletions(-) create mode 100644 tests/worker/lib/decorators/conftest.py diff --git a/tests/worker/lib/decorators/conftest.py b/tests/worker/lib/decorators/conftest.py new file mode 100644 index 00000000..851d7497 --- /dev/null +++ b/tests/worker/lib/decorators/conftest.py @@ -0,0 +1,10 @@ +import os + +import pytest + + +# Unset test mode flag before each test to ensure decorator logic is executed +# during unit testing of the decorator itself. +@pytest.fixture(autouse=True) +def unset_test_mode_flag(): + os.environ.pop("MAVEDB_TEST_MODE", None) diff --git a/tests/worker/lib/decorators/test_job_guarantee.py b/tests/worker/lib/decorators/test_job_guarantee.py index 2e1faf70..1371fed3 100644 --- a/tests/worker/lib/decorators/test_job_guarantee.py +++ b/tests/worker/lib/decorators/test_job_guarantee.py @@ -8,8 +8,6 @@ pytest.importorskip("arq") # Skip tests if arq is not installed -import os - from sqlalchemy import select from mavedb import __version__ @@ -19,13 +17,6 @@ from tests.helpers.transaction_spy import TransactionSpy -# Unset test mode flag before each test to ensure decorator logic is executed -# during unit testing of the decorator itself. -@pytest.fixture(autouse=True) -def unset_test_mode_flag(): - os.environ.pop("MAVEDB_TEST_MODE", None) - - @with_guaranteed_job_run_record("test_job") async def sample_job(ctx: dict, job_id: int): """Sample job function to test the decorator. diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index ba8320f7..261bdcaa 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -10,7 +10,6 @@ pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio -import os from unittest.mock import patch from sqlalchemy import select @@ -24,13 +23,6 @@ from tests.helpers.transaction_spy import TransactionSpy -# Unset test mode flag before each test to ensure decorator logic is executed -# during unit testing of the decorator itself. -@pytest.fixture(autouse=True) -def unset_test_mode_flag(): - os.environ.pop("MAVEDB_TEST_MODE", None) - - @with_job_management async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): """Sample job function to test the decorator. diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index 1b8ae22f..d951a67b 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -12,7 +12,6 @@ pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio -import os from unittest.mock import patch from sqlalchemy import select @@ -25,13 +24,6 @@ from tests.helpers.transaction_spy import TransactionSpy -# Unset test mode flag before each test to ensure decorator logic is executed -# during unit testing of the decorator itself. -@pytest.fixture(autouse=True) -def unset_test_mode_flag(): - os.environ.pop("MAVEDB_TEST_MODE", None) - - async def sample_job(ctx=None, job_id=None): """Sample job function to test the decorator. When called, it patches the with_job_management decorator to be a no-op so we can test the @@ -303,12 +295,12 @@ async def test_decorator_integrated_pipeline_lifecycle_success( session.commit() @with_pipeline_management - async def sample_job(ctx: dict, job_id: int): + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals return {"status": "ok"} @with_pipeline_management - async def sample_dependent_job(ctx: dict, job_id: int): + async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): await dep_event.wait() # Simulate async work, block until test signals return {"status": "ok"} @@ -389,17 +381,17 @@ async def test_decorator_integrated_pipeline_lifecycle_retryable_failure( dep_event = asyncio.Event() @with_pipeline_management - async def sample_job(ctx: dict, job_id: int): + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals raise RuntimeError("Simulated job failure for retry") @with_pipeline_management - async def sample_retried_job(ctx: dict, job_id: int): + async def sample_retried_job(ctx: dict, job_id: int, job_manager: JobManager): await retry_event.wait() # Simulate async work, block until test signals return {"status": "ok"} @with_pipeline_management - async def sample_dependent_job(ctx: dict, job_id: int): + async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): await dep_event.wait() # Simulate async work, block until test signals return {"status": "ok"} @@ -495,7 +487,7 @@ async def test_decorator_integrated_pipeline_lifecycle_non_retryable_failure( event = asyncio.Event() @with_pipeline_management - async def sample_job(ctx: dict, job_id: int): + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals raise RuntimeError("Simulated job failure") From 92ab08188927890179370aa0dd525c77b32e2714 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 23 Jan 2026 17:08:46 -0800 Subject: [PATCH 28/70] feat: enhance pipeline start logic with controllable coordination From certain decorator contexts, we wish to not coordinate the pipeline after starting it. This prevents jobs from being double enqueued mistakenly. --- .../lib/decorators/pipeline_management.py | 7 ++-- .../worker/lib/managers/pipeline_manager.py | 13 +++++-- .../lib/managers/test_pipeline_manager.py | 36 +++++++++++++------ 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index 3bede53f..d5ece4f6 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -128,9 +128,12 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] logger.info(f"Pipeline ID for job {job_id} is {pipeline_id}. Coordinating pipeline.") - # If the pipeline is still in the created state, start it now + # If the pipeline is still in the created state, start it now. From this context, + # we do not wish to coordinate the pipeline. Doing so would result in the current + # job being re-queued before it has been marked as running, leading to potential state + # inconsistencies. if pipeline_manager and pipeline_manager.get_pipeline_status() == PipelineStatus.CREATED: - await pipeline_manager.start_pipeline() + await pipeline_manager.start_pipeline(coordinate=False) db_session.commit() logger.info(f"Pipeline {pipeline_id} associated with job {job_id} started successfully") diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py index a81a2738..74f6d344 100644 --- a/src/mavedb/worker/lib/managers/pipeline_manager.py +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -156,11 +156,11 @@ def __init__(self, db: Session, redis: ArqRedis, pipeline_id: int): self.pipeline_id = pipeline_id self.get_pipeline() # Validate pipeline exists on init - async def start_pipeline(self) -> None: + async def start_pipeline(self, coordinate: bool = True) -> None: """Start the pipeline Entry point to start pipeline execution. Sets pipeline status to RUNNING - and enqueues independent jobs using coordinate pipeline. + and enqueues independent jobs using coordinate pipeline if coordinate is True. Raises: DatabaseConnectionError: Cannot query or update pipeline @@ -183,7 +183,14 @@ async def start_pipeline(self) -> None: self.db.flush() logger.info(f"Pipeline {self.pipeline_id} started successfully") - await self.coordinate_pipeline() + + # Allow controllable coordination logic. By default, we want to coordinate + # immediately after starting to enqueue independent jobs. However, if a job + # has already been enqueued and is beginning execution and starts the pipeline, + # as a result of its job management decorator, we want to skip coordination here + # so we do not double-enqueue jobs. + if coordinate: + await self.coordinate_pipeline() async def coordinate_pipeline(self) -> None: """Coordinate pipeline after a job completes. diff --git a/tests/worker/lib/managers/test_pipeline_manager.py b/tests/worker/lib/managers/test_pipeline_manager.py index 5c57ba3f..cb7de415 100644 --- a/tests/worker/lib/managers/test_pipeline_manager.py +++ b/tests/worker/lib/managers/test_pipeline_manager.py @@ -82,7 +82,11 @@ class TestStartPipelineUnit: """Unit tests for starting a pipeline.""" @pytest.mark.asyncio - async def test_start_pipeline_successful(self, mock_pipeline_manager): + @pytest.mark.parametrize( + "coordinate_after_start", + [True, False], + ) + async def test_start_pipeline_successful(self, mock_pipeline_manager, coordinate_after_start): """Test successful pipeline start from CREATED state.""" with ( patch.object( @@ -94,10 +98,13 @@ async def test_start_pipeline_successful(self, mock_pipeline_manager): patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate, TransactionSpy.spy(mock_pipeline_manager.db, expect_flush=True), ): - await mock_pipeline_manager.start_pipeline() + await mock_pipeline_manager.start_pipeline(coordinate=coordinate_after_start) mock_set_status.assert_called_once_with(PipelineStatus.RUNNING) - mock_coordinate.assert_called_once() + if coordinate_after_start: + mock_coordinate.assert_called_once() + else: + mock_coordinate.assert_not_called() @pytest.mark.asyncio @pytest.mark.parametrize( @@ -131,14 +138,18 @@ class TestStartPipelineIntegration: """Integration tests for starting a pipeline.""" @pytest.mark.asyncio + @pytest.mark.parametrize( + "coordinate_after_start", + [True, False], + ) async def test_start_pipeline_successful( - self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run + self, session, arq_redis, with_populated_job_data, sample_pipeline, sample_job_run, coordinate_after_start ): """Test successful pipeline start from CREATED state.""" manager = PipelineManager(session, arq_redis, sample_pipeline.id) with TransactionSpy.spy(session, expect_flush=True): - await manager.start_pipeline() + await manager.start_pipeline(coordinate=coordinate_after_start) # Commit the session to persist changes session.commit() @@ -147,13 +158,16 @@ async def test_start_pipeline_successful( pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() assert pipeline.status == PipelineStatus.RUNNING - # Verify the initial job was queued + # Verify the initial job was queued if we are coordinating after start job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() - assert job.status == JobStatus.QUEUED - - # Verify the job was enqueued in Redis jobs = await arq_redis.queued_jobs() - assert jobs[0].function == sample_job_run.job_function + + if coordinate_after_start: + assert job.status == JobStatus.QUEUED + assert jobs[0].function == sample_job_run.job_function + else: + assert job.status == JobStatus.PENDING + assert len(jobs) == 0 @pytest.mark.asyncio async def test_start_pipeline_no_jobs(self, session, arq_redis, with_populated_job_data, sample_empty_pipeline): @@ -161,7 +175,7 @@ async def test_start_pipeline_no_jobs(self, session, arq_redis, with_populated_j manager = PipelineManager(session, arq_redis, sample_empty_pipeline.id) with TransactionSpy.spy(session, expect_flush=True): - await manager.start_pipeline() + await manager.start_pipeline(coordinate=True) # Commit the session to persist changes session.commit() From a06aa21581423f2516e50fa66fb62f80560cc002 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Sat, 24 Jan 2026 13:37:20 -0800 Subject: [PATCH 29/70] feat: logic fixups and comprehensive test cases for variant processing jobs --- src/mavedb/lib/mapping.py | 2 + .../jobs/variant_processing/creation.py | 112 +- .../worker/jobs/variant_processing/mapping.py | 130 +- tests/conftest_optional.py | 9 +- tests/helpers/constants.py | 57 +- tests/helpers/util/mapping.py | 6 - tests/helpers/util/setup/worker.py | 193 +- tests/worker/conftest.py | 176 +- tests/worker/conftest_optional.py | 3 + tests/worker/data/counts.csv | 9 +- tests/worker/data/scores.csv | 9 +- .../jobs/variant_processing/conftest.py | 191 ++ .../jobs/variant_processing/test_creation.py | 1404 ++++++++++++++ .../jobs/variant_processing/test_mapping.py | 1650 +++++++++++++++++ 14 files changed, 3585 insertions(+), 366 deletions(-) delete mode 100644 tests/helpers/util/mapping.py create mode 100644 tests/worker/jobs/variant_processing/conftest.py diff --git a/src/mavedb/lib/mapping.py b/src/mavedb/lib/mapping.py index d3915f53..0f601e85 100644 --- a/src/mavedb/lib/mapping.py +++ b/src/mavedb/lib/mapping.py @@ -9,6 +9,8 @@ "c": "cdna", } +EXCLUDED_PREMAPPED_ANNOTATION_KEYS = {"sequence"} + class VRSMap: url: str diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py index f71c5ed8..27a5a1aa 100644 --- a/src/mavedb/worker/jobs/variant_processing/creation.py +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -5,14 +5,17 @@ pipeline including data validation, standardization, and database persistence. """ +import io import logging +import pandas as pd from sqlalchemy import delete, null, select -from mavedb.data_providers.services import RESTDataProvider +from mavedb.data_providers.services import CSV_UPLOAD_S3_BUCKET_NAME, RESTDataProvider, s3_client from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.lib.score_sets import columns_for_dataset, create_variants, create_variants_data from mavedb.lib.validation.dataframe.dataframe import validate_and_standardize_dataframe_pair +from mavedb.lib.validation.exceptions import ValidationError from mavedb.models.enums.mapping_state import MappingState from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.mapped_variant import MappedVariant @@ -28,20 +31,21 @@ @with_pipeline_management -async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResultData: +async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """ Create variants for a given ScoreSet based on uploaded score and count data. Args: ctx: The job context dictionary. + job_id: The ID of the job being executed. job_manager: Manager for job lifecycle and DB operations. Job Parameters: - score_set_id (int): The ID of the ScoreSet to create variants for. - correlation_id (str): Correlation ID for tracing requests across services. - updater_id (int): The ID of the user performing the update. - - scores (pd.DataFrame): DataFrame containing score data. - - counts (pd.DataFrame): DataFrame containing count data. + - scores_file_key (str): S3 key for the uploaded scores CSV file. + - counts_file_key (str): S3 key for the uploaded counts CSV file. - score_columns_metadata (dict): Metadata for score columns. - count_columns_metadata (dict): Metadata for count columns. @@ -51,6 +55,10 @@ async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResu Returns: dict: Result indicating success and any exception details """ + # Handle everything prior to score set fetch in an outer layer. Any issues prior to + # fetching the score set should fail the job outright and we will be unable to set + # a processing state on the score set itself. + logger.info(msg="Starting create_variants_for_score_set job", extra=job_manager.logging_context()) hdp: RESTDataProvider = ctx["hdp"] # Get the job definition we are working on @@ -60,40 +68,68 @@ async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResu "score_set_id", "correlation_id", "updater_id", - "scores", - "counts", + "scores_file_key", + "counts_file_key", "score_columns_metadata", "count_columns_metadata", ] - validate_job_params(job_manager, _job_required_params, job) + validate_job_params(_job_required_params, job) # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore - correlation_id = job.job_params["correlation_id"] # type: ignore - updater_id = job.job_params["updater_id"] # type: ignore - scores = job.job_params["scores"] # type: ignore - counts = job.job_params["counts"] # type: ignore - score_columns_metadata = job.job_params["score_columns_metadata"] # type: ignore - count_columns_metadata = job.job_params["count_columns_metadata"] # type: ignore - - # Setup initial context and progress - job_manager.save_to_context( - { - "application": "mavedb-worker", - "function": "create_variants_for_score_set", - "resource": score_set.urn, - "correlation_id": correlation_id, - } - ) - job_manager.update_progress(0, 100, "Starting variant creation job.") - logger.info(msg="Started variant creation job", extra=job_manager.logging_context()) - - updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() # Main processing block. Handled in a try/except to ensure we can set score set state appropriately, # which is handled independently of the job state. - # TODO:XXX In a future iteration, we may want to move this logic into the job manager itself for better cohesion. + # TODO:XXX In a future iteration, we should rely on the job manager itself for maintaining processing + # state for better cohesion. This try/except is redundant in it's duties with the job manager. try: + correlation_id = job.job_params["correlation_id"] # type: ignore + updater_id = job.job_params["updater_id"] # type: ignore + score_file_key = job.job_params["scores_file_key"] # type: ignore + count_file_key = job.job_params["counts_file_key"] # type: ignore + score_columns_metadata = job.job_params["score_columns_metadata"] # type: ignore + count_columns_metadata = job.job_params["count_columns_metadata"] # type: ignore + + job_manager.save_to_context( + { + "score_set_id": score_set.id, + "updater_id": updater_id, + "correlation_id": correlation_id, + "score_file_key": score_file_key, + "count_file_key": count_file_key, + "bucket_name": CSV_UPLOAD_S3_BUCKET_NAME, + } + ) + logger.debug(msg="Fetching file resources from S3 for variant creation", extra=job_manager.logging_context()) + + s3 = s3_client() + scores = io.BytesIO() + s3.download_fileobj(Bucket=CSV_UPLOAD_S3_BUCKET_NAME, Key=score_file_key, Fileobj=scores) + scores_df = pd.read_csv(scores) + + # Counts file is optional + counts_df = None + if count_file_key: + counts = io.BytesIO() + s3.download_fileobj(Bucket=CSV_UPLOAD_S3_BUCKET_NAME, Key=count_file_key, Fileobj=counts) + counts_df = pd.read_csv(counts) + + logger.debug(msg="Successfully fetched file resources from S3", extra=job_manager.logging_context()) + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "create_variants_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting variant creation job.") + logger.info(msg="Started variant creation job", extra=job_manager.logging_context()) + + updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() + score_set.modified_by = updated_by score_set.processing_state = ProcessingState.processing score_set.mapping_state = MappingState.pending_variant_processing @@ -118,8 +154,8 @@ async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResu validated_scores, validated_counts, validated_score_columns_metadata, validated_count_columns_metadata = ( validate_and_standardize_dataframe_pair( - scores_df=scores, - counts_df=counts, + scores_df=scores_df, + counts_df=counts_df, score_columns_metadata=score_columns_metadata, count_columns_metadata=count_columns_metadata, targets=score_set.target_genes, @@ -140,8 +176,6 @@ async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResu else {}, } - job_manager.update_progress(90, 100, "Creating variants in database.") - # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. if score_set.variants: existing_variants = job_manager.db.scalars( @@ -161,14 +195,17 @@ async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResu variants_data = create_variants_data(validated_scores, validated_counts, None) create_variants(job_manager.db, score_set, variants_data) - # NOTE: Since these are likely to be internal errors, it makes less sense to add them to the DB and surface them to the end user. - # Catch all exceptions so we can log them and set score set state appropriately. except Exception as e: job_manager.db.rollback() score_set.processing_state = ProcessingState.failed - score_set.processing_errors = {"exception": str(e), "detail": []} score_set.mapping_state = MappingState.not_attempted + # Capture exception details in score set processing errors for all exceptions. + score_set.processing_errors = {"exception": str(e), "detail": []} + # ValidationErrors arise from problematic input data; capture their details specifically. + if isinstance(e, ValidationError): + score_set.processing_errors["detail"] = e.triggering_exceptions + if score_set.num_variants: score_set.processing_errors["exception"] = ( f"Update failed, variants were not updated. {score_set.processing_errors.get('exception', '')}" @@ -207,7 +244,6 @@ async def create_variants_for_score_set(ctx, job_manager: JobManager) -> JobResu job_manager.db.commit() job_manager.db.refresh(score_set) - job_manager.update_progress(100, 100, "Completed variant creation job.") - logger.info(msg="Committed new variants to score set.", extra=job_manager.logging_context()) - + job_manager.update_progress(100, 100, "Completed variant creation job.") + logger.info(msg="Committed new variants to score set.", extra=job_manager.logging_context()) return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/variant_processing/mapping.py b/src/mavedb/worker/jobs/variant_processing/mapping.py index 848c7b06..184041ea 100644 --- a/src/mavedb/worker/jobs/variant_processing/mapping.py +++ b/src/mavedb/worker/jobs/variant_processing/mapping.py @@ -21,7 +21,7 @@ NonexistentMappingScoresError, ) from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.mapping import ANNOTATION_LAYERS +from mavedb.lib.mapping import ANNOTATION_LAYERS, EXCLUDED_PREMAPPED_ANNOTATION_KEYS from mavedb.lib.slack import send_slack_error from mavedb.models.enums.mapping_state import MappingState from mavedb.models.mapped_variant import MappedVariant @@ -37,9 +37,12 @@ @with_pipeline_management -async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobResultData: +async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """Map variants for a given score set using VRS.""" - # Get the job definition we are working on + # Handle everything prior to score set fetch in an outer layer. Any issues prior to + # fetching the score set should fail the job outright and we will be unable to set + # a processing state on the score set itself. + job = job_manager.get_job() _job_required_params = [ @@ -47,32 +50,33 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR "correlation_id", "updater_id", ] - validate_job_params(job_manager, _job_required_params, job) + validate_job_params(_job_required_params, job) # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore - correlation_id = job.job_params["correlation_id"] # type: ignore - updater_id = job.job_params["updater_id"] # type: ignore - updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() - - # Setup initial context and progress - job_manager.save_to_context( - { - "application": "mavedb-worker", - "function": "map_variants_for_score_set", - "resource": score_set.urn, - "correlation_id": correlation_id, - } - ) - job_manager.update_progress(0, 100, "Starting variant mapping job.") - logger.info(msg="Started variant mapping job", extra=job_manager.logging_context()) - - # TODO#372: non-nullable URNs - if not score_set.urn: - raise ValueError("Score set URN is required for variant mapping.") # Handle everything within try/except to persist appropriate mapping state try: + correlation_id = job.job_params["correlation_id"] # type: ignore + updater_id = job.job_params["updater_id"] # type: ignore + updated_by = job_manager.db.scalars(select(User).where(User.id == updater_id)).one() + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "map_variants_for_score_set", + "resource": score_set.urn, + "correlation_id": correlation_id, + } + ) + job_manager.update_progress(0, 100, "Starting variant mapping job.") + logger.info(msg="Started variant mapping job", extra=job_manager.logging_context()) + + # TODO#372: non-nullable URNs + if not score_set.urn: # pragma: no cover + raise ValueError("Score set URN is required for variant mapping.") + # Setup score set state for mapping score_set.mapping_state = MappingState.processing score_set.mapping_errors = null() @@ -98,74 +102,37 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR mapping_results = await loop.run_in_executor(ctx["pool"], blocking) logger.debug(msg="Done mapping variants.", extra=job_manager.logging_context()) - job_manager.update_progress(80, 100, "Processing mapped variants and updating database.") + job_manager.update_progress(80, 100, "Processing mapped variants.") - ## Check our assumptions about mapping results and handle errors appropriately. Don't raise exceptions directly, - ## the try/except handling is intended for unexpected errors only. + ## Check our assumptions about mapping results and handle errors appropriately. # Ensure we have mapping results if not mapping_results: - score_set.mapping_state = MappingState.failed + job_manager.db.rollback() score_set.mapping_errors = {"error_message": "Mapping results were not returned from VRS mapping service."} - job_manager.db.add(score_set) - job_manager.db.commit() - job_manager.update_progress(100, 100, "Variant mapping failed due to missing results.") - job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) logger.error( msg="Mapping results were not returned from VRS mapping service.", extra=job_manager.logging_context() ) - return { - "status": "error", - "data": {}, - "exception_details": { - "message": "Mapping results were not returned from VRS mapping service.", - "type": NonexistentMappingResultsError.__name__, - "traceback": None, - }, - } + raise NonexistentMappingResultsError("Mapping results were not returned from VRS mapping service.") # Ensure we have mapped scores mapped_scores = mapping_results.get("mapped_scores") if not mapped_scores: - score_set.mapping_state = MappingState.failed + job_manager.db.rollback() score_set.mapping_errors = {"error_message": mapping_results.get("error_message")} - job_manager.db.add(score_set) - job_manager.db.commit() - job_manager.update_progress(100, 100, "Variant mapping failed; no variants were mapped.") - job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) logger.error(msg="No variants were mapped for this score set.", extra=job_manager.logging_context()) - return { - "status": "error", - "data": {}, - "exception_details": { - "message": "No variants were mapped for this score set.", - "type": NonexistentMappingScoresError.__name__, - "traceback": None, - }, - } + raise NonexistentMappingScoresError("No variants were mapped for this score set.") # Ensure we have reference metadata reference_metadata = mapping_results.get("reference_sequences") if not reference_metadata: - score_set.mapping_state = MappingState.failed + job_manager.db.rollback() score_set.mapping_errors = {"error_message": "Reference metadata missing from mapping results."} - job_manager.db.add(score_set) - job_manager.db.commit() - job_manager.update_progress(100, 100, "Variant mapping failed due to missing reference metadata.") - job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) logger.error(msg="Reference metadata missing from mapping results.", extra=job_manager.logging_context()) - return { - "status": "error", - "data": {}, - "exception_details": { - "message": "Reference metadata missing from mapping results.", - "type": NonexistentMappingReferenceError.__name__, - "traceback": None, - }, - } + raise NonexistentMappingReferenceError("Reference metadata missing from mapping results.") # Process and store mapped variants for target_gene_identifier in reference_metadata: @@ -185,7 +152,6 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR # allow for multiple annotation layers pre_mapped_metadata: dict[str, Any] = {} post_mapped_metadata: dict[str, Any] = {} - excluded_pre_mapped_keys = {"sequence"} # add gene-level info gene_info = reference_metadata[target_gene_identifier].get("gene_info") @@ -203,7 +169,8 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR ) if layer_premapped: pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { - k: layer_premapped[k] for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys + k: layer_premapped[k] + for k in set(list(layer_premapped.keys())) - EXCLUDED_PREMAPPED_ANNOTATION_KEYS } job_manager.save_to_context({"pre_mapped_layer_exists": True}) @@ -226,7 +193,7 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR total_variants = len(mapped_scores) job_manager.save_to_context({"total_variants_to_process": total_variants}) - job_manager.update_progress(90, 100, "Storing mapped variants in database.") + job_manager.update_progress(90, 100, "Saving mapped variants.") successful_mapped_variants = 0 for mapped_score in mapped_scores: @@ -270,7 +237,7 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR if successful_mapped_variants == 0: score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "All variants failed to map"} + score_set.mapping_errors = {"error_message": "All variants failed to map."} elif successful_mapped_variants < total_variants: score_set.mapping_state = MappingState.incomplete else: @@ -284,9 +251,15 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR "inserted_mapped_variants": len(mapped_scores), } ) + except (NonexistentMappingResultsError, NonexistentMappingScoresError, NonexistentMappingReferenceError) as e: + send_slack_error(e) + logging_context = {**job_manager.logging_context(), **format_raised_exception_info_as_dict(e)} + logger.error(msg="Known error during variant mapping.", extra=logging_context) + + score_set.mapping_state = MappingState.failed + # These exceptions have already set mapping_errors appropriately - job_manager.update_progress(100, 100, "Completed processing of mapped variants.") - logger.info(msg="Inserted mapped variants into db.", extra=job_manager.logging_context()) + raise e # Re-raise to be handled by the job management system except Exception as e: send_slack_error(e) @@ -302,14 +275,13 @@ async def map_variants_for_score_set(ctx: dict, job_manager: JobManager) -> JobR } job_manager.update_progress(100, 100, "Variant mapping failed due to an unexpected error.") - return { - "status": "error", - "data": {}, - "exception_details": {"message": str(e), "type": type(e).__name__, "traceback": None}, - } + # Raise unexpected exceptions to be handled by the job management system + raise e finally: job_manager.db.add(score_set) job_manager.db.commit() + logger.info(msg="Inserted mapped variants into db.", extra=job_manager.logging_context()) + job_manager.update_progress(100, 100, "Finished processing mapped variants.") return {"status": "ok" if successful_mapped_variants > 0 else "error", "data": {}, "exception_details": None} diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index 028a4e05..acbeec63 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -20,6 +20,7 @@ from mavedb.models.user import User from mavedb.server_main import app from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS +from mavedb.worker.lib.managers.types import JobResultData from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER #################################################################################################### @@ -77,6 +78,10 @@ def some_test(client, arq_redis): await redis_.aclose(close_connection_pool=True) +async def dummy_arq_function(ctx, *args, **kwargs) -> JobResultData: + return {"status": "ok", "data": {}, "exception_details": None} + + @pytest_asyncio.fixture() async def arq_worker(data_provider, session, arq_redis): """ @@ -86,7 +91,7 @@ async def arq_worker(data_provider, session, arq_redis): ``` async def worker_test(arq_redis, arq_worker): - await arq_redis.enqueue_job('some_job') + await arq_redis.enqueue_job('dummy_arq_function') await arq_worker.async_run() await arq_worker.run_check() ``` @@ -102,7 +107,7 @@ async def on_job(ctx): ctx["pool"] = futures.ProcessPoolExecutor() worker_ = Worker( - functions=BACKGROUND_FUNCTIONS, + functions=BACKGROUND_FUNCTIONS + [dummy_arq_function], cron_jobs=BACKGROUND_CRONJOBS, redis_pool=arq_redis, burst=True, diff --git a/tests/helpers/constants.py b/tests/helpers/constants.py index 32918235..e46c2c2a 100644 --- a/tests/helpers/constants.py +++ b/tests/helpers/constants.py @@ -1209,52 +1209,35 @@ }, } -TEST_CODING_LAYER = { +TEST_PROTEIN_LAYER = { + "computed_reference_sequence": { + "sequence_type": "protein", + "sequence_id": "ga4gh:SQ.ref_protein_test", + "sequence": "MKTIIALSYIFCLVFADYKDDDDK", + }, "mapped_reference_sequence": { - "sequence_accessions": [VALID_NT_ACCESSION], + "sequence_type": "protein", + "sequence_id": "ga4gh:SQ.map_protein_test", + "sequence_accessions": [VALID_PRO_ACCESSION], }, } -TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD = { - "metadata": {}, - "reference_sequences": { - "TEST1": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - } +TEST_CODING_LAYER = { + "computed_reference_sequence": { + "sequence_type": "coding", + "sequence_id": "ga4gh:SQ.ref_coding_test", + "sequence": "ATGAAGACGATTATTGCTCTTATCTTTCCTCTTTTGCTGATATACGACGACGACAAA", }, - "mapped_scores": [], - "vrs_version": "2.0", - "dcd_mapping_version": "pytest.0.0", - "mapped_date_utc": datetime.isoformat(datetime.now()), -} - -TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD = { - "metadata": {}, - "reference_sequences": { - "TEST2": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - } + "mapped_reference_sequence": { + "sequence_type": "coding", + "sequence_id": "ga4gh:SQ.map_coding_test", + "sequence_accessions": [VALID_NT_ACCESSION], }, - "mapped_scores": [], - "vrs_version": "2.0", - "dcd_mapping_version": "pytest.0.0", - "mapped_date_utc": datetime.isoformat(datetime.now()), } -TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD = { +TEST_MAPPING_SCAFFOLD = { "metadata": {}, - "reference_sequences": { - "TEST3": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - }, - "TEST4": { - "gene_info": TEST_GENE_INFO, - "layers": {"g": TEST_GENOMIC_LAYER, "c": TEST_CODING_LAYER}, - }, - }, + "reference_sequences": {}, "mapped_scores": [], "vrs_version": "2.0", "dcd_mapping_version": "pytest.0.0", diff --git a/tests/helpers/util/mapping.py b/tests/helpers/util/mapping.py deleted file mode 100644 index 828e7df8..00000000 --- a/tests/helpers/util/mapping.py +++ /dev/null @@ -1,6 +0,0 @@ -from mavedb.worker.jobs.utils.constants import MAPPING_QUEUE_NAME - - -async def sanitize_mapping_queue(standalone_worker_context, score_set): - queued_job = await standalone_worker_context["redis"].rpop(MAPPING_QUEUE_NAME) - assert int(queued_job.decode("utf-8")) == score_set.id diff --git a/tests/helpers/util/setup/worker.py b/tests/helpers/util/setup/worker.py index 50eee000..91aadb81 100644 --- a/tests/helpers/util/setup/worker.py +++ b/tests/helpers/util/setup/worker.py @@ -1,110 +1,52 @@ -import json from asyncio.unix_events import _UnixSelectorEventLoop from copy import deepcopy from unittest.mock import patch -from uuid import uuid4 -import cdot -import jsonschema from sqlalchemy import select -from mavedb.lib.score_sets import csv_data_to_df -from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.score_set import ScoreSet as ScoreSetDbModel from mavedb.models.variant import Variant -from mavedb.view_models.experiment import Experiment, ExperimentCreate -from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate from mavedb.worker.jobs import ( create_variants_for_score_set, map_variants_for_score_set, ) from tests.helpers.constants import ( - TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_MINIMAL_EXPERIMENT, - TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD, - TEST_NT_CDOT_TRANSCRIPT, - TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD, + TEST_CODING_LAYER, + TEST_GENE_INFO, + TEST_GENOMIC_LAYER, + TEST_MAPPING_SCAFFOLD, + TEST_PROTEIN_LAYER, TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, ) -from tests.helpers.util.mapping import sanitize_mapping_queue - - -async def setup_records_and_files(async_client, data_files, input_score_set): - experiment_payload = deepcopy(TEST_MINIMAL_EXPERIMENT) - jsonschema.validate(instance=experiment_payload, schema=ExperimentCreate.model_json_schema()) - experiment_response = await async_client.post("/api/v1/experiments/", json=experiment_payload) - assert experiment_response.status_code == 200 - experiment = experiment_response.json() - jsonschema.validate(instance=experiment, schema=Experiment.model_json_schema()) - - score_set_payload = deepcopy(input_score_set) - score_set_payload["experimentUrn"] = experiment["urn"] - jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.model_json_schema()) - score_set_response = await async_client.post("/api/v1/score-sets/", json=score_set_payload) - assert score_set_response.status_code == 200 - score_set = score_set_response.json() - jsonschema.validate(instance=score_set, schema=ScoreSet.model_json_schema()) - - scores_fp = ( - "scores_multi_target.csv" - if len(score_set["targetGenes"]) > 1 - else ("scores.csv" if "targetSequence" in score_set["targetGenes"][0] else "scores_acc.csv") - ) - counts_fp = ( - "counts_multi_target.csv" - if len(score_set["targetGenes"]) > 1 - else ("counts.csv" if "targetSequence" in score_set["targetGenes"][0] else "counts_acc.csv") - ) - with ( - open(data_files / scores_fp, "rb") as score_file, - open(data_files / counts_fp, "rb") as count_file, - open(data_files / "score_columns_metadata.json", "rb") as score_columns_file, - open(data_files / "count_columns_metadata.json", "rb") as count_columns_file, - ): - scores = csv_data_to_df(score_file) - counts = csv_data_to_df(count_file) - score_columns_metadata = json.load(score_columns_file) - count_columns_metadata = json.load(count_columns_file) - return score_set["urn"], scores, counts, score_columns_metadata, count_columns_metadata - -async def setup_records_files_and_variants(session, async_client, data_files, input_score_set, worker_ctx): - score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( - async_client, data_files, input_score_set - ) - score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - # Patch CDOT `_get_transcript`, in the event this function is called on an accesssion based scoreset. - with patch.object( - cdot.hgvs.dataproviders.RESTDataProvider, - "_get_transcript", - return_value=TEST_NT_CDOT_TRANSCRIPT, +async def create_variants_in_score_set( + session, mock_s3_client, score_df, count_df, mock_worker_ctx, variant_creation_run +): + """Add variants to a given score set in the database.""" + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[score_df, count_df], + ), ): - result = await create_variants_for_score_set( - worker_ctx, uuid4().hex, score_set.id, 1, scores, counts, score_columns_metadata, count_columns_metadata - ) - - score_set_with_variants = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() - - assert result["success"] - assert score_set.processing_state is ProcessingState.success - assert score_set_with_variants.num_variants == 3 + result = await create_variants_for_score_set(mock_worker_ctx, variant_creation_run.id) - return score_set_with_variants + assert result["status"] == "ok" + session.commit() -async def setup_records_files_and_variants_with_mapping( - session, async_client, data_files, input_score_set, standalone_worker_context +async def create_mappings_in_score_set( + session, mock_s3_client, mock_worker_ctx, score_df, count_df, variant_creation_run, variant_mapping_run ): - score_set = await setup_records_files_and_variants( - session, async_client, data_files, input_score_set, standalone_worker_context + score_set = await create_variants_in_score_set( + session, mock_s3_client, score_df, count_df, mock_worker_ctx, variant_creation_run ) - await sanitize_mapping_queue(standalone_worker_context, score_set) async def dummy_mapping_job(): - return await setup_mapping_output(async_client, session, score_set) + return await construct_mock_mapping_output(session, score_set, with_layers={"g", "c", "p"}) with ( patch.object( @@ -114,41 +56,60 @@ async def dummy_mapping_job(): ), patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), ): - result = await map_variants_for_score_set(standalone_worker_context, uuid4().hex, score_set.id, 1) - - assert result["success"] - assert not result["retried"] - assert not result["enqueued_jobs"] - return session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).one() - - -async def setup_mapping_output( - async_client, session, score_set, score_set_is_seq_based=True, score_set_is_multi_target=False, empty=False + result = await map_variants_for_score_set(mock_worker_ctx, variant_mapping_run.id) + + assert result["status"] == "ok" + session.commit() + + +async def construct_mock_mapping_output( + session, + score_set, + with_layers, + with_gene_info=True, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, ): - score_set_response = await async_client.get(f"/api/v1/score-sets/{score_set.urn}") - - if score_set_is_seq_based: - if score_set_is_multi_target: - # If this is a multi-target sequence based score set, use the scaffold for that. - mapping_output = deepcopy(TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD) - else: - mapping_output = deepcopy(TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD) - else: - # there is not currently a multi-target accession-based score set test - mapping_output = deepcopy(TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD) - mapping_output["metadata"] = score_set_response.json() - - if empty: - return mapping_output - - variants = session.scalars(select(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn)).all() - for variant in variants: - mapped_score = { - "pre_mapped": TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - "post_mapped": TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, - "mavedb_id": variant.urn, - } - - mapping_output["mapped_scores"].append(mapped_score) + """Construct mapping output for a given score set in the database.""" + mapping_output = deepcopy(TEST_MAPPING_SCAFFOLD) + + if with_reference_metadata: + for target in score_set.target_genes: + mapping_output["reference_sequences"][target.name] = { + "gene_info": TEST_GENE_INFO if with_gene_info else {}, + } + + for target in score_set.target_genes: + mapping_output["reference_sequences"][target.name]["layers"] = {} + if "g" in with_layers: + mapping_output["reference_sequences"][target.name]["layers"]["g"] = TEST_GENOMIC_LAYER + if "c" in with_layers: + mapping_output["reference_sequences"][target.name]["layers"]["c"] = TEST_CODING_LAYER + if "p" in with_layers: + mapping_output["reference_sequences"][target.name]["layers"]["p"] = TEST_PROTEIN_LAYER + + if with_mapped_scores: + variants = session.scalars( + select(Variant).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set.urn) + ).all() + + for idx, variant in enumerate(variants): + mapped_score = { + "pre_mapped": TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X if with_pre_mapped else {}, + "post_mapped": TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X if with_post_mapped else {}, + "mavedb_id": variant.urn, + } + + # Skip every other variant if not with_all_variants + if not with_all_variants and idx % 2 == 0: + mapped_score["post_mapped"] = {} + + mapping_output["mapped_scores"].append(mapped_score) + + if not mapping_output["mapped_scores"]: + mapping_output["error_message"] = "test error: no mapped scores" return mapping_output diff --git a/tests/worker/conftest.py b/tests/worker/conftest.py index eef66d03..4f1f32e3 100644 --- a/tests/worker/conftest.py +++ b/tests/worker/conftest.py @@ -7,22 +7,21 @@ from shutil import copytree from unittest.mock import Mock +import pandas as pd import pytest from mavedb.models.enums.job_pipeline import DependencyType, JobStatus, PipelineStatus +from mavedb.models.experiment import Experiment +from mavedb.models.experiment_set import ExperimentSet from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun from mavedb.models.license import License from mavedb.models.pipeline import Pipeline -from mavedb.models.taxonomy import Taxonomy +from mavedb.models.score_set import ScoreSet +from mavedb.models.target_gene import TargetGene +from mavedb.models.target_sequence import TargetSequence from mavedb.models.user import User -from tests.helpers.constants import ( - EXTRA_USER, - TEST_INACTIVE_LICENSE, - TEST_LICENSE, - TEST_SAVED_TAXONOMY, - TEST_USER, -) +from tests.helpers.constants import EXTRA_USER, TEST_LICENSE, TEST_USER # Attempt to import optional top level fixtures. If the modules they depend on are not installed, # we won't have access to our full fixture suite and only a limited subset of tests can be run. @@ -34,7 +33,7 @@ @pytest.fixture -def sample_job_run(): +def sample_job_run(sample_pipeline): """Create a sample JobRun instance for testing.""" return JobRun( id=1, @@ -42,7 +41,7 @@ def sample_job_run(): job_type="test_job", job_function="test_function", status=JobStatus.PENDING, - pipeline_id=1, + pipeline_id=sample_pipeline.id, progress_current=0, progress_total=100, progress_message="Ready to start", @@ -51,7 +50,7 @@ def sample_job_run(): @pytest.fixture -def sample_dependent_job_run(): +def sample_dependent_job_run(sample_pipeline): """Create a sample dependent JobRun instance for testing.""" return JobRun( id=2, @@ -59,7 +58,7 @@ def sample_dependent_job_run(): job_type="dependent_job", job_function="dependent_function", status=JobStatus.PENDING, - pipeline_id=1, + pipeline_id=sample_pipeline.id, progress_current=0, progress_total=100, progress_message="Waiting for dependency", @@ -113,24 +112,96 @@ def sample_empty_pipeline(): @pytest.fixture -def sample_job_dependency(): +def sample_job_dependency(sample_dependent_job_run, sample_job_run): """Create a sample JobDependency instance for testing.""" return JobDependency( - id=2, # dependent job - depends_on_job_id=1, # depends on job 1 + id=sample_dependent_job_run.id, # dependent job + depends_on_job_id=sample_job_run.id, # depends on job 1 dependency_type=DependencyType.SUCCESS_REQUIRED, created_at=datetime.now(), ) @pytest.fixture -def with_populated_domain_data(session): +def sample_user(): + """Create a sample User instance for testing.""" + return User(**TEST_USER) + + +@pytest.fixture +def sample_extra_user(): + """Create an extra sample User instance for testing.""" + return User(**EXTRA_USER) + + +@pytest.fixture +def sample_license(): + """Create a sample License instance for testing.""" + return License(**TEST_LICENSE) + + +@pytest.fixture +def sample_experiment_set(sample_user): + """Create a sample ExperimentSet instance for testing.""" + return ExperimentSet( + extra_metadata={}, + created_by=sample_user, + ) + + +@pytest.fixture +def sample_experiment(sample_experiment_set, sample_user): + """Create a sample Experiment instance for testing.""" + return Experiment( + title="Sample Experiment", + short_description="A sample experiment for testing purposes", + abstract_text="This is an abstract for the sample experiment.", + method_text="This is a method description for the sample experiment.", + extra_metadata={}, + experiment_set=sample_experiment_set, + created_by=sample_user, + ) + + +@pytest.fixture +def sample_score_set(sample_experiment, sample_user, sample_license): + """Create a sample ScoreSet instance for testing.""" + return ScoreSet( + title="Sample Score Set", + short_description="A sample score set for testing purposes", + abstract_text="This is an abstract for the sample score set.", + method_text="This is a method description for the sample score set.", + extra_metadata={}, + experiment=sample_experiment, + created_by=sample_user, + license=sample_license, + target_genes=[ + TargetGene( + name="Sample Gene", + category="protein_coding", + target_sequence=TargetSequence(label="testsequence", sequence_type="dna", sequence="ATGCAT"), + ) + ], + ) + + +@pytest.fixture +def with_populated_domain_data( + session, + sample_user, + sample_extra_user, + sample_experiment_set, + sample_experiment, + sample_score_set, + sample_license, +): db = session - db.add(User(**TEST_USER)) - db.add(User(**EXTRA_USER)) - db.add(Taxonomy(**TEST_SAVED_TAXONOMY)) - db.add(License(**TEST_LICENSE)) - db.add(License(**TEST_INACTIVE_LICENSE)) + db.add(sample_user) + db.add(sample_extra_user) + db.add(sample_experiment_set) + db.add(sample_experiment) + db.add(sample_score_set) + db.add(sample_license) db.commit() @@ -218,65 +289,10 @@ def data_files(tmp_path): @pytest.fixture -def mock_pipeline(): - """Create a mock Pipeline instance. By default, - properties are identical to a default new Pipeline entered into the db - with sensible defaults for non-nullable but unset fields. - """ - return Mock( - spec=Pipeline, - id=1, - urn="test:pipeline:1", - name="Test Pipeline", - description="A test pipeline", - status=PipelineStatus.CREATED, - correlation_id="test_correlation_123", - metadata_={}, - created_at=datetime.now(), - started_at=None, - finished_at=None, - created_by_user_id=None, - mavedb_version=None, - ) - - -@pytest.fixture -def mock_job_run(mock_pipeline): - """Create a mock JobRun instance. By default, - properties are identical to a default new JobRun entered into the db - with sensible defaults for non-nullable but unset fields. - """ - return Mock( - spec=JobRun, - id=123, - urn="test:job:123", - job_type="test_job", - job_function="test_function", - status=JobStatus.PENDING, - pipeline_id=mock_pipeline.id, - priority=0, - max_retries=3, - retry_count=0, - retry_delay_seconds=None, - scheduled_at=datetime.now(), - started_at=None, - finished_at=None, - created_at=datetime.now(), - error_message=None, - error_traceback=None, - failure_category=None, - worker_id=None, - worker_host=None, - progress_current=None, - progress_total=None, - progress_message=None, - correlation_id=None, - metadata_={}, - mavedb_version=None, - ) +def sample_score_dataframe(data_files): + return pd.read_csv(data_files / "scores.csv") @pytest.fixture -def data_files(tmp_path): - copytree(Path(__file__).absolute().parent / "data", tmp_path / "data") - return tmp_path / "data" +def sample_count_dataframe(data_files): + return pd.read_csv(data_files / "counts.csv") diff --git a/tests/worker/conftest_optional.py b/tests/worker/conftest_optional.py index a3a00f54..9848fe51 100644 --- a/tests/worker/conftest_optional.py +++ b/tests/worker/conftest_optional.py @@ -1,3 +1,4 @@ +from concurrent.futures import ProcessPoolExecutor from unittest.mock import Mock, patch import pytest @@ -50,6 +51,7 @@ def mock_worker_ctx(session): """Create a mock worker context dictionary for testing.""" mock_redis = Mock(spec=ArqRedis) mock_hdp = Mock(spec=RESTDataProvider) + mock_pool = Mock(spec=ProcessPoolExecutor) # Don't mock the session itself to allow real DB interactions in tests # It's generally more pain than it's worth to mock out SQLAlchemy sessions, @@ -58,4 +60,5 @@ def mock_worker_ctx(session): "db": session, "redis": mock_redis, "hdp": mock_hdp, + "pool": mock_pool, } diff --git a/tests/worker/data/counts.csv b/tests/worker/data/counts.csv index 0cc1e742..4821232a 100644 --- a/tests/worker/data/counts.csv +++ b/tests/worker/data/counts.csv @@ -1,4 +1,5 @@ -hgvs_nt,hgvs_pro,c_0,c_1 -c.1A>T,p.Thr1Ser,10,20 -c.2C>T,p.Thr1Met,8,8 -c.6T>A,p.Phe2Leu,90,2 +hgvs_nt,hgvs_splice,hgvs_pro,c_0,c_1 +c.1A>T,NA,p.Met1Leu,10,20 +c.2T>A,NA,p.Met1Lys,8,8 +c.3G>C,NA,p.Met1Ile,90,2 +c.4C>G,NA,p.His2Asp,12,1 diff --git a/tests/worker/data/scores.csv b/tests/worker/data/scores.csv index 11fce498..bd8e3bae 100644 --- a/tests/worker/data/scores.csv +++ b/tests/worker/data/scores.csv @@ -1,4 +1,5 @@ -hgvs_nt,hgvs_pro,score,s_0,s_1 -c.1A>T,p.Thr1Ser,0.3,val1,val1 -c.2C>T,p.Thr1Met,0.0,val2,val2 -c.6T>A,p.Phe2Leu,-1.65,val3,val3 +hgvs_nt,hgvs_splice,hgvs_pro,score,s_0,s_1 +c.1A>T,NA,p.Met1Leu,0.3,val1,val1 +c.2T>A,NA,p.Met1Lys,0,val2,val2 +c.3G>C,NA,p.Met1Ile,-1.65,val3,val3 +c.4C>G,NA,p.His2Asp,NA,val5,val4 diff --git a/tests/worker/jobs/variant_processing/conftest.py b/tests/worker/jobs/variant_processing/conftest.py new file mode 100644 index 00000000..1b88df2d --- /dev/null +++ b/tests/worker/jobs/variant_processing/conftest.py @@ -0,0 +1,191 @@ +from unittest import mock + +import pytest +from mypy_boto3_s3 import S3Client + +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline + + +@pytest.fixture +def create_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): + """Provide sample parameters for create_variants_for_score_set job.""" + + return { + "scores_file_key": "sample_scores.csv", + "counts_file_key": "sample_counts.csv", + "correlation_id": "sample-correlation-id", + "updater_id": sample_user.id, + "score_set_id": sample_score_set.id, + "score_columns_metadata": {"s_0": {"description": "metadataS", "details": "detailsS"}}, + "count_columns_metadata": {"c_0": {"description": "metadataC", "details": "detailsC"}}, + } + + +@pytest.fixture +def map_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): + """Provide sample parameters for map_variants_for_score_set job.""" + + return { + "score_set_id": sample_score_set.id, + "correlation_id": "sample-mapping-correlation-id", + "updater_id": sample_user.id, + } + + +@pytest.fixture +def mock_s3_client(): + """Mock S3 client for tests that interact with S3.""" + + with mock.patch("mavedb.worker.jobs.variant_processing.creation.s3_client") as mock_s3_client_func: + mock_s3 = mock.MagicMock(spec=S3Client) + mock_s3_client_func.return_value = mock_s3 + yield mock_s3 + + +@pytest.fixture +def sample_independent_variant_creation_run(create_variants_sample_params): + """Create a JobRun instance for variant creation job.""" + + return JobRun( + urn="test:create_variants_for_score_set", + job_type="create_variants_for_score_set", + job_function="create_variants_for_score_set", + max_retries=3, + retry_count=0, + job_params=create_variants_sample_params, + ) + + +@pytest.fixture +def sample_independent_variant_mapping_run(map_variants_sample_params): + """Create a JobRun instance for variant mapping job.""" + + return JobRun( + urn="test:map_variants_for_score_set", + job_type="map_variants_for_score_set", + job_function="map_variants_for_score_set", + max_retries=3, + retry_count=0, + job_params=map_variants_sample_params, + ) + + +@pytest.fixture +def dummy_pipeline_step(): + """Create a dummy pipeline step function for testing.""" + + return JobRun( + urn="test:dummy_pipeline_step", + job_type="dummy_pipeline_step", + job_function="dummy_arq_function", + max_retries=3, + retry_count=0, + ) + + +@pytest.fixture +def sample_pipeline_variant_creation_run( + session, + with_variant_creation_pipeline, + sample_variant_creation_pipeline, + sample_independent_variant_creation_run, +): + """Create a JobRun instance for variant creation job.""" + + sample_independent_variant_creation_run.pipeline_id = sample_variant_creation_pipeline.id + session.add(sample_independent_variant_creation_run) + session.commit() + return sample_independent_variant_creation_run + + +@pytest.fixture +def sample_pipeline_variant_mapping_run( + session, + with_variant_mapping_pipeline, + sample_independent_variant_mapping_run, + sample_variant_mapping_pipeline, +): + """Create a JobRun instance for variant mapping job.""" + + sample_independent_variant_mapping_run.pipeline_id = sample_variant_mapping_pipeline.id + session.add(sample_independent_variant_mapping_run) + session.commit() + return sample_independent_variant_mapping_run + + +@pytest.fixture +def sample_variant_creation_pipeline(): + """Create a Pipeline instance.""" + + return Pipeline( + name="variant_creation_pipeline", + description="Pipeline for creating variants", + ) + + +@pytest.fixture +def sample_variant_mapping_pipeline(): + """Create a Pipeline instance.""" + + return Pipeline( + name="variant_mapping_pipeline", + description="Pipeline for mapping variants", + ) + + +@pytest.fixture +def with_independent_processing_runs( + session, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, +): + """Fixture to ensure independent variant processing runs exist in the database.""" + + session.add(sample_independent_variant_creation_run) + session.add(sample_independent_variant_mapping_run) + session.commit() + + +@pytest.fixture +def with_variant_creation_pipeline(session, sample_variant_creation_pipeline): + """Fixture to ensure variant creation pipeline and its runs exist in the database.""" + session.add(sample_variant_creation_pipeline) + session.commit() + + +@pytest.fixture +def with_variant_creation_pipeline_runs( + session, + with_variant_creation_pipeline, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + dummy_pipeline_step, +): + """Fixture to ensure pipeline variant processing runs exist in the database.""" + session.add(sample_pipeline_variant_creation_run) + dummy_pipeline_step.pipeline_id = sample_variant_creation_pipeline.id + session.add(dummy_pipeline_step) + session.commit() + + +@pytest.fixture +def with_variant_mapping_pipeline(session, sample_variant_mapping_pipeline): + """Fixture to ensure variant mapping pipeline and its runs exist in the database.""" + session.add(sample_variant_mapping_pipeline) + session.commit() + + +@pytest.fixture +def with_variant_mapping_pipeline_runs( + session, + with_variant_mapping_pipeline, + sample_variant_mapping_pipeline, + sample_pipeline_variant_mapping_run, + dummy_pipeline_step, +): + """Fixture to ensure pipeline variant processing runs exist in the database.""" + session.add(sample_pipeline_variant_mapping_run) + dummy_pipeline_step.pipeline_id = sample_variant_mapping_pipeline.id + session.add(dummy_pipeline_step) + session.commit() diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index e69de29b..a034ebeb 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -0,0 +1,1404 @@ +import math +from unittest.mock import ANY, MagicMock, call, patch + +import pytest + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.enums.processing_state import ProcessingState +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.models.variant import Variant +from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set +from mavedb.worker.lib.managers.job_manager import JobManager + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestCreateVariantsForScoreSetUnit: + """Unit tests for create_variants_for_score_set job.""" + + async def test_create_variants_for_score_set_raises_key_error_on_missing_hdp_from_ctx( + self, + mock_job_manager, + ): + ctx = {} # Missing 'hdp' key + + with pytest.raises(KeyError) as exc_info: + await create_variants_for_score_set(ctx=ctx, job_id=999, job_manager=mock_job_manager) + + assert str(exc_info.value) == "'hdp'" + + async def test_create_variants_for_score_set_calls_s3_client_with_correct_parameters( + self, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None) as mock_download_fileobj, + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + # Use ANY for dynamically created Fileobj parameters. + mock_download_fileobj.assert_has_calls( + [ + call(Bucket="score-set-csv-uploads-dev", Key="sample_scores.csv", Fileobj=ANY), + call(Bucket="score-set-csv-uploads-dev", Key="sample_counts.csv", Fileobj=ANY), + ] + ) + + async def test_create_variants_for_score_set_s3_file_not_found( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object( + mock_s3_client, + "download_fileobj", + side_effect=Exception("The specified key does not exist."), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(Exception) as exc_info, + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") + assert str(exc_info.value) == "The specified key does not exist." + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + + async def test_create_variants_for_score_set_counts_file_can_be_optional( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Remove counts_file_key to test optional behavior + create_variants_sample_params_without_counts = create_variants_sample_params.copy() + create_variants_sample_params_without_counts["counts_file_key"] = None + create_variants_sample_params_without_counts["count_columns_metadata"] = None + sample_independent_variant_creation_run.job_params = create_variants_sample_params_without_counts + session.add(sample_independent_variant_creation_run) + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample score dataframe only + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + None, + create_variants_sample_params_without_counts["score_columns_metadata"], + None, + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + async def test_create_variants_for_score_set_raises_when_no_targets_exist( + self, + session, + with_independent_processing_runs, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Remove all TargetGene entries to simulate no targets existing + sample_score_set.target_genes = [] + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(ValueError) as exc_info, + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Score set has no targets; cannot create variants.") + assert str(exc_info.value) == "Can't create variants when score set has no targets." + + async def test_create_variants_for_score_set_calls_validate_standardize_dataframe_with_correct_parameters( + self, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ) as mock_validate, + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + mock_validate.assert_called_once_with( + scores_df=sample_score_dataframe, + counts_df=sample_count_dataframe, + score_columns_metadata=create_variants_sample_params["score_columns_metadata"], + count_columns_metadata=create_variants_sample_params["count_columns_metadata"], + targets=sample_score_set.target_genes, + hdp=mock_worker_ctx["hdp"], + ) + + async def test_create_variants_for_score_set_calls_create_variants_data_with_correct_parameters( + self, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ) as mock_create_variants_data, + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + mock_create_variants_data.assert_called_once_with(sample_score_dataframe, sample_count_dataframe, None) + + async def test_create_variants_for_score_set_calls_create_variants_with_correct_parameters( + self, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + mock_variant = MagicMock(spec=Variant) + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[mock_variant], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants", + return_value=None, + ) as mock_create_variants, + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + mock_create_variants.assert_called_once_with(mock_worker_ctx["db"], sample_score_set, [mock_variant]) + + async def test_create_variants_for_score_set_handles_empty_variant_data( + self, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants_data", return_value=[]), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + # If no exceptions are raised, the test passes for handling empty variant data. + + async def test_create_variants_for_score_set_removes_existing_variants_before_creation( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Add existing variants to the score set to test removal + sample_score_set.num_variants = 1 + variant = Variant(data={}, score_set_id=sample_score_set.id) + session.add(variant) + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + # Verify that existing variants have been removed + remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(remaining_variants) == 0 + session.refresh(sample_score_set) + assert sample_score_set.num_variants == 0 # Updated after creation + + async def test_create_variants_for_score_set_updates_processing_state( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + assert sample_score_set.processing_errors is None + + async def test_create_variants_for_score_set_updates_progress( + self, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + return_value=( + sample_score_dataframe, + sample_count_dataframe, + create_variants_sample_params["score_columns_metadata"], + create_variants_sample_params["count_columns_metadata"], + ), + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.create_variants_data", + return_value=[MagicMock(spec=Variant)], + ), + patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting variant creation job."), + call(10, 100, "Validated score set metadata and beginning data validation."), + call(80, 100, "Data validation complete; creating variants in database."), + call(100, 100, "Completed variant creation job."), + ] + ) + + async def test_create_variants_for_score_set_retains_existing_variants_when_exception_occurs( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # Add existing variants to the score set to test retention on failure + sample_score_set.num_variants = 1 + variant = Variant(data={}, score_set_id=sample_score_set.id) + session.add(variant) + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Test exception during data validation"), + ), + pytest.raises(Exception) as exc_info, + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + assert str(exc_info.value) == "Test exception during data validation" + + # Verify that existing variants are still present + remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(remaining_variants) == 1 + session.refresh(sample_score_set) + assert sample_score_set.num_variants == 1 # Should remain unchanged + + async def test_create_variants_for_score_set_handles_exception_and_updates_state( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Test exception during data validation"), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(Exception) as exc_info, + ): + await create_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_creation_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id + ), + ) + + assert str(exc_info.value) == "Test exception during data validation" + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Test exception during data validation" in sample_score_set.processing_errors["exception"] + mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestCreateVariantsForScoreSetIntegration: + """Integration tests for create_variants_for_score_set job.""" + + ## Common success workflows + + async def test_create_variants_for_score_set_independent_job( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + # Assume the S3 client works as expected. + # + # Moto is omitted here for brevity since this + # function doesn't have S3 side effects. We assume the file is already in S3 for this test, + # and any cases where the file is not present will be handled by the job manager and tested + # in unit tests. + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + # + # A side effect of not mocking S3 more thoroughly + # is that our S3 download has no return value and just side effects data into a file-like object, + # so we mock pd.read_csv directly to avoid it trying to read from an empty file. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + async def test_create_variants_for_score_set_pipeline_job( + self, + session, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_creation_run.id) + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that pipeline job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + # Verify that pipeline status is updated. Pipeline will remain RUNNING + # as our default test pipeline includes the mapping job as well. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING + + ## Common edge cases + + async def test_create_variants_for_score_set_replaces_variants( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # First run to create initial variants + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + initial_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(initial_variants) == sample_score_dataframe.shape[0] + + # Modify dataframes to simulate updated data + updated_score_dataframe = sample_score_dataframe.copy() + updated_score_dataframe["score"] += 10 # Increment scores by 10 + + updated_count_dataframe = sample_count_dataframe.copy() + updated_count_dataframe["c_0"] += 5 # Increment counts by 5 + + # Mock a second run with updated dataframes + sample_independent_variant_creation_run.status = JobStatus.PENDING + session.commit() + + # Second run to replace existing variants + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[updated_score_dataframe, updated_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + replaced_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(replaced_variants) == sample_score_dataframe.shape[0] + + # Verify that the variants have been replaced with updated data + for variant in replaced_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = updated_score_dataframe.loc[ + updated_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = updated_count_dataframe.loc[ + updated_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(replaced_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + async def test_create_variants_for_score_set_handles_missing_counts_file( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + sample_score_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + sample_independent_variant_creation_run.job_params["counts_file_key"] = None + sample_independent_variant_creation_run.job_params["count_columns_metadata"] = {} + session.commit() + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return only the score dataframe + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present but... + assert variant.data["count_data"] == {} # ...ensure count_data is empty since no counts file was provided + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + ## Common failure workflows + + async def test_create_variants_for_score_set_validation_error_during_creation( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + sample_score_dataframe.loc[0, "hgvs_nt"] = "c.G>X" # Introduce invalid value to trigger validation error + + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "encountered 1 invalid variant strings" in sample_score_set.processing_errors["exception"] + assert len(sample_score_set.processing_errors["detail"]) > 0 + + # Verify that no variants were created + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == 0 + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + async def test_create_variants_for_score_set_generic_exception_handling_during_creation( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + async def test_create_variants_for_score_set_generic_exception_handling_during_replacement( + self, + session, + with_independent_processing_runs, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + # First run to create initial variants + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + initial_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(initial_variants) == sample_score_dataframe.shape[0] + + # Mock a second run to replace existing variants + sample_independent_variant_creation_run.status = JobStatus.PENDING + session.commit() + + # Second run to replace existing variants but trigger a generic exception + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that initial variants are still present + remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(remaining_variants) == len(initial_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + ## Pipeline failure workflow + + async def test_create_variants_for_score_set_pipeline_job_generic_exception_handling( + self, + session, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_worker_ctx, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + ): + await create_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_creation_run.id) + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + # Verify that pipeline status is updated. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED + + # Verify other pipeline runs are marked as failed + other_runs = ( + session.query(Pipeline) + .filter( + JobRun.pipeline_id == sample_variant_creation_pipeline.id, + Pipeline.id != sample_pipeline_variant_creation_run.id, + ) + .all() + ) + for run in other_runs: + assert run.status == PipelineStatus.CANCELLED + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestCreateVariantsForScoreSetArqContext: + """Integration tests for create_variants_for_score_set job using ARQ worker context.""" + + async def test_create_variants_for_score_set_with_arq_context_independent_ctx( + self, + session, + arq_redis, + arq_worker, + with_independent_processing_runs, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_independent_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + async def test_create_variants_for_score_set_with_arq_context_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes. + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + ): + await arq_redis.enqueue_job( + "create_variants_for_score_set", + sample_pipeline_variant_creation_run.id, + _job_id=sample_pipeline_variant_creation_run.urn, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that variants have been created in the database + created_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + assert len(created_variants) == sample_score_dataframe.shape[0] + session.refresh(sample_score_set) + assert sample_score_set.num_variants == len(created_variants) + assert sample_score_set.processing_state == ProcessingState.success + assert sample_score_set.mapping_state == MappingState.queued + + # Verify that the created variants have expected data + for variant in created_variants: + assert variant.data # Ensure data is not empty + assert "score_data" in variant.data # Ensure score_data is present + expected_score = sample_score_dataframe.loc[ + sample_score_dataframe["hgvs_nt"] == variant.hgvs_nt, "score" + ].values[0] + actual_score = variant.data["score_data"]["score"] + if actual_score is None and (isinstance(expected_score, float) and math.isnan(expected_score)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_score == expected_score # Ensure score matches + assert "count_data" in variant.data # Ensure count_data is present + expected_count = sample_count_dataframe.loc[ + sample_count_dataframe["hgvs_nt"] == variant.hgvs_nt, "c_0" + ].values[0] + actual_count = variant.data["count_data"]["c_0"] + if actual_count is None and (isinstance(expected_count, float) and math.isnan(expected_count)): + pass # None in variant, NaN in DataFrame: OK + else: + assert actual_count == expected_count # Ensure count matches + + # Verify that no extra variants were created + all_variants = session.query(Variant).all() + assert len(all_variants) == len(created_variants) + + # Verify that pipeline job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.SUCCEEDED + + # Verify that pipeline status is updated. Pipeline will remain RUNNING + # as our default test pipeline includes the mapping job as well. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING + + async def test_create_variants_for_score_set_with_arq_context_generic_exception_handling_independent_ctx( + self, + session, + arq_redis, + arq_worker, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_independent_variant_creation_run, + with_populated_domain_data, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_independent_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_independent_variant_creation_run.__class__) + .filter(sample_independent_variant_creation_run.__class__.id == sample_independent_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + async def test_create_variants_for_score_set_with_arq_context_generic_exception_handling_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + with_variant_creation_pipeline_runs, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + with_populated_domain_data, + mock_s3_client, + create_variants_sample_params, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + with ( + patch.object(mock_s3_client, "download_fileobj", return_value=None), + # Mock pd.read_csv to return sample dataframes + patch( + "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", + side_effect=[sample_score_dataframe, sample_count_dataframe], + ), + patch( + "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", + side_effect=Exception("Generic exception during data validation"), + ), + ): + await arq_redis.enqueue_job("create_variants_for_score_set", sample_pipeline_variant_creation_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the score set's processing state is updated to failed + session.refresh(sample_score_set) + assert sample_score_set.processing_state == ProcessingState.failed + assert sample_score_set.mapping_state == MappingState.not_attempted + assert "Generic exception during data validation" in sample_score_set.processing_errors["exception"] + + # Verify that job state is as expected + job_run = ( + session.query(sample_pipeline_variant_creation_run.__class__) + .filter(sample_pipeline_variant_creation_run.__class__.id == sample_pipeline_variant_creation_run.id) + .one() + ) + assert job_run.progress_current == 100 + assert job_run.status == JobStatus.FAILED + + # Verify that pipeline status is updated. + session.refresh(sample_variant_creation_pipeline) + assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED + + # Verify other pipeline runs are marked as cancelled + other_runs = ( + session.query(Pipeline) + .filter( + JobRun.pipeline_id == sample_variant_creation_pipeline.id, + Pipeline.id != sample_pipeline_variant_creation_run.id, + ) + .all() + ) + for run in other_runs: + assert run.status == PipelineStatus.CANCELLED diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index e69de29b..74a1c050 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -0,0 +1,1650 @@ +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import MagicMock, call, patch + +import pytest +from sqlalchemy.exc import NoResultFound + +from mavedb.lib.exceptions import ( + NonexistentMappingReferenceError, + NonexistentMappingResultsError, + NonexistentMappingScoresError, +) +from mavedb.lib.mapping import EXCLUDED_PREMAPPED_ANNOTATION_KEYS +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.variant import Variant +from mavedb.worker.jobs.variant_processing.mapping import map_variants_for_score_set +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import TEST_CODING_LAYER, TEST_GENOMIC_LAYER, TEST_PROTEIN_LAYER +from tests.helpers.util.setup.worker import construct_mock_mapping_output, create_variants_in_score_set + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestMapVariantsForScoreSetUnit: + """Unit tests for map_variants_for_score_set job.""" + + async def dummy_mapping_output(self, output_data={}): + return output_data + + async def test_map_variants_for_score_set_no_mapping_results( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no mapping results are found.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object(_UnixSelectorEventLoop, "run_in_executor", return_value=self.dummy_mapping_output({})), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(NonexistentMappingResultsError), + ): + await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing results.") + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Mapping results were not returned from VRS mapping service" + in sample_score_set.mapping_errors["error_message"] + ) + + async def test_map_variants_for_score_set_no_mapped_scores( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no scores are mapped.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output( + {"mapped_scores": [], "error_message": "No variants were mapped for this score set"} + ), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(NonexistentMappingScoresError), + ): + await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed; no variants were mapped.") + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "No variants were mapped for this score set" in sample_score_set.mapping_errors["error_message"] + + async def test_map_variants_for_score_set_no_reference_data( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no reference data is available.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output( + {"mapped_scores": [MagicMock()], "error_message": "Reference metadata missing from mapping results"} + ), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(NonexistentMappingReferenceError), + ): + await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing reference metadata.") + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "Reference metadata missing from mapping results" in sample_score_set.mapping_errors["error_message"] + + async def test_map_variants_for_score_set_nonexistent_target_gene( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when the target gene does not exist.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output( + { + "mapped_scores": [MagicMock()], + "reference_sequences": {"some_key": "some_value"}, + } + ), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(ValueError), + ): + await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + async def test_map_variants_for_score_set_returns_variants_not_in_score_set( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when variants not in score set are returned.""" + # Add a non-existent variant to the mapped output to ensure at least one invalid mapping + mapping_output = await construct_mock_mapping_output( + session=mock_worker_ctx["db"], score_set=sample_score_set, with_layers={"g", "c", "p"} + ) + mapping_output["mapped_scores"].append({"variant_id": "not_in_score_set", "some_other_data": "value"}) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=self.dummy_mapping_output(mapping_output), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(NoResultFound), + ): + await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + async def test_map_variants_for_score_set_success_missing_gene_info( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with missing gene info.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=False, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + mock_worker_ctx["db"].add(variant) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify the gene info is missing from the target gene reference sequence + for target in sample_score_set.target_genes: + assert target.mapped_hgnc_name is None + + # Verify that a mapped variant was created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 1 + + @pytest.mark.parametrize( + "with_layers", + [ + {"g"}, + {"c"}, + {"p"}, + {"g", "c"}, + {"g", "p"}, + {"c", "p"}, + {"g", "c", "p"}, + ], + ) + async def test_map_variants_for_score_set_success_layer_permutations( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + with_layers, + ): + """Test successful mapping variants with annotation layer permutations.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers=with_layers, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + mock_worker_ctx["db"].add(variant) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify the annotation layers presence/absence + for target in sample_score_set.target_genes: + if "g" in with_layers: + assert target.pre_mapped_metadata["genomic"] is not None + assert target.post_mapped_metadata["genomic"] is not None + pre_mapped_comparator = TEST_GENOMIC_LAYER["computed_reference_sequence"].copy() + for key in EXCLUDED_PREMAPPED_ANNOTATION_KEYS: + pre_mapped_comparator.pop(key, None) + + assert target.pre_mapped_metadata["genomic"] == pre_mapped_comparator + assert target.post_mapped_metadata["genomic"] == TEST_GENOMIC_LAYER["mapped_reference_sequence"] + else: + assert target.post_mapped_metadata.get("genomic") is None + + if "c" in with_layers: + assert target.pre_mapped_metadata["cdna"] is not None + assert target.post_mapped_metadata["cdna"] is not None + pre_mapped_comparator = TEST_CODING_LAYER["computed_reference_sequence"].copy() + for key in EXCLUDED_PREMAPPED_ANNOTATION_KEYS: + pre_mapped_comparator.pop(key, None) + + assert target.pre_mapped_metadata["cdna"] == pre_mapped_comparator + assert target.post_mapped_metadata["cdna"] == TEST_CODING_LAYER["mapped_reference_sequence"] + else: + assert target.post_mapped_metadata.get("cdna") is None + + if "p" in with_layers: + assert target.pre_mapped_metadata["protein"] is not None + assert target.post_mapped_metadata["protein"] is not None + pre_mapped_comparator = TEST_PROTEIN_LAYER["computed_reference_sequence"].copy() + for key in EXCLUDED_PREMAPPED_ANNOTATION_KEYS: + pre_mapped_comparator.pop(key, None) + + assert target.pre_mapped_metadata["protein"] == pre_mapped_comparator + assert target.post_mapped_metadata["protein"] == TEST_PROTEIN_LAYER["mapped_reference_sequence"] + else: + assert target.post_mapped_metadata.get("protein") is None + + # Verify that a mapped variant was created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 1 + + async def test_map_variants_for_score_set_success_no_successful_mapping( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with no successful mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=False, # Missing post-mapped + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + mock_worker_ctx["db"].add(variant) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "error" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors["error_message"] == "All variants failed to map." + + # Verify that one mapped variant was created. Although no successful mapping, an entry is still created. + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 1 + + # Verify that the mapped variant has no post-mapped data + mapped_variant = mapped_variants[0] + assert mapped_variant.post_mapped == {} + + async def test_map_variants_for_score_set_incomplete_mapping( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with incomplete mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=False, # Only some variants mapped + ) + + # Create two variants in the score set to be mapped + variant1 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={}, + urn="variant:1", + ) + variant2 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.2G>T", + hgvs_pro="NP_000000.1:p.Val2Leu", + data={}, + urn="variant:2", + ) + mock_worker_ctx["db"].add_all([variant1, variant2]) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.incomplete + assert sample_score_set.mapping_errors is None + + # Although only one variant was successfully mapped, verify that an entity was created + # for each variant in the score set + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 2 + + # Verify that only one variant has post-mapped data + mapped_variant_with_post_data = ( + mock_worker_ctx["db"].query(MappedVariant).filter(MappedVariant.post_mapped != {}).one_or_none() + ) + assert mapped_variant_with_post_data is not None + + mapped_variant_without_post_data = ( + mock_worker_ctx["db"].query(MappedVariant).filter(MappedVariant.post_mapped == {}).one_or_none() + ) + assert mapped_variant_without_post_data is not None + + async def test_map_variants_for_score_set_complete_mapping( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test successful mapping variants with complete mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, # All variants mapped + ) + + # Create two variants in the score set to be mapped + variant1 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={}, + urn="variant:1", + ) + variant2 = Variant( + score_set_id=sample_score_set.id, + hgvs_nt="NM_000000.1:c.2G>T", + hgvs_pro="NP_000000.1:p.Val2Leu", + data={}, + urn="variant:2", + ) + mock_worker_ctx["db"].add_all([variant1, variant2]) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 2 + + # Verify that both variants have post-mapped data. I'm comfortable assuming the + # data is correct given our layer permutation tests above. + for urn in ["variant:1", "variant:2"]: + mapped_variant = ( + mock_worker_ctx["db"].query(MappedVariant).filter(MappedVariant.variant.has(urn=urn)).one_or_none() + ) + assert mapped_variant is not None + assert mapped_variant.post_mapped != {} + + async def test_map_variants_for_score_set_updates_existing_mapped_variants( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants updates existing mapped variants.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant and associated mapped data in the score set to be updated + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + mock_worker_ctx["db"].add(variant) + mock_worker_ctx["db"].commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2023-01-01T00:00:00Z", + mapping_api_version="v1.0.0", + ) + mock_worker_ctx["db"].add(mapped_variant) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify the existing mapped variant was marked as non-current + non_current_mapped_variant = ( + mock_worker_ctx["db"] + .query(MappedVariant) + .filter(MappedVariant.id == mapped_variant.id, MappedVariant.current.is_(False)) + .one_or_none() + ) + assert non_current_mapped_variant is not None + + # Verify a new mapped variant entry was created + new_mapped_variant = ( + mock_worker_ctx["db"] + .query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() + ) + assert new_mapped_variant is not None + + # Verify that the new mapped variant has updated mapping data + assert new_mapped_variant.mapped_date != "2023-01-01T00:00:00Z" + assert new_mapped_variant.mapping_api_version != "v1.0.0" + + async def test_map_variants_for_score_set_progress_updates( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants reports progress updates.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Create a variant in the score set to be mapped + variant = Variant( + score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} + ) + mock_worker_ctx["db"].add(variant) + mock_worker_ctx["db"].commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await map_variants_for_score_set( + ctx=mock_worker_ctx, + job_id=sample_independent_variant_mapping_run.id, + job_manager=JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id + ), + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify progress updates were reported + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting variant mapping job."), + call(10, 100, "Score set prepared for variant mapping."), + call(30, 100, "Mapping variants using VRS mapping service."), + call(80, 100, "Processing mapped variants."), + call(90, 100, "Saving mapped variants."), + call(100, 100, "Finished processing mapped variants."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestMapVariantsForScoreSetIntegration: + """Integration tests for map_variants_for_score_set job.""" + + async def test_map_variants_for_score_set_independent_job( + self, + session, + with_independent_processing_runs, + mock_s3_client, + mock_worker_ctx, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + ): + """Test mapping variants for an independent processing run.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Mock mapping output + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Now, map variants for the score set + result = await map_variants_for_score_set(mock_worker_ctx, sample_independent_variant_mapping_run.id) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + # Verify that mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that target gene info was updated + for target in sample_score_set.target_genes: + assert target.mapped_hgnc_name is not None + assert target.post_mapped_metadata is not None + + # Verify that each variant has a corresponding mapped variant + variants = ( + mock_worker_ctx["db"] + .query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that the job status was updated + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + async def test_map_variants_for_score_set_pipeline_context( + self, + session, + with_variant_creation_pipeline_runs, + with_variant_mapping_pipeline_runs, + mock_s3_client, + mock_worker_ctx, + sample_pipeline_variant_creation_run, + sample_pipeline_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + ): + """Test mapping variants for a pipeline processing run.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_pipeline_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Mock mapping output + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Now, map variants for the score set + result = await map_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_mapping_run.id) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + # Verify that mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that target gene info was updated + for target in sample_score_set.target_genes: + assert target.mapped_hgnc_name is not None + assert target.post_mapped_metadata is not None + + # Verify that each variant has a corresponding mapped variant + variants = ( + mock_worker_ctx["db"] + .query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that the job status was updated + processing_run = ( + mock_worker_ctx["db"] + .query(sample_pipeline_variant_mapping_run.__class__) + .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status was updated. We expect RUNNING here because + # the mapping job is not the only job in our dummy pipeline. + pipeline_run = ( + mock_worker_ctx["db"] + .query(sample_pipeline_variant_mapping_run.pipeline.__class__) + .filter( + sample_pipeline_variant_mapping_run.pipeline.__class__.id + == sample_pipeline_variant_mapping_run.pipeline.id + ) + .one() + ) + assert pipeline_run.status == PipelineStatus.RUNNING + + async def test_map_variants_for_score_set_empty_mapping_results( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants when no mapping results are returned.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return {} + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object(_UnixSelectorEventLoop, "run_in_executor", return_value=dummy_mapping_job()), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "failed" + assert result["exception_details"]["type"] == "NonexistentMappingResultsError" + assert result["data"] == {} + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert ( + "Mapping results were not returned from VRS mapping service" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_no_mapped_scores( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants when no variants are mapped.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=False, + with_reference_metadata=True, + with_mapped_scores=False, # No mapped scores + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "failed" + assert result["exception_details"]["type"] == "NonexistentMappingScoresError" + assert result["data"] == {} + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # Error message originates from our mock mapping construction function + assert "test error: no mapped scores" in sample_score_set.mapping_errors["error_message"] + + # Verify that no mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_no_reference_data( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants when no reference data is provided.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=False, # No reference metadata + with_mapped_scores=True, + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "failed" + assert result["exception_details"]["type"] == "NonexistentMappingReferenceError" + assert result["data"] == {} + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "Reference metadata missing from mapping results" in sample_score_set.mapping_errors["error_message"] + + # Verify that no mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_updates_current_mapped_variants( + self, + session, + mock_s3_client, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + sample_score_dataframe, + sample_count_dataframe, + sample_independent_variant_creation_run, + ): + """Test mapping variants updates current mapped variants even if no changes occur.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + mock_worker_ctx, + sample_independent_variant_creation_run, + ) + + # Associate mapped variants with all variants just created in the score set + variants = mock_worker_ctx["db"].query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + for variant in variants: + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2023-01-01T00:00:00Z", + mapping_api_version="v1.0.0", + ) + mock_worker_ctx["db"].add(mapped_variant) + mock_worker_ctx["db"].commit() + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "ok" + assert result["data"] == {} + assert result["exception_details"] is None + + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that mapped variants were marked as non-current and new entries created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == len(variants) * 2 # Each variant has two mapped entries now + for variant in variants: + non_current_mapped_variant = ( + mock_worker_ctx["db"] + .query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(False)) + .one_or_none() + ) + assert non_current_mapped_variant is not None + + new_mapped_variant = ( + mock_worker_ctx["db"] + .query(MappedVariant) + .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) + .one_or_none() + ) + assert new_mapped_variant is not None + + # Verify that the new mapped variant has updated mapping data + assert new_mapped_variant.mapped_date != "2023-01-01T00:00:00Z" + assert new_mapped_variant.mapping_api_version != "v1.0.0" + + # Verify that the job status was updated. + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + async def test_map_variants_for_score_set_no_variants( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when no variants exist in the score set.""" + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=mock_worker_ctx["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "failed" + assert result["data"] == {} + assert result["exception_details"] is not None + assert result["exception_details"]["type"] == "NonexistentMappingScoresError" + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + assert "test error: no mapped scores" in sample_score_set.mapping_errors["error_message"] + + # Verify that no mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_exception_in_mapping( + self, + with_independent_processing_runs, + mock_worker_ctx, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants when an exception occurs during mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + raise ValueError("test exception during mapping") + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + result = await map_variants_for_score_set( + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + ) + + assert result["status"] == "failed" + assert result["data"] == {} + assert result["exception_details"]["type"] == "ValueError" + # exception messages are persisted in internal properties + assert "test exception during mapping" in result["exception_details"]["message"] + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # but replaced with generic error message for external visibility + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + mock_worker_ctx["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestMapVariantsForScoreSetArqContext: + """Integration tests for map_variants_for_score_set job using ARQ worker context.""" + + async def test_create_variants_for_score_set_with_arq_context_independent_ctx( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + with_independent_processing_runs, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, + ): + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + standalone_worker_context, + sample_independent_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=standalone_worker_context["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + await arq_redis.enqueue_job("map_variants_for_score_set", sample_independent_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that mapped variants were created + mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that each variant has a corresponding mapped variant + variants = ( + standalone_worker_context["db"] + .query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that the job status was updated + processing_run = ( + standalone_worker_context["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + async def test_map_variants_for_score_set_with_arq_context_pipeline_ctx( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + with_variant_creation_pipeline_runs, + with_variant_mapping_pipeline_runs, + with_populated_domain_data, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + sample_score_set, + sample_pipeline_variant_creation_run, + sample_pipeline_variant_mapping_run, + ): + """Test mapping variants for a pipeline processing run using ARQ context.""" + + # First, create variants in the score set + await create_variants_in_score_set( + session, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + standalone_worker_context, + sample_pipeline_variant_creation_run, + ) + + async def dummy_mapping_job(): + return await construct_mock_mapping_output( + session=standalone_worker_context["db"], + score_set=sample_score_set, + with_gene_info=True, + with_layers={"g", "c", "p"}, + with_pre_mapped=True, + with_post_mapped=True, + with_reference_metadata=True, + with_mapped_scores=True, + with_all_variants=True, + ) + + # Mock mapping output + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + # Now, map variants for the score set + await arq_redis.enqueue_job("map_variants_for_score_set", sample_pipeline_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that mapped variants were created + mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + assert len(mapped_variants) == 4 + + # Verify score set mapping state + assert sample_score_set.mapping_state == MappingState.complete + assert sample_score_set.mapping_errors is None + + # Verify that each variant has a corresponding mapped variant + variants = ( + standalone_worker_context["db"] + .query(Variant) + .join(MappedVariant, MappedVariant.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) + .all() + ) + assert len(variants) == 4 + + # Verify that the job status was updated + processing_run = ( + standalone_worker_context["db"] + .query(sample_pipeline_variant_mapping_run.__class__) + .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status was updated. We expect RUNNING here because + # the mapping job is not the only job in our dummy pipeline. + pipeline_run = ( + standalone_worker_context["db"] + .query(sample_pipeline_variant_mapping_run.pipeline.__class__) + .filter( + sample_pipeline_variant_mapping_run.pipeline.__class__.id + == sample_pipeline_variant_mapping_run.pipeline.id + ) + .one() + ) + assert pipeline_run.status == PipelineStatus.RUNNING + + async def test_map_variants_for_score_set_with_arq_context_generic_exception_handling( + self, + arq_redis, + arq_worker, + standalone_worker_context, + with_independent_processing_runs, + sample_independent_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants with ARQ context when an exception occurs during mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + raise ValueError("test exception during mapping") + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + await arq_redis.enqueue_job("map_variants_for_score_set", sample_independent_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # but replaced with generic error message for external visibility + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + standalone_worker_context["db"] + .query(sample_independent_variant_mapping_run.__class__) + .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + async def test_map_variants_for_score_set_with_arq_context_generic_exception_in_pipeline_ctx( + self, + arq_redis, + arq_worker, + standalone_worker_context, + with_variant_mapping_pipeline_runs, + sample_pipeline_variant_mapping_run, + sample_score_set, + ): + """Test mapping variants with ARQ context in pipeline when an exception occurs during mapping.""" + + # Network requests occur within an event loop. Mock result of mapping call + # with return value from run_in_executor. + async def dummy_mapping_job(): + raise ValueError("test exception during mapping") + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_mapping_job(), + ), + ): + await arq_redis.enqueue_job("map_variants_for_score_set", sample_pipeline_variant_mapping_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + assert sample_score_set.mapping_state == MappingState.failed + assert sample_score_set.mapping_errors is not None + # but replaced with generic error message for external visibility + assert ( + "Encountered an unexpected error while parsing mapped variants" + in sample_score_set.mapping_errors["error_message"] + ) + + # Verify that no mapped variants were created + mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + assert len(mapped_variants) == 0 + + # Verify that the job status was updated. + processing_run = ( + standalone_worker_context["db"] + .query(sample_pipeline_variant_mapping_run.__class__) + .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) + .one() + ) + assert processing_run.status == JobStatus.FAILED + + # Verify that the pipeline run status was updated to FAILED. + pipeline_run = ( + standalone_worker_context["db"] + .query(sample_pipeline_variant_mapping_run.pipeline.__class__) + .filter( + sample_pipeline_variant_mapping_run.pipeline.__class__.id + == sample_pipeline_variant_mapping_run.pipeline.id + ) + .one() + ) + assert pipeline_run.status == PipelineStatus.FAILED + + # Verify that other jobs in the pipeline were skipped + for job_run in pipeline_run.job_runs: + if job_run.id != sample_pipeline_variant_mapping_run.id: + assert job_run.status == JobStatus.SKIPPED From 9603334a117246691c07560e64150e5f586985ae Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Sat, 24 Jan 2026 14:52:00 -0800 Subject: [PATCH 30/70] feat: add start_pipeline job and related tests for pipeline management --- .../jobs/pipeline_management/__init__.py | 12 + .../pipeline_management/start_pipeline.py | 59 ++++ src/mavedb/worker/jobs/registry.py | 3 + .../jobs/pipeline_management/conftest.py | 62 ++++ .../test_start_pipeline.py | 300 ++++++++++++++++++ 5 files changed, 436 insertions(+) create mode 100644 src/mavedb/worker/jobs/pipeline_management/__init__.py create mode 100644 src/mavedb/worker/jobs/pipeline_management/start_pipeline.py create mode 100644 tests/worker/jobs/pipeline_management/conftest.py create mode 100644 tests/worker/jobs/pipeline_management/test_start_pipeline.py diff --git a/src/mavedb/worker/jobs/pipeline_management/__init__.py b/src/mavedb/worker/jobs/pipeline_management/__init__.py new file mode 100644 index 00000000..95470f75 --- /dev/null +++ b/src/mavedb/worker/jobs/pipeline_management/__init__.py @@ -0,0 +1,12 @@ +""" +Pipeline management job entrypoints. + +This module exposes job functions for pipeline management, such as starting a pipeline. +Import job functions here and add them to __all__ for job discovery and import convenience. +""" + +from .start_pipeline import start_pipeline + +__all__ = [ + "start_pipeline", +] diff --git a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py new file mode 100644 index 00000000..c67472e5 --- /dev/null +++ b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py @@ -0,0 +1,59 @@ +import logging + +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def start_pipeline(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """Start the pipeline associated with the given job. + + This job initializes and starts the pipeline execution process. + It sets up the necessary pipeline management context and triggers + the pipeline coordination. + + NOTE: This function requires a dedicated 'start_pipeline' job run record + in the database. This job run must be created prior to invoking this function + and should be associated with the pipeline to be started. + + Args: + ctx (dict): The job context dictionary. + job_id (int): The ID of the job run. + job_manager (JobManager): Manager for job lifecycle and DB operations. + + Side Effects: + - Initializes and starts the pipeline execution. + + Returns: + dict: Result indicating success and any exception details + """ + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "start_pipeline", + "resource": f"pipeline_for_job_{job_id}", + "correlation_id": None, + } + ) + job_manager.update_progress(0, 100, "Coordinating pipeline for the first time.") + logger.debug(msg="Coordinating pipeline for the first time.", extra=job_manager.logging_context()) + + if not job_manager.pipeline_id: + raise ValueError(f"No pipeline associated with job {job_id}") + + # Initialize PipelineManager and coordinate pipeline. The pipeline manager decorator + # will have started the pipeline for us already, but doesn't coordinate on start automatically. + pipeline_manager = PipelineManager(job_manager.db, job_manager.redis, job_manager.pipeline_id) + await pipeline_manager.coordinate_pipeline() + + # Finalize job state + job_manager.db.commit() + job_manager.update_progress(100, 100, "Initial pipeline coordination complete.") + logger.debug(msg="Done starting pipeline.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py index 06ae2b29..60654170 100644 --- a/src/mavedb/worker/jobs/registry.py +++ b/src/mavedb/worker/jobs/registry.py @@ -21,6 +21,7 @@ submit_score_set_mappings_to_ldh, submit_uniprot_mapping_jobs_for_score_set, ) +from mavedb.worker.jobs.pipeline_management import start_pipeline from mavedb.worker.jobs.variant_processing import ( create_variants_for_score_set, map_variants_for_score_set, @@ -41,6 +42,8 @@ # Data management jobs refresh_materialized_views, refresh_published_variants_view, + # Pipeline management jobs + start_pipeline, ] # Cron job definitions for ARQ worker diff --git a/tests/worker/jobs/pipeline_management/conftest.py b/tests/worker/jobs/pipeline_management/conftest.py new file mode 100644 index 00000000..d7d2a239 --- /dev/null +++ b/tests/worker/jobs/pipeline_management/conftest.py @@ -0,0 +1,62 @@ +import pytest + +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline + + +@pytest.fixture +def sample_dummy_pipeline(): + """Create a sample Pipeline instance for testing.""" + + return Pipeline( + name="Dummy Pipeline", + description="A dummy pipeline for testing purposes", + ) + + +@pytest.fixture +def with_dummy_pipeline(session, sample_dummy_pipeline): + """Fixture to ensure dummy pipeline exists in the database.""" + session.add(sample_dummy_pipeline) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline): + """Create a sample JobRun instance for starting the dummy pipeline.""" + start_job_run = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="start_pipeline", + job_function="start_pipeline", + ) + session.add(start_job_run) + session.commit() + + return start_job_run + + +@pytest.fixture +def with_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline_start): + """Fixture to ensure a start pipeline job run for the dummy pipeline exists in the database.""" + session.add(sample_dummy_pipeline_start) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline_step(session, sample_dummy_pipeline): + """Create a sample PipelineStep instance for the dummy pipeline.""" + step = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="dummy_step", + job_function="dummy_arq_function", + ) + session.add(step) + session.commit() + return step + + +@pytest.fixture +def with_full_dummy_pipeline(session, with_dummy_pipeline_start, sample_dummy_pipeline, sample_dummy_pipeline_step): + """Fixture to ensure dummy pipeline steps exist in the database.""" + session.add(sample_dummy_pipeline_step) + session.commit() diff --git a/tests/worker/jobs/pipeline_management/test_start_pipeline.py b/tests/worker/jobs/pipeline_management/test_start_pipeline.py new file mode 100644 index 00000000..12eb9675 --- /dev/null +++ b/tests/worker/jobs/pipeline_management/test_start_pipeline.py @@ -0,0 +1,300 @@ +from unittest.mock import call, patch + +import pytest +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.job_run import JobRun +from mavedb.worker.jobs.pipeline_management.start_pipeline import start_pipeline +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.pipeline_manager import PipelineManager + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestStartPipelineUnit: + """Unit tests for starting pipelines.""" + + @pytest.fixture(autouse=True) + def setup_start_pipeline_job_run(self, session, with_dummy_pipeline, sample_dummy_pipeline): + """Fixture to ensure a start pipeline job run exists in the database.""" + job_run = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="start_pipeline", + job_function="start_pipeline", + ) + session.add(job_run) + session.commit() + + return job_run + + async def test_start_pipeline_raises_exception_when_no_pipeline_associated_with_job( + self, + session, + mock_worker_ctx, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline raises an exception when no pipeline is associated with the job.""" + + # Remove pipeline association from job run + setup_start_pipeline_job_run.pipeline_id = None + session.commit() + + with pytest.raises(ValueError, match="No pipeline associated with job"): + await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + async def test_start_pipeline_starts_pipeline_successfully( + self, + session, + mock_worker_ctx, + mock_pipeline_manager, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline completes successfully.""" + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.PipelineManager") as mock_pipeline_manager_class, + patch.object(PipelineManager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + result = await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + assert result["status"] == "ok" + mock_coordinate_pipeline.assert_called_once() + + async def test_start_pipeline_updates_progress( + self, + session, + mock_worker_ctx, + mock_pipeline_manager, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline updates job progress.""" + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.PipelineManager") as mock_pipeline_manager_class, + patch.object(PipelineManager, "coordinate_pipeline", return_value=None), + patch.object( + JobManager, + "update_progress", + return_value=None, + ) as mock_update_progress, + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + result = await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + assert result["status"] == "ok" + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Coordinating pipeline for the first time."), + call(100, 100, "Initial pipeline coordination complete."), + ] + ) + + async def test_start_pipeline_raises_exception( + self, + session, + mock_worker_ctx, + mock_pipeline_manager, + setup_start_pipeline_job_run, + ): + """Test that starting a pipeline raises an exception.""" + + with ( + patch("mavedb.worker.lib.managers.pipeline_manager.PipelineManager") as mock_pipeline_manager_class, + patch.object( + PipelineManager, + "coordinate_pipeline", + side_effect=Exception("Simulated pipeline start failure"), + ), + pytest.raises(Exception, match="Simulated pipeline start failure"), + ): + mock_pipeline_manager_class.return_value = mock_pipeline_manager + + await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestStartPipelineIntegration: + """Integration tests for starting pipelines.""" + + async def test_start_pipeline_on_job_without_pipeline_fails( + self, + session, + mock_worker_ctx, + with_full_dummy_pipeline, + sample_dummy_pipeline_start, + ): + """Test that starting a pipeline on a job without an associated pipeline fails.""" + + sample_dummy_pipeline_start.pipeline_id = None + session.commit() + + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "failed" + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.FAILED + + async def test_start_pipeline_on_valid_job_succeeds_and_coordinates_pipeline( + self, session, mock_worker_ctx, with_full_dummy_pipeline, sample_dummy_pipeline_start, sample_dummy_pipeline + ): + """Test that starting a pipeline on a valid job succeeds and coordinates the pipeline.""" + + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "ok" + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.RUNNING + + async def test_start_pipeline_handles_exceptions_gracefully( + self, + session, + mock_worker_ctx, + with_full_dummy_pipeline, + sample_dummy_pipeline, + sample_dummy_pipeline_start, + ): + """Test that starting a pipeline handles exceptions gracefully.""" + # Mock a coordination failure during pipeline start. Realistically if this failed in pipeline start + # it would likely also fail during the final coordination attempt in the exception handler, but for testing purposes + # we only mock the initial failure here. In a real-world scenario, we'd likely have to rely on our alerting here and + # intervene manually or via a separate recovery job to fix the pipeline state. + real_coordinate_pipeline = PipelineManager.coordinate_pipeline + call_count = {"n": 0} + + async def custom_side_effect(*args, **kwargs): + if call_count["n"] == 0: + call_count["n"] += 1 + raise Exception("Simulated pipeline start failure") + return await real_coordinate_pipeline( + PipelineManager(session, mock_worker_ctx["db"], sample_dummy_pipeline.id), *args, **kwargs + ) # Allow the final coordination attempt to proceed 'normally' + + with patch( + "mavedb.worker.lib.managers.pipeline_manager.PipelineManager.coordinate_pipeline", + side_effect=custom_side_effect, + ): + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "failed" + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.FAILED + + # Verify that the pipeline state is updated to CANCELLED + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.FAILED + + async def test_start_pipeline_no_jobs_in_pipeline( + self, + session, + mock_worker_ctx, + with_dummy_pipeline, + sample_dummy_pipeline_start, + sample_dummy_pipeline, + ): + """Test starting a pipeline that has no jobs defined.""" + + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "ok" + + # Verify that a JobRun was created for the start_pipeline job and it succeeded + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.SUCCEEDED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestStartPipelineArqContext: + """Test starting pipelines using an ARQ worker context.""" + + async def test_start_pipeline_with_arq_context( + self, + session, + arq_redis, + arq_worker, + with_full_dummy_pipeline, + sample_dummy_pipeline_start, + sample_dummy_pipeline, + ): + """Test starting a pipeline using an ARQ worker context.""" + + await arq_redis.enqueue_job("start_pipeline", sample_dummy_pipeline_start.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the start job run status + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.RUNNING + + # Verify that other pipeline steps have been queued + pipeline_steps = ( + session.execute( + select(JobRun).where( + JobRun.pipeline_id == sample_dummy_pipeline.id, JobRun.id != sample_dummy_pipeline_start.id + ) + ) + .scalars() + .all() + ) + assert len(pipeline_steps) == 1 + assert pipeline_steps[0].job_type == "dummy_step" + assert pipeline_steps[0].status == JobStatus.QUEUED + + async def test_start_pipeline_with_arq_context_no_jobs_in_pipeline( + self, + session, + arq_redis, + arq_worker, + with_dummy_pipeline, + sample_dummy_pipeline_start, + sample_dummy_pipeline, + ): + """Test starting a pipeline with no jobs using an ARQ worker context.""" + + await arq_redis.enqueue_job("start_pipeline", sample_dummy_pipeline_start.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that a JobRun was created for the start_pipeline job and it succeeded + session.refresh(sample_dummy_pipeline_start) + assert sample_dummy_pipeline_start.status == JobStatus.SUCCEEDED + + # Verify that the pipeline state is updated appropriately + session.refresh(sample_dummy_pipeline) + assert sample_dummy_pipeline.status == PipelineStatus.SUCCEEDED From fcfb060a26fcd3e563d3c56d1458fc8639539af1 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Sat, 24 Jan 2026 16:26:09 -0800 Subject: [PATCH 31/70] feat: gnomAD managed job tests and enhancements - Adds comprehensive test cases for gnomAD managed job - Enhances athena engine in test cases with mocked db fixture --- src/mavedb/lib/gnomad.py | 64 ++- src/mavedb/scripts/link_gnomad_variants.py | 8 +- .../worker/jobs/external_services/gnomad.py | 13 +- tests/conftest.py | 53 +- .../worker/jobs/external_services/conftest.py | 99 ++++ .../external_services/network}/test_gnomad.py | 0 .../jobs/external_services/test_gnomad.py | 461 ++++++++++++++++++ 7 files changed, 657 insertions(+), 41 deletions(-) create mode 100644 tests/worker/jobs/external_services/conftest.py rename tests/{network/worker => worker/jobs/external_services/network}/test_gnomad.py (100%) diff --git a/src/mavedb/lib/gnomad.py b/src/mavedb/lib/gnomad.py index 02a7da2d..937471b8 100644 --- a/src/mavedb/lib/gnomad.py +++ b/src/mavedb/lib/gnomad.py @@ -1,19 +1,18 @@ +import logging import os import re -import logging from typing import Any, Sequence, Union -from sqlalchemy import text, select, Row +from sqlalchemy import Connection, Row, select, text from sqlalchemy.orm import Session from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.utils import batched -from mavedb.db.athena import engine as athena_engine from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.mapped_variant import MappedVariant GNOMAD_DB_NAME = "gnomAD" -GNOMAD_DATA_VERSION = os.getenv("GNOMAD_DATA_VERSION") +GNOMAD_DATA_VERSION = os.getenv("GNOMAD_DATA_VERSION", "v4.1") # e.g., "v4.1" logger = logging.getLogger(__name__) @@ -66,7 +65,9 @@ def allele_list_from_list_like_string(alleles_string: str) -> list[str]: return alleles -def gnomad_variant_data_for_caids(caids: Sequence[str]) -> Sequence[Row[Any]]: # pragma: no cover +def gnomad_variant_data_for_caids( + athena_session: Connection, caids: Sequence[str] +) -> Sequence[Row[Any]]: # pragma: no cover """ Fetches variant rows from the gnomAD table for a list of CAIDs. Athena has a maximum character limit of 262144 in queries. CAIDs are about 12 characters long on average + 4 for two quotes, a comma and a space. Chunk our list @@ -94,36 +95,33 @@ def gnomad_variant_data_for_caids(caids: Sequence[str]) -> Sequence[Row[Any]]: caid_strs = [",".join(f"'{caid}'" for caid in chunk) for chunk in chunked_caids] save_to_logging_context({"num_caids": len(caids), "num_chunks": len(caid_strs)}) - with athena_engine.connect() as athena_connection: - logger.debug(msg="Connected to Athena", extra=logging_context()) - - result_rows: list[Row[Any]] = [] - for chunk_index, caid_str in enumerate(caid_strs): - athena_query = f""" - SELECT - "locus.contig", - "locus.position", - "alleles", - "caid", - "joint.freq.all.ac", - "joint.freq.all.an", - "joint.fafmax.faf95_max_gen_anc", - "joint.fafmax.faf95_max" - FROM - {gnomad_table_name()} - WHERE - caid IN ({caid_str}) - """ - logger.debug( - msg=f"Fetching gnomAD variants from Athena (batch {chunk_index}) with query:\n{athena_query}", - extra=logging_context(), - ) + result_rows: list[Row[Any]] = [] + for chunk_index, caid_str in enumerate(caid_strs): + athena_query = f""" + SELECT + "locus.contig", + "locus.position", + "alleles", + "caid", + "joint.freq.all.ac", + "joint.freq.all.an", + "joint.fafmax.faf95_max_gen_anc", + "joint.fafmax.faf95_max" + FROM + {gnomad_table_name()} + WHERE + caid IN ({caid_str}) + """ + logger.debug( + msg=f"Fetching gnomAD variants from Athena (batch {chunk_index}) with query:\n{athena_query}", + extra=logging_context(), + ) - result = athena_connection.execute(text(athena_query)) - rows = result.fetchall() - result_rows.extend(rows) + result = athena_session.execute(text(athena_query)) + rows = result.fetchall() + result_rows.extend(rows) - logger.debug(f"Fetched {len(rows)} gnomAD variants from Athena (batch {chunk_index}).") + logger.debug(f"Fetched {len(rows)} gnomAD variants from Athena (batch {chunk_index}).") save_to_logging_context({"num_gnomad_variant_rows_fetched": len(result_rows)}) logger.debug(msg="Done fetching gnomAD variants from Athena", extra=logging_context()) diff --git a/src/mavedb/scripts/link_gnomad_variants.py b/src/mavedb/scripts/link_gnomad_variants.py index e7f0fa49..d910ea59 100644 --- a/src/mavedb/scripts/link_gnomad_variants.py +++ b/src/mavedb/scripts/link_gnomad_variants.py @@ -5,13 +5,13 @@ from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.db import athena from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants -from mavedb.models.score_set import ScoreSet from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant from mavedb.scripts.environment import with_database_session - logger = logging.getLogger(__name__) @@ -62,7 +62,9 @@ def link_gnomad_variants(db: Session, score_set_urn: list[str], all_score_sets: logger.info(f"Found {len(caids)} CAIDs for the selected score sets to link to gnomAD variants.") # 2. Query Athena for gnomAD variants matching the CAIDs - gnomad_variant_data = gnomad_variant_data_for_caids(caids) + with athena.engine.connect() as athena_session: + logger.debug("Fetching gnomAD variants from Athena.") + gnomad_variant_data = gnomad_variant_data_for_caids(athena_session, caids) if not gnomad_variant_data: logger.error("No gnomAD records found for the provided CAIDs.") diff --git a/src/mavedb/worker/jobs/external_services/gnomad.py b/src/mavedb/worker/jobs/external_services/gnomad.py index e045d247..b63b1be6 100644 --- a/src/mavedb/worker/jobs/external_services/gnomad.py +++ b/src/mavedb/worker/jobs/external_services/gnomad.py @@ -11,6 +11,7 @@ from sqlalchemy import select +from mavedb.db import athena from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet @@ -24,7 +25,7 @@ @with_pipeline_management -async def link_gnomad_variants(ctx: dict, job_manager: JobManager) -> JobResultData: +async def link_gnomad_variants(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """ Link mapped variants to gnomAD variants based on ClinGen Allele IDs (CAIDs). This job fetches mapped variants associated with a given score set that have CAIDs, @@ -37,7 +38,8 @@ async def link_gnomad_variants(ctx: dict, job_manager: JobManager) -> JobResultD Args: ctx (dict): The job context dictionary. - job_manager (JobManager): Manager for job lifecycle and DB operations. + job_id (int): The ID of the job being executed. + job_manager (JobManager): The job manager instance for database and logging operations. Side Effects: - Updates MappedVariant records to link to gnomAD variants. @@ -49,7 +51,7 @@ async def link_gnomad_variants(ctx: dict, job_manager: JobManager) -> JobResultD job = job_manager.get_job() _job_required_params = ["score_set_id", "correlation_id"] - validate_job_params(job_manager, _job_required_params, job) + validate_job_params(_job_required_params, job) # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore @@ -97,7 +99,10 @@ async def link_gnomad_variants(ctx: dict, job_manager: JobManager) -> JobResultD ) # Fetch gnomAD variant data for the CAIDs - gnomad_variant_data = gnomad_variant_data_for_caids(variant_caids) + with athena.engine.connect() as athena_session: + logger.debug("Fetching gnomAD variants from Athena.") + gnomad_variant_data = gnomad_variant_data_for_caids(athena_session, variant_caids) + num_gnomad_variants_with_caid_match = len(gnomad_variant_data) job_manager.save_to_context({"num_gnomad_variants_with_caid_match": num_gnomad_variants_with_caid_match}) diff --git a/tests/conftest.py b/tests/conftest.py index 0cb869fd..f745fe20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,11 +8,12 @@ import pytest import pytest_postgresql import pytest_socket -from sqlalchemy import create_engine, text +from sqlalchemy import Column, Float, Integer, MetaData, String, Table, create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool from mavedb.db.base import Base +from mavedb.lib.gnomad import gnomad_table_name from mavedb.models import * # noqa: F403 from mavedb.models.experiment import Experiment from mavedb.models.experiment_set import ExperimentSet @@ -105,6 +106,56 @@ def session(postgresql): Base.metadata.drop_all(bind=engine) +@pytest.fixture +def athena_engine(): + """Create and yield a SQLAlchemy engine connected to a mock Athena database.""" + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + + # TODO: Define your table schema here + my_table = Table( + gnomad_table_name(), + metadata, + Column("id", Integer, primary_key=True), + Column("locus.contig", String), + Column("locus.position", Integer), + Column("alleles", String), + Column("caid", String), + Column("joint.freq.all.ac", Integer), + Column("joint.freq.all.an", Integer), + Column("joint.fafmax.faf95_max_gen_anc", String), + Column("joint.fafmax.faf95_max", Float), + ) + metadata.create_all(engine) + + session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() + + # Insert test data + session.execute( + my_table.insert(), + [ + { + "id": 1, + "locus.contig": "chr1", + "locus.position": 12345, + "alleles": "[G, A]", + "caid": "CA123", + "joint.freq.all.ac": 23, + "joint.freq.all.an": 32432423, + "joint.fafmax.faf95_max_gen_anc": "anc1", + "joint.fafmax.faf95_max": 0.000006763700000000002, + } + ], + ) + session.commit() + session.close() + + try: + yield engine + finally: + engine.dispose() + + @pytest.fixture def setup_lib_db(session): """ diff --git a/tests/worker/jobs/external_services/conftest.py b/tests/worker/jobs/external_services/conftest.py new file mode 100644 index 00000000..ff275357 --- /dev/null +++ b/tests/worker/jobs/external_services/conftest.py @@ -0,0 +1,99 @@ +import pytest + +from mavedb.models.job_run import JobRun +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.pipeline import Pipeline +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant + + +@pytest.fixture +def link_gnomad_variants_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for create_variants_for_score_set job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def sample_link_gnomad_variants_pipeline(): + """Create a pipeline instance for link_gnomad_variants job.""" + + return Pipeline( + urn="test:link_gnomad_variants_pipeline", + name="Link gnomAD Variants Pipeline", + ) + + +@pytest.fixture +def sample_link_gnomad_variants_run(link_gnomad_variants_sample_params): + """Create a JobRun instance for link_gnomad_variants job.""" + + return JobRun( + urn="test:link_gnomad_variants", + job_type="link_gnomad_variants", + job_function="link_gnomad_variants", + max_retries=3, + retry_count=0, + job_params=link_gnomad_variants_sample_params, + ) + + +@pytest.fixture +def with_gnomad_linking_job(session, sample_link_gnomad_variants_run): + """Add a link_gnomad_variants job run to the session.""" + + session.add(sample_link_gnomad_variants_run) + session.commit() + + +@pytest.fixture +def with_gnomad_linking_pipeline(session, sample_link_gnomad_variants_pipeline): + """Add a link_gnomad_variants pipeline to the session.""" + + session.add(sample_link_gnomad_variants_pipeline) + session.commit() + + +@pytest.fixture +def sample_link_gnomad_variants_run_pipeline( + session, + with_gnomad_linking_job, + with_gnomad_linking_pipeline, + sample_link_gnomad_variants_run, + sample_link_gnomad_variants_pipeline, +): + """Provide a context with a link_gnomad_variants job run and pipeline.""" + + sample_link_gnomad_variants_run.pipeline_id = sample_link_gnomad_variants_pipeline.id + session.commit() + return sample_link_gnomad_variants_run + + +@pytest.fixture +def setup_sample_variants_with_caid(with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run): + """Setup variants and mapped variants in the database for testing.""" + session = mock_worker_ctx["db"] + score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) + + # Add a variant and mapped variant to the database with a CAID + variant = Variant( + urn="urn:variant:test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={"hgvs_c": "NM_000000.1:c.1A>G", "hgvs_p": "NP_000000.1:p.Met1Val"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA123", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() diff --git a/tests/network/worker/test_gnomad.py b/tests/worker/jobs/external_services/network/test_gnomad.py similarity index 100% rename from tests/network/worker/test_gnomad.py rename to tests/worker/jobs/external_services/network/test_gnomad.py diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index e69de29b..81b4e3ae 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -0,0 +1,461 @@ +from unittest.mock import MagicMock, call, patch + +import pytest + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.gnomad_variant import GnomADVariant +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants +from mavedb.worker.lib.managers.job_manager import JobManager + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestLinkGnomadVariantsUnit: + """Unit tests for the link_gnomad_variants job.""" + + @pytest.fixture + def setup_sample_variants_with_caid( + self, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run + ): + """Setup variants and mapped variants in the database for testing.""" + session = mock_worker_ctx["db"] + score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) + + # Add a variant and mapped variant to the database with a CAID + variant = Variant( + urn="urn:variant:test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={"hgvs_c": "NM_000000.1:c.1A>G", "hgvs_p": "NP_000000.1:p.Met1Val"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA123", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + async def test_link_gnomad_variants_no_variants_with_caids( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + ): + """Test linking gnomAD variants when no mapped variants have CAIDs.""" + with patch.object(JobManager, "update_progress") as mock_update_progress: + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_any_call( + 100, 100, "No variants with CAIDs found to link to gnomAD variants. Nothing to do." + ) + + async def test_link_gnomad_variants_no_gnomad_matches( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test linking gnomAD variants when no gnomAD variants match the CAIDs.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value={}, + ), + ): + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_any_call(100, 100, "No gnomAD variants with CAID matches found. Nothing to link.") + + async def test_link_gnomad_variants_call_linking_method( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that the linking method is called when gnomAD variants match CAIDs.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[MagicMock()], + ), + patch( + "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", + return_value=1, + ) as mock_linking_method, + ): + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_linking_method.assert_called_once() + mock_update_progress.assert_any_call(100, 100, "Linked 1 mapped variants to gnomAD variants.") + + async def test_link_gnomad_variants_updates_progress( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that progress updates are made during the linking process.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + return_value=[MagicMock()], + ), + patch( + "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", + return_value=1, + ), + ): + result = await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting gnomAD mapped resource linkage."), + call(10, 100, "Found 1 variants with CAIDs to link to gnomAD variants."), + call(75, 100, "Found 1 gnomAD variants matching CAIDs."), + call(100, 100, "Linked 1 mapped variants to gnomAD variants."), + ] + ) + + async def test_link_gnomad_variants_propagates_exceptions( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that exceptions during the linking process are propagated.""" + with patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ): + with pytest.raises(Exception) as exc_info: + await link_gnomad_variants( + mock_worker_ctx, + 1, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + ) + + assert str(exc_info.value) == "Test exception" + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestLinkGnomadVariantsIntegration: + """Integration tests for the link_gnomad_variants job.""" + + async def test_link_gnomad_variants_no_variants_with_caids( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + ): + """Test the end-to-end functionality of the link_gnomad_variants job when no variants have CAIDs.""" + + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run.id) + assert result["status"] == "ok" + + # Verify that no gnomAD variants were linked + session = mock_worker_ctx["db"] + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_no_matching_caids( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test the end-to-end functionality of the link_gnomad_variants job when no matching CAIDs are found.""" + # Update the created mapped variant to have a CAID that won't match any gnomAD data + session = mock_worker_ctx["db"] + mapped_variant = session.query(MappedVariant).first() + mapped_variant.clingen_allele_id = "NON_MATCHING_CAID" + session.commit() + + # Patch the athena engine to use the mock athena_engine fixture + with patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine): + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run.id) + + assert result["status"] == "ok" + + # Verify that no gnomAD variants were linked + session = mock_worker_ctx["db"] + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_successful_linking_independent( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test the end-to-end functionality of the link_gnomad_variants job with successful linking.""" + + # Patch the athena engine to use the mock athena_engine fixture + with patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine): + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run.id) + + assert result["status"] == "ok" + + # Verify that gnomAD variants were linked + session = mock_worker_ctx["db"] + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_successful_linking_pipeline( + self, + with_populated_domain_data, + mock_worker_ctx, + sample_link_gnomad_variants_run_pipeline, + sample_link_gnomad_variants_pipeline, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test the end-to-end functionality of the link_gnomad_variants job with successful linking in a pipeline.""" + + # Patch the athena engine to use the mock athena_engine fixture + with patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine): + result = await link_gnomad_variants(mock_worker_ctx, sample_link_gnomad_variants_run_pipeline.id) + + assert result["status"] == "ok" + + # Verify that gnomAD variants were linked + session = mock_worker_ctx["db"] + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify job status updates + session.refresh(sample_link_gnomad_variants_run_pipeline) + assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.SUCCEEDED + + # Verify pipeline status updates + session.refresh(sample_link_gnomad_variants_pipeline) + assert sample_link_gnomad_variants_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_link_gnomad_variants_exceptions_handled_by_decorators( + self, + with_populated_domain_data, + with_gnomad_linking_job, + mock_worker_ctx, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + athena_engine, + ): + """Test that exceptions during the linking process are handled by decorators.""" + + # Patch the athena engine to use the mock athena_engine fixture + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + ): + result = await link_gnomad_variants( + mock_worker_ctx, + sample_link_gnomad_variants_run.id, + ) + + assert result["status"] == "failed" + assert "Test exception" in result["exception_details"]["message"] + + # Verify job status updates + session = mock_worker_ctx["db"] + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.FAILED + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestLinkGnomadVariantsArqContext: + """Tests for link_gnomad_variants job using the ARQ context fixture.""" + + async def test_link_gnomad_variants_with_arq_context_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_gnomad_linking_job, + athena_engine, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that the link_gnomad_variants job works with the ARQ context fixture.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify that the job completed successfully + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED + + async def test_link_gnomad_variants_with_arq_context_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + athena_engine, + sample_link_gnomad_variants_run_pipeline, + sample_link_gnomad_variants_pipeline, + setup_sample_variants_with_caid, + ): + """Test that the link_gnomad_variants job works with the ARQ context fixture in a pipeline.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run_pipeline.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) > 0 + + # Verify that the job completed successfully + session.refresh(sample_link_gnomad_variants_run_pipeline) + assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.SUCCEEDED + + # Verify pipeline status updates + session.refresh(sample_link_gnomad_variants_pipeline) + assert sample_link_gnomad_variants_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_link_gnomad_variants_with_arq_context_exception_handling_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_gnomad_linking_job, + athena_engine, + sample_link_gnomad_variants_run, + setup_sample_variants_with_caid, + ): + """Test that exceptions in the link_gnomad_variants job are handled with the ARQ context fixture.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that no gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify that the job failed + session.refresh(sample_link_gnomad_variants_run) + assert sample_link_gnomad_variants_run.status == JobStatus.FAILED + + async def test_link_gnomad_variants_with_arq_context_exception_handling_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + athena_engine, + sample_link_gnomad_variants_pipeline, + sample_link_gnomad_variants_run_pipeline, + setup_sample_variants_with_caid, + ): + """Test that exceptions in the link_gnomad_variants job are handled with the ARQ context fixture.""" + + with ( + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + ): + await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run_pipeline.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that no gnomAD variants were linked + gnomad_variants = session.query(GnomADVariant).all() + assert len(gnomad_variants) == 0 + + # Verify that the job failed + session.refresh(sample_link_gnomad_variants_run_pipeline) + assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.FAILED + + # Verify that the pipeline failed + session.refresh(sample_link_gnomad_variants_pipeline) + assert sample_link_gnomad_variants_pipeline.status == PipelineStatus.FAILED From a301f2d205cde6659775e35b3d0e9e97a86fd473 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Mon, 26 Jan 2026 20:09:19 -0800 Subject: [PATCH 32/70] feat: uniprot managed job tests and enhancements Adds comprehensive test cases for uniprot managed jobs and tweaks logic to support testing. Adds e2e testing for API methods with limited and marked network tests. --- src/mavedb/lib/exceptions.py | 18 + .../worker/jobs/external_services/uniprot.py | 198 +- tests/network/worker/test_uniprot.py | 0 .../worker/jobs/external_services/conftest.py | 266 +++ .../external_services/network/test_uniprot.py | 60 + .../jobs/external_services/test_uniprot.py | 2014 +++++++++++++++++ 6 files changed, 2493 insertions(+), 63 deletions(-) delete mode 100644 tests/network/worker/test_uniprot.py create mode 100644 tests/worker/jobs/external_services/network/test_uniprot.py diff --git a/src/mavedb/lib/exceptions.py b/src/mavedb/lib/exceptions.py index aae550d4..db7458f1 100644 --- a/src/mavedb/lib/exceptions.py +++ b/src/mavedb/lib/exceptions.py @@ -208,3 +208,21 @@ class UniProtPollingEnqueueError(ValueError): """Raised when a UniProt ID polling job fails to be enqueued despite appearing as if it should have been""" pass + + +class UniprotMappingResultNotFoundError(ValueError): + """Raised when no UniProt ID is found in the mapping results for a target gene.""" + + pass + + +class UniprotAmbiguousMappingResultError(ValueError): + """Raised when ambiguous UniProt IDs are found in the mapping results for a target gene.""" + + pass + + +class NonExistentTargetGeneError(ValueError): + """Raised when a target gene does not exist in the database.""" + + pass diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py index 713cd60f..fccfdadf 100644 --- a/src/mavedb/worker/jobs/external_services/uniprot.py +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -9,12 +9,18 @@ """ import logging +from typing import Optional, TypedDict from sqlalchemy import select - -from mavedb.lib.exceptions import UniProtPollingEnqueueError +from sqlalchemy.orm.attributes import flag_modified + +from mavedb.lib.exceptions import ( + NonExistentTargetGeneError, + UniprotAmbiguousMappingResultError, + UniprotMappingResultNotFoundError, + UniProtPollingEnqueueError, +) from mavedb.lib.mapping import extract_ids_from_post_mapped_metadata -from mavedb.lib.slack import log_and_send_slack_message from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession from mavedb.models.job_dependency import JobDependency @@ -27,16 +33,30 @@ logger = logging.getLogger(__name__) +class MappingJob(TypedDict): + job_id: Optional[str] + accession: str + + @with_pipeline_management -async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobManager) -> JobResultData: +async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """Submit UniProt ID mapping jobs for all target genes in a given ScoreSet. + NOTE: This function assumes that a dependent polling job has already been created + for the same ScoreSet. It is the responsibility of this function to ensure that + the polling job exists and to set the `mapping_jobs` parameter on the polling job. + + Without running the polling job, the results of the submitted UniProt mapping jobs + will never be retrieved or processed, so running this function alone is insufficient + to complete the UniProt mapping workflow. + Job Parameters: - score_set_id (int): The ID of the ScoreSet containing target genes to map. - correlation_id (str): Correlation ID for tracing requests across services. Args: ctx (dict): The job context dictionary. + job_id (int): The ID of the job being executed. job_manager (JobManager): Manager for job lifecycle and DB operations. Side Effects: @@ -45,6 +65,9 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobM Sets the parameter `mapping_jobs` on the polling job with a dictionary of target gene IDs to UniProt job IDs. TODO#XXX: Split mapping jobs into one per target gene so that polling can be more granular. + Raises: + - UniProtPollingEnqueueError: If the dependent polling job cannot be found. + Returns: dict: Result indicating success and any exception details """ @@ -52,7 +75,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobM job = job_manager.get_job() _job_required_params = ["score_set_id", "correlation_id"] - validate_job_params(job_manager, _job_required_params, job) + validate_job_params(_job_required_params, job) # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore @@ -70,76 +93,107 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobM job_manager.update_progress(0, 100, "Starting UniProt mapping job submission.") logger.info(msg="Started UniProt mapping job submission", extra=job_manager.logging_context()) - if not score_set or not score_set.target_genes: + # Preset submitted jobs metadata so it persists even if no jobs are submitted. + job.metadata_["submitted_jobs"] = {} + job_manager.db.commit() + + if not score_set.target_genes: job_manager.update_progress(100, 100, "No target genes found. Skipped UniProt mapping job submission.") - msg = f"No target genes for score set {score_set.id}. Skipped mapping targets to UniProt." - log_and_send_slack_message(msg=msg, ctx=job_manager.logging_context(), level=logging.WARNING) + logger.error( + msg=f"No target genes found for score set {score_set.urn}. Skipped UniProt mapping job submission.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception_details": None} uniprot_api = UniProtIDMappingAPI() job_manager.save_to_context({"total_target_genes_to_map_to_uniprot": len(score_set.target_genes)}) - mapping_jobs = {} + mapping_jobs: dict[str, MappingJob] = {} for idx, target_gene in enumerate(score_set.target_genes): acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + logger.warning( + msg=f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped mapping this target.", + extra=job_manager.logging_context(), + ) continue if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. This target will be skipped." - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + logger.warning( + msg=f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped mapping this target.", + extra=job_manager.logging_context(), + ) continue ac_to_map = acs[0] from_db = infer_db_name_from_sequence_accession(ac_to_map) spawned_job = uniprot_api.submit_id_mapping(from_db, "UniProtKB", [ac_to_map]) # type: ignore - mapping_jobs[target_gene.id] = {"job_id": spawned_job, "accession_mapped": ac_to_map} + + # Explicitly cast ints to strs in mapping job keys. These are converted to strings internally + # by SQLAlchemy when storing job_params as JSON, so be explicit here to avoid confusion. + mapping_jobs[str(target_gene.id)] = {"job_id": spawned_job, "accession": ac_to_map} job_manager.save_to_context( { "submitted_uniprot_mapping_jobs": { **job_manager.logging_context().get("submitted_uniprot_mapping_jobs", {}), - target_gene.id: mapping_jobs[target_gene.id], + str(target_gene.id): mapping_jobs[str(target_gene.id)], } } ) - logger.info( - msg=f"Submitted UniProt ID mapping job for target gene {target_gene.id}.", - extra=job_manager.logging_context(), - ) job_manager.update_progress( - int((idx + 1 / len(score_set.target_genes)) * 100), + int((idx + 1 / len(score_set.target_genes)) * 95), 100, f"Submitted UniProt mapping job for target gene {target_gene.name}.", ) + logger.info( + msg=f"Submitted UniProt ID mapping job for target gene {target_gene.id}.", + extra=job_manager.logging_context(), + ) - # Set mapping jobs on dependent polling job. Only one polling job per score set should be created. + # Save submitted jobs to job metadata for auditing purposes + job.metadata_["submitted_jobs"] = mapping_jobs + flag_modified(job, "metadata_") + job_manager.db.commit() + + # If no mapping jobs were submitted, log and exit early. + if not mapping_jobs or not any((job_info["job_id"] for job_info in mapping_jobs.values())): + job_manager.update_progress(100, 100, "No UniProt mapping jobs were submitted.") + logger.warning(msg="No UniProt mapping jobs were submitted.", extra=job_manager.logging_context()) + + return {"status": "ok", "data": {}, "exception_details": None} + + # It's an essential responsibility of the submit job (when submissions exist) to ensure that the polling job exists. dependent_polling_job = job_manager.db.scalars( select(JobDependency).where(JobDependency.depends_on_job_id == job.id) ).all() - if not dependent_polling_job or len(dependent_polling_job) != 1: + job_manager.update_progress(100, 100, "Failed to submit UniProt mapping jobs.") + logger.error( + msg=f"Could not find unique dependent polling job for UniProt mapping job {job.id}.", + extra=job_manager.logging_context(), + ) + raise UniProtPollingEnqueueError( f"Could not find unique dependent polling job for UniProt mapping job {job.id}." ) + # Set mapping jobs on dependent polling job. Only one polling job per score set should be created. polling_job = dependent_polling_job[0].job_run polling_job.job_params = { **(polling_job.job_params or {}), - "mapping_jobs": { - target_gene_id: mapping_info["job_id"] for target_gene_id, mapping_info in mapping_jobs.items() - }, + "mapping_jobs": mapping_jobs, } - job_manager.db.add(polling_job) + job_manager.update_progress(100, 100, "Completed submission of UniProt mapping jobs.") + logger.info(msg="Completed UniProt mapping job submission", extra=job_manager.logging_context()) job_manager.db.commit() return {"status": "ok", "data": {}, "exception_details": None} @with_pipeline_management -async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobManager) -> JobResultData: +async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """Submit UniProt ID mapping jobs for all target genes in a given ScoreSet. Job Parameters: @@ -149,8 +203,13 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobMan Args: ctx (dict): The job context dictionary. + job_id (int): The ID of the job being processed. job_manager (JobManager): Manager for job lifecycle and DB operations. + Side Effects: + - Polls UniProt ID mapping jobs for each target gene in the ScoreSet. + - Updates target genes with mapped UniProt IDs in the database. + TODO#XXX: Split mapping jobs into one per target gene so that polling can be more granular. Returns: @@ -160,12 +219,12 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobMan job = job_manager.get_job() _job_required_params = ["score_set_id", "correlation_id", "mapping_jobs"] - validate_job_params(job_manager, _job_required_params, job) + validate_job_params(_job_required_params, job) # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore correlation_id = job.job_params["correlation_id"] # type: ignore - mapping_jobs = job.job_params.get("mapping_jobs", {}) # type: ignore + mapping_jobs: dict[str, MappingJob] = job.job_params.get("mapping_jobs", {}) # type: ignore # Setup initial context and progress job_manager.save_to_context( @@ -179,54 +238,67 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobMan job_manager.update_progress(0, 100, "Starting UniProt mapping job polling.") logger.info(msg="Started UniProt mapping job polling", extra=job_manager.logging_context()) - if not score_set or not score_set.target_genes: - msg = f"No target genes for score set {score_set.id}. Skipped polling targets for UniProt mapping results." - log_and_send_slack_message(msg=msg, ctx=job_manager.logging_context(), level=logging.WARNING) - + if not mapping_jobs or not any(mapping_jobs.values()): + job_manager.update_progress(100, 100, "No mapping jobs found to poll.") + logger.warning( + msg=f"No mapping jobs found in job parameters for polling UniProt mapping jobs for score set {score_set.urn}.", + extra=job_manager.logging_context(), + ) return {"status": "ok", "data": {}, "exception_details": None} # Poll each mapping job and update target genes with UniProt IDs uniprot_api = UniProtIDMappingAPI() - for target_gene in score_set.target_genes: - acs = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not acs: - msg = f"No accession IDs found in post_mapped_metadata for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) - continue - - if len(acs) != 1: - msg = f"More than one accession ID is associated with target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) - continue - - mapped_ac = acs[0] - job_id = mapping_jobs.get(target_gene.id) # type: ignore - - if not job_id: - msg = f"No job ID found for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target." - # This issue has already been sent to Slack in the job submission function, so we just log it here. - logger.debug(msg=msg, extra=job_manager.logging_context()) + for target_gene_id, mapping_job in mapping_jobs.items(): + mapping_job_id = mapping_job["job_id"] + + if not mapping_job_id: + logger.warning( + msg=f"No UniProt mapping job ID found for target gene ID {target_gene_id}. Skipped polling this job.", + extra=job_manager.logging_context(), + ) continue - if not uniprot_api.check_id_mapping_results_ready(job_id): - msg = f"Job {job_id} not ready for target gene {target_gene.id} in score set {score_set.urn}. Skipped polling this target" - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) + # Check if the mapping job is ready + if not uniprot_api.check_id_mapping_results_ready(mapping_job_id): + logger.warning( + msg=f"Job {mapping_job_id} not ready. Skipped polling this job.", + extra=job_manager.logging_context(), + ) + # TODO#XXX: When results are not ready, we want to signal to the manager a desire to retry + # this polling job later. For now, we just skip and log. continue - results = uniprot_api.get_id_mapping_results(job_id) + # Extract mapped UniProt IDs from results + results = uniprot_api.get_id_mapping_results(mapping_job_id) mapped_ids = uniprot_api.extract_uniprot_id_from_results(results) + mapped_ac = mapping_job["accession"] + # Handle cases where no or ambiguous results are found if not mapped_ids: - msg = f"No UniProt ID found for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) - continue + msg = f"No UniProt ID found for accession {mapped_ac}. Cannot add UniProt ID." + job_manager.update_progress(100, 100, msg) + logger.error(msg=msg, extra=job_manager.logging_context()) + raise UniprotMappingResultNotFoundError() if len(mapped_ids) != 1: - msg = f"Found ambiguous Uniprot ID mapping results for target gene {target_gene.id} in score set {score_set.urn}. Cannot add UniProt ID for this target." - log_and_send_slack_message(msg, job_manager.logging_context(), logging.WARNING) - continue + msg = f"Ambiguous UniProt ID mapping results for accession {mapped_ac}. Cannot add UniProt ID." + job_manager.update_progress(100, 100, msg) + logger.error(msg=msg, extra=job_manager.logging_context()) + raise UniprotAmbiguousMappingResultError() mapped_uniprot_id = mapped_ids[0][mapped_ac]["uniprot_id"] + + # Update target gene with mapped UniProt ID + target_gene = next( + (tg for tg in score_set.target_genes if str(tg.id) == str(target_gene_id)), + None, + ) + if not target_gene: + msg = f"Target gene ID {target_gene_id} not found in score set {score_set.urn}. Cannot add UniProt ID." + job_manager.update_progress(100, 100, msg) + logger.error(msg=msg, extra=job_manager.logging_context()) + raise NonExistentTargetGeneError() + target_gene.uniprot_id_from_mapped_metadata = mapped_uniprot_id job_manager.db.add(target_gene) logger.info( @@ -234,7 +306,7 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_manager: JobMan extra=job_manager.logging_context(), ) job_manager.update_progress( - int((list(score_set.target_genes).index(target_gene) + 1 / len(score_set.target_genes)) * 100), + int((list(score_set.target_genes).index(target_gene) + 1) / len(score_set.target_genes) * 95), 100, f"Polled UniProt mapping job for target gene {target_gene.name}.", ) diff --git a/tests/network/worker/test_uniprot.py b/tests/network/worker/test_uniprot.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/worker/jobs/external_services/conftest.py b/tests/worker/jobs/external_services/conftest.py index ff275357..2f422506 100644 --- a/tests/worker/jobs/external_services/conftest.py +++ b/tests/worker/jobs/external_services/conftest.py @@ -1,11 +1,15 @@ import pytest +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun from mavedb.models.mapped_variant import MappedVariant from mavedb.models.pipeline import Pipeline from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant +## Gnomad Linkage Job Fixtures ## + @pytest.fixture def link_gnomad_variants_sample_params(with_populated_domain_data, sample_score_set): @@ -97,3 +101,265 @@ def setup_sample_variants_with_caid(with_populated_domain_data, mock_worker_ctx, ) session.add(mapped_variant) session.commit() + + +## Uniprot Job Fixtures ## + + +@pytest.fixture +def submit_uniprot_mapping_jobs_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for submit_uniprot_mapping_jobs_for_score_set job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def poll_uniprot_mapping_jobs_sample_params( + submit_uniprot_mapping_jobs_sample_params, + with_dependent_polling_job_for_submission_run, +): + """Provide sample parameters for poll_uniprot_mapping_jobs_for_score_set job.""" + + return { + "correlation_id": submit_uniprot_mapping_jobs_sample_params["correlation_id"], + "score_set_id": submit_uniprot_mapping_jobs_sample_params["score_set_id"], + "mapping_jobs": {}, + } + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_pipeline(): + """Create a pipeline instance for submit_uniprot_mapping_jobs_for_score_set job.""" + + return Pipeline( + urn="test:submit_uniprot_mapping_jobs_pipeline", + name="Submit UniProt Mapping Jobs Pipeline", + ) + + +@pytest.fixture +def sample_poll_uniprot_mapping_jobs_pipeline(): + """Create a pipeline instance for poll_uniprot_mapping_jobs_for_score_set job.""" + + return Pipeline( + urn="test:poll_uniprot_mapping_jobs_pipeline", + name="Poll UniProt Mapping Jobs Pipeline", + ) + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_run(submit_uniprot_mapping_jobs_sample_params): + """Create a JobRun instance for submit_uniprot_mapping_jobs_for_score_set job.""" + + return JobRun( + urn="test:submit_uniprot_mapping_jobs", + job_type="submit_uniprot_mapping_jobs", + job_function="submit_uniprot_mapping_jobs_for_score_set", + max_retries=3, + retry_count=0, + job_params=submit_uniprot_mapping_jobs_sample_params, + ) + + +@pytest.fixture +def sample_dummy_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, +): + """Create a sample dummy dependent polling job for the submission run.""" + + dependent_job = JobRun( + urn="test:dummy_poll_uniprot_mapping_jobs", + job_type="dummy_poll_uniprot_mapping_jobs", + job_function="dummy_arq_function", + max_retries=3, + retry_count=0, + job_params={ + "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], + "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], + "mapping_jobs": {}, + }, + ) + + return dependent_job + + +@pytest.fixture +def sample_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, +): + """Create a sample dependent polling job for the submission run.""" + + dependent_job = JobRun( + urn="test:dependent_poll_uniprot_mapping_jobs", + job_type="dependent_poll_uniprot_mapping_jobs", + job_function="poll_uniprot_mapping_jobs_for_score_set", + max_retries=3, + retry_count=0, + job_params={ + "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], + "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], + "mapping_jobs": {}, + }, + ) + + return dependent_job + + +@pytest.fixture +def with_dummy_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, +): + """Create a sample dummy dependent polling job for the submission run.""" + session.add(sample_dummy_polling_job_for_submission_run) + session.commit() + + dependency = JobDependency( + id=sample_dummy_polling_job_for_submission_run.id, + depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, + dependency_type=DependencyType.SUCCESS_REQUIRED, + ) + session.add(dependency) + session.commit() + + +@pytest.fixture +def with_dependent_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, + sample_polling_job_for_submission_run, +): + """Create a sample dependent polling job for the submission run.""" + session.add(sample_polling_job_for_submission_run) + session.commit() + + dependency = JobDependency( + id=sample_polling_job_for_submission_run.id, + depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, + dependency_type=DependencyType.SUCCESS_REQUIRED, + ) + session.add(dependency) + session.commit() + + +@pytest.fixture +def with_independent_polling_job_for_submission_run( + session, + sample_polling_job_for_submission_run, +): + """Create a sample dependent polling job for the submission run.""" + session.add(sample_polling_job_for_submission_run) + session.commit() + + +@pytest.fixture +def with_submit_uniprot_mapping_job(session, sample_submit_uniprot_mapping_jobs_run): + """Add a submit_uniprot_mapping_jobs job run to the session.""" + + session.add(sample_submit_uniprot_mapping_jobs_run) + session.commit() + + +@pytest.fixture +def with_poll_uniprot_mapping_job(session, sample_poll_uniprot_mapping_jobs_run): + """Add a poll_uniprot_mapping_jobs job run to the session.""" + + session.add(sample_poll_uniprot_mapping_jobs_run) + session.commit() + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_run_in_pipeline( + session, + with_submit_uniprot_mapping_job, + with_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run, + sample_submit_uniprot_mapping_jobs_pipeline, +): + """Provide a context with a submit_uniprot_mapping_jobs job run and pipeline.""" + + sample_submit_uniprot_mapping_jobs_run.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return sample_submit_uniprot_mapping_jobs_run + + +@pytest.fixture +def sample_poll_uniprot_mapping_jobs_run_in_pipeline( + session, + with_independent_polling_job_for_submission_run, + with_poll_uniprot_mapping_jobs_pipeline, + sample_polling_job_for_submission_run, + sample_poll_uniprot_mapping_jobs_pipeline, +): + """Provide a context with a poll_uniprot_mapping_jobs job run and pipeline.""" + + sample_polling_job_for_submission_run.pipeline_id = sample_poll_uniprot_mapping_jobs_pipeline.id + session.commit() + return sample_polling_job_for_submission_run + + +@pytest.fixture +def sample_dummy_polling_job_for_submission_run_in_pipeline( + session, + with_dummy_polling_job_for_submission_run, + with_submit_uniprot_mapping_jobs_pipeline, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_dummy_polling_job_for_submission_run, +): + """Provide a context with a dependent polling job run in the pipeline.""" + + dependent_job = sample_dummy_polling_job_for_submission_run + dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return dependent_job + + +@pytest.fixture +def sample_polling_job_for_submission_run_in_pipeline( + session, + with_dependent_polling_job_for_submission_run, + with_submit_uniprot_mapping_jobs_pipeline, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_polling_job_for_submission_run, +): + """Provide a context with a dependent polling job run in the pipeline.""" + + dependent_job = sample_polling_job_for_submission_run + dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return dependent_job + + +@pytest.fixture +def with_submit_uniprot_mapping_jobs_pipeline( + session, + sample_submit_uniprot_mapping_jobs_pipeline, +): + """Add a submit_uniprot_mapping_jobs pipeline to the session.""" + + session.add(sample_submit_uniprot_mapping_jobs_pipeline) + session.commit() + + +@pytest.fixture +def with_poll_uniprot_mapping_jobs_pipeline( + session, + sample_poll_uniprot_mapping_jobs_pipeline, +): + """Add a poll_uniprot_mapping_jobs pipeline to the session.""" + session.add(sample_poll_uniprot_mapping_jobs_pipeline) + session.commit() diff --git a/tests/worker/jobs/external_services/network/test_uniprot.py b/tests/worker/jobs/external_services/network/test_uniprot.py new file mode 100644 index 00000000..249a412c --- /dev/null +++ b/tests/worker/jobs/external_services/network/test_uniprot.py @@ -0,0 +1,60 @@ +import pytest + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from tests.helpers.constants import TEST_REFSEQ_IDENTIFIER + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.network +class TestE2EUniprotMappingJobs: + """End-to-end tests for UniProt mapping jobs.""" + + async def test_uniprot_mapping_jobs_e2e( + self, + session, + arq_redis, + arq_worker, + sample_score_set, + with_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_polling_job_for_submission_run_in_pipeline, + ): + """Test the end-to-end flow of submitting and polling UniProt mapping jobs.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [TEST_REFSEQ_IDENTIFIER]}} + session.commit() + + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + submitted_jobs = sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] + assert "1" in submitted_jobs + assert submitted_jobs["1"]["job_id"] is not None + assert submitted_jobs["1"]["accession"] == TEST_REFSEQ_IDENTIFIER + + # Verify that polling job params have been updated correctly + session.refresh(sample_polling_job_for_submission_run_in_pipeline) + assert sample_polling_job_for_submission_run_in_pipeline.job_params["mapping_jobs"] == { + "1": {"job_id": submitted_jobs["1"]["job_id"], "accession": TEST_REFSEQ_IDENTIFIER} + } + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job has run and is succeeded (pipeline ctx) + session.refresh(sample_polling_job_for_submission_run_in_pipeline) + assert sample_polling_job_for_submission_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is running + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index e69de29b..fc0f9fa5 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -0,0 +1,2014 @@ +from unittest.mock import call, patch + +import pytest + +from mavedb.lib.exceptions import ( + NonExistentTargetGeneError, + UniprotAmbiguousMappingResultError, + UniprotMappingResultNotFoundError, + UniProtPollingEnqueueError, +) +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.target_gene import TargetGene +from mavedb.models.target_sequence import TargetSequence +from mavedb.worker.jobs.external_services.uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import ( + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + TEST_UNIPROT_SWISS_PROT_TYPE, + VALID_NT_ACCESSION, + VALID_UNIPROT_ACCESSION, +) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestSubmitUniprotMappingJobsForScoreSetUnit: + """Unit tests for submit_uniprot_mapping_jobs_for_score_set function.""" + + async def test_submit_uniprot_mapping_jobs_no_targets( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when no target genes are present.""" + + # Ensure the sample score set has no target genes + sample_score_set.target_genes = [] + mock_worker_ctx["db"].commit() + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, 100, "No target genes found. Skipped UniProt mapping job submission." + ) + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_no_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when no ACs are present in post mapped metadata.""" + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No UniProt mapping jobs were submitted.") + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_too_many_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when too many ACs are present in post mapped metadata.""" + + # Arrange the post mapped metadata to have multiple ACs + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION, "P67890"]}} + session.commit() + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No UniProt mapping jobs were submitted.") + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_no_jobs_submitted( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test submitting UniProt mapping jobs when no jobs are submitted.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value=None, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No UniProt mapping jobs were submitted.") + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": None, "accession": VALID_NT_ACCESSION} + } + + async def test_submit_uniprot_mapping_jobs_api_failure_raises( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test handling of UniProt API failure during job submission.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch.object(JobManager, "update_progress"), + pytest.raises(Exception, match="UniProt API failure"), + ): + await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + async def test_submit_uniprot_mapping_jobs_raises_dependent_job_not_available( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test handling when dependent polling job is not available.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(UniProtPollingEnqueueError), + ): + await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "Failed to submit UniProt mapping jobs.") + + # Verify that the job metadata contains the submitted jobs (which were submitted before the error) + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + + async def test_submit_uniprot_mapping_jobs_successful_submission( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Test successful submission of UniProt mapping jobs.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch.object(JobManager, "update_progress"), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + assert job_result["status"] == "ok" + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + async def test_submit_uniprot_mapping_jobs_partial_submission( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Test partial submission of UniProt mapping jobs.""" + + # Add another target gene to the score set to simulate multiple submissions + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + mock_worker_ctx["db"].add(new_target_gene) + mock_worker_ctx["db"].commit() + + # Arrange the post mapped metadata to have a single AC for both target genes + target_gene_1 = sample_score_set.target_genes[0] + target_gene_1.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + target_gene_2 = new_target_gene + target_gene_2.post_mapped_metadata = {"protein": {"sequence_accessions": ["NM_000546"]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=["job_12345", None], + ), + patch.object(JobManager, "update_progress"), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + assert job_result["status"] == "ok" + + expected_submitted_jobs = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}, + "2": {"job_id": None, "accession": "NM_000546"}, + } + + # Verify that the job metadata contains both submitted and failed jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + async def test_submit_uniprot_mapping_jobs_updates_progress( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Test that progress updates are made during UniProt mapping job submission.""" + + # Arrange the post mapped metadata to have a single AC + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_submit_uniprot_mapping_jobs_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting UniProt mapping job submission."), + call( + 95, 100, f"Submitted UniProt mapping job for target gene {sample_score_set.target_genes[0].name}." + ), + call(100, 100, "Completed submission of UniProt mapping jobs."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestSubmitUniprotMappingJobsForScoreSetIntegration: + """Integration tests for submit_uniprot_mapping_jobs_for_score_set function.""" + + async def test_submit_uniprot_mapping_jobs_success_independent_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_called_once() + assert job_result["status"] == "ok" + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run) + sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending (non-pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + + async def test_submit_uniprot_mapping_jobs_success_pipeline_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_jobs_pipeline, + with_dummy_polling_job_for_submission_run, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_dummy_polling_job_for_submission_run_in_pipeline, + sample_score_set, + ): + """Integration test for submitting UniProt mapping jobs in a pipeline context.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + + mock_submit_id_mapping.assert_called_once() + assert job_result["status"] == "ok" + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert ( + sample_dummy_polling_job_for_submission_run_in_pipeline.job_params["mapping_jobs"] + == expected_submitted_jobs + ) + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is now queued (pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert sample_dummy_polling_job_for_submission_run_in_pipeline.status == JobStatus.QUEUED + + # Verify that the pipeline run status is running + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.RUNNING + + async def test_submit_uniprot_mapping_jobs_no_targets( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when no target genes are present.""" + + # Ensure the sample score set has no target genes + sample_score_set.target_genes = [] + mock_worker_ctx["db"].commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_not_called() + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_no_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when no ACs are present in post mapped metadata.""" + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_not_called() + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_too_many_acs_in_post_mapped_metadata( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when too many ACs are present in post mapped metadata.""" + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ) as mock_submit_id_mapping, + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + mock_submit_id_mapping.assert_not_called() + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_propagates_exceptions( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test to ensure exceptions during UniProt mapping job submission are propagated to decorators.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ): + result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + assert result["status"] == "failed" + assert "UniProt API failure" in result["exception_details"]["message"] + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_no_jobs_submitted( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for submitting UniProt mapping jobs when no jobs are submitted.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value=None, + ), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + assert job_result["status"] == "ok" + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": None, "accession": VALID_NT_ACCESSION} + } + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_partial_submission( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test for partial submission of UniProt mapping jobs.""" + + # Add another target gene to the score set to simulate multiple submissions + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + mock_worker_ctx["db"].add(new_target_gene) + mock_worker_ctx["db"].commit() + + # Add accessions to both target genes' post mapped metadata + for idx, tg in enumerate(sample_score_set.target_genes): + tg.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION + f"{idx:05d}"]}} + mock_worker_ctx["db"].commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=["job_12345", None], + ), + ): + job_result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + assert job_result["status"] == "ok" + + expected_submitted_jobs = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION + "00000"}, + "2": {"job_id": None, "accession": VALID_NT_ACCESSION + "00001"}, + } + + # Verify that the job metadata contains both submitted and failed jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending and params were updated correctly + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == expected_submitted_jobs + + async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + ): + """Integration test to ensure error is raised to the decorator when dependent polling job is not available.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ): + result = await submit_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id + ) + + assert result["status"] == "failed" + assert ( + "Could not find unique dependent polling job for UniProt mapping job" + in result["exception_details"]["message"] + ) + + # Verify that the job metadata contains the job we submitted before the error + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + + # nothing to verify for dependent polling job since it does not exist + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestSubmitUniprotMappingJobsArqContext: + """Integration tests for submit_uniprot_mapping_jobs_for_score_set function in ARQ context.""" + + async def test_submit_uniprot_mapping_jobs_with_arq_context_independent( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run) + sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.job_params["mapping_jobs"] == expected_submitted_jobs + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is still pending (non-pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run) + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + + async def test_submit_uniprot_mapping_jobs_with_arq_context_pipeline( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_jobs_pipeline, + with_dummy_polling_job_for_submission_run, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_dummy_polling_job_for_submission_run_in_pipeline, + sample_score_set, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + expected_submitted_jobs = {"1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}} + + # Verify that the job metadata contains the submitted job + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] == expected_submitted_jobs + + # Verify that polling job params have been updated correctly + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert ( + sample_dummy_polling_job_for_submission_run_in_pipeline.job_params["mapping_jobs"] + == expected_submitted_jobs + ) + + # Verify that the submission job was completed successfully + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the dependent polling job is now queued (pipeline ctx) + session.refresh(sample_dummy_polling_job_for_submission_run_in_pipeline) + assert sample_dummy_polling_job_for_submission_run_in_pipeline.status == JobStatus.QUEUED + + # Verify that the pipeline run status is running + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.RUNNING + + async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_independent( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_job, + with_dummy_polling_job_for_submission_run, + sample_score_set, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, + ): + """Integration test to ensure exceptions during UniProt mapping job submission are propagated to decorators.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run) + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + + # Verify that the dependent polling job is still pending and no param changes were made + assert sample_dummy_polling_job_for_submission_run.status == JobStatus.PENDING + assert sample_dummy_polling_job_for_submission_run.job_params.get("mapping_jobs") == {} + + async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_pipeline( + self, + session, + arq_redis, + arq_worker, + athena_engine, + with_populated_domain_data, + with_submit_uniprot_mapping_jobs_pipeline, + with_dummy_polling_job_for_submission_run, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_dummy_polling_job_for_submission_run_in_pipeline, + sample_score_set, + ): + """Integration test to ensure exceptions during UniProt mapping job submission are propagated to decorators.""" + + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ): + await arq_redis.enqueue_job( + "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the job metadata contains no submitted jobs + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] == {} + + # Verify that the submission job failed + session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.FAILED + + # Verify that the dependent polling job is now cancelled and no param changes were made + assert sample_dummy_polling_job_for_submission_run_in_pipeline.status == JobStatus.SKIPPED + assert sample_dummy_polling_job_for_submission_run_in_pipeline.job_params.get("mapping_jobs") == {} + + # Verify that the pipeline run status is failed + session.refresh(sample_submit_uniprot_mapping_jobs_pipeline) + assert sample_submit_uniprot_mapping_jobs_pipeline.status == PipelineStatus.FAILED + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestPollUniprotMappingJobsForScoreSetUnit: + """Unit tests for poll_uniprot_mapping_jobs_for_score_set function.""" + + async def test_poll_uniprot_mapping_jobs_no_mapping_jobs( + self, + session, + mock_worker_ctx, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Ensure there are no mapping jobs in the polling job params + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = {} + session.commit() + + with ( + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No mapping jobs found to poll.") + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # TODO:XXX -- We will eventually want to make sure the job indicates to the manager + # its desire to be retried. For now, we just verify that no changes are made + # when results are not ready. + async def test_poll_uniprot_mapping_jobs_results_not_ready( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=False, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_called_with(100, 100, "Completed polling of UniProt mapping jobs.") + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + async def test_poll_uniprot_mapping_jobs_no_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={"results": []}, # minimal response with no results + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(UniprotMappingResultNotFoundError), + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, 100, f"No UniProt ID found for accession {VALID_NT_ACCESSION}. Cannot add UniProt ID." + ) + + async def test_poll_uniprot_mapping_jobs_ambiguous_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={ + "results": [ + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": f"{VALID_UNIPROT_ACCESSION}", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": "P67890", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + ] + }, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(UniprotAmbiguousMappingResultError), + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, + 100, + f"Ambiguous UniProt ID mapping results for accession {VALID_NT_ACCESSION}. Cannot add UniProt ID.", + ) + + async def test_poll_uniprot_mapping_jobs_nonexistent_target( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job with a non-existent target gene ID + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "999": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + pytest.raises(NonExistentTargetGeneError), + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + mock_update_progress.assert_called_with( + 100, + 100, + f"Target gene ID 999 not found in score set {sample_score_set.urn}. Cannot add UniProt ID.", + ) + + async def test_poll_uniprot_mapping_jobs_successful_update( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_called_with(100, 100, "Completed polling of UniProt mapping jobs.") + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + async def test_poll_uniprot_mapping_jobs_partial_success( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have two mapping jobs + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}, + "2": {"job_id": "job_67890", "accession": "NONEXISTENT_AC"}, + } + session.commit() + + # Add another target gene to the score set to correspond to the second mapping job + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + mock_worker_ctx["db"].add(new_target_gene) + mock_worker_ctx["db"].commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=[True, False], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + side_effect=[ + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, # Successful result for the first mapping job + {"results": []}, # No results for the second mapping job + ], + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made + mock_update_progress.assert_called_with(100, 100, "Completed polling of UniProt mapping jobs.") + + # Verify the target gene uniprot id has been updated for the successful mapping and + # remains None for the failed mapping + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + assert sample_score_set.target_genes[1].uniprot_id_from_mapped_metadata is None + + async def test_poll_uniprot_mapping_jobs_updates_progress( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have one mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_11111", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=[True, True, True], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + side_effect=[TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE], + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert job_result["status"] == "ok" + + # Verify that progress updates were made incrementally + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting UniProt mapping job polling."), + call(95, 100, "Polled UniProt mapping job for target gene Sample Gene."), + call(100, 100, "Completed polling of UniProt mapping jobs."), + ] + ) + + # Verify the target gene uniprot ids have been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + async def test_poll_uniprot_mapping_jobs_propagates_exceptions( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + pytest.raises(Exception) as exc_info, + ): + await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, + 1, + JobManager( + db=mock_worker_ctx["db"], + redis=mock_worker_ctx["redis"], + job_id=sample_polling_job_for_submission_run.id, + ), + ) + + assert str(exc_info.value) == "UniProt API failure" + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestPollUniprotMappingJobsForScoreSetIntegration: + """Integration tests for poll_uniprot_mapping_jobs_for_score_set function.""" + + async def test_poll_uniprot_mapping_jobs_success_independent_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_success_pipeline_ctx( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_poll_uniprot_mapping_jobs_pipeline, + sample_score_set, + sample_poll_uniprot_mapping_jobs_run_in_pipeline, + sample_poll_uniprot_mapping_jobs_pipeline, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_poll_uniprot_mapping_jobs_run_in_pipeline.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_poll_uniprot_mapping_jobs_run_in_pipeline.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) + assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded (this is the only job in the test pipeline) + session.refresh(sample_poll_uniprot_mapping_jobs_pipeline) + assert sample_poll_uniprot_mapping_jobs_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_no_mapping_jobs( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Ensure there are no mapping jobs in the polling job params + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = {} + session.commit() + + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job succeeded + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_partial_mapping_jobs( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have two mapping jobs + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION}, + "2": {"job_id": None, "accession": "NONEXISTENT_AC"}, + } + session.commit() + + # Add another target gene to the score set to correspond to the second mapping job + new_target_gene = TargetGene( + score_set_id=sample_score_set.id, + name="TP53", + category="protein_coding", + target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), + ) + mock_worker_ctx["db"].add(new_target_gene) + mock_worker_ctx["db"].commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=[True], + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + side_effect=[TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE], + ), + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id has been updated for the successful mapping and + # remains None for the mapping with no job id + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + assert sample_score_set.target_genes[1].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job succeeded + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_results_not_ready( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=False, + ): + job_result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert job_result["status"] == "ok" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job succeeded + # TODO#XXX -- For now, we mark the job as succeeded even if no updates were made. + # In the future, we may want to have the job indicate it should be retried. + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_no_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={"results": []}, # minimal response with no results + ), + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert result["status"] == "failed" + assert result["exception_details"]["type"] == "UniprotMappingResultNotFoundError" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + async def test_poll_uniprot_mapping_jobs_ambiguous_results( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value={ + "results": [ + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": f"{VALID_UNIPROT_ACCESSION}", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + { + "from": VALID_NT_ACCESSION, + "to": { + "primaryAccession": "P67890", + "entryType": TEST_UNIPROT_SWISS_PROT_TYPE, + }, + }, + ] + }, + ), + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert result["status"] == "failed" + assert result["exception_details"]["type"] == "UniprotAmbiguousMappingResultError" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + async def test_poll_uniprot_mapping_jobs_nonexistent_target( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job with a non-existent target gene ID + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "999": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert result["status"] == "failed" + assert result["exception_details"]["type"] == "NonExistentTargetGeneError" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + async def test_poll_uniprot_mapping_jobs_propagates_exceptions_to_decorator( + self, + session, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ): + result = await poll_uniprot_mapping_jobs_for_score_set( + mock_worker_ctx, sample_polling_job_for_submission_run.id + ) + + assert result["status"] == "failed" + assert result["exception_details"]["message"] == "UniProt API failure" + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestPollUniprotMappingJobsForScoreSetArqContext: + """Integration tests for poll_uniprot_mapping_jobs_for_score_set function with ARQ context.""" + + async def test_poll_uniprot_mapping_jobs_with_arq_context_independent( + self, + session, + arq_worker, + arq_redis, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + with_submit_uniprot_mapping_job, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", sample_polling_job_for_submission_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_with_arq_context_pipeline( + self, + session, + arq_worker, + arq_redis, + with_populated_domain_data, + with_poll_uniprot_mapping_jobs_pipeline, + sample_score_set, + sample_poll_uniprot_mapping_jobs_run_in_pipeline, + sample_poll_uniprot_mapping_jobs_pipeline, + ): + # Add an accession to the target gene's post mapped metadata + target_gene = sample_score_set.target_genes[0] + target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} + session.commit() + + # Arrange the polling job params to have a single mapping job + sample_poll_uniprot_mapping_jobs_run_in_pipeline.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + return_value=True, + ), + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", + return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, + ), + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", + sample_poll_uniprot_mapping_jobs_run_in_pipeline.id, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the target gene uniprot id has been updated + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata == VALID_UNIPROT_ACCESSION + + # Verify that the polling job was completed successfully + session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) + assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded (this is the only job in the test pipeline) + session.refresh(sample_poll_uniprot_mapping_jobs_pipeline) + assert sample_poll_uniprot_mapping_jobs_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_independent( + self, + session, + arq_worker, + arq_redis, + mock_worker_ctx, + with_populated_domain_data, + with_independent_polling_job_for_submission_run, + sample_score_set, + sample_polling_job_for_submission_run, + ): + # Arrange the polling job params to have a single mapping job + sample_polling_job_for_submission_run.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", sample_polling_job_for_submission_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the polling job failed + session.refresh(sample_polling_job_for_submission_run) + assert sample_polling_job_for_submission_run.status == JobStatus.FAILED + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None + + async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_pipeline( + self, + session, + arq_worker, + arq_redis, + mock_worker_ctx, + with_populated_domain_data, + with_poll_uniprot_mapping_jobs_pipeline, + sample_score_set, + sample_poll_uniprot_mapping_jobs_run_in_pipeline, + sample_poll_uniprot_mapping_jobs_pipeline, + ): + # Arrange the polling job params to have a single mapping job + sample_poll_uniprot_mapping_jobs_run_in_pipeline.job_params["mapping_jobs"] = { + "1": {"job_id": "job_12345", "accession": VALID_NT_ACCESSION} + } + session.commit() + + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + ): + await arq_redis.enqueue_job( + "poll_uniprot_mapping_jobs_for_score_set", + sample_poll_uniprot_mapping_jobs_run_in_pipeline.id, + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the polling job failed + session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) + assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.FAILED + + # Verify that the pipeline run status is failed + session.refresh(sample_poll_uniprot_mapping_jobs_pipeline) + assert sample_poll_uniprot_mapping_jobs_pipeline.status == PipelineStatus.FAILED + + # Verify the target gene uniprot id remains unchanged + session.refresh(sample_score_set) + assert sample_score_set.target_genes[0].uniprot_id_from_mapped_metadata is None From a06f351f1cc3cc767a855a11f0d2f651982faa14 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 27 Jan 2026 19:44:45 -0800 Subject: [PATCH 33/70] feat: clingen managed job enhancements - Adds comprehensive test cases for clingen managed jobs - Removes clingen linking via LDH. These IDs will always be linked via the CAR in future versions --- src/mavedb/lib/clingen/services.py | 55 +- src/mavedb/lib/exceptions.py | 6 + src/mavedb/scripts/link_clingen_variants.py | 75 - src/mavedb/worker/jobs/__init__.py | 2 - .../worker/jobs/external_services/__init__.py | 2 - .../worker/jobs/external_services/clingen.py | 203 +- src/mavedb/worker/jobs/registry.py | 2 - tests/helpers/util/setup/worker.py | 42 +- tests/lib/clingen/test_services.py | 74 +- tests/network/worker/test_clingen.py | 0 tests/worker/jobs/conftest.py | 807 ++++++ .../worker/jobs/external_services/conftest.py | 365 --- .../external_services/network/test_clingen.py | 134 + .../external_services/network/test_gnomad.py | 0 .../jobs/external_services/test_clingen.py | 2259 ++++++++++++++--- .../jobs/pipeline_management/conftest.py | 62 - .../jobs/variant_processing/conftest.py | 191 -- 17 files changed, 2919 insertions(+), 1360 deletions(-) delete mode 100644 src/mavedb/scripts/link_clingen_variants.py delete mode 100644 tests/network/worker/test_clingen.py create mode 100644 tests/worker/jobs/conftest.py delete mode 100644 tests/worker/jobs/external_services/conftest.py create mode 100644 tests/worker/jobs/external_services/network/test_clingen.py delete mode 100644 tests/worker/jobs/external_services/network/test_gnomad.py delete mode 100644 tests/worker/jobs/pipeline_management/conftest.py delete mode 100644 tests/worker/jobs/variant_processing/conftest.py diff --git a/src/mavedb/lib/clingen/services.py b/src/mavedb/lib/clingen/services.py index 1bcb7778..a9e41fcb 100644 --- a/src/mavedb/lib/clingen/services.py +++ b/src/mavedb/lib/clingen/services.py @@ -1,19 +1,16 @@ import hashlib import logging -import requests import os import time from datetime import datetime from typing import Optional -from urllib import parse - +import requests from jose import jwt -from mavedb.lib.logging.context import logging_context, save_to_logging_context, format_raised_exception_info_as_dict -from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD, LDH_MAVE_ACCESS_ENDPOINT - -from mavedb.lib.types.clingen import LdhSubmission, ClinGenAllele +from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD +from mavedb.lib.logging.context import format_raised_exception_info_as_dict, logging_context, save_to_logging_context +from mavedb.lib.types.clingen import ClinGenAllele, LdhSubmission from mavedb.lib.utils import batched logger = logging.getLogger(__name__) @@ -279,50 +276,6 @@ def _existing_jwt(self) -> Optional[str]: return None -def get_clingen_variation(urn: str) -> Optional[dict]: - """ - Fetches ClinGen variation data for a given URN (Uniform Resource Name) from the Linked Data Hub. - - Args: - urn (str): The URN of the variation to fetch. - - Returns: - Optional[dict]: A dictionary containing the variation data if the request is successful, - or None if the request fails. - """ - response = requests.get( - f"{LDH_MAVE_ACCESS_ENDPOINT}/{parse.quote_plus(urn)}", - headers={"Accept": "application/json"}, - ) - - if response.status_code == 200: - return response.json() - else: - logger.error(f"Failed to fetch data for URN {urn}: {response.status_code} - {response.text}") - return None - - -def clingen_allele_id_from_ldh_variation(variation: Optional[dict]) -> Optional[str]: - """ - Extracts the ClinGen allele ID from a given variation dictionary. - - Args: - variation (Optional[dict]): A dictionary containing variation data, otherwise None. - - Returns: - Optional[str]: The ClinGen allele ID if found, otherwise None. - """ - if not variation: - return None - - try: - return variation["data"]["ldFor"]["Variant"][0]["entId"] - except (KeyError, IndexError) as exc: - save_to_logging_context(format_raised_exception_info_as_dict(exc)) - logger.error("Failed to extract ClinGen allele ID from variation data.", extra=logging_context()) - return None - - def get_allele_registry_associations( content_submissions: list[str], submission_response: list[ClinGenAllele] ) -> dict[str, str]: diff --git a/src/mavedb/lib/exceptions.py b/src/mavedb/lib/exceptions.py index db7458f1..63e891a3 100644 --- a/src/mavedb/lib/exceptions.py +++ b/src/mavedb/lib/exceptions.py @@ -226,3 +226,9 @@ class NonExistentTargetGeneError(ValueError): """Raised when a target gene does not exist in the database.""" pass + + +class LDHSubmissionFailureError(Exception): + """Raised when submission to ClinGen Linked Data Hub (LDH) fails for all submissions.""" + + pass diff --git a/src/mavedb/scripts/link_clingen_variants.py b/src/mavedb/scripts/link_clingen_variants.py deleted file mode 100644 index 2ca3c069..00000000 --- a/src/mavedb/scripts/link_clingen_variants.py +++ /dev/null @@ -1,75 +0,0 @@ -import click -import logging -from typing import Sequence - -from sqlalchemy import and_, select -from sqlalchemy.orm import Session - -from mavedb.lib.clingen.services import get_clingen_variation, clingen_allele_id_from_ldh_variation -from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant -from mavedb.scripts.environment import with_database_session - -logger = logging.getLogger(__name__) - - -@click.command() -@with_database_session -@click.argument("urns", nargs=-1) -@click.option("--score-sets/--variants", default=False) -@click.option("--unlinked", default=False, is_flag=True) -def link_clingen_variants(db: Session, urns: Sequence[str], score_sets: bool, unlinked: bool) -> None: - """ - Submit data to ClinGen for mapped variant allele ID generation for the given URNs. - """ - if not urns: - logger.error("No URNs provided. Please provide at least one URN.") - return - - # Convert score set URNs to variant URNs. - if score_sets: - query = ( - select(Variant.urn) - .join(MappedVariant) - .join(ScoreSet) - .where(MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None)) - ) - - if unlinked: - query = query.where(MappedVariant.clingen_allele_id.is_(None)) - - variants = [db.scalars(query.where(ScoreSet.urn == urn)).all() for urn in urns] - urns = [variant for sublist in variants for variant in sublist if variant is not None] - - failed_urns = [] - for urn in urns: - ldh_variation = get_clingen_variation(urn) - allele_id = clingen_allele_id_from_ldh_variation(ldh_variation) - - if not allele_id: - failed_urns.append(urn) - continue - - mapped_variant = db.scalar( - select(MappedVariant).join(Variant).where(and_(Variant.urn == urn, MappedVariant.current.is_(True))) - ) - - if not mapped_variant: - logger.warning(f"No mapped variant found for URN {urn}.") - failed_urns.append(urn) - continue - - mapped_variant.clingen_allele_id = allele_id - db.add(mapped_variant) - - logger.info(f"Successfully linked URN {urn} to ClinGen variation {allele_id}.") - - if failed_urns: - logger.warning(f"Failed to link the following {len(failed_urns)} URNs: {', '.join(failed_urns)}") - - logger.info(f"Linking process completed. Linked {len(urns) - len(failed_urns)}/{len(urns)} URNs successfully.") - - -if __name__ == "__main__": - link_clingen_variants() diff --git a/src/mavedb/worker/jobs/__init__.py b/src/mavedb/worker/jobs/__init__.py index a7a86a58..6a52927c 100644 --- a/src/mavedb/worker/jobs/__init__.py +++ b/src/mavedb/worker/jobs/__init__.py @@ -16,7 +16,6 @@ refresh_published_variants_view, ) from mavedb.worker.jobs.external_services.clingen import ( - link_clingen_variants, submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, ) @@ -39,7 +38,6 @@ "create_variants_for_score_set", "map_variants_for_score_set", # External service integration jobs - "link_clingen_variants", "submit_score_set_mappings_to_car", "submit_score_set_mappings_to_ldh", "poll_uniprot_mapping_jobs_for_score_set", diff --git a/src/mavedb/worker/jobs/external_services/__init__.py b/src/mavedb/worker/jobs/external_services/__init__.py index 60135efe..eabe8ebe 100644 --- a/src/mavedb/worker/jobs/external_services/__init__.py +++ b/src/mavedb/worker/jobs/external_services/__init__.py @@ -8,7 +8,6 @@ # External services job functions from .clingen import ( - link_clingen_variants, submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, ) @@ -19,7 +18,6 @@ ) __all__ = [ - "link_clingen_variants", "submit_score_set_mappings_to_car", "submit_score_set_mappings_to_ldh", "link_gnomad_variants", diff --git a/src/mavedb/worker/jobs/external_services/clingen.py b/src/mavedb/worker/jobs/external_services/clingen.py index 56b7a5f9..5d0de7f7 100644 --- a/src/mavedb/worker/jobs/external_services/clingen.py +++ b/src/mavedb/worker/jobs/external_services/clingen.py @@ -17,6 +17,7 @@ from mavedb.lib.clingen.constants import ( CAR_SUBMISSION_ENDPOINT, + CLIN_GEN_SUBMISSION_ENABLED, DEFAULT_LDH_SUBMISSION_BATCH_SIZE, LDH_SUBMISSION_ENDPOINT, ) @@ -24,10 +25,9 @@ from mavedb.lib.clingen.services import ( ClinGenAlleleRegistryService, ClinGenLdhService, - clingen_allele_id_from_ldh_variation, get_allele_registry_associations, - get_clingen_variation, ) +from mavedb.lib.exceptions import LDHSubmissionFailureError from mavedb.lib.variants import get_hgvs_from_post_mapped from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet @@ -85,6 +85,24 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: job_manager.update_progress(0, 100, "Starting CAR mapped resource submission.") logger.info(msg="Started CAR mapped resource submission", extra=job_manager.logging_context()) + # Ensure we've enabled ClinGen submission + if not CLIN_GEN_SUBMISSION_ENABLED: + job_manager.update_progress(100, 100, "ClinGen submission is disabled. Skipping CAR submission.") + logger.warning( + msg="ClinGen submission is disabled via configuration, skipping submission of mapped variants to CAR.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception_details": None} + + # Check for CAR submission endpoint + if not CAR_SUBMISSION_ENDPOINT: + job_manager.update_progress(100, 100, "CAR submission endpoint not configured. Can't complete submission.") + logger.warning( + msg="ClinGen Allele Registry submission is disabled (no submission endpoint), unable to complete submission of mapped variants to CAR.", + extra=job_manager.logging_context(), + ) + raise ValueError("ClinGen Allele Registry submission endpoint is not configured.") + # Fetch mapped variants with post-mapped data for the score set variant_post_mapped_objects = job_manager.db.execute( select(MappedVariant.id, MappedVariant.post_mapped) @@ -104,11 +122,12 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: extra=job_manager.logging_context(), ) return {"status": "ok", "data": {}, "exception_details": None} + job_manager.update_progress( 10, 100, f"Preparing {len(variant_post_mapped_objects)} mapped variants for CAR submission." ) - # Build HGVS strings for submission + # Build HGVS strings for submission. Don't do duplicate submissions-- store mapped variant IDs by HGVS. variant_post_mapped_hgvs: dict[str, list[int]] = {} for mapped_variant_id, post_mapped in variant_post_mapped_objects: hgvs_for_post_mapped = get_hgvs_from_post_mapped(post_mapped) @@ -124,22 +143,14 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: variant_post_mapped_hgvs[hgvs_for_post_mapped].append(mapped_variant_id) else: variant_post_mapped_hgvs[hgvs_for_post_mapped] = [mapped_variant_id] + job_manager.save_to_context({"unique_variants_to_submit_car": len(variant_post_mapped_hgvs)}) job_manager.update_progress(15, 100, "Submitting mapped variants to CAR.") - # Check for CAR submission endpoint - if not CAR_SUBMISSION_ENDPOINT: - job_manager.update_progress(100, 100, "CAR submission endpoint not configured. Skipping submission.") - logger.warning( - msg="ClinGen Allele Registry submission is disabled (no submission endpoint), skipping submission of mapped variants to CAR.", - extra=job_manager.logging_context(), - ) - raise ValueError("ClinGen Allele Registry submission endpoint is not configured.") - # Do submission car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) registered_alleles = car_service.dispatch_submissions(list(variant_post_mapped_hgvs.keys())) - job_manager.update_progress(50, 100, "Processing registered alleles from CAR.") + job_manager.update_progress(60, 100, "Processing registered alleles from CAR.") # Process registered alleles and update mapped variants linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) @@ -159,7 +170,7 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: # Calculate progress: 50% + (processed/total_mapped)*50, rounded to nearest 5% if total % 20 == 0 or processed == total: - progress = 50 + round((processed / total) * 50 / 5) * 5 + progress = 50 + round((processed / total) * 45 / 5) * 5 job_manager.update_progress(progress, 100, f"Processed {processed} of {total} registered alleles.") # Finalize progress @@ -170,7 +181,7 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: @with_pipeline_management -async def submit_score_set_mappings_to_ldh(ctx: dict, job_manager: JobManager) -> JobResultData: +async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """ Submit mapped variants for a score set to the ClinGen Linked Data Hub (LDH). @@ -252,6 +263,14 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_manager: JobManager) - variant_content.append((variation, variant, mapped_variant)) + if not variant_content: + job_manager.update_progress(100, 100, "No valid mapped variants to submit to LDH. Skipping submission.") + logger.warning( + msg="No valid mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", + extra=job_manager.logging_context(), + ) + return {"status": "ok", "data": {}, "exception_details": None} + job_manager.save_to_context({"unique_variants_to_submit_ldh": len(variant_content)}) job_manager.update_progress(30, 100, f"Dispatching submissions for {len(variant_content)} unique variants to LDH.") submission_content = construct_ldh_submission(variant_content) @@ -262,154 +281,40 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_manager: JobManager) - loop = asyncio.get_running_loop() submission_successes, submission_failures = await loop.run_in_executor(ctx["pool"], blocking) job_manager.update_progress(90, 100, "Finalizing LDH mapped resource submission.") - - # TODO: Track submission successes and failures, add as annotation features. - if submission_failures: - job_manager.save_to_context({"ldh_submission_failures": len(submission_failures)}) - logger.error( - msg=f"LDH mapped resource submission encountered {len(submission_failures)} failures.", - extra=job_manager.logging_context(), - ) - - # Finalize progress - job_manager.update_progress(100, 100, "Finalized LDH mapped resource submission.") - job_manager.db.commit() - return {"status": "ok", "data": {}, "exception_details": None} - - -def do_clingen_fetch(variant_urns): - return [(variant_urn, get_clingen_variation(variant_urn)) for variant_urn in variant_urns] - - -@with_pipeline_management -async def link_clingen_variants(ctx: dict, job_manager: JobManager) -> JobResultData: - """ - Link mapped variants to ClinGen Linked Data Hub (LDH) submissions. - - This job links mapped variant data to existing LDH data for a given score set. It fetches - LDH variations for each mapped variant and updates the database accordingly. Progress - and errors are logged throughout the process. - - Required job_params in the JobRun: - - score_set_id (int): ID of the ScoreSet to process - - correlation_id (str): Correlation ID for tracking - - Args: - ctx (dict): Worker context containing DB and Redis connections - job_manager (JobManager): Manager for job lifecycle and DB operations - - Side Effects: - - Updates MappedVariant records with ClinGen Allele IDs from LDH objects - - Returns: - dict: Result indicating success and any exception details - """ - # Get the job definition we are working on - job = job_manager.get_job() - - _job_required_params = ["score_set_id", "correlation_id"] - validate_job_params(_job_required_params, job) - - # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. - score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore - correlation_id = job.job_params["correlation_id"] # type: ignore - - # Setup initial context and progress job_manager.save_to_context( { - "application": "mavedb-worker", - "function": "link_clingen_variants", - "resource": score_set.urn, - "correlation_id": correlation_id, + "ldh_submission_successes": len(submission_successes), + "ldh_submission_failures": len(submission_failures), } ) - job_manager.update_progress(0, 100, "Starting LDH mapped resource linkage.") - logger.info(msg="Started LDH mapped resource linkage", extra=job_manager.logging_context()) - - # Fetch mapped variants with post-mapped data for the score set - variant_urns = job_manager.db.scalars( - select(Variant.urn) - .join(MappedVariant) - .join(ScoreSet) - .where(ScoreSet.urn == score_set.urn, MappedVariant.current.is_(True), MappedVariant.post_mapped.is_not(None)) - ).all() - num_variant_urns = len(variant_urns) - - job_manager.save_to_context({"total_variants_to_link_ldh": num_variant_urns}) - job_manager.update_progress(10, 100, f"Found {num_variant_urns} mapped variants to link to LDH submissions.") - if not variant_urns: - job_manager.update_progress(100, 100, "No mapped variants to link to LDH submissions. Skipping linkage.") + # TODO: Track submission successes and failures, add as annotation features. + if submission_failures: logger.warning( - msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH linkage (nothing to do). A gnomAD linkage job will not be enqueued, as no variants will have a CAID.", + msg=f"LDH mapped resource submission encountered {len(submission_failures)} failures.", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} - logger.info(msg="Attempting to link mapped variants to LDH submissions.", extra=job_manager.logging_context()) - - # TODO#372: Non-nullable variant urns. - # Fetch linked data from LDH for each variant URN - blocking = functools.partial( - do_clingen_fetch, - variant_urns, # type: ignore - ) - loop = asyncio.get_running_loop() - linked_data = await loop.run_in_executor(ctx["pool"], blocking) - - linked_allele_ids = [ - (variant_urn, clingen_allele_id_from_ldh_variation(clingen_variation)) - for variant_urn, clingen_variation in linked_data - ] - job_manager.save_to_context({"ldh_variants_fetched": len(linked_allele_ids)}) - job_manager.update_progress(70, 100, "Fetched existing LDH variant data.") - logger.info(msg="Fetched existing LDH variant data.", extra=job_manager.logging_context()) - - # Link mapped variants to fetched LDH data - linkage_failures = [] - for variant_urn, ldh_variation in linked_allele_ids: - # XXX: Should we unlink variation if it is not found? Does this constitute a failure? - if not ldh_variation: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No LDH variation found.", + if not submission_successes: + job_manager.update_progress(100, 100, "All mapped variant submissions to LDH failed.") + error_message = f"All LDH submissions failed for score set {score_set.urn}." + logger.error( + msg=error_message, extra=job_manager.logging_context(), ) - linkage_failures.append(variant_urn) - continue - mapped_variant = job_manager.db.scalars( - select(MappedVariant).join(Variant).where(Variant.urn == variant_urn, MappedVariant.current.is_(True)) - ).one_or_none() + raise LDHSubmissionFailureError(error_message) - if not mapped_variant: - logger.warning( - msg=f"Failed to link mapped variant {variant_urn} to LDH submission. No mapped variant found.", - extra=job_manager.logging_context(), - ) - linkage_failures.append(variant_urn) - continue - - mapped_variant.clingen_allele_id = ldh_variation - job_manager.db.add(mapped_variant) - - # TODO: Track annotation progress. Given the new progress model, we can better understand what linked and what didn't and - # can move away from the retry threshold model. - - # Calculate progress: 70% + (linked/total_variants)*30, rounded to nearest 5% - if len(linked_allele_ids) % 20 == 0 or len(linked_allele_ids) == num_variant_urns: - progress = 70 + round((len(linked_allele_ids) / num_variant_urns) * 30 / 5) * 5 - job_manager.update_progress( - progress, 100, f"Linked {len(linked_allele_ids)} of {num_variant_urns} variants." - ) - - job_manager.save_to_context({"ldh_linkage_failures": len(linkage_failures)}) - if linkage_failures: - logger.warning( - msg=f"LDH mapped resource linkage encountered {len(linkage_failures)} failures.", - extra=job_manager.logging_context(), - ) + logger.info( + msg="Completed LDH mapped resource submission", + extra=job_manager.logging_context(), + ) # Finalize progress - job_manager.update_progress(100, 100, "Finalized LDH mapped resource linkage.") + job_manager.update_progress( + 100, + 100, + f"Finalized LDH mapped resource submission ({len(submission_successes)} successes, {len(submission_failures)} failures).", + ) job_manager.db.commit() return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py index 60654170..251d87c8 100644 --- a/src/mavedb/worker/jobs/registry.py +++ b/src/mavedb/worker/jobs/registry.py @@ -14,7 +14,6 @@ refresh_published_variants_view, ) from mavedb.worker.jobs.external_services import ( - link_clingen_variants, link_gnomad_variants, poll_uniprot_mapping_jobs_for_score_set, submit_score_set_mappings_to_car, @@ -35,7 +34,6 @@ # External service jobs submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, - link_clingen_variants, submit_uniprot_mapping_jobs_for_score_set, poll_uniprot_mapping_jobs_for_score_set, link_gnomad_variants, diff --git a/tests/helpers/util/setup/worker.py b/tests/helpers/util/setup/worker.py index 91aadb81..dd4473bc 100644 --- a/tests/helpers/util/setup/worker.py +++ b/tests/helpers/util/setup/worker.py @@ -10,6 +10,7 @@ create_variants_for_score_set, map_variants_for_score_set, ) +from mavedb.worker.lib.managers.job_manager import JobManager from tests.helpers.constants import ( TEST_CODING_LAYER, TEST_GENE_INFO, @@ -32,7 +33,19 @@ async def create_variants_in_score_set( side_effect=[score_df, count_df], ), ): - result = await create_variants_for_score_set(mock_worker_ctx, variant_creation_run.id) + # Guard against both possible function signatures, with some uses of this function coming from + # integration tests that need not pass a JobManager. + try: + result = await create_variants_for_score_set( + mock_worker_ctx, + variant_creation_run.id, + ) + except TypeError: + result = await create_variants_for_score_set( + mock_worker_ctx, + variant_creation_run.id, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], variant_creation_run.id), + ) assert result["status"] == "ok" session.commit() @@ -41,10 +54,14 @@ async def create_variants_in_score_set( async def create_mappings_in_score_set( session, mock_s3_client, mock_worker_ctx, score_df, count_df, variant_creation_run, variant_mapping_run ): - score_set = await create_variants_in_score_set( + await create_variants_in_score_set( session, mock_s3_client, score_df, count_df, mock_worker_ctx, variant_creation_run ) + score_set = session.execute( + select(ScoreSetDbModel).where(ScoreSetDbModel.id == variant_creation_run.job_params["score_set_id"]) + ).scalar_one() + async def dummy_mapping_job(): return await construct_mock_mapping_output(session, score_set, with_layers={"g", "c", "p"}) @@ -54,9 +71,17 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), - patch("mavedb.worker.jobs.variant_processing.mapping.CLIN_GEN_SUBMISSION_ENABLED", False), ): - result = await map_variants_for_score_set(mock_worker_ctx, variant_mapping_run.id) + # Guard against both possible function signatures, with some uses of this function coming from + # integration tests that need not pass a JobManager. + try: + result = await map_variants_for_score_set(mock_worker_ctx, variant_mapping_run.id) + except TypeError: + result = await map_variants_for_score_set( + mock_worker_ctx, + variant_mapping_run.id, + JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], variant_mapping_run.id), + ) assert result["status"] == "ok" session.commit() @@ -98,11 +123,16 @@ async def construct_mock_mapping_output( for idx, variant in enumerate(variants): mapped_score = { - "pre_mapped": TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X if with_pre_mapped else {}, - "post_mapped": TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X if with_post_mapped else {}, + "pre_mapped": deepcopy(TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X) if with_pre_mapped else {}, + "post_mapped": deepcopy(TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X) if with_post_mapped else {}, "mavedb_id": variant.urn, } + # Don't alter HGVS strings in post mapped output. This makes it considerably + # easier to assert correctness in tests. + if with_post_mapped: + mapped_score["post_mapped"]["expressions"][0]["value"] = variant.hgvs_nt or variant.hgvs_pro + # Skip every other variant if not with_all_variants if not with_all_variants and idx % 2 == 0: mapped_score["post_mapped"] = {} diff --git a/tests/lib/clingen/test_services.py b/tests/lib/clingen/test_services.py index 34828649..7141eea3 100644 --- a/tests/lib/clingen/test_services.py +++ b/tests/lib/clingen/test_services.py @@ -1,27 +1,23 @@ # ruff: noqa: E402 import os +from datetime import datetime +from unittest.mock import MagicMock, patch + import pytest import requests -from datetime import datetime -from unittest.mock import patch, MagicMock -from urllib import parse arq = pytest.importorskip("arq") cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") -from mavedb.lib.clingen.constants import LDH_MAVE_ACCESS_ENDPOINT, GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD -from mavedb.lib.utils import batched +from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD from mavedb.lib.clingen.services import ( ClinGenAlleleRegistryService, ClinGenLdhService, - get_clingen_variation, - clingen_allele_id_from_ldh_variation, get_allele_registry_associations, ) - -from tests.helpers.constants import VALID_CLINGEN_CA_ID +from mavedb.lib.utils import batched TEST_CLINGEN_URL = "https://pytest.clingen.com" TEST_CAR_URL = "https://pytest.car.clingen.com" @@ -219,66 +215,6 @@ def test_dispatch_submissions_no_batching(self, mock_batched, mock_authenticate, ) -@patch("mavedb.lib.clingen.services.requests.get") -def test_get_clingen_variation_success(mock_get): - mocked_response_json = {"data": {"ldFor": {"Variant": [{"id": "variant_1", "name": "Test Variant"}]}}} - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = mocked_response_json - mock_get.return_value = mock_response - - urn = "urn:example:variant" - result = get_clingen_variation(urn) - - assert result == mocked_response_json - mock_get.assert_called_once_with( - f"{LDH_MAVE_ACCESS_ENDPOINT}/{parse.quote_plus(urn)}", - headers={"Accept": "application/json"}, - ) - - -@patch("mavedb.lib.clingen.services.requests.get") -def test_get_clingen_variation_failure(mock_get): - mock_response = MagicMock() - mock_response.status_code = 404 - mock_response.text = "Not Found" - mock_get.return_value = mock_response - - urn = "urn:example:nonexistent_variant" - result = get_clingen_variation(urn) - - assert result is None - mock_get.assert_called_once_with( - f"{LDH_MAVE_ACCESS_ENDPOINT}/{parse.quote_plus(urn)}", - headers={"Accept": "application/json"}, - ) - - -def test_clingen_allele_id_from_ldh_variation_success(): - variation = {"data": {"ldFor": {"Variant": [{"entId": VALID_CLINGEN_CA_ID}]}}} - result = clingen_allele_id_from_ldh_variation(variation) - assert result == VALID_CLINGEN_CA_ID - - -def test_clingen_allele_id_from_ldh_variation_missing_key(): - variation = {"data": {"ldFor": {"Variant": []}}} - - result = clingen_allele_id_from_ldh_variation(variation) - assert result is None - - -def test_clingen_allele_id_from_ldh_variation_no_variation(): - result = clingen_allele_id_from_ldh_variation(None) - assert result is None - - -def test_clingen_allele_id_from_ldh_variation_key_error(): - variation = {"data": {}} - - result = clingen_allele_id_from_ldh_variation(variation) - assert result is None - - class TestClinGenAlleleRegistryService: def test_init(self, car_service): assert car_service.url == TEST_CAR_URL diff --git a/tests/network/worker/test_clingen.py b/tests/network/worker/test_clingen.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/worker/jobs/conftest.py b/tests/worker/jobs/conftest.py new file mode 100644 index 00000000..7310d9d6 --- /dev/null +++ b/tests/worker/jobs/conftest.py @@ -0,0 +1,807 @@ +from unittest import mock + +import pytest +from mypy_boto3_s3 import S3Client + +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.pipeline import Pipeline +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant + + +@pytest.fixture +def mock_s3_client(): + """Mock S3 client for tests that interact with S3.""" + + with mock.patch("mavedb.worker.jobs.variant_processing.creation.s3_client") as mock_s3_client_func: + mock_s3 = mock.MagicMock(spec=S3Client) + mock_s3_client_func.return_value = mock_s3 + yield mock_s3 + + +## param fixtures for job runs ## + + +@pytest.fixture +def create_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): + """Provide sample parameters for create_variants_for_score_set job.""" + + return { + "scores_file_key": "sample_scores.csv", + "counts_file_key": "sample_counts.csv", + "correlation_id": "sample-correlation-id", + "updater_id": sample_user.id, + "score_set_id": sample_score_set.id, + "score_columns_metadata": {"s_0": {"description": "metadataS", "details": "detailsS"}}, + "count_columns_metadata": {"c_0": {"description": "metadataC", "details": "detailsC"}}, + } + + +@pytest.fixture +def map_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): + """Provide sample parameters for map_variants_for_score_set job.""" + + return { + "score_set_id": sample_score_set.id, + "correlation_id": "sample-mapping-correlation-id", + "updater_id": sample_user.id, + } + + +@pytest.fixture +def link_gnomad_variants_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for create_variants_for_score_set job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def submit_uniprot_mapping_jobs_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for submit_uniprot_mapping_jobs_for_score_set job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +@pytest.fixture +def poll_uniprot_mapping_jobs_sample_params( + submit_uniprot_mapping_jobs_sample_params, + with_dependent_polling_job_for_submission_run, +): + """Provide sample parameters for poll_uniprot_mapping_jobs_for_score_set job.""" + + return { + "correlation_id": submit_uniprot_mapping_jobs_sample_params["correlation_id"], + "score_set_id": submit_uniprot_mapping_jobs_sample_params["score_set_id"], + "mapping_jobs": {}, + } + + +@pytest.fixture +def submit_score_set_mappings_to_car_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for submit_score_set_mappings_to_car job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + } + + +## Sample pipeline + + +@pytest.fixture +def sample_pipeline(): + """Create a sample Pipeline instance for testing.""" + + return Pipeline( + name="Sample Pipeline", + description="A sample pipeline for testing purposes", + ) + + +@pytest.fixture +def with_sample_pipeline(session, sample_pipeline): + """Fixture to ensure sample pipeline exists in the database.""" + session.add(sample_pipeline) + session.commit() + + +## Variant creation job fixtures + + +@pytest.fixture +def dummy_variant_creation_job_run(create_variants_sample_params): + """Create a dummy variant creation job run for testing.""" + + return JobRun( + urn="test:dummy_variant_creation_job", + job_type="dummy_variant_creation", + job_function="dummy_variant_creation_function", + max_retries=3, + retry_count=0, + job_params=create_variants_sample_params, + ) + + +@pytest.fixture +def dummy_variant_mapping_job_run(map_variants_sample_params): + """Create a dummy variant mapping job run for testing.""" + + return JobRun( + urn="test:dummy_variant_mapping_job", + job_type="dummy_variant_mapping", + job_function="dummy_variant_mapping_function", + max_retries=3, + retry_count=0, + job_params=map_variants_sample_params, + ) + + +@pytest.fixture +def with_dummy_setup_jobs( + session, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, +): + """Add dummy variant creation and mapping job runs to the session.""" + + session.add(dummy_variant_creation_job_run) + session.add(dummy_variant_mapping_job_run) + session.commit() + + +## Gnomad Linkage Job Fixtures ## + + +@pytest.fixture +def sample_link_gnomad_variants_pipeline(): + """Create a pipeline instance for link_gnomad_variants job.""" + + return Pipeline( + urn="test:link_gnomad_variants_pipeline", + name="Link gnomAD Variants Pipeline", + ) + + +@pytest.fixture +def sample_link_gnomad_variants_run(link_gnomad_variants_sample_params): + """Create a JobRun instance for link_gnomad_variants job.""" + + return JobRun( + urn="test:link_gnomad_variants", + job_type="link_gnomad_variants", + job_function="link_gnomad_variants", + max_retries=3, + retry_count=0, + job_params=link_gnomad_variants_sample_params, + ) + + +@pytest.fixture +def with_gnomad_linking_job(session, sample_link_gnomad_variants_run): + """Add a link_gnomad_variants job run to the session.""" + + session.add(sample_link_gnomad_variants_run) + session.commit() + + +@pytest.fixture +def with_gnomad_linking_pipeline(session, sample_link_gnomad_variants_pipeline): + """Add a link_gnomad_variants pipeline to the session.""" + + session.add(sample_link_gnomad_variants_pipeline) + session.commit() + + +@pytest.fixture +def sample_link_gnomad_variants_run_pipeline( + session, + with_gnomad_linking_job, + with_gnomad_linking_pipeline, + sample_link_gnomad_variants_run, + sample_link_gnomad_variants_pipeline, +): + """Provide a context with a link_gnomad_variants job run and pipeline.""" + + sample_link_gnomad_variants_run.pipeline_id = sample_link_gnomad_variants_pipeline.id + session.commit() + return sample_link_gnomad_variants_run + + +@pytest.fixture +def setup_sample_variants_with_caid(with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run): + """Setup variants and mapped variants in the database for testing.""" + session = mock_worker_ctx["db"] + score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) + + # Add a variant and mapped variant to the database with a CAID + variant = Variant( + urn="urn:variant:test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.1A>G", + hgvs_pro="NP_000000.1:p.Met1Val", + data={"hgvs_c": "NM_000000.1:c.1A>G", "hgvs_p": "NP_000000.1:p.Met1Val"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA123", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + +## Uniprot Job Fixtures ## + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_pipeline(): + """Create a pipeline instance for submit_uniprot_mapping_jobs_for_score_set job.""" + + return Pipeline( + urn="test:submit_uniprot_mapping_jobs_pipeline", + name="Submit UniProt Mapping Jobs Pipeline", + ) + + +@pytest.fixture +def sample_poll_uniprot_mapping_jobs_pipeline(): + """Create a pipeline instance for poll_uniprot_mapping_jobs_for_score_set job.""" + + return Pipeline( + urn="test:poll_uniprot_mapping_jobs_pipeline", + name="Poll UniProt Mapping Jobs Pipeline", + ) + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_run(submit_uniprot_mapping_jobs_sample_params): + """Create a JobRun instance for submit_uniprot_mapping_jobs_for_score_set job.""" + + return JobRun( + urn="test:submit_uniprot_mapping_jobs", + job_type="submit_uniprot_mapping_jobs", + job_function="submit_uniprot_mapping_jobs_for_score_set", + max_retries=3, + retry_count=0, + job_params=submit_uniprot_mapping_jobs_sample_params, + ) + + +@pytest.fixture +def sample_dummy_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, +): + """Create a sample dummy dependent polling job for the submission run.""" + + dependent_job = JobRun( + urn="test:dummy_poll_uniprot_mapping_jobs", + job_type="dummy_poll_uniprot_mapping_jobs", + job_function="dummy_arq_function", + max_retries=3, + retry_count=0, + job_params={ + "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], + "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], + "mapping_jobs": {}, + }, + ) + + return dependent_job + + +@pytest.fixture +def sample_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, +): + """Create a sample dependent polling job for the submission run.""" + + dependent_job = JobRun( + urn="test:dependent_poll_uniprot_mapping_jobs", + job_type="dependent_poll_uniprot_mapping_jobs", + job_function="poll_uniprot_mapping_jobs_for_score_set", + max_retries=3, + retry_count=0, + job_params={ + "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], + "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], + "mapping_jobs": {}, + }, + ) + + return dependent_job + + +@pytest.fixture +def with_dummy_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, + sample_dummy_polling_job_for_submission_run, +): + """Create a sample dummy dependent polling job for the submission run.""" + session.add(sample_dummy_polling_job_for_submission_run) + session.commit() + + dependency = JobDependency( + id=sample_dummy_polling_job_for_submission_run.id, + depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, + dependency_type=DependencyType.SUCCESS_REQUIRED, + ) + session.add(dependency) + session.commit() + + +@pytest.fixture +def with_dependent_polling_job_for_submission_run( + session, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_run, + sample_polling_job_for_submission_run, +): + """Create a sample dependent polling job for the submission run.""" + session.add(sample_polling_job_for_submission_run) + session.commit() + + dependency = JobDependency( + id=sample_polling_job_for_submission_run.id, + depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, + dependency_type=DependencyType.SUCCESS_REQUIRED, + ) + session.add(dependency) + session.commit() + + +@pytest.fixture +def with_independent_polling_job_for_submission_run( + session, + sample_polling_job_for_submission_run, +): + """Create a sample dependent polling job for the submission run.""" + session.add(sample_polling_job_for_submission_run) + session.commit() + + +@pytest.fixture +def with_submit_uniprot_mapping_job(session, sample_submit_uniprot_mapping_jobs_run): + """Add a submit_uniprot_mapping_jobs job run to the session.""" + + session.add(sample_submit_uniprot_mapping_jobs_run) + session.commit() + + +@pytest.fixture +def with_poll_uniprot_mapping_job(session, sample_poll_uniprot_mapping_jobs_run): + """Add a poll_uniprot_mapping_jobs job run to the session.""" + + session.add(sample_poll_uniprot_mapping_jobs_run) + session.commit() + + +@pytest.fixture +def sample_submit_uniprot_mapping_jobs_run_in_pipeline( + session, + with_submit_uniprot_mapping_job, + with_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run, + sample_submit_uniprot_mapping_jobs_pipeline, +): + """Provide a context with a submit_uniprot_mapping_jobs job run and pipeline.""" + + sample_submit_uniprot_mapping_jobs_run.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return sample_submit_uniprot_mapping_jobs_run + + +@pytest.fixture +def sample_poll_uniprot_mapping_jobs_run_in_pipeline( + session, + with_independent_polling_job_for_submission_run, + with_poll_uniprot_mapping_jobs_pipeline, + sample_polling_job_for_submission_run, + sample_poll_uniprot_mapping_jobs_pipeline, +): + """Provide a context with a poll_uniprot_mapping_jobs job run and pipeline.""" + + sample_polling_job_for_submission_run.pipeline_id = sample_poll_uniprot_mapping_jobs_pipeline.id + session.commit() + return sample_polling_job_for_submission_run + + +@pytest.fixture +def sample_dummy_polling_job_for_submission_run_in_pipeline( + session, + with_dummy_polling_job_for_submission_run, + with_submit_uniprot_mapping_jobs_pipeline, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_dummy_polling_job_for_submission_run, +): + """Provide a context with a dependent polling job run in the pipeline.""" + + dependent_job = sample_dummy_polling_job_for_submission_run + dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return dependent_job + + +@pytest.fixture +def sample_polling_job_for_submission_run_in_pipeline( + session, + with_dependent_polling_job_for_submission_run, + with_submit_uniprot_mapping_jobs_pipeline, + with_submit_uniprot_mapping_job, + sample_submit_uniprot_mapping_jobs_pipeline, + sample_submit_uniprot_mapping_jobs_run_in_pipeline, + sample_polling_job_for_submission_run, +): + """Provide a context with a dependent polling job run in the pipeline.""" + + dependent_job = sample_polling_job_for_submission_run + dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id + session.commit() + return dependent_job + + +@pytest.fixture +def with_submit_uniprot_mapping_jobs_pipeline( + session, + sample_submit_uniprot_mapping_jobs_pipeline, +): + """Add a submit_uniprot_mapping_jobs pipeline to the session.""" + + session.add(sample_submit_uniprot_mapping_jobs_pipeline) + session.commit() + + +@pytest.fixture +def with_poll_uniprot_mapping_jobs_pipeline( + session, + sample_poll_uniprot_mapping_jobs_pipeline, +): + """Add a poll_uniprot_mapping_jobs pipeline to the session.""" + session.add(sample_poll_uniprot_mapping_jobs_pipeline) + session.commit() + + +## Clingen Job Fixtures ## + + +@pytest.fixture +def submit_score_set_mappings_to_car_sample_pipeline(): + """Create a pipeline instance for submit_score_set_mappings_to_car job.""" + + return Pipeline( + urn="test:submit_score_set_mappings_to_car_pipeline", + name="Submit Score Set Mappings to ClinGen Allele Registry Pipeline", + ) + + +@pytest.fixture +def submit_score_set_mappings_to_ldh_sample_pipeline(): + """Create a pipeline instance for submit_score_set_mappings_to_ldh job.""" + + return Pipeline( + urn="test:submit_score_set_mappings_to_ldh_pipeline", + name="Submit Score Set Mappings to ClinGen Allele Registry Pipeline", + ) + + +@pytest.fixture +def submit_score_set_mappings_to_car_sample_job_run(submit_score_set_mappings_to_car_params): + """Create a JobRun instance for submit_score_set_mappings_to_car job.""" + + return JobRun( + urn="test:submit_score_set_mappings_to_car", + job_type="submit_score_set_mappings_to_car", + job_function="submit_score_set_mappings_to_car", + max_retries=3, + retry_count=0, + job_params=submit_score_set_mappings_to_car_params, + ) + + +@pytest.fixture +def submit_score_set_mappings_to_ldh_sample_job_run(submit_score_set_mappings_to_car_params): + """Create a JobRun instance for submit_score_set_mappings_to_car job.""" + + return JobRun( + urn="test:submit_score_set_mappings_to_car", + job_type="submit_score_set_mappings_to_car", + job_function="submit_score_set_mappings_to_car", + max_retries=3, + retry_count=0, + job_params=submit_score_set_mappings_to_car_params, + ) + + +@pytest.fixture +def submit_score_set_mappings_to_car_sample_job_run_in_pipeline( + session, + with_submit_score_set_mappings_to_car_pipeline, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_pipeline, + submit_score_set_mappings_to_car_sample_job_run, +): + """Provide a context with a submit_score_set_mappings_to_car job run and pipeline.""" + + submit_score_set_mappings_to_car_sample_job_run.pipeline_id = submit_score_set_mappings_to_car_sample_pipeline.id + session.commit() + return submit_score_set_mappings_to_car_sample_job_run + + +@pytest.fixture +def submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline( + session, + with_submit_score_set_mappings_to_ldh_pipeline, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_pipeline, + submit_score_set_mappings_to_ldh_sample_job_run, +): + """Provide a context with a submit_score_set_mappings_to_ldh job run and pipeline.""" + + submit_score_set_mappings_to_ldh_sample_job_run.pipeline_id = submit_score_set_mappings_to_ldh_sample_pipeline.id + session.commit() + return submit_score_set_mappings_to_ldh_sample_job_run + + +@pytest.fixture +def with_submit_score_set_mappings_to_car_job(session, submit_score_set_mappings_to_car_sample_job_run): + """Add a submit_score_set_mappings_to_car job run to the session.""" + + session.add(submit_score_set_mappings_to_car_sample_job_run) + session.commit() + + +@pytest.fixture +def with_submit_score_set_mappings_to_ldh_job(session, submit_score_set_mappings_to_ldh_sample_job_run): + """Add a submit_score_set_mappings_to_ldh job run to the session.""" + + session.add(submit_score_set_mappings_to_ldh_sample_job_run) + session.commit() + + +@pytest.fixture +def with_submit_score_set_mappings_to_car_pipeline( + session, + submit_score_set_mappings_to_car_sample_pipeline, +): + """Add a submit_score_set_mappings_to_car pipeline to the session.""" + + session.add(submit_score_set_mappings_to_car_sample_pipeline) + session.commit() + + +@pytest.fixture +def with_submit_score_set_mappings_to_ldh_pipeline( + session, + submit_score_set_mappings_to_ldh_sample_pipeline, +): + """Add a submit_score_set_mappings_to_ldh pipeline to the session.""" + + session.add(submit_score_set_mappings_to_ldh_sample_pipeline) + session.commit() + + +@pytest.fixture +def sample_independent_variant_creation_run(create_variants_sample_params): + """Create a JobRun instance for variant creation job.""" + + return JobRun( + urn="test:create_variants_for_score_set", + job_type="create_variants_for_score_set", + job_function="create_variants_for_score_set", + max_retries=3, + retry_count=0, + job_params=create_variants_sample_params, + ) + + +@pytest.fixture +def sample_independent_variant_mapping_run(map_variants_sample_params): + """Create a JobRun instance for variant mapping job.""" + + return JobRun( + urn="test:map_variants_for_score_set", + job_type="map_variants_for_score_set", + job_function="map_variants_for_score_set", + max_retries=3, + retry_count=0, + job_params=map_variants_sample_params, + ) + + +@pytest.fixture +def dummy_pipeline_step(): + """Create a dummy pipeline step function for testing.""" + + return JobRun( + urn="test:dummy_pipeline_step", + job_type="dummy_pipeline_step", + job_function="dummy_arq_function", + max_retries=3, + retry_count=0, + ) + + +@pytest.fixture +def sample_pipeline_variant_creation_run( + session, + with_variant_creation_pipeline, + sample_variant_creation_pipeline, + sample_independent_variant_creation_run, +): + """Create a JobRun instance for variant creation job.""" + + sample_independent_variant_creation_run.pipeline_id = sample_variant_creation_pipeline.id + session.add(sample_independent_variant_creation_run) + session.commit() + return sample_independent_variant_creation_run + + +@pytest.fixture +def sample_pipeline_variant_mapping_run( + session, + with_variant_mapping_pipeline, + sample_independent_variant_mapping_run, + sample_variant_mapping_pipeline, +): + """Create a JobRun instance for variant mapping job.""" + + sample_independent_variant_mapping_run.pipeline_id = sample_variant_mapping_pipeline.id + session.add(sample_independent_variant_mapping_run) + session.commit() + return sample_independent_variant_mapping_run + + +@pytest.fixture +def sample_variant_creation_pipeline(): + """Create a Pipeline instance.""" + + return Pipeline( + name="variant_creation_pipeline", + description="Pipeline for creating variants", + ) + + +@pytest.fixture +def sample_variant_mapping_pipeline(): + """Create a Pipeline instance.""" + + return Pipeline( + name="variant_mapping_pipeline", + description="Pipeline for mapping variants", + ) + + +@pytest.fixture +def with_independent_processing_runs( + session, + sample_independent_variant_creation_run, + sample_independent_variant_mapping_run, +): + """Fixture to ensure independent variant processing runs exist in the database.""" + + session.add(sample_independent_variant_creation_run) + session.add(sample_independent_variant_mapping_run) + session.commit() + + +@pytest.fixture +def with_variant_creation_pipeline(session, sample_variant_creation_pipeline): + """Fixture to ensure variant creation pipeline and its runs exist in the database.""" + session.add(sample_variant_creation_pipeline) + session.commit() + + +@pytest.fixture +def with_variant_creation_pipeline_runs( + session, + with_variant_creation_pipeline, + sample_variant_creation_pipeline, + sample_pipeline_variant_creation_run, + dummy_pipeline_step, +): + """Fixture to ensure pipeline variant processing runs exist in the database.""" + session.add(sample_pipeline_variant_creation_run) + dummy_pipeline_step.pipeline_id = sample_variant_creation_pipeline.id + session.add(dummy_pipeline_step) + session.commit() + + +@pytest.fixture +def with_variant_mapping_pipeline(session, sample_variant_mapping_pipeline): + """Fixture to ensure variant mapping pipeline and its runs exist in the database.""" + session.add(sample_variant_mapping_pipeline) + session.commit() + + +@pytest.fixture +def with_variant_mapping_pipeline_runs( + session, + with_variant_mapping_pipeline, + sample_variant_mapping_pipeline, + sample_pipeline_variant_mapping_run, + dummy_pipeline_step, +): + """Fixture to ensure pipeline variant processing runs exist in the database.""" + session.add(sample_pipeline_variant_mapping_run) + dummy_pipeline_step.pipeline_id = sample_variant_mapping_pipeline.id + session.add(dummy_pipeline_step) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline(): + """Create a sample Pipeline instance for testing.""" + + return Pipeline( + name="Dummy Pipeline", + description="A dummy pipeline for testing purposes", + ) + + +@pytest.fixture +def with_dummy_pipeline(session, sample_dummy_pipeline): + """Fixture to ensure dummy pipeline exists in the database.""" + session.add(sample_dummy_pipeline) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline): + """Create a sample JobRun instance for starting the dummy pipeline.""" + start_job_run = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="start_pipeline", + job_function="start_pipeline", + ) + session.add(start_job_run) + session.commit() + + return start_job_run + + +@pytest.fixture +def with_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline_start): + """Fixture to ensure a start pipeline job run for the dummy pipeline exists in the database.""" + session.add(sample_dummy_pipeline_start) + session.commit() + + +@pytest.fixture +def sample_dummy_pipeline_step(session, sample_dummy_pipeline): + """Create a sample PipelineStep instance for the dummy pipeline.""" + step = JobRun( + pipeline_id=sample_dummy_pipeline.id, + job_type="dummy_step", + job_function="dummy_arq_function", + ) + session.add(step) + session.commit() + return step + + +@pytest.fixture +def with_full_dummy_pipeline(session, with_dummy_pipeline_start, sample_dummy_pipeline, sample_dummy_pipeline_step): + """Fixture to ensure dummy pipeline steps exist in the database.""" + session.add(sample_dummy_pipeline_step) + session.commit() diff --git a/tests/worker/jobs/external_services/conftest.py b/tests/worker/jobs/external_services/conftest.py deleted file mode 100644 index 2f422506..00000000 --- a/tests/worker/jobs/external_services/conftest.py +++ /dev/null @@ -1,365 +0,0 @@ -import pytest - -from mavedb.models.enums.job_pipeline import DependencyType -from mavedb.models.job_dependency import JobDependency -from mavedb.models.job_run import JobRun -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.pipeline import Pipeline -from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant - -## Gnomad Linkage Job Fixtures ## - - -@pytest.fixture -def link_gnomad_variants_sample_params(with_populated_domain_data, sample_score_set): - """Provide sample parameters for create_variants_for_score_set job.""" - - return { - "correlation_id": "sample-correlation-id", - "score_set_id": sample_score_set.id, - } - - -@pytest.fixture -def sample_link_gnomad_variants_pipeline(): - """Create a pipeline instance for link_gnomad_variants job.""" - - return Pipeline( - urn="test:link_gnomad_variants_pipeline", - name="Link gnomAD Variants Pipeline", - ) - - -@pytest.fixture -def sample_link_gnomad_variants_run(link_gnomad_variants_sample_params): - """Create a JobRun instance for link_gnomad_variants job.""" - - return JobRun( - urn="test:link_gnomad_variants", - job_type="link_gnomad_variants", - job_function="link_gnomad_variants", - max_retries=3, - retry_count=0, - job_params=link_gnomad_variants_sample_params, - ) - - -@pytest.fixture -def with_gnomad_linking_job(session, sample_link_gnomad_variants_run): - """Add a link_gnomad_variants job run to the session.""" - - session.add(sample_link_gnomad_variants_run) - session.commit() - - -@pytest.fixture -def with_gnomad_linking_pipeline(session, sample_link_gnomad_variants_pipeline): - """Add a link_gnomad_variants pipeline to the session.""" - - session.add(sample_link_gnomad_variants_pipeline) - session.commit() - - -@pytest.fixture -def sample_link_gnomad_variants_run_pipeline( - session, - with_gnomad_linking_job, - with_gnomad_linking_pipeline, - sample_link_gnomad_variants_run, - sample_link_gnomad_variants_pipeline, -): - """Provide a context with a link_gnomad_variants job run and pipeline.""" - - sample_link_gnomad_variants_run.pipeline_id = sample_link_gnomad_variants_pipeline.id - session.commit() - return sample_link_gnomad_variants_run - - -@pytest.fixture -def setup_sample_variants_with_caid(with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run): - """Setup variants and mapped variants in the database for testing.""" - session = mock_worker_ctx["db"] - score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) - - # Add a variant and mapped variant to the database with a CAID - variant = Variant( - urn="urn:variant:test-variant-with-caid", - score_set_id=score_set.id, - hgvs_nt="NM_000000.1:c.1A>G", - hgvs_pro="NP_000000.1:p.Met1Val", - data={"hgvs_c": "NM_000000.1:c.1A>G", "hgvs_p": "NP_000000.1:p.Met1Val"}, - ) - session.add(variant) - session.commit() - mapped_variant = MappedVariant( - variant_id=variant.id, - clingen_allele_id="CA123", - current=True, - mapped_date="2024-01-01T00:00:00Z", - mapping_api_version="1.0.0", - ) - session.add(mapped_variant) - session.commit() - - -## Uniprot Job Fixtures ## - - -@pytest.fixture -def submit_uniprot_mapping_jobs_sample_params(with_populated_domain_data, sample_score_set): - """Provide sample parameters for submit_uniprot_mapping_jobs_for_score_set job.""" - - return { - "correlation_id": "sample-correlation-id", - "score_set_id": sample_score_set.id, - } - - -@pytest.fixture -def poll_uniprot_mapping_jobs_sample_params( - submit_uniprot_mapping_jobs_sample_params, - with_dependent_polling_job_for_submission_run, -): - """Provide sample parameters for poll_uniprot_mapping_jobs_for_score_set job.""" - - return { - "correlation_id": submit_uniprot_mapping_jobs_sample_params["correlation_id"], - "score_set_id": submit_uniprot_mapping_jobs_sample_params["score_set_id"], - "mapping_jobs": {}, - } - - -@pytest.fixture -def sample_submit_uniprot_mapping_jobs_pipeline(): - """Create a pipeline instance for submit_uniprot_mapping_jobs_for_score_set job.""" - - return Pipeline( - urn="test:submit_uniprot_mapping_jobs_pipeline", - name="Submit UniProt Mapping Jobs Pipeline", - ) - - -@pytest.fixture -def sample_poll_uniprot_mapping_jobs_pipeline(): - """Create a pipeline instance for poll_uniprot_mapping_jobs_for_score_set job.""" - - return Pipeline( - urn="test:poll_uniprot_mapping_jobs_pipeline", - name="Poll UniProt Mapping Jobs Pipeline", - ) - - -@pytest.fixture -def sample_submit_uniprot_mapping_jobs_run(submit_uniprot_mapping_jobs_sample_params): - """Create a JobRun instance for submit_uniprot_mapping_jobs_for_score_set job.""" - - return JobRun( - urn="test:submit_uniprot_mapping_jobs", - job_type="submit_uniprot_mapping_jobs", - job_function="submit_uniprot_mapping_jobs_for_score_set", - max_retries=3, - retry_count=0, - job_params=submit_uniprot_mapping_jobs_sample_params, - ) - - -@pytest.fixture -def sample_dummy_polling_job_for_submission_run( - session, - with_submit_uniprot_mapping_job, - sample_submit_uniprot_mapping_jobs_run, -): - """Create a sample dummy dependent polling job for the submission run.""" - - dependent_job = JobRun( - urn="test:dummy_poll_uniprot_mapping_jobs", - job_type="dummy_poll_uniprot_mapping_jobs", - job_function="dummy_arq_function", - max_retries=3, - retry_count=0, - job_params={ - "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], - "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], - "mapping_jobs": {}, - }, - ) - - return dependent_job - - -@pytest.fixture -def sample_polling_job_for_submission_run( - session, - with_submit_uniprot_mapping_job, - sample_submit_uniprot_mapping_jobs_run, -): - """Create a sample dependent polling job for the submission run.""" - - dependent_job = JobRun( - urn="test:dependent_poll_uniprot_mapping_jobs", - job_type="dependent_poll_uniprot_mapping_jobs", - job_function="poll_uniprot_mapping_jobs_for_score_set", - max_retries=3, - retry_count=0, - job_params={ - "correlation_id": sample_submit_uniprot_mapping_jobs_run.job_params["correlation_id"], - "score_set_id": sample_submit_uniprot_mapping_jobs_run.job_params["score_set_id"], - "mapping_jobs": {}, - }, - ) - - return dependent_job - - -@pytest.fixture -def with_dummy_polling_job_for_submission_run( - session, - with_submit_uniprot_mapping_job, - sample_submit_uniprot_mapping_jobs_run, - sample_dummy_polling_job_for_submission_run, -): - """Create a sample dummy dependent polling job for the submission run.""" - session.add(sample_dummy_polling_job_for_submission_run) - session.commit() - - dependency = JobDependency( - id=sample_dummy_polling_job_for_submission_run.id, - depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, - dependency_type=DependencyType.SUCCESS_REQUIRED, - ) - session.add(dependency) - session.commit() - - -@pytest.fixture -def with_dependent_polling_job_for_submission_run( - session, - with_submit_uniprot_mapping_job, - sample_submit_uniprot_mapping_jobs_run, - sample_polling_job_for_submission_run, -): - """Create a sample dependent polling job for the submission run.""" - session.add(sample_polling_job_for_submission_run) - session.commit() - - dependency = JobDependency( - id=sample_polling_job_for_submission_run.id, - depends_on_job_id=sample_submit_uniprot_mapping_jobs_run.id, - dependency_type=DependencyType.SUCCESS_REQUIRED, - ) - session.add(dependency) - session.commit() - - -@pytest.fixture -def with_independent_polling_job_for_submission_run( - session, - sample_polling_job_for_submission_run, -): - """Create a sample dependent polling job for the submission run.""" - session.add(sample_polling_job_for_submission_run) - session.commit() - - -@pytest.fixture -def with_submit_uniprot_mapping_job(session, sample_submit_uniprot_mapping_jobs_run): - """Add a submit_uniprot_mapping_jobs job run to the session.""" - - session.add(sample_submit_uniprot_mapping_jobs_run) - session.commit() - - -@pytest.fixture -def with_poll_uniprot_mapping_job(session, sample_poll_uniprot_mapping_jobs_run): - """Add a poll_uniprot_mapping_jobs job run to the session.""" - - session.add(sample_poll_uniprot_mapping_jobs_run) - session.commit() - - -@pytest.fixture -def sample_submit_uniprot_mapping_jobs_run_in_pipeline( - session, - with_submit_uniprot_mapping_job, - with_submit_uniprot_mapping_jobs_pipeline, - sample_submit_uniprot_mapping_jobs_run, - sample_submit_uniprot_mapping_jobs_pipeline, -): - """Provide a context with a submit_uniprot_mapping_jobs job run and pipeline.""" - - sample_submit_uniprot_mapping_jobs_run.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id - session.commit() - return sample_submit_uniprot_mapping_jobs_run - - -@pytest.fixture -def sample_poll_uniprot_mapping_jobs_run_in_pipeline( - session, - with_independent_polling_job_for_submission_run, - with_poll_uniprot_mapping_jobs_pipeline, - sample_polling_job_for_submission_run, - sample_poll_uniprot_mapping_jobs_pipeline, -): - """Provide a context with a poll_uniprot_mapping_jobs job run and pipeline.""" - - sample_polling_job_for_submission_run.pipeline_id = sample_poll_uniprot_mapping_jobs_pipeline.id - session.commit() - return sample_polling_job_for_submission_run - - -@pytest.fixture -def sample_dummy_polling_job_for_submission_run_in_pipeline( - session, - with_dummy_polling_job_for_submission_run, - with_submit_uniprot_mapping_jobs_pipeline, - with_submit_uniprot_mapping_job, - sample_submit_uniprot_mapping_jobs_pipeline, - sample_submit_uniprot_mapping_jobs_run_in_pipeline, - sample_dummy_polling_job_for_submission_run, -): - """Provide a context with a dependent polling job run in the pipeline.""" - - dependent_job = sample_dummy_polling_job_for_submission_run - dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id - session.commit() - return dependent_job - - -@pytest.fixture -def sample_polling_job_for_submission_run_in_pipeline( - session, - with_dependent_polling_job_for_submission_run, - with_submit_uniprot_mapping_jobs_pipeline, - with_submit_uniprot_mapping_job, - sample_submit_uniprot_mapping_jobs_pipeline, - sample_submit_uniprot_mapping_jobs_run_in_pipeline, - sample_polling_job_for_submission_run, -): - """Provide a context with a dependent polling job run in the pipeline.""" - - dependent_job = sample_polling_job_for_submission_run - dependent_job.pipeline_id = sample_submit_uniprot_mapping_jobs_pipeline.id - session.commit() - return dependent_job - - -@pytest.fixture -def with_submit_uniprot_mapping_jobs_pipeline( - session, - sample_submit_uniprot_mapping_jobs_pipeline, -): - """Add a submit_uniprot_mapping_jobs pipeline to the session.""" - - session.add(sample_submit_uniprot_mapping_jobs_pipeline) - session.commit() - - -@pytest.fixture -def with_poll_uniprot_mapping_jobs_pipeline( - session, - sample_poll_uniprot_mapping_jobs_pipeline, -): - """Add a poll_uniprot_mapping_jobs pipeline to the session.""" - session.add(sample_poll_uniprot_mapping_jobs_pipeline) - session.commit() diff --git a/tests/worker/jobs/external_services/network/test_clingen.py b/tests/worker/jobs/external_services/network/test_clingen.py new file mode 100644 index 00000000..95ce0135 --- /dev/null +++ b/tests/worker/jobs/external_services/network/test_clingen.py @@ -0,0 +1,134 @@ +from unittest.mock import patch + +import pytest +from sqlalchemy import select + +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus +from mavedb.models.mapped_variant import MappedVariant +from tests.helpers.util.setup.worker import create_mappings_in_score_set + + +# TODO#XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. +@pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.network +class TestE2EClingenSubmitScoreSetMappingsToCar: + """End-to-end tests for ClinGen CAR submission jobs.""" + + async def test_clingen_car_submission_e2e( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + mock_s3_client, + sample_score_set, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_pipeline, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + """Test the end-to-end flow of submitting score set mappings to ClinGen CAR.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", + "https://reg.test.genome.network", + ), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testuser"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the submission job was completed successfully + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify that variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id is not None + + +# TODO#XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. +@pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.network +class TestE2EClingenSubmitScoreSetMappingsToLdh: + """End-to-end tests for ClinGen LDH submission jobs.""" + + async def test_clingen_ldh_submission_e2e( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + mock_s3_client, + sample_score_set, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_pipeline, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + """Test the end-to-end flow of submitting score set mappings to ClinGen LDH.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to simulate all submissions failing + with ( + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_NAME", "testuser"), + patch("mavedb.lib.clingen.services.GENBOREE_ACCOUNT_PASSWORD", "testpassword"), + patch("mavedb.lib.clingen.constants.LDH_ACCESS_ENDPOINT", "https://genboree.org/ldh-stg/srvc"), + patch("mavedb.lib.clingen.constants.CLIN_GEN_TENANT", "dev-clingen"), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that the submission job succeeded + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify that the pipeline run status is succeeded + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/network/test_gnomad.py b/tests/worker/jobs/external_services/network/test_gnomad.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index add6d0b1..614e53e5 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -1,518 +1,2005 @@ -# ruff: noqa: E402 - -from unittest.mock import MagicMock, call, patch -from uuid import uuid4 +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import call, patch import pytest +from sqlalchemy import select -from mavedb.models.enums.job_pipeline import JobStatus -from mavedb.models.job_run import JobRun -from mavedb.worker.lib.managers.job_manager import JobManager - -arq = pytest.importorskip("arq") - -from sqlalchemy.exc import NoResultFound - -from mavedb.lib.clingen.services import ( - ClinGenAlleleRegistryService, -) +from mavedb.lib.exceptions import LDHSubmissionFailureError +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet as ScoreSetDbModel -from mavedb.worker.jobs import ( +from mavedb.models.variant import Variant +from mavedb.worker.jobs.external_services.clingen import ( submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, ) -from tests.helpers.constants import ( - TEST_CLINGEN_ALLELE_OBJECT, - TEST_MINIMAL_SEQ_SCORESET, -) -from tests.helpers.util.setup.worker import ( - setup_records_files_and_variants_with_mapping, -) - -############################################################################################################################################ -# ClinGen CAR Submission -############################################################################################################################################ +from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.util.setup.worker import create_mappings_in_score_set -@pytest.mark.asyncio @pytest.mark.unit -class TestSubmitScoreSetMappingsToCARUnit: - """Tests for the submit_score_set_mappings_to_car function.""" +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToCarUnit: + """Tests for the Clingen submit_score_set_mappings_to_car function.""" - @pytest.mark.parametrize("missing_param", ["score_set_id", "correlation_id"]) - async def test_submit_score_set_mappings_to_car_required_params( + async def test_submit_score_set_mappings_to_car_submission_disabled( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, - missing_param, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, ): - """Test that submitting a non-existent score set raises an exception.""" - - mock_job_run.job_params = {"score_set_id": 99, "correlation_id": uuid4().hex} + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", False), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) - del mock_job_run.job_params[missing_param] + mock_update_progress.assert_called_with(100, 100, "ClinGen submission is disabled. Skipping CAR submission.") + assert result["status"] == "ok" - with pytest.raises(ValueError): - await submit_score_set_mappings_to_car(mock_worker_ctx, 99, job_manager=mock_job_manager) + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 - async def test_submit_score_set_mappings_to_car_raises_when_no_score_set( + async def test_submit_score_set_mappings_to_car_no_mappings( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, ): - """Test that submitting a non-existent score set raises an exception.""" + """Test submitting score set mappings to ClinGen when there are no mappings.""" + with ( + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No mapped variants to submit to CAR. Skipped submission.") + assert result["status"] == "ok" - mock_job_run.job_params = {"score_set_id": 99, "correlation_id": uuid4().hex} + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + async def test_submit_score_set_mappings_to_car_submission_endpoint_not_set( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + ): + # Patch to disable ClinGen submission endpoint with ( - pytest.raises(NoResultFound), - patch.object(mock_job_manager.db, "scalars", side_effect=NoResultFound()), - patch.object(mock_job_manager, "update_progress", return_value=None), - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", ""), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + pytest.raises(ValueError), ): - await submit_score_set_mappings_to_car(mock_worker_ctx, 99, job_manager=mock_job_manager) + await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with( + 100, 100, "CAR submission endpoint not configured. Can't complete submission." + ) - async def test_submit_score_set_mappings_to_car_no_mapped_variants( + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_no_registered_alleles( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with no mapped variants completes successfully.""" - - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + # Patch ClinGenAlleleRegistryService to return no registered alleles with ( - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock(one=MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=0)), - ), - patch.object( - mock_job_manager.db, - "execute", - return_value=MagicMock(all=lambda: []), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=[], ), - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - patch.object(mock_job_manager, "update_progress", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, ): - result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") assert result["status"] == "ok" - async def test_submit_score_set_mappings_to_car_no_variants_updates_progress( + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_no_linked_alleles( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with no variants updates progress to 100%.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Patch ClinGenAlleleRegistryService to return registered alleles that do not match submitted HGVS + registered_alleles_mock = [ + {"@id": "CA123456", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>C"}]}, + {"@id": "CA234567", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>G"}]}, + ] with ( - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock(one=MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=0)), - ), - patch.object( - mock_job_manager.db, - "execute", - return_value=MagicMock(all=lambda: []), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, ): - await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) - expected_calls = [ - call(0, 100, "Starting CAR mapped resource submission."), - call(100, 100, "No mapped variants to submit to CAR. Skipped submission."), - ] - mock_update_progress.assert_has_calls(expected_calls) + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" - async def test_submit_score_set_mappings_to_car_no_submission_endpoint( + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_repeated_hgvs( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with no CAR submission endpoint configured raises an exception.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Patch ClinGenAlleleRegistryService to return registered alleles with repeated HGVS + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": "CA_DUPLICATE", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mapped_variants[0].post_mapped)}], + } + ] with ( - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock(one=MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=1)), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - patch.object( - mock_job_manager.db, - "execute", - return_value=MagicMock(all=lambda: [(999, {}), (1000, {})]), + # Patch get_hgvs_from_post_mapped to return the same HGVS for all variants + patch( + "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", + return_value=get_hgvs_from_post_mapped(mapped_variants[0].post_mapped), ), - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - patch.object(mock_job_manager, "update_progress", return_value=None), - patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", None), - pytest.raises(ValueError), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, ): - await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id == "CA_DUPLICATE" - async def test_submit_score_set_mappings_to_car_no_variants_associated( + async def test_submit_score_set_mappings_to_car_hgvs_not_found( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with no variants associated completes successfully.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Get the mapped variants from score set before submission + mapped_variants = session.scalars( + select(MappedVariant) + .join(Variant) + .where(Variant.score_set_id == submit_score_set_mappings_to_car_sample_job_run.job_params["score_set_id"]) + ).all() - mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) - mocked_mapped_variant_with_hgvs = MagicMock(spec=MappedVariant, id=1000, clingen_allele_id=None) + # Patch ClinGenAlleleRegistryService to return registered alleles + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] with ( - # db.scalars is called twice in this function: once to get the score set (one), once to get the mapped variants (all) - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock( - one=mocked_score_set, - all=lambda: [mocked_mapped_variant_with_hgvs], - ), - ), - # db.execute is called to get the mapped variant IDs and post mapped data - patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), - # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both - # calls return valid HGVS strings. - patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=["c.122G>C", "c.123A>T"], - ), - # validate_job_params is called to validate job parameters - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - # update_progress is called multiple times to update job progress - patch.object(mock_job_manager, "update_progress", return_value=None), - # CAR_SUBMISSION_ENDPOINT is patched to a test URL patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant - patch.object(ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[]), - # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID + # Patch get_hgvs_from_post_mapped to not find any HGVS in registered alleles + patch("mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_propagates_exception( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( patch( - "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", - return_value={}, + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), ), - patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + pytest.raises(Exception) as exc_info, ): - result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) - # Assert no CAID was not added to the variant - mock_db_add.assert_not_called() - assert mocked_mapped_variant_with_hgvs.clingen_allele_id is None - assert result["status"] == "ok" + assert str(exc_info.value) == "ClinGen service error" - async def test_submit_score_set_mappings_to_car_no_variants_found_in_db( + async def test_submit_score_set_mappings_to_car_success( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + sample_score_set, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with no mapped variants found in the db completes successfully.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Get the mapped variants from score set before submission + mapped_variants = session.scalars( + select(MappedVariant).join(Variant).where(Variant.score_set_id == sample_score_set.id) + ).all() + assert len(mapped_variants) == 4 - mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) - mocked_mapped_variant_with_hgvs = MagicMock(spec=MappedVariant, id=1000, clingen_allele_id=None) + # Patch ClinGenAlleleRegistryService to return registered alleles + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] with ( - # db.scalars is called twice in this function: once to get the score set (one), twice to get the mapped variants (all) - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock( - one=mocked_score_set, - all=lambda: [], - ), - ), - # db.execute is called to get the mapped variant IDs and post mapped data - patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), - # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both - # calls return valid HGVS strings. patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=["c.122G>C", "c.123A>T"], - ), - # validate_job_params is called to validate job parameters - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - # update_progress is called multiple times to update job progress - patch.object(mock_job_manager, "update_progress", return_value=None), - # CAR_SUBMISSION_ENDPOINT is patched to a test URL - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant - patch.object( - ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT] - ), - # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID - patch( - "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", - return_value={"c.122G>C": "CAID:0000000", "c.123A>T": "CAID:0000001"}, + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, ): - result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) - # Assert no CAID was not added to the variant - mock_db_add.assert_not_called() - assert mocked_mapped_variant_with_hgvs.clingen_allele_id is None + mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") assert result["status"] == "ok" - async def test_submit_score_set_mappings_to_car_skips_submission_for_variants_without_hgvs_string( + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + async def test_submit_score_set_mappings_to_car_updates_progress( self, - mock_job_manager, - mock_job_run, mock_worker_ctx, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + sample_score_set, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with mapped variants completes successfully but skips variants without an HGVS string.""" + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Get the mapped variants from score set before submission + mapped_variants = session.scalars( + select(MappedVariant).join(Variant).where(Variant.score_set_id == sample_score_set.id) + ).all() + assert len(mapped_variants) == 4 - mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) - mocked_mapped_variant_with_hgvs = MagicMock(spec=MappedVariant, id=1000) + # Patch ClinGenAlleleRegistryService to return registered alleles + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] with ( - # db.scalars is called twice in this function: once to get the score set (one), once to get the mapped variants (all) - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock( - one=mocked_score_set, - all=lambda: [mocked_mapped_variant_with_hgvs], - ), - ), - # db.execute is called to get the mapped variant IDs and post mapped data - patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), - # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that the first - # call returns None (no HGVS), the second returns a valid HGVS string. patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=[None, "c.123A>T"], + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - # validate_job_params is called to validate job parameters - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - # update_progress is called multiple times to update job progress - patch.object(mock_job_manager, "update_progress", return_value=None), - # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + await submit_score_set_mappings_to_car( + mock_worker_ctx, + submit_score_set_mappings_to_car_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id + ), + ) + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting CAR mapped resource submission."), + call(10, 100, "Preparing 4 mapped variants for CAR submission."), + call(15, 100, "Submitting mapped variants to CAR."), + call(60, 100, "Processing registered alleles from CAR."), + call(95, 100, "Processed 4 of 4 registered alleles."), + call(100, 100, "Completed CAR mapped resource submission."), + ] + ) + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 4 + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToCarIntegration: + """Integration tests for the Clingen submit_score_set_mappings_to_car function.""" + + async def test_submit_score_set_mappings_to_car_independent_ctx( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant - patch.object( - ClinGenAlleleRegistryService, "dispatch_submissions", return_value=[TEST_CLINGEN_ALLELE_OBJECT] + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_pipeline_ctx( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + submit_score_set_mappings_to_car_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( patch( - "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", - return_value={"c.123A>T": "CAID:0000001"}, + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), - patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), ): - result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) - # Assert the variant without an HGVS string was skipped, and the other variant was updated with the CAID - mock_db_add.assert_has_calls([call(mocked_mapped_variant_with_hgvs)]) - assert mocked_mapped_variant_with_hgvs.clingen_allele_id == "CAID:0000001" assert result["status"] == "ok" - async def test_submit_score_set_mappings_to_car_success( + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_submission_disabled( self, - mock_job_manager, - mock_job_run, - mock_worker_ctx, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with mapped variants completes successfully.""" + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", False), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 - mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) - mocked_mapped_variant_with_hgvs_999 = MagicMock(spec=MappedVariant, id=999) - mocked_mapped_variant_with_hgvs_1000 = MagicMock(spec=MappedVariant, id=1000) + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + async def test_submit_score_set_mappings_to_car_no_submission_endpoint( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Patch to disable ClinGen submission endpoint with ( - # db.scalars is called three times in this function: once to get the score set (one), twice to get the mapped variants (all) - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock( - one=mocked_score_set, - all=MagicMock( - side_effect=[[mocked_mapped_variant_with_hgvs_999], [mocked_mapped_variant_with_hgvs_1000]] - ), - ), - ), - # db.execute is called to get the mapped variant IDs and post mapped data - patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), - # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both - # calls return valid HGVS strings. - patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=["c.122G>C", "c.123A>T"], - ), - # validate_job_params is called to validate job parameters - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - # update_progress is called multiple times to update job progress - patch.object(mock_job_manager, "update_progress", return_value=None), - # CAR_SUBMISSION_ENDPOINT is patched to a test URL - patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant - patch.object( - ClinGenAlleleRegistryService, - "dispatch_submissions", - return_value=[TEST_CLINGEN_ALLELE_OBJECT, TEST_CLINGEN_ALLELE_OBJECT], - ), - # Mock the get_allele_registry_associations function to return a mapping from HGVS to CAID - patch( - "mavedb.worker.jobs.external_services.clingen.get_allele_registry_associations", - return_value={"c.122G>C": "CAID:0000000", "c.123A>T": "CAID:0000001"}, - ), - patch.object(mock_job_manager.db, "add", return_value=None) as mock_db_add, + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", ""), ): - result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) - # Assert the variant without an HGVS string was skipped, and the other variant was updated with the CAID - mock_db_add.assert_has_calls( - [call(mocked_mapped_variant_with_hgvs_999), call(mocked_mapped_variant_with_hgvs_1000)] + assert result["status"] == "failed" + assert ( + result["exception_details"]["message"] == "ClinGen Allele Registry submission endpoint is not configured." ) - assert mocked_mapped_variant_with_hgvs_999.clingen_allele_id == "CAID:0000000" - assert mocked_mapped_variant_with_hgvs_1000.clingen_allele_id == "CAID:0000001" - assert result["status"] == "ok" - async def test_submit_score_set_mappings_to_car_updates_progress( + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED + + async def test_submit_score_set_mappings_to_car_no_mappings( self, - mock_job_manager, - mock_job_run, - mock_worker_ctx, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, ): - """Test that submitting a score set with mapped variants updates progress correctly.""" + """Test submitting score set mappings to ClinGen when there are no mappings.""" + with ( + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 - mock_job_run.job_params = {"score_set_id": 1, "correlation_id": uuid4().hex} + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED - mocked_score_set = MagicMock(spec=ScoreSetDbModel, urn="urn:1", num_variants=2) - mocked_mapped_variant_with_hgvs_999 = MagicMock(spec=MappedVariant, id=999) - mocked_mapped_variant_with_hgvs_1000 = MagicMock(spec=MappedVariant, id=1000) + async def test_submit_score_set_mappings_to_car_no_registered_alleles( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + # Patch ClinGenAlleleRegistryService to return no registered alleles with ( - # db.scalars is called three times in this function: once to get the score set (one), twice to get the mapped variants (all) - patch.object( - mock_job_manager.db, - "scalars", - return_value=MagicMock( - one=mocked_score_set, - all=MagicMock( - side_effect=[[mocked_mapped_variant_with_hgvs_999], [mocked_mapped_variant_with_hgvs_1000]] - ), - ), - ), - # db.execute is called to get the mapped variant IDs and post mapped data - patch.object(mock_job_manager.db, "execute", return_value=MagicMock(all=lambda: [(999, {}), (1000, {})])), - # get_hgvs_from_post_mapped is called twice, once for each mapped variant. mock that both - # calls return valid HGVS strings. patch( - "mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", - side_effect=["c.122G>C", "c.123A>T"], + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=[], ), - # validate_job_params is called to validate job parameters - patch("mavedb.worker.jobs.external_services.clingen.validate_job_params", return_value=None), - # update_progress is called multiple times to update job progress - patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, - # CAR_SUBMISSION_ENDPOINT is patched to a test URL + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_car_no_linked_alleles( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles that do not match submitted HGVS + registered_alleles_mock = [ + {"@id": "CA123456", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>C"}]}, + {"@id": "CA234567", "type": "nucleotide", "genomicAlleles": [{"hgvs": "NC_000007.14:g.140453136A>G"}]}, + ] + + with ( patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - "https://reg.test.genome.network/pytest", - ), - # Mock the dispatch_submissions method to return a test ClinGen allele object, which we should associate with the variant - patch.object( - ClinGenAlleleRegistryService, - "dispatch_submissions", - return_value=[TEST_CLINGEN_ALLELE_OBJECT], + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), ): - result = await submit_score_set_mappings_to_car(mock_worker_ctx, 1, job_manager=mock_job_manager) + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) - # Assert the variant without an HGVS string was skipped, and the other variant was updated with the CAID - mock_update_progress.assert_has_calls( - [ - call(0, 100, "Starting CAR mapped resource submission."), - call(10, 100, "Preparing 2 mapped variants for CAR submission."), - call(15, 100, "Submitting mapped variants to CAR."), - call(50, 100, "Processing registered alleles from CAR."), - call(100, 100, "Completed CAR mapped resource submission."), - ] - ) assert result["status"] == "ok" + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 -@pytest.mark.asyncio -@pytest.mark.integration -class TestSubmitScoreSetMappingsToCARIntegration: - """Integration tests for the submit_score_set_mappings_to_car function.""" - - @pytest.fixture() - def setup_car_submission_job_run(self, session): - """Add a submit_score_set_mappings_to_car job run to the DB before each test.""" - job_run = JobRun( - job_type="external_service", - job_function="submit_score_set_mappings_to_car", - status=JobStatus.PENDING, - job_params={"correlation_id": "test-corr-id"}, - ) - session.add(job_run) - session.commit() - return job_run + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED - async def test_submit_score_set_mappings_to_car_no_submission_endpoint( + async def test_submit_score_set_mappings_to_car_propagates_exception_to_decorator( self, standalone_worker_context, session, - with_populated_test_data, - setup_car_submission_job_run, - async_client, - data_files, - arq_redis, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ): - """Test that submitting a score set with no CAR submission endpoint configured raises an exception.""" - score_set = await setup_records_files_and_variants_with_mapping( + # Create mappings in the score set + await create_mappings_in_score_set( session, - async_client, - data_files, - TEST_MINIMAL_SEQ_SCORESET, + mock_s3_client, standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, ) - with patch( - "mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", - None, - ): - with pytest.raises(ValueError): - await submit_score_set_mappings_to_car( - standalone_worker_context, - score_set.id, - JobManager( - session, - arq_redis, - setup_car_submission_job_run.id, - ), - ) + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + ): + result = await submit_score_set_mappings_to_car( + standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id + ) + + assert result["status"] == "failed" + assert result["exception_details"]["message"] == "ClinGen service error" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToCarArqContext: + """Tests for the Clingen submit_score_set_mappings_to_car function with ARQ context.""" + + async def test_submit_score_set_mappings_to_car_with_arq_context_independent( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + async def test_submit_score_set_mappings_to_car_with_arq_context_pipeline( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + submit_score_set_mappings_to_car_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to return registered alleles + mapped_variants = session.scalars(select(MappedVariant)).all() + registered_alleles_mock = [ + { + "@id": f"CA{mv.id}", + "type": "nucleotide", + "genomicAlleles": [{"hgvs": get_hgvs_from_post_mapped(mv.post_mapped)}], + } + for mv in mapped_variants + ] + + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + return_value=registered_alleles_mock, + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.SUCCEEDED + + # Verify variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == len(mapped_variants) + for variant in variants: + assert variant.clingen_allele_id == f"CA{variant.id}" + + async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handling_independent( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run) + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED + assert submit_score_set_mappings_to_car_sample_job_run.error_message == "ClinGen service error" + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handling_pipeline( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_car_job, + submit_score_set_mappings_to_car_sample_job_run_in_pipeline, + submit_score_set_mappings_to_car_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenAlleleRegistryService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch( + "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", + side_effect=Exception("ClinGen service error"), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.FAILED + assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.error_message == "ClinGen service error" + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_car_sample_pipeline) + assert submit_score_set_mappings_to_car_sample_pipeline.status == PipelineStatus.FAILED + + # Verify no variants have CAIDs assigned + variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() + assert len(variants) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToLdhUnit: + """Unit tests for the Clingen submit_score_set_mappings_to_car function.""" + + async def test_submit_score_set_mappings_to_ldh_no_variants( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with(100, 100, "No mapped variants to submit to LDH. Skipping submission.") + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_ldh_all_submissions_failed( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_submission_failure(*args, **kwargs): + return ([], ["Submission failed"]) + + # Patch ClinGenLdhService to simulate all submissions failing + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_failure(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + pytest.raises(LDHSubmissionFailureError), + ): + await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with(100, 100, "All mapped variant submissions to LDH failed.") + + async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise HGVS not found exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch("mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", return_value=None), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id + ), + ) + + mock_update_progress.assert_called_with( + 100, 100, "No valid mapped variants to submit to LDH. Skipping submission." + ) + assert result["status"] == "ok" + + async def test_submit_score_set_mappings_to_ldh_propagates_exception( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + pytest.raises(Exception) as exc_info, + ): + await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id + ), + ) + + assert str(exc_info.value) == "LDH service error" + + async def test_submit_score_set_mappings_to_ldh_partial_submission( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_partial_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + ["Submission failed for some variants"], + ) + + # Patch ClinGenLdhService to simulate partial submission success + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_partial_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id + ), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_called_with( + 100, 100, "Finalized LDH mapped resource submission (2 successes, 1 failures)." + ) + + async def test_submit_score_set_mappings_to_ldh_all_successful_submission( + self, + mock_worker_ctx, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + mock_worker_ctx, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_successful_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [], + ) + + # Patch ClinGenLdhService to simulate all submissions succeeding + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_successful_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), + patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, + ): + result = await submit_score_set_mappings_to_ldh( + mock_worker_ctx, + submit_score_set_mappings_to_ldh_sample_job_run.id, + JobManager( + mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id + ), + ) + + assert result["status"] == "ok" + mock_update_progress.assert_called_with( + 100, 100, "Finalized LDH mapped resource submission (2 successes, 0 failures)." + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToLdhIntegration: + """Integration tests for the Clingen submit_score_set_mappings_to_ldh function.""" + + async def test_submit_score_set_mappings_to_ldh_independent( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_pipeline_ctx( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + submit_score_set_mappings_to_ldh_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + + assert result["status"] == "ok" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_propagates_exception_to_decorator( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "failed" + assert result["exception_details"]["message"] == "LDH service error" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + + async def test_submit_score_set_mappings_to_ldh_no_linked_alleles( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_no_linked_alleles_submission(*args, **kwargs): + return ([], []) + + # Patch ClinGenLdhService to simulate no linked alleles found + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_no_linked_alleles_submission(), + ), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise HGVS not found exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.jobs.external_services.clingen.get_hgvs_from_post_mapped", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_all_submissions_failed( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_submission_failure(*args, **kwargs): + return ([], ["Submission failed"]) + + # Patch ClinGenLdhService to simulate all submissions failing + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_submission_failure(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "failed" + assert "All LDH submissions failed for score set" in result["exception_details"]["message"] + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + + async def test_submit_score_set_mappings_to_ldh_partial_submission( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_partial_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}], + ["Submission failed for some variants"], + ) + + # Patch ClinGenLdhService to simulate partial submission success + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_partial_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_all_successful_submission( + self, + standalone_worker_context, + session, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_successful_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [], + ) + + # Patch ClinGenLdhService to simulate all submissions succeeding + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_successful_submission(), + ), + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + ): + result = await submit_score_set_mappings_to_ldh( + standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id + ) + + assert result["status"] == "ok" + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestClingenSubmitScoreSetMappingsToLdhArqIntegration: + """ARQ Integration tests for the Clingen submit_score_set_mappings_to_ldh function.""" + + async def test_submit_score_set_mappings_to_ldh_independent( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_with_arq_context_in_pipeline( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + submit_score_set_mappings_to_ldh_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + async def dummy_ldh_submission(*args, **kwargs): + return ( + [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [], + ) + + # Patch to disable ClinGen submission endpoint + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=dummy_ldh_submission(), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handling( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + assert submit_score_set_mappings_to_ldh_sample_job_run.error_message == "LDH service error" + + async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handling_pipeline_ctx( + self, + standalone_worker_context, + session, + arq_redis, + arq_worker, + with_submit_score_set_mappings_to_ldh_job, + submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline, + submit_score_set_mappings_to_ldh_sample_pipeline, + mock_s3_client, + sample_score_dataframe, + sample_count_dataframe, + with_dummy_setup_jobs, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ): + # Create mappings in the score set + await create_mappings_in_score_set( + session, + mock_s3_client, + standalone_worker_context, + sample_score_dataframe, + sample_count_dataframe, + dummy_variant_creation_job_run, + dummy_variant_mapping_job_run, + ) + + # Patch ClinGenLdhService to raise an exception + with ( + patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=Exception("LDH service error"), + ), + ): + await arq_redis.enqueue_job( + "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id + ) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify the job status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.FAILED + assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.error_message == "LDH service error" + + # Verify the pipeline status is updated in the database + session.refresh(submit_score_set_mappings_to_ldh_sample_pipeline) + assert submit_score_set_mappings_to_ldh_sample_pipeline.status == PipelineStatus.FAILED diff --git a/tests/worker/jobs/pipeline_management/conftest.py b/tests/worker/jobs/pipeline_management/conftest.py deleted file mode 100644 index d7d2a239..00000000 --- a/tests/worker/jobs/pipeline_management/conftest.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest - -from mavedb.models.job_run import JobRun -from mavedb.models.pipeline import Pipeline - - -@pytest.fixture -def sample_dummy_pipeline(): - """Create a sample Pipeline instance for testing.""" - - return Pipeline( - name="Dummy Pipeline", - description="A dummy pipeline for testing purposes", - ) - - -@pytest.fixture -def with_dummy_pipeline(session, sample_dummy_pipeline): - """Fixture to ensure dummy pipeline exists in the database.""" - session.add(sample_dummy_pipeline) - session.commit() - - -@pytest.fixture -def sample_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline): - """Create a sample JobRun instance for starting the dummy pipeline.""" - start_job_run = JobRun( - pipeline_id=sample_dummy_pipeline.id, - job_type="start_pipeline", - job_function="start_pipeline", - ) - session.add(start_job_run) - session.commit() - - return start_job_run - - -@pytest.fixture -def with_dummy_pipeline_start(session, with_dummy_pipeline, sample_dummy_pipeline_start): - """Fixture to ensure a start pipeline job run for the dummy pipeline exists in the database.""" - session.add(sample_dummy_pipeline_start) - session.commit() - - -@pytest.fixture -def sample_dummy_pipeline_step(session, sample_dummy_pipeline): - """Create a sample PipelineStep instance for the dummy pipeline.""" - step = JobRun( - pipeline_id=sample_dummy_pipeline.id, - job_type="dummy_step", - job_function="dummy_arq_function", - ) - session.add(step) - session.commit() - return step - - -@pytest.fixture -def with_full_dummy_pipeline(session, with_dummy_pipeline_start, sample_dummy_pipeline, sample_dummy_pipeline_step): - """Fixture to ensure dummy pipeline steps exist in the database.""" - session.add(sample_dummy_pipeline_step) - session.commit() diff --git a/tests/worker/jobs/variant_processing/conftest.py b/tests/worker/jobs/variant_processing/conftest.py deleted file mode 100644 index 1b88df2d..00000000 --- a/tests/worker/jobs/variant_processing/conftest.py +++ /dev/null @@ -1,191 +0,0 @@ -from unittest import mock - -import pytest -from mypy_boto3_s3 import S3Client - -from mavedb.models.job_run import JobRun -from mavedb.models.pipeline import Pipeline - - -@pytest.fixture -def create_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): - """Provide sample parameters for create_variants_for_score_set job.""" - - return { - "scores_file_key": "sample_scores.csv", - "counts_file_key": "sample_counts.csv", - "correlation_id": "sample-correlation-id", - "updater_id": sample_user.id, - "score_set_id": sample_score_set.id, - "score_columns_metadata": {"s_0": {"description": "metadataS", "details": "detailsS"}}, - "count_columns_metadata": {"c_0": {"description": "metadataC", "details": "detailsC"}}, - } - - -@pytest.fixture -def map_variants_sample_params(with_populated_domain_data, sample_score_set, sample_user): - """Provide sample parameters for map_variants_for_score_set job.""" - - return { - "score_set_id": sample_score_set.id, - "correlation_id": "sample-mapping-correlation-id", - "updater_id": sample_user.id, - } - - -@pytest.fixture -def mock_s3_client(): - """Mock S3 client for tests that interact with S3.""" - - with mock.patch("mavedb.worker.jobs.variant_processing.creation.s3_client") as mock_s3_client_func: - mock_s3 = mock.MagicMock(spec=S3Client) - mock_s3_client_func.return_value = mock_s3 - yield mock_s3 - - -@pytest.fixture -def sample_independent_variant_creation_run(create_variants_sample_params): - """Create a JobRun instance for variant creation job.""" - - return JobRun( - urn="test:create_variants_for_score_set", - job_type="create_variants_for_score_set", - job_function="create_variants_for_score_set", - max_retries=3, - retry_count=0, - job_params=create_variants_sample_params, - ) - - -@pytest.fixture -def sample_independent_variant_mapping_run(map_variants_sample_params): - """Create a JobRun instance for variant mapping job.""" - - return JobRun( - urn="test:map_variants_for_score_set", - job_type="map_variants_for_score_set", - job_function="map_variants_for_score_set", - max_retries=3, - retry_count=0, - job_params=map_variants_sample_params, - ) - - -@pytest.fixture -def dummy_pipeline_step(): - """Create a dummy pipeline step function for testing.""" - - return JobRun( - urn="test:dummy_pipeline_step", - job_type="dummy_pipeline_step", - job_function="dummy_arq_function", - max_retries=3, - retry_count=0, - ) - - -@pytest.fixture -def sample_pipeline_variant_creation_run( - session, - with_variant_creation_pipeline, - sample_variant_creation_pipeline, - sample_independent_variant_creation_run, -): - """Create a JobRun instance for variant creation job.""" - - sample_independent_variant_creation_run.pipeline_id = sample_variant_creation_pipeline.id - session.add(sample_independent_variant_creation_run) - session.commit() - return sample_independent_variant_creation_run - - -@pytest.fixture -def sample_pipeline_variant_mapping_run( - session, - with_variant_mapping_pipeline, - sample_independent_variant_mapping_run, - sample_variant_mapping_pipeline, -): - """Create a JobRun instance for variant mapping job.""" - - sample_independent_variant_mapping_run.pipeline_id = sample_variant_mapping_pipeline.id - session.add(sample_independent_variant_mapping_run) - session.commit() - return sample_independent_variant_mapping_run - - -@pytest.fixture -def sample_variant_creation_pipeline(): - """Create a Pipeline instance.""" - - return Pipeline( - name="variant_creation_pipeline", - description="Pipeline for creating variants", - ) - - -@pytest.fixture -def sample_variant_mapping_pipeline(): - """Create a Pipeline instance.""" - - return Pipeline( - name="variant_mapping_pipeline", - description="Pipeline for mapping variants", - ) - - -@pytest.fixture -def with_independent_processing_runs( - session, - sample_independent_variant_creation_run, - sample_independent_variant_mapping_run, -): - """Fixture to ensure independent variant processing runs exist in the database.""" - - session.add(sample_independent_variant_creation_run) - session.add(sample_independent_variant_mapping_run) - session.commit() - - -@pytest.fixture -def with_variant_creation_pipeline(session, sample_variant_creation_pipeline): - """Fixture to ensure variant creation pipeline and its runs exist in the database.""" - session.add(sample_variant_creation_pipeline) - session.commit() - - -@pytest.fixture -def with_variant_creation_pipeline_runs( - session, - with_variant_creation_pipeline, - sample_variant_creation_pipeline, - sample_pipeline_variant_creation_run, - dummy_pipeline_step, -): - """Fixture to ensure pipeline variant processing runs exist in the database.""" - session.add(sample_pipeline_variant_creation_run) - dummy_pipeline_step.pipeline_id = sample_variant_creation_pipeline.id - session.add(dummy_pipeline_step) - session.commit() - - -@pytest.fixture -def with_variant_mapping_pipeline(session, sample_variant_mapping_pipeline): - """Fixture to ensure variant mapping pipeline and its runs exist in the database.""" - session.add(sample_variant_mapping_pipeline) - session.commit() - - -@pytest.fixture -def with_variant_mapping_pipeline_runs( - session, - with_variant_mapping_pipeline, - sample_variant_mapping_pipeline, - sample_pipeline_variant_mapping_run, - dummy_pipeline_step, -): - """Fixture to ensure pipeline variant processing runs exist in the database.""" - session.add(sample_pipeline_variant_mapping_run) - dummy_pipeline_step.pipeline_id = sample_variant_mapping_pipeline.id - session.add(dummy_pipeline_step) - session.commit() From 3aedeade0cd72f47fab31d3270f76d0cdbe7e212 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 27 Jan 2026 21:33:32 -0800 Subject: [PATCH 34/70] fixup(variant creation) --- src/mavedb/worker/jobs/variant_processing/creation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py index 27a5a1aa..37b7605e 100644 --- a/src/mavedb/worker/jobs/variant_processing/creation.py +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -105,6 +105,7 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job s3 = s3_client() scores = io.BytesIO() s3.download_fileobj(Bucket=CSV_UPLOAD_S3_BUCKET_NAME, Key=score_file_key, Fileobj=scores) + scores.seek(0) scores_df = pd.read_csv(scores) # Counts file is optional @@ -112,6 +113,7 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job if count_file_key: counts = io.BytesIO() s3.download_fileobj(Bucket=CSV_UPLOAD_S3_BUCKET_NAME, Key=count_file_key, Fileobj=counts) + counts.seek(0) counts_df = pd.read_csv(counts) logger.debug(msg="Successfully fetched file resources from S3", extra=job_manager.logging_context()) From 2af66dd5ed01249ff8954f8447d17a720103e3b2 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 27 Jan 2026 23:23:14 -0800 Subject: [PATCH 35/70] feat: implement job and pipeline factories with definitions and tests --- src/mavedb/lib/types/workflow.py | 16 ++ src/mavedb/lib/workflow/__init__.py | 9 + src/mavedb/lib/workflow/definitions.py | 82 +++++++ src/mavedb/lib/workflow/job_factory.py | 62 +++++ src/mavedb/lib/workflow/pipeline_factory.py | 116 ++++++++++ src/mavedb/lib/workflow/py.typed | 0 src/mavedb/models/enums/job_pipeline.py | 10 + tests/lib/workflow/conftest.py | 89 ++++++++ tests/lib/workflow/test_job_factory.py | 191 ++++++++++++++++ tests/lib/workflow/test_pipeline_factory.py | 238 ++++++++++++++++++++ 10 files changed, 813 insertions(+) create mode 100644 src/mavedb/lib/types/workflow.py create mode 100644 src/mavedb/lib/workflow/__init__.py create mode 100644 src/mavedb/lib/workflow/definitions.py create mode 100644 src/mavedb/lib/workflow/job_factory.py create mode 100644 src/mavedb/lib/workflow/pipeline_factory.py create mode 100644 src/mavedb/lib/workflow/py.typed create mode 100644 tests/lib/workflow/conftest.py create mode 100644 tests/lib/workflow/test_job_factory.py create mode 100644 tests/lib/workflow/test_pipeline_factory.py diff --git a/src/mavedb/lib/types/workflow.py b/src/mavedb/lib/types/workflow.py new file mode 100644 index 00000000..b0e6413e --- /dev/null +++ b/src/mavedb/lib/types/workflow.py @@ -0,0 +1,16 @@ +from typing import Any, TypedDict + +from mavedb.models.enums.job_pipeline import DependencyType + + +class JobDefinition(TypedDict): + key: str + type: str + function: str + params: dict[str, Any] + dependencies: list[tuple[str, DependencyType]] + + +class PipelineDefinition(TypedDict): + description: str + job_definitions: list[JobDefinition] diff --git a/src/mavedb/lib/workflow/__init__.py b/src/mavedb/lib/workflow/__init__.py new file mode 100644 index 00000000..65be1386 --- /dev/null +++ b/src/mavedb/lib/workflow/__init__.py @@ -0,0 +1,9 @@ +from .definitions import PIPELINE_DEFINITIONS +from .job_factory import JobFactory +from .pipeline_factory import PipelineFactory + +__all__ = [ + "JobFactory", + "PipelineFactory", + "PIPELINE_DEFINITIONS", +] diff --git a/src/mavedb/lib/workflow/definitions.py b/src/mavedb/lib/workflow/definitions.py new file mode 100644 index 00000000..49aa4dd7 --- /dev/null +++ b/src/mavedb/lib/workflow/definitions.py @@ -0,0 +1,82 @@ +from mavedb.lib.types.workflow import PipelineDefinition +from mavedb.models.enums.job_pipeline import DependencyType, JobType + +# As a general rule, job keys should match function names for clarity. In some cases of +# repeated jobs, a suffix may be added to the key for uniqueness. + +PIPELINE_DEFINITIONS: dict[str, PipelineDefinition] = { + "validate_map_annotate_score_set": { + "description": "Pipeline to validate, map, and annotate variants for a score set.", + "job_definitions": [ + { + "key": "create_variants_for_score_set", + "function": "create_variants_for_score_set", + "type": JobType.VARIANT_CREATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + "scores_file_key": None, # Required param to be filled in at runtime + "counts_file_key": None, # Required param to be filled in at runtime + "score_columns_metadata": None, # Required param to be filled in at runtime + "count_columns_metadata": None, # Required param to be filled in at runtime + }, + "dependencies": [], + }, + { + "key": "map_variants_for_score_set", + "function": "map_variants_for_score_set", + "type": JobType.VARIANT_MAPPING, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("create_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "submit_score_set_mappings_to_car", + "function": "submit_score_set_mappings_to_car", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "link_gnomad_variants", + "function": "link_gnomad_variants", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "submit_uniprot_mapping_jobs_for_score_set", + "function": "submit_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "poll_uniprot_mapping_jobs_for_score_set", + "function": "poll_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "mapping_jobs": {}, # Required param to be filled in at runtime by previous job + }, + "dependencies": [("submit_uniprot_mapping_jobs_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + ], + }, + # Add more pipelines here +} diff --git a/src/mavedb/lib/workflow/job_factory.py b/src/mavedb/lib/workflow/job_factory.py new file mode 100644 index 00000000..a5aa4dfa --- /dev/null +++ b/src/mavedb/lib/workflow/job_factory.py @@ -0,0 +1,62 @@ +from copy import deepcopy +from typing import Optional + +from sqlalchemy.orm import Session + +from mavedb import __version__ as mavedb_version +from mavedb.lib.types.workflow import JobDefinition +from mavedb.models.job_run import JobRun + + +class JobFactory: + """ + JobFactory is responsible for creating and persisting JobRun instances based on + provided job definitions and pipeline parameters. + + Attributes: + session (Session): The SQLAlchemy session used for database operations. + + Methods: + create_job_run(job_def: JobDefinition, pipeline_id: Optional[int], user_id: int, correlation_id: str, pipeline_params: dict) -> JobRun:""" + + def __init__(self, session: Session): + self.session = session + + def create_job_run( + self, job_def: JobDefinition, correlation_id: str, pipeline_params: dict, pipeline_id: Optional[int] = None + ) -> JobRun: + """ + Creates and persists a new JobRun instance based on the provided job definition and pipeline parameters. + + Args: + job_def (JobDefinition): The job definition containing job type, function, and parameter template. + pipeline_id (Optional[int]): The ID of the pipeline this job run is associated with. + correlation_id (str): A unique identifier for correlating this job run with external systems or logs. + pipeline_params (dict): A dictionary of parameters to fill in required job parameters and allow for extensibility. + + Returns: + JobRun: The newly created JobRun instance (not yet committed to the database). + + Raises: + ValueError: If any required parameter defined in the job definition is missing from pipeline_params. + """ + job_params = deepcopy(job_def["params"]) + + # Fill in required params from pipeline_params + for key in job_params: + if job_params[key] is None: + if key not in pipeline_params: + raise ValueError(f"Missing required param: {key}") + job_params[key] = pipeline_params[key] + + job_run = JobRun( + job_type=job_def["type"], + job_function=job_def["function"], + job_params=job_params, + pipeline_id=pipeline_id, + mavedb_version=mavedb_version, + correlation_id=correlation_id, + ) # type: ignore[call-arg] + + self.session.add(job_run) + return job_run diff --git a/src/mavedb/lib/workflow/pipeline_factory.py b/src/mavedb/lib/workflow/pipeline_factory.py new file mode 100644 index 00000000..42ec1e00 --- /dev/null +++ b/src/mavedb/lib/workflow/pipeline_factory.py @@ -0,0 +1,116 @@ +from sqlalchemy.orm import Session + +from mavedb import __version__ as mavedb_version +from mavedb.lib.logging.context import correlation_id_for_context +from mavedb.lib.workflow.definitions import PIPELINE_DEFINITIONS +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.models.enums.job_pipeline import JobType +from mavedb.models.job_dependency import JobDependency +from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.models.user import User + + +class PipelineFactory: + """ + PipelineFactory is responsible for creating Pipeline instances and their associated JobRun and JobDependency records in the database. + + Attributes: + session (Session): The SQLAlchemy session used for database operations. + + Methods: + __init__(session: Session): + Initializes the PipelineFactory with a database session. + + create_pipeline( + pipeline_name: str, + pipeline_description: Optional[str], + creating_user: User, + pipeline_params: dict + ) -> Pipeline: + Creates a new Pipeline along with its JobRun and JobDependency records, + commits them to the database, and returns the created Pipeline object. + """ + + def __init__(self, session: Session): + self.session = session + + def create_pipeline( + self, pipeline_name: str, creating_user: User, pipeline_params: dict + ) -> tuple[Pipeline, JobRun]: + """ + Creates a new Pipeline instance along with its associated JobRun and JobDependency records. + + Args: + pipeline_name (str): The name of the pipeline to create. + pipeline_description (Optional[str]): A description for the pipeline. + creating_user (User): The user object representing the user creating the pipeline. + pipeline_params (dict): Additional parameters for pipeline creation, such as correlation_id. + + Returns: + Pipeline: The created Pipeline object. + JobRun: The JobRun object representing the start of the pipeline. + + Raises: + KeyError: If the specified pipeline_name is not found in PIPELINE_DEFINITIONS. + Exception: If there is an error during database operations. + + Side Effects: + - Adds and commits new Pipeline, JobRun, and JobDependency records to the database session. + """ + pipeline_def = PIPELINE_DEFINITIONS[pipeline_name] + jobs = pipeline_def["job_definitions"] + job_runs: dict[str, JobRun] = {} + + correlation_id = pipeline_params.get("correlation_id", correlation_id_for_context()) + + pipeline = Pipeline( + name=pipeline_name, + description=pipeline_def["description"], + correlation_id=correlation_id, + created_by_user_id=creating_user.id, + mavedb_version=mavedb_version, + ) # type: ignore[call-arg] + self.session.add(pipeline) + self.session.flush() # To get pipeline.id + + start_pipeline_job = JobRun( + job_type=JobType.PIPELINE_MANAGEMENT, + job_function="start_pipeline", + job_params={}, + pipeline_id=pipeline.id, + mavedb_version=mavedb_version, + correlation_id=correlation_id, + ) # type: ignore[call-arg] + self.session.add(start_pipeline_job) + self.session.flush() # to get start_pipeline_job.id + + job_factory = JobFactory(self.session) + for job_def in jobs: + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=pipeline.id, + correlation_id=correlation_id, + pipeline_params=pipeline_params, + ) + job_runs[job_def["key"]] = job_run + + self.session.flush() # to get job_run IDs + + for job_def in jobs: + job_deps = job_def["dependencies"] + + job_run = job_runs[job_def["key"]] + for dep_key, dependency_type in job_deps: + dep_job_run = job_runs[dep_key] + + dep_job = JobDependency( + id=job_run.id, + depends_on_job_id=dep_job_run.id, + dependency_type=dependency_type, + ) # type: ignore[call-arg] + + self.session.add(dep_job) + + self.session.commit() + return pipeline, start_pipeline_job diff --git a/src/mavedb/lib/workflow/py.typed b/src/mavedb/lib/workflow/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/models/enums/job_pipeline.py b/src/mavedb/models/enums/job_pipeline.py index 0900b580..8a70eb3f 100644 --- a/src/mavedb/models/enums/job_pipeline.py +++ b/src/mavedb/models/enums/job_pipeline.py @@ -81,3 +81,13 @@ class AnnotationStatus(str, Enum): SUCCESS = "success" FAILED = "failed" SKIPPED = "skipped" + + +class JobType(str, Enum): + """Types of jobs in the pipeline.""" + + VARIANT_CREATION = "variant_creation" + VARIANT_MAPPING = "variant_mapping" + MAPPED_VARIANT_ANNOTATION = "mapped_variant_annotation" + PIPELINE_MANAGEMENT = "pipeline_management" + DATA_MANAGEMENT = "data_management" diff --git a/tests/lib/workflow/conftest.py b/tests/lib/workflow/conftest.py new file mode 100644 index 00000000..d88789a4 --- /dev/null +++ b/tests/lib/workflow/conftest.py @@ -0,0 +1,89 @@ +from unittest.mock import patch + +import pytest + +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.lib.workflow.pipeline_factory import PipelineFactory +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.user import User +from tests.helpers.constants import TEST_USER + + +@pytest.fixture +def job_factory(session): + """Fixture to provide a mocked JobFactory instance.""" + yield JobFactory(session) + + +@pytest.fixture +def pipeline_factory(session): + """Fixture to provide a mocked PipelineFactory instance.""" + yield PipelineFactory(session) + + +@pytest.fixture +def sample_job_definition(): + """Provides a sample job definition for testing.""" + return { + "key": "sample_job", + "type": "data_processing", + "function": "process_data", + "params": {"param1": "value1", "param2": "value2", "required_param": None}, + "dependencies": [], + } + + +@pytest.fixture +def sample_independent_pipeline_definition(sample_job_definition): + """Provides a sample pipeline definition for testing.""" + return { + "name": "sample_pipeline", + "description": "A sample pipeline for testing purposes.", + "job_definitions": [sample_job_definition], + } + + +@pytest.fixture +def sample_dependent_pipeline_definition(): + """Provides a sample pipeline definition with job dependencies for testing.""" + job_def_1 = { + "key": "job_1", + "type": "data_processing", + "function": "process_data_1", + "params": {"paramA": None}, + "dependencies": [], + } + job_def_2 = { + "key": "job_2", + "type": "data_processing", + "function": "process_data_2", + "params": {"paramB": None}, + "dependencies": [("job_1", DependencyType.SUCCESS_REQUIRED)], + } + return { + "name": "dependent_pipeline", + "description": "A sample pipeline with job dependencies for testing.", + "job_definitions": [job_def_1, job_def_2], + } + + +@pytest.fixture +def with_test_pipeline_definition_ctx(sample_dependent_pipeline_definition, sample_independent_pipeline_definition): + """Fixture to temporarily add a test pipeline definition.""" + test_pipeline_definitions = { + sample_dependent_pipeline_definition["name"]: sample_dependent_pipeline_definition, + sample_independent_pipeline_definition["name"]: sample_independent_pipeline_definition, + } + + with patch("mavedb.lib.workflow.pipeline_factory.PIPELINE_DEFINITIONS", test_pipeline_definitions): + yield + + +@pytest.fixture +def test_user(session): + """Fixture to create and provide a test user in the database.""" + db = session + user = User(**TEST_USER) + db.add(user) + db.commit() + yield user diff --git a/tests/lib/workflow/test_job_factory.py b/tests/lib/workflow/test_job_factory.py new file mode 100644 index 00000000..c34b6ca0 --- /dev/null +++ b/tests/lib/workflow/test_job_factory.py @@ -0,0 +1,191 @@ +from unittest.mock import patch + +import pytest + +from mavedb.models.pipeline import Pipeline + + +@pytest.mark.unit +class TestJobFactoryUnit: + """Unit tests for the JobFactory class.""" + + def test_create_job_run_persists_preset_params_from_definition(self, job_factory, sample_job_definition): + existing_params = {"param1": "new_value1", "param2": "new_value2", "required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=existing_params, + pipeline_id=1, + ) + + assert job_run.job_params["param1"] == "value1" + assert job_run.job_params["param2"] == "value2" + + def test_create_job_run_raises_error_for_missing_params(self, job_factory, sample_job_definition): + incomplete_params = {"param1": "new_value1"} # Missing param2 + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=incomplete_params, + pipeline_id=1, + ) + + assert "Missing required param: required_param" in str(exc_info.value) + + def test_create_job_run_fills_in_required_params(self, job_factory, sample_job_definition): + pipeline_params = {"required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=pipeline_params, + pipeline_id=1, + ) + + assert job_run.job_params["param1"] == "value1" + assert job_run.job_params["param2"] == "value2" + assert job_run.job_params["required_param"] == "required_value" + + def test_create_job_run_persists_correlation_id(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run.correlation_id == "test-correlation-id" + + def test_create_job_run_persists_mavedb_version(self, job_factory, sample_job_definition): + with patch("mavedb.lib.workflow.job_factory.mavedb_version", "1.2.3"): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run.mavedb_version == "1.2.3" + + def test_create_job_run_persists_job_type_and_function(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run.job_type == sample_job_definition["type"] + assert job_run.job_function == sample_job_definition["function"] + + def test_create_job_run_ignores_extra_pipeline_params(self, job_factory, sample_job_definition): + pipeline_params = { + "param1": "new_value1", + "param2": "new_value2", + "required_param": "required_value", + "extra_param": "should_be_ignored", + } + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params=pipeline_params, + pipeline_id=1, + ) + + assert "extra_param" not in job_run.job_params + + def test_create_job_run_with_no_pipeline_id(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + ) + + assert job_run.pipeline_id is None + + def test_create_job_run_associates_with_pipeline(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=42, + ) + + assert job_run.pipeline_id == 42 + + def test_create_job_run_adds_to_session(self, job_factory, sample_job_definition): + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="test-correlation-id", + pipeline_params={"param1": "value1", "param2": "value2", "required_param": "required_value"}, + pipeline_id=1, + ) + + assert job_run in job_factory.session.new + + +@pytest.mark.integration +class TestJobFactoryIntegration: + """Integration tests for the JobFactory class within pipeline execution.""" + + def test_create_job_run_independent(self, job_factory, sample_job_definition): + pipeline_params = {"required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="integration-correlation-id", + pipeline_params=pipeline_params, + ) + job_factory.session.commit() + + retrieved_job_run = job_factory.session.get(type(job_run), job_run.id) + + assert retrieved_job_run is not None + assert retrieved_job_run.job_type == sample_job_definition["type"] + assert retrieved_job_run.job_function == sample_job_definition["function"] + assert retrieved_job_run.job_params["param1"] == "value1" + assert retrieved_job_run.job_params["param2"] == "value2" + assert retrieved_job_run.job_params["required_param"] == "required_value" + assert retrieved_job_run.correlation_id == "integration-correlation-id" + assert retrieved_job_run.pipeline_id is None + + def test_create_job_run_with_pipeline(self, job_factory, sample_job_definition): + pipeline = Pipeline( + name="Test Pipeline", + description="A pipeline for testing JobFactory integration.", + ) + job_factory.session.add(pipeline) + job_factory.session.flush() + + pipeline_params = {"required_param": "required_value"} + job_run = job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="integration-correlation-id", + pipeline_params=pipeline_params, + pipeline_id=pipeline.id, + ) + job_factory.session.commit() + + retrieved_job_run = job_factory.session.get(type(job_run), job_run.id) + + assert retrieved_job_run is not None + assert retrieved_job_run.job_type == sample_job_definition["type"] + assert retrieved_job_run.job_function == sample_job_definition["function"] + assert retrieved_job_run.job_params["param1"] == "value1" + assert retrieved_job_run.job_params["param2"] == "value2" + assert retrieved_job_run.job_params["required_param"] == "required_value" + assert retrieved_job_run.correlation_id == "integration-correlation-id" + assert retrieved_job_run.pipeline_id == pipeline.id + + def test_create_job_run_missing_params_raises_error(self, job_factory, sample_job_definition): + incomplete_params = {"param1": "new_value1"} # Missing required_param + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_run( + job_def=sample_job_definition, + correlation_id="integration-correlation-id", + pipeline_params=incomplete_params, + pipeline_id=100, + ) + + assert "Missing required param: required_param" in str(exc_info.value) diff --git a/tests/lib/workflow/test_pipeline_factory.py b/tests/lib/workflow/test_pipeline_factory.py new file mode 100644 index 00000000..e585666f --- /dev/null +++ b/tests/lib/workflow/test_pipeline_factory.py @@ -0,0 +1,238 @@ +import pytest +from sqlalchemy import select + +from mavedb.lib.workflow.pipeline_factory import PipelineFactory +from mavedb.models.job_run import JobRun + + +@pytest.mark.unit +class TestPipelineFactoryUnit: + """Unit tests for the PipelineFactory class.""" + + def test_create_pipeline_raises_if_pipeline_not_found(self, session, test_user): + """Test that creating a pipeline with an unknown name raises a KeyError.""" + pipeline_factory = PipelineFactory(session=session) + + with pytest.raises(KeyError) as exc_info: + pipeline_factory.create_pipeline( + pipeline_name="unknown_pipeline", + creating_user=test_user, + pipeline_params={}, + ) + + assert "unknown_pipeline" in str(exc_info.value) + + def test_create_pipeline_prioritizes_correlation_id_from_params( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that the correlation_id from pipeline_params is used when creating a pipeline.""" + pipeline_name = sample_independent_pipeline_definition["name"] + test_correlation_id = "test-correlation-id-123" + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"correlation_id": test_correlation_id, "required_param": "some_value"}, + ) + + assert job_run.correlation_id == test_correlation_id + + def test_create_pipeline_creates_start_pipeline_job( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline results in a JobRun of type 'start_pipeline'.""" + pipeline_name = sample_independent_pipeline_definition["name"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + stmt = select(JobRun).where(JobRun.pipeline_id == pipeline.id) + job_runs = session.execute(stmt).scalars().all() + + start_pipeline_jobs = [jr for jr in job_runs if jr.job_function == "start_pipeline"] + assert len(start_pipeline_jobs) == 1 + assert start_pipeline_jobs[0].id == job_run.id + + def test_create_pipeline_creates_job_runs( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline results in the correct number of JobRun instances.""" + pipeline_name = sample_independent_pipeline_definition["name"] + expected_job_count = len(sample_independent_pipeline_definition["job_definitions"]) + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + stmt = select(JobRun).where(JobRun.pipeline_id == pipeline.id) + job_runs = session.execute(stmt).scalars().all() + + # One additional job run for the start_pipeline job + assert len(job_runs) == expected_job_count + 1 + + def test_create_pipeline_creates_job_dependencies( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_dependent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline with job dependencies results in correct JobDependency records.""" + pipeline_name = sample_dependent_pipeline_definition["name"] + jobs = sample_dependent_pipeline_definition["job_definitions"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"paramA": "valueA", "paramB": "valueB", "required_param": "some_value"}, + ) + + stmt = select(JobRun).where(JobRun.pipeline_id == pipeline.id) + job_runs = session.execute(stmt).scalars().all() + job_run_dict = {jr.job_function: jr for jr in job_runs} + + # Verify dependencies + for job_def in jobs: + job_deps = job_def["dependencies"] + job_run = job_run_dict[job_def["function"]] + + # For each dependency, check that a JobDependency record exists + # and verify its properties + for dep_key, dependency_type in job_deps: + dep_job_run = job_run_dict[[jd for jd in jobs if jd["key"] == dep_key][0]["function"]] + + assert len(job_run.job_dependencies) == 1 + for jd in job_run.job_dependencies: + assert jd.depends_on_job_id == dep_job_run.id + assert jd.dependency_type == dependency_type + + def test_create_pipeline_creates_pipeline( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Test that creating a pipeline results in a Pipeline record in the database.""" + pipeline_name = sample_independent_pipeline_definition["name"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + stmt = select(pipeline.__class__).where(pipeline.__class__.id == pipeline.id) + retrieved_pipeline = session.execute(stmt).scalars().first() + + assert retrieved_pipeline is not None + assert retrieved_pipeline.id == pipeline.id + + +@pytest.mark.integration +class TestPipelineFactoryIntegration: + """Integration tests for the PipelineFactory class.""" + + def test_create_pipeline_independent( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_independent_pipeline_definition, + test_user, + ): + """Integration test for creating an independent pipeline.""" + pipeline_name = sample_independent_pipeline_definition["name"] + + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params={"required_param": "some_value"}, + ) + + assert pipeline.name == pipeline_name + assert job_run.job_function == "start_pipeline" + + for job_def in sample_independent_pipeline_definition["job_definitions"]: + stmt = select(JobRun).where( + JobRun.pipeline_id == pipeline.id, + JobRun.job_function == job_def["function"], + ) + job_run = session.execute(stmt).scalars().first() + assert job_run is not None + assert job_run.job_params["param1"] == "value1" + assert job_run.job_params["param2"] == "value2" + assert job_run.pipeline_id == pipeline.id + assert job_run.job_dependencies == [] + + def test_create_pipeline_dependent( + self, + session, + with_test_pipeline_definition_ctx, + pipeline_factory, + sample_dependent_pipeline_definition, + test_user, + ): + """Integration test for creating a dependent pipeline.""" + pipeline_name = sample_dependent_pipeline_definition["name"] + + passed_params = {"paramA": "valueA", "paramB": "valueB", "required_param": "some_value"} + pipeline, job_run = pipeline_factory.create_pipeline( + pipeline_name=pipeline_name, + creating_user=test_user, + pipeline_params=passed_params, + ) + + assert pipeline.name == pipeline_name + assert job_run.job_function == "start_pipeline" + + job_runs = {} + for job_def in sample_dependent_pipeline_definition["job_definitions"]: + stmt = select(JobRun).where( + JobRun.pipeline_id == pipeline.id, + JobRun.job_function == job_def["function"], + ) + jr = session.execute(stmt).scalars().first() + assert jr is not None + assert jr.pipeline_id == pipeline.id + for param_key, param_value in job_def["params"].items(): + if param_value is not None: + assert jr.job_params[param_key] == param_value + else: + assert jr.job_params[param_key] == passed_params[param_key] + + job_runs[job_def["key"]] = jr + + # Verify dependencies + for job_def in sample_dependent_pipeline_definition["job_definitions"]: + job_deps = job_def["dependencies"] + job_run = job_runs[job_def["key"]] + for dep_key, dependency_type in job_deps: + dep_job_run = job_runs[dep_key] + + assert len(job_run.job_dependencies) == 1 + for jd in job_run.job_dependencies: + assert jd.depends_on_job_id == dep_job_run.id + assert jd.dependency_type == dependency_type From 987b38ad2caf484f18f35a6432e458cf1f4bb06c Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 27 Jan 2026 23:31:49 -0800 Subject: [PATCH 36/70] feat: integrate PipelineFactory for variant creation and update processes --- src/mavedb/routers/score_sets.py | 48 ++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index a20f5829..395baedf 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -68,6 +68,7 @@ generate_experiment_urn, generate_score_set_urn, ) +from mavedb.lib.workflow.pipeline_factory import PipelineFactory from mavedb.models.clinical_control import ClinicalControl from mavedb.models.contributor import Contributor from mavedb.models.enums.processing_state import ProcessingState @@ -113,6 +114,7 @@ async def enqueue_variant_creation( new_score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, new_count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, worker: ArqRedis, + db: Session, ) -> None: assert item.dataset_columns is not None @@ -169,25 +171,36 @@ async def enqueue_variant_creation( Key=counts_file_key, ) + pipeline_factory = PipelineFactory(session=db) + pipeline, pipeline_entrypoint = pipeline_factory.create_pipeline( + pipeline_name="validate_map_annotate_score_set", + creating_user=user_data.user, + pipeline_params={ + "correlation_id": correlation_id_for_context(), + "score_set_id": item.id, + "updater_id": user_data.user.id, + "scores_file_key": scores_file_key, + "counts_file_key": counts_file_key, + "score_columns_metadata": item.dataset_columns.get("score_columns_metadata") + if new_score_columns_metadata is None + else new_score_columns_metadata, + "count_columns_metadata": item.dataset_columns.get("count_columns_metadata") + if new_count_columns_metadata is None + else new_count_columns_metadata, + }, + ) + # Await the insertion of this job into the worker queue, not the job itself. # Uses provided score and counts dataframes and metadata files, or falls back to existing data on the score set if not provided. job = await worker.enqueue_job( - "create_variants_for_score_set", - correlation_id_for_context(), - item.id, - user_data.user.id, - scores_file_to_upload, - counts_file_to_upload, - item.dataset_columns.get("score_columns_metadata") - if new_score_columns_metadata is None - else new_score_columns_metadata, - item.dataset_columns.get("count_columns_metadata") - if new_count_columns_metadata is None - else new_count_columns_metadata, + pipeline_entrypoint.job_function, pipeline_entrypoint.id, _job_id=pipeline_entrypoint.urn ) if job is not None: save_to_logging_context({"worker_job_id": job.job_id}) - logger.info(msg="Enqueued variant creation job.", extra=logging_context()) + logger.info( + msg="Enqueued validate_map_annotate_score_set pipeline (job_id: {}).".format(job.job_id), + extra=logging_context(), + ) class ScoreSetUpdateResult(TypedDict): @@ -1780,6 +1793,7 @@ async def upload_score_set_variant_data( new_score_columns_metadata=dataset_column_metadata.get("score_columns_metadata", {}), new_count_columns_metadata=dataset_column_metadata.get("count_columns_metadata", {}), worker=worker, + db=db, ) db.add(item) @@ -1904,6 +1918,7 @@ async def update_score_set_with_variants( new_count_columns_metadata=dataset_column_metadata.get("count_columns_metadata") if did_count_columns_metadata_change else existing_count_columns_metadata, + db=db, ) db.add(updatedItem) @@ -1951,7 +1966,12 @@ async def update_score_set( updatedItem.processing_state = ProcessingState.processing logger.info(msg="Enqueuing variant creation job.", extra=logging_context()) - await enqueue_variant_creation(item=updatedItem, user_data=user_data, worker=worker) + await enqueue_variant_creation( + item=updatedItem, + user_data=user_data, + worker=worker, + db=db, + ) db.add(updatedItem) db.commit() From 3ca697ac35a1fadf7762dac4c7eeeab1586bcbd0 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 11:05:05 -0800 Subject: [PATCH 37/70] feat: add context manager for database session management --- src/mavedb/db/session.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/mavedb/db/session.py b/src/mavedb/db/session.py index ab75604a..4fe2baa1 100644 --- a/src/mavedb/db/session.py +++ b/src/mavedb/db/session.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -15,8 +16,23 @@ engine = create_engine( # For PostgreSQL: - DB_URL + DB_URL, + pool_size=10, # For SQLite: # DB_URL, connect_args={"check_same_thread": False} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@contextmanager +def db_session(): + """Provide a transactional scope around a series of operations.""" + session = SessionLocal() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() From c61bd41b09cefe365cfa54a7d13e4b4550ff3803 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 11:32:12 -0800 Subject: [PATCH 38/70] feat: use session context manager in worker decorators rather than injecting in lifecycle hooks This contextmanager method ensures sessions are closed in a more consistent and guaranteed manner. --- .../worker/lib/decorators/job_guarantee.py | 25 ++++++++++--------- .../worker/lib/decorators/job_management.py | 15 +++++------ .../lib/decorators/pipeline_management.py | 15 +++++------ src/mavedb/worker/lib/decorators/utils.py | 15 +++++++++++ src/mavedb/worker/settings/lifecycle.py | 8 +----- 5 files changed, 41 insertions(+), 37 deletions(-) diff --git a/src/mavedb/worker/lib/decorators/job_guarantee.py b/src/mavedb/worker/lib/decorators/job_guarantee.py index 5dabf8ff..81dc62b5 100644 --- a/src/mavedb/worker/lib/decorators/job_guarantee.py +++ b/src/mavedb/worker/lib/decorators/job_guarantee.py @@ -31,7 +31,7 @@ async def my_cron_job(ctx, ...): from mavedb import __version__ from mavedb.models.enums.job_pipeline import JobStatus from mavedb.models.job_run import JobRun -from mavedb.worker.lib.decorators.utils import is_test_mode +from mavedb.worker.lib.decorators.utils import ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers.types import JobResultData F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) @@ -60,24 +60,25 @@ async def my_cron_job(ctx, ...): def decorator(func: F) -> F: @functools.wraps(func) async def async_wrapper(*args, **kwargs): - # No-op in test mode - if is_test_mode(): - return await func(*args, **kwargs) + with ensure_session_ctx(ctx=args[0]): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) - # The job id must be passed as the second argument to the wrapped function. - job = _create_job_run(job_type, func, args, kwargs) - args = list(args) - args.insert(1, job.id) - args = tuple(args) + # The job id must be passed as the second argument to the wrapped function. + job = _create_job_run(job_type, func, args, kwargs) + args = list(args) + args.insert(1, job.id) + args = tuple(args) - return await func(*args, **kwargs) + return await func(*args, **kwargs) return async_wrapper # type: ignore return decorator -def _create_job_run(job_type: str, func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> None: +def _create_job_run(job_type: str, func: Callable[..., Awaitable[JobResultData]], args: tuple, kwargs: dict) -> JobRun: """ Creates and persists a JobRun record for a function before job execution. """ @@ -97,7 +98,7 @@ def _create_job_run(job_type: str, func: Callable[..., Awaitable[JobResultData]] job_function=func.__name__, status=JobStatus.PENDING, mavedb_version=__version__, - ) + ) # type: ignore[call-arg] db.add(job_run) db.commit() diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 37120929..8822410e 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -13,7 +13,7 @@ from arq import ArqRedis from sqlalchemy.orm import Session -from mavedb.worker.lib.decorators.utils import is_test_mode +from mavedb.worker.lib.decorators.utils import ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers import JobManager from mavedb.worker.lib.managers.types import JobResultData @@ -63,11 +63,12 @@ async def my_job_function(ctx, param1, param2, job_manager: JobManager): @functools.wraps(func) async def async_wrapper(*args, **kwargs): - # No-op in test mode - if is_test_mode(): - return await func(*args, **kwargs) + with ensure_session_ctx(ctx=args[0]): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) - return await _execute_managed_job(func, args, kwargs) + return await _execute_managed_job(func, args, kwargs) return cast(F, async_wrapper) @@ -181,7 +182,3 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar # We don't mind that we lose ARQs built in job marking, since we perform our own job # lifecycle management via with_job_management. return result - - -# Export decorator at module level for easy import -__all__ = ["with_job_management"] diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index d5ece4f6..3ba91020 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -17,7 +17,7 @@ from mavedb.models.enums.job_pipeline import PipelineStatus from mavedb.models.job_run import JobRun from mavedb.worker.lib.decorators import with_job_management -from mavedb.worker.lib.decorators.utils import is_test_mode +from mavedb.worker.lib.decorators.utils import ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers import PipelineManager from mavedb.worker.lib.managers.types import JobResultData @@ -72,11 +72,12 @@ async def my_job_function(ctx, param1, param2): @functools.wraps(func) async def async_wrapper(*args, **kwargs): - # No-op in test mode - if is_test_mode(): - return await func(*args, **kwargs) + with ensure_session_ctx(ctx=args[0]): + # No-op in test mode + if is_test_mode(): + return await func(*args, **kwargs) - return await _execute_managed_pipeline(func, args, kwargs) + return await _execute_managed_pipeline(func, args, kwargs) return cast(F, async_wrapper) @@ -196,7 +197,3 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] # We don't mind that we lose ARQs built in job marking, since we perform our own job # lifecycle management via with_job_management. return result - - -# Export decorator at module level for easy import -__all__ = ["with_pipeline_management"] diff --git a/src/mavedb/worker/lib/decorators/utils.py b/src/mavedb/worker/lib/decorators/utils.py index 373d72b3..7bfb1a4b 100644 --- a/src/mavedb/worker/lib/decorators/utils.py +++ b/src/mavedb/worker/lib/decorators/utils.py @@ -1,4 +1,7 @@ import os +from contextlib import contextmanager + +from mavedb.db.session import db_session def is_test_mode() -> bool: @@ -18,3 +21,15 @@ def is_test_mode() -> bool: # This pattern allows us to control decorator behavior in tests without # altering production code paths. return os.getenv("MAVEDB_TEST_MODE") == "1" + + +@contextmanager +def ensure_session_ctx(ctx): + if "db" in ctx and ctx["db"] is not None: + # No-op context manager + yield ctx["db"] + else: + with db_session() as session: + ctx["db"] = session + yield session + ctx["db"] = None # Optionally clean up diff --git a/src/mavedb/worker/settings/lifecycle.py b/src/mavedb/worker/settings/lifecycle.py index 7288c691..18e301f9 100644 --- a/src/mavedb/worker/settings/lifecycle.py +++ b/src/mavedb/worker/settings/lifecycle.py @@ -3,7 +3,6 @@ This module defines the startup, shutdown, and job lifecycle hooks for the ARQ worker. These hooks manage: - Process pool for CPU-intensive tasks -- Database session management per job - HGVS data provider setup - Job state initialization and cleanup """ @@ -11,7 +10,6 @@ from concurrent import futures from mavedb.data_providers.services import cdot_rest -from mavedb.db.session import SessionLocal async def startup(ctx): @@ -23,13 +21,9 @@ async def shutdown(ctx): async def on_job_start(ctx): - db = SessionLocal() - db.current_user_id = None - ctx["db"] = db ctx["hdp"] = cdot_rest() ctx["state"] = {} async def on_job_end(ctx): - db = ctx["db"] - db.close() + pass From 010f15cf75bf23df24d46c8680ec2429489deac6 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 11:44:44 -0800 Subject: [PATCH 39/70] refactor: streamline context handling in job and pipeline decorators --- .../worker/lib/decorators/job_guarantee.py | 13 +- .../worker/lib/decorators/job_management.py | 23 +- .../lib/decorators/pipeline_management.py | 23 +- src/mavedb/worker/lib/decorators/utils.py | 18 ++ tests/conftest.py | 22 ++ tests/helpers/util/setup/worker.py | 4 +- tests/worker/conftest_optional.py | 3 +- tests/worker/jobs/conftest.py | 5 +- .../worker/jobs/data_management/test_views.py | 2 + .../external_services/network/test_clingen.py | 2 + .../external_services/network/test_uniprot.py | 2 + .../jobs/external_services/test_clingen.py | 66 ++--- .../jobs/external_services/test_gnomad.py | 31 ++- .../jobs/external_services/test_uniprot.py | 60 ++-- .../test_start_pipeline.py | 12 +- .../jobs/variant_processing/test_creation.py | 127 ++++----- .../jobs/variant_processing/test_mapping.py | 257 ++++++++---------- .../lib/decorators/test_job_guarantee.py | 18 +- .../lib/decorators/test_job_management.py | 42 +-- .../decorators/test_pipeline_management.py | 46 ++-- 20 files changed, 365 insertions(+), 411 deletions(-) diff --git a/src/mavedb/worker/lib/decorators/job_guarantee.py b/src/mavedb/worker/lib/decorators/job_guarantee.py index 81dc62b5..d93c08d6 100644 --- a/src/mavedb/worker/lib/decorators/job_guarantee.py +++ b/src/mavedb/worker/lib/decorators/job_guarantee.py @@ -31,7 +31,7 @@ async def my_cron_job(ctx, ...): from mavedb import __version__ from mavedb.models.enums.job_pipeline import JobStatus from mavedb.models.job_run import JobRun -from mavedb.worker.lib.decorators.utils import ensure_session_ctx, is_test_mode +from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers.types import JobResultData F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) @@ -60,7 +60,7 @@ async def my_cron_job(ctx, ...): def decorator(func: F) -> F: @functools.wraps(func) async def async_wrapper(*args, **kwargs): - with ensure_session_ctx(ctx=args[0]): + with ensure_session_ctx(ctx=ensure_ctx(args)): # No-op in test mode if is_test_mode(): return await func(*args, **kwargs) @@ -83,14 +83,7 @@ def _create_job_run(job_type: str, func: Callable[..., Awaitable[JobResultData]] Creates and persists a JobRun record for a function before job execution. """ # Extract context (implicit first argument by ARQ convention) - if not args: - raise ValueError("Managed job functions must receive context as first argument") - ctx = args[0] - - # Get database session from context - if "db" not in ctx: - raise ValueError("DB session not found in job context") - + ctx = ensure_ctx(args) db: Session = ctx["db"] job_run = JobRun( diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 8822410e..272c96bf 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -13,7 +13,7 @@ from arq import ArqRedis from sqlalchemy.orm import Session -from mavedb.worker.lib.decorators.utils import ensure_session_ctx, is_test_mode +from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_job_id, ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers import JobManager from mavedb.worker.lib.managers.types import JobResultData @@ -63,7 +63,7 @@ async def my_job_function(ctx, param1, param2, job_manager: JobManager): @functools.wraps(func) async def async_wrapper(*args, **kwargs): - with ensure_session_ctx(ctx=args[0]): + with ensure_session_ctx(ctx=ensure_ctx(args)): # No-op in test mode if is_test_mode(): return await func(*args, **kwargs) @@ -96,23 +96,12 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar Raises: Exception: Re-raises any exception after proper job failure tracking """ - # Extract context (implicit first argument by ARQ convention) - if not args: - raise ValueError("Managed job functions must receive context as first argument") - ctx = args[0] - - # Get database session and job ID from context - if "db" not in ctx: - raise ValueError("DB session not found in job context") + ctx = ensure_ctx(args) + db_session: Session = ctx["db"] + job_id = ensure_job_id(args) + if "redis" not in ctx: raise ValueError("Redis connection not found in job context") - - # Extract job_id (second argument by MaveDB convention) - if not args or len(args) < 2 or not isinstance(args[1], int): - raise ValueError("Job ID not found in pipeline context") - job_id = args[1] - - db_session: Session = ctx["db"] redis_pool: ArqRedis = ctx["redis"] try: diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index 3ba91020..b0659a90 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -17,7 +17,7 @@ from mavedb.models.enums.job_pipeline import PipelineStatus from mavedb.models.job_run import JobRun from mavedb.worker.lib.decorators import with_job_management -from mavedb.worker.lib.decorators.utils import ensure_session_ctx, is_test_mode +from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_job_id, ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers import PipelineManager from mavedb.worker.lib.managers.types import JobResultData @@ -72,7 +72,7 @@ async def my_job_function(ctx, param1, param2): @functools.wraps(func) async def async_wrapper(*args, **kwargs): - with ensure_session_ctx(ctx=args[0]): + with ensure_session_ctx(ctx=ensure_ctx(args)): # No-op in test mode if is_test_mode(): return await func(*args, **kwargs) @@ -97,25 +97,14 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] Raises: Exception: Propagates any exception raised during function execution. """ - # Extract context (first argument by ARQ convention) - if not args or len(args) < 1 or not isinstance(args[0], dict): - raise ValueError("Managed pipeline functions must receive context as first argument") - ctx = args[0] - - # Get database session and pipeline ID from context - if "db" not in ctx: - raise ValueError("DB session not found in pipeline context") + ctx = ensure_ctx(args) + job_id = ensure_job_id(args) + db_session: Session = ctx["db"] + if "redis" not in ctx: raise ValueError("Redis connection not found in pipeline context") - - db_session: Session = ctx["db"] redis_pool: ArqRedis = ctx["redis"] - # Extract job_id (second argument by MaveDB convention) - if not args or len(args) < 2 or not isinstance(args[1], int): - raise ValueError("Job ID not found in pipeline context") - job_id = args[1] - pipeline_manager = None pipeline_id = None try: diff --git a/src/mavedb/worker/lib/decorators/utils.py b/src/mavedb/worker/lib/decorators/utils.py index 7bfb1a4b..4315b6e0 100644 --- a/src/mavedb/worker/lib/decorators/utils.py +++ b/src/mavedb/worker/lib/decorators/utils.py @@ -33,3 +33,21 @@ def ensure_session_ctx(ctx): ctx["db"] = session yield session ctx["db"] = None # Optionally clean up + + +def ensure_ctx(args) -> dict: + # Extract context (first argument by ARQ convention) + if not args or len(args) < 1 or not isinstance(args[0], dict): + raise ValueError("Managed functions must receive context as first argument") + + ctx = args[0] + return ctx + + +def ensure_job_id(args) -> int: + # Extract job_id (second argument by MaveDB convention) + if not args or len(args) < 2 or not isinstance(args[1], int): + raise ValueError("Job ID not found in function arguments") + + job_id = args[1] + return job_id diff --git a/tests/conftest.py b/tests/conftest.py index f745fe20..dd6ee6bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import logging # noqa: F401 import os import sys +from contextlib import contextmanager from datetime import datetime from unittest import mock @@ -106,6 +107,27 @@ def session(postgresql): Base.metadata.drop_all(bind=engine) +@pytest.fixture +def db_session_fixture(session): + @contextmanager + def _db_session_cm(): + yield session + + return _db_session_cm + + +# ALL locations which use the db_session fixture need to be patched to use +# the test version. +@pytest.fixture +def patch_db_session_ctxmgr(db_session_fixture): + with ( + mock.patch("mavedb.db.session.db_session", db_session_fixture), + mock.patch("mavedb.worker.lib.decorators.utils.db_session", db_session_fixture), + # Add other modules that use db_session here as needed + ): + yield + + @pytest.fixture def athena_engine(): """Create and yield a SQLAlchemy engine connected to a mock Athena database.""" diff --git a/tests/helpers/util/setup/worker.py b/tests/helpers/util/setup/worker.py index dd4473bc..2723b90f 100644 --- a/tests/helpers/util/setup/worker.py +++ b/tests/helpers/util/setup/worker.py @@ -44,7 +44,7 @@ async def create_variants_in_score_set( result = await create_variants_for_score_set( mock_worker_ctx, variant_creation_run.id, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], variant_creation_run.id), + JobManager(session, mock_worker_ctx["redis"], variant_creation_run.id), ) assert result["status"] == "ok" @@ -80,7 +80,7 @@ async def dummy_mapping_job(): result = await map_variants_for_score_set( mock_worker_ctx, variant_mapping_run.id, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], variant_mapping_run.id), + JobManager(session, mock_worker_ctx["redis"], variant_mapping_run.id), ) assert result["status"] == "ok" diff --git a/tests/worker/conftest_optional.py b/tests/worker/conftest_optional.py index 9848fe51..f6da4b7c 100644 --- a/tests/worker/conftest_optional.py +++ b/tests/worker/conftest_optional.py @@ -47,7 +47,7 @@ def mock_pipeline_manager(mock_job_manager, mock_pipeline): @pytest.fixture -def mock_worker_ctx(session): +def mock_worker_ctx(): """Create a mock worker context dictionary for testing.""" mock_redis = Mock(spec=ArqRedis) mock_hdp = Mock(spec=RESTDataProvider) @@ -57,7 +57,6 @@ def mock_worker_ctx(session): # It's generally more pain than it's worth to mock out SQLAlchemy sessions, # although it can sometimes be useful when raising specific exceptions. return { - "db": session, "redis": mock_redis, "hdp": mock_hdp, "pool": mock_pool, diff --git a/tests/worker/jobs/conftest.py b/tests/worker/jobs/conftest.py index 7310d9d6..a98d27ae 100644 --- a/tests/worker/jobs/conftest.py +++ b/tests/worker/jobs/conftest.py @@ -218,9 +218,10 @@ def sample_link_gnomad_variants_run_pipeline( @pytest.fixture -def setup_sample_variants_with_caid(with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run): +def setup_sample_variants_with_caid( + session, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run +): """Setup variants and mapped variants in the database for testing.""" - session = mock_worker_ctx["db"] score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) # Add a variant and mapped variant to the database with a CAID diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py index b9962163..2038eaf7 100644 --- a/tests/worker/jobs/data_management/test_views.py +++ b/tests/worker/jobs/data_management/test_views.py @@ -16,6 +16,8 @@ from mavedb.worker.jobs.data_management.views import refresh_materialized_views, refresh_published_variants_view from tests.helpers.transaction_spy import TransactionSpy +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + ############################################################################################################################################ # refresh_materialized_views ############################################################################################################################################ diff --git a/tests/worker/jobs/external_services/network/test_clingen.py b/tests/worker/jobs/external_services/network/test_clingen.py index 95ce0135..1a401e8e 100644 --- a/tests/worker/jobs/external_services/network/test_clingen.py +++ b/tests/worker/jobs/external_services/network/test_clingen.py @@ -7,6 +7,8 @@ from mavedb.models.mapped_variant import MappedVariant from tests.helpers.util.setup.worker import create_mappings_in_score_set +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + # TODO#XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. @pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") diff --git a/tests/worker/jobs/external_services/network/test_uniprot.py b/tests/worker/jobs/external_services/network/test_uniprot.py index 249a412c..288fb23b 100644 --- a/tests/worker/jobs/external_services/network/test_uniprot.py +++ b/tests/worker/jobs/external_services/network/test_uniprot.py @@ -3,6 +3,8 @@ from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from tests.helpers.constants import TEST_REFSEQ_IDENTIFIER +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.asyncio @pytest.mark.integration diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index 614e53e5..dff03917 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -16,6 +16,8 @@ from mavedb.worker.lib.managers.job_manager import JobManager from tests.helpers.util.setup.worker import create_mappings_in_score_set +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.unit @pytest.mark.asyncio @@ -37,9 +39,7 @@ async def test_submit_score_set_mappings_to_car_submission_disabled( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "ClinGen submission is disabled. Skipping CAR submission.") @@ -65,9 +65,7 @@ async def test_submit_score_set_mappings_to_car_no_mappings( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "No mapped variants to submit to CAR. Skipped submission.") @@ -94,9 +92,7 @@ async def test_submit_score_set_mappings_to_car_submission_endpoint_not_set( await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with( @@ -144,9 +140,7 @@ async def test_submit_score_set_mappings_to_car_no_registered_alleles( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") @@ -198,9 +192,7 @@ async def test_submit_score_set_mappings_to_car_no_linked_alleles( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") @@ -261,9 +253,7 @@ async def test_submit_score_set_mappings_to_car_repeated_hgvs( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") @@ -330,9 +320,7 @@ async def test_submit_score_set_mappings_to_car_hgvs_not_found( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") @@ -379,9 +367,7 @@ async def test_submit_score_set_mappings_to_car_propagates_exception( await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) assert str(exc_info.value) == "ClinGen service error" @@ -439,9 +425,7 @@ async def test_submit_score_set_mappings_to_car_success( result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "Completed CAR mapped resource submission.") @@ -506,9 +490,7 @@ async def test_submit_score_set_mappings_to_car_updates_progress( await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), ) mock_update_progress.assert_has_calls( @@ -1157,9 +1139,7 @@ async def test_submit_score_set_mappings_to_ldh_no_variants( result = await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "No mapped variants to submit to LDH. Skipping submission.") @@ -1207,9 +1187,7 @@ async def dummy_submission_failure(*args, **kwargs): await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) mock_update_progress.assert_called_with(100, 100, "All mapped variant submissions to LDH failed.") @@ -1248,9 +1226,7 @@ async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( result = await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) mock_update_progress.assert_called_with( @@ -1296,9 +1272,7 @@ async def test_submit_score_set_mappings_to_ldh_propagates_exception( await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) assert str(exc_info.value) == "LDH service error" @@ -1347,9 +1321,7 @@ async def dummy_partial_submission(*args, **kwargs): result = await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) assert result["status"] == "ok" @@ -1401,9 +1373,7 @@ async def dummy_successful_submission(*args, **kwargs): result = await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, - JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id - ), + JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) assert result["status"] == "ok" diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index 81b4e3ae..935c5fe8 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -10,6 +10,8 @@ from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants from mavedb.worker.lib.managers.job_manager import JobManager +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.asyncio @pytest.mark.unit @@ -18,10 +20,9 @@ class TestLinkGnomadVariantsUnit: @pytest.fixture def setup_sample_variants_with_caid( - self, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run + self, session, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run ): """Setup variants and mapped variants in the database for testing.""" - session = mock_worker_ctx["db"] score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) # Add a variant and mapped variant to the database with a CAID @@ -46,6 +47,7 @@ def setup_sample_variants_with_caid( async def test_link_gnomad_variants_no_variants_with_caids( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -56,7 +58,7 @@ async def test_link_gnomad_variants_no_variants_with_caids( result = await link_gnomad_variants( mock_worker_ctx, 1, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), ) assert result["status"] == "ok" @@ -66,6 +68,7 @@ async def test_link_gnomad_variants_no_variants_with_caids( async def test_link_gnomad_variants_no_gnomad_matches( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -84,7 +87,7 @@ async def test_link_gnomad_variants_no_gnomad_matches( result = await link_gnomad_variants( mock_worker_ctx, 1, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), ) assert result["status"] == "ok" @@ -92,6 +95,7 @@ async def test_link_gnomad_variants_no_gnomad_matches( async def test_link_gnomad_variants_call_linking_method( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -114,7 +118,7 @@ async def test_link_gnomad_variants_call_linking_method( result = await link_gnomad_variants( mock_worker_ctx, 1, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), ) assert result["status"] == "ok" @@ -123,6 +127,7 @@ async def test_link_gnomad_variants_call_linking_method( async def test_link_gnomad_variants_updates_progress( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -145,7 +150,7 @@ async def test_link_gnomad_variants_updates_progress( result = await link_gnomad_variants( mock_worker_ctx, 1, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), ) assert result["status"] == "ok" @@ -160,6 +165,7 @@ async def test_link_gnomad_variants_updates_progress( async def test_link_gnomad_variants_propagates_exceptions( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -175,7 +181,7 @@ async def test_link_gnomad_variants_propagates_exceptions( await link_gnomad_variants( mock_worker_ctx, 1, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), + JobManager(session, mock_worker_ctx["redis"], sample_link_gnomad_variants_run.id), ) assert str(exc_info.value) == "Test exception" @@ -188,6 +194,7 @@ class TestLinkGnomadVariantsIntegration: async def test_link_gnomad_variants_no_variants_with_caids( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -199,7 +206,6 @@ async def test_link_gnomad_variants_no_variants_with_caids( assert result["status"] == "ok" # Verify that no gnomAD variants were linked - session = mock_worker_ctx["db"] gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 @@ -209,6 +215,7 @@ async def test_link_gnomad_variants_no_variants_with_caids( async def test_link_gnomad_variants_no_matching_caids( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -218,7 +225,6 @@ async def test_link_gnomad_variants_no_matching_caids( ): """Test the end-to-end functionality of the link_gnomad_variants job when no matching CAIDs are found.""" # Update the created mapped variant to have a CAID that won't match any gnomAD data - session = mock_worker_ctx["db"] mapped_variant = session.query(MappedVariant).first() mapped_variant.clingen_allele_id = "NON_MATCHING_CAID" session.commit() @@ -230,7 +236,6 @@ async def test_link_gnomad_variants_no_matching_caids( assert result["status"] == "ok" # Verify that no gnomAD variants were linked - session = mock_worker_ctx["db"] gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 @@ -240,6 +245,7 @@ async def test_link_gnomad_variants_no_matching_caids( async def test_link_gnomad_variants_successful_linking_independent( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -256,7 +262,6 @@ async def test_link_gnomad_variants_successful_linking_independent( assert result["status"] == "ok" # Verify that gnomAD variants were linked - session = mock_worker_ctx["db"] gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) > 0 @@ -266,6 +271,7 @@ async def test_link_gnomad_variants_successful_linking_independent( async def test_link_gnomad_variants_successful_linking_pipeline( self, + session, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run_pipeline, @@ -282,7 +288,6 @@ async def test_link_gnomad_variants_successful_linking_pipeline( assert result["status"] == "ok" # Verify that gnomAD variants were linked - session = mock_worker_ctx["db"] gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) > 0 @@ -296,6 +301,7 @@ async def test_link_gnomad_variants_successful_linking_pipeline( async def test_link_gnomad_variants_exceptions_handled_by_decorators( self, + session, with_populated_domain_data, with_gnomad_linking_job, mock_worker_ctx, @@ -322,7 +328,6 @@ async def test_link_gnomad_variants_exceptions_handled_by_decorators( assert "Test exception" in result["exception_details"]["message"] # Verify job status updates - session = mock_worker_ctx["db"] session.refresh(sample_link_gnomad_variants_run) assert sample_link_gnomad_variants_run.status == JobStatus.FAILED diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index fc0f9fa5..ea714664 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -23,6 +23,8 @@ VALID_UNIPROT_ACCESSION, ) +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.unit @pytest.mark.asyncio @@ -42,7 +44,7 @@ async def test_submit_uniprot_mapping_jobs_no_targets( # Ensure the sample score set has no target genes sample_score_set.target_genes = [] - mock_worker_ctx["db"].commit() + session.commit() with ( patch.object(JobManager, "update_progress") as mock_update_progress, @@ -51,7 +53,7 @@ async def test_submit_uniprot_mapping_jobs_no_targets( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -85,7 +87,7 @@ async def test_submit_uniprot_mapping_jobs_no_acs_in_post_mapped_metadata( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -122,7 +124,7 @@ async def test_submit_uniprot_mapping_jobs_too_many_acs_in_post_mapped_metadata( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -163,7 +165,7 @@ async def test_submit_uniprot_mapping_jobs_no_jobs_submitted( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -207,7 +209,7 @@ async def test_submit_uniprot_mapping_jobs_api_failure_raises( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -245,7 +247,7 @@ async def test_submit_uniprot_mapping_jobs_raises_dependent_job_not_available( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -288,7 +290,7 @@ async def test_submit_uniprot_mapping_jobs_successful_submission( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -326,8 +328,8 @@ async def test_submit_uniprot_mapping_jobs_partial_submission( category="protein_coding", target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), ) - mock_worker_ctx["db"].add(new_target_gene) - mock_worker_ctx["db"].commit() + session.add(new_target_gene) + session.commit() # Arrange the post mapped metadata to have a single AC for both target genes target_gene_1 = sample_score_set.target_genes[0] @@ -347,7 +349,7 @@ async def test_submit_uniprot_mapping_jobs_partial_submission( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -396,7 +398,7 @@ async def test_submit_uniprot_mapping_jobs_updates_progress( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_submit_uniprot_mapping_jobs_run.id, ), @@ -542,7 +544,7 @@ async def test_submit_uniprot_mapping_jobs_no_targets( # Ensure the sample score set has no target genes sample_score_set.target_genes = [] - mock_worker_ctx["db"].commit() + session.commit() with ( patch( @@ -750,13 +752,13 @@ async def test_submit_uniprot_mapping_jobs_partial_submission( category="protein_coding", target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), ) - mock_worker_ctx["db"].add(new_target_gene) - mock_worker_ctx["db"].commit() + session.add(new_target_gene) + session.commit() # Add accessions to both target genes' post mapped metadata for idx, tg in enumerate(sample_score_set.target_genes): tg.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION + f"{idx:05d}"]}} - mock_worker_ctx["db"].commit() + session.commit() with ( patch( @@ -1053,7 +1055,7 @@ async def test_poll_uniprot_mapping_jobs_no_mapping_jobs( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1095,7 +1097,7 @@ async def test_poll_uniprot_mapping_jobs_results_not_ready( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1141,7 +1143,7 @@ async def test_poll_uniprot_mapping_jobs_no_results( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1199,7 +1201,7 @@ async def test_poll_uniprot_mapping_jobs_ambiguous_results( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1242,7 +1244,7 @@ async def test_poll_uniprot_mapping_jobs_nonexistent_target( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1284,7 +1286,7 @@ async def test_poll_uniprot_mapping_jobs_successful_update( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1322,8 +1324,8 @@ async def test_poll_uniprot_mapping_jobs_partial_success( category="protein_coding", target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), ) - mock_worker_ctx["db"].add(new_target_gene) - mock_worker_ctx["db"].commit() + session.add(new_target_gene) + session.commit() with ( patch( @@ -1343,7 +1345,7 @@ async def test_poll_uniprot_mapping_jobs_partial_success( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1390,7 +1392,7 @@ async def test_poll_uniprot_mapping_jobs_updates_progress( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1437,7 +1439,7 @@ async def test_poll_uniprot_mapping_jobs_propagates_exceptions( mock_worker_ctx, 1, JobManager( - db=mock_worker_ctx["db"], + db=session, redis=mock_worker_ctx["redis"], job_id=sample_polling_job_for_submission_run.id, ), @@ -1595,8 +1597,8 @@ async def test_poll_uniprot_mapping_jobs_partial_mapping_jobs( category="protein_coding", target_sequence=TargetSequence(sequence="MEEPQSDPSV", sequence_type="protein"), ) - mock_worker_ctx["db"].add(new_target_gene) - mock_worker_ctx["db"].commit() + session.add(new_target_gene) + session.commit() with ( patch( diff --git a/tests/worker/jobs/pipeline_management/test_start_pipeline.py b/tests/worker/jobs/pipeline_management/test_start_pipeline.py index 12eb9675..9f70d9f1 100644 --- a/tests/worker/jobs/pipeline_management/test_start_pipeline.py +++ b/tests/worker/jobs/pipeline_management/test_start_pipeline.py @@ -9,6 +9,8 @@ from mavedb.worker.lib.managers.job_manager import JobManager from mavedb.worker.lib.managers.pipeline_manager import PipelineManager +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.unit @pytest.mark.asyncio @@ -44,7 +46,7 @@ async def test_start_pipeline_raises_exception_when_no_pipeline_associated_with_ await start_pipeline( mock_worker_ctx, setup_start_pipeline_job_run.id, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), ) async def test_start_pipeline_starts_pipeline_successfully( @@ -65,7 +67,7 @@ async def test_start_pipeline_starts_pipeline_successfully( result = await start_pipeline( mock_worker_ctx, setup_start_pipeline_job_run.id, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), ) assert result["status"] == "ok" @@ -94,7 +96,7 @@ async def test_start_pipeline_updates_progress( result = await start_pipeline( mock_worker_ctx, setup_start_pipeline_job_run.id, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), ) assert result["status"] == "ok" @@ -129,7 +131,7 @@ async def test_start_pipeline_raises_exception( await start_pipeline( mock_worker_ctx, setup_start_pipeline_job_run.id, - JobManager(mock_worker_ctx["db"], mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), ) @@ -194,7 +196,7 @@ async def custom_side_effect(*args, **kwargs): call_count["n"] += 1 raise Exception("Simulated pipeline start failure") return await real_coordinate_pipeline( - PipelineManager(session, mock_worker_ctx["db"], sample_dummy_pipeline.id), *args, **kwargs + PipelineManager(session, session, sample_dummy_pipeline.id), *args, **kwargs ) # Allow the final coordination attempt to proceed 'normally' with patch( diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index a034ebeb..6f94ae58 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -12,25 +12,31 @@ from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set from mavedb.worker.lib.managers.job_manager import JobManager +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.unit @pytest.mark.asyncio +@pytest.mark.usefixtures("patch_db_session_ctxmgr") class TestCreateVariantsForScoreSetUnit: """Unit tests for create_variants_for_score_set job.""" async def test_create_variants_for_score_set_raises_key_error_on_missing_hdp_from_ctx( self, + mock_worker_ctx, mock_job_manager, ): - ctx = {} # Missing 'hdp' key + ctx = mock_worker_ctx.copy() + del ctx["hdp"] with pytest.raises(KeyError) as exc_info: - await create_variants_for_score_set(ctx=ctx, job_id=999, job_manager=mock_job_manager) + await create_variants_for_score_set(ctx, 999, mock_job_manager) assert str(exc_info.value) == "'hdp'" async def test_create_variants_for_score_set_calls_s3_client_with_correct_parameters( self, + session, with_independent_processing_runs, with_populated_domain_data, mock_worker_ctx, @@ -64,11 +70,9 @@ async def test_create_variants_for_score_set_calls_s3_client_with_correct_parame patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) # Use ANY for dynamically created Fileobj parameters. @@ -99,11 +103,9 @@ async def test_create_variants_for_score_set_s3_file_not_found( pytest.raises(Exception) as exc_info, ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") @@ -155,11 +157,9 @@ async def test_create_variants_for_score_set_counts_file_can_be_optional( patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) async def test_create_variants_for_score_set_raises_when_no_targets_exist( @@ -189,11 +189,9 @@ async def test_create_variants_for_score_set_raises_when_no_targets_exist( pytest.raises(ValueError) as exc_info, ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_update_progress.assert_any_call(100, 100, "Score set has no targets; cannot create variants.") @@ -201,6 +199,7 @@ async def test_create_variants_for_score_set_raises_when_no_targets_exist( async def test_create_variants_for_score_set_calls_validate_standardize_dataframe_with_correct_parameters( self, + session, with_independent_processing_runs, with_populated_domain_data, mock_worker_ctx, @@ -234,11 +233,9 @@ async def test_create_variants_for_score_set_calls_validate_standardize_datafram patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_validate.assert_called_once_with( @@ -252,6 +249,7 @@ async def test_create_variants_for_score_set_calls_validate_standardize_datafram async def test_create_variants_for_score_set_calls_create_variants_data_with_correct_parameters( self, + session, with_independent_processing_runs, with_populated_domain_data, mock_worker_ctx, @@ -285,17 +283,16 @@ async def test_create_variants_for_score_set_calls_create_variants_data_with_cor patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_create_variants_data.assert_called_once_with(sample_score_dataframe, sample_count_dataframe, None) async def test_create_variants_for_score_set_calls_create_variants_with_correct_parameters( self, + session, with_independent_processing_runs, with_populated_domain_data, mock_worker_ctx, @@ -333,17 +330,16 @@ async def test_create_variants_for_score_set_calls_create_variants_with_correct_ ) as mock_create_variants, ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) - mock_create_variants.assert_called_once_with(mock_worker_ctx["db"], sample_score_set, [mock_variant]) + mock_create_variants.assert_called_once_with(session, sample_score_set, [mock_variant]) async def test_create_variants_for_score_set_handles_empty_variant_data( self, + session, with_independent_processing_runs, with_populated_domain_data, mock_worker_ctx, @@ -374,11 +370,9 @@ async def test_create_variants_for_score_set_handles_empty_variant_data( patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) # If no exceptions are raised, the test passes for handling empty variant data. @@ -424,11 +418,9 @@ async def test_create_variants_for_score_set_removes_existing_variants_before_cr patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) # Verify that existing variants have been removed @@ -473,11 +465,9 @@ async def test_create_variants_for_score_set_updates_processing_state( patch("mavedb.worker.jobs.variant_processing.creation.create_variants", return_value=None), ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) session.refresh(sample_score_set) @@ -487,6 +477,7 @@ async def test_create_variants_for_score_set_updates_processing_state( async def test_create_variants_for_score_set_updates_progress( self, + session, with_independent_processing_runs, with_populated_domain_data, mock_worker_ctx, @@ -521,11 +512,9 @@ async def test_create_variants_for_score_set_updates_progress( patch.object(JobManager, "update_progress") as mock_update_progress, ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_update_progress.assert_has_calls( @@ -570,11 +559,9 @@ async def test_create_variants_for_score_set_retains_existing_variants_when_exce pytest.raises(Exception) as exc_info, ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) assert str(exc_info.value) == "Test exception during data validation" @@ -613,11 +600,9 @@ async def test_create_variants_for_score_set_handles_exception_and_updates_state pytest.raises(Exception) as exc_info, ): await create_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_creation_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_creation_run.id - ), + mock_worker_ctx, + sample_independent_variant_creation_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) assert str(exc_info.value) == "Test exception during data validation" @@ -1239,11 +1224,7 @@ async def test_create_variants_for_score_set_with_arq_context_pipeline_ctx( side_effect=[sample_score_dataframe, sample_count_dataframe], ), ): - await arq_redis.enqueue_job( - "create_variants_for_score_set", - sample_pipeline_variant_creation_run.id, - _job_id=sample_pipeline_variant_creation_run.urn, - ) + await arq_redis.enqueue_job("create_variants_for_score_set", sample_pipeline_variant_creation_run.id) await arq_worker.async_run() await arq_worker.run_check() diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index 74a1c050..fa0c3dc8 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -19,6 +19,8 @@ from tests.helpers.constants import TEST_CODING_LAYER, TEST_GENOMIC_LAYER, TEST_PROTEIN_LAYER from tests.helpers.util.setup.worker import construct_mock_mapping_output, create_variants_in_score_set +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @pytest.mark.unit @pytest.mark.asyncio @@ -30,6 +32,7 @@ async def dummy_mapping_output(self, output_data={}): async def test_map_variants_for_score_set_no_mapping_results( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -45,11 +48,9 @@ async def test_map_variants_for_score_set_no_mapping_results( pytest.raises(NonexistentMappingResultsError), ): await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing results.") @@ -63,6 +64,7 @@ async def test_map_variants_for_score_set_no_mapping_results( async def test_map_variants_for_score_set_no_mapped_scores( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -84,11 +86,9 @@ async def test_map_variants_for_score_set_no_mapped_scores( pytest.raises(NonexistentMappingScoresError), ): await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed; no variants were mapped.") @@ -99,6 +99,7 @@ async def test_map_variants_for_score_set_no_mapped_scores( async def test_map_variants_for_score_set_no_reference_data( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -120,11 +121,9 @@ async def test_map_variants_for_score_set_no_reference_data( pytest.raises(NonexistentMappingReferenceError), ): await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing reference metadata.") @@ -135,6 +134,7 @@ async def test_map_variants_for_score_set_no_reference_data( async def test_map_variants_for_score_set_nonexistent_target_gene( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -159,11 +159,9 @@ async def test_map_variants_for_score_set_nonexistent_target_gene( pytest.raises(ValueError), ): await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") @@ -177,6 +175,7 @@ async def test_map_variants_for_score_set_nonexistent_target_gene( async def test_map_variants_for_score_set_returns_variants_not_in_score_set( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -185,7 +184,7 @@ async def test_map_variants_for_score_set_returns_variants_not_in_score_set( """Test mapping variants when variants not in score set are returned.""" # Add a non-existent variant to the mapped output to ensure at least one invalid mapping mapping_output = await construct_mock_mapping_output( - session=mock_worker_ctx["db"], score_set=sample_score_set, with_layers={"g", "c", "p"} + session=session, score_set=sample_score_set, with_layers={"g", "c", "p"} ) mapping_output["mapped_scores"].append({"variant_id": "not_in_score_set", "some_other_data": "value"}) @@ -201,11 +200,9 @@ async def test_map_variants_for_score_set_returns_variants_not_in_score_set( pytest.raises(NoResultFound), ): await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") @@ -219,6 +216,7 @@ async def test_map_variants_for_score_set_returns_variants_not_in_score_set( async def test_map_variants_for_score_set_success_missing_gene_info( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -230,7 +228,7 @@ async def test_map_variants_for_score_set_success_missing_gene_info( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=False, with_layers={"g", "c", "p"}, @@ -245,8 +243,8 @@ async def dummy_mapping_job(): variant = Variant( score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} ) - mock_worker_ctx["db"].add(variant) - mock_worker_ctx["db"].commit() + session.add(variant) + session.commit() with ( patch.object( @@ -256,11 +254,9 @@ async def dummy_mapping_job(): ), ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "ok" @@ -275,7 +271,7 @@ async def dummy_mapping_job(): assert target.mapped_hgnc_name is None # Verify that a mapped variant was created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 1 @pytest.mark.parametrize( @@ -292,6 +288,7 @@ async def dummy_mapping_job(): ) async def test_map_variants_for_score_set_success_layer_permutations( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -304,7 +301,7 @@ async def test_map_variants_for_score_set_success_layer_permutations( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers=with_layers, @@ -319,8 +316,8 @@ async def dummy_mapping_job(): variant = Variant( score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} ) - mock_worker_ctx["db"].add(variant) - mock_worker_ctx["db"].commit() + session.add(variant) + session.commit() with ( patch.object( @@ -330,11 +327,9 @@ async def dummy_mapping_job(): ), ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "ok" @@ -383,11 +378,12 @@ async def dummy_mapping_job(): assert target.post_mapped_metadata.get("protein") is None # Verify that a mapped variant was created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 1 async def test_map_variants_for_score_set_success_no_successful_mapping( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -399,7 +395,7 @@ async def test_map_variants_for_score_set_success_no_successful_mapping( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -414,8 +410,8 @@ async def dummy_mapping_job(): variant = Variant( score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} ) - mock_worker_ctx["db"].add(variant) - mock_worker_ctx["db"].commit() + session.add(variant) + session.commit() with ( patch.object( @@ -425,11 +421,9 @@ async def dummy_mapping_job(): ), ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "error" @@ -440,7 +434,7 @@ async def dummy_mapping_job(): assert sample_score_set.mapping_errors["error_message"] == "All variants failed to map." # Verify that one mapped variant was created. Although no successful mapping, an entry is still created. - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 1 # Verify that the mapped variant has no post-mapped data @@ -449,6 +443,7 @@ async def dummy_mapping_job(): async def test_map_variants_for_score_set_incomplete_mapping( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -460,7 +455,7 @@ async def test_map_variants_for_score_set_incomplete_mapping( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -486,8 +481,8 @@ async def dummy_mapping_job(): data={}, urn="variant:2", ) - mock_worker_ctx["db"].add_all([variant1, variant2]) - mock_worker_ctx["db"].commit() + session.add_all([variant1, variant2]) + session.commit() with ( patch.object( @@ -497,11 +492,9 @@ async def dummy_mapping_job(): ), ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "ok" @@ -513,22 +506,23 @@ async def dummy_mapping_job(): # Although only one variant was successfully mapped, verify that an entity was created # for each variant in the score set - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 2 # Verify that only one variant has post-mapped data mapped_variant_with_post_data = ( - mock_worker_ctx["db"].query(MappedVariant).filter(MappedVariant.post_mapped != {}).one_or_none() + session.query(MappedVariant).filter(MappedVariant.post_mapped != {}).one_or_none() ) assert mapped_variant_with_post_data is not None mapped_variant_without_post_data = ( - mock_worker_ctx["db"].query(MappedVariant).filter(MappedVariant.post_mapped == {}).one_or_none() + session.query(MappedVariant).filter(MappedVariant.post_mapped == {}).one_or_none() ) assert mapped_variant_without_post_data is not None async def test_map_variants_for_score_set_complete_mapping( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -540,7 +534,7 @@ async def test_map_variants_for_score_set_complete_mapping( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -566,8 +560,8 @@ async def dummy_mapping_job(): data={}, urn="variant:2", ) - mock_worker_ctx["db"].add_all([variant1, variant2]) - mock_worker_ctx["db"].commit() + session.add_all([variant1, variant2]) + session.commit() with ( patch.object( @@ -577,11 +571,9 @@ async def dummy_mapping_job(): ), ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "ok" @@ -592,21 +584,20 @@ async def dummy_mapping_job(): assert sample_score_set.mapping_errors is None # Verify that mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 2 # Verify that both variants have post-mapped data. I'm comfortable assuming the # data is correct given our layer permutation tests above. for urn in ["variant:1", "variant:2"]: - mapped_variant = ( - mock_worker_ctx["db"].query(MappedVariant).filter(MappedVariant.variant.has(urn=urn)).one_or_none() - ) + mapped_variant = session.query(MappedVariant).filter(MappedVariant.variant.has(urn=urn)).one_or_none() assert mapped_variant is not None assert mapped_variant.post_mapped != {} async def test_map_variants_for_score_set_updates_existing_mapped_variants( self, with_independent_processing_runs, + session, mock_worker_ctx, sample_independent_variant_mapping_run, sample_score_set, @@ -617,7 +608,7 @@ async def test_map_variants_for_score_set_updates_existing_mapped_variants( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -632,16 +623,16 @@ async def dummy_mapping_job(): variant = Variant( score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} ) - mock_worker_ctx["db"].add(variant) - mock_worker_ctx["db"].commit() + session.add(variant) + session.commit() mapped_variant = MappedVariant( variant_id=variant.id, current=True, mapped_date="2023-01-01T00:00:00Z", mapping_api_version="v1.0.0", ) - mock_worker_ctx["db"].add(mapped_variant) - mock_worker_ctx["db"].commit() + session.add(mapped_variant) + session.commit() with ( patch.object( @@ -651,11 +642,9 @@ async def dummy_mapping_job(): ), ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "ok" @@ -667,8 +656,7 @@ async def dummy_mapping_job(): # Verify the existing mapped variant was marked as non-current non_current_mapped_variant = ( - mock_worker_ctx["db"] - .query(MappedVariant) + session.query(MappedVariant) .filter(MappedVariant.id == mapped_variant.id, MappedVariant.current.is_(False)) .one_or_none() ) @@ -676,8 +664,7 @@ async def dummy_mapping_job(): # Verify a new mapped variant entry was created new_mapped_variant = ( - mock_worker_ctx["db"] - .query(MappedVariant) + session.query(MappedVariant) .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) .one_or_none() ) @@ -689,6 +676,7 @@ async def dummy_mapping_job(): async def test_map_variants_for_score_set_progress_updates( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -700,7 +688,7 @@ async def test_map_variants_for_score_set_progress_updates( # with return value from run_in_executor. async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -715,8 +703,8 @@ async def dummy_mapping_job(): variant = Variant( score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} ) - mock_worker_ctx["db"].add(variant) - mock_worker_ctx["db"].commit() + session.add(variant) + session.commit() with ( patch.object( @@ -727,11 +715,9 @@ async def dummy_mapping_job(): patch.object(JobManager, "update_progress") as mock_update_progress, ): result = await map_variants_for_score_set( - ctx=mock_worker_ctx, - job_id=sample_independent_variant_mapping_run.id, - job_manager=JobManager( - mock_worker_ctx["db"], mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id - ), + mock_worker_ctx, + sample_independent_variant_mapping_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) assert result["status"] == "ok" @@ -785,7 +771,7 @@ async def test_map_variants_for_score_set_independent_job( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -812,7 +798,7 @@ async def dummy_mapping_job(): assert result["exception_details"] is None # Verify that mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 4 # Verify score set mapping state @@ -826,8 +812,7 @@ async def dummy_mapping_job(): # Verify that each variant has a corresponding mapped variant variants = ( - mock_worker_ctx["db"] - .query(Variant) + session.query(Variant) .join(MappedVariant, MappedVariant.variant_id == Variant.id) .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) .all() @@ -836,8 +821,7 @@ async def dummy_mapping_job(): # Verify that the job status was updated processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -870,7 +854,7 @@ async def test_map_variants_for_score_set_pipeline_context( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -897,7 +881,7 @@ async def dummy_mapping_job(): assert result["exception_details"] is None # Verify that mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 4 # Verify score set mapping state @@ -911,8 +895,7 @@ async def dummy_mapping_job(): # Verify that each variant has a corresponding mapped variant variants = ( - mock_worker_ctx["db"] - .query(Variant) + session.query(Variant) .join(MappedVariant, MappedVariant.variant_id == Variant.id) .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) .all() @@ -921,8 +904,7 @@ async def dummy_mapping_job(): # Verify that the job status was updated processing_run = ( - mock_worker_ctx["db"] - .query(sample_pipeline_variant_mapping_run.__class__) + session.query(sample_pipeline_variant_mapping_run.__class__) .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) .one() ) @@ -931,8 +913,7 @@ async def dummy_mapping_job(): # Verify that the pipeline run status was updated. We expect RUNNING here because # the mapping job is not the only job in our dummy pipeline. pipeline_run = ( - mock_worker_ctx["db"] - .query(sample_pipeline_variant_mapping_run.pipeline.__class__) + session.query(sample_pipeline_variant_mapping_run.pipeline.__class__) .filter( sample_pipeline_variant_mapping_run.pipeline.__class__.id == sample_pipeline_variant_mapping_run.pipeline.id @@ -990,13 +971,12 @@ async def dummy_mapping_job(): ) # Verify that no mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 # Verify that the job status was updated. processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -1028,7 +1008,7 @@ async def test_map_variants_for_score_set_no_mapped_scores( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -1063,13 +1043,12 @@ async def dummy_mapping_job(): assert "test error: no mapped scores" in sample_score_set.mapping_errors["error_message"] # Verify that no mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 # Verify that the job status was updated. processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -1101,7 +1080,7 @@ async def test_map_variants_for_score_set_no_reference_data( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -1135,13 +1114,12 @@ async def dummy_mapping_job(): assert "Reference metadata missing from mapping results" in sample_score_set.mapping_errors["error_message"] # Verify that no mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 # Verify that the job status was updated. processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -1172,7 +1150,7 @@ async def test_map_variants_for_score_set_updates_current_mapped_variants( ) # Associate mapped variants with all variants just created in the score set - variants = mock_worker_ctx["db"].query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() + variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() for variant in variants: mapped_variant = MappedVariant( variant_id=variant.id, @@ -1180,12 +1158,12 @@ async def test_map_variants_for_score_set_updates_current_mapped_variants( mapped_date="2023-01-01T00:00:00Z", mapping_api_version="v1.0.0", ) - mock_worker_ctx["db"].add(mapped_variant) - mock_worker_ctx["db"].commit() + session.add(mapped_variant) + session.commit() async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -1218,20 +1196,18 @@ async def dummy_mapping_job(): assert sample_score_set.mapping_errors is None # Verify that mapped variants were marked as non-current and new entries created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == len(variants) * 2 # Each variant has two mapped entries now for variant in variants: non_current_mapped_variant = ( - mock_worker_ctx["db"] - .query(MappedVariant) + session.query(MappedVariant) .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(False)) .one_or_none() ) assert non_current_mapped_variant is not None new_mapped_variant = ( - mock_worker_ctx["db"] - .query(MappedVariant) + session.query(MappedVariant) .filter(MappedVariant.variant_id == variant.id, MappedVariant.current.is_(True)) .one_or_none() ) @@ -1243,8 +1219,7 @@ async def dummy_mapping_job(): # Verify that the job status was updated. processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -1252,6 +1227,7 @@ async def dummy_mapping_job(): async def test_map_variants_for_score_set_no_variants( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -1261,7 +1237,7 @@ async def test_map_variants_for_score_set_no_variants( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=mock_worker_ctx["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -1296,13 +1272,12 @@ async def dummy_mapping_job(): assert "test error: no mapped scores" in sample_score_set.mapping_errors["error_message"] # Verify that no mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 # Verify that the job status was updated. processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -1310,6 +1285,7 @@ async def dummy_mapping_job(): async def test_map_variants_for_score_set_exception_in_mapping( self, + session, with_independent_processing_runs, mock_worker_ctx, sample_independent_variant_mapping_run, @@ -1349,13 +1325,12 @@ async def dummy_mapping_job(): ) # Verify that no mapped variants were created - mapped_variants = mock_worker_ctx["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 # Verify that the job status was updated. processing_run = ( - mock_worker_ctx["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) diff --git a/tests/worker/lib/decorators/test_job_guarantee.py b/tests/worker/lib/decorators/test_job_guarantee.py index 1371fed3..23db1d94 100644 --- a/tests/worker/lib/decorators/test_job_guarantee.py +++ b/tests/worker/lib/decorators/test_job_guarantee.py @@ -16,6 +16,8 @@ from mavedb.worker.lib.decorators.job_guarantee import with_guaranteed_job_run_record from tests.helpers.transaction_spy import TransactionSpy +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @with_guaranteed_job_run_record("test_job") async def sample_job(ctx: dict, job_id: int): @@ -38,27 +40,19 @@ async def test_decorator_must_receive_ctx_as_first_argument(self, mock_worker_ct with pytest.raises(ValueError) as exc_info: await sample_job() - assert "Managed job functions must receive context as first argument" in str(exc_info.value) - - async def test_decorator_must_receive_db_in_ctx(self, mock_worker_ctx): - del mock_worker_ctx["db"] - - with pytest.raises(ValueError) as exc_info: - await sample_job(mock_worker_ctx) - - assert "DB session not found in job context" in str(exc_info.value) + assert "Managed functions must receive context as first argument" in str(exc_info.value) async def test_decorator_calls_wrapped_function(self, mock_worker_ctx): result = await sample_job(mock_worker_ctx) assert result == {"status": "ok"} - async def test_decorator_creates_job_run(self, mock_worker_ctx): + async def test_decorator_creates_job_run(self, mock_worker_ctx, session): with ( - TransactionSpy.spy(mock_worker_ctx["db"], expect_flush=True, expect_commit=True), + TransactionSpy.spy(session, expect_flush=True, expect_commit=True), ): await sample_job(mock_worker_ctx) - job_run = mock_worker_ctx["db"].execute(select(JobRun)).scalars().first() + job_run = session.execute(select(JobRun)).scalars().first() assert job_run is not None assert job_run.status == JobStatus.PENDING assert job_run.job_type == "test_job" diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index 261bdcaa..2462b4b6 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -22,6 +22,8 @@ from mavedb.worker.lib.managers.job_manager import JobManager from tests.helpers.transaction_spy import TransactionSpy +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + @with_job_management async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): @@ -58,14 +60,16 @@ async def test_decorator_must_receive_ctx_as_first_argument(self, mock_job_manag with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_job_manager.db): await sample_job() - assert "Managed job functions must receive context as first argument" in str(exc_info.value) + assert "Managed functions must receive context as first argument" in str(exc_info.value) - async def test_decorator_calls_wrapped_function_and_returns_result(self, mock_job_manager, mock_worker_ctx): + async def test_decorator_calls_wrapped_function_and_returns_result( + self, session, mock_job_manager, mock_worker_ctx + ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None), patch.object(mock_job_manager, "succeed_job", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager @@ -73,13 +77,13 @@ async def test_decorator_calls_wrapped_function_and_returns_result(self, mock_jo assert result == {"status": "ok"} async def test_decorator_calls_start_job_and_succeed_job_when_wrapped_function_succeeds( - self, mock_worker_ctx, mock_job_manager + self, session, mock_worker_ctx, mock_job_manager ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, patch.object(mock_job_manager, "succeed_job", return_value=None) as mock_succeed_job, - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager await sample_job(mock_worker_ctx, 999) @@ -88,14 +92,14 @@ async def test_decorator_calls_start_job_and_succeed_job_when_wrapped_function_s mock_succeed_job.assert_called_once() async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_raises_and_no_retry( - self, mock_worker_ctx, mock_job_manager + self, session, mock_worker_ctx, mock_job_manager ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", return_value=None) as mock_fail_job, - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager await sample_raise(mock_worker_ctx, 999) @@ -104,14 +108,14 @@ async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_rais mock_fail_job.assert_called_once() async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_raises_and_retry( - self, mock_worker_ctx, mock_job_manager + self, session, mock_worker_ctx, mock_job_manager ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, patch.object(mock_job_manager, "should_retry", return_value=True), patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager await sample_raise(mock_worker_ctx, 999) @@ -119,7 +123,7 @@ async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_r mock_start_job.assert_called_once() mock_prepare_retry.assert_called_once_with(reason="error in wrapped function") - @pytest.mark.parametrize("missing_key", ["db", "redis"]) + @pytest.mark.parametrize("missing_key", ["redis"]) async def test_decorator_raises_value_error_if_required_context_missing( self, mock_job_manager, mock_worker_ctx, missing_key ): @@ -132,36 +136,36 @@ async def test_decorator_raises_value_error_if_required_context_missing( assert "not found in job context" in str(exc_info.value).lower() async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( - self, mock_job_manager, mock_worker_ctx + self, session, mock_job_manager, mock_worker_ctx ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", side_effect=JobStateError("error in job start")), patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_rollback=True, expect_commit=True), + TransactionSpy.spy(session, expect_rollback=True, expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager result = await sample_job(mock_worker_ctx, 999) assert "error in job start" in result["exception_details"]["message"] - async def test_decorator_raises_value_error_if_job_id_missing(self, mock_job_manager, mock_worker_ctx): + async def test_decorator_raises_value_error_if_job_id_missing(self, session, mock_job_manager, mock_worker_ctx): # Remove job_id from args to simulate missing job_id - with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_worker_ctx["db"]): + with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(session): await sample_job(mock_worker_ctx) - assert "job id not found in pipeline context" in str(exc_info.value).lower() + assert "job id not found in function arguments" in str(exc_info.value).lower() async def test_decorator_swallows_exception_from_wrapped_function_inside_except( - self, mock_job_manager, mock_worker_ctx + self, session, mock_job_manager, mock_worker_ctx ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None), patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", side_effect=JobStateError("error in job fail")), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), ): mock_job_manager_class.return_value = mock_job_manager result = await sample_raise(mock_worker_ctx, 999) @@ -169,7 +173,7 @@ async def test_decorator_swallows_exception_from_wrapped_function_inside_except( # Errors within the main try block should take precedence assert "error in wrapped function" in result["exception_details"]["message"] - async def test_decorator_passes_job_manager_to_wrapped(self, mock_job_manager, mock_worker_ctx): + async def test_decorator_passes_job_manager_to_wrapped(self, session, mock_job_manager, mock_worker_ctx): @with_job_management async def assert_manager_passed_job(ctx, job_id: int, job_manager): assert isinstance(job_manager, JobManager) @@ -179,7 +183,7 @@ async def assert_manager_passed_job(ctx, job_id: int, job_manager): patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, patch.object(mock_job_manager, "start_job", return_value=None), patch.object(mock_job_manager, "succeed_job", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_job_manager_class.return_value = mock_job_manager assert await assert_manager_passed_job(mock_worker_ctx, 999) diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index d951a67b..721bb0c8 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -23,6 +23,8 @@ from mavedb.worker.lib.managers.pipeline_manager import PipelineManager from tests.helpers.transaction_spy import TransactionSpy +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + async def sample_job(ctx=None, job_id=None): """Sample job function to test the decorator. When called, it patches @@ -89,9 +91,9 @@ async def test_decorator_must_receive_ctx_as_first_argument(self, mock_pipeline_ with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): await sample_job() - assert "Managed pipeline functions must receive context as first argument" in str(exc_info.value) + assert "Managed functions must receive context as first argument" in str(exc_info.value) - @pytest.mark.parametrize("missing_key", ["db", "redis"]) + @pytest.mark.parametrize("missing_key", ["redis"]) async def test_decorator_raises_value_error_if_required_context_missing( self, mock_pipeline_manager, mock_worker_ctx, missing_key ): @@ -108,12 +110,14 @@ async def test_decorator_raises_value_error_if_job_id_missing(self, mock_pipelin with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): await sample_job(mock_worker_ctx) - assert "job id not found in pipeline context" in str(exc_info.value).lower() + assert "job id not found in function arguments" in str(exc_info.value).lower() - async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id(self, mock_pipeline_manager, mock_worker_ctx): + async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id( + self, session, mock_pipeline_manager, mock_worker_ctx + ): with ( TransactionSpy.mock_database_execution_failure( - mock_worker_ctx["db"], + session, exception=ValueError("job id not found in pipeline context"), expect_rollback=True, ), @@ -121,13 +125,13 @@ async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id(self, mock await sample_job(mock_worker_ctx, 999) async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manager( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager result = await sample_job(mock_worker_ctx, sample_job_run.id) @@ -135,14 +139,14 @@ async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manage assert result == {"status": "ok"} async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( - self, mock_pipeline_manager, mock_worker_ctx, sample_independent_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_independent_job_run, with_populated_job_data ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, # We shouldn't expect any commits since no pipeline coordination occurs - TransactionSpy.spy(mock_worker_ctx["db"]), + TransactionSpy.spy(session), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager result = await sample_job(mock_worker_ctx, sample_independent_job_run.id) @@ -152,14 +156,14 @@ async def test_decorator_skips_coordination_and_start_when_no_pipeline_exists( assert result == {"status": "ok"} async def test_decorator_starts_pipeline_when_in_created_state( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager result = await sample_job(mock_worker_ctx, sample_job_run.id) @@ -172,14 +176,14 @@ async def test_decorator_starts_pipeline_when_in_created_state( [status for status in PipelineStatus._member_map_.values() if status != PipelineStatus.CREATED], ) async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data, pipeline_state + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data, pipeline_state ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=pipeline_state), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None) as mock_start_pipeline, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager result = await sample_job(mock_worker_ctx, sample_job_run.id) @@ -188,14 +192,14 @@ async def test_decorator_does_not_start_pipeline_when_in_not_in_created_state( assert result == {"status": "ok"} async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrapped_function( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None) as mock_coordinate_pipeline, patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True), + TransactionSpy.spy(session, expect_commit=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager await sample_job(mock_worker_ctx, sample_job_run.id) @@ -203,14 +207,14 @@ async def test_decorator_calls_pipeline_manager_coordinate_pipeline_after_wrappe mock_coordinate_pipeline.assert_called_once() async def test_decorator_swallows_exception_from_wrapped_function( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager await sample_raise(mock_worker_ctx, sample_job_run.id) @@ -218,7 +222,7 @@ async def test_decorator_swallows_exception_from_wrapped_function( # TODO: Assert calls for notification hooks and job result data async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pipeline( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): with ( patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, @@ -231,7 +235,7 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), # Exception raised from coordinate_pipeline should trigger rollback, # and commit will be called when pipeline status is set to running - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager await sample_job(mock_worker_ctx, sample_job_run.id) @@ -239,7 +243,7 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip # TODO: Assert calls for notification hooks and job result data async def test_decorator_swallows_exception_from_job_management_decorator( - self, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data + self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data ): def passthrough_decorator(f): return f @@ -254,7 +258,7 @@ def passthrough_decorator(f): patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, - TransactionSpy.spy(mock_worker_ctx["db"], expect_commit=True, expect_rollback=True), + TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), ): mock_pipeline_manager_class.return_value = mock_pipeline_manager From b2c5fe752b42a758f74a5b4673ccd644241171b3 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 16:00:01 -0800 Subject: [PATCH 40/70] feat: add new job definitions for score set annotation pipeline --- src/mavedb/lib/workflow/definitions.py | 97 ++++++++++++++------------ 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/src/mavedb/lib/workflow/definitions.py b/src/mavedb/lib/workflow/definitions.py index 49aa4dd7..54a7b645 100644 --- a/src/mavedb/lib/workflow/definitions.py +++ b/src/mavedb/lib/workflow/definitions.py @@ -1,9 +1,57 @@ -from mavedb.lib.types.workflow import PipelineDefinition +from mavedb.lib.types.workflow import JobDefinition, PipelineDefinition from mavedb.models.enums.job_pipeline import DependencyType, JobType # As a general rule, job keys should match function names for clarity. In some cases of # repeated jobs, a suffix may be added to the key for uniqueness. + +def annotation_pipeline_job_definitions() -> list[JobDefinition]: + return [ + { + "key": "submit_score_set_mappings_to_car", + "function": "submit_score_set_mappings_to_car", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "updater_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "link_gnomad_variants", + "function": "link_gnomad_variants", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "submit_uniprot_mapping_jobs_for_score_set", + "function": "submit_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + }, + "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "poll_uniprot_mapping_jobs_for_score_set", + "function": "poll_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "mapping_jobs": {}, # Required param to be filled in at runtime by previous job + }, + "dependencies": [("submit_uniprot_mapping_jobs_for_score_set", DependencyType.SUCCESS_REQUIRED)], + }, + ] + + PIPELINE_DEFINITIONS: dict[str, PipelineDefinition] = { "validate_map_annotate_score_set": { "description": "Pipeline to validate, map, and annotate variants for a score set.", @@ -34,49 +82,12 @@ }, "dependencies": [("create_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], }, - { - "key": "submit_score_set_mappings_to_car", - "function": "submit_score_set_mappings_to_car", - "type": JobType.MAPPED_VARIANT_ANNOTATION, - "params": { - "correlation_id": None, # Required param to be filled in at runtime - "score_set_id": None, # Required param to be filled in at runtime - "updater_id": None, # Required param to be filled in at runtime - }, - "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], - }, - { - "key": "link_gnomad_variants", - "function": "link_gnomad_variants", - "type": JobType.MAPPED_VARIANT_ANNOTATION, - "params": { - "correlation_id": None, # Required param to be filled in at runtime - "score_set_id": None, # Required param to be filled in at runtime - }, - "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], - }, - { - "key": "submit_uniprot_mapping_jobs_for_score_set", - "function": "submit_uniprot_mapping_jobs_for_score_set", - "type": JobType.MAPPED_VARIANT_ANNOTATION, - "params": { - "correlation_id": None, # Required param to be filled in at runtime - "score_set_id": None, # Required param to be filled in at runtime - }, - "dependencies": [("map_variants_for_score_set", DependencyType.SUCCESS_REQUIRED)], - }, - { - "key": "poll_uniprot_mapping_jobs_for_score_set", - "function": "poll_uniprot_mapping_jobs_for_score_set", - "type": JobType.MAPPED_VARIANT_ANNOTATION, - "params": { - "correlation_id": None, # Required param to be filled in at runtime - "score_set_id": None, # Required param to be filled in at runtime - "mapping_jobs": {}, # Required param to be filled in at runtime by previous job - }, - "dependencies": [("submit_uniprot_mapping_jobs_for_score_set", DependencyType.SUCCESS_REQUIRED)], - }, + *annotation_pipeline_job_definitions(), ], }, + "annotate_score_set": { + "description": "Pipeline to annotate variants for a score set.", + "job_definitions": annotation_pipeline_job_definitions(), + }, # Add more pipelines here } From 5ca9d3f546f6cb8af81e8ca9a655930abf502843 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 19:01:16 -0800 Subject: [PATCH 41/70] feat: implement AnnotationStatusManager for managing variant annotation statuses --- src/mavedb/lib/annotation_status_manager.py | 146 ++++++ tests/lib/test_annotation_status_manager.py | 495 ++++++++++++++++++++ 2 files changed, 641 insertions(+) create mode 100644 src/mavedb/lib/annotation_status_manager.py create mode 100644 tests/lib/test_annotation_status_manager.py diff --git a/src/mavedb/lib/annotation_status_manager.py b/src/mavedb/lib/annotation_status_manager.py new file mode 100644 index 00000000..628846da --- /dev/null +++ b/src/mavedb/lib/annotation_status_manager.py @@ -0,0 +1,146 @@ +"""Manage annotation statuses for variants. + +This module provides functionality to insert and retrieve annotation statuses +for genetic variants, ensuring that only one current status exists per +(variant, annotation type, version) combination. +""" + +import logging +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + +logger = logging.getLogger(__name__) + + +class AnnotationStatusManager: + """ + Manager for handling variant annotation statuses. + + Attributes: + session (Session): The SQLAlchemy session used for database operations. + + Methods: + add_annotation( + variant_id: int, + annotation_type: AnnotationType, + version: Optional[str], + annotation_data: dict, + current: bool = True + ) -> VariantAnnotationStatus: + Inserts a new annotation status and marks previous ones as not current. + + get_current_annotation( + variant_id: int, + annotation_type: AnnotationType, + version: Optional[str] = None + ) -> Optional[VariantAnnotationStatus]: + Retrieves the current annotation status for a given variant/type/version. + """ + + def __init__(self, session: Session): + self.session = session + + def add_annotation( + self, + variant_id: int, + annotation_type: AnnotationType, + status: AnnotationStatus, + version: Optional[str] = None, + annotation_data: dict = {}, + current: bool = True, + ) -> VariantAnnotationStatus: + """ + Insert a new annotation and mark previous ones as not current for the same (variant, type, version). + Callers should take care to ensure only one current annotation exists per (variant, type, version). Note + + Args: + variant_id (int): The ID of the variant being annotated. + annotation_type (AnnotationType): The type of annotation (e.g., 'vrs', 'clinvar'). + version (Optional[str]): The version of the annotation source. + annotation_data (dict): Additional data for the annotation status. + current (bool): Whether this annotation is the current one. + + Returns: + VariantAnnotationStatus: The newly created annotation status record. + + Side Effects: + - Updates existing records to set current=False for the same (variant, type, version). + - Adds a new VariantAnnotationStatus record to the database session. + + NOTE: + - This method does not commit the session and only flushes to the database. The caller + is responsible for persisting any changes (e.g., by calling session.commit()). + """ + logger.debug( + f"Adding annotation for variant_id={variant_id}, annotation_type={annotation_type.value}, version={version}" + ) + + # Find existing current annotations to be replaced + existing_current = ( + self.session.execute( + select(VariantAnnotationStatus).where( + VariantAnnotationStatus.variant_id == variant_id, + VariantAnnotationStatus.annotation_type == annotation_type.value, + VariantAnnotationStatus.version == version, + VariantAnnotationStatus.current.is_(True), + ) + ) + .scalars() + .all() + ) + for var_ann in existing_current: + logger.debug( + f"Replacing current annotation {var_ann.id} for variant_id={variant_id}, annotation_type={annotation_type.value}, version={version}" + ) + var_ann.current = False + + self.session.flush() + + new_status = VariantAnnotationStatus( + variant_id=variant_id, + annotation_type=annotation_type.value, + status=status.value, + version=version, + current=current, + **annotation_data, + ) # type: ignore[call-arg] + + self.session.add(new_status) + self.session.flush() + + logger.info( + f"Successfully added annotation for variant_id={variant_id}, annotation_type={annotation_type.value}, version={version}" + ) + return new_status + + def get_current_annotation( + self, variant_id: int, annotation_type: AnnotationType, version: Optional[str] = None + ) -> Optional[VariantAnnotationStatus]: + """ + Retrieve the current annotation for a given variant/type/version. + + Args: + variant_id (int): The ID of the variant. + annotation_type (AnnotationType): The type of annotation. + version (Optional[str]): The version of the annotation source. + + Returns: + Optional[VariantAnnotationStatus]: The current annotation status record, or None if not found. + """ + stmt = select(VariantAnnotationStatus).where( + VariantAnnotationStatus.variant_id == variant_id, + VariantAnnotationStatus.annotation_type == annotation_type.value, + VariantAnnotationStatus.current.is_(True), + ) + + if version is not None: + stmt = stmt.where(VariantAnnotationStatus.version == version) + + result = self.session.execute(stmt) + return result.scalar_one_or_none() diff --git a/tests/lib/test_annotation_status_manager.py b/tests/lib/test_annotation_status_manager.py new file mode 100644 index 00000000..633cc848 --- /dev/null +++ b/tests/lib/test_annotation_status_manager.py @@ -0,0 +1,495 @@ +import pytest + +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.variant import Variant + + +@pytest.fixture +def annotation_status_manager(session): + """Fixture to provide an AnnotationStatusManager instance.""" + return AnnotationStatusManager(session) + + +@pytest.fixture +def existing_annotation_status(session, annotation_status_manager, setup_lib_db_with_variant): + """Fixture to create an existing annotation status in the database.""" + + # Add initial annotation + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + return annotation + + +@pytest.fixture +def existing_unversioned_annotation_status(session, annotation_status_manager, setup_lib_db_with_variant): + """Fixture to create an existing annotation status in the database.""" + + # Add initial annotation + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=None, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + return annotation + + +@pytest.mark.unit +class TestAnnotationStatusManagerCreateAnnotationUnit: + """Unit tests for AnnotationStatusManager.add_annotation method.""" + + @pytest.mark.parametrize( + "annotation_type", + AnnotationType._member_map_.values(), + ) + @pytest.mark.parametrize( + "status", + AnnotationStatus._member_map_.values(), + ) + def test_add_annotation_creates_entry_with_annotation_type_version_status( + self, session, annotation_status_manager, annotation_type, status, setup_lib_db_with_variant + ): + """Test that adding an annotation creates a new entry with correct type and version.""" + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=annotation_type, + version="v1.0", + annotation_data={}, + current=True, + status=status, + ) + session.commit() + + assert annotation.annotation_type == annotation_type.value + assert annotation.status == status.value + assert annotation.version == "v1.0" + + def test_add_annotation_persists_annotation_data( + self, session, annotation_status_manager, setup_lib_db_with_variant + ): + """Test that adding an annotation persists the provided annotation data.""" + annotation_data = { + "success_data": {"some_key": "some_value"}, + "error_message": None, + "failure_category": None, + } + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + status=AnnotationStatus.SUCCESS, + version="v1.0", + annotation_data=annotation_data, + current=True, + ) + session.commit() + + for key, value in annotation_data.items(): + assert getattr(annotation, key) == value + + def test_add_annotation_creates_entry_and_marks_previous_not_current( + self, session, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation creates a new entry and marks previous ones as not current.""" + manager = AnnotationStatusManager(session) + + # Add second annotation for same (variant, type, version) + annotation = manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_annotation_status) + assert existing_annotation_status.current is False + + def test_add_annotation_with_different_version_keeps_previous_current( + self, session, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation with a different version keeps previous current.""" + manager = AnnotationStatusManager(session) + + # Add second annotation for same (variant, type) but different version + annotation = manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v2", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_annotation_status) + assert existing_annotation_status.current is True + + def test_add_annotation_with_different_type_keeps_previous_current( + self, session, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation with a different type keeps previous current.""" + manager = AnnotationStatusManager(session) + + # Add second annotation for same variant but different type + annotation = manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_annotation_status) + assert existing_annotation_status.current is True + + def test_add_annotation_without_version(self, session, annotation_status_manager, setup_lib_db_with_variant): + """Test that adding an annotation without specifying version works correctly.""" + annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VEP_FUNCTIONAL_CONSEQUENCE, + version=None, + annotation_data={}, + status=AnnotationStatus.SKIPPED, + current=True, + ) + session.commit() + + assert annotation.id is not None + assert annotation.version is None + assert annotation.current is True + + def test_add_annotation_multiple_without_version_marks_previous_not_current( + self, session, annotation_status_manager, existing_unversioned_annotation_status, setup_lib_db_with_variant + ): + """Test that adding multiple annotations without version marks previous ones as not current.""" + + # Add second annotation without version + second_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=None, + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + assert second_annotation.id is not None + assert second_annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_unversioned_annotation_status) + assert existing_unversioned_annotation_status.current is False + + def test_add_annotation_different_type_without_version_keeps_previous_current( + self, session, annotation_status_manager, existing_unversioned_annotation_status, setup_lib_db_with_variant + ): + """Test that adding an annotation of different type without version keeps previous current.""" + + # Add second annotation of different type without version + second_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version=None, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert second_annotation.id is not None + assert second_annotation.current is True + + # Refresh first annotation from DB + session.refresh(existing_unversioned_annotation_status) + assert existing_unversioned_annotation_status.current is True + + def test_add_annotation_multiple_variants_independent_current_flags( + self, session, annotation_status_manager, setup_lib_db_with_score_set + ): + """Test that adding annotations for different variants maintains independent current flags.""" + + variant1 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={}) + variant2 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.2A>T", hgvs_pro="NP_000000.1:p.Met2Val", data={}) + session.add_all([variant1, variant2]) + session.commit() + session.refresh(variant1) + session.refresh(variant2) + + # Add annotation for variant 1 + annotation1 = annotation_status_manager.add_annotation( + variant_id=variant1.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Add annotation for variant 2 + annotation2 = annotation_status_manager.add_annotation( + variant_id=variant2.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + assert annotation1.id is not None + assert annotation1.current is True + + assert annotation2.id is not None + assert annotation2.current is True + + +class TestAnnotationStatusManagerGetCurrentAnnotationUnit: + """Unit tests for AnnotationStatusManager.get_current_annotation method.""" + + def test_get_current_annotation_returns_none_when_no_entry( + self, annotation_status_manager, setup_lib_db_with_variant + ): + """Test that getting current annotation returns None when no entry exists.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + assert annotation is None + + def test_get_current_annotation_returns_correct_entry( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation returns the correct entry.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + assert annotation.id == existing_annotation_status.id + assert annotation.current is True + + def test_get_current_annotation_returns_none_for_non_current( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation returns None when the entry is not current.""" + # Mark existing annotation as not current + existing_annotation_status.current = False + session.commit() + + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + assert annotation is None + + def test_get_current_annotation_with_different_version_returns_none( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation with different version returns None.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v2", + ) + assert annotation is None + + def test_get_current_annotation_with_different_type_returns_none( + self, session, annotation_status_manager, existing_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation with different type returns None.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version="v1", + ) + assert annotation is None + + def test_get_current_annotation_without_version_returns_correct_entry( + self, session, annotation_status_manager, existing_unversioned_annotation_status, setup_lib_db_with_variant + ): + """Test that getting current annotation without version returns the correct entry.""" + annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=None, + ) + assert annotation.id == existing_unversioned_annotation_status.id + assert annotation.current is True + + +class TestAnnotationStatusManagerIntegration: + """Integration tests for AnnotationStatusManager methods.""" + + def test_add_and_get_current_annotation_work_together( + self, session, annotation_status_manager, setup_lib_db_with_variant + ): + """Test that adding and getting current annotation work together correctly.""" + # Add annotation + added_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Get current annotation + retrieved_annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version="v1", + ) + + assert retrieved_annotation is not None + assert retrieved_annotation.id == added_annotation.id + assert retrieved_annotation.current is True + assert retrieved_annotation.status == AnnotationStatus.SUCCESS + + @pytest.mark.parametrize( + "version", + ["v1.0", "v2.0", None], + ) + def test_add_multiple_and_get_current_returns_latest( + self, session, annotation_status_manager, version, setup_lib_db_with_variant + ): + """Test that adding multiple annotations and getting current returns the latest one.""" + # Add first annotation + annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + # Add second annotation + second_annotation = annotation_status_manager.add_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Get current annotation + retrieved_annotation = annotation_status_manager.get_current_annotation( + variant_id=setup_lib_db_with_variant.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + ) + + assert retrieved_annotation is not None + assert retrieved_annotation.id == second_annotation.id + assert retrieved_annotation.current is True + assert retrieved_annotation.version == version + assert retrieved_annotation.status == AnnotationStatus.SUCCESS + + @pytest.mark.parametrize( + "version", + ["v1.0", "v2.0", None], + ) + def test_add_annotations_for_different_variants_and_get_current_independent( + self, session, annotation_status_manager, version, setup_lib_db_with_score_set + ): + """Test that adding annotations for different variants and getting current works independently.""" + + variant1 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={}) + variant2 = Variant(score_set_id=1, hgvs_nt="NM_000000.1:c.2A>T", hgvs_pro="NP_000000.1:p.Met2Val", data={}) + session.add_all([variant1, variant2]) + session.commit() + session.refresh(variant1) + session.refresh(variant2) + + # Add annotation for variant 1 + annotation1 = annotation_status_manager.add_annotation( + variant_id=variant1.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.SUCCESS, + current=True, + ) + session.commit() + + # Add annotation for variant 2 + annotation2 = annotation_status_manager.add_annotation( + variant_id=variant2.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + annotation_data={}, + status=AnnotationStatus.FAILED, + current=True, + ) + session.commit() + + # Get current annotation for variant 1 + retrieved_annotation1 = annotation_status_manager.get_current_annotation( + variant_id=variant1.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + ) + + assert retrieved_annotation1 is not None + assert retrieved_annotation1.id == annotation1.id + assert retrieved_annotation1.current is True + assert retrieved_annotation1.status == AnnotationStatus.SUCCESS + assert retrieved_annotation1.version == version + + # Get current annotation for variant 2 + retrieved_annotation2 = annotation_status_manager.get_current_annotation( + variant_id=variant2.id, + annotation_type=AnnotationType.VRS_MAPPING, + version=version, + ) + + assert retrieved_annotation2 is not None + assert retrieved_annotation2.id == annotation2.id + assert retrieved_annotation2.current is True + assert retrieved_annotation2.status == AnnotationStatus.FAILED + assert retrieved_annotation2.version == version From 806f8ed5c5cce47691077e05bc81534894e9a228 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 19:02:04 -0800 Subject: [PATCH 42/70] feat: add annotation status tracking to jobs --- src/mavedb/lib/gnomad.py | 16 + .../worker/jobs/data_management/views.py | 4 +- .../worker/jobs/external_services/clingen.py | 106 +++++- .../worker/jobs/external_services/gnomad.py | 48 ++- .../worker/jobs/external_services/uniprot.py | 22 +- src/mavedb/worker/jobs/jobs.md | 1 + .../pipeline_management/start_pipeline.py | 2 +- .../jobs/variant_processing/creation.py | 17 +- .../worker/jobs/variant_processing/mapping.py | 40 +- .../worker/lib/decorators/job_management.py | 4 +- tests/conftest_optional.py | 3 +- .../worker/jobs/data_management/test_views.py | 8 +- .../jobs/external_services/test_clingen.py | 347 ++++++++++++++++-- .../jobs/external_services/test_gnomad.py | 45 ++- .../jobs/external_services/test_uniprot.py | 9 +- .../jobs/variant_processing/test_creation.py | 45 ++- .../jobs/variant_processing/test_mapping.py | 287 +++++++++++++-- 17 files changed, 869 insertions(+), 135 deletions(-) create mode 100644 src/mavedb/worker/jobs/jobs.md diff --git a/src/mavedb/lib/gnomad.py b/src/mavedb/lib/gnomad.py index 937471b8..ea76d613 100644 --- a/src/mavedb/lib/gnomad.py +++ b/src/mavedb/lib/gnomad.py @@ -6,8 +6,11 @@ from sqlalchemy import Connection, Row, select, text from sqlalchemy.orm import Session +from mavedb.lib.annotation_status_manager import AnnotationStatusManager from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.utils import batched +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.mapped_variant import MappedVariant @@ -168,6 +171,7 @@ def link_gnomad_variants_to_mapped_variants( if faf95_max is not None: faf95_max = float(faf95_max) + annotation_manager = AnnotationStatusManager(db) for mapped_variant in mapped_variants_with_caids: # Remove any existing gnomAD variants for this mapped variant that match the current gnomAD data version to avoid data duplication. # There should only be one gnomAD variant per mapped variant per gnomAD data version, since each gnomAD variant can only match to one @@ -215,6 +219,18 @@ def link_gnomad_variants_to_mapped_variants( linked_gnomad_variants += 1 db.add(gnomad_variant) + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.GNOMAD_ALLELE_FREQUENCY, + version=GNOMAD_DATA_VERSION, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "success_data": { + "gnomad_db_identifier": gnomad_variant.db_identifier, + } + }, + current=True, + ) logger.debug( msg=f"Linked gnomAD variant {gnomad_variant.db_identifier} to mapped variant {mapped_variant.id} ({mapped_variant.clingen_allele_id})", diff --git a/src/mavedb/worker/jobs/data_management/views.py b/src/mavedb/worker/jobs/data_management/views.py index 24e5fac8..d93c38a2 100644 --- a/src/mavedb/worker/jobs/data_management/views.py +++ b/src/mavedb/worker/jobs/data_management/views.py @@ -55,7 +55,7 @@ async def refresh_materialized_views(ctx: dict, job_id: int, job_manager: JobMan # Do refresh refresh_all_mat_views(job_manager.db) - job_manager.db.commit() + job_manager.db.flush() # Finalize job state job_manager.update_progress(100, 100, "Completed refresh of all materialized views.") @@ -105,7 +105,7 @@ async def refresh_published_variants_view(ctx: dict, job_id: int, job_manager: J # Do refresh PublishedVariantsMV.refresh(job_manager.db) - job_manager.db.commit() + job_manager.db.flush() # Finalize job state job_manager.update_progress(100, 100, "Completed refresh of published variants materialized view.") diff --git a/src/mavedb/worker/jobs/external_services/clingen.py b/src/mavedb/worker/jobs/external_services/clingen.py index 5d0de7f7..4fe61a6d 100644 --- a/src/mavedb/worker/jobs/external_services/clingen.py +++ b/src/mavedb/worker/jobs/external_services/clingen.py @@ -15,6 +15,7 @@ from sqlalchemy import select +from mavedb.lib.annotation_status_manager import AnnotationStatusManager from mavedb.lib.clingen.constants import ( CAR_SUBMISSION_ENDPOINT, CLIN_GEN_SUBMISSION_ENABLED, @@ -29,6 +30,8 @@ ) from mavedb.lib.exceptions import LDHSubmissionFailureError from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant @@ -154,18 +157,33 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: # Process registered alleles and update mapped variants linked_alleles = get_allele_registry_associations(list(variant_post_mapped_hgvs.keys()), registered_alleles) - processed = 0 total = len(linked_alleles) + processed = 0 + # Setup annotation manager + annotation_manager = AnnotationStatusManager(job_manager.db) + registered_mapped_variant_ids = [] for hgvs_string, caid in linked_alleles.items(): mapped_variant_ids = variant_post_mapped_hgvs[hgvs_string] + registered_mapped_variant_ids.extend(mapped_variant_ids) mapped_variants = job_manager.db.scalars( select(MappedVariant).where(MappedVariant.id.in_(mapped_variant_ids)) ).all() - # TODO: Track annotation progress. for mapped_variant in mapped_variants: mapped_variant.clingen_allele_id = caid job_manager.db.add(mapped_variant) + + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version=None, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "success_data": {"clingen_allele_id": caid}, + }, + current=True, + ) + processed += 1 # Calculate progress: 50% + (processed/total_mapped)*50, rounded to nearest 5% @@ -173,9 +191,27 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: progress = 50 + round((processed / total) * 45 / 5) * 5 job_manager.update_progress(progress, 100, f"Processed {processed} of {total} registered alleles.") + # For mapped variants which did not get a CAID, log failure annotation + failed_submissions = set(obj[0] for obj in variant_post_mapped_objects) - set(registered_mapped_variant_ids) + for mapped_variant_id in failed_submissions: + mapped_variant = job_manager.db.scalars( + select(MappedVariant).where(MappedVariant.id == mapped_variant_id) + ).one() + + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINGEN_ALLELE_ID, + version=None, + status=AnnotationStatus.FAILED, + annotation_data={ + "error_message": "Failed to register variant with ClinGen Allele Registry.", + }, + current=True, + ) + # Finalize progress job_manager.update_progress(100, 100, "Completed CAR mapped resource submission.") - job_manager.db.commit() + job_manager.db.flush() logger.info(msg="Completed CAR mapped resource submission", extra=job_manager.logging_context()) return {"status": "ok", "data": {}, "exception_details": None} @@ -251,6 +287,7 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: # Build submission content variant_content = [] + variant_for_urn = {} for variant, mapped_variant in variant_objects: variation = get_hgvs_from_post_mapped(mapped_variant.post_mapped) @@ -262,6 +299,7 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: continue variant_content.append((variation, variant, mapped_variant)) + variant_for_urn[variant.urn] = variant if not variant_content: job_manager.update_progress(100, 100, "No valid mapped variants to submit to LDH. Skipping submission.") @@ -288,7 +326,53 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: } ) - # TODO: Track submission successes and failures, add as annotation features. + # TODO prior to finalizing: Verify typing of ClinGen submission responses. See https://reg.clinicalgenome.org/doc/AlleleRegistry_1.01.xx_api_v1.pdf + annotation_manager = AnnotationStatusManager(job_manager.db) + submitted_variant_urns = set() + for success in submission_successes: + logger.debug( + msg=f"Successfully submitted mapped variant to LDH: {success}", + extra=job_manager.logging_context(), + ) + + submitted_urn = success["data"]["entId"] + submitted_variant = variant_for_urn[submitted_urn] + + annotation_manager.add_annotation( + variant_id=submitted_variant.id, + annotation_type=AnnotationType.LDH_SUBMISSION, + version=None, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "success_data": {"ldh_iri": success["data"]["ldhIri"], "ldh_id": success["data"]["ldhId"]}, + }, + current=True, + ) + submitted_variant_urns.add(submitted_urn) + + # It isn't trivial to map individual failures back to their corresponding variants, + # especially when submission occurred in batch. Save all failures generically here. + # Note that failures may not be present in the submission failures list, but they are + # guaranteed to be absent from the successes list. + for failure_urn in set(variant_for_urn.keys()) - submitted_variant_urns: + logger.error( + msg=f"Failed to submit mapped variant to LDH: {failure_urn}", + extra=job_manager.logging_context(), + ) + + failed_variant = variant_for_urn[failure_urn] + + annotation_manager.add_annotation( + variant_id=failed_variant.id, + annotation_type=AnnotationType.LDH_SUBMISSION, + version=None, + status=AnnotationStatus.FAILED, + annotation_data={ + "error_message": "Failed to submit variant to ClinGen Linked Data Hub.", + }, + current=True, + ) + if submission_failures: logger.warning( msg=f"LDH mapped resource submission encountered {len(submission_failures)} failures.", @@ -303,7 +387,17 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: extra=job_manager.logging_context(), ) - raise LDHSubmissionFailureError(error_message) + # Return a failure state here rather than raising to indicate to the manager + # we should still commit any successful annotations. + return { + "status": "failed", + "data": {}, + "exception_details": { + "message": error_message, + "type": LDHSubmissionFailureError.__name__, + "traceback": None, + }, + } logger.info( msg="Completed LDH mapped resource submission", @@ -316,5 +410,5 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: 100, f"Finalized LDH mapped resource submission ({len(submission_successes)} successes, {len(submission_failures)} failures).", ) - job_manager.db.commit() + job_manager.db.flush() return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/external_services/gnomad.py b/src/mavedb/worker/jobs/external_services/gnomad.py index b63b1be6..87d6bf69 100644 --- a/src/mavedb/worker/jobs/external_services/gnomad.py +++ b/src/mavedb/worker/jobs/external_services/gnomad.py @@ -12,7 +12,14 @@ from sqlalchemy import select from mavedb.db import athena -from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.lib.gnomad import ( + GNOMAD_DATA_VERSION, + gnomad_variant_data_for_caids, + link_gnomad_variants_to_mapped_variants, +) +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant @@ -105,22 +112,41 @@ async def link_gnomad_variants(ctx: dict, job_id: int, job_manager: JobManager) num_gnomad_variants_with_caid_match = len(gnomad_variant_data) - job_manager.save_to_context({"num_gnomad_variants_with_caid_match": num_gnomad_variants_with_caid_match}) - - if not gnomad_variant_data: - job_manager.update_progress(100, 100, "No gnomAD variants with CAID matches found. Nothing to link.") - logger.warning( - msg="No gnomAD variants with CAID matches were found for this score set. Skipping gnomAD linkage (nothing to do).", - extra=job_manager.logging_context(), - ) + # NOTE: Proceed intentionally with linking even if no matches were found, to record skipped annotations. - return {"status": "ok", "data": {}, "exception_details": None} + job_manager.save_to_context({"num_gnomad_variants_with_caid_match": num_gnomad_variants_with_caid_match}) job_manager.update_progress(75, 100, f"Found {num_gnomad_variants_with_caid_match} gnomAD variants matching CAIDs.") # Link mapped variants to gnomAD variants logger.info(msg="Attempting to link mapped variants to gnomAD variants.", extra=job_manager.logging_context()) num_linked_gnomad_variants = link_gnomad_variants_to_mapped_variants(job_manager.db, gnomad_variant_data) - job_manager.db.commit() + job_manager.db.flush() + + # For variants which are not linked, create annotation status records indicating skipped linkage + mapped_variants_with_caids = job_manager.db.scalars( + select(MappedVariant) + .join(Variant) + .join(ScoreSet) + .where( + ScoreSet.urn == score_set.urn, + MappedVariant.current.is_(True), + MappedVariant.clingen_allele_id.is_not(None), + ) + ).all() + annotation_manager = AnnotationStatusManager(job_manager.db) + for mapped_variant in mapped_variants_with_caids: + if not mapped_variant.gnomad_variants: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.GNOMAD_ALLELE_FREQUENCY, + version=GNOMAD_DATA_VERSION, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "error_message": "No gnomAD variant could be linked for this mapped variant.", + "failure_category": "not_found", + }, + current=True, + ) # Save final context and progress job_manager.save_to_context({"num_mapped_variants_linked_to_gnomad_variants": num_linked_gnomad_variants}) diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py index fccfdadf..ac99c5ed 100644 --- a/src/mavedb/worker/jobs/external_services/uniprot.py +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -95,7 +95,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ # Preset submitted jobs metadata so it persists even if no jobs are submitted. job.metadata_["submitted_jobs"] = {} - job_manager.db.commit() + job_manager.db.flush() if not score_set.target_genes: job_manager.update_progress(100, 100, "No target genes found. Skipped UniProt mapping job submission.") @@ -155,7 +155,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ # Save submitted jobs to job metadata for auditing purposes job.metadata_["submitted_jobs"] = mapping_jobs flag_modified(job, "metadata_") - job_manager.db.commit() + job_manager.db.flush() # If no mapping jobs were submitted, log and exit early. if not mapping_jobs or not any((job_info["job_id"] for job_info in mapping_jobs.values())): @@ -175,9 +175,17 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ extra=job_manager.logging_context(), ) - raise UniProtPollingEnqueueError( - f"Could not find unique dependent polling job for UniProt mapping job {job.id}." - ) + # Return a failure state here rather than raising to indicate to the manager + # we should still commit any successful annotations. + return { + "status": "failed", + "data": {}, + "exception_details": { + "type": UniProtPollingEnqueueError.__name__, + "message": f"Could not find unique dependent polling job for UniProt mapping job {job.id}.", + "traceback": None, + }, + } # Set mapping jobs on dependent polling job. Only one polling job per score set should be created. polling_job = dependent_polling_job[0].job_run @@ -188,7 +196,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ job_manager.update_progress(100, 100, "Completed submission of UniProt mapping jobs.") logger.info(msg="Completed UniProt mapping job submission", extra=job_manager.logging_context()) - job_manager.db.commit() + job_manager.db.flush() return {"status": "ok", "data": {}, "exception_details": None} @@ -312,5 +320,5 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ma ) job_manager.update_progress(100, 100, "Completed polling of UniProt mapping jobs.") - job_manager.db.commit() + job_manager.db.flush() return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/jobs.md b/src/mavedb/worker/jobs/jobs.md new file mode 100644 index 00000000..30404ce4 --- /dev/null +++ b/src/mavedb/worker/jobs/jobs.md @@ -0,0 +1 @@ +TODO \ No newline at end of file diff --git a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py index c67472e5..ddd28f7c 100644 --- a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py +++ b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py @@ -52,7 +52,7 @@ async def start_pipeline(ctx: dict, job_id: int, job_manager: JobManager) -> Job await pipeline_manager.coordinate_pipeline() # Finalize job state - job_manager.db.commit() + job_manager.db.flush() job_manager.update_progress(100, 100, "Initial pipeline coordination complete.") logger.debug(msg="Done starting pipeline.", extra=job_manager.logging_context()) diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py index 37b7605e..87f1aecf 100644 --- a/src/mavedb/worker/jobs/variant_processing/creation.py +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -140,8 +140,9 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job {"processing_state": score_set.processing_state.name, "mapping_state": score_set.mapping_state.name} ) + # Flush initial score set state job_manager.db.add(score_set) - job_manager.db.commit() + job_manager.db.flush() job_manager.db.refresh(score_set) job_manager.update_progress(10, 100, "Validated score set metadata and beginning data validation.") @@ -226,7 +227,15 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job msg="Encountered an internal exception while processing variants.", extra=job_manager.logging_context() ) - raise e + return { + "status": "failed", + "data": {}, + "exception_details": { + "message": str(e), + "type": e.__class__.__name__, + "traceback": format_raised_exception_info_as_dict(e).get("traceback", ""), + }, + } else: score_set.processing_state = ProcessingState.success @@ -243,9 +252,9 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job finally: job_manager.db.add(score_set) - job_manager.db.commit() + job_manager.db.flush() job_manager.db.refresh(score_set) job_manager.update_progress(100, 100, "Completed variant creation job.") - logger.info(msg="Committed new variants to score set.", extra=job_manager.logging_context()) + logger.info(msg="Added new variants to score set.", extra=job_manager.logging_context()) return {"status": "ok", "data": {}, "exception_details": None} diff --git a/src/mavedb/worker/jobs/variant_processing/mapping.py b/src/mavedb/worker/jobs/variant_processing/mapping.py index 184041ea..bb43a43e 100644 --- a/src/mavedb/worker/jobs/variant_processing/mapping.py +++ b/src/mavedb/worker/jobs/variant_processing/mapping.py @@ -15,6 +15,7 @@ from sqlalchemy.dialects.postgresql import JSONB from mavedb.data_providers.services import vrs_mapper +from mavedb.lib.annotation_status_manager import AnnotationStatusManager from mavedb.lib.exceptions import ( NonexistentMappingReferenceError, NonexistentMappingResultsError, @@ -23,6 +24,9 @@ from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.lib.mapping import ANNOTATION_LAYERS, EXCLUDED_PREMAPPED_ANNOTATION_KEYS from mavedb.lib.slack import send_slack_error +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus from mavedb.models.enums.mapping_state import MappingState from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet @@ -84,7 +88,7 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan score_set.modification_date = date.today() job_manager.db.add(score_set) - job_manager.db.commit() + job_manager.db.flush() job_manager.save_to_context({"mapping_state": score_set.mapping_state.name}) job_manager.update_progress(10, 100, "Score set prepared for variant mapping.") @@ -196,6 +200,7 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan job_manager.update_progress(90, 100, "Saving mapped variants.") successful_mapped_variants = 0 + annotation_manager = AnnotationStatusManager(job_manager.db) for mapped_score in mapped_scores: variant_urn = mapped_score.get("mavedb_id") variant = job_manager.db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() @@ -216,7 +221,8 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan job_manager.db.add(existing_mapped_variant) logger.debug(msg="Set existing mapped variant to current = false.", extra=job_manager.logging_context()) - if mapped_score.get("pre_mapped") and mapped_score.get("post_mapped"): + annotation_was_successful = mapped_score.get("pre_mapped") and mapped_score.get("post_mapped") + if annotation_was_successful: successful_mapped_variants += 1 job_manager.save_to_context({"successful_mapped_variants": successful_mapped_variants}) @@ -232,6 +238,21 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan current=True, ) + annotation_manager.add_annotation( + variant_id=variant.id, # type: ignore + annotation_type=AnnotationType.VRS_MAPPING, + version=mapped_score.get("vrs_version", null()), + status=AnnotationStatus.SUCCESS if annotation_was_successful else AnnotationStatus.FAILED, + annotation_data={ + "error_message": mapped_score.get("error_message", null()), + "job_run_id": job.id, + "success_data": { + "mapped_assay_level_hgvs": get_hgvs_from_post_mapped(mapped_score.get("post_mapped", {})), + }, + }, + current=True, + ) + job_manager.db.add(mapped_variant) logger.debug(msg="Added new mapped variant to session.", extra=job_manager.logging_context()) @@ -259,7 +280,11 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan score_set.mapping_state = MappingState.failed # These exceptions have already set mapping_errors appropriately - raise e # Re-raise to be handled by the job management system + return { + "status": "error", + "data": {}, + "exception_details": {"message": str(e), "type": e.__class__.__name__, "traceback": None}, + } except Exception as e: send_slack_error(e) @@ -275,12 +300,15 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan } job_manager.update_progress(100, 100, "Variant mapping failed due to an unexpected error.") - # Raise unexpected exceptions to be handled by the job management system - raise e + return { + "status": "error", + "data": {}, + "exception_details": {"message": str(e), "type": e.__class__.__name__, "traceback": None}, + } finally: job_manager.db.add(score_set) - job_manager.db.commit() + job_manager.db.flush() logger.info(msg="Inserted mapped variants into db.", extra=job_manager.logging_context()) job_manager.update_progress(100, 100, "Finished processing mapped variants.") diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 272c96bf..7adee374 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -118,7 +118,9 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar # Execute the async function result = await func(*args, **kwargs) - # Mark job as succeeded and persist state + # Mark job as succeeded and persist state. As a general rule, jobs do not + # commit their own state and we do not persist their state until we mark + # them as succeeded. job_manager.succeed_job(result=result) db_session.commit() diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index acbeec63..d5a1bbd8 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -124,9 +124,8 @@ async def on_job(ctx): @pytest.fixture -def standalone_worker_context(session, data_provider, arq_redis): +def standalone_worker_context(data_provider, arq_redis): yield { - "db": session, "hdp": data_provider, "state": {}, "job_id": "test_job", diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py index 2038eaf7..119bafc3 100644 --- a/tests/worker/jobs/data_management/test_views.py +++ b/tests/worker/jobs/data_management/test_views.py @@ -32,7 +32,7 @@ async def test_refresh_materialized_views_calls_refresh_function(self, mock_work """Test that refresh_materialized_views calls the refresh function.""" with ( patch("mavedb.worker.jobs.data_management.views.refresh_all_mat_views") as mock_refresh, - TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), ): result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) @@ -44,7 +44,7 @@ async def test_refresh_materialized_views_updates_progress(self, mock_worker_ctx with ( patch("mavedb.worker.jobs.data_management.views.refresh_all_mat_views"), patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, - TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), ): result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) @@ -140,7 +140,7 @@ async def test_refresh_published_variants_view_calls_refresh_function( with ( patch.object(PublishedVariantsMV, "refresh") as mock_refresh, patch("mavedb.worker.jobs.data_management.views.validate_job_params"), - TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), ): result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) @@ -157,7 +157,7 @@ async def test_refresh_published_variants_view_updates_progress( patch.object(PublishedVariantsMV, "refresh"), patch("mavedb.worker.jobs.data_management.views.validate_job_params"), patch.object(mock_job_manager, "update_progress", return_value=None) as mock_update_progress, - TransactionSpy.spy(mock_job_manager.db, expect_commit=True), + TransactionSpy.spy(mock_job_manager.db, expect_commit=False, expect_flush=True), ): result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index dff03917..1b042a76 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -4,16 +4,17 @@ import pytest from sqlalchemy import select -from mavedb.lib.exceptions import LDHSubmissionFailureError from mavedb.lib.variants import get_hgvs_from_post_mapped from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.mapped_variant import MappedVariant from mavedb.models.variant import Variant +from mavedb.models.variant_annotation_status import VariantAnnotationStatus from mavedb.worker.jobs.external_services.clingen import ( submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, ) from mavedb.worker.lib.managers.job_manager import JobManager +from tests.helpers.constants import TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST from tests.helpers.util.setup.worker import create_mappings_in_score_set pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") @@ -150,6 +151,15 @@ async def test_submit_score_set_mappings_to_car_no_registered_alleles( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + assert ann.annotation_type == "clingen_allele_id" + async def test_submit_score_set_mappings_to_car_no_linked_alleles( self, mock_worker_ctx, @@ -202,6 +212,15 @@ async def test_submit_score_set_mappings_to_car_no_linked_alleles( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + assert ann.annotation_type == "clingen_allele_id" + async def test_submit_score_set_mappings_to_car_repeated_hgvs( self, mock_worker_ctx, @@ -265,6 +284,15 @@ async def test_submit_score_set_mappings_to_car_repeated_hgvs( for variant in variants: assert variant.clingen_allele_id == "CA_DUPLICATE" + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + assert ann.annotation_type == "clingen_allele_id" + async def test_submit_score_set_mappings_to_car_hgvs_not_found( self, mock_worker_ctx, @@ -330,6 +358,15 @@ async def test_submit_score_set_mappings_to_car_hgvs_not_found( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + assert ann.annotation_type == "clingen_allele_id" + async def test_submit_score_set_mappings_to_car_propagates_exception( self, mock_worker_ctx, @@ -437,6 +474,15 @@ async def test_submit_score_set_mappings_to_car_success( for variant in variants: assert variant.clingen_allele_id == f"CA{variant.id}" + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + assert ann.annotation_type == "clingen_allele_id" + async def test_submit_score_set_mappings_to_car_updates_progress( self, mock_worker_ctx, @@ -504,12 +550,6 @@ async def test_submit_score_set_mappings_to_car_updates_progress( ] ) - # Verify variants have CAIDs assigned - variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() - assert len(variants) == 4 - for variant in variants: - assert variant.clingen_allele_id == f"CA{variant.id}" - @pytest.mark.integration @pytest.mark.asyncio @@ -571,6 +611,14 @@ async def test_submit_score_set_mappings_to_car_independent_ctx( for variant in variants: assert variant.clingen_allele_id == f"CA{variant.id}" + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == len(mapped_variants) + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED @@ -631,6 +679,14 @@ async def test_submit_score_set_mappings_to_car_pipeline_ctx( for variant in variants: assert variant.clingen_allele_id == f"CA{variant.id}" + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == len(mapped_variants) + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED @@ -666,6 +722,10 @@ async def test_submit_score_set_mappings_to_car_submission_disabled( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify no annotation statuses were created + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 0 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED @@ -701,6 +761,10 @@ async def test_submit_score_set_mappings_to_car_no_submission_endpoint( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify no annotation statuses were created + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 0 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED @@ -727,6 +791,10 @@ async def test_submit_score_set_mappings_to_car_no_mappings( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify no annotation statuses were created + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 0 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED @@ -774,6 +842,12 @@ async def test_submit_score_set_mappings_to_car_no_registered_alleles( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED @@ -826,6 +900,12 @@ async def test_submit_score_set_mappings_to_car_no_linked_alleles( variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify annotation statuses were rendered as failed + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED @@ -941,6 +1021,14 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_independent( for variant in variants: assert variant.clingen_allele_id == f"CA{variant.id}" + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + async def test_submit_score_set_mappings_to_car_with_arq_context_pipeline( self, standalone_worker_context, @@ -1007,6 +1095,14 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_pipeline( for variant in variants: assert variant.clingen_allele_id == f"CA{variant.id}" + # Verify annotation statuses were rendered as success + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handling_independent( self, standalone_worker_context, @@ -1057,6 +1153,12 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handl variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 0 + async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handling_pipeline( self, standalone_worker_context, @@ -1112,6 +1214,12 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handl variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() assert len(variants) == 0 + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "clingen_allele_id") + ).all() + assert len(annotation_statuses) == 0 + @pytest.mark.unit @pytest.mark.asyncio @@ -1170,7 +1278,7 @@ async def test_submit_score_set_mappings_to_ldh_all_submissions_failed( ) async def dummy_submission_failure(*args, **kwargs): - return ([], ["Submission failed"]) + return ([], [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 4) # Patch ClinGenLdhService to simulate all submissions failing with ( @@ -1182,14 +1290,15 @@ async def dummy_submission_failure(*args, **kwargs): patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), patch("mavedb.worker.jobs.external_services.clingen.LDH_SUBMISSION_ENDPOINT", "http://fake-endpoint"), patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, - pytest.raises(LDHSubmissionFailureError), ): - await submit_score_set_mappings_to_ldh( + result = await submit_score_set_mappings_to_ldh( mock_worker_ctx, submit_score_set_mappings_to_ldh_sample_job_run.id, JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_ldh_sample_job_run.id), ) + assert result["status"] == "failed" + assert "All LDH submissions failed for score set" in result["exception_details"]["message"] mock_update_progress.assert_called_with(100, 100, "All mapped variant submissions to LDH failed.") async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( @@ -1301,10 +1410,22 @@ async def test_submit_score_set_mappings_to_ldh_partial_submission( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_partial_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], - ["Submission failed for some variants"], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants[2:], start=1) + ], + [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 2, ) # Patch ClinGenLdhService to simulate partial submission success @@ -1326,7 +1447,7 @@ async def dummy_partial_submission(*args, **kwargs): assert result["status"] == "ok" mock_update_progress.assert_called_with( - 100, 100, "Finalized LDH mapped resource submission (2 successes, 1 failures)." + 100, 100, "Finalized LDH mapped resource submission (2 successes, 2 failures)." ) async def test_submit_score_set_mappings_to_ldh_all_successful_submission( @@ -1353,9 +1474,21 @@ async def test_submit_score_set_mappings_to_ldh_all_successful_submission( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_successful_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], [], ) @@ -1378,7 +1511,7 @@ async def dummy_successful_submission(*args, **kwargs): assert result["status"] == "ok" mock_update_progress.assert_called_with( - 100, 100, "Finalized LDH mapped resource submission (2 successes, 0 failures)." + 100, 100, "Finalized LDH mapped resource submission (4 successes, 0 failures)." ) @@ -1411,9 +1544,21 @@ async def test_submit_score_set_mappings_to_ldh_independent( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_ldh_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], [], ) @@ -1432,6 +1577,14 @@ async def dummy_ldh_submission(*args, **kwargs): assert result["status"] == "ok" + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED @@ -1461,9 +1614,21 @@ async def test_submit_score_set_mappings_to_ldh_pipeline_ctx( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_ldh_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], [], ) @@ -1482,6 +1647,14 @@ async def dummy_ldh_submission(*args, **kwargs): assert result["status"] == "ok" + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED @@ -1576,6 +1749,14 @@ async def dummy_no_linked_alleles_submission(*args, **kwargs): assert result["status"] == "ok" + # Verify annotation statuses were created with failures + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED @@ -1615,6 +1796,12 @@ async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( assert result["status"] == "ok" + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 0 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED @@ -1644,7 +1831,7 @@ async def test_submit_score_set_mappings_to_ldh_all_submissions_failed( ) async def dummy_submission_failure(*args, **kwargs): - return ([], ["Submission failed"]) + return ([], [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 4) # Patch ClinGenLdhService to simulate all submissions failing with ( @@ -1662,9 +1849,18 @@ async def dummy_submission_failure(*args, **kwargs): assert result["status"] == "failed" assert "All LDH submissions failed for score set" in result["exception_details"]["message"] + # Verify annotation statuses were created with failures + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "failed" + # Verify the job status is updated in the database + # TODO:XXX: Change status to 'failed' once decorator supports it session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) - assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED async def test_submit_score_set_mappings_to_ldh_partial_submission( self, @@ -1690,10 +1886,21 @@ async def test_submit_score_set_mappings_to_ldh_partial_submission( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_partial_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}], - ["Submission failed for some variants"], + [ + { + "data": { + "entId": variants[0].urn, + "ldhId": f"LDH123400{1}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{1}", + }, + "status": {"code": 200, "name": "OK"}, + } + ], + [TEST_CLINGEN_LDH_LINKING_RESPONSE_BAD_REQUEST] * 3, ) # Patch ClinGenLdhService to simulate partial submission success @@ -1711,6 +1918,22 @@ async def dummy_partial_submission(*args, **kwargs): assert result["status"] == "ok" + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + success_count = 0 + failure_count = 0 + for ann in annotation_statuses: + if ann.status == "success": + success_count += 1 + elif ann.status == "failed": + failure_count += 1 + + assert success_count == 1 + assert failure_count == 3 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED @@ -1739,9 +1962,21 @@ async def test_submit_score_set_mappings_to_ldh_all_successful_submission( dummy_variant_mapping_job_run, ) - async def dummy_successful_submission(*args, **kwargs): + variants = session.scalars(select(Variant)).all() + + async def dummy_ldh_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], [], ) @@ -1750,7 +1985,7 @@ async def dummy_successful_submission(*args, **kwargs): patch.object( _UnixSelectorEventLoop, "run_in_executor", - return_value=dummy_successful_submission(), + return_value=dummy_ldh_submission(), ), patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), ): @@ -1760,6 +1995,14 @@ async def dummy_successful_submission(*args, **kwargs): assert result["status"] == "ok" + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED @@ -1796,9 +2039,21 @@ async def test_submit_score_set_mappings_to_ldh_independent( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_ldh_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], [], ) @@ -1817,6 +2072,14 @@ async def dummy_ldh_submission(*args, **kwargs): await arq_worker.async_run() await arq_worker.run_check() + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED @@ -1848,9 +2111,21 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_in_pipeline( dummy_variant_mapping_job_run, ) + variants = session.scalars(select(Variant)).all() + async def dummy_ldh_submission(*args, **kwargs): return ( - [{"@id": "LDH12345"}, {"@id": "LDH23456"}], + [ + { + "data": { + "entId": v.urn, + "ldhId": f"LDH123400{idx}", + "ldhIri": f"https://10.15.55.128/ldh-stg/MaveDBMapping/id/LDH123400{idx}", + }, + "status": {"code": 200, "name": "OK"}, + } + for idx, v in enumerate(variants, start=1) + ], [], ) @@ -1869,6 +2144,14 @@ async def dummy_ldh_submission(*args, **kwargs): await arq_worker.async_run() await arq_worker.run_check() + # Verify annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 4 + for ann in annotation_statuses: + assert ann.status == "success" + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.SUCCEEDED @@ -1918,6 +2201,12 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handl await arq_worker.async_run() await arq_worker.run_check() + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 0 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED @@ -1965,6 +2254,12 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handl await arq_worker.async_run() await arq_worker.run_check() + # Verify no annotation statuses were created + annotation_statuses = session.scalars( + select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") + ).all() + assert len(annotation_statuses) == 0 + # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline) assert submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.status == JobStatus.FAILED diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index 935c5fe8..17fb3ec1 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -7,6 +7,7 @@ from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant +from mavedb.models.variant_annotation_status import VariantAnnotationStatus from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants from mavedb.worker.lib.managers.job_manager import JobManager @@ -91,7 +92,7 @@ async def test_link_gnomad_variants_no_gnomad_matches( ) assert result["status"] == "ok" - mock_update_progress.assert_any_call(100, 100, "No gnomAD variants with CAID matches found. Nothing to link.") + mock_update_progress.assert_any_call(100, 100, "Linked 0 mapped variants to gnomAD variants.") async def test_link_gnomad_variants_call_linking_method( self, @@ -209,6 +210,10 @@ async def test_link_gnomad_variants_no_variants_with_caids( gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 + # Verify no annotations were rendered (since there were no variants with CAIDs) + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify job status updates session.refresh(sample_link_gnomad_variants_run) assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED @@ -239,6 +244,12 @@ async def test_link_gnomad_variants_no_matching_caids( gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 + # Verify a skipped annotation status was rendered (since there were variants with CAIDs) + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "skipped" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + # Verify job status updates session.refresh(sample_link_gnomad_variants_run) assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED @@ -265,6 +276,12 @@ async def test_link_gnomad_variants_successful_linking_independent( gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) > 0 + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + # Verify job status updates session.refresh(sample_link_gnomad_variants_run) assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED @@ -291,6 +308,12 @@ async def test_link_gnomad_variants_successful_linking_pipeline( gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) > 0 + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + # Verify job status updates session.refresh(sample_link_gnomad_variants_run_pipeline) assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.SUCCEEDED @@ -361,6 +384,12 @@ async def test_link_gnomad_variants_with_arq_context_independent( gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) > 0 + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + # Verify that the job completed successfully session.refresh(sample_link_gnomad_variants_run) assert sample_link_gnomad_variants_run.status == JobStatus.SUCCEEDED @@ -389,6 +418,12 @@ async def test_link_gnomad_variants_with_arq_context_pipeline( gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) > 0 + # Verify annotation status was rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == "success" + assert annotation_statuses[0].annotation_type == "gnomad_allele_frequency" + # Verify that the job completed successfully session.refresh(sample_link_gnomad_variants_run_pipeline) assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.SUCCEEDED @@ -425,6 +460,10 @@ async def test_link_gnomad_variants_with_arq_context_exception_handling_independ gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 + # Verify no annotations were rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job failed session.refresh(sample_link_gnomad_variants_run) assert sample_link_gnomad_variants_run.status == JobStatus.FAILED @@ -457,6 +496,10 @@ async def test_link_gnomad_variants_with_arq_context_exception_handling_pipeline gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 + # Verify no annotations were rendered + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job failed session.refresh(sample_link_gnomad_variants_run_pipeline) assert sample_link_gnomad_variants_run_pipeline.status == JobStatus.FAILED diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index ea714664..3a543544 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -678,7 +678,7 @@ async def test_submit_uniprot_mapping_jobs_propagates_exceptions( # Verify that the job metadata contains no submitted jobs session.refresh(sample_submit_uniprot_mapping_jobs_run) - assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + assert sample_submit_uniprot_mapping_jobs_run.metadata_.get("submitted_jobs") is None # Verify that the submission job failed session.refresh(sample_submit_uniprot_mapping_jobs_run) @@ -827,7 +827,8 @@ async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( # Verify that the submission job failed session.refresh(sample_submit_uniprot_mapping_jobs_run) - assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED + # TODO#XXX: Should be failed when supported by decorator + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED # nothing to verify for dependent polling job since it does not exist @@ -973,7 +974,7 @@ async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_i # Verify that the job metadata contains no submitted jobs session.refresh(sample_submit_uniprot_mapping_jobs_run) - assert sample_submit_uniprot_mapping_jobs_run.metadata_["submitted_jobs"] == {} + assert sample_submit_uniprot_mapping_jobs_run.metadata_.get("submitted_jobs") is None # Verify that the submission job failed session.refresh(sample_submit_uniprot_mapping_jobs_run) @@ -1016,7 +1017,7 @@ async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_p # Verify that the job metadata contains no submitted jobs session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) - assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_["submitted_jobs"] == {} + assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_.get("submitted_jobs") is None # Verify that the submission job failed session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index 6f94ae58..5b93e15a 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -100,16 +100,16 @@ async def test_create_variants_for_score_set_s3_file_not_found( side_effect=Exception("The specified key does not exist."), ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(Exception) as exc_info, ): - await create_variants_for_score_set( + result = await create_variants_for_score_set( mock_worker_ctx, sample_independent_variant_creation_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") - assert str(exc_info.value) == "The specified key does not exist." + assert result["status"] == "failed" + assert "The specified key does not exist." in result["exception_details"]["message"] session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed assert sample_score_set.mapping_state == MappingState.not_attempted @@ -186,16 +186,16 @@ async def test_create_variants_for_score_set_raises_when_no_targets_exist( side_effect=[sample_score_dataframe, sample_count_dataframe], ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(ValueError) as exc_info, ): - await create_variants_for_score_set( + result = await create_variants_for_score_set( mock_worker_ctx, sample_independent_variant_creation_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) mock_update_progress.assert_any_call(100, 100, "Score set has no targets; cannot create variants.") - assert str(exc_info.value) == "Can't create variants when score set has no targets." + assert result["status"] == "failed" + assert "Can't create variants when score set has no targets." in result["exception_details"]["message"] async def test_create_variants_for_score_set_calls_validate_standardize_dataframe_with_correct_parameters( self, @@ -556,15 +556,15 @@ async def test_create_variants_for_score_set_retains_existing_variants_when_exce "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", side_effect=Exception("Test exception during data validation"), ), - pytest.raises(Exception) as exc_info, ): - await create_variants_for_score_set( + result = await create_variants_for_score_set( mock_worker_ctx, sample_independent_variant_creation_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) - assert str(exc_info.value) == "Test exception during data validation" + assert result["status"] == "failed" + assert "Test exception during data validation" in result["exception_details"]["message"] # Verify that existing variants are still present remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() @@ -597,15 +597,15 @@ async def test_create_variants_for_score_set_handles_exception_and_updates_state side_effect=Exception("Test exception during data validation"), ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(Exception) as exc_info, ): - await create_variants_for_score_set( + result = await create_variants_for_score_set( mock_worker_ctx, sample_independent_variant_creation_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) - assert str(exc_info.value) == "Test exception during data validation" + assert result["status"] == "failed" + assert "Test exception during data validation" in result["exception_details"]["message"] # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) @@ -960,7 +960,7 @@ async def test_create_variants_for_score_set_validation_error_during_creation( .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.FAILED + assert job_run.status == JobStatus.SUCCEEDED async def test_create_variants_for_score_set_generic_exception_handling_during_creation( self, @@ -1002,7 +1002,7 @@ async def test_create_variants_for_score_set_generic_exception_handling_during_c .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.FAILED + assert job_run.status == JobStatus.SUCCEEDED async def test_create_variants_for_score_set_generic_exception_handling_during_replacement( self, @@ -1065,7 +1065,7 @@ async def test_create_variants_for_score_set_generic_exception_handling_during_r .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.FAILED + assert job_run.status == JobStatus.SUCCEEDED ## Pipeline failure workflow @@ -1110,12 +1110,11 @@ async def test_create_variants_for_score_set_pipeline_job_generic_exception_hand .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.FAILED + assert job_run.status == JobStatus.SUCCEEDED # Verify that pipeline status is updated. session.refresh(sample_variant_creation_pipeline) - assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED - + assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING # Verify other pipeline runs are marked as failed other_runs = ( session.query(Pipeline) @@ -1126,7 +1125,7 @@ async def test_create_variants_for_score_set_pipeline_job_generic_exception_hand .all() ) for run in other_runs: - assert run.status == PipelineStatus.CANCELLED + assert run.status == JobStatus.PENDING @pytest.mark.asyncio @@ -1320,7 +1319,7 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.FAILED + assert job_run.status == JobStatus.SUCCEEDED async def test_create_variants_for_score_set_with_arq_context_generic_exception_handling_pipeline_ctx( self, @@ -1366,11 +1365,11 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.FAILED + assert job_run.status == JobStatus.SUCCEEDED # Verify that pipeline status is updated. session.refresh(sample_variant_creation_pipeline) - assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED + assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING # Verify other pipeline runs are marked as cancelled other_runs = ( @@ -1382,4 +1381,4 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ .all() ) for run in other_runs: - assert run.status == PipelineStatus.CANCELLED + assert run.status == JobStatus.PENDING diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index fa0c3dc8..a7cc1412 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -14,6 +14,7 @@ from mavedb.models.enums.mapping_state import MappingState from mavedb.models.mapped_variant import MappedVariant from mavedb.models.variant import Variant +from mavedb.models.variant_annotation_status import VariantAnnotationStatus from mavedb.worker.jobs.variant_processing.mapping import map_variants_for_score_set from mavedb.worker.lib.managers.job_manager import JobManager from tests.helpers.constants import TEST_CODING_LAYER, TEST_GENOMIC_LAYER, TEST_PROTEIN_LAYER @@ -62,6 +63,15 @@ async def test_map_variants_for_score_set_no_mapping_results( in sample_score_set.mapping_errors["error_message"] ) + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + async def test_map_variants_for_score_set_no_mapped_scores( self, session, @@ -97,6 +107,15 @@ async def test_map_variants_for_score_set_no_mapped_scores( assert sample_score_set.mapping_errors is not None assert "No variants were mapped for this score set" in sample_score_set.mapping_errors["error_message"] + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + async def test_map_variants_for_score_set_no_reference_data( self, session, @@ -132,6 +151,15 @@ async def test_map_variants_for_score_set_no_reference_data( assert sample_score_set.mapping_errors is not None assert "Reference metadata missing from mapping results" in sample_score_set.mapping_errors["error_message"] + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + async def test_map_variants_for_score_set_nonexistent_target_gene( self, session, @@ -173,6 +201,15 @@ async def test_map_variants_for_score_set_nonexistent_target_gene( in sample_score_set.mapping_errors["error_message"] ) + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + async def test_map_variants_for_score_set_returns_variants_not_in_score_set( self, session, @@ -214,6 +251,15 @@ async def test_map_variants_for_score_set_returns_variants_not_in_score_set( in sample_score_set.mapping_errors["error_message"] ) + # Verify no annotations were created + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 0 + async def test_map_variants_for_score_set_success_missing_gene_info( self, session, @@ -274,6 +320,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 1 + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].annotation_type == "vrs_mapping" + assert annotation_statuses[0].status == "success" + @pytest.mark.parametrize( "with_layers", [ @@ -381,6 +438,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 1 + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].annotation_type == "vrs_mapping" + assert annotation_statuses[0].status == "success" + async def test_map_variants_for_score_set_success_no_successful_mapping( self, session, @@ -441,6 +509,17 @@ async def dummy_mapping_job(): mapped_variant = mapped_variants[0] assert mapped_variant.post_mapped == {} + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].annotation_type == "vrs_mapping" + assert annotation_statuses[0].status == "failed" + async def test_map_variants_for_score_set_incomplete_mapping( self, session, @@ -520,6 +599,24 @@ async def dummy_mapping_job(): ) assert mapped_variant_without_post_data is not None + # Verify that annotation statuses were created and correct + annotation_status_success = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, VariantAnnotationStatus.status == "success") + .all() + ) + assert len(annotation_status_success) == 1 + assert annotation_status_success[0].annotation_type == "vrs_mapping" + annotation_status_failed = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id, VariantAnnotationStatus.status == "failed") + .all() + ) + assert len(annotation_status_failed) == 1 + assert annotation_status_failed[0].annotation_type == "vrs_mapping" + async def test_map_variants_for_score_set_complete_mapping( self, session, @@ -594,6 +691,18 @@ async def dummy_mapping_job(): assert mapped_variant is not None assert mapped_variant.post_mapped != {} + # Verify that annotation statuses were created and correct + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 2 + for status in annotation_statuses: + assert status.annotation_type == "vrs_mapping" + assert status.status == "success" + async def test_map_variants_for_score_set_updates_existing_mapped_variants( self, with_independent_processing_runs, @@ -619,7 +728,7 @@ async def dummy_mapping_job(): with_all_variants=True, ) - # Create a variant and associated mapped data in the score set to be updated + # Create a variant and associated mapped data/annotation status in the score set to be updated variant = Variant( score_set_id=sample_score_set.id, hgvs_nt="NM_000000.1:c.1A>G", hgvs_pro="NP_000000.1:p.Met1Val", data={} ) @@ -633,6 +742,11 @@ async def dummy_mapping_job(): ) session.add(mapped_variant) session.commit() + variant_annotation_status = VariantAnnotationStatus( + variant_id=variant.id, current=True, annotation_type="vrs_mapping", status="success" + ) + session.add(variant_annotation_status) + session.commit() with ( patch.object( @@ -674,6 +788,25 @@ async def dummy_mapping_job(): assert new_mapped_variant.mapped_date != "2023-01-01T00:00:00Z" assert new_mapped_variant.mapping_api_version != "v1.0.0" + # Verify the non-current annotation status still exists + old_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter( + VariantAnnotationStatus.variant_id == non_current_mapped_variant.variant_id, + VariantAnnotationStatus.current.is_(False), + ) + .one_or_none() + ) + assert old_annotation_status is not None + + # Verify that a new annotation status was created + new_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == variant.id, VariantAnnotationStatus.current.is_(True)) + .one_or_none() + ) + assert new_annotation_status is not None + async def test_map_variants_for_score_set_progress_updates( self, session, @@ -819,6 +952,15 @@ async def dummy_mapping_job(): ) assert len(variants) == 4 + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + # Verify that the job status was updated processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) @@ -902,6 +1044,15 @@ async def dummy_mapping_job(): ) assert len(variants) == 4 + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + # Verify that the job status was updated processing_run = ( session.query(sample_pipeline_variant_mapping_run.__class__) @@ -959,7 +1110,7 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "failed" + assert result["status"] == "error" assert result["exception_details"]["type"] == "NonexistentMappingResultsError" assert result["data"] == {} @@ -974,13 +1125,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED async def test_map_variants_for_score_set_no_mapped_scores( self, @@ -1033,7 +1188,7 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "failed" + assert result["status"] == "error" assert result["exception_details"]["type"] == "NonexistentMappingScoresError" assert result["data"] == {} @@ -1046,13 +1201,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED async def test_map_variants_for_score_set_no_reference_data( self, @@ -1105,7 +1264,7 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "failed" + assert result["status"] == "error" assert result["exception_details"]["type"] == "NonexistentMappingReferenceError" assert result["data"] == {} @@ -1117,13 +1276,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED async def test_map_variants_for_score_set_updates_current_mapped_variants( self, @@ -1158,6 +1321,10 @@ async def test_map_variants_for_score_set_updates_current_mapped_variants( mapped_date="2023-01-01T00:00:00Z", mapping_api_version="v1.0.0", ) + annotation_status = VariantAnnotationStatus( + variant_id=variant.id, current=True, annotation_type="vrs_mapping", status="success" + ) + session.add(annotation_status) session.add(mapped_variant) session.commit() @@ -1217,6 +1384,24 @@ async def dummy_mapping_job(): assert new_mapped_variant.mapped_date != "2023-01-01T00:00:00Z" assert new_mapped_variant.mapping_api_version != "v1.0.0" + # Verify that annotation statuses where marked as non-current and new entries created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == len(variants) * 2 # Each variant has two annotation statuses now + for variant in variants: + old_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == variant.id, VariantAnnotationStatus.current.is_(False)) + .one_or_none() + ) + assert old_annotation_status is not None + + new_annotation_status = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == variant.id, VariantAnnotationStatus.current.is_(True)) + .one_or_none() + ) + assert new_annotation_status is not None + # Verify that the job status was updated. processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) @@ -1262,7 +1447,7 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "failed" + assert result["status"] == "error" assert result["data"] == {} assert result["exception_details"] is not None assert result["exception_details"]["type"] == "NonexistentMappingScoresError" @@ -1275,13 +1460,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED async def test_map_variants_for_score_set_exception_in_mapping( self, @@ -1310,7 +1499,7 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "failed" + assert result["status"] == "error" assert result["data"] == {} assert result["exception_details"]["type"] == "ValueError" # exception messages are persisted in internal properties @@ -1328,13 +1517,17 @@ async def dummy_mapping_job(): mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED @pytest.mark.integration @@ -1368,7 +1561,7 @@ async def test_create_variants_for_score_set_with_arq_context_independent_ctx( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=standalone_worker_context["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -1391,7 +1584,7 @@ async def dummy_mapping_job(): await arq_worker.run_check() # Verify that mapped variants were created - mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 4 # Verify score set mapping state @@ -1400,18 +1593,25 @@ async def dummy_mapping_job(): # Verify that each variant has a corresponding mapped variant variants = ( - standalone_worker_context["db"] - .query(Variant) + session.query(Variant) .join(MappedVariant, MappedVariant.variant_id == Variant.id) .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) .all() ) assert len(variants) == 4 + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + # Verify that the job status was updated processing_run = ( - standalone_worker_context["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) @@ -1447,7 +1647,7 @@ async def test_map_variants_for_score_set_with_arq_context_pipeline_ctx( async def dummy_mapping_job(): return await construct_mock_mapping_output( - session=standalone_worker_context["db"], + session=session, score_set=sample_score_set, with_gene_info=True, with_layers={"g", "c", "p"}, @@ -1472,7 +1672,7 @@ async def dummy_mapping_job(): await arq_worker.run_check() # Verify that mapped variants were created - mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 4 # Verify score set mapping state @@ -1481,18 +1681,25 @@ async def dummy_mapping_job(): # Verify that each variant has a corresponding mapped variant variants = ( - standalone_worker_context["db"] - .query(Variant) + session.query(Variant) .join(MappedVariant, MappedVariant.variant_id == Variant.id) .filter(Variant.score_set_id == sample_score_set.id, MappedVariant.current.is_(True)) .all() ) assert len(variants) == 4 + # Verify that each variant has an annotation status + annotation_statuses = ( + session.query(VariantAnnotationStatus) + .join(Variant, VariantAnnotationStatus.variant_id == Variant.id) + .filter(Variant.score_set_id == sample_score_set.id) + .all() + ) + assert len(annotation_statuses) == 4 + # Verify that the job status was updated processing_run = ( - standalone_worker_context["db"] - .query(sample_pipeline_variant_mapping_run.__class__) + session.query(sample_pipeline_variant_mapping_run.__class__) .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) .one() ) @@ -1501,8 +1708,7 @@ async def dummy_mapping_job(): # Verify that the pipeline run status was updated. We expect RUNNING here because # the mapping job is not the only job in our dummy pipeline. pipeline_run = ( - standalone_worker_context["db"] - .query(sample_pipeline_variant_mapping_run.pipeline.__class__) + session.query(sample_pipeline_variant_mapping_run.pipeline.__class__) .filter( sample_pipeline_variant_mapping_run.pipeline.__class__.id == sample_pipeline_variant_mapping_run.pipeline.id @@ -1513,6 +1719,7 @@ async def dummy_mapping_job(): async def test_map_variants_for_score_set_with_arq_context_generic_exception_handling( self, + session, arq_redis, arq_worker, standalone_worker_context, @@ -1547,20 +1754,24 @@ async def dummy_mapping_job(): ) # Verify that no mapped variants were created - mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( - standalone_worker_context["db"] - .query(sample_independent_variant_mapping_run.__class__) + session.query(sample_independent_variant_mapping_run.__class__) .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED async def test_map_variants_for_score_set_with_arq_context_generic_exception_in_pipeline_ctx( self, + session, arq_redis, arq_worker, standalone_worker_context, @@ -1595,31 +1806,33 @@ async def dummy_mapping_job(): ) # Verify that no mapped variants were created - mapped_variants = standalone_worker_context["db"].query(MappedVariant).all() + mapped_variants = session.query(MappedVariant).all() assert len(mapped_variants) == 0 + # Verify that no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + # Verify that the job status was updated. processing_run = ( - standalone_worker_context["db"] - .query(sample_pipeline_variant_mapping_run.__class__) + session.query(sample_pipeline_variant_mapping_run.__class__) .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.FAILED + assert processing_run.status == JobStatus.SUCCEEDED # Verify that the pipeline run status was updated to FAILED. pipeline_run = ( - standalone_worker_context["db"] - .query(sample_pipeline_variant_mapping_run.pipeline.__class__) + session.query(sample_pipeline_variant_mapping_run.pipeline.__class__) .filter( sample_pipeline_variant_mapping_run.pipeline.__class__.id == sample_pipeline_variant_mapping_run.pipeline.id ) .one() ) - assert pipeline_run.status == PipelineStatus.FAILED + assert pipeline_run.status == PipelineStatus.RUNNING # Verify that other jobs in the pipeline were skipped for job_run in pipeline_run.job_runs: if job_run.id != sample_pipeline_variant_mapping_run.id: - assert job_run.status == JobStatus.SKIPPED + assert job_run.status == JobStatus.QUEUED From 011522c536d6bcc24f52c8048cf1afc5bea73944 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 20:45:32 -0800 Subject: [PATCH 43/70] feat: streamline job results and exception handling in tests - Updated test assertions to check for "exception" status instead of "failed" in variant creation and mapping tests. - Enhanced exception handling in job management decorators to return structured results with "status", "data", and "exception" fields. - Modified job manager methods to align with new result structure, ensuring consistent handling of job outcomes across success, failure, and cancellation scenarios. - Adjusted integration tests to validate the new result format and ensure proper job state transitions. - Improved clarity in test cases by asserting the presence of exception details where applicable. --- src/mavedb/lib/exceptions.py | 12 ++ .../worker/jobs/data_management/views.py | 4 +- .../worker/jobs/external_services/clingen.py | 24 +-- .../worker/jobs/external_services/gnomad.py | 4 +- .../worker/jobs/external_services/uniprot.py | 18 +-- .../pipeline_management/start_pipeline.py | 9 +- .../jobs/variant_processing/creation.py | 12 +- .../worker/jobs/variant_processing/mapping.py | 25 +-- .../worker/lib/decorators/job_management.py | 29 ++-- .../lib/decorators/pipeline_management.py | 10 +- src/mavedb/worker/lib/managers/job_manager.py | 8 +- .../worker/lib/managers/pipeline_manager.py | 6 +- src/mavedb/worker/lib/managers/types.py | 12 +- src/mavedb/worker/lib/managers/utils.py | 2 +- .../worker/jobs/data_management/test_views.py | 23 +-- .../jobs/external_services/test_clingen.py | 32 ++-- .../jobs/external_services/test_gnomad.py | 4 +- .../jobs/external_services/test_uniprot.py | 32 ++-- .../test_start_pipeline.py | 19 ++- .../jobs/variant_processing/test_creation.py | 36 ++--- .../jobs/variant_processing/test_mapping.py | 94 +++++++----- .../lib/decorators/test_job_management.py | 89 ++++++++++- .../decorators/test_pipeline_management.py | 8 +- tests/worker/lib/managers/test_job_manager.py | 143 +++++++++++++----- .../lib/managers/test_pipeline_manager.py | 18 +-- tests/worker/lib/managers/test_utils.py | 2 +- 26 files changed, 414 insertions(+), 261 deletions(-) diff --git a/src/mavedb/lib/exceptions.py b/src/mavedb/lib/exceptions.py index 63e891a3..2dadeb95 100644 --- a/src/mavedb/lib/exceptions.py +++ b/src/mavedb/lib/exceptions.py @@ -232,3 +232,15 @@ class LDHSubmissionFailureError(Exception): """Raised when submission to ClinGen Linked Data Hub (LDH) fails for all submissions.""" pass + + +class PipelineNotFoundError(Exception): + """Raised when a pipeline associated with a job is not found.""" + + pass + + +class NoMappedVariantsError(Exception): + """Raised when no variants were mapped during the variant mapping process.""" + + pass diff --git a/src/mavedb/worker/jobs/data_management/views.py b/src/mavedb/worker/jobs/data_management/views.py index d93c38a2..abf787c2 100644 --- a/src/mavedb/worker/jobs/data_management/views.py +++ b/src/mavedb/worker/jobs/data_management/views.py @@ -61,7 +61,7 @@ async def refresh_materialized_views(ctx: dict, job_id: int, job_manager: JobMan job_manager.update_progress(100, 100, "Completed refresh of all materialized views.") logger.debug(msg="Done refreshing materialized views.", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} @with_pipeline_management @@ -111,4 +111,4 @@ async def refresh_published_variants_view(ctx: dict, job_id: int, job_manager: J job_manager.update_progress(100, 100, "Completed refresh of published variants materialized view.") logger.debug(msg="Done refreshing published variants materialized view.", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/clingen.py b/src/mavedb/worker/jobs/external_services/clingen.py index 4fe61a6d..e67e4337 100644 --- a/src/mavedb/worker/jobs/external_services/clingen.py +++ b/src/mavedb/worker/jobs/external_services/clingen.py @@ -95,7 +95,7 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: msg="ClinGen submission is disabled via configuration, skipping submission of mapped variants to CAR.", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "skipped", "data": {}, "exception": None} # Check for CAR submission endpoint if not CAR_SUBMISSION_ENDPOINT: @@ -104,7 +104,11 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: msg="ClinGen Allele Registry submission is disabled (no submission endpoint), unable to complete submission of mapped variants to CAR.", extra=job_manager.logging_context(), ) - raise ValueError("ClinGen Allele Registry submission endpoint is not configured.") + return { + "status": "failed", + "data": {}, + "exception": ValueError("ClinGen Allele Registry submission endpoint is not configured."), + } # Fetch mapped variants with post-mapped data for the score set variant_post_mapped_objects = job_manager.db.execute( @@ -124,7 +128,7 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: msg="No current mapped variants with post mapped metadata were found for this score set. Skipping CAR submission.", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} job_manager.update_progress( 10, 100, f"Preparing {len(variant_post_mapped_objects)} mapped variants for CAR submission." @@ -213,7 +217,7 @@ async def submit_score_set_mappings_to_car(ctx: dict, job_id: int, job_manager: job_manager.update_progress(100, 100, "Completed CAR mapped resource submission.") job_manager.db.flush() logger.info(msg="Completed CAR mapped resource submission", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} @with_pipeline_management @@ -282,7 +286,7 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: msg="No current mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} job_manager.update_progress(10, 100, f"Submitting {len(variant_objects)} mapped variants to LDH.") # Build submission content @@ -307,7 +311,7 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: msg="No valid mapped variants with post mapped metadata were found for this score set. Skipping LDH submission.", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} job_manager.save_to_context({"unique_variants_to_submit_ldh": len(variant_content)}) job_manager.update_progress(30, 100, f"Dispatching submissions for {len(variant_content)} unique variants to LDH.") @@ -392,11 +396,7 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: return { "status": "failed", "data": {}, - "exception_details": { - "message": error_message, - "type": LDHSubmissionFailureError.__name__, - "traceback": None, - }, + "exception": LDHSubmissionFailureError(error_message), } logger.info( @@ -411,4 +411,4 @@ async def submit_score_set_mappings_to_ldh(ctx: dict, job_id: int, job_manager: f"Finalized LDH mapped resource submission ({len(submission_successes)} successes, {len(submission_failures)} failures).", ) job_manager.db.flush() - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/gnomad.py b/src/mavedb/worker/jobs/external_services/gnomad.py index 87d6bf69..b1e33785 100644 --- a/src/mavedb/worker/jobs/external_services/gnomad.py +++ b/src/mavedb/worker/jobs/external_services/gnomad.py @@ -97,7 +97,7 @@ async def link_gnomad_variants(ctx: dict, job_id: int, job_manager: JobManager) msg="No current mapped variants with CAIDs were found for this score set. Skipping gnomAD linkage (nothing to do).", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} job_manager.update_progress(10, 100, f"Found {num_variant_caids} variants with CAIDs to link to gnomAD variants.") logger.info( @@ -152,4 +152,4 @@ async def link_gnomad_variants(ctx: dict, job_id: int, job_manager: JobManager) job_manager.save_to_context({"num_mapped_variants_linked_to_gnomad_variants": num_linked_gnomad_variants}) job_manager.update_progress(100, 100, f"Linked {num_linked_gnomad_variants} mapped variants to gnomAD variants.") logger.info(msg="Done linking gnomAD variants to mapped variants.", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py index ac99c5ed..bfd89a0d 100644 --- a/src/mavedb/worker/jobs/external_services/uniprot.py +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -104,7 +104,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} uniprot_api = UniProtIDMappingAPI() job_manager.save_to_context({"total_target_genes_to_map_to_uniprot": len(score_set.target_genes)}) @@ -162,7 +162,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ job_manager.update_progress(100, 100, "No UniProt mapping jobs were submitted.") logger.warning(msg="No UniProt mapping jobs were submitted.", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} # It's an essential responsibility of the submit job (when submissions exist) to ensure that the polling job exists. dependent_polling_job = job_manager.db.scalars( @@ -180,11 +180,9 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ return { "status": "failed", "data": {}, - "exception_details": { - "type": UniProtPollingEnqueueError.__name__, - "message": f"Could not find unique dependent polling job for UniProt mapping job {job.id}.", - "traceback": None, - }, + "exception": UniProtPollingEnqueueError( + f"Could not find unique dependent polling job for UniProt mapping job {job.id}." + ), } # Set mapping jobs on dependent polling job. Only one polling job per score set should be created. @@ -197,7 +195,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ job_manager.update_progress(100, 100, "Completed submission of UniProt mapping jobs.") logger.info(msg="Completed UniProt mapping job submission", extra=job_manager.logging_context()) job_manager.db.flush() - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} @with_pipeline_management @@ -252,7 +250,7 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ma msg=f"No mapping jobs found in job parameters for polling UniProt mapping jobs for score set {score_set.urn}.", extra=job_manager.logging_context(), ) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} # Poll each mapping job and update target genes with UniProt IDs uniprot_api = UniProtIDMappingAPI() @@ -321,4 +319,4 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ma job_manager.update_progress(100, 100, "Completed polling of UniProt mapping jobs.") job_manager.db.flush() - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py index ddd28f7c..e2d80f38 100644 --- a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py +++ b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py @@ -1,5 +1,6 @@ import logging +from mavedb.lib.exceptions import PipelineNotFoundError from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management from mavedb.worker.lib.managers.job_manager import JobManager from mavedb.worker.lib.managers.pipeline_manager import PipelineManager @@ -44,7 +45,11 @@ async def start_pipeline(ctx: dict, job_id: int, job_manager: JobManager) -> Job logger.debug(msg="Coordinating pipeline for the first time.", extra=job_manager.logging_context()) if not job_manager.pipeline_id: - raise ValueError(f"No pipeline associated with job {job_id}") + return { + "status": "exception", + "data": {}, + "exception": PipelineNotFoundError("No pipeline associated with this job."), + } # Initialize PipelineManager and coordinate pipeline. The pipeline manager decorator # will have started the pipeline for us already, but doesn't coordinate on start automatically. @@ -56,4 +61,4 @@ async def start_pipeline(ctx: dict, job_id: int, job_manager: JobManager) -> Job job_manager.update_progress(100, 100, "Initial pipeline coordination complete.") logger.debug(msg="Done starting pipeline.", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py index 87f1aecf..3774782a 100644 --- a/src/mavedb/worker/jobs/variant_processing/creation.py +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -227,15 +227,7 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job msg="Encountered an internal exception while processing variants.", extra=job_manager.logging_context() ) - return { - "status": "failed", - "data": {}, - "exception_details": { - "message": str(e), - "type": e.__class__.__name__, - "traceback": format_raised_exception_info_as_dict(e).get("traceback", ""), - }, - } + return {"status": "failed" if isinstance(e, ValidationError) else "exception", "data": {}, "exception": e} else: score_set.processing_state = ProcessingState.success @@ -257,4 +249,4 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job job_manager.update_progress(100, 100, "Completed variant creation job.") logger.info(msg="Added new variants to score set.", extra=job_manager.logging_context()) - return {"status": "ok", "data": {}, "exception_details": None} + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/variant_processing/mapping.py b/src/mavedb/worker/jobs/variant_processing/mapping.py index bb43a43e..eee55a32 100644 --- a/src/mavedb/worker/jobs/variant_processing/mapping.py +++ b/src/mavedb/worker/jobs/variant_processing/mapping.py @@ -17,6 +17,7 @@ from mavedb.data_providers.services import vrs_mapper from mavedb.lib.annotation_status_manager import AnnotationStatusManager from mavedb.lib.exceptions import ( + NoMappedVariantsError, NonexistentMappingReferenceError, NonexistentMappingResultsError, NonexistentMappingScoresError, @@ -280,11 +281,7 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan score_set.mapping_state = MappingState.failed # These exceptions have already set mapping_errors appropriately - return { - "status": "error", - "data": {}, - "exception_details": {"message": str(e), "type": e.__class__.__name__, "traceback": None}, - } + return {"status": "exception", "data": {}, "exception": e} except Exception as e: send_slack_error(e) @@ -300,11 +297,7 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan } job_manager.update_progress(100, 100, "Variant mapping failed due to an unexpected error.") - return { - "status": "error", - "data": {}, - "exception_details": {"message": str(e), "type": e.__class__.__name__, "traceback": None}, - } + return {"status": "exception", "data": {}, "exception": e} finally: job_manager.db.add(score_set) @@ -312,4 +305,14 @@ async def map_variants_for_score_set(ctx: dict, job_id: int, job_manager: JobMan logger.info(msg="Inserted mapped variants into db.", extra=job_manager.logging_context()) job_manager.update_progress(100, 100, "Finished processing mapped variants.") - return {"status": "ok" if successful_mapped_variants > 0 else "error", "data": {}, "exception_details": None} + + if successful_mapped_variants == 0: + logger.error(msg="No variants were successfully mapped.", extra=job_manager.logging_context()) + return { + "status": "failed", + "data": {}, + "exception": NoMappedVariantsError("No variants were successfully mapped."), + } + + logger.info(msg="Variant mapping job completed successfully.", extra=job_manager.logging_context()) + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 7adee374..74867556 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -13,6 +13,7 @@ from arq import ArqRedis from sqlalchemy.orm import Session +from mavedb.models.enums.job_pipeline import JobStatus from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_job_id, ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers import JobManager from mavedb.worker.lib.managers.types import JobResultData @@ -118,12 +119,20 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar # Execute the async function result = await func(*args, **kwargs) - # Mark job as succeeded and persist state. As a general rule, jobs do not - # commit their own state and we do not persist their state until we mark - # them as succeeded. - job_manager.succeed_job(result=result) + # Move job to final state based on result + if result.get("status") == "failed" or result.get("exception"): + job_manager.fail_job(result=result, error=result["exception"]) + elif result.get("status") == "skipped": + job_manager.skip_job(result=result) + else: + job_manager.succeed_job(result=result) db_session.commit() + # If the job is not marked as succeeded, check if we should retry + if job_manager.get_job_status() != JobStatus.SUCCEEDED and job_manager.should_retry(): + job_manager.prepare_retry(reason="Job did not complete successfully") + db_session.commit() + return result except Exception as e: @@ -132,15 +141,7 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar db_session.rollback() # Build failure result data - result = { - "status": "failed", - "data": {}, - "exception_details": { - "type": type(e).__name__, - "message": str(e), - "traceback": None, # Could be populated with actual traceback if needed - }, - } + result = {"status": "exception", "data": {}, "exception": e} # Mark job as failed job_manager.fail_job(result=result, error=e) @@ -152,8 +153,6 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar job_manager.prepare_retry(reason=str(e)) db_session.commit() - result["status"] = "retried" - # short circuit raising the exception. We indicate to the caller # we did encounter a terminal failure and coordination should proceed. return result diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index b0659a90..ac35ce38 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -170,15 +170,7 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] logger.error(f"Pipeline {pipeline_id} associated with job {job_id} failed to coordinate: {e}") # Build job result data for failure - result = { - "status": "failed", - "data": {}, - "exception_details": { - "type": type(e).__name__, - "message": str(e), - "traceback": None, # Could be populated with actual traceback if needed - }, - } + result = {"status": "failed", "data": {}, "exception": e} # TODO: Notification hooks diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py index f89aecbb..b2269398 100644 --- a/src/mavedb/worker/lib/managers/job_manager.py +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -278,7 +278,13 @@ def complete_job(self, status: JobStatus, result: JobResultData, error: Optional job_run = self.get_job() try: job_run.status = status - job_run.metadata_["result"] = result + job_run.metadata_["result"] = { + "status": result["status"], + "data": result["data"], + "exception_details": format_raised_exception_info_as_dict(result["exception"]) # type: ignore + if result.get("exception") + else None, + } job_run.finished_at = datetime.now() if status == JobStatus.SUCCEEDED: diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py index 74f6d344..0fffe94d 100644 --- a/src/mavedb/worker/lib/managers/pipeline_manager.py +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -390,9 +390,9 @@ async def enqueue_ready_jobs(self) -> None: if should_skip: job_manager.skip_job( { - "output": {}, - "logs": "", - "metadata": {"result": reason, "timestamp": datetime.now().isoformat()}, + "status": "skipped", + "exception": None, + "data": {"result": reason, "timestamp": datetime.now().isoformat()}, } ) logger.info(f"Skipped job {job.urn} due to unreachable dependencies: {reason}") diff --git a/src/mavedb/worker/lib/managers/types.py b/src/mavedb/worker/lib/managers/types.py index e93b2ac2..475b28a2 100644 --- a/src/mavedb/worker/lib/managers/types.py +++ b/src/mavedb/worker/lib/managers/types.py @@ -1,16 +1,10 @@ -from typing import Optional, TypedDict - - -class ExceptionDetails(TypedDict): - type: str - message: str - traceback: Optional[str] +from typing import Literal, Optional, TypedDict class JobResultData(TypedDict): - status: str + status: Literal["ok", "failed", "skipped", "exception", "cancelled"] data: dict - exception_details: Optional[ExceptionDetails] + exception: Optional[Exception] class RetryHistoryEntry(TypedDict): diff --git a/src/mavedb/worker/lib/managers/utils.py b/src/mavedb/worker/lib/managers/utils.py index 91395d4a..975fc7d6 100644 --- a/src/mavedb/worker/lib/managers/utils.py +++ b/src/mavedb/worker/lib/managers/utils.py @@ -31,7 +31,7 @@ def construct_bulk_cancellation_result(reason: str) -> JobResultData: "reason": reason, "timestamp": datetime.now().isoformat(), }, - "exception_details": None, + "exception": None, } diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py index 119bafc3..564c24cb 100644 --- a/tests/worker/jobs/data_management/test_views.py +++ b/tests/worker/jobs/data_management/test_views.py @@ -37,7 +37,7 @@ async def test_refresh_materialized_views_calls_refresh_function(self, mock_work result = await refresh_materialized_views(mock_worker_ctx, 999, job_manager=mock_job_manager) mock_refresh.assert_called_once_with(mock_job_manager.db) - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} async def test_refresh_materialized_views_updates_progress(self, mock_worker_ctx, mock_job_manager): """Test that refresh_materialized_views updates progress correctly.""" @@ -53,7 +53,7 @@ async def test_refresh_materialized_views_updates_progress(self, mock_worker_ctx call(100, 100, "Completed refresh of all materialized views."), ] mock_update_progress.assert_has_calls(expected_calls) - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} @pytest.mark.asyncio @@ -75,7 +75,7 @@ async def test_refresh_materialized_views_integration(self, standalone_worker_co assert job.status == JobStatus.SUCCEEDED assert job.job_type == "cron_job" - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} async def test_refresh_materialized_views_handles_exceptions(self, standalone_worker_context, session): """Integration test that ensures exceptions during refresh are handled properly.""" @@ -97,7 +97,8 @@ async def test_refresh_materialized_views_handles_exceptions(self, standalone_wo assert job.status == JobStatus.FAILED assert job.job_type == "cron_job" assert job.error_message == "Test exception during refresh" - assert result["exception_details"]["message"] == "Test exception during refresh" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) @pytest.mark.asyncio @@ -145,7 +146,7 @@ async def test_refresh_published_variants_view_calls_refresh_function( result = await refresh_published_variants_view(mock_worker_ctx, 999, job_manager=mock_job_manager) mock_refresh.assert_called_once_with(mock_job_manager.db) - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} async def test_refresh_published_variants_view_updates_progress( self, mock_worker_ctx, mock_job_manager, mock_job_run @@ -166,7 +167,7 @@ async def test_refresh_published_variants_view_updates_progress( call(100, 100, "Completed refresh of published variants materialized view."), ] mock_update_progress.assert_has_calls(expected_calls) - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} @pytest.mark.asyncio @@ -197,7 +198,7 @@ async def test_refresh_published_variants_view_integration_standalone( session.refresh(setup_refresh_job_run) assert setup_refresh_job_run.status == JobStatus.SUCCEEDED - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} async def test_refresh_published_variants_view_integration_pipeline( self, standalone_worker_context, session, setup_refresh_job_run @@ -220,7 +221,7 @@ async def test_refresh_published_variants_view_integration_pipeline( session.refresh(setup_refresh_job_run) assert setup_refresh_job_run.status == JobStatus.SUCCEEDED - assert result == {"status": "ok", "data": {}, "exception_details": None} + assert result == {"status": "ok", "data": {}, "exception": None} session.refresh(pipeline) assert pipeline.status == PipelineStatus.SUCCEEDED @@ -241,7 +242,8 @@ async def test_refresh_published_variants_view_handles_exceptions( session.refresh(setup_refresh_job_run) assert setup_refresh_job_run.status == JobStatus.FAILED assert setup_refresh_job_run.error_message == "Test exception during published variants view refresh" - assert result["exception_details"]["message"] == "Test exception during published variants view refresh" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) async def test_refresh_published_variants_view_requires_params( self, setup_refresh_job_run, standalone_worker_context, session @@ -257,7 +259,8 @@ async def test_refresh_published_variants_view_requires_params( session.refresh(setup_refresh_job_run) assert setup_refresh_job_run.status == JobStatus.FAILED assert "Job has no job_params defined" in setup_refresh_job_run.error_message - assert "Job has no job_params defined" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) @pytest.mark.asyncio diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index 1b042a76..aaa813ed 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -4,6 +4,7 @@ import pytest from sqlalchemy import select +from mavedb.lib.exceptions import LDHSubmissionFailureError from mavedb.lib.variants import get_hgvs_from_post_mapped from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.mapped_variant import MappedVariant @@ -44,7 +45,7 @@ async def test_submit_score_set_mappings_to_car_submission_disabled( ) mock_update_progress.assert_called_with(100, 100, "ClinGen submission is disabled. Skipping CAR submission.") - assert result["status"] == "ok" + assert result["status"] == "skipped" # Verify no variants have CAIDs assigned variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() @@ -88,9 +89,8 @@ async def test_submit_score_set_mappings_to_car_submission_endpoint_not_set( patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", ""), patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), patch.object(JobManager, "update_progress", return_value=None) as mock_update_progress, - pytest.raises(ValueError), ): - await submit_score_set_mappings_to_car( + result = await submit_score_set_mappings_to_car( mock_worker_ctx, submit_score_set_mappings_to_car_sample_job_run.id, JobManager(session, mock_worker_ctx["redis"], submit_score_set_mappings_to_car_sample_job_run.id), @@ -99,6 +99,8 @@ async def test_submit_score_set_mappings_to_car_submission_endpoint_not_set( mock_update_progress.assert_called_with( 100, 100, "CAR submission endpoint not configured. Can't complete submission." ) + assert result["status"] == "failed" + assert isinstance(result["exception"], ValueError) # Verify no variants have CAIDs assigned variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() @@ -716,7 +718,7 @@ async def test_submit_score_set_mappings_to_car_submission_disabled( standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id ) - assert result["status"] == "ok" + assert result["status"] == "skipped" # Verify no variants have CAIDs assigned variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() @@ -728,7 +730,7 @@ async def test_submit_score_set_mappings_to_car_submission_disabled( # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) - assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SUCCEEDED + assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.SKIPPED async def test_submit_score_set_mappings_to_car_no_submission_endpoint( self, @@ -753,9 +755,7 @@ async def test_submit_score_set_mappings_to_car_no_submission_endpoint( ) assert result["status"] == "failed" - assert ( - result["exception_details"]["message"] == "ClinGen Allele Registry submission endpoint is not configured." - ) + assert isinstance(result["exception"], ValueError) # Verify no variants have CAIDs assigned variants = session.scalars(select(MappedVariant).where(MappedVariant.clingen_allele_id.isnot(None))).all() @@ -947,8 +947,9 @@ async def test_submit_score_set_mappings_to_car_propagates_exception_to_decorato standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id ) - assert result["status"] == "failed" - assert result["exception_details"]["message"] == "ClinGen service error" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + assert str(result["exception"]) == "ClinGen service error" # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) @@ -1298,7 +1299,7 @@ async def dummy_submission_failure(*args, **kwargs): ) assert result["status"] == "failed" - assert "All LDH submissions failed for score set" in result["exception_details"]["message"] + assert isinstance(result["exception"], LDHSubmissionFailureError) mock_update_progress.assert_called_with(100, 100, "All mapped variant submissions to LDH failed.") async def test_submit_score_set_mappings_to_ldh_hgvs_not_found( @@ -1700,8 +1701,9 @@ async def test_submit_score_set_mappings_to_ldh_propagates_exception_to_decorato standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id ) - assert result["status"] == "failed" - assert result["exception_details"]["message"] == "LDH service error" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) + assert str(result["exception"]) == "LDH service error" # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) @@ -1847,7 +1849,7 @@ async def dummy_submission_failure(*args, **kwargs): ) assert result["status"] == "failed" - assert "All LDH submissions failed for score set" in result["exception_details"]["message"] + assert isinstance(result["exception"], LDHSubmissionFailureError) # Verify annotation statuses were created with failures annotation_statuses = session.scalars( @@ -1860,7 +1862,7 @@ async def dummy_submission_failure(*args, **kwargs): # Verify the job status is updated in the database # TODO:XXX: Change status to 'failed' once decorator supports it session.refresh(submit_score_set_mappings_to_ldh_sample_job_run) - assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.SUCCEEDED + assert submit_score_set_mappings_to_ldh_sample_job_run.status == JobStatus.FAILED async def test_submit_score_set_mappings_to_ldh_partial_submission( self, diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index 17fb3ec1..eac1086a 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -347,8 +347,8 @@ async def test_link_gnomad_variants_exceptions_handled_by_decorators( sample_link_gnomad_variants_run.id, ) - assert result["status"] == "failed" - assert "Test exception" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) # Verify job status updates session.refresh(sample_link_gnomad_variants_run) diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index 3a543544..a12534d2 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -241,9 +241,8 @@ async def test_submit_uniprot_mapping_jobs_raises_dependent_job_not_available( return_value="job_12345", ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(UniProtPollingEnqueueError), ): - await submit_uniprot_mapping_jobs_for_score_set( + result = await submit_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, 1, JobManager( @@ -254,6 +253,8 @@ async def test_submit_uniprot_mapping_jobs_raises_dependent_job_not_available( ) mock_update_progress.assert_called_with(100, 100, "Failed to submit UniProt mapping jobs.") + assert result["status"] == "failed" + assert isinstance(result["exception"], UniProtPollingEnqueueError) # Verify that the job metadata contains the submitted jobs (which were submitted before the error) session.refresh(sample_submit_uniprot_mapping_jobs_run) @@ -673,8 +674,8 @@ async def test_submit_uniprot_mapping_jobs_propagates_exceptions( mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id ) - assert result["status"] == "failed" - assert "UniProt API failure" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) # Verify that the job metadata contains no submitted jobs session.refresh(sample_submit_uniprot_mapping_jobs_run) @@ -814,10 +815,7 @@ async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( ) assert result["status"] == "failed" - assert ( - "Could not find unique dependent polling job for UniProt mapping job" - in result["exception_details"]["message"] - ) + assert isinstance(result["exception"], UniProtPollingEnqueueError) # Verify that the job metadata contains the job we submitted before the error session.refresh(sample_submit_uniprot_mapping_jobs_run) @@ -828,7 +826,7 @@ async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( # Verify that the submission job failed session.refresh(sample_submit_uniprot_mapping_jobs_run) # TODO#XXX: Should be failed when supported by decorator - assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.SUCCEEDED + assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED # nothing to verify for dependent polling job since it does not exist @@ -1691,8 +1689,8 @@ async def test_poll_uniprot_mapping_jobs_no_results( mock_worker_ctx, sample_polling_job_for_submission_run.id ) - assert result["status"] == "failed" - assert result["exception_details"]["type"] == "UniprotMappingResultNotFoundError" + assert result["status"] == "exception" + assert isinstance(result["exception"], UniprotMappingResultNotFoundError) # Verify the target gene uniprot id remains unchanged session.refresh(sample_score_set) @@ -1748,8 +1746,8 @@ async def test_poll_uniprot_mapping_jobs_ambiguous_results( mock_worker_ctx, sample_polling_job_for_submission_run.id ) - assert result["status"] == "failed" - assert result["exception_details"]["type"] == "UniprotAmbiguousMappingResultError" + assert result["status"] == "exception" + assert isinstance(result["exception"], UniprotAmbiguousMappingResultError) # Verify the target gene uniprot id remains unchanged session.refresh(sample_score_set) @@ -1788,8 +1786,8 @@ async def test_poll_uniprot_mapping_jobs_nonexistent_target( mock_worker_ctx, sample_polling_job_for_submission_run.id ) - assert result["status"] == "failed" - assert result["exception_details"]["type"] == "NonExistentTargetGeneError" + assert result["status"] == "exception" + assert isinstance(result["exception"], NonExistentTargetGeneError) # Verify the target gene uniprot id remains unchanged session.refresh(sample_score_set) @@ -1822,8 +1820,8 @@ async def test_poll_uniprot_mapping_jobs_propagates_exceptions_to_decorator( mock_worker_ctx, sample_polling_job_for_submission_run.id ) - assert result["status"] == "failed" - assert result["exception_details"]["message"] == "UniProt API failure" + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) # Verify the target gene uniprot id remains unchanged session.refresh(sample_score_set) diff --git a/tests/worker/jobs/pipeline_management/test_start_pipeline.py b/tests/worker/jobs/pipeline_management/test_start_pipeline.py index 9f70d9f1..5f2d88ac 100644 --- a/tests/worker/jobs/pipeline_management/test_start_pipeline.py +++ b/tests/worker/jobs/pipeline_management/test_start_pipeline.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy import select +from mavedb.lib.exceptions import PipelineNotFoundError from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.job_run import JobRun from mavedb.worker.jobs.pipeline_management.start_pipeline import start_pipeline @@ -42,12 +43,14 @@ async def test_start_pipeline_raises_exception_when_no_pipeline_associated_with_ setup_start_pipeline_job_run.pipeline_id = None session.commit() - with pytest.raises(ValueError, match="No pipeline associated with job"): - await start_pipeline( - mock_worker_ctx, - setup_start_pipeline_job_run.id, - JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), - ) + result = await start_pipeline( + mock_worker_ctx, + setup_start_pipeline_job_run.id, + JobManager(session, mock_worker_ctx["redis"], setup_start_pipeline_job_run.id), + ) + + assert result["status"] == "exception" + assert isinstance(result["exception"], PipelineNotFoundError) async def test_start_pipeline_starts_pipeline_successfully( self, @@ -153,7 +156,7 @@ async def test_start_pipeline_on_job_without_pipeline_fails( session.commit() result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) - assert result["status"] == "failed" + assert result["status"] == "exception" # Verify the start job run status session.refresh(sample_dummy_pipeline_start) @@ -204,7 +207,7 @@ async def custom_side_effect(*args, **kwargs): side_effect=custom_side_effect, ): result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) - assert result["status"] == "failed" + assert result["status"] == "exception" # Verify the start job run status session.refresh(sample_dummy_pipeline_start) diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index 5b93e15a..dadb74db 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -108,8 +108,8 @@ async def test_create_variants_for_score_set_s3_file_not_found( ) mock_update_progress.assert_any_call(100, 100, "Variant creation job failed due to an internal error.") - assert result["status"] == "failed" - assert "The specified key does not exist." in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed assert sample_score_set.mapping_state == MappingState.not_attempted @@ -194,8 +194,8 @@ async def test_create_variants_for_score_set_raises_when_no_targets_exist( ) mock_update_progress.assert_any_call(100, 100, "Score set has no targets; cannot create variants.") - assert result["status"] == "failed" - assert "Can't create variants when score set has no targets." in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], ValueError) async def test_create_variants_for_score_set_calls_validate_standardize_dataframe_with_correct_parameters( self, @@ -563,8 +563,8 @@ async def test_create_variants_for_score_set_retains_existing_variants_when_exce JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) - assert result["status"] == "failed" - assert "Test exception during data validation" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) # Verify that existing variants are still present remaining_variants = session.query(Variant).filter(Variant.score_set_id == sample_score_set.id).all() @@ -604,8 +604,8 @@ async def test_create_variants_for_score_set_handles_exception_and_updates_state JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_creation_run.id), ) - assert result["status"] == "failed" - assert "Test exception during data validation" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert isinstance(result["exception"], Exception) # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) @@ -960,7 +960,7 @@ async def test_create_variants_for_score_set_validation_error_during_creation( .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.SUCCEEDED + assert job_run.status == JobStatus.FAILED async def test_create_variants_for_score_set_generic_exception_handling_during_creation( self, @@ -1002,7 +1002,7 @@ async def test_create_variants_for_score_set_generic_exception_handling_during_c .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.SUCCEEDED + assert job_run.status == JobStatus.FAILED async def test_create_variants_for_score_set_generic_exception_handling_during_replacement( self, @@ -1065,7 +1065,7 @@ async def test_create_variants_for_score_set_generic_exception_handling_during_r .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.SUCCEEDED + assert job_run.status == JobStatus.FAILED ## Pipeline failure workflow @@ -1110,11 +1110,11 @@ async def test_create_variants_for_score_set_pipeline_job_generic_exception_hand .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.SUCCEEDED + assert job_run.status == JobStatus.FAILED # Verify that pipeline status is updated. session.refresh(sample_variant_creation_pipeline) - assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING + assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED # Verify other pipeline runs are marked as failed other_runs = ( session.query(Pipeline) @@ -1125,7 +1125,7 @@ async def test_create_variants_for_score_set_pipeline_job_generic_exception_hand .all() ) for run in other_runs: - assert run.status == JobStatus.PENDING + assert run.status == JobStatus.SKIPPED @pytest.mark.asyncio @@ -1319,7 +1319,7 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.SUCCEEDED + assert job_run.status == JobStatus.FAILED async def test_create_variants_for_score_set_with_arq_context_generic_exception_handling_pipeline_ctx( self, @@ -1365,11 +1365,11 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ .one() ) assert job_run.progress_current == 100 - assert job_run.status == JobStatus.SUCCEEDED + assert job_run.status == JobStatus.FAILED # Verify that pipeline status is updated. session.refresh(sample_variant_creation_pipeline) - assert sample_variant_creation_pipeline.status == PipelineStatus.RUNNING + assert sample_variant_creation_pipeline.status == PipelineStatus.FAILED # Verify other pipeline runs are marked as cancelled other_runs = ( @@ -1381,4 +1381,4 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ .all() ) for run in other_runs: - assert run.status == JobStatus.PENDING + assert run.status == JobStatus.SKIPPED diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index a7cc1412..79e763f0 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import NoResultFound from mavedb.lib.exceptions import ( + NoMappedVariantsError, NonexistentMappingReferenceError, NonexistentMappingResultsError, NonexistentMappingScoresError, @@ -46,15 +47,17 @@ async def test_map_variants_for_score_set_no_mapping_results( with ( patch.object(_UnixSelectorEventLoop, "run_in_executor", return_value=self.dummy_mapping_output({})), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(NonexistentMappingResultsError), ): - await map_variants_for_score_set( + result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing results.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingResultsError) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -93,15 +96,17 @@ async def test_map_variants_for_score_set_no_mapped_scores( ), ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(NonexistentMappingScoresError), ): - await map_variants_for_score_set( + result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed; no variants were mapped.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingScoresError) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -137,15 +142,17 @@ async def test_map_variants_for_score_set_no_reference_data( ), ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(NonexistentMappingReferenceError), ): - await map_variants_for_score_set( + result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to missing reference metadata.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NonexistentMappingReferenceError) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -184,15 +191,17 @@ async def test_map_variants_for_score_set_nonexistent_target_gene( ), ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(ValueError), ): - await map_variants_for_score_set( + result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], ValueError) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -234,15 +243,17 @@ async def test_map_variants_for_score_set_returns_variants_not_in_score_set( return_value=self.dummy_mapping_output(mapping_output), ), patch.object(JobManager, "update_progress") as mock_update_progress, - pytest.raises(NoResultFound), ): - await map_variants_for_score_set( + result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) mock_update_progress.assert_any_call(100, 100, "Variant mapping failed due to an unexpected error.") + assert result["status"] == "exception" + assert result["data"] == {} + assert isinstance(result["exception"], NoResultFound) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -307,7 +318,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.complete assert sample_score_set.mapping_errors is None @@ -391,7 +402,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.complete assert sample_score_set.mapping_errors is None @@ -494,9 +505,9 @@ async def dummy_mapping_job(): JobManager(session, mock_worker_ctx["redis"], sample_independent_variant_mapping_run.id), ) - assert result["status"] == "error" + assert result["status"] == "failed" assert result["data"] == {} - assert result["exception_details"] is None + assert isinstance(result["exception"], NoMappedVariantsError) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors["error_message"] == "All variants failed to map." @@ -578,7 +589,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.incomplete assert sample_score_set.mapping_errors is None @@ -675,7 +686,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.complete assert sample_score_set.mapping_errors is None @@ -763,7 +774,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.complete assert sample_score_set.mapping_errors is None @@ -855,7 +866,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.complete assert sample_score_set.mapping_errors is None @@ -928,7 +939,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None # Verify that mapped variants were created mapped_variants = session.query(MappedVariant).all() @@ -1020,7 +1031,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None # Verify that mapped variants were created mapped_variants = session.query(MappedVariant).all() @@ -1110,8 +1121,8 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "error" - assert result["exception_details"]["type"] == "NonexistentMappingResultsError" + assert result["status"] == "exception" + assert isinstance(result["exception"], NonexistentMappingResultsError) assert result["data"] == {} assert sample_score_set.mapping_state == MappingState.failed @@ -1135,7 +1146,7 @@ async def dummy_mapping_job(): .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED async def test_map_variants_for_score_set_no_mapped_scores( self, @@ -1188,8 +1199,8 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "error" - assert result["exception_details"]["type"] == "NonexistentMappingScoresError" + assert result["status"] == "exception" + assert isinstance(result["exception"], NonexistentMappingScoresError) assert result["data"] == {} assert sample_score_set.mapping_state == MappingState.failed @@ -1211,7 +1222,7 @@ async def dummy_mapping_job(): .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED async def test_map_variants_for_score_set_no_reference_data( self, @@ -1264,8 +1275,8 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "error" - assert result["exception_details"]["type"] == "NonexistentMappingReferenceError" + assert result["status"] == "exception" + assert isinstance(result["exception"], NonexistentMappingReferenceError) assert result["data"] == {} assert sample_score_set.mapping_state == MappingState.failed @@ -1286,7 +1297,7 @@ async def dummy_mapping_job(): .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED async def test_map_variants_for_score_set_updates_current_mapped_variants( self, @@ -1357,7 +1368,7 @@ async def dummy_mapping_job(): assert result["status"] == "ok" assert result["data"] == {} - assert result["exception_details"] is None + assert result["exception"] is None assert sample_score_set.mapping_state == MappingState.complete assert sample_score_set.mapping_errors is None @@ -1447,10 +1458,9 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "error" + assert result["status"] == "exception" assert result["data"] == {} - assert result["exception_details"] is not None - assert result["exception_details"]["type"] == "NonexistentMappingScoresError" + assert isinstance(result["exception"], NonexistentMappingScoresError) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -1470,7 +1480,7 @@ async def dummy_mapping_job(): .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED async def test_map_variants_for_score_set_exception_in_mapping( self, @@ -1499,11 +1509,11 @@ async def dummy_mapping_job(): sample_independent_variant_mapping_run.id, ) - assert result["status"] == "error" + assert result["status"] == "exception" assert result["data"] == {} - assert result["exception_details"]["type"] == "ValueError" + assert isinstance(result["exception"], ValueError) # exception messages are persisted in internal properties - assert "test exception during mapping" in result["exception_details"]["message"] + assert "test exception during mapping" in str(result["exception"]) assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None @@ -1527,7 +1537,7 @@ async def dummy_mapping_job(): .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED @pytest.mark.integration @@ -1767,7 +1777,7 @@ async def dummy_mapping_job(): .filter(sample_independent_variant_mapping_run.__class__.id == sample_independent_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED async def test_map_variants_for_score_set_with_arq_context_generic_exception_in_pipeline_ctx( self, @@ -1819,7 +1829,7 @@ async def dummy_mapping_job(): .filter(sample_pipeline_variant_mapping_run.__class__.id == sample_pipeline_variant_mapping_run.id) .one() ) - assert processing_run.status == JobStatus.SUCCEEDED + assert processing_run.status == JobStatus.FAILED # Verify that the pipeline run status was updated to FAILED. pipeline_run = ( @@ -1830,9 +1840,9 @@ async def dummy_mapping_job(): ) .one() ) - assert pipeline_run.status == PipelineStatus.RUNNING + assert pipeline_run.status == PipelineStatus.FAILED # Verify that other jobs in the pipeline were skipped for job_run in pipeline_run.job_runs: if job_run.id != sample_pipeline_variant_mapping_run.id: - assert job_run.status == JobStatus.QUEUED + assert job_run.status == JobStatus.SKIPPED diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index 2462b4b6..aa80fc6e 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -91,6 +91,51 @@ async def test_decorator_calls_start_job_and_succeed_job_when_wrapped_function_s mock_start_job.assert_called_once() mock_succeed_job.assert_called_once() + @pytest.mark.parametrize( + "status", + [ + "failed", + "exception", + ], + ) + async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_returns_failed_status( + self, session, mock_worker_ctx, mock_job_manager, status + ): + @with_job_management + async def sample_fail(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": status, "data": {}, "exception": RuntimeError("simulated failure")} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "fail_job", return_value=None) as mock_fail_job, + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_fail(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_fail_job.assert_called_once() + + async def test_decorator_calls_start_job_and_skip_job_when_wrapped_function_returns_skipped_status( + self, session, mock_worker_ctx, mock_job_manager + ): + @with_job_management + async def sample_skip(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "skipped", "data": {}, "exception": None} + + with ( + patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, + patch.object(mock_job_manager, "skip_job", return_value=None) as mock_skip_job, + TransactionSpy.spy(session, expect_commit=True), + ): + mock_job_manager_class.return_value = mock_job_manager + await sample_skip(mock_worker_ctx, 999) + + mock_start_job.assert_called_once() + mock_skip_job.assert_called_once() + async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_raises_and_no_retry( self, session, mock_worker_ctx, mock_job_manager ): @@ -138,9 +183,10 @@ async def test_decorator_raises_value_error_if_required_context_missing( async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( self, session, mock_job_manager, mock_worker_ctx ): + raised_exc = JobStateError("error in job start") with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, - patch.object(mock_job_manager, "start_job", side_effect=JobStateError("error in job start")), + patch.object(mock_job_manager, "start_job", side_effect=raised_exc), patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", return_value=None), TransactionSpy.spy(session, expect_rollback=True, expect_commit=True), @@ -148,7 +194,8 @@ async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( mock_job_manager_class.return_value = mock_job_manager result = await sample_job(mock_worker_ctx, 999) - assert "error in job start" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert raised_exc == result["exception"] async def test_decorator_raises_value_error_if_job_id_missing(self, session, mock_job_manager, mock_worker_ctx): # Remove job_id from args to simulate missing job_id @@ -171,13 +218,14 @@ async def test_decorator_swallows_exception_from_wrapped_function_inside_except( result = await sample_raise(mock_worker_ctx, 999) # Errors within the main try block should take precedence - assert "error in wrapped function" in result["exception_details"]["message"] + assert result["status"] == "exception" + assert str(result["exception"]) == "error in wrapped function" async def test_decorator_passes_job_manager_to_wrapped(self, session, mock_job_manager, mock_worker_ctx): @with_job_management async def assert_manager_passed_job(ctx, job_id: int, job_manager): assert isinstance(job_manager, JobManager) - return True + return {"status": "ok", "data": {}, "exception": None} with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, @@ -203,7 +251,7 @@ async def test_decorator_integrated_job_lifecycle_success( @with_job_management async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals - return {"status": "ok"} + return {"status": "ok", "data": {}, "exception": None} # Start the job (it will block at event.wait()) job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) @@ -221,7 +269,36 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.SUCCEEDED - async def test_decorator_integrated_job_lifecycle_failure( + async def test_decorator_integrated_job_lifecycle_skipped( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "skipped", "data": {}, "exception": None} + + # Run the job + await sample_job(standalone_worker_context, sample_job_run.id) + + # After completion, status should be SKIPPED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.SKIPPED + + async def test_decorator_integrated_job_lifecycle_failed( + self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data + ): + @with_job_management + async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): + return {"status": "failed", "data": {}, "exception": RuntimeError("Simulated job failure")} + + # Run the job + await sample_job(standalone_worker_context, sample_job_run.id) + + # After completion, status should be FAILED + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.FAILED + assert job.error_message == "Simulated job failure" + + async def test_decorator_integrated_job_lifecycle_raised_exception( self, session, arq_redis, sample_job_run, standalone_worker_context, with_populated_job_data ): # Use an event to control when the job completes diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index 721bb0c8..dcd5862c 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -301,12 +301,12 @@ async def test_decorator_integrated_pipeline_lifecycle_success( @with_pipeline_management async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals - return {"status": "ok"} + return {"status": "ok", "data": {}, "exception": None} @with_pipeline_management async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): await dep_event.wait() # Simulate async work, block until test signals - return {"status": "ok"} + return {"status": "ok", "data": {}, "exception": None} # Start the job (it will block at event.wait()) job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) @@ -392,12 +392,12 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): @with_pipeline_management async def sample_retried_job(ctx: dict, job_id: int, job_manager: JobManager): await retry_event.wait() # Simulate async work, block until test signals - return {"status": "ok"} + return {"status": "ok", "data": {}, "exception": None} @with_pipeline_management async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): await dep_event.wait() # Simulate async work, block until test signals - return {"status": "ok"} + return {"status": "ok", "data": {}, "exception": None} # Start the job (it will block at event.wait()) job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py index 3806ac68..4b3cde68 100644 --- a/tests/worker/lib/managers/test_job_manager.py +++ b/tests/worker/lib/managers/test_job_manager.py @@ -8,6 +8,8 @@ import pytest +from mavedb.lib.logging.context import format_raised_exception_info_as_dict + pytest.importorskip("arq") import re @@ -296,12 +298,20 @@ def test_complete_job_sets_default_failure_category_when_job_failed(self, mock_j # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(mock_job_manager.db): - mock_job_manager.complete_job(status=JobStatus.FAILED, result={}) + mock_job_manager.complete_job( + status=JobStatus.FAILED, result={"status": "failed", "data": {}, "exception": Exception()} + ) # Verify job state was updated on our mock object with expected values. assert mock_job_run.status == JobStatus.FAILED assert mock_job_run.finished_at is not None - assert mock_job_run.metadata_ == {"result": {}} + assert mock_job_run.metadata_ == { + "result": { + "status": "failed", + "data": {}, + "exception_details": format_raised_exception_info_as_dict(Exception()), + } + } assert mock_job_run.progress_message == "Job failed" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None @@ -320,12 +330,20 @@ def test_complete_job_success(self, mock_job_manager, valid_status, exception, m # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(mock_job_manager.db): - mock_job_manager.complete_job(status=valid_status, result={"output": "test"}, error=exception) + mock_job_manager.complete_job( + status=valid_status, + result={"status": "ok", "data": {"output": "test"}, "exception": exception}, + error=exception, + ) # Verify job state was updated on our mock object with expected values. assert mock_job_run.status == valid_status assert mock_job_run.finished_at is not None - assert mock_job_run.metadata_["result"] == {"output": "test"} + assert mock_job_run.metadata_["result"] == { + "status": "ok", + "data": {"output": "test"}, + "exception_details": format_raised_exception_info_as_dict(exception) if exception else None, + } assert mock_job_run.progress_message is not None # If an exception was provided, verify error fields are set appropriately. @@ -383,7 +401,9 @@ def test_job_updated_successfully_without_error( # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(manager.db): - manager.complete_job(status=valid_status, result={"output": "test"}) + manager.complete_job( + status=valid_status, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) # Commit pending changes made by start job. session.flush() @@ -393,7 +413,7 @@ def test_job_updated_successfully_without_error( assert job.status == valid_status assert job.finished_at is not None - assert job.metadata_ == {"result": {"output": "test"}} + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None assert job.error_traceback is None @@ -416,7 +436,15 @@ def test_job_updated_successfully_with_error( # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(manager.db): - manager.complete_job(status=valid_status, result={"output": "test"}, error=ValueError("Test error")) + manager.complete_job( + status=valid_status, + result={ + "status": "ok", + "data": {"output": "test"}, + "exception": ValueError("Test error"), + }, + error=ValueError("Test error"), + ) # Commit pending changes made by start job. session.flush() @@ -426,7 +454,13 @@ def test_job_updated_successfully_with_error( assert job.status == valid_status assert job.finished_at is not None - assert job.metadata_ == {"result": {"output": "test"}} + assert job.metadata_ == { + "result": { + "status": "ok", + "data": {"output": "test"}, + "exception_details": format_raised_exception_info_as_dict(ValueError("Test error")), + } + } assert job.error_message == "Test error" assert job.error_traceback is not None assert job.failure_category == FailureCategory.UNKNOWN @@ -446,17 +480,28 @@ def test_fail_job_success(self, mock_job_manager, mock_job_run): patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, TransactionSpy.spy(mock_job_manager.db), ): - mock_job_manager.fail_job(error=test_exception, result={"output": "test"}) + mock_job_manager.fail_job( + error=test_exception, + result={"status": "failed", "data": {"output": "test"}, "exception": test_exception}, + ) # Verify this function is a thin wrapper around complete_job with expected parameters. mock_complete_job.assert_called_once_with( - status=JobStatus.FAILED, result={"output": "test"}, error=test_exception + status=JobStatus.FAILED, + result={"status": "failed", "data": {"output": "test"}, "exception": test_exception}, + error=test_exception, ) # Verify job state was updated on our mock object with expected values. assert mock_job_run.status == JobStatus.FAILED assert mock_job_run.finished_at is not None - assert mock_job_run.metadata_ == {"result": {"output": "test"}} + assert mock_job_run.metadata_ == { + "result": { + "status": "failed", + "data": {"output": "test"}, + "exception_details": format_raised_exception_info_as_dict(test_exception), + } + } assert mock_job_run.progress_message == "Job failed" assert mock_job_run.error_message == str(test_exception) assert mock_job_run.error_traceback is not None @@ -471,8 +516,9 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d manager = JobManager(session, arq_redis, sample_job_run.id) # Fail job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. + exc = ValueError("Test error") with TransactionSpy.spy(manager.db): - manager.fail_job(result={"output": "test"}, error=ValueError("Test error")) + manager.fail_job(result={"status": "failed", "data": {}, "exception": exc}, error=exc) # Commit pending changes made by fail job. session.flush() @@ -482,7 +528,9 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.FAILED assert job.finished_at is not None - assert job.metadata_ == {"result": {"output": "test"}} + assert job.metadata_ == { + "result": {"status": "failed", "data": {}, "exception_details": format_raised_exception_info_as_dict(exc)} + } assert job.progress_message == "Job failed" assert job.error_message == "Test error" assert job.error_traceback is not None @@ -501,15 +549,19 @@ def test_succeed_job_success(self, mock_job_manager, mock_job_run): patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, TransactionSpy.spy(mock_job_manager.db), ): - mock_job_manager.succeed_job(result={"output": "test"}) + mock_job_manager.succeed_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) # Verify this function is a thin wrapper around complete_job with expected parameters. - mock_complete_job.assert_called_once_with(status=JobStatus.SUCCEEDED, result={"output": "test"}) + mock_complete_job.assert_called_once_with( + status=JobStatus.SUCCEEDED, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) # Verify job state was updated on our mock object with expected values. assert mock_job_run.status == JobStatus.SUCCEEDED assert mock_job_run.finished_at is not None - assert mock_job_run.metadata_ == {"result": {"output": "test"}} + assert mock_job_run.metadata_ == { + "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} + } assert mock_job_run.progress_message == "Job completed successfully" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None @@ -525,7 +577,7 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(manager.db): - manager.succeed_job(result={"output": "test"}) + manager.succeed_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) # Commit pending changes made by start job. session.flush() @@ -536,7 +588,7 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.SUCCEEDED assert job.finished_at is not None assert job.progress_message == "Job completed successfully" - assert job.metadata_ == {"result": {"output": "test"}} + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None assert job.error_traceback is None assert job.failure_category is None @@ -554,15 +606,19 @@ def test_cancel_job_success(self, mock_job_manager, mock_job_run): patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, TransactionSpy.spy(mock_job_manager.db), ): - mock_job_manager.cancel_job(result={"error": "Job was cancelled"}) + mock_job_manager.cancel_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) # Verify this function is a thin wrapper around complete_job with expected parameters. - mock_complete_job.assert_called_once_with(status=JobStatus.CANCELLED, result={"error": "Job was cancelled"}) + mock_complete_job.assert_called_once_with( + status=JobStatus.CANCELLED, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) # Verify job state was updated on our mock object with expected values. assert mock_job_run.status == JobStatus.CANCELLED assert mock_job_run.finished_at is not None - assert mock_job_run.metadata_ == {"result": {"error": "Job was cancelled"}} + assert mock_job_run.metadata_ == { + "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} + } assert mock_job_run.progress_message == "Job cancelled" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None @@ -578,7 +634,7 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d # Complete job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(manager.db): - manager.cancel_job(result={"output": "test"}) + manager.cancel_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) # Commit pending changes made by start job. session.flush() @@ -589,7 +645,7 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.CANCELLED assert job.progress_message == "Job cancelled" assert job.finished_at is not None - assert job.metadata_ == {"result": {"output": "test"}} + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None assert job.error_traceback is None assert job.failure_category is None @@ -607,15 +663,19 @@ def test_skip_job_success(self, mock_job_manager, mock_job_run): patch.object(mock_job_manager, "complete_job", wraps=mock_job_manager.complete_job) as mock_complete_job, TransactionSpy.spy(mock_job_manager.db), ): - mock_job_manager.skip_job(result={"output": "test"}) + mock_job_manager.skip_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) # Verify this function is a thin wrapper around complete_job with expected parameters. - mock_complete_job.assert_called_once_with(status=JobStatus.SKIPPED, result={"output": "test"}) + mock_complete_job.assert_called_once_with( + status=JobStatus.SKIPPED, result={"status": "ok", "data": {"output": "test"}, "exception": None} + ) # Verify job state was updated on our mock object with expected values. assert mock_job_run.status == JobStatus.SKIPPED assert mock_job_run.finished_at is not None - assert mock_job_run.metadata_ == {"result": {"output": "test"}} + assert mock_job_run.metadata_ == { + "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} + } assert mock_job_run.progress_message == "Job skipped" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None @@ -632,7 +692,7 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d # Skip job. Spy on transaction to ensure nothing is flushed/rolled back/committed prematurely. with TransactionSpy.spy(manager.db): - manager.skip_job(result={"output": "test"}) + manager.skip_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) # Commit pending changes made by start job. session.flush() @@ -643,7 +703,7 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.SKIPPED assert job.progress_message == "Job skipped" assert job.finished_at is not None - assert job.metadata_ == {"result": {"output": "test"}} + assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None assert job.error_traceback is None assert job.failure_category is None @@ -1896,7 +1956,7 @@ def test_full_successful_job_lifecycle(self, session, arq_redis, with_populated_ # Complete job with TransactionSpy.spy(manager.db): - manager.succeed_job(result={"output": "success"}) + manager.succeed_job(result={"status": "ok", "data": {"output": "test"}, "exception": None}) session.flush() job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -1940,7 +2000,7 @@ def test_full_cancelled_job_lifecycle(self, session, arq_redis, with_populated_j # Cancel job with TransactionSpy.spy(manager.db): - manager.cancel_job({"reason": "User requested cancellation"}) + manager.cancel_job({"status": "ok", "data": {"reason": "User requested cancellation"}, "exception": None}) session.flush() # Verify job is cancelled @@ -1961,7 +2021,7 @@ def test_full_skipped_job_lifecycle(self, session, arq_redis, with_populated_job # Skip job with TransactionSpy.spy(manager.db): - manager.skip_job(result={"reason": "Precondition not met"}) + manager.skip_job(result={"status": "ok", "data": {"reason": "Job not needed"}, "exception": None}) session.flush() job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -1994,11 +2054,9 @@ def test_full_failed_job_lifecycle(self, session, arq_redis, with_populated_job_ assert job.status == JobStatus.RUNNING # Fail job + exc = Exception("An error occurred") with TransactionSpy.spy(manager.db): - manager.fail_job( - error=Exception("An error occurred"), - result={"details": "Traceback details here"}, - ) + manager.fail_job(error=exc, result={"status": "failed", "data": {}, "exception": exc}) session.flush() job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -2032,10 +2090,11 @@ def test_full_retried_job_lifecycle(self, session, arq_redis, with_populated_job assert job.status == JobStatus.RUNNING # Fail job + exc = Exception("Temporary error") with TransactionSpy.spy(manager.db): manager.fail_job( - error=Exception("Temporary error"), - result={"details": "Traceback details here"}, + error=exc, + result={"status": "failed", "data": {}, "exception": exc}, ) session.flush() @@ -2084,10 +2143,11 @@ def test_full_reset_job_lifecycle(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.RUNNING # Fail job + exc = Exception("Some error") with TransactionSpy.spy(manager.db): manager.fail_job( - error=Exception("Some error"), - result={"details": "Traceback details here"}, + error=exc, + result={"status": "failed", "data": {}, "exception": exc}, ) session.flush() @@ -2120,10 +2180,11 @@ def test_full_reset_job_lifecycle(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.RUNNING # Fail job again + exc = Exception("Another error") with TransactionSpy.spy(manager.db): manager.fail_job( - error=Exception("Another error"), - result={"details": "Traceback details here"}, + error=exc, + result={"status": "failed", "data": {}, "exception": exc}, ) session.flush() diff --git a/tests/worker/lib/managers/test_pipeline_manager.py b/tests/worker/lib/managers/test_pipeline_manager.py index cb7de415..4f892824 100644 --- a/tests/worker/lib/managers/test_pipeline_manager.py +++ b/tests/worker/lib/managers/test_pipeline_manager.py @@ -3387,7 +3387,7 @@ async def test_full_pipeline_lifecycle( await arq_redis.flushdb() # exit job manager decorator: set job to SUCCEEDED - job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) session.commit() # exit pipeline manager decorator: enqueue newly queueable jobs or terminate pipeline @@ -3427,7 +3427,7 @@ async def test_full_pipeline_lifecycle( await arq_redis.flushdb() # exit job manager decorator: set dependent job to SUCCEEDED - job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) session.commit() # exit pipeline manager decorator: enqueue newly queueable jobs or terminate pipeline @@ -3481,7 +3481,7 @@ async def test_paused_pipeline_lifecycle( await arq_redis.flushdb() # Simulate job completion - job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) session.commit() # Coordinate the pipeline @@ -3524,7 +3524,7 @@ async def test_paused_pipeline_lifecycle( await arq_redis.flushdb() # Simulate dependent job completion - dependent_job_manager.succeed_job({"output": "some result", "logs": "some logs", "metadata": {"key": "value"}}) + dependent_job_manager.succeed_job({"status": "ok", "data": {}, "exception": None}) session.commit() # Coordinate the pipeline @@ -3630,9 +3630,8 @@ async def test_restart_pipeline_lifecycle( # Evict the job from redis to simulate completion. await arq_redis.flushdb() - job_manager.fail_job( - error=Exception("Simulated job failure"), result={"output": None, "logs": "some logs", "metadata": {}} - ) + exc = Exception("Simulated job failure") + job_manager.fail_job(error=exc, result={"status": "error", "data": {}, "exception": exc}) session.commit() # Coordinate the pipeline @@ -3709,9 +3708,8 @@ async def test_retry_pipeline_lifecycle( # Evict the job from redis to simulate completion. await arq_redis.flushdb() - job_manager.fail_job( - error=Exception("Simulated job failure"), result={"output": None, "logs": "some logs", "metadata": {}} - ) + exc = Exception("Simulated job failure") + job_manager.fail_job(error=exc, result={"status": "error", "data": {}, "exception": exc}) session.commit() # Coordinate the pipeline diff --git a/tests/worker/lib/managers/test_utils.py b/tests/worker/lib/managers/test_utils.py index a33285b4..fdb46e40 100644 --- a/tests/worker/lib/managers/test_utils.py +++ b/tests/worker/lib/managers/test_utils.py @@ -18,7 +18,7 @@ def test_construct_bulk_cancellation_result(self): assert result["status"] == "cancelled" assert result["data"]["reason"] == reason assert "timestamp" in result["data"] - assert result["exception_details"] is None + assert result["exception"] is None @pytest.mark.unit From 7b4434660a8c3d8e2460c5993493aa721511eb7c Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 21:00:22 -0800 Subject: [PATCH 44/70] feat: less prescriptive status messages in complete job functions Alters the `complete_job` method to remove default updates to the progress message. This allows the job to set its final progress message, which results in generally more useful messages than the generic options we have at our disposal in the complete job method. --- src/mavedb/worker/lib/managers/job_manager.py | 9 +-------- src/mavedb/worker/lib/managers/pipeline_manager.py | 1 + tests/worker/lib/managers/test_job_manager.py | 13 ------------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py index b2269398..b02cde18 100644 --- a/src/mavedb/worker/lib/managers/job_manager.py +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -287,14 +287,7 @@ def complete_job(self, status: JobStatus, result: JobResultData, error: Optional } job_run.finished_at = datetime.now() - if status == JobStatus.SUCCEEDED: - job_run.progress_message = "Job completed successfully" - elif status == JobStatus.CANCELLED: - job_run.progress_message = "Job cancelled" - elif status == JobStatus.SKIPPED: - job_run.progress_message = "Job skipped" - elif status == JobStatus.FAILED: - job_run.progress_message = "Job failed" + if status == JobStatus.FAILED: job_run.failure_category = FailureCategory.UNKNOWN if error: diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py index 0fffe94d..d5b69b80 100644 --- a/src/mavedb/worker/lib/managers/pipeline_manager.py +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -388,6 +388,7 @@ async def enqueue_ready_jobs(self) -> None: should_skip, reason = self.should_skip_job_due_to_dependencies(job) if should_skip: + job_manager.update_status_message(f"Job skipped: {reason}") job_manager.skip_job( { "status": "skipped", diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py index 4b3cde68..e9a11954 100644 --- a/tests/worker/lib/managers/test_job_manager.py +++ b/tests/worker/lib/managers/test_job_manager.py @@ -312,7 +312,6 @@ def test_complete_job_sets_default_failure_category_when_job_failed(self, mock_j "exception_details": format_raised_exception_info_as_dict(Exception()), } } - assert mock_job_run.progress_message == "Job failed" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None assert mock_job_run.failure_category == FailureCategory.UNKNOWN @@ -344,7 +343,6 @@ def test_complete_job_success(self, mock_job_manager, valid_status, exception, m "data": {"output": "test"}, "exception_details": format_raised_exception_info_as_dict(exception) if exception else None, } - assert mock_job_run.progress_message is not None # If an exception was provided, verify error fields are set appropriately. if exception: @@ -502,7 +500,6 @@ def test_fail_job_success(self, mock_job_manager, mock_job_run): "exception_details": format_raised_exception_info_as_dict(test_exception), } } - assert mock_job_run.progress_message == "Job failed" assert mock_job_run.error_message == str(test_exception) assert mock_job_run.error_traceback is not None assert mock_job_run.failure_category == FailureCategory.UNKNOWN @@ -531,7 +528,6 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d assert job.metadata_ == { "result": {"status": "failed", "data": {}, "exception_details": format_raised_exception_info_as_dict(exc)} } - assert job.progress_message == "Job failed" assert job.error_message == "Test error" assert job.error_traceback is not None assert job.failure_category == FailureCategory.UNKNOWN @@ -562,7 +558,6 @@ def test_succeed_job_success(self, mock_job_manager, mock_job_run): assert mock_job_run.metadata_ == { "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} } - assert mock_job_run.progress_message == "Job completed successfully" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None assert mock_job_run.failure_category is None @@ -587,7 +582,6 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d assert job.status == JobStatus.SUCCEEDED assert job.finished_at is not None - assert job.progress_message == "Job completed successfully" assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None assert job.error_traceback is None @@ -619,7 +613,6 @@ def test_cancel_job_success(self, mock_job_manager, mock_job_run): assert mock_job_run.metadata_ == { "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} } - assert mock_job_run.progress_message == "Job cancelled" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None assert mock_job_run.failure_category is None @@ -643,7 +636,6 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.CANCELLED - assert job.progress_message == "Job cancelled" assert job.finished_at is not None assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None @@ -676,7 +668,6 @@ def test_skip_job_success(self, mock_job_manager, mock_job_run): assert mock_job_run.metadata_ == { "result": {"status": "ok", "data": {"output": "test"}, "exception_details": None} } - assert mock_job_run.progress_message == "Job skipped" assert mock_job_run.error_message is None assert mock_job_run.error_traceback is None assert mock_job_run.failure_category is None @@ -701,7 +692,6 @@ def test_job_updated_successfully(self, session, arq_redis, with_populated_job_d job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.SKIPPED - assert job.progress_message == "Job skipped" assert job.finished_at is not None assert job.metadata_ == {"result": {"status": "ok", "data": {"output": "test"}, "exception_details": None}} assert job.error_message is None @@ -1972,7 +1962,6 @@ def test_full_successful_job_lifecycle(self, session, arq_redis, with_populated_ assert final_job.status == JobStatus.SUCCEEDED assert final_job.progress_current == 200 assert final_job.progress_total == 200 - assert final_job.progress_message == "Job completed successfully" def test_full_cancelled_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a cancelled job.""" @@ -2009,7 +1998,6 @@ def test_full_cancelled_job_lifecycle(self, session, arq_redis, with_populated_j job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.CANCELLED assert job.finished_at is not None - assert job.progress_message == "Job cancelled" def test_full_skipped_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a skipped job.""" @@ -2027,7 +2015,6 @@ def test_full_skipped_job_lifecycle(self, session, arq_redis, with_populated_job job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.SKIPPED assert job.finished_at is not None - assert job.progress_message == "Job skipped" def test_full_failed_job_lifecycle(self, session, arq_redis, with_populated_job_data, sample_job_run): """Test full job lifecycle for a failed job.""" From c8b0a02286ff459cbd7dccd28822c20bb6d9b68a Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 21:02:03 -0800 Subject: [PATCH 45/70] fix: ensure exception info is always present for failed jobs in job management --- src/mavedb/worker/lib/decorators/job_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 74867556..534c0336 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -121,7 +121,8 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar # Move job to final state based on result if result.get("status") == "failed" or result.get("exception"): - job_manager.fail_job(result=result, error=result["exception"]) + # Exception info should always be present for failed jobs + job_manager.fail_job(result=result, error=result["exception"]) # type: ignore[keyword-arg] elif result.get("status") == "skipped": job_manager.skip_job(result=result) else: From cedb42dbb2fd13e6681967e75672fbccf93adf7a Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 21:05:17 -0800 Subject: [PATCH 46/70] fix: move Athena engine fixture to optional conftest for core dependency compatibility --- tests/conftest.py | 53 +--------------------------------- tests/conftest_optional.py | 58 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index dd6ee6bd..82b43aeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,12 +9,11 @@ import pytest import pytest_postgresql import pytest_socket -from sqlalchemy import Column, Float, Integer, MetaData, String, Table, create_engine, text +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool from mavedb.db.base import Base -from mavedb.lib.gnomad import gnomad_table_name from mavedb.models import * # noqa: F403 from mavedb.models.experiment import Experiment from mavedb.models.experiment_set import ExperimentSet @@ -128,56 +127,6 @@ def patch_db_session_ctxmgr(db_session_fixture): yield -@pytest.fixture -def athena_engine(): - """Create and yield a SQLAlchemy engine connected to a mock Athena database.""" - engine = create_engine("sqlite:///:memory:") - metadata = MetaData() - - # TODO: Define your table schema here - my_table = Table( - gnomad_table_name(), - metadata, - Column("id", Integer, primary_key=True), - Column("locus.contig", String), - Column("locus.position", Integer), - Column("alleles", String), - Column("caid", String), - Column("joint.freq.all.ac", Integer), - Column("joint.freq.all.an", Integer), - Column("joint.fafmax.faf95_max_gen_anc", String), - Column("joint.fafmax.faf95_max", Float), - ) - metadata.create_all(engine) - - session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() - - # Insert test data - session.execute( - my_table.insert(), - [ - { - "id": 1, - "locus.contig": "chr1", - "locus.position": 12345, - "alleles": "[G, A]", - "caid": "CA123", - "joint.freq.all.ac": 23, - "joint.freq.all.an": 32432423, - "joint.fafmax.faf95_max_gen_anc": "anc1", - "joint.fafmax.faf95_max": 0.000006763700000000002, - } - ], - ) - session.commit() - session.close() - - try: - yield engine - finally: - engine.dispose() - - @pytest.fixture def setup_lib_db(session): """ diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index d5a1bbd8..3735634e 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -13,10 +13,13 @@ from biocommons.seqrepo import SeqRepo from fastapi.testclient import TestClient from httpx import AsyncClient +from sqlalchemy import Column, Float, Integer, MetaData, String, Table +from mavedb.db.session import create_engine, sessionmaker from mavedb.deps import get_db, get_seqrepo, get_worker, hgvs_data_provider from mavedb.lib.authentication import UserData, get_current_user from mavedb.lib.authorization import require_current_user +from mavedb.lib.gnomad import gnomad_table_name from mavedb.models.user import User from mavedb.server_main import app from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS @@ -404,3 +407,58 @@ def client(app_): async def async_client(app_): async with AsyncClient(app=app_, base_url="http://testserver") as ac: yield ac + + +##################################################################################################### +# Athena +##################################################################################################### + + +@pytest.fixture +def athena_engine(): + """Create and yield a SQLAlchemy engine connected to a mock Athena database.""" + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + + # TODO: Define your table schema here + my_table = Table( + gnomad_table_name(), + metadata, + Column("id", Integer, primary_key=True), + Column("locus.contig", String), + Column("locus.position", Integer), + Column("alleles", String), + Column("caid", String), + Column("joint.freq.all.ac", Integer), + Column("joint.freq.all.an", Integer), + Column("joint.fafmax.faf95_max_gen_anc", String), + Column("joint.fafmax.faf95_max", Float), + ) + metadata.create_all(engine) + + session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() + + # Insert test data + session.execute( + my_table.insert(), + [ + { + "id": 1, + "locus.contig": "chr1", + "locus.position": 12345, + "alleles": "[G, A]", + "caid": "CA123", + "joint.freq.all.ac": 23, + "joint.freq.all.an": 32432423, + "joint.fafmax.faf95_max_gen_anc": "anc1", + "joint.fafmax.faf95_max": 0.000006763700000000002, + } + ], + ) + session.commit() + session.close() + + try: + yield engine + finally: + engine.dispose() From 60ef67ddbb1a3cef52abb3f2b07d9e1f66821544 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 28 Jan 2026 21:07:52 -0800 Subject: [PATCH 47/70] feat: add standalone context creation for worker lifecycle management --- src/mavedb/worker/settings/lifecycle.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/mavedb/worker/settings/lifecycle.py b/src/mavedb/worker/settings/lifecycle.py index 18e301f9..3866b461 100644 --- a/src/mavedb/worker/settings/lifecycle.py +++ b/src/mavedb/worker/settings/lifecycle.py @@ -12,6 +12,20 @@ from mavedb.data_providers.services import cdot_rest +def standalone_ctx(): + """Create a standalone worker context dictionary.""" + ctx = {} + ctx["pool"] = futures.ProcessPoolExecutor() + ctx["hdp"] = cdot_rest() + ctx["state"] = {} + + # Additional context setup can be added here as needed. + # This function should not drift from the lifecycle hooks + # below and is useful for invoking worker jobs outside of ARQ. + + return ctx + + async def startup(ctx): ctx["pool"] = futures.ProcessPoolExecutor() From a3f36d1668eca8277c901170e24453252e7e2aa9 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 09:18:09 -0800 Subject: [PATCH 48/70] feat: add asyncclick dependency and update environment script to use it This update will support using job definitions directly in scripts. --- poetry.lock | 19 +++++++++++++++++-- pyproject.toml | 1 + src/mavedb/scripts/environment.py | 4 +--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 35c2477c..0fc68c24 100644 --- a/poetry.lock +++ b/poetry.lock @@ -154,6 +154,21 @@ files = [ {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, ] +[[package]] +name = "asyncclick" +version = "8.3.0.7" +description = "Composable command line interface toolkit, async fork" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "asyncclick-8.3.0.7-py3-none-any.whl", hash = "sha256:7607046de39a3f315867cad818849f973e29d350c10d92f251db3ff7600c6c7d"}, + {file = "asyncclick-8.3.0.7.tar.gz", hash = "sha256:8a80d8ac613098ee6a9a8f0248f60c66c273e22402cf3f115ed7f071acfc71d3"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "attrs" version = "25.3.0" @@ -1045,7 +1060,7 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "extra == \"server\" and (platform_system == \"Windows\" or sys_platform == \"win32\")", dev = "sys_platform == \"win32\""} +markers = {main = "platform_system == \"Windows\" or extra == \"server\" and sys_platform == \"win32\"", dev = "sys_platform == \"win32\""} [[package]] name = "coloredlogs" @@ -4844,4 +4859,4 @@ server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", " [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "a92cfae921a52b547c08ab74fd06a60427d5ac28601c68f4ca6d740e2059dfb2" +content-hash = "4be857a91855622d543b3eb008624fc9bb57b605d17e5aec00a0e1c8bef5ed3c" diff --git a/pyproject.toml b/pyproject.toml index bb55a121..4bf083e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ starlette-context = { version = "^0.3.6", optional = true } slack-sdk = { version = "~3.21.3", optional = true } uvicorn = { extras = ["standard"], version = "*", optional = true } watchtower = { version = "~3.2.0", optional = true } +asyncclick = "^8.3.0.7" [tool.poetry.group.dev] optional = true diff --git a/src/mavedb/scripts/environment.py b/src/mavedb/scripts/environment.py index 66bdbb78..831da7a4 100644 --- a/src/mavedb/scripts/environment.py +++ b/src/mavedb/scripts/environment.py @@ -4,16 +4,14 @@ import enum import logging -import click from functools import wraps - +import asyncclick as click from sqlalchemy.orm import configure_mappers from mavedb import deps from mavedb.models import * # noqa: F403 - logger = logging.getLogger(__name__) From 0416b2d7a53a39e9e959ec0489e549f75478a4d7 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 09:18:31 -0800 Subject: [PATCH 49/70] feat: add standalone job definitions and update lifecycle context for job submission --- src/mavedb/worker/jobs/__init__.py | 2 + src/mavedb/worker/jobs/registry.py | 83 +++++++++++++++++++++++++ src/mavedb/worker/settings/lifecycle.py | 1 + 3 files changed, 86 insertions(+) diff --git a/src/mavedb/worker/jobs/__init__.py b/src/mavedb/worker/jobs/__init__.py index 6a52927c..e421bbad 100644 --- a/src/mavedb/worker/jobs/__init__.py +++ b/src/mavedb/worker/jobs/__init__.py @@ -27,6 +27,7 @@ from mavedb.worker.jobs.registry import ( BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS, + STANDALONE_JOB_DEFINITIONS, ) from mavedb.worker.jobs.variant_processing.creation import create_variants_for_score_set from mavedb.worker.jobs.variant_processing.mapping import ( @@ -49,4 +50,5 @@ # Job registry and utilities "BACKGROUND_FUNCTIONS", "BACKGROUND_CRONJOBS", + "STANDALONE_JOB_DEFINITIONS", ] diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py index 251d87c8..af1e9836 100644 --- a/src/mavedb/worker/jobs/registry.py +++ b/src/mavedb/worker/jobs/registry.py @@ -9,6 +9,8 @@ from arq.cron import CronJob, cron +from mavedb.lib.types.workflow import JobDefinition +from mavedb.models.enums.job_pipeline import JobType from mavedb.worker.jobs.data_management import ( refresh_materialized_views, refresh_published_variants_view, @@ -56,7 +58,88 @@ ] +STANDALONE_JOB_DEFINITIONS: dict[Callable, JobDefinition] = { + create_variants_for_score_set: { + "dependencies": [], + "params": { + "score_set_id": None, + "updater_id": None, + "correlation_id": None, + "scores_file_key": None, + "counts_file_key": None, + "score_columns_metadata": None, + "count_columns_metadata": None, + }, + "function": "create_variants_for_score_set", + "key": "create_variants_for_score_set", + "type": JobType.VARIANT_CREATION, + }, + map_variants_for_score_set: { + "dependencies": [], + "params": {"score_set_id": None, "updater_id": None, "correlation_id": None}, + "function": "map_variants_for_score_set", + "key": "map_variants_for_score_set", + "type": JobType.VARIANT_MAPPING, + }, + submit_score_set_mappings_to_car: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "submit_score_set_mappings_to_car", + "key": "submit_score_set_mappings_to_car", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + submit_score_set_mappings_to_ldh: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "submit_score_set_mappings_to_ldh", + "key": "submit_score_set_mappings_to_ldh", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + submit_uniprot_mapping_jobs_for_score_set: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "submit_uniprot_mapping_jobs_for_score_set", + "key": "submit_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + poll_uniprot_mapping_jobs_for_score_set: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "poll_uniprot_mapping_jobs_for_score_set", + "key": "poll_uniprot_mapping_jobs_for_score_set", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + link_gnomad_variants: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None}, + "function": "link_gnomad_variants", + "key": "link_gnomad_variants", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, + refresh_materialized_views: { + "dependencies": [], + "params": {"correlation_id": None}, + "function": "refresh_materialized_views", + "key": "refresh_materialized_views", + "type": JobType.DATA_MANAGEMENT, + }, + refresh_published_variants_view: { + "dependencies": [], + "params": {"correlation_id": None}, + "function": "refresh_published_variants_view", + "key": "refresh_published_variants_view", + "type": JobType.DATA_MANAGEMENT, + }, +} +""" +Standalone job definitions for direct job submission outside of pipelines. +All job definitions in this dict must correspond to a job function in BACKGROUND_FUNCTIONS +and must not have any dependencies on other jobs. +""" + + __all__ = [ "BACKGROUND_FUNCTIONS", "BACKGROUND_CRONJOBS", + "STANDALONE_JOB_DEFINITIONS", ] diff --git a/src/mavedb/worker/settings/lifecycle.py b/src/mavedb/worker/settings/lifecycle.py index 3866b461..7e5f933f 100644 --- a/src/mavedb/worker/settings/lifecycle.py +++ b/src/mavedb/worker/settings/lifecycle.py @@ -16,6 +16,7 @@ def standalone_ctx(): """Create a standalone worker context dictionary.""" ctx = {} ctx["pool"] = futures.ProcessPoolExecutor() + ctx["redis"] = None # Redis connection can be set up here if needed. ctx["hdp"] = cdot_rest() ctx["state"] = {} From a013cc04f95b1e0a0314a066e787495db3cca3d2 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 09:18:53 -0800 Subject: [PATCH 50/70] feat: refactor populate_mapped_variant_data to use async and job submission for score sets --- .../scripts/populate_mapped_variants.py | 201 ++++-------------- 1 file changed, 46 insertions(+), 155 deletions(-) diff --git a/src/mavedb/scripts/populate_mapped_variants.py b/src/mavedb/scripts/populate_mapped_variants.py index de9eedbd..72b4b449 100644 --- a/src/mavedb/scripts/populate_mapped_variants.py +++ b/src/mavedb/scripts/populate_mapped_variants.py @@ -1,178 +1,69 @@ +import datetime import logging -from datetime import date -from typing import Optional, Sequence, Union +from typing import Optional, Sequence -import click -from sqlalchemy import cast, select -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Session +import asyncclick as click # using asyncclick to allow async commands +from sqlalchemy import select -from mavedb.data_providers.services import vrs_mapper -from mavedb.lib.exceptions import NonexistentMappingReferenceError -from mavedb.lib.logging.context import format_raised_exception_info_as_dict -from mavedb.lib.mapping import ANNOTATION_LAYERS -from mavedb.models.enums.mapping_state import MappingState -from mavedb.models.mapped_variant import MappedVariant +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.scripts.environment import script_environment, with_database_session +from mavedb.scripts.environment import script_environment +from mavedb.worker.jobs import STANDALONE_JOB_DEFINITIONS, map_variants_for_score_set +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -def variant_from_mapping(db: Session, mapping: dict, dcd_mapping_version: str) -> MappedVariant: - variant_urn = mapping.get("mavedb_id") - variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() - - return MappedVariant( - variant_id=variant.id, - pre_mapped=mapping.get("pre_mapped"), - post_mapped=mapping.get("post_mapped"), - modification_date=date.today(), - mapped_date=date.today(), # since this is a one-time script, assume mapping was done today - vrs_version=mapping.get("vrs_version"), - mapping_api_version=dcd_mapping_version, - error_message=mapping.get("error_message"), - current=True, - ) - - @script_environment.command() -@with_database_session @click.argument("urns", nargs=-1) @click.option("--all", help="Populate mapped variants for every score set in MaveDB.", is_flag=True) -def populate_mapped_variant_data(db: Session, urns: Sequence[Optional[str]], all: bool): +@click.option("--as-user-id", type=int, help="User ID to attribute as the updater of the mapped variants.") +async def populate_mapped_variant_data(urns: Sequence[Optional[str]], all: bool, as_user_id: Optional[int]): score_set_ids: Sequence[Optional[int]] + db = SessionLocal() + if all: score_set_ids = db.scalars(select(ScoreSet.id)).all() logger.info( - f"Command invoked with --all. Routine will populate mapped variant data for {len(urns)} score sets." + f"Command invoked with --all. Routine will populate mapped variant data for {len(score_set_ids)} score sets." ) else: score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() - logger.info(f"Populating mapped variant data for the provided score sets ({len(urns)}).") - - vrs = vrs_mapper() - - for idx, ss_id in enumerate(score_set_ids): - if not ss_id: - continue - - score_set = db.scalar(select(ScoreSet).where(ScoreSet.id == ss_id)) - if not score_set: - logger.warning(f"Could not fetch score set with id={ss_id}.") - continue - - try: - existing_mapped_variants = ( - db.query(MappedVariant) - .join(Variant) - .join(ScoreSet) - .filter(ScoreSet.id == ss_id, MappedVariant.current.is_(True)) - .all() - ) - - for variant in existing_mapped_variants: - variant.current = False - - assert score_set.urn - logger.info(f"Mapping score set {score_set.urn}.") - mapped_scoreset = vrs.map_score_set(score_set.urn) - logger.info(f"Done mapping score set {score_set.urn}.") - - dcd_mapping_version = mapped_scoreset["dcd_mapping_version"] - mapped_scores = mapped_scoreset.get("mapped_scores") - - if not mapped_scores: - # if there are no mapped scores, the score set failed to map. - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": mapped_scoreset.get("error_message")} - db.commit() - logger.info(f"No mapped variants available for {score_set.urn}.") - else: - reference_metadata = mapped_scoreset.get("reference_sequences") - if not reference_metadata: - raise NonexistentMappingReferenceError() - - for target_gene_identifier in reference_metadata: - target_gene = next( - ( - target_gene - for target_gene in score_set.target_genes - if target_gene.name == target_gene_identifier - ), - None, - ) - if not target_gene: - raise ValueError( - f"Target gene {target_gene_identifier} not found in database for score set {score_set.urn}." - ) - # allow for multiple annotation layers - pre_mapped_metadata = {} - post_mapped_metadata: dict[str, Union[Optional[str], dict[str, dict[str, str | list[str]]]]] = {} - excluded_pre_mapped_keys = {"sequence"} - - gene_info = reference_metadata[target_gene_identifier].get("gene_info") - if gene_info: - target_gene.mapped_hgnc_name = gene_info.get("hgnc_symbol") - post_mapped_metadata["hgnc_name_selection_method"] = gene_info.get("selection_method") - - for annotation_layer in reference_metadata[target_gene_identifier]["layers"]: - layer_premapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( - "computed_reference_sequence" - ) - if layer_premapped: - pre_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = { - k: layer_premapped[k] - for k in set(list(layer_premapped.keys())) - excluded_pre_mapped_keys - } - layer_postmapped = reference_metadata[target_gene_identifier]["layers"][annotation_layer].get( - "mapped_reference_sequence" - ) - if layer_postmapped: - post_mapped_metadata[ANNOTATION_LAYERS[annotation_layer]] = layer_postmapped - target_gene.pre_mapped_metadata = cast(pre_mapped_metadata, JSONB) - target_gene.post_mapped_metadata = cast(post_mapped_metadata, JSONB) - - mapped_variants = [ - variant_from_mapping(db=db, mapping=mapped_score, dcd_mapping_version=dcd_mapping_version) - for mapped_score in mapped_scores - ] - logger.debug(f"Done constructing {len(mapped_variants)} mapped variant objects.") - - num_successful_variants = len( - [variant for variant in mapped_variants if variant.post_mapped is not None] - ) - logger.debug( - f"{num_successful_variants}/{len(mapped_variants)} variants generated a post-mapped VRS object." - ) - - if num_successful_variants == 0: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = {"error_message": "All variants failed to map"} - elif num_successful_variants < len(mapped_variants): - score_set.mapping_state = MappingState.incomplete - else: - score_set.mapping_state = MappingState.complete - - db.bulk_save_objects(mapped_variants) - db.commit() - logger.info(f"Done populating {len(mapped_variants)} mapped variants for {score_set.urn}.") - - except Exception as e: - logging_context = { - "mapped_score_sets": urns[:idx], - "unmapped_score_sets": urns[idx:], - } - logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} - logger.error(f"Score set {score_set.urn} failed to map.", extra=logging_context) - logger.info(f"Rolling back all changes for scoreset {score_set.urn}") - db.rollback() - - logger.info(f"Done with score set {score_set.urn}. ({idx+1}/{len(urns)}).") + logger.info(f"Populating mapped variant data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for mapping variants + job_def = STANDALONE_JOB_DEFINITIONS[map_variants_for_score_set] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Populating mapped variant data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "updater_id": as_user_id + if as_user_id is not None + else 1, # Use provided user ID or default to System user + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") - logger.info("Done populating mapped variant data.") + await map_variants_for_score_set(ctx, job_run.id) if __name__ == "__main__": From f3a7d6a0eae20f9d19e31e70c9d94ddd72a3db26 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 11:07:17 -0800 Subject: [PATCH 51/70] chore: test cleanup --- tests/helpers/util/variant.py | 5 ++ tests/lib/test_annotation_status_manager.py | 4 ++ tests/lib/test_gnomad.py | 69 ++++++++++++------- tests/lib/workflow/conftest.py | 17 ++--- tests/lib/workflow/conftest_optional.py | 16 +++++ tests/lib/workflow/test_job_factory.py | 7 +- tests/lib/workflow/test_pipeline_factory.py | 4 ++ tests/routers/conftest.py | 24 ++++--- tests/routers/conftest_optional.py | 14 ++++ tests/routers/test_score_set.py | 44 +++++++++--- tests/worker/jobs/conftest.py | 16 ++--- tests/worker/jobs/conftest_optional.py | 14 ++++ .../worker/jobs/data_management/test_views.py | 5 +- .../external_services/network/test_clingen.py | 7 +- .../external_services/network/test_uniprot.py | 4 ++ .../jobs/external_services/test_clingen.py | 7 +- .../jobs/external_services/test_gnomad.py | 6 +- .../jobs/external_services/test_uniprot.py | 6 +- .../test_start_pipeline.py | 7 +- tests/worker/jobs/utils/test_setup.py | 6 +- .../jobs/variant_processing/test_creation.py | 8 ++- .../jobs/variant_processing/test_mapping.py | 7 +- .../decorators/test_pipeline_management.py | 3 +- tests/worker/lib/managers/test_job_manager.py | 3 +- tests/worker/lib/managers/test_utils.py | 4 ++ 25 files changed, 218 insertions(+), 89 deletions(-) create mode 100644 tests/lib/workflow/conftest_optional.py create mode 100644 tests/routers/conftest_optional.py create mode 100644 tests/worker/jobs/conftest_optional.py diff --git a/tests/helpers/util/variant.py b/tests/helpers/util/variant.py index 5fcc05db..eede1e61 100644 --- a/tests/helpers/util/variant.py +++ b/tests/helpers/util/variant.py @@ -36,7 +36,11 @@ def mock_worker_variant_insertion( with ( open(scores_csv_path, "rb") as score_file, patch.object(ArqRedis, "enqueue_job", return_value=None) as worker_queue, + patch("mavedb.routers.score_sets.s3_client") as mock_s3_client, ): + mock_s3 = mock_s3_client.return_value + mock_s3.upload_fileobj.return_value = None # or whatever you want + files = {"scores_file": (scores_csv_path.name, score_file, "rb")} if counts_csv_path is not None: @@ -69,6 +73,7 @@ def mock_worker_variant_insertion( # Assert we have mocked a job being added to the queue, and that the request succeeded. The # response value here isn't important- we will add variants to the score set manually. + mock_s3.upload_fileobj.assert_called() worker_queue.assert_called_once() assert response.status_code == 200 diff --git a/tests/lib/test_annotation_status_manager.py b/tests/lib/test_annotation_status_manager.py index 633cc848..98980f00 100644 --- a/tests/lib/test_annotation_status_manager.py +++ b/tests/lib/test_annotation_status_manager.py @@ -1,5 +1,9 @@ +# ruff: noqa: E402 + import pytest +pytest.importorskip("psycopg2") + from mavedb.lib.annotation_status_manager import AnnotationStatusManager from mavedb.models.enums.annotation_type import AnnotationType from mavedb.models.enums.job_pipeline import AnnotationStatus diff --git a/tests/lib/test_gnomad.py b/tests/lib/test_gnomad.py index 043c6c56..14dde952 100644 --- a/tests/lib/test_gnomad.py +++ b/tests/lib/test_gnomad.py @@ -1,25 +1,26 @@ # ruff: noqa: E402 -import pytest -import importlib from unittest.mock import patch +import pytest + +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + pyathena = pytest.importorskip("pyathena") fastapi = pytest.importorskip("fastapi") from mavedb.lib.gnomad import ( - gnomad_identifier, allele_list_from_list_like_string, + gnomad_identifier, + gnomad_table_name, link_gnomad_variants_to_mapped_variants, ) -from mavedb.models.mapped_variant import MappedVariant from mavedb.models.gnomad_variant import GnomADVariant - +from mavedb.models.mapped_variant import MappedVariant from tests.helpers.constants import ( - TEST_GNOMAD_ALLELE_NUMBER, + TEST_GNOMAD_DATA_VERSION, TEST_GNOMAD_VARIANT, TEST_MINIMAL_MAPPED_VARIANT, - TEST_GNOMAD_DATA_VERSION, ) ### Tests for gnomad_identifier function ### @@ -63,22 +64,17 @@ def test_gnomad_identifier_raises_with_no_alleles(): ### Tests for gnomad_table_name function ### -def test_gnomad_table_name_returns_expected(monkeypatch): - monkeypatch.setenv("GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION) - # Reload the module to update GNOMAD_DATA_VERSION global - import mavedb.lib.gnomad as gnomad_mod - - importlib.reload(gnomad_mod) - assert gnomad_mod.gnomad_table_name() == TEST_GNOMAD_DATA_VERSION.replace(".", "_") - +def test_gnomad_table_name_returns_expected(): + with patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION): + assert gnomad_table_name() == TEST_GNOMAD_DATA_VERSION.replace(".", "_") -def test_gnomad_table_name_raises_if_env_not_set(monkeypatch): - monkeypatch.delenv("GNOMAD_DATA_VERSION", raising=False) - import mavedb.lib.gnomad as gnomad_mod - importlib.reload(gnomad_mod) - with pytest.raises(ValueError, match="GNOMAD_DATA_VERSION environment variable is not set."): - gnomad_mod.gnomad_table_name() +def test_gnomad_table_name_raises_if_env_not_set(): + with ( + pytest.raises(ValueError, match="GNOMAD_DATA_VERSION environment variable is not set."), + patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", None), + ): + gnomad_table_name() ### Tests for allele_list_from_list_like_string function ### @@ -125,6 +121,16 @@ def test_allele_list_from_list_like_string_invalid_format_not_list(): ### Tests for link_gnomad_variants_to_mapped_variants function ### +def _verify_annotation_status(session, mapped_variants, expected_version): + annotations = session.query(VariantAnnotationStatus).all() + assert len(annotations) == len(mapped_variants) + + for mapped_variant, annotation in zip(mapped_variants, annotations): + assert annotation.variant_id == mapped_variant.variant_id + assert annotation.annotation_type == "gnomad_allele_frequency" + assert annotation.version == expected_version + + def test_links_new_gnomad_variant_to_mapped_variant( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant ): @@ -148,6 +154,8 @@ def test_links_new_gnomad_variant_to_mapped_variant( for attr in edited_saved_gnomad_variant: assert getattr(mapped_variant.gnomad_variants[0], attr) == edited_saved_gnomad_variant[attr] + _verify_annotation_status(session, [mapped_variant], TEST_GNOMAD_DATA_VERSION) + def test_can_link_gnomad_variants_with_none_type_faf_fields( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant @@ -175,6 +183,8 @@ def test_can_link_gnomad_variants_with_none_type_faf_fields( for attr in gnomad_variant_comparator: assert getattr(mapped_variant.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + _verify_annotation_status(session, [mapped_variant], TEST_GNOMAD_DATA_VERSION) + def test_links_existing_gnomad_variant(session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant): gnomad_variant = GnomADVariant(**TEST_GNOMAD_VARIANT) @@ -199,8 +209,10 @@ def test_links_existing_gnomad_variant(session, mocked_gnomad_variant_row, setup for attr in edited_saved_gnomad_variant: assert getattr(mapped_variant.gnomad_variants[0], attr) == edited_saved_gnomad_variant[attr] + _verify_annotation_status(session, [mapped_variant], TEST_GNOMAD_DATA_VERSION) -def test_removes_existing_gnomad_variant_with_same_version( + +def test_adding_existing_gnomad_variant_with_same_version_does_not_result_in_duplication( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant ): mapped_variant = setup_lib_db_with_mapped_variant @@ -212,7 +224,6 @@ def test_removes_existing_gnomad_variant_with_same_version( result = link_gnomad_variants_to_mapped_variants(session, [mocked_gnomad_variant_row]) assert result == 1 - setattr(mocked_gnomad_variant_row, "joint.freq.all.ac", "1234") with patch("mavedb.lib.gnomad.GNOMAD_DATA_VERSION", TEST_GNOMAD_DATA_VERSION): result = link_gnomad_variants_to_mapped_variants(session, [mocked_gnomad_variant_row]) assert result == 1 @@ -221,8 +232,6 @@ def test_removes_existing_gnomad_variant_with_same_version( session.refresh(mapped_variant) edited_saved_gnomad_variant = TEST_GNOMAD_VARIANT.copy() - edited_saved_gnomad_variant["allele_count"] = 1234 - edited_saved_gnomad_variant["allele_frequency"] = float(1234 / int(TEST_GNOMAD_ALLELE_NUMBER)) edited_saved_gnomad_variant.pop("creation_date") edited_saved_gnomad_variant.pop("modification_date") @@ -230,6 +239,8 @@ def test_removes_existing_gnomad_variant_with_same_version( for attr in edited_saved_gnomad_variant: assert getattr(mapped_variant.gnomad_variants[0], attr) == edited_saved_gnomad_variant[attr] + _verify_annotation_status(session, [mapped_variant, mapped_variant], TEST_GNOMAD_DATA_VERSION) + def test_links_multiple_rows_and_variants(session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant): mapped_variant1 = setup_lib_db_with_mapped_variant @@ -256,11 +267,15 @@ def test_links_multiple_rows_and_variants(session, mocked_gnomad_variant_row, se for attr in gnomad_variant_comparator: assert getattr(mv.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + _verify_annotation_status(session, [mapped_variant1, mapped_variant2], TEST_GNOMAD_DATA_VERSION) + def test_returns_zero_when_no_mapped_variants(session, mocked_gnomad_variant_row): result = link_gnomad_variants_to_mapped_variants(session, [mocked_gnomad_variant_row]) assert result == 0 + _verify_annotation_status(session, [], TEST_GNOMAD_DATA_VERSION) + def test_only_current_flag_filters_variants(session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant): mapped_variant1 = setup_lib_db_with_mapped_variant @@ -287,6 +302,8 @@ def test_only_current_flag_filters_variants(session, mocked_gnomad_variant_row, for attr in gnomad_variant_comparator: assert getattr(mapped_variant2.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + _verify_annotation_status(session, [mapped_variant2], TEST_GNOMAD_DATA_VERSION) + def test_only_current_flag_is_false_operates_on_all_variants( session, mocked_gnomad_variant_row, setup_lib_db_with_mapped_variant @@ -313,3 +330,5 @@ def test_only_current_flag_is_false_operates_on_all_variants( assert len(mv.gnomad_variants) == 1 for attr in gnomad_variant_comparator: assert getattr(mv.gnomad_variants[0], attr) == gnomad_variant_comparator[attr] + + _verify_annotation_status(session, [mapped_variant1, mapped_variant2], TEST_GNOMAD_DATA_VERSION) diff --git a/tests/lib/workflow/conftest.py b/tests/lib/workflow/conftest.py index d88789a4..dad72098 100644 --- a/tests/lib/workflow/conftest.py +++ b/tests/lib/workflow/conftest.py @@ -2,23 +2,14 @@ import pytest -from mavedb.lib.workflow.job_factory import JobFactory -from mavedb.lib.workflow.pipeline_factory import PipelineFactory from mavedb.models.enums.job_pipeline import DependencyType from mavedb.models.user import User from tests.helpers.constants import TEST_USER - -@pytest.fixture -def job_factory(session): - """Fixture to provide a mocked JobFactory instance.""" - yield JobFactory(session) - - -@pytest.fixture -def pipeline_factory(session): - """Fixture to provide a mocked PipelineFactory instance.""" - yield PipelineFactory(session) +try: + from .conftest_optional import * # noqa: F403, F401 +except ImportError: + pass @pytest.fixture diff --git a/tests/lib/workflow/conftest_optional.py b/tests/lib/workflow/conftest_optional.py new file mode 100644 index 00000000..f165cc74 --- /dev/null +++ b/tests/lib/workflow/conftest_optional.py @@ -0,0 +1,16 @@ +import pytest + +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.lib.workflow.pipeline_factory import PipelineFactory + + +@pytest.fixture +def job_factory(session): + """Fixture to provide a mocked JobFactory instance.""" + yield JobFactory(session) + + +@pytest.fixture +def pipeline_factory(session): + """Fixture to provide a mocked PipelineFactory instance.""" + yield PipelineFactory(session) diff --git a/tests/lib/workflow/test_job_factory.py b/tests/lib/workflow/test_job_factory.py index c34b6ca0..6b730299 100644 --- a/tests/lib/workflow/test_job_factory.py +++ b/tests/lib/workflow/test_job_factory.py @@ -1,7 +1,10 @@ -from unittest.mock import patch - +# ruff: noqa: E402 import pytest +pytest.importorskip("fastapi") + +from unittest.mock import patch + from mavedb.models.pipeline import Pipeline diff --git a/tests/lib/workflow/test_pipeline_factory.py b/tests/lib/workflow/test_pipeline_factory.py index e585666f..b944e469 100644 --- a/tests/lib/workflow/test_pipeline_factory.py +++ b/tests/lib/workflow/test_pipeline_factory.py @@ -1,4 +1,8 @@ +# ruff: noqa: E402 import pytest + +pytest.importorskip("fastapi") + from sqlalchemy import select from mavedb.lib.workflow.pipeline_factory import PipelineFactory diff --git a/tests/routers/conftest.py b/tests/routers/conftest.py index d54b18d8..ba34c548 100644 --- a/tests/routers/conftest.py +++ b/tests/routers/conftest.py @@ -4,32 +4,36 @@ import pytest from mavedb.models.clinical_control import ClinicalControl -from mavedb.models.controlled_keyword import ControlledKeyword from mavedb.models.contributor import Contributor +from mavedb.models.controlled_keyword import ControlledKeyword from mavedb.models.enums.user_role import UserRole -from mavedb.models.publication_identifier import PublicationIdentifier from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.license import License +from mavedb.models.publication_identifier import PublicationIdentifier from mavedb.models.role import Role from mavedb.models.taxonomy import Taxonomy from mavedb.models.user import User - from tests.helpers.constants import ( ADMIN_USER, - TEST_CLINVAR_CONTROL, - TEST_GENERIC_CLINICAL_CONTROL, - EXTRA_USER, EXTRA_CONTRIBUTOR, + EXTRA_LICENSE, + EXTRA_USER, + TEST_CLINVAR_CONTROL, TEST_DB_KEYWORDS, - TEST_LICENSE, + TEST_GENERIC_CLINICAL_CONTROL, + TEST_GNOMAD_VARIANT, TEST_INACTIVE_LICENSE, - EXTRA_LICENSE, + TEST_LICENSE, + TEST_PUBMED_PUBLICATION, TEST_SAVED_TAXONOMY, TEST_USER, - TEST_PUBMED_PUBLICATION, - TEST_GNOMAD_VARIANT, ) +try: + from .conftest_optional import * # noqa: F403, F401 +except ImportError: + pass + @pytest.fixture def setup_router_db(session): diff --git a/tests/routers/conftest_optional.py b/tests/routers/conftest_optional.py new file mode 100644 index 00000000..efbd119b --- /dev/null +++ b/tests/routers/conftest_optional.py @@ -0,0 +1,14 @@ +from unittest import mock + +import pytest +from mypy_boto3_s3 import S3Client + + +@pytest.fixture +def mock_s3_client(): + """Mock S3 client for tests that interact with S3.""" + + with mock.patch("mavedb.routers.score_sets.s3_client") as mock_s3_client_func: + mock_s3 = mock.MagicMock(spec=S3Client) + mock_s3_client_func.return_value = mock_s3 + yield mock_s3 diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 86234392..5cb29ab6 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -448,7 +448,7 @@ def test_can_patch_score_set_data_before_publication( indirect=["mock_publication_fetch"], ) def test_can_patch_score_set_data_with_files_before_publication( - client, setup_router_db, form_field, filename, mime_type, data_files, mock_publication_fetch + client, setup_router_db, form_field, filename, mime_type, data_files, mock_publication_fetch, mock_s3_client ): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -460,7 +460,10 @@ def test_can_patch_score_set_data_with_files_before_publication( if form_field == "counts_file" or form_field == "scores_file": data_file_path = data_files / filename files = {form_field: (filename, open(data_file_path, "rb"), mime_type)} - with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + with ( + patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), + ): response = client.patch(f"/api/v1/score-sets-with-variants/{score_set['urn']}", files=files) worker_queue.assert_called_once() assert response.status_code == 200 @@ -871,13 +874,14 @@ def test_creating_user_can_view_all_score_calibrations_in_score_set(client, setu ######################################################################################################################## -def test_add_score_set_variants_scores_only_endpoint(client, setup_router_db, data_files): +def test_add_score_set_variants_scores_only_endpoint(client, setup_router_db, data_files, mock_s3_client): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores.csv" with ( open(scores_csv_path, "rb") as scores_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -895,7 +899,9 @@ def test_add_score_set_variants_scores_only_endpoint(client, setup_router_db, da assert score_set == response_data -def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setup_router_db, data_files): +def test_add_score_set_variants_scores_and_counts_endpoint( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores.csv" @@ -904,6 +910,7 @@ def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setu open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -925,7 +932,7 @@ def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setu def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( - session, client, setup_router_db, data_files + session, client, setup_router_db, data_files, mock_s3_client ): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -939,6 +946,7 @@ def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( open(score_columns_metadata_path, "rb") as score_columns_metadata_file, open(count_columns_metadata_path, "rb") as count_columns_metadata_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): score_columns_metadata = json.load(score_columns_metadata_file) count_columns_metadata = json.load(count_columns_metadata_file) @@ -965,13 +973,14 @@ def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( assert score_set == response_data -def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_router_db, data_files): +def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_router_db, data_files, mock_s3_client): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores_utf8_encoded.csv" with ( open(scores_csv_path, "rb") as scores_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -989,7 +998,9 @@ def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_ assert score_set == response_data -def test_add_score_set_variants_scores_and_counts_endpoint_utf8_encoded(session, client, setup_router_db, data_files): +def test_add_score_set_variants_scores_and_counts_endpoint_utf8_encoded( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores_utf8_encoded.csv" @@ -998,6 +1009,7 @@ def test_add_score_set_variants_scores_and_counts_endpoint_utf8_encoded(session, open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1073,7 +1085,9 @@ def test_anonymous_cannot_add_scores_to_other_user_score_set( assert "Could not validate credentials" in response_data["detail"] -def test_contributor_can_add_scores_to_other_user_score_set(session, client, setup_router_db, data_files): +def test_contributor_can_add_scores_to_other_user_score_set( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) change_ownership(session, score_set["urn"], ScoreSetDbModel) @@ -1090,6 +1104,7 @@ def test_contributor_can_add_scores_to_other_user_score_set(session, client, set with ( open(scores_csv_path, "rb") as scores_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1127,7 +1142,9 @@ def test_contributor_can_add_scores_to_other_user_score_set(session, client, set assert score_set == response_data -def test_contributor_can_add_scores_and_counts_to_other_user_score_set(session, client, setup_router_db, data_files): +def test_contributor_can_add_scores_and_counts_to_other_user_score_set( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) change_ownership(session, score_set["urn"], ScoreSetDbModel) @@ -1146,6 +1163,7 @@ def test_contributor_can_add_scores_and_counts_to_other_user_score_set(session, open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1187,7 +1205,7 @@ def test_contributor_can_add_scores_and_counts_to_other_user_score_set(session, def test_admin_can_add_scores_to_other_user_score_set( - session, client, setup_router_db, data_files, admin_app_overrides + session, client, setup_router_db, data_files, mock_s3_client, admin_app_overrides ): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -1197,6 +1215,7 @@ def test_admin_can_add_scores_to_other_user_score_set( open(scores_csv_path, "rb") as scores_file, DependencyOverrider(admin_app_overrides), patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", @@ -1214,7 +1233,9 @@ def test_admin_can_add_scores_to_other_user_score_set( assert score_set == response_data -def test_admin_can_add_scores_and_counts_to_other_user_score_set(session, client, setup_router_db, data_files): +def test_admin_can_add_scores_and_counts_to_other_user_score_set( + session, client, setup_router_db, data_files, mock_s3_client +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) scores_csv_path = data_files / "scores.csv" @@ -1223,6 +1244,7 @@ def test_admin_can_add_scores_and_counts_to_other_user_score_set(session, client open(scores_csv_path, "rb") as scores_file, open(counts_csv_path, "rb") as counts_file, patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + patch.object(mock_s3_client, "upload_fileobj", return_value=None), ): response = client.post( f"/api/v1/score-sets/{score_set['urn']}/variants/data", diff --git a/tests/worker/jobs/conftest.py b/tests/worker/jobs/conftest.py index a98d27ae..4a41aaab 100644 --- a/tests/worker/jobs/conftest.py +++ b/tests/worker/jobs/conftest.py @@ -1,7 +1,4 @@ -from unittest import mock - import pytest -from mypy_boto3_s3 import S3Client from mavedb.models.enums.job_pipeline import DependencyType from mavedb.models.job_dependency import JobDependency @@ -11,15 +8,10 @@ from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant - -@pytest.fixture -def mock_s3_client(): - """Mock S3 client for tests that interact with S3.""" - - with mock.patch("mavedb.worker.jobs.variant_processing.creation.s3_client") as mock_s3_client_func: - mock_s3 = mock.MagicMock(spec=S3Client) - mock_s3_client_func.return_value = mock_s3 - yield mock_s3 +try: + from .conftest_optional import * # noqa: F403, F401 +except ImportError: + pass ## param fixtures for job runs ## diff --git a/tests/worker/jobs/conftest_optional.py b/tests/worker/jobs/conftest_optional.py new file mode 100644 index 00000000..3ca408cb --- /dev/null +++ b/tests/worker/jobs/conftest_optional.py @@ -0,0 +1,14 @@ +from unittest import mock + +import pytest +from mypy_boto3_s3 import S3Client + + +@pytest.fixture +def mock_s3_client(): + """Mock S3 client for tests that interact with S3.""" + + with mock.patch("mavedb.worker.jobs.variant_processing.creation.s3_client") as mock_s3_client_func: + mock_s3 = mock.MagicMock(spec=S3Client) + mock_s3_client_func.return_value = mock_s3 + yield mock_s3 diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py index 564c24cb..d5011ec9 100644 --- a/tests/worker/jobs/data_management/test_views.py +++ b/tests/worker/jobs/data_management/test_views.py @@ -2,9 +2,6 @@ import pytest -from mavedb.models.pipeline import Pipeline -from mavedb.models.published_variant import PublishedVariantsMV - pytest.importorskip("arq") # Skip tests if arq is not installed from unittest.mock import call, patch @@ -13,6 +10,8 @@ from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.job_run import JobRun +from mavedb.models.pipeline import Pipeline +from mavedb.models.published_variant import PublishedVariantsMV from mavedb.worker.jobs.data_management.views import refresh_materialized_views, refresh_published_variants_view from tests.helpers.transaction_spy import TransactionSpy diff --git a/tests/worker/jobs/external_services/network/test_clingen.py b/tests/worker/jobs/external_services/network/test_clingen.py index 1a401e8e..5587925e 100644 --- a/tests/worker/jobs/external_services/network/test_clingen.py +++ b/tests/worker/jobs/external_services/network/test_clingen.py @@ -1,6 +1,11 @@ -from unittest.mock import patch +# ruff: noqa: E402 import pytest + +pytest.importorskip("arq") + +from unittest.mock import patch + from sqlalchemy import select from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus diff --git a/tests/worker/jobs/external_services/network/test_uniprot.py b/tests/worker/jobs/external_services/network/test_uniprot.py index 288fb23b..506eb20f 100644 --- a/tests/worker/jobs/external_services/network/test_uniprot.py +++ b/tests/worker/jobs/external_services/network/test_uniprot.py @@ -1,5 +1,9 @@ +# ruff: noqa: E402 + import pytest +pytest.importorskip("arq") + from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from tests.helpers.constants import TEST_REFSEQ_IDENTIFIER diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index aaa813ed..26fb88c9 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -1,7 +1,12 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + from asyncio.unix_events import _UnixSelectorEventLoop from unittest.mock import call, patch -import pytest from sqlalchemy import select from mavedb.lib.exceptions import LDHSubmissionFailureError diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index eac1086a..16a88f5c 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -1,7 +1,11 @@ -from unittest.mock import MagicMock, call, patch +# ruff: noqa: E402 import pytest +pytest.importorskip("arq") + +from unittest.mock import MagicMock, call, patch + from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.mapped_variant import MappedVariant diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index a12534d2..e40371d4 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -1,7 +1,11 @@ -from unittest.mock import call, patch +# ruff: noqa: E402 import pytest +pytest.importorskip("arq") + +from unittest.mock import call, patch + from mavedb.lib.exceptions import ( NonExistentTargetGeneError, UniprotAmbiguousMappingResultError, diff --git a/tests/worker/jobs/pipeline_management/test_start_pipeline.py b/tests/worker/jobs/pipeline_management/test_start_pipeline.py index 5f2d88ac..b5605de1 100644 --- a/tests/worker/jobs/pipeline_management/test_start_pipeline.py +++ b/tests/worker/jobs/pipeline_management/test_start_pipeline.py @@ -1,6 +1,11 @@ -from unittest.mock import call, patch +# ruff: noqa: E402 import pytest + +pytest.importorskip("arq") + +from unittest.mock import call, patch + from sqlalchemy import select from mavedb.lib.exceptions import PipelineNotFoundError diff --git a/tests/worker/jobs/utils/test_setup.py b/tests/worker/jobs/utils/test_setup.py index 096abd2d..70c40759 100644 --- a/tests/worker/jobs/utils/test_setup.py +++ b/tests/worker/jobs/utils/test_setup.py @@ -1,7 +1,11 @@ -from unittest.mock import Mock +# ruff: noqa: E402 import pytest +pytest.importorskip("arq") + +from unittest.mock import Mock + from mavedb.models.job_run import JobRun from mavedb.worker.jobs.utils.setup import validate_job_params diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index dadb74db..66e64c85 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -1,8 +1,12 @@ -import math -from unittest.mock import ANY, MagicMock, call, patch +# ruff: noqa: E402 import pytest +pytest.importorskip("arq") + +import math +from unittest.mock import ANY, MagicMock, call, patch + from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.enums.mapping_state import MappingState from mavedb.models.enums.processing_state import ProcessingState diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index 79e763f0..5546f4d7 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -1,7 +1,12 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + from asyncio.unix_events import _UnixSelectorEventLoop from unittest.mock import MagicMock, call, patch -import pytest from sqlalchemy.exc import NoResultFound from mavedb.lib.exceptions import ( diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index dcd5862c..0cfd4a69 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -7,8 +7,6 @@ import pytest -from mavedb.worker.lib.managers.job_manager import JobManager - pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio @@ -20,6 +18,7 @@ from mavedb.models.job_run import JobRun from mavedb.models.pipeline import Pipeline from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager from mavedb.worker.lib.managers.pipeline_manager import PipelineManager from tests.helpers.transaction_spy import TransactionSpy diff --git a/tests/worker/lib/managers/test_job_manager.py b/tests/worker/lib/managers/test_job_manager.py index e9a11954..ad6b6ef1 100644 --- a/tests/worker/lib/managers/test_job_manager.py +++ b/tests/worker/lib/managers/test_job_manager.py @@ -8,8 +8,6 @@ import pytest -from mavedb.lib.logging.context import format_raised_exception_info_as_dict - pytest.importorskip("arq") import re @@ -19,6 +17,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.lib.logging.context import format_raised_exception_info_as_dict from mavedb.models.enums.job_pipeline import FailureCategory, JobStatus from mavedb.models.job_run import JobRun from mavedb.worker.lib.managers.constants import ( diff --git a/tests/worker/lib/managers/test_utils.py b/tests/worker/lib/managers/test_utils.py index fdb46e40..eb5adb81 100644 --- a/tests/worker/lib/managers/test_utils.py +++ b/tests/worker/lib/managers/test_utils.py @@ -1,5 +1,9 @@ +# ruff: noqa: E402 + import pytest +pytest.importorskip("arq") + from mavedb.models.enums.job_pipeline import DependencyType, JobStatus from mavedb.worker.lib.managers.constants import COMPLETED_JOB_STATUSES from mavedb.worker.lib.managers.utils import ( From da0e2ceb4fa694f0b8e469a13cd129688ebaf62d Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 22 Jan 2026 17:35:20 -0800 Subject: [PATCH 52/70] fix: remove ga4gh packages from server group --- poetry.lock | 11 ++--------- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0fc68c24..2bd65bd7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -180,7 +180,6 @@ files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, ] -markers = {main = "extra == \"server\""} [package.extras] benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] @@ -278,7 +277,6 @@ description = "miscellaneous simple bioinformatics utilities and lookup tables" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "bioutils-0.6.1-py3-none-any.whl", hash = "sha256:9928297331b9fc0a4fd4235afdef9a80a0916d8b5c2811ab781bded0dad4b9b6"}, {file = "bioutils-0.6.1.tar.gz", hash = "sha256:6ad7a9b6da73beea798a935499339d8b60a434edc37dfc803474d2e93e0e64aa"}, @@ -814,7 +812,6 @@ description = "Canonical JSON" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "canonicaljson-2.0.0-py3-none-any.whl", hash = "sha256:c38a315de3b5a0532f1ec1f9153cd3d716abfc565a558d00a4835428a34fca5b"}, {file = "canonicaljson-2.0.0.tar.gz", hash = "sha256:e2fdaef1d7fadc5d9cb59bd3d0d41b064ddda697809ac4325dced721d12f113f"}, @@ -1546,7 +1543,6 @@ description = "GA4GH Categorical Variation Representation (Cat-VRS) reference im optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_cat_vrs-0.7.1-py3-none-any.whl", hash = "sha256:549e726182d9fdc28d049b9adc6a8c65189bbade06b2ceed8cb20a35cbdefc45"}, {file = "ga4gh_cat_vrs-0.7.1.tar.gz", hash = "sha256:ac8d11ea5f474e8a9745107673d4e8b6949819ccdc9debe2ab8ad8e5f853f87c"}, @@ -1568,7 +1564,6 @@ description = "GA4GH Variant Annotation (VA) reference implementation" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_va_spec-0.4.2-py3-none-any.whl", hash = "sha256:c165a96dfa225845b5d63740d3ad40c9f2dcb26808cf759b73bc122a68a9a60e"}, {file = "ga4gh_va_spec-0.4.2.tar.gz", hash = "sha256:13eda6a8cfc7a2baa395e33d17e3296c2ec1c63ec85fe38085751c112cf1c902"}, @@ -1591,7 +1586,6 @@ description = "GA4GH Variation Representation Specification (VRS) reference impl optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_vrs-2.1.3-py3-none-any.whl", hash = "sha256:15b20363d9d4a4604be0930b41b14c9b4e6dc15a6e8be813544f0775b873bc5b"}, {file = "ga4gh_vrs-2.1.3.tar.gz", hash = "sha256:48af6de1eb40e00aa68ed5a935061917b4017468ef366e8e68bbbc17ffaa60f3"}, @@ -3875,7 +3869,6 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"server\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -4854,9 +4847,9 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it type = ["pytest-mypy"] [extras] -server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "ga4gh-va-spec", "hgvs", "orcid", "psycopg2", "pyathena", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "pyathena", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "4be857a91855622d543b3eb008624fc9bb57b605d17e5aec00a0e1c8bef5ed3c" +content-hash = "452148c0c5ee1b9cbb12087a27c8d6d3e650ad1eb4fed99b4470b4db16f041c6" diff --git a/pyproject.toml b/pyproject.toml index 4bf083e7..cc4e938c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ SQLAlchemy = { extras = ["mypy"], version = "~2.0.0" } [tool.poetry.extras] -server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "ga4gh-va-spec", "orcid", "psycopg2", "python-jose", "python-multipart", "pyathena", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "pyathena", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] [tool.mypy] From 1abe4c6fdff88b14b3cf94db2d8765b5f1b23755 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 11:54:48 -0800 Subject: [PATCH 53/70] docs: minimal developer docs via copilot for worker jobs --- src/mavedb/worker/README.md | 12 ++++++ src/mavedb/worker/best_practices.md | 31 +++++++++++++++ src/mavedb/worker/job_decorators.md | 48 ++++++++++++++++++++++++ src/mavedb/worker/job_managers.md | 36 ++++++++++++++++++ src/mavedb/worker/job_registry.md | 39 +++++++++++++++++++ src/mavedb/worker/jobs/jobs.md | 1 - src/mavedb/worker/jobs_overview.md | 32 ++++++++++++++++ src/mavedb/worker/pipeline_management.md | 29 ++++++++++++++ 8 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 src/mavedb/worker/README.md create mode 100644 src/mavedb/worker/best_practices.md create mode 100644 src/mavedb/worker/job_decorators.md create mode 100644 src/mavedb/worker/job_managers.md create mode 100644 src/mavedb/worker/job_registry.md delete mode 100644 src/mavedb/worker/jobs/jobs.md create mode 100644 src/mavedb/worker/jobs_overview.md create mode 100644 src/mavedb/worker/pipeline_management.md diff --git a/src/mavedb/worker/README.md b/src/mavedb/worker/README.md new file mode 100644 index 00000000..45745205 --- /dev/null +++ b/src/mavedb/worker/README.md @@ -0,0 +1,12 @@ +# ARQ Worker Jobs Developer Documentation + +This documentation provides an overview and detailed guidance for developers working with the ARQ worker jobs, decorators, and managers in the MaveDB API codebase. It is organized into the following sections: + +- [Job System Overview](jobs_overview.md) +- [Job Decorators](job_decorators.md) +- [Job Managers](job_managers.md) +- [Pipeline Management](pipeline_management.md) +- [Job Registry and Configuration](job_registry.md) +- [Best Practices & Patterns](best_practices.md) + +Each section is a separate markdown file for clarity and maintainability. Start with `jobs_overview.md` for a high-level understanding, then refer to the other files for implementation details and usage patterns. diff --git a/src/mavedb/worker/best_practices.md b/src/mavedb/worker/best_practices.md new file mode 100644 index 00000000..65301284 --- /dev/null +++ b/src/mavedb/worker/best_practices.md @@ -0,0 +1,31 @@ +# Best Practices & Patterns + +## General Principles +- Use decorators to ensure all jobs are tracked, auditable, and robust to errors. +- Keep job functions focused and stateless; use the database and JobManager for state. +- Prefer async functions for jobs to maximize concurrency. +- Use the appropriate manager (JobManager or PipelineManager) for state transitions and coordination. +- Write unit tests for job logic and integration tests for job orchestration. + +## Error Handling +- Always handle exceptions at the job or pipeline boundary. Legacy score set and mapping jobs track status at the +item level, but this will be remedied in a future update. +- Use custom exception types for clarity and recovery strategies. +- Log all errors with sufficient context for debugging and audit. + +## Job Design +- Use `with_guaranteed_job_run_record` for standalone jobs that require audit. +- Use `with_pipeline_management` for jobs that are part of a pipeline. +- Avoid side effects outside the job context; use dependency injection for testability. + +## Testing +- Mock external services in unit tests. +- Use integration tests to verify job and pipeline orchestration. +- Test error paths and recovery logic. + +## Documentation +- Document each job's purpose, parameters, and expected side effects. +- Update the registry and README when adding new jobs. + +## References +- See the other markdown files in this directory for detailed usage and examples. diff --git a/src/mavedb/worker/job_decorators.md b/src/mavedb/worker/job_decorators.md new file mode 100644 index 00000000..c3511b07 --- /dev/null +++ b/src/mavedb/worker/job_decorators.md @@ -0,0 +1,48 @@ +# Job Decorators + +Job decorators provide lifecycle management, error handling, and audit guarantees for ARQ worker jobs. They are essential for ensuring that jobs are tracked, failures are handled robustly, and pipelines are coordinated correctly. + +## Key Decorators + +### `with_guaranteed_job_run_record(job_type)` +- Ensures a `JobRun` record is created and persisted before job execution begins. +- Should be applied before any job management decorators. +- Not supported for pipeline jobs. +- Example: + ```python + @with_guaranteed_job_run_record("cron_job") + @with_job_management + async def my_cron_job(ctx, ...): + ... + ``` + +### `with_job_management` +- Adds automatic job lifecycle management to ARQ worker functions. +- Tracks job start/completion, injects a `JobManager` for progress and state updates, and handles errors robustly. +- Supports both sync and async functions. +- Example: + ```python + @with_job_management + async def my_job(ctx, job_manager: JobManager): + job_manager.update_progress(10, message="Starting work") + ... + ``` + +### `with_pipeline_management` +- Adds pipeline lifecycle management to jobs that are part of a pipeline. +- Coordinates the pipeline after the job completes (success or failure). +- Built on top of `with_job_management`. +- Example: + ```python + @with_pipeline_management + async def my_pipeline_job(ctx, ...): + ... + ``` + +## Stacking Order +- If using both `with_guaranteed_job_run_record` and `with_job_management`, always apply `with_guaranteed_job_run_record` first. +- For pipeline jobs, use only `with_pipeline_management` (which includes job management). + +## See Also +- [Job Managers](job_managers.md) +- [Pipeline Management](pipeline_management.md) diff --git a/src/mavedb/worker/job_managers.md b/src/mavedb/worker/job_managers.md new file mode 100644 index 00000000..b099b4de --- /dev/null +++ b/src/mavedb/worker/job_managers.md @@ -0,0 +1,36 @@ +# Job Managers + +Job managers are responsible for the lifecycle, state transitions, and progress tracking of jobs and pipelines. They provide atomic operations, robust error handling, and ensure data consistency. + +## JobManager +- Manages the lifecycle of a single job (start, progress, success, failure, retry, cancel). +- Ensures atomic state transitions and safe rollback on failure. +- Does not commit database changes (only flushes); the caller is responsible for commits. +- Handles progress tracking, retry logic, and session cleanup. +- Example usage: + ```python + manager = JobManager(db, redis, job_id=123) + manager.start_job() + manager.update_progress(25, message="Starting validation") + manager.succeed_job(result={"count": 100}) + ``` + +## PipelineManager +- Coordinates pipeline execution, manages job dependencies, and updates pipeline status. +- Handles pausing, unpausing, and cancellation of pipelines. +- Uses the same exception hierarchy as JobManager for consistency. +- Example usage: + ```python + pipeline_manager = PipelineManager(db, redis, pipeline_id=456) + await pipeline_manager.coordinate_pipeline() + new_status = pipeline_manager.transition_pipeline_status() + cancelled_count = pipeline_manager.cancel_remaining_jobs(reason="Dependency failed") + ``` + +## Exception Handling +- Both managers use custom exceptions for database errors, state errors, and coordination errors. +- Always handle exceptions at the job or pipeline boundary to ensure robust recovery and logging. + +## See Also +- [Job Decorators](job_decorators.md) +- [Pipeline Management](pipeline_management.md) diff --git a/src/mavedb/worker/job_registry.md b/src/mavedb/worker/job_registry.md new file mode 100644 index 00000000..c470c1ed --- /dev/null +++ b/src/mavedb/worker/job_registry.md @@ -0,0 +1,39 @@ +# Job Registry and Configuration + +All ARQ worker jobs must be registered for execution and scheduling. The registry provides a centralized list of available jobs and cron jobs for ARQ configuration. + +## Job Registry +- Located in `jobs/registry.py`. +- Lists all job functions in `BACKGROUND_FUNCTIONS` for ARQ worker discovery. +- Defines scheduled (cron) jobs in `BACKGROUND_CRONJOBS` using ARQ's `cron` utility. + +## Example +```python +from mavedb.worker.jobs.data_management import refresh_materialized_views +from mavedb.worker.jobs.external_services import submit_score_set_mappings_to_car + +BACKGROUND_FUNCTIONS = [ + refresh_materialized_views, + submit_score_set_mappings_to_car, + ... +] + +BACKGROUND_CRONJOBS = [ + cron( + refresh_materialized_views, + name="refresh_all_materialized_views", + hour=20, + minute=0, + keep_result=timedelta(minutes=2).total_seconds(), + ), +] +``` + +## Adding a New Job +1. Implement the job function in the appropriate submodule. +2. Add the function to `BACKGROUND_FUNCTIONS` in `registry.py`. +3. (Optional) Add a cron job to `BACKGROUND_CRONJOBS` if scheduling is needed. + +## See Also +- [Job System Overview](jobs_overview.md) +- [Best Practices](best_practices.md) diff --git a/src/mavedb/worker/jobs/jobs.md b/src/mavedb/worker/jobs/jobs.md deleted file mode 100644 index 30404ce4..00000000 --- a/src/mavedb/worker/jobs/jobs.md +++ /dev/null @@ -1 +0,0 @@ -TODO \ No newline at end of file diff --git a/src/mavedb/worker/jobs_overview.md b/src/mavedb/worker/jobs_overview.md new file mode 100644 index 00000000..ec14b421 --- /dev/null +++ b/src/mavedb/worker/jobs_overview.md @@ -0,0 +1,32 @@ +# Job System Overview + +The ARQ worker job system in MaveDB provides a robust, scalable, and auditable framework for background processing, data management, and integration with external services. It is designed to support both simple jobs and complex pipelines with dependency management, error handling, and progress tracking. + +## Key Concepts + +- **Job**: A discrete unit of work, typically implemented as an async function, executed by the ARQ worker. +- **Pipeline**: A sequence of jobs with defined dependencies, managed as a single workflow. +- **JobRun**: A database record tracking the execution state, progress, and results of a job. +- **JobManager**: A class responsible for managing the lifecycle and state transitions of a single job. +- **PipelineManager**: A class responsible for coordinating pipelines, managing dependencies, and updating pipeline status. +- **Decorators**: Utilities that add lifecycle management, error handling, and audit guarantees to job functions. + +## Directory Structure + +- `jobs/` — Entrypoints and registry for all ARQ worker jobs. +- `jobs/data_management/`, `jobs/external_services/`, `jobs/variant_processing/`, etc. — Job implementations grouped by domain. +- `lib/decorators/` — Decorators for job and pipeline management. +- `lib/managers/` — JobManager, PipelineManager, and related utilities. + +## Job Lifecycle + +1. **Job Registration**: All available jobs are registered in `jobs/registry.py` for ARQ configuration. +2. **Job Execution**: Jobs are executed by the ARQ worker, with decorators ensuring audit, error handling, and state management. +3. **State Tracking**: Each job run is tracked in the database via a `JobRun` record. +4. **Pipeline Coordination**: For jobs that are part of a pipeline, the `PipelineManager` coordinates dependencies and status. + +## When to Add a Job +- When you need background processing, integration with external APIs, or scheduled/cron tasks. +- When you want robust error handling, progress tracking, and auditability for long-running or critical operations. + +See the following sections for details on decorators, managers, and best practices. diff --git a/src/mavedb/worker/pipeline_management.md b/src/mavedb/worker/pipeline_management.md new file mode 100644 index 00000000..02ee5694 --- /dev/null +++ b/src/mavedb/worker/pipeline_management.md @@ -0,0 +1,29 @@ +# Pipeline Management + +Pipeline management in the ARQ worker system allows for the orchestration of complex workflows composed of multiple dependent jobs. Pipelines are coordinated using the `PipelineManager` and the `with_pipeline_management` decorator. + +## Key Concepts +- **Pipeline**: A collection of jobs with defined dependencies and a shared execution context. +- **PipelineManager**: Handles pipeline status, job dependencies, pausing/unpausing, and cancellation. +- **with_pipeline_management**: Decorator that ensures pipeline coordination after job completion. + +## Usage Patterns +- Use pipelines for workflows that require multiple jobs to run in sequence or with dependencies. +- Each job in a pipeline should be decorated with `with_pipeline_management`. +- Pipelines are defined and started outside the decorator; the decorator only coordinates after job completion. + +## Example +```python +@with_pipeline_management +async def validate_and_map_variants(ctx, ...): + ... +``` + +## Features +- Automatic pipeline status updates +- Dependency management and job coordination +- Robust error handling and logging + +## See Also +- [Job Managers](job_managers.md) +- [Job Decorators](job_decorators.md) From 797ea39c6d8248b1e2f1809de74fd45c6ee9762b Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 12:07:24 -0800 Subject: [PATCH 54/70] fix: mypy typing --- src/mavedb/scripts/populate_mapped_variants.py | 5 ++++- src/mavedb/worker/lib/decorators/job_management.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/mavedb/scripts/populate_mapped_variants.py b/src/mavedb/scripts/populate_mapped_variants.py index 72b4b449..759026bf 100644 --- a/src/mavedb/scripts/populate_mapped_variants.py +++ b/src/mavedb/scripts/populate_mapped_variants.py @@ -63,7 +63,10 @@ async def populate_mapped_variant_data(urns: Sequence[Optional[str]], all: bool, db.flush() logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") - await map_variants_for_score_set(ctx, job_run.id) + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await map_variants_for_score_set(ctx, job_run.id) # type: ignore[call-arg] if __name__ == "__main__": diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 534c0336..3829cdc6 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -122,7 +122,7 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar # Move job to final state based on result if result.get("status") == "failed" or result.get("exception"): # Exception info should always be present for failed jobs - job_manager.fail_job(result=result, error=result["exception"]) # type: ignore[keyword-arg] + job_manager.fail_job(result=result, error=result["exception"]) # type: ignore[arg-type] elif result.get("status") == "skipped": job_manager.skip_job(result=result) else: From a1d3150c6ffbe3c285e3300cfd92b82c91ab4174 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 12:51:21 -0800 Subject: [PATCH 55/70] fix: test attempting to connect via socket to athena --- .../worker/jobs/external_services/test_gnomad.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index 16a88f5c..40a7f115 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -79,6 +79,7 @@ async def test_link_gnomad_variants_no_gnomad_matches( mock_worker_ctx, sample_link_gnomad_variants_run, setup_sample_variants_with_caid, + athena_engine, ): """Test linking gnomAD variants when no gnomAD variants match the CAIDs.""" @@ -88,6 +89,7 @@ async def test_link_gnomad_variants_no_gnomad_matches( "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", return_value={}, ), + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), ): result = await link_gnomad_variants( mock_worker_ctx, @@ -106,6 +108,7 @@ async def test_link_gnomad_variants_call_linking_method( mock_worker_ctx, sample_link_gnomad_variants_run, setup_sample_variants_with_caid, + athena_engine, ): """Test that the linking method is called when gnomAD variants match CAIDs.""" @@ -119,6 +122,7 @@ async def test_link_gnomad_variants_call_linking_method( "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", return_value=1, ) as mock_linking_method, + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), ): result = await link_gnomad_variants( mock_worker_ctx, @@ -138,6 +142,7 @@ async def test_link_gnomad_variants_updates_progress( mock_worker_ctx, sample_link_gnomad_variants_run, setup_sample_variants_with_caid, + athena_engine, ): """Test that progress updates are made during the linking process.""" @@ -151,6 +156,7 @@ async def test_link_gnomad_variants_updates_progress( "mavedb.worker.jobs.external_services.gnomad.link_gnomad_variants_to_mapped_variants", return_value=1, ), + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), ): result = await link_gnomad_variants( mock_worker_ctx, @@ -176,11 +182,15 @@ async def test_link_gnomad_variants_propagates_exceptions( mock_worker_ctx, sample_link_gnomad_variants_run, setup_sample_variants_with_caid, + athena_engine, ): """Test that exceptions during the linking process are propagated.""" - with patch( - "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", - side_effect=Exception("Test exception"), + with ( + patch( + "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", + side_effect=Exception("Test exception"), + ), + patch("mavedb.worker.jobs.external_services.gnomad.athena.engine", athena_engine), ): with pytest.raises(Exception) as exc_info: await link_gnomad_variants( From fcccb9aee108678b6fbe58d5f259f9c2af7f8c21 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 12:52:38 -0800 Subject: [PATCH 56/70] feat: add Slack error notifications to job/pipeline decorators - Integrated `send_slack_error` calls in multiple test cases across different modules to ensure error notifications are sent when exceptions occur. - Updated tests for materialized views, published variants, Clingen submissions, GnomAD linking, UniProt mappings, pipeline management, and variant processing to assert that Slack notifications are triggered on failures. - Enhanced error handling in job management decorators to include Slack notifications for missing context and job failures. --- .../worker/lib/decorators/job_management.py | 26 ++++-- .../lib/decorators/pipeline_management.py | 24 +++-- .../worker/lib/managers/pipeline_manager.py | 6 +- .../worker/jobs/data_management/test_views.py | 10 ++- .../jobs/external_services/test_clingen.py | 16 ++++ .../jobs/external_services/test_gnomad.py | 6 ++ .../jobs/external_services/test_uniprot.py | 60 +++++++++---- .../test_start_pipeline.py | 16 ++-- .../jobs/variant_processing/test_creation.py | 12 +++ .../jobs/variant_processing/test_mapping.py | 14 +++ .../lib/decorators/test_job_management.py | 87 ++++++++++++------- .../decorators/test_pipeline_management.py | 86 +++++++++++------- 12 files changed, 265 insertions(+), 98 deletions(-) diff --git a/src/mavedb/worker/lib/decorators/job_management.py b/src/mavedb/worker/lib/decorators/job_management.py index 3829cdc6..5b8a8ca0 100644 --- a/src/mavedb/worker/lib/decorators/job_management.py +++ b/src/mavedb/worker/lib/decorators/job_management.py @@ -13,6 +13,7 @@ from arq import ArqRedis from sqlalchemy.orm import Session +from mavedb.lib.slack import send_slack_error from mavedb.models.enums.job_pipeline import JobStatus from mavedb.worker.lib.decorators.utils import ensure_ctx, ensure_job_id, ensure_session_ctx, is_test_mode from mavedb.worker.lib.managers import JobManager @@ -97,13 +98,18 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar Raises: Exception: Re-raises any exception after proper job failure tracking """ - ctx = ensure_ctx(args) - db_session: Session = ctx["db"] - job_id = ensure_job_id(args) + try: + ctx = ensure_ctx(args) + db_session: Session = ctx["db"] + job_id = ensure_job_id(args) - if "redis" not in ctx: - raise ValueError("Redis connection not found in job context") - redis_pool: ArqRedis = ctx["redis"] + if "redis" not in ctx: + raise ValueError("Redis connection not found in job context") + redis_pool: ArqRedis = ctx["redis"] + except Exception as e: + logger.critical(f"Failed to initialize job management context: {e}") + send_slack_error(e) + raise try: # Initialize JobManager @@ -123,6 +129,8 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar if result.get("status") == "failed" or result.get("exception"): # Exception info should always be present for failed jobs job_manager.fail_job(result=result, error=result["exception"]) # type: ignore[arg-type] + send_slack_error(result["exception"]) + elif result.get("status") == "skipped": job_manager.skip_job(result=result) else: @@ -161,13 +169,15 @@ async def _execute_managed_job(func: Callable[..., Awaitable[JobResultData]], ar except Exception as inner_e: logger.critical(f"Failed to mark job {job_id} as failed: {inner_e}") - # TODO: Notification hooks + # Notify separately about inner failure, which affects job persistence + send_slack_error(inner_e) # Re-raise the outer exception immediately to prevent duplicate notifications finally: logger.error(f"Job {job_id} failed: {e}") - # TODO: Notification hooks + # Notify about the original exception + send_slack_error(e) # Swallow the exception after alerting so ARQ can finish the job cleanly and log results. # We don't mind that we lose ARQs built in job marking, since we perform our own job diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index ac35ce38..5bcf3a15 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -14,6 +14,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.lib.slack import send_slack_error from mavedb.models.enums.job_pipeline import PipelineStatus from mavedb.models.job_run import JobRun from mavedb.worker.lib.decorators import with_job_management @@ -97,13 +98,18 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] Raises: Exception: Propagates any exception raised during function execution. """ - ctx = ensure_ctx(args) - job_id = ensure_job_id(args) - db_session: Session = ctx["db"] + try: + ctx = ensure_ctx(args) + job_id = ensure_job_id(args) + db_session: Session = ctx["db"] - if "redis" not in ctx: - raise ValueError("Redis connection not found in pipeline context") - redis_pool: ArqRedis = ctx["redis"] + if "redis" not in ctx: + raise ValueError("Redis connection not found in pipeline context") + redis_pool: ArqRedis = ctx["redis"] + except Exception as e: + logger.critical(f"Failed to initialize pipeline management context: {e}") + send_slack_error(e) + raise pipeline_manager = None pipeline_id = None @@ -164,6 +170,9 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] f"Unable to perform cleanup coordination on pipeline {pipeline_id} associated with job {job_id} after error: {inner_e}" ) + # Notify about the internal error, as it indicates a serious problem with pipeline state persistence + send_slack_error(inner_e) + # No further work here. We can rely on the notification hooks below to alert on the original failure # and should allow result generation to proceed as normal so the job can be logged. finally: @@ -172,7 +181,8 @@ async def _execute_managed_pipeline(func: Callable[..., Awaitable[JobResultData] # Build job result data for failure result = {"status": "failed", "data": {}, "exception": e} - # TODO: Notification hooks + # Notify about the original failure + send_slack_error(e) # Swallow the exception after alerting so ARQ can finish the job cleanly and log results. # We don't mind that we lose ARQs built in job marking, since we perform our own job diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py index d5b69b80..eda91c61 100644 --- a/src/mavedb/worker/lib/managers/pipeline_manager.py +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -42,6 +42,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session +from mavedb.lib.slack import send_slack_message from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun @@ -312,7 +313,10 @@ def transition_pipeline_status(self) -> PipelineStatus: else: new_status = PipelineStatus.PARTIAL logger.warning(f"Inconsistent job counts detected for pipeline {self.pipeline_id}: {status_counts}") - # TODO: Notification hooks + send_slack_message( + f"Inconsistent job counts detected for pipeline {self.pipeline_id}: {status_counts}" + ) + else: new_status = PipelineStatus.CANCELLED diff --git a/tests/worker/jobs/data_management/test_views.py b/tests/worker/jobs/data_management/test_views.py index d5011ec9..26ab0426 100644 --- a/tests/worker/jobs/data_management/test_views.py +++ b/tests/worker/jobs/data_management/test_views.py @@ -85,8 +85,10 @@ async def test_refresh_materialized_views_handles_exceptions(self, standalone_wo side_effect=Exception("Test exception during refresh"), ), TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await refresh_materialized_views(standalone_worker_context) + mock_send_slack_error.assert_called_once() job = session.execute( select(JobRun).where(JobRun.job_function == "refresh_materialized_views") @@ -235,8 +237,10 @@ async def test_refresh_published_variants_view_handles_exceptions( side_effect=Exception("Test exception during published variants view refresh"), ), TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + mock_send_slack_error.assert_called_once() session.refresh(setup_refresh_job_run) assert setup_refresh_job_run.status == JobStatus.FAILED @@ -252,8 +256,12 @@ async def test_refresh_published_variants_view_requires_params( session.add(setup_refresh_job_run) session.commit() - with TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True): + with ( + TransactionSpy.spy(session, expect_rollback=True, expect_flush=True, expect_commit=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): result = await refresh_published_variants_view(standalone_worker_context, setup_refresh_job_run.id) + mock_send_slack_error.assert_called_once() session.refresh(setup_refresh_job_run) assert setup_refresh_job_run.status == JobStatus.FAILED diff --git a/tests/worker/jobs/external_services/test_clingen.py b/tests/worker/jobs/external_services/test_clingen.py index 26fb88c9..365f9483 100644 --- a/tests/worker/jobs/external_services/test_clingen.py +++ b/tests/worker/jobs/external_services/test_clingen.py @@ -754,11 +754,13 @@ async def test_submit_score_set_mappings_to_car_no_submission_endpoint( with ( patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", ""), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await submit_score_set_mappings_to_car( standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "failed" assert isinstance(result["exception"], ValueError) @@ -947,11 +949,13 @@ async def test_submit_score_set_mappings_to_car_propagates_exception_to_decorato ), patch("mavedb.worker.jobs.external_services.clingen.CAR_SUBMISSION_ENDPOINT", "http://fake-endpoint"), patch("mavedb.worker.jobs.external_services.clingen.CLIN_GEN_SUBMISSION_ENABLED", True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await submit_score_set_mappings_to_car( standalone_worker_context, submit_score_set_mappings_to_car_sample_job_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], Exception) assert str(result["exception"]) == "ClinGen service error" @@ -1143,6 +1147,7 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handl "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", side_effect=Exception("ClinGen service error"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run.id @@ -1150,6 +1155,7 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handl await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run) assert submit_score_set_mappings_to_car_sample_job_run.status == JobStatus.FAILED @@ -1200,6 +1206,7 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handl "mavedb.worker.jobs.external_services.clingen.ClinGenAlleleRegistryService.dispatch_submissions", side_effect=Exception("ClinGen service error"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "submit_score_set_mappings_to_car", submit_score_set_mappings_to_car_sample_job_run_in_pipeline.id @@ -1207,6 +1214,7 @@ async def test_submit_score_set_mappings_to_car_with_arq_context_exception_handl await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify the job status is updated in the database session.refresh(submit_score_set_mappings_to_car_sample_job_run_in_pipeline) assert submit_score_set_mappings_to_car_sample_job_run_in_pipeline.status == JobStatus.FAILED @@ -1701,11 +1709,13 @@ async def test_submit_score_set_mappings_to_ldh_propagates_exception_to_decorato side_effect=Exception("LDH service error"), ), patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await submit_score_set_mappings_to_ldh( standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], Exception) assert str(result["exception"]) == "LDH service error" @@ -1848,11 +1858,13 @@ async def dummy_submission_failure(*args, **kwargs): return_value=dummy_submission_failure(), ), patch("mavedb.worker.jobs.external_services.clingen.ClinGenLdhService.authenticate", return_value=None), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await submit_score_set_mappings_to_ldh( standalone_worker_context, submit_score_set_mappings_to_ldh_sample_job_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "failed" assert isinstance(result["exception"], LDHSubmissionFailureError) @@ -2201,6 +2213,7 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handl "run_in_executor", side_effect=Exception("LDH service error"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run.id @@ -2208,6 +2221,7 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handl await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify no annotation statuses were created annotation_statuses = session.scalars( select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") @@ -2254,6 +2268,7 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handl "run_in_executor", side_effect=Exception("LDH service error"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "submit_score_set_mappings_to_ldh", submit_score_set_mappings_to_ldh_sample_job_run_in_pipeline.id @@ -2261,6 +2276,7 @@ async def test_submit_score_set_mappings_to_ldh_with_arq_context_exception_handl await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify no annotation statuses were created annotation_statuses = session.scalars( select(VariantAnnotationStatus).where(VariantAnnotationStatus.annotation_type == "ldh_submission") diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index 40a7f115..a3e379e9 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -355,12 +355,14 @@ async def test_link_gnomad_variants_exceptions_handled_by_decorators( "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", side_effect=Exception("Test exception"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await link_gnomad_variants( mock_worker_ctx, sample_link_gnomad_variants_run.id, ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], Exception) @@ -465,11 +467,13 @@ async def test_link_gnomad_variants_with_arq_context_exception_handling_independ "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", side_effect=Exception("Test exception"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run.id) await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that no gnomAD variants were linked gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 @@ -501,11 +505,13 @@ async def test_link_gnomad_variants_with_arq_context_exception_handling_pipeline "mavedb.worker.jobs.external_services.gnomad.gnomad_variant_data_for_caids", side_effect=Exception("Test exception"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job("link_gnomad_variants", sample_link_gnomad_variants_run_pipeline.id) await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that no gnomAD variants were linked gnomad_variants = session.query(GnomADVariant).all() assert len(gnomad_variants) == 0 diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index e40371d4..dd9e0990 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -670,14 +670,18 @@ async def test_submit_uniprot_mapping_jobs_propagates_exceptions( target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} session.commit() - with patch( - "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", - side_effect=Exception("UniProt API failure"), + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await submit_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], Exception) @@ -810,14 +814,18 @@ async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} session.commit() - with patch( - "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", - return_value="job_12345", + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + return_value="job_12345", + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await submit_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, sample_submit_uniprot_mapping_jobs_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "failed" assert isinstance(result["exception"], UniProtPollingEnqueueError) @@ -964,9 +972,12 @@ async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_i target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} session.commit() - with patch( - "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", - side_effect=Exception("UniProt API failure"), + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run.id @@ -974,6 +985,7 @@ async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_i await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that the job metadata contains no submitted jobs session.refresh(sample_submit_uniprot_mapping_jobs_run) assert sample_submit_uniprot_mapping_jobs_run.metadata_.get("submitted_jobs") is None @@ -1007,9 +1019,12 @@ async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_p target_gene.post_mapped_metadata = {"protein": {"sequence_accessions": [VALID_NT_ACCESSION]}} session.commit() - with patch( - "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", - side_effect=Exception("UniProt API failure"), + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.submit_id_mapping", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "submit_uniprot_mapping_jobs_for_score_set", sample_submit_uniprot_mapping_jobs_run_in_pipeline.id @@ -1017,6 +1032,7 @@ async def test_submit_uniprot_mapping_jobs_with_arq_context_exception_handling_p await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that the job metadata contains no submitted jobs session.refresh(sample_submit_uniprot_mapping_jobs_run_in_pipeline) assert sample_submit_uniprot_mapping_jobs_run_in_pipeline.metadata_.get("submitted_jobs") is None @@ -1688,11 +1704,13 @@ async def test_poll_uniprot_mapping_jobs_no_results( "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", return_value={"results": []}, # minimal response with no results ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await poll_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, sample_polling_job_for_submission_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], UniprotMappingResultNotFoundError) @@ -1745,11 +1763,13 @@ async def test_poll_uniprot_mapping_jobs_ambiguous_results( ] }, ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await poll_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, sample_polling_job_for_submission_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], UniprotAmbiguousMappingResultError) @@ -1785,11 +1805,13 @@ async def test_poll_uniprot_mapping_jobs_nonexistent_target( "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.get_id_mapping_results", return_value=TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await poll_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, sample_polling_job_for_submission_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], NonExistentTargetGeneError) @@ -1816,14 +1838,18 @@ async def test_poll_uniprot_mapping_jobs_propagates_exceptions_to_decorator( } session.commit() - with patch( - "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", - side_effect=Exception("UniProt API failure"), + with ( + patch( + "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", + side_effect=Exception("UniProt API failure"), + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await poll_uniprot_mapping_jobs_for_score_set( mock_worker_ctx, sample_polling_job_for_submission_run.id ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], Exception) @@ -1960,6 +1986,7 @@ async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_ind "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", side_effect=Exception("UniProt API failure"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "poll_uniprot_mapping_jobs_for_score_set", sample_polling_job_for_submission_run.id @@ -1967,6 +1994,7 @@ async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_ind await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that the polling job failed session.refresh(sample_polling_job_for_submission_run) assert sample_polling_job_for_submission_run.status == JobStatus.FAILED @@ -1998,6 +2026,7 @@ async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_pip "mavedb.worker.jobs.external_services.uniprot.UniProtIDMappingAPI.check_id_mapping_results_ready", side_effect=Exception("UniProt API failure"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job( "poll_uniprot_mapping_jobs_for_score_set", @@ -2006,6 +2035,7 @@ async def test_poll_uniprot_mapping_jobs_with_arq_context_exception_handling_pip await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that the polling job failed session.refresh(sample_poll_uniprot_mapping_jobs_run_in_pipeline) assert sample_poll_uniprot_mapping_jobs_run_in_pipeline.status == JobStatus.FAILED diff --git a/tests/worker/jobs/pipeline_management/test_start_pipeline.py b/tests/worker/jobs/pipeline_management/test_start_pipeline.py index b5605de1..08179374 100644 --- a/tests/worker/jobs/pipeline_management/test_start_pipeline.py +++ b/tests/worker/jobs/pipeline_management/test_start_pipeline.py @@ -160,8 +160,10 @@ async def test_start_pipeline_on_job_without_pipeline_fails( sample_dummy_pipeline_start.pipeline_id = None session.commit() - result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) - assert result["status"] == "exception" + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) + assert result["status"] == "exception" + mock_send_slack_error.assert_called_once() # Verify the start job run status session.refresh(sample_dummy_pipeline_start) @@ -207,12 +209,16 @@ async def custom_side_effect(*args, **kwargs): PipelineManager(session, session, sample_dummy_pipeline.id), *args, **kwargs ) # Allow the final coordination attempt to proceed 'normally' - with patch( - "mavedb.worker.lib.managers.pipeline_manager.PipelineManager.coordinate_pipeline", - side_effect=custom_side_effect, + with ( + patch( + "mavedb.worker.lib.managers.pipeline_manager.PipelineManager.coordinate_pipeline", + side_effect=custom_side_effect, + ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await start_pipeline(mock_worker_ctx, sample_dummy_pipeline_start.id) assert result["status"] == "exception" + mock_send_slack_error.assert_called_once() # Verify the start job run status session.refresh(sample_dummy_pipeline_start) diff --git a/tests/worker/jobs/variant_processing/test_creation.py b/tests/worker/jobs/variant_processing/test_creation.py index 66e64c85..b2b15fca 100644 --- a/tests/worker/jobs/variant_processing/test_creation.py +++ b/tests/worker/jobs/variant_processing/test_creation.py @@ -943,9 +943,11 @@ async def test_create_variants_for_score_set_validation_error_during_creation( "mavedb.worker.jobs.variant_processing.creation.pd.read_csv", side_effect=[sample_score_dataframe, sample_count_dataframe], ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + mock_send_slack_error.assert_called_once() # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed @@ -990,9 +992,11 @@ async def test_create_variants_for_score_set_generic_exception_handling_during_c "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", side_effect=Exception("Generic exception during data validation"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + mock_send_slack_error.assert_called_once() # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed @@ -1049,9 +1053,11 @@ async def test_create_variants_for_score_set_generic_exception_handling_during_r "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", side_effect=Exception("Generic exception during data validation"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await create_variants_for_score_set(mock_worker_ctx, sample_independent_variant_creation_run.id) + mock_send_slack_error.assert_called_once() # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed @@ -1098,9 +1104,11 @@ async def test_create_variants_for_score_set_pipeline_job_generic_exception_hand "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", side_effect=Exception("Generic exception during data validation"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await create_variants_for_score_set(mock_worker_ctx, sample_pipeline_variant_creation_run.id) + mock_send_slack_error.assert_called_once() # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed @@ -1305,11 +1313,13 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", side_effect=Exception("Generic exception during data validation"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job("create_variants_for_score_set", sample_independent_variant_creation_run.id) await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed @@ -1351,11 +1361,13 @@ async def test_create_variants_for_score_set_with_arq_context_generic_exception_ "mavedb.worker.jobs.variant_processing.creation.validate_and_standardize_dataframe_pair", side_effect=Exception("Generic exception during data validation"), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job("create_variants_for_score_set", sample_pipeline_variant_creation_run.id) await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() # Verify that the score set's processing state is updated to failed session.refresh(sample_score_set) assert sample_score_set.processing_state == ProcessingState.failed diff --git a/tests/worker/jobs/variant_processing/test_mapping.py b/tests/worker/jobs/variant_processing/test_mapping.py index 5546f4d7..61357984 100644 --- a/tests/worker/jobs/variant_processing/test_mapping.py +++ b/tests/worker/jobs/variant_processing/test_mapping.py @@ -1120,12 +1120,14 @@ async def dummy_mapping_job(): # with return value from run_in_executor. with ( patch.object(_UnixSelectorEventLoop, "run_in_executor", return_value=dummy_mapping_job()), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], NonexistentMappingResultsError) assert result["data"] == {} @@ -1198,12 +1200,14 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], NonexistentMappingScoresError) assert result["data"] == {} @@ -1274,12 +1278,14 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert isinstance(result["exception"], NonexistentMappingReferenceError) assert result["data"] == {} @@ -1457,12 +1463,14 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert result["data"] == {} assert isinstance(result["exception"], NonexistentMappingScoresError) @@ -1508,12 +1516,14 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): result = await map_variants_for_score_set( mock_worker_ctx, sample_independent_variant_mapping_run.id, ) + mock_send_slack_error.assert_called_once() assert result["status"] == "exception" assert result["data"] == {} assert isinstance(result["exception"], ValueError) @@ -1755,11 +1765,13 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job("map_variants_for_score_set", sample_independent_variant_mapping_run.id) await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None # but replaced with generic error message for external visibility @@ -1807,11 +1819,13 @@ async def dummy_mapping_job(): "run_in_executor", return_value=dummy_mapping_job(), ), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): await arq_redis.enqueue_job("map_variants_for_score_set", sample_pipeline_variant_mapping_run.id) await arq_worker.async_run() await arq_worker.run_check() + mock_send_slack_error.assert_called_once() assert sample_score_set.mapping_state == MappingState.failed assert sample_score_set.mapping_errors is not None # but replaced with generic error message for external visibility diff --git a/tests/worker/lib/decorators/test_job_management.py b/tests/worker/lib/decorators/test_job_management.py index aa80fc6e..c887588f 100644 --- a/tests/worker/lib/decorators/test_job_management.py +++ b/tests/worker/lib/decorators/test_job_management.py @@ -7,6 +7,7 @@ import pytest + pytest.importorskip("arq") # Skip tests if arq is not installed import asyncio @@ -141,6 +142,7 @@ async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_rais ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", return_value=None) as mock_fail_job, @@ -151,12 +153,14 @@ async def test_decorator_calls_start_job_and_fail_job_when_wrapped_function_rais mock_start_job.assert_called_once() mock_fail_job.assert_called_once() + mock_send_slack_error.assert_called_once() async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_raises_and_retry( self, session, mock_worker_ctx, mock_job_manager ): with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, patch.object(mock_job_manager, "start_job", return_value=None) as mock_start_job, patch.object(mock_job_manager, "should_retry", return_value=True), patch.object(mock_job_manager, "prepare_retry", return_value=None) as mock_prepare_retry, @@ -167,6 +171,7 @@ async def test_decorator_calls_start_job_and_retries_job_when_wrapped_function_r mock_start_job.assert_called_once() mock_prepare_retry.assert_called_once_with(reason="error in wrapped function") + mock_send_slack_error.assert_called_once() @pytest.mark.parametrize("missing_key", ["redis"]) async def test_decorator_raises_value_error_if_required_context_missing( @@ -174,9 +179,13 @@ async def test_decorator_raises_value_error_if_required_context_missing( ): del mock_worker_ctx[missing_key] - with pytest.raises(ValueError) as exc_info: + with ( + pytest.raises(ValueError) as exc_info, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): await sample_job(mock_worker_ctx, 999) + mock_send_slack_error.assert_called_once() assert missing_key.replace("_", " ") in str(exc_info.value).lower() assert "not found in job context" in str(exc_info.value).lower() @@ -186,6 +195,7 @@ async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( raised_exc = JobStateError("error in job start") with ( patch("mavedb.worker.lib.decorators.job_management.JobManager") as mock_job_manager_class, + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, patch.object(mock_job_manager, "start_job", side_effect=raised_exc), patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", return_value=None), @@ -196,12 +206,18 @@ async def test_decorator_swallows_exception_from_lifecycle_state_outside_except( assert result["status"] == "exception" assert raised_exc == result["exception"] + mock_send_slack_error.assert_called_once() async def test_decorator_raises_value_error_if_job_id_missing(self, session, mock_job_manager, mock_worker_ctx): # Remove job_id from args to simulate missing job_id - with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(session): + with ( + pytest.raises(ValueError) as exc_info, + TransactionSpy.spy(session), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, + ): await sample_job(mock_worker_ctx) + mock_send_slack_error.assert_called_once() assert "job id not found in function arguments" in str(exc_info.value).lower() async def test_decorator_swallows_exception_from_wrapped_function_inside_except( @@ -213,10 +229,13 @@ async def test_decorator_swallows_exception_from_wrapped_function_inside_except( patch.object(mock_job_manager, "should_retry", return_value=False), patch.object(mock_job_manager, "fail_job", side_effect=JobStateError("error in job fail")), TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error, ): mock_job_manager_class.return_value = mock_job_manager result = await sample_raise(mock_worker_ctx, 999) + # Should notify for internal and job error + assert mock_send_slack_error.call_count == 2 # Errors within the main try block should take precedence assert result["status"] == "exception" assert str(result["exception"]) == "error in wrapped function" @@ -290,9 +309,11 @@ async def test_decorator_integrated_job_lifecycle_failed( async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): return {"status": "failed", "data": {}, "exception": RuntimeError("Simulated job failure")} - # Run the job - await sample_job(standalone_worker_context, sample_job_run.id) + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Run the job + await sample_job(standalone_worker_context, sample_job_run.id) + mock_send_slack_error.assert_called_once() # After completion, status should be FAILED job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() assert job.status == JobStatus.FAILED @@ -310,17 +331,20 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): raise RuntimeError("Simulated job failure") # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) - # At this point, the job should be started but not in error - await asyncio.sleep(0.1) # Give the event loop a moment to start the job - job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() - assert job.status == JobStatus.RUNNING + # At this point, the job should be started but not in error + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING - # Now allow the job to complete with failure. This failure - # should be swallowed by the job_task. - event.set() - await job_task + # Now allow the job to complete with failure. This failure + # should be swallowed by the job_task. + event.set() + await job_task + + mock_send_slack_error.assert_called_once() # After failure, status should be FAILED job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -339,23 +363,26 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals raise RuntimeError("Simulated job failure for retry") - # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) - - # At this point, the job should be started but not in error - await asyncio.sleep(0.1) # Give the event loop a moment to start the job - job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() - assert job.status == JobStatus.RUNNING - - # TODO: We patch `should_retry` to return True to force a retry scenario. After implementing failure - # categorization in the worker, this patch can be removed and we should directly test retry logic based - # on failure categories. - # - # Now allow the job to complete with failure that triggers a retry. This failure - # should be swallowed by the job_task. - with patch.object(JobManager, "should_retry", return_value=True): - event.set() - await job_task + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + + # At this point, the job should be started but not in error + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING + + # TODO: We patch `should_retry` to return True to force a retry scenario. After implementing failure + # categorization in the worker, this patch can be removed and we should directly test retry logic based + # on failure categories. + # + # Now allow the job to complete with failure that triggers a retry. This failure + # should be swallowed by the job_task. + with patch.object(JobManager, "should_retry", return_value=True): + event.set() + await job_task + + mock_send_slack_error.assert_called_once() # After failure with retry, status should be PENDING job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() diff --git a/tests/worker/lib/decorators/test_pipeline_management.py b/tests/worker/lib/decorators/test_pipeline_management.py index 0cfd4a69..45c7c3d2 100644 --- a/tests/worker/lib/decorators/test_pipeline_management.py +++ b/tests/worker/lib/decorators/test_pipeline_management.py @@ -98,18 +98,28 @@ async def test_decorator_raises_value_error_if_required_context_missing( ): del mock_worker_ctx[missing_key] - with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): + with ( + pytest.raises(ValueError) as exc_info, + TransactionSpy.spy(mock_pipeline_manager.db), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): await sample_job(mock_worker_ctx, 999) assert missing_key.replace("_", " ") in str(exc_info.value).lower() assert "not found in pipeline context" in str(exc_info.value).lower() + mock_send_slack_error.assert_called_once() async def test_decorator_raises_value_error_if_job_id_missing(self, mock_pipeline_manager, mock_worker_ctx): # Remove job_id from args to simulate missing job_id - with pytest.raises(ValueError) as exc_info, TransactionSpy.spy(mock_pipeline_manager.db): + with ( + pytest.raises(ValueError) as exc_info, + TransactionSpy.spy(mock_pipeline_manager.db), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, + ): await sample_job(mock_worker_ctx) assert "job id not found in function arguments" in str(exc_info.value).lower() + mock_send_slack_error.assert_called_once() async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id( self, session, mock_pipeline_manager, mock_worker_ctx @@ -120,8 +130,10 @@ async def test_decorator_swallows_exception_if_cant_fetch_pipeline_id( exception=ValueError("job id not found in pipeline context"), expect_rollback=True, ), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, ): await sample_job(mock_worker_ctx, 999) + mock_send_slack_error.assert_called_once() async def test_decorator_fetches_pipeline_from_db_and_constructs_pipeline_manager( self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data @@ -214,11 +226,12 @@ async def test_decorator_swallows_exception_from_wrapped_function( patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, ): mock_pipeline_manager_class.return_value = mock_pipeline_manager await sample_raise(mock_worker_ctx, sample_job_run.id) - # TODO: Assert calls for notification hooks and job result data + mock_send_slack_error.assert_called_once() async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pipeline( self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data @@ -235,11 +248,12 @@ async def test_decorator_swallows_exception_from_pipeline_manager_coordinate_pip # Exception raised from coordinate_pipeline should trigger rollback, # and commit will be called when pipeline status is set to running TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, ): mock_pipeline_manager_class.return_value = mock_pipeline_manager await sample_job(mock_worker_ctx, sample_job_run.id) - # TODO: Assert calls for notification hooks and job result data + assert mock_send_slack_error.call_count == 2 async def test_decorator_swallows_exception_from_job_management_decorator( self, session, mock_pipeline_manager, mock_worker_ctx, sample_job_run, with_populated_job_data @@ -256,8 +270,10 @@ def passthrough_decorator(f): ) as mock_with_job_mgmt, patch.object(mock_pipeline_manager, "start_pipeline", return_value=None), patch.object(mock_pipeline_manager, "get_pipeline_status", return_value=PipelineStatus.CREATED), + patch.object(mock_pipeline_manager, "coordinate_pipeline", return_value=None), patch("mavedb.worker.lib.decorators.pipeline_management.PipelineManager") as mock_pipeline_manager_class, TransactionSpy.spy(session, expect_commit=True, expect_rollback=True), + patch("mavedb.worker.lib.decorators.pipeline_management.send_slack_error") as mock_send_slack_error, ): mock_pipeline_manager_class.return_value = mock_pipeline_manager @@ -268,7 +284,7 @@ async def sample_job(ctx: dict, job_id: int, pipeline_manager: PipelineManager): await sample_job(mock_worker_ctx, sample_job_run.id, pipeline_manager=mock_pipeline_manager) mock_with_job_mgmt.assert_called_once() - # TODO: Assert calls for notification hooks and job result data + mock_send_slack_error.assert_called_once() @pytest.mark.asyncio @@ -398,22 +414,26 @@ async def sample_dependent_job(ctx: dict, job_id: int, job_manager: JobManager): await dep_event.wait() # Simulate async work, block until test signals return {"status": "ok", "data": {}, "exception": None} - # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + # job management handles slack alerting in this context + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) - # At this point, the job should be started but not completed - await asyncio.sleep(0.1) # Give the event loop a moment to start the job - job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() - assert job.status == JobStatus.RUNNING + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING - pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() - assert pipeline.status == PipelineStatus.RUNNING + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING - # Now allow the job to complete with failure that triggers a retry. This failure - # should be swallowed by the job_task. - with patch.object(JobManager, "should_retry", return_value=True): - event.set() - await job_task + # Now allow the job to complete with failure that triggers a retry. This failure + # should be swallowed by the job_task. + with patch.object(JobManager, "should_retry", return_value=True): + event.set() + await job_task + + mock_send_slack_error.assert_called_once() # After failure with retry, status should be QUEUED job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() @@ -494,22 +514,26 @@ async def sample_job(ctx: dict, job_id: int, job_manager: JobManager): await event.wait() # Simulate async work, block until test signals raise RuntimeError("Simulated job failure") - # Start the job (it will block at event.wait()) - job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) + # job management handles slack alerting in this context + with patch("mavedb.worker.lib.decorators.job_management.send_slack_error") as mock_send_slack_error: + # Start the job (it will block at event.wait()) + job_task = asyncio.create_task(sample_job(standalone_worker_context, sample_job_run.id)) - # At this point, the job should be started but not completed - await asyncio.sleep(0.1) # Give the event loop a moment to start the job - job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() - assert job.status == JobStatus.RUNNING + # At this point, the job should be started but not completed + await asyncio.sleep(0.1) # Give the event loop a moment to start the job + job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() + assert job.status == JobStatus.RUNNING - pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() - assert pipeline.status == PipelineStatus.RUNNING + pipeline = session.execute(select(Pipeline).where(Pipeline.id == sample_pipeline.id)).scalar_one() + assert pipeline.status == PipelineStatus.RUNNING - # Now allow the job to complete with failure and flush the Redis queue. This failure - # should be swallowed by the pipeline manager - await arq_redis.flushdb() - event.set() - await job_task + # Now allow the job to complete with failure and flush the Redis queue. This failure + # should be swallowed by the pipeline manager + await arq_redis.flushdb() + event.set() + await job_task + + mock_send_slack_error.assert_called_once() # After failure with no retry, status should be FAILED job = session.execute(select(JobRun).where(JobRun.id == sample_job_run.id)).scalar_one() From e85312ad334e94498eec1fe18153ee38aff29078 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 13:18:29 -0800 Subject: [PATCH 57/70] fix: update TODO comments for clarity and specificity in UniProt and ClinGen tests --- src/mavedb/worker/jobs/external_services/uniprot.py | 4 ++-- src/mavedb/worker/jobs/variant_processing/creation.py | 2 +- src/mavedb/worker/lib/decorators/pipeline_management.py | 2 +- tests/worker/jobs/external_services/network/test_clingen.py | 4 ++-- tests/worker/jobs/external_services/test_uniprot.py | 1 - 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/mavedb/worker/jobs/external_services/uniprot.py b/src/mavedb/worker/jobs/external_services/uniprot.py index bfd89a0d..637ff162 100644 --- a/src/mavedb/worker/jobs/external_services/uniprot.py +++ b/src/mavedb/worker/jobs/external_services/uniprot.py @@ -63,7 +63,7 @@ async def submit_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ - Submits UniProt ID mapping jobs for each target gene in the ScoreSet. - Fetches the dependent job for this function, which is the polling job for UniProt results. Sets the parameter `mapping_jobs` on the polling job with a dictionary of target gene IDs to UniProt job IDs. - TODO#XXX: Split mapping jobs into one per target gene so that polling can be more granular. + TODO#646: Split mapping jobs into one per target gene so that polling can be more granular. Raises: - UniProtPollingEnqueueError: If the dependent polling job cannot be found. @@ -216,7 +216,7 @@ async def poll_uniprot_mapping_jobs_for_score_set(ctx: dict, job_id: int, job_ma - Polls UniProt ID mapping jobs for each target gene in the ScoreSet. - Updates target genes with mapped UniProt IDs in the database. - TODO#XXX: Split mapping jobs into one per target gene so that polling can be more granular. + TODO#646: Split mapping jobs into one per target gene so that polling can be more granular. Returns: dict: Result indicating success and any exception details diff --git a/src/mavedb/worker/jobs/variant_processing/creation.py b/src/mavedb/worker/jobs/variant_processing/creation.py index 3774782a..cee4ff5f 100644 --- a/src/mavedb/worker/jobs/variant_processing/creation.py +++ b/src/mavedb/worker/jobs/variant_processing/creation.py @@ -80,7 +80,7 @@ async def create_variants_for_score_set(ctx: dict, job_id: int, job_manager: Job # Main processing block. Handled in a try/except to ensure we can set score set state appropriately, # which is handled independently of the job state. - # TODO:XXX In a future iteration, we should rely on the job manager itself for maintaining processing + # TODO:647 In a future iteration, we should rely on the job manager itself for maintaining processing # state for better cohesion. This try/except is redundant in it's duties with the job manager. try: correlation_id = job.job_params["correlation_id"] # type: ignore diff --git a/src/mavedb/worker/lib/decorators/pipeline_management.py b/src/mavedb/worker/lib/decorators/pipeline_management.py index 5bcf3a15..a181c72e 100644 --- a/src/mavedb/worker/lib/decorators/pipeline_management.py +++ b/src/mavedb/worker/lib/decorators/pipeline_management.py @@ -50,7 +50,7 @@ def with_pipeline_management(func: F) -> F: Features: - Pipeline lifecycle tracking - Job lifecycle tracking via with_job_management - - Robust error handling, logging, and TODO(alerting) on failures + - Robust error handling, logging, and alerting on failures Example: @with_pipeline_management diff --git a/tests/worker/jobs/external_services/network/test_clingen.py b/tests/worker/jobs/external_services/network/test_clingen.py index 5587925e..2bd8645a 100644 --- a/tests/worker/jobs/external_services/network/test_clingen.py +++ b/tests/worker/jobs/external_services/network/test_clingen.py @@ -15,7 +15,7 @@ pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") -# TODO#XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. +# XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. @pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") @pytest.mark.asyncio @pytest.mark.integration @@ -82,7 +82,7 @@ async def test_clingen_car_submission_e2e( assert variant.clingen_allele_id is not None -# TODO#XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. +# XXX: Connect with ClinGen to resolve the invalid credentials issue on test site. @pytest.mark.skip(reason="invalid credentials, despite what is provided in documentation.") @pytest.mark.integration @pytest.mark.asyncio diff --git a/tests/worker/jobs/external_services/test_uniprot.py b/tests/worker/jobs/external_services/test_uniprot.py index dd9e0990..99ab3a07 100644 --- a/tests/worker/jobs/external_services/test_uniprot.py +++ b/tests/worker/jobs/external_services/test_uniprot.py @@ -837,7 +837,6 @@ async def test_submit_uniprot_mapping_jobs_no_dependent_job_raises( # Verify that the submission job failed session.refresh(sample_submit_uniprot_mapping_jobs_run) - # TODO#XXX: Should be failed when supported by decorator assert sample_submit_uniprot_mapping_jobs_run.status == JobStatus.FAILED # nothing to verify for dependent polling job since it does not exist From 893473f4e66e1a328687a3b7d30f7061059f978e Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 15:12:51 -0800 Subject: [PATCH 58/70] feat: make Redis client optional in managers and add error handling for missing Redis in PipelineManager --- .../jobs/pipeline_management/start_pipeline.py | 3 ++- src/mavedb/worker/lib/managers/base_manager.py | 5 +++-- src/mavedb/worker/lib/managers/job_manager.py | 4 ++-- src/mavedb/worker/lib/managers/pipeline_manager.py | 8 +++++++- tests/worker/lib/managers/test_pipeline_manager.py | 14 ++++++++++++++ 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py index e2d80f38..7dbed7d4 100644 --- a/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py +++ b/src/mavedb/worker/jobs/pipeline_management/start_pipeline.py @@ -53,7 +53,8 @@ async def start_pipeline(ctx: dict, job_id: int, job_manager: JobManager) -> Job # Initialize PipelineManager and coordinate pipeline. The pipeline manager decorator # will have started the pipeline for us already, but doesn't coordinate on start automatically. - pipeline_manager = PipelineManager(job_manager.db, job_manager.redis, job_manager.pipeline_id) + redis = job_manager.redis or ctx["redis"] + pipeline_manager = PipelineManager(job_manager.db, redis, job_manager.pipeline_id) await pipeline_manager.coordinate_pipeline() # Finalize job state diff --git a/src/mavedb/worker/lib/managers/base_manager.py b/src/mavedb/worker/lib/managers/base_manager.py index 08da4670..de0fe67f 100644 --- a/src/mavedb/worker/lib/managers/base_manager.py +++ b/src/mavedb/worker/lib/managers/base_manager.py @@ -6,6 +6,7 @@ import logging from abc import ABC +from typing import Optional from arq import ArqRedis from sqlalchemy.orm import Session @@ -27,12 +28,12 @@ class BaseManager(ABC): redis: ARQ Redis client for job queue operations """ - def __init__(self, db: Session, redis: ArqRedis): + def __init__(self, db: Session, redis: Optional[ArqRedis]): """Initialize base manager with database and Redis connections. Args: db: SQLAlchemy database session for job and pipeline queries - redis: ARQ Redis client for job queue operations + redis(Optional[ArqRedis]): ARQ Redis client for job queue operations Raises: DatabaseConnectionError: Cannot connect to database diff --git a/src/mavedb/worker/lib/managers/job_manager.py b/src/mavedb/worker/lib/managers/job_manager.py index b02cde18..e762ada0 100644 --- a/src/mavedb/worker/lib/managers/job_manager.py +++ b/src/mavedb/worker/lib/managers/job_manager.py @@ -134,7 +134,7 @@ class JobManager(BaseManager): context: dict[str, Any] = {} - def __init__(self, db: Session, redis: ArqRedis, job_id: int): + def __init__(self, db: Session, redis: Optional[ArqRedis], job_id: int): """Initialize JobManager for a specific job. Args: @@ -142,7 +142,7 @@ def __init__(self, db: Session, redis: ArqRedis, job_id: int): be configured for the appropriate database and have proper transaction isolation. redis: ARQ Redis client for job queue operations. Must be connected - and ready for enqueue operations. + and ready for enqueue operations. Optional; can be None if Redis is not used. job_id: Unique identifier of the job to manage. Must correspond to an existing JobRun record in the database. diff --git a/src/mavedb/worker/lib/managers/pipeline_manager.py b/src/mavedb/worker/lib/managers/pipeline_manager.py index eda91c61..b0ecfcf1 100644 --- a/src/mavedb/worker/lib/managers/pipeline_manager.py +++ b/src/mavedb/worker/lib/managers/pipeline_manager.py @@ -142,7 +142,9 @@ def __init__(self, db: Session, redis: ArqRedis, pipeline_id: int): Args: db: SQLAlchemy database session for job and pipeline queries - redis: ARQ Redis client for job queue operations + redis: ARQ Redis client for job queue operations. Note that although the Redis + client is optional for base managers, PipelineManager requires it for + job coordination. pipeline_id: ID of the pipeline this manager instance will coordinate Raises: @@ -1126,6 +1128,10 @@ async def _enqueue_in_arq(self, job: JobRun, is_retry: bool) -> None: Raises: PipelineCoordinationError: If ARQ enqueuing fails """ + if not self.redis: + logger.error(f"Redis client is not configured for PipelineManager; cannot enqueue job {job.urn}") + raise PipelineCoordinationError("Redis client is not configured for job enqueueing; cannot proceed.") + try: defer_by = timedelta(seconds=job.retry_delay_seconds if is_retry and job.retry_delay_seconds else 0) arq_success = await self.redis.enqueue_job(job.job_function, job.id, _defer_by=defer_by, _job_id=job.urn) diff --git a/tests/worker/lib/managers/test_pipeline_manager.py b/tests/worker/lib/managers/test_pipeline_manager.py index 4f892824..7cb7931e 100644 --- a/tests/worker/lib/managers/test_pipeline_manager.py +++ b/tests/worker/lib/managers/test_pipeline_manager.py @@ -3265,6 +3265,20 @@ def test_set_pipeline_status_integration_running_status_sets_started_at( class TestEnqueueInArqUnit: """Test enqueuing jobs in ARQ.""" + @pytest.mark.asyncio + async def test_enqueue_in_arq_without_redis_raises_pipeline_coordination_error(self, mock_pipeline_manager): + """Test that attempting to enqueue a job without a Redis connection raises PipelineCoordinationError.""" + mock_job = Mock(spec=JobRun, job_function="test_func", id=1, urn="urn:example", retry_delay_seconds=10) + mock_pipeline_manager.redis = None + + with ( + pytest.raises( + PipelineCoordinationError, match="Redis client is not configured for job enqueueing; cannot proceed." + ), + TransactionSpy.spy(mock_pipeline_manager.db), + ): + await mock_pipeline_manager._enqueue_in_arq(job=mock_job, is_retry=False) + @pytest.mark.asyncio @pytest.mark.parametrize("enqueud", [Mock(spec=ArqJob), None]) @pytest.mark.parametrize("retry", [True, False]) From 5db3561d53d6ce58e4d70ec4cf59ad7314ebf2ea Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 15:19:44 -0800 Subject: [PATCH 59/70] feat: implement create_job_dependency method in JobFactory with validation and error handling --- src/mavedb/lib/workflow/job_factory.py | 40 ++++++++ tests/lib/workflow/conftest.py | 31 ++++++ tests/lib/workflow/test_job_factory.py | 130 ++++++++++++++++++++++++- 3 files changed, 197 insertions(+), 4 deletions(-) diff --git a/src/mavedb/lib/workflow/job_factory.py b/src/mavedb/lib/workflow/job_factory.py index a5aa4dfa..556c9c09 100644 --- a/src/mavedb/lib/workflow/job_factory.py +++ b/src/mavedb/lib/workflow/job_factory.py @@ -5,6 +5,8 @@ from mavedb import __version__ as mavedb_version from mavedb.lib.types.workflow import JobDefinition +from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_dependency import JobDependency from mavedb.models.job_run import JobRun @@ -60,3 +62,41 @@ def create_job_run( self.session.add(job_run) return job_run + + def create_job_dependency( + self, + parent_job_run_id: int, + child_job_run_id: int, + dependency_type: DependencyType = DependencyType.SUCCESS_REQUIRED, + ) -> JobDependency: + """ + Creates and persists a JobDependency instance linking a parent job run to a child job run. + + Args: + parent_job_run_id (int): The ID of the parent job run. + child_job_run_id (int): The ID of the child job run. + dependency_type (DependencyType): The type of dependency (default is SUCCESS_REQUIRED). + + Returns: + JobDependency: The newly created JobDependency instance (not yet committed to the database). + + Raises: + ValueError: If the parent or child job run IDs do not exist in the database. + """ + + # Validate that the parent and child job runs exist + parent_exists = self.session.query(JobRun.id).filter(JobRun.id == parent_job_run_id).first() is not None + child_exists = self.session.query(JobRun.id).filter(JobRun.id == child_job_run_id).first() is not None + if not parent_exists: + raise ValueError(f"Parent job run ID {parent_job_run_id} does not exist.") + if not child_exists: + raise ValueError(f"Child job run ID {child_job_run_id} does not exist.") + + job_dependency = JobDependency( + id=child_job_run_id, + depends_on_job_id=parent_job_run_id, + dependency_type=dependency_type, + ) # type: ignore[call-arg] + + self.session.add(job_dependency) + return job_dependency diff --git a/tests/lib/workflow/conftest.py b/tests/lib/workflow/conftest.py index dad72098..0f9d9e50 100644 --- a/tests/lib/workflow/conftest.py +++ b/tests/lib/workflow/conftest.py @@ -3,6 +3,7 @@ import pytest from mavedb.models.enums.job_pipeline import DependencyType +from mavedb.models.job_run import JobRun from mavedb.models.user import User from tests.helpers.constants import TEST_USER @@ -78,3 +79,33 @@ def test_user(session): db.add(user) db.commit() yield user + + +@pytest.fixture +def test_workflow_parent_job_run(session, test_user): + """Fixture to create and provide a test parent job run for workflow tests.""" + parent_job_run = JobRun( + job_type="test_type", + job_function="test_function", + job_params={}, + correlation_id="test_correlation_id", + ) + session.add(parent_job_run) + session.commit() + + yield parent_job_run + + +@pytest.fixture +def test_workflow_child_job_run(session, test_user, test_workflow_parent_job_run): + """Fixture to create and provide a test child job run for workflow tests.""" + child_job_run = JobRun( + job_type="test_type", + job_function="test_function", + job_params={}, + correlation_id="test_correlation_id", + ) + session.add(child_job_run) + session.commit() + + yield child_job_run diff --git a/tests/lib/workflow/test_job_factory.py b/tests/lib/workflow/test_job_factory.py index 6b730299..bf2e13ba 100644 --- a/tests/lib/workflow/test_job_factory.py +++ b/tests/lib/workflow/test_job_factory.py @@ -1,6 +1,8 @@ # ruff: noqa: E402 import pytest +from mavedb.models.job_dependency import JobDependency + pytest.importorskip("fastapi") from unittest.mock import patch @@ -9,8 +11,8 @@ @pytest.mark.unit -class TestJobFactoryUnit: - """Unit tests for the JobFactory class.""" +class TestJobFactoryCreateJobRunUnit: + """Unit tests for the JobFactory create_job_run method.""" def test_create_job_run_persists_preset_params_from_definition(self, job_factory, sample_job_definition): existing_params = {"param1": "new_value1", "param2": "new_value2", "required_param": "required_value"} @@ -129,8 +131,8 @@ def test_create_job_run_adds_to_session(self, job_factory, sample_job_definition @pytest.mark.integration -class TestJobFactoryIntegration: - """Integration tests for the JobFactory class within pipeline execution.""" +class TestJobFactoryCreateJobRunIntegration: + """Integration tests for the JobFactory create_job_run method within pipeline execution.""" def test_create_job_run_independent(self, job_factory, sample_job_definition): pipeline_params = {"required_param": "required_value"} @@ -192,3 +194,123 @@ def test_create_job_run_missing_params_raises_error(self, job_factory, sample_jo ) assert "Missing required param: required_param" in str(exc_info.value) + + +@pytest.mark.unit +class TestJobFactoryCreateJobDependencyUnit: + """Unit tests for the JobFactory create_job_dependency method.""" + + def test_create_job_dependency_persists_fields( + self, job_factory, test_workflow_parent_job_run, test_workflow_child_job_run + ): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = test_workflow_child_job_run.id + dependency_type = "success_required" + + job_dependency = job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + dependency_type=dependency_type, + ) + + assert job_dependency.id == child_job_run_id + assert job_dependency.depends_on_job_id == parent_job_run_id + assert job_dependency.dependency_type == dependency_type + + def test_create_job_dependency_defaults_dependency_type( + self, job_factory, test_workflow_parent_job_run, test_workflow_child_job_run + ): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = test_workflow_child_job_run.id + + job_dependency = job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert job_dependency.id == child_job_run_id + assert job_dependency.depends_on_job_id == parent_job_run_id + assert job_dependency.dependency_type == "success_required" + + def test_create_job_dependency_raises_error_for_nonexistent_parent(self, job_factory, test_workflow_child_job_run): + parent_job_run_id = 9999 # Assuming this ID does not exist + child_job_run_id = test_workflow_child_job_run.id + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Parent job run ID {parent_job_run_id} does not exist." in str(exc_info.value) + + def test_create_job_dependency_raises_error_for_nonexistent_child(self, job_factory, test_workflow_parent_job_run): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = 9999 # Assuming this ID does not exist + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Child job run ID {child_job_run_id} does not exist." in str(exc_info.value) + + +@pytest.mark.integration +class TestJobFactoryCreateJobDependencyIntegration: + """Integration tests for the JobFactory create_job_dependency method within job execution.""" + + def test_create_job_dependency(self, job_factory, test_workflow_parent_job_run, test_workflow_child_job_run): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = test_workflow_child_job_run.id + dependency_type = "success_required" + + job_dependency = job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + dependency_type=dependency_type, + ) + job_factory.session.commit() + + retrieved_dependency = ( + job_factory.session.query(type(job_dependency)) + .filter( + type(job_dependency).id == child_job_run_id, + type(job_dependency).depends_on_job_id == parent_job_run_id, + ) + .first() + ) + + assert retrieved_dependency is not None + assert retrieved_dependency.id == child_job_run_id + assert retrieved_dependency.depends_on_job_id == parent_job_run_id + assert retrieved_dependency.dependency_type == dependency_type + + def test_create_job_dependency_missing_parent_raises_error(self, session, job_factory, test_workflow_child_job_run): + parent_job_run_id = 9999 # Assuming this ID does not exist + child_job_run_id = test_workflow_child_job_run.id + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Parent job run ID {parent_job_run_id} does not exist." in str(exc_info.value) + job_dependencies = session.query(JobDependency).all() + assert not job_dependencies + + def test_create_job_dependency_missing_child_raises_error(self, session, job_factory, test_workflow_parent_job_run): + parent_job_run_id = test_workflow_parent_job_run.id + child_job_run_id = 9999 # Assuming this ID does not exist + + with pytest.raises(ValueError) as exc_info: + job_factory.create_job_dependency( + parent_job_run_id=parent_job_run_id, + child_job_run_id=child_job_run_id, + ) + + assert f"Child job run ID {child_job_run_id} does not exist." in str(exc_info.value) + job_dependencies = session.query(JobDependency).all() + assert not job_dependencies From 85a426823e519643b540acd366fe0544992a5419 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 15:20:13 -0800 Subject: [PATCH 60/70] feat: refactor UniProt ID mapping script to use async commands and job management --- .../map_to_uniprot_id_from_mapped_metadata.py | 209 +++++++++--------- 1 file changed, 106 insertions(+), 103 deletions(-) diff --git a/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py b/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py index c681babc..1e37b103 100644 --- a/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py +++ b/src/mavedb/scripts/map_to_uniprot_id_from_mapped_metadata.py @@ -1,126 +1,129 @@ -import click +import asyncio +import datetime import logging -from typing import Optional -from sqlalchemy.orm import Session +import asyncclick as click # using asyncclick to allow async commands -from mavedb.scripts.environment import with_database_session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.models.enums.job_pipeline import JobStatus from mavedb.models.score_set import ScoreSet -from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI -from mavedb.lib.uniprot.utils import infer_db_name_from_sequence_accession -from mavedb.lib.mapping import extract_ids_from_post_mapped_metadata - -VALID_UNIPROT_DBS = [ - "UniProtKB", - "UniProtKB_AC-ID", - "UniProtKB-Swiss-Prot", - "UniParc", - "UniRef50", - "UniRef90", - "UniRef100", -] +from mavedb.worker.jobs.external_services.uniprot import ( + poll_uniprot_mapping_jobs_for_score_set, + submit_uniprot_mapping_jobs_for_score_set, +) +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) @click.command() -@with_database_session -@click.option("--score-set-urn", type=str, default=None, help="Score set URN to process. If not provided, process all.") +@click.argument("score_set_urn", type=str, required=True) @click.option("--polling-interval", type=int, default=30, help="Polling interval in seconds for checking job status.") @click.option("--polling-attempts", type=int, default=5, help="Number of tries to poll for job completion.") -@click.option("--to-db", type=str, default="UniProtKB", help="Target UniProt database for ID mapping.") -@click.option( - "--prefer-swiss-prot", is_flag=True, default=True, help="Prefer Swiss-Prot entries in the mapping results." -) @click.option( - "--refresh-mapped-identifier", + "--refresh", is_flag=True, default=False, help="Refresh the existing mapped identifier, if one exists.", ) -def main( - db: Session, - score_set_urn: Optional[str], +async def main( + score_set_urn: str, polling_interval: int, polling_attempts: int, - to_db: str, - prefer_swiss_prot: bool = True, - refresh_mapped_identifier: bool = False, + refresh: bool = False, ) -> None: - if to_db not in VALID_UNIPROT_DBS: - raise ValueError(f"Invalid target database: {to_db}. Must be one of {VALID_UNIPROT_DBS}.") + db = SessionLocal() + if score_set_urn: - score_sets = db.query(ScoreSet).filter(ScoreSet.urn == score_set_urn).all() - else: - score_sets = db.query(ScoreSet).all() - - api = UniProtIDMappingAPI(polling_interval=polling_interval, polling_tries=polling_attempts) - - logger.info(f"Processing {len(score_sets)} score sets.") - for score_set in score_sets: - logger.info(f"Processing score set: {score_set.urn}") - - if not score_set.target_genes: - logger.warning(f"No target gene for score set {score_set.urn}. Skipped mapping this score set.") - continue - - for target_gene in score_set.target_genes: - if target_gene.uniprot_id_from_mapped_metadata and not refresh_mapped_identifier: - logger.debug( - f"Target gene {target_gene.id} already has UniProt ID {target_gene.uniprot_id_from_mapped_metadata} and refresh_mapped_identifier is False. Skipped mapping this target." - ) - continue - - if not target_gene.post_mapped_metadata: - logger.warning( - f"No post-mapped metadata for target gene {target_gene.id}. Skipped mapping this target." - ) - continue - - ids = extract_ids_from_post_mapped_metadata(target_gene.post_mapped_metadata) # type: ignore - if not ids: - logger.warning( - f"No IDs found in post_mapped_metadata for target gene {target_gene.id}. Skipped mapping this target." - ) - continue - if len(ids) > 1: - logger.warning( - f"More than one accession ID found in post_mapped_metadata for target gene {target_gene.id}. Skipped mapping this target." - ) - continue - - id_to_map = ids[0] - from_db = infer_db_name_from_sequence_accession(id_to_map) - job_id = api.submit_id_mapping(from_db, to_db=to_db, ids=[id_to_map]) - - if not job_id: - logger.warning(f"Failed to submit job for target gene {target_gene.id}. Skipped mapping this target.") - continue - if not api.check_id_mapping_results_ready(job_id): - logger.warning(f"Job {job_id} not ready for target gene {target_gene.id}. Skipped mapping this target.") - continue - - results = api.get_id_mapping_results(job_id) - mapped_results = api.extract_uniprot_id_from_results(results, prefer_swiss_prot=prefer_swiss_prot) - - if not mapped_results: - logger.warning(f"No UniProt ID found for target gene {target_gene.id}. Skipped mapping this target.") - continue - if len(mapped_results) > 1: - logger.warning( - f"Could not unambiguously map target gene {target_gene.id}. Found multiple UniProt IDs ({len(mapped_results)})." - ) - continue - - uniprot_id = mapped_results[0][id_to_map]["uniprot_id"] - target_gene.uniprot_id_from_mapped_metadata = uniprot_id - db.add(target_gene) - - logger.info(f"Updated target gene {target_gene.id} with UniProt ID {uniprot_id}.") - - logger.info(f"Processed score set {score_set.urn} with {len(score_set.target_genes)} target genes.") - - logger.info(f"Done processing {len(score_sets)} score sets.") + score_set = db.query(ScoreSet).filter(ScoreSet.urn == score_set_urn).one() + + score_set_id = score_set.id + if not refresh and any(tg.uniprot_id_from_mapped_metadata for tg in score_set.target_genes): + logger.info(f"Score set {score_set_urn} already has mapped UniProt IDs. Use --refresh to re-map.") + return + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definitions + submission_def = STANDALONE_JOB_DEFINITIONS[submit_uniprot_mapping_jobs_for_score_set] + polling_def = STANDALONE_JOB_DEFINITIONS[poll_uniprot_mapping_jobs_for_score_set] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + submission_run = job_factory.create_job_run( + job_def=submission_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + }, + ) + db.add(submission_run) + db.flush() + + polling_run = job_factory.create_job_run( + job_def=polling_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + "mapping_jobs": {}, # Will be filled in by the submission job + }, + ) + db.add(polling_run) + db.flush() + + # Dependencies are still valid outside of pipeline contexts, but we must invoke + # dependent jobs manually. + polling_dependency = job_factory.create_job_dependency( + parent_job_run_id=submission_run.id, child_job_run_id=polling_run.id + ) + db.add(polling_dependency) + db.flush() + + logger.info( + f"Submitted UniProt ID mapping submission job run ID {submission_run.id} for score set URN {score_set_urn}." + ) + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await submit_uniprot_mapping_jobs_for_score_set(ctx, submission_run.id) # type: ignore[call-arg] + + job_manager = JobManager(db, None, submission_run.id) + for i in range(polling_attempts): + logger.info( + f"Submitted UniProt ID mapping polling job run ID {polling_run.id} for score set URN {score_set_urn}, attempt {i + 1}." + ) + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + polling_result: JobResultData = await poll_uniprot_mapping_jobs_for_score_set(ctx, polling_run.id) # type: ignore[call-arg] + db.refresh(polling_run) + + if polling_run.status == JobStatus.SUCCEEDED: + logger.info(f"Polling job for score set URN {score_set_urn} succeeded on attempt {i + 1}.") + break + + logger.info( + f"Polling job for score set URN {score_set_urn} failed on attempt {i + 1} with error: {polling_result.get('exception')}" + ) + db.refresh(polling_run) + job_manager.prepare_retry(f"Polling job failed. Attempting retry in {polling_interval} seconds.") + await asyncio.sleep(polling_interval) + + logger.info(f"Completed UniProt ID mapping for score set URN {score_set_urn}. Polling result : {polling_result}") if __name__ == "__main__": From 9a4dcfca5498c2f0510dac2b6e12b6f164cb894d Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 16:22:08 -0800 Subject: [PATCH 61/70] feat: refactor link_gnomad_variants script to use async commands and job management --- src/mavedb/scripts/link_gnomad_variants.py | 112 +++++++++------------ 1 file changed, 48 insertions(+), 64 deletions(-) diff --git a/src/mavedb/scripts/link_gnomad_variants.py b/src/mavedb/scripts/link_gnomad_variants.py index d910ea59..af684683 100644 --- a/src/mavedb/scripts/link_gnomad_variants.py +++ b/src/mavedb/scripts/link_gnomad_variants.py @@ -1,82 +1,66 @@ +import datetime import logging -from typing import Sequence -import click -from sqlalchemy import select -from sqlalchemy.orm import Session +import asyncclick as click -from mavedb.db import athena -from mavedb.lib.gnomad import gnomad_variant_data_for_caids, link_gnomad_variants_to_mapped_variants -from mavedb.models.mapped_variant import MappedVariant +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.scripts.environment import with_database_session +from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) @click.command() -@with_database_session -@click.option( - "--score-set-urn", multiple=True, type=str, help="Score set URN(s) to process. Can be used multiple times." -) +@click.argument("urns", nargs=-1) @click.option("--all", "all_score_sets", is_flag=True, help="Process all score sets in the database.", default=False) -@click.option("--only-current", is_flag=True, help="Only process current mapped variants.", default=True) -def link_gnomad_variants(db: Session, score_set_urn: list[str], all_score_sets: bool, only_current: bool) -> None: +async def main(urns: list[str], all_score_sets: bool) -> None: """ Query AWS Athena for gnomAD variants matching mapped variant CAIDs for one or more score sets. """ - # 1. Collect all CAIDs for mapped variants in the selected score sets + db = SessionLocal() + if all_score_sets: - score_sets = db.query(ScoreSet.id).all() - score_set_ids = [s.id for s in score_sets] + logger.info("Processing all score sets in the database.") + score_sets = db.query(ScoreSet).all() else: - if not score_set_urn: - logger.error("No score set URNs specified.") - return - - score_sets = db.query(ScoreSet.id).filter(ScoreSet.urn.in_(score_set_urn)).all() - score_set_ids = [s.id for s in score_sets] - if len(score_set_ids) != len(score_set_urn): - logger.warning("Some provided URNs were not found in the database.") - - if not score_set_ids: - logger.error("No score sets found.") - return - - caid_query = ( - select(MappedVariant.clingen_allele_id) - .join(Variant) - .where(Variant.score_set_id.in_(score_set_ids), MappedVariant.clingen_allele_id.is_not(None)) - ) - - if only_current: - caid_query = caid_query.where(MappedVariant.current.is_(True)) - - # We filter out Nonetype CAIDs to avoid issues with Athena queries, so we can type this as Sequence[str] and ignore MyPy warnings - caids: Sequence[str] = db.scalars(caid_query.distinct()).all() # type: ignore - if not caids: - logger.error("No CAIDs found for the selected score sets.") - return - - logger.info(f"Found {len(caids)} CAIDs for the selected score sets to link to gnomAD variants.") - - # 2. Query Athena for gnomAD variants matching the CAIDs - with athena.engine.connect() as athena_session: - logger.debug("Fetching gnomAD variants from Athena.") - gnomad_variant_data = gnomad_variant_data_for_caids(athena_session, caids) - - if not gnomad_variant_data: - logger.error("No gnomAD records found for the provided CAIDs.") - return - - logger.info(f"Fetched {len(gnomad_variant_data)} gnomAD records from Athena.") - - # 3. Link gnomAD variants to mapped variants in the database - link_gnomad_variants_to_mapped_variants(db, gnomad_variant_data, only_current=only_current) - - logger.info("Done linking gnomAD variants.") + logger.info(f"Processing score sets with URNs: {urns}") + score_sets = db.query(ScoreSet).filter(ScoreSet.urn.in_(urns)).all() + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for gnomAD linking + job_def = STANDALONE_JOB_DEFINITIONS[link_gnomad_variants] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set in score_sets: + logger.info(f"Linking gnomAD variants for score set ID {score_set.id} (URN: {score_set.urn})...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set.id, + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set.id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await link_gnomad_variants(ctx, job_run.id) # type: ignore if __name__ == "__main__": - link_gnomad_variants() + main() From 86b2478bf427b7c12116138f2c5f58b7ae56e062 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 16:27:36 -0800 Subject: [PATCH 62/70] feat: refactor clingen_car_submission script to use async commands and job management --- src/mavedb/scripts/clingen_car_submission.py | 159 ++++++------------- 1 file changed, 49 insertions(+), 110 deletions(-) diff --git a/src/mavedb/scripts/clingen_car_submission.py b/src/mavedb/scripts/clingen_car_submission.py index 29ea5fd8..492c6c3e 100644 --- a/src/mavedb/scripts/clingen_car_submission.py +++ b/src/mavedb/scripts/clingen_car_submission.py @@ -1,133 +1,72 @@ -import click +import datetime import logging from typing import Sequence + +import asyncclick as click from sqlalchemy import select -from sqlalchemy.orm import Session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant -from mavedb.scripts.environment import with_database_session -from mavedb.lib.clingen.services import ClinGenAlleleRegistryService, get_allele_registry_associations -from mavedb.lib.clingen.constants import CAR_SUBMISSION_ENDPOINT -from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.worker.jobs.external_services.clingen import submit_score_set_mappings_to_car +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) -def submit_urns_to_car(db: Session, urns: Sequence[str], debug: bool) -> list[str]: - if not CAR_SUBMISSION_ENDPOINT: - logger.error("`CAR_SUBMISSION_ENDPOINT` is not set. Please check your configuration.") - return [] - - car_service = ClinGenAlleleRegistryService(url=CAR_SUBMISSION_ENDPOINT) - submitted_entities = [] - - if debug: - logger.debug("Debug mode enabled. Submitting only one request to ClinGen CAR.") - urns = urns[:1] - - for idx, urn in enumerate(urns): - logger.info(f"Processing URN: {urn}. (Scoreset {idx + 1}/{len(urns)})") - try: - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == urn)).one_or_none() - if not score_set: - logger.warning(f"No score set found for URN: {urn}") - continue - - logger.info(f"Submitting mapped variants to CAR service for score set with URN: {urn}") - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant, MappedVariant.variant_id == Variant.id) - .join(ScoreSet) - .where(ScoreSet.urn == urn) - .where(MappedVariant.post_mapped.is_not(None)) - .where(MappedVariant.current.is_(True)) - ).all() - - if not variant_objects: - logger.warning(f"No mapped variants found for score set with URN: {urn}") - continue - - if debug: - logger.debug(f"Debug mode enabled. Submitting only one variant to ClinGen CAR for URN: {urn}") - variant_objects = variant_objects[:1] - - logger.debug(f"Preparing {len(variant_objects)} mapped variants for CAR submission") - hgvs_to_mapped_variant: dict[str, list[int]] = {} - for variant, mapped_variant in variant_objects: - hgvs = get_hgvs_from_post_mapped(mapped_variant.post_mapped) - if hgvs and hgvs not in hgvs_to_mapped_variant: - hgvs_to_mapped_variant[hgvs] = [mapped_variant.id] - elif hgvs and hgvs in hgvs_to_mapped_variant: - hgvs_to_mapped_variant[hgvs].append(mapped_variant.id) - else: - logger.warning(f"No HGVS string found for mapped variant {variant.urn}") - - if not hgvs_to_mapped_variant: - logger.warning(f"No HGVS strings to submit for URN: {urn}") - continue - - logger.info(f"Submitting {len(hgvs_to_mapped_variant)} HGVS strings to CAR service for URN: {urn}") - response = car_service.dispatch_submissions(list(hgvs_to_mapped_variant.keys())) - - if not response: - logger.error(f"CAR submission failed for URN: {urn}") - else: - logger.info(f"Successfully submitted to CAR for URN: {urn}") - # Associate CAIDs with mapped variants - associations = get_allele_registry_associations(list(hgvs_to_mapped_variant.keys()), response) - for hgvs, caid in associations.items(): - mapped_variant_ids = hgvs_to_mapped_variant.get(hgvs, []) - for mv_id in mapped_variant_ids: - mapped_variant = db.scalar(select(MappedVariant).where(MappedVariant.id == mv_id)) - if not mapped_variant: - logger.warning(f"Mapped variant with ID {mv_id} not found for HGVS {hgvs}.") - continue - - mapped_variant.clingen_allele_id = caid - db.add(mapped_variant) - - submitted_entities.extend([variant.urn for variant, _ in variant_objects]) - - except Exception as e: - logger.error(f"Error processing URN {urn}", exc_info=e) - - return submitted_entities - - @click.command() -@with_database_session @click.argument("urns", nargs=-1) @click.option("--all", help="Submit variants for every score set in MaveDB.", is_flag=True) -@click.option("--suppress-output", help="Suppress final print output to the console.", is_flag=True) -@click.option("--debug", help="Enable debug mode. This will send only one request at most to ClinGen CAR", is_flag=True) -def submit_car_urns_command( - db: Session, - urns: Sequence[str], - all: bool, - suppress_output: bool, - debug: bool, -) -> None: +async def main(urns: Sequence[str], all: bool) -> None: """ Submit data to ClinGen Allele Registry for mapped variant CAID generation for the given URNs. """ + db = SessionLocal() + if urns and all: logger.error("Cannot provide both URNs and --all option.") return if all: - urns = db.scalars(select(ScoreSet.urn)).all() # type: ignore - - if not urns: - logger.error("No URNs provided. Please provide at least one URN.") - return - - submitted_variant_urns = submit_urns_to_car(db, urns, debug) - - if not suppress_output: - print(", ".join(submitted_variant_urns)) + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info(f"Command invoked with --all. Routine will submit CAR data for {len(score_set_ids)} score sets.") + else: + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Submitting CAR data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for CAR submission + job_def = STANDALONE_JOB_DEFINITIONS[submit_score_set_mappings_to_car] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Submitting CAR data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await submit_score_set_mappings_to_car(ctx, job_run.id) # type: ignore if __name__ == "__main__": - submit_car_urns_command() + main() From 25bc7da35f9ef6d91d379233daebaae2c7f352b8 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Thu, 29 Jan 2026 16:30:27 -0800 Subject: [PATCH 63/70] feat: refactor clingen_ldh_submission script to streamline job submission process and enhance logging --- src/mavedb/scripts/clingen_ldh_submission.py | 222 +++++-------------- 1 file changed, 51 insertions(+), 171 deletions(-) diff --git a/src/mavedb/scripts/clingen_ldh_submission.py b/src/mavedb/scripts/clingen_ldh_submission.py index 94f16520..17178287 100644 --- a/src/mavedb/scripts/clingen_ldh_submission.py +++ b/src/mavedb/scripts/clingen_ldh_submission.py @@ -1,19 +1,18 @@ -import click +import datetime import logging import re -from typing import Optional, Sequence +from typing import Sequence -from sqlalchemy import and_, select +import click +from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant -from mavedb.scripts.environment import with_database_session -from mavedb.lib.clingen.services import ClinGenLdhService -from mavedb.lib.clingen.constants import DEFAULT_LDH_SUBMISSION_BATCH_SIZE, LDH_SUBMISSION_ENDPOINT -from mavedb.lib.clingen.content_constructors import construct_ldh_submission -from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.worker.jobs.external_services.clingen import submit_score_set_mappings_to_ldh +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) @@ -21,177 +20,58 @@ variant_with_reference_regex = re.compile(r":") -def submit_urns_to_clingen( - db: Session, urns: Sequence[str], unlinked_only: bool, prefer_unmapped_hgvs: bool, debug: bool -) -> list[str]: - ldh_service = ClinGenLdhService(url=LDH_SUBMISSION_ENDPOINT) - ldh_service.authenticate() - - submitted_entities = [] - - if debug: - logger.debug("Debug mode enabled. Submitting only one request to ClinGen.") - urns = urns[:1] - - for idx, urn in enumerate(urns): - logger.info(f"Processing URN: {urn}. (Scoreset {idx + 1}/{len(urns)})") - - try: - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == urn)).one_or_none() - if not score_set: - logger.warning(f"No score set found for URN: {urn}") - continue - - logger.info(f"Submitting mapped variants to LDH service for score set with URN: {urn}") - mapped_variant_join_clause = and_( - MappedVariant.variant_id == Variant.id, - MappedVariant.post_mapped.is_not(None), - MappedVariant.current.is_(True), - ) - variant_objects = db.execute( - select(Variant, MappedVariant) - .join(MappedVariant, mapped_variant_join_clause, isouter=True) - .join(ScoreSet) - .where(ScoreSet.urn == urn) - ).all() - - if not variant_objects: - logger.warning(f"No mapped variants found for score set with URN: {urn}") - continue - - logger.debug(f"Preparing {len(variant_objects)} mapped variants for submission") - - variant_content: list[tuple[str, Variant, Optional[MappedVariant]]] = [] - for variant, mapped_variant in variant_objects: - if mapped_variant is None: - if variant.hgvs_nt is not None and intronic_variant_with_reference_regex.search(variant.hgvs_nt): - # Use the hgvs_nt string for unmapped intronic variants. This is because our mapper does not yet - # support mapping intronic variants. - variation = variant.hgvs_nt - if variation: - logger.info(f"Using hgvs_nt for unmapped intronic variant {variant.urn}: {variation}") - elif variant.hgvs_nt is not None and variant_with_reference_regex.search(variant.hgvs_nt): - # Use the hgvs_nt string for other unmapped NT variants in accession-based score sets. - variation = variant.hgvs_nt - if variation: - logger.info(f"Using hgvs_nt for unmapped non-intronic variant {variant.urn}: {variation}") - elif variant.hgvs_pro is not None and variant_with_reference_regex.search(variant.hgvs_pro): - # Use the hgvs_pro string for unmapped PRO variants in accession-based score sets. - variation = variant.hgvs_pro - if variation: - logger.info(f"Using hgvs_pro for unmapped non-intronic variant {variant.urn}: {variation}") - else: - logger.warning( - f"No variation found for unmapped variant {variant.urn} (nt: {variant.hgvs_nt}, aa: {variant.hgvs_pro}, splice: {variant.hgvs_splice})." - ) - continue - else: - if unlinked_only and mapped_variant.clingen_allele_id: - continue - # If the script was run with the --prefer-unmapped-hgvs flag, use the hgvs_nt string rather than the - # mapped variant, as long as the variant is accession-based. - if ( - prefer_unmapped_hgvs - and variant.hgvs_nt is not None - and variant_with_reference_regex.search(variant.hgvs_nt) - ): - variation = variant.hgvs_nt - if variation: - logger.info(f"Using hgvs_nt for mapped variant {variant.urn}: {variation}") - elif ( - prefer_unmapped_hgvs - and variant.hgvs_pro is not None - and variant_with_reference_regex.search(variant.hgvs_pro) - ): - variation = variant.hgvs_pro - if variation: - logger.info( - f"Using hgvs_pro for mapped variant {variant.urn}: {variation}" - ) # continue # TEMPORARY. Only submit unmapped variants. - else: - variation = get_hgvs_from_post_mapped(mapped_variant) - if variation: - logger.info(f"Using mapped variant for {variant.urn}: {variation}") - - if not variation: - logger.warning( - f"No variation found for mapped variant {variant.urn} (nt: {variant.hgvs_nt}, aa: {variant.hgvs_pro}, splice: {variant.hgvs_splice})." - ) - continue - - variant_content.append((variation, variant, mapped_variant)) - - if debug: - logger.debug("Debug mode enabled. Submitting only one request to ClinGen.") - variant_content = variant_content[:1] - - logger.debug(f"Constructing LDH submission for {len(variant_content)} variants") - submission_content = construct_ldh_submission(variant_content) - submission_successes, submission_failures = ldh_service.dispatch_submissions( - submission_content, DEFAULT_LDH_SUBMISSION_BATCH_SIZE - ) - - if submission_failures: - logger.error(f"Failed to submit some variants for URN: {urn}") - else: - logger.info(f"Successfully submitted all variants for URN: {urn}") - - submitted_entities.extend([variant.urn for _, variant, _ in variant_content]) - - except Exception as e: - logger.error(f"Error processing URN {urn}", exc_info=e) - - # TODO#372: non-nullable urns. - return submitted_entities # type: ignore - - @click.command() -@with_database_session @click.argument("urns", nargs=-1) @click.option("--all", help="Submit variants for every score set in MaveDB.", is_flag=True) -@click.option( - "--unlinked", - default=False, - help="Only submit variants that have not already been linked to ClinGen alleles.", - is_flag=True, -) -@click.option( - "--prefer-unmapped-hgvs", - default=False, - help="If the unmapped HGVS string is accession-based, use it in the submission instead of the mapped variant.", - is_flag=True, -) -@click.option("--suppress-output", help="Suppress final print output to the console.", is_flag=True) -@click.option("--debug", help="Enable debug mode. This will send only one request at most to ClinGen", is_flag=True) -def submit_clingen_urns_command( - db: Session, - urns: Sequence[str], - all: bool, - unlinked: bool, - prefer_unmapped_hgvs: bool, - suppress_output: bool, - debug: bool, -) -> None: +def main(db: Session, urns: Sequence[str], all: bool) -> None: """ - Submit data to ClinGen for mapped variant allele ID generation for the given URNs. + Submit data to ClinGen LDH for mapped variant allele ID generation for the given URNs. """ + db = SessionLocal() + if urns and all: logger.error("Cannot provide both URNs and --all option.") return if all: - # TODO#372: non-nullable urns. - urns = db.scalars(select(ScoreSet.urn)).all() # type: ignore - - if not urns: - logger.error("No URNs provided. Please provide at least one URN.") - return - - submitted_variant_urns = submit_urns_to_clingen(db, urns, unlinked, prefer_unmapped_hgvs, debug) - - if not suppress_output: - print(", ".join(submitted_variant_urns)) + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info(f"Command invoked with --all. Routine will submit LDH data for {len(score_set_ids)} score sets.") + else: + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Submitting LDH data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for ldh submission + job_def = STANDALONE_JOB_DEFINITIONS[submit_score_set_mappings_to_ldh] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Submitting LDH data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + }, + ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") + + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await submit_score_set_mappings_to_ldh(ctx, job_run.id) # type: ignore if __name__ == "__main__": - submit_clingen_urns_command() + main() From 3c7449b99acd8cc3c6489b30c7a43c8b282adee3 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Fri, 30 Jan 2026 13:33:58 -0800 Subject: [PATCH 64/70] feat: clinvar clinical control refresh job + script --- src/mavedb/lib/clinvar/__init__.py | 0 src/mavedb/lib/clinvar/constants.py | 1 + src/mavedb/lib/clinvar/utils.py | 112 ++ .../scripts/refresh_clinvar_variant_data.py | 224 +-- .../worker/jobs/external_services/__init__.py | 2 + .../worker/jobs/external_services/clinvar.py | 266 +++ src/mavedb/worker/jobs/registry.py | 9 + tests/conftest.py | 9 + tests/conftest_optional.py | 4 +- tests/helpers/constants.py | 1 + tests/lib/clinvar/network/test_utils.py | 23 + tests/lib/clinvar/test_utils.py | 148 ++ tests/worker/jobs/conftest.py | 74 +- .../external_services/network/test_clinvar.py | 48 + .../jobs/external_services/test_clinvar.py | 1470 +++++++++++++++++ 15 files changed, 2229 insertions(+), 162 deletions(-) create mode 100644 src/mavedb/lib/clinvar/__init__.py create mode 100644 src/mavedb/lib/clinvar/constants.py create mode 100644 src/mavedb/lib/clinvar/utils.py create mode 100644 src/mavedb/worker/jobs/external_services/clinvar.py create mode 100644 tests/lib/clinvar/network/test_utils.py create mode 100644 tests/lib/clinvar/test_utils.py create mode 100644 tests/worker/jobs/external_services/network/test_clinvar.py create mode 100644 tests/worker/jobs/external_services/test_clinvar.py diff --git a/src/mavedb/lib/clinvar/__init__.py b/src/mavedb/lib/clinvar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mavedb/lib/clinvar/constants.py b/src/mavedb/lib/clinvar/constants.py new file mode 100644 index 00000000..b0d5397f --- /dev/null +++ b/src/mavedb/lib/clinvar/constants.py @@ -0,0 +1 @@ +TSV_VARIANT_ARCHIVE_BASE_URL = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/archive" diff --git a/src/mavedb/lib/clinvar/utils.py b/src/mavedb/lib/clinvar/utils.py new file mode 100644 index 00000000..845dcec9 --- /dev/null +++ b/src/mavedb/lib/clinvar/utils.py @@ -0,0 +1,112 @@ +import csv +import gzip +import io +import sys +from datetime import datetime +from typing import Dict + +import requests + +from mavedb.lib.clinvar.constants import TSV_VARIANT_ARCHIVE_BASE_URL + + +def validate_clinvar_variant_summary_date(month: int, year: int) -> None: + """ + Validates the provided month and year for fetching ClinVar variant summary data. + + Ensures that: + - The year is not earlier than 2015 (ClinVar archived data is only available from 2015 onwards). + - The year is not in the future. + - If the year is the current year, the month is not in the future. + + Raises: + ValueError: If the provided year is before 2015, in the future, or if the month is in the future for the current year. + + Args: + month (int): The month to validate (1-12). + year (int): The year to validate. + """ + current_year = datetime.now().year + current_month = datetime.now().month + + if month < 1 or month > 12: + raise ValueError("Month must be an integer between 1 and 12.") + + if year < 2015 or (year == 2015 and month < 2): + raise ValueError("ClinVar archived data is only available from February 2015 onwards.") + elif year > current_year: + raise ValueError("Cannot fetch ClinVar data for future years.") + elif year == current_year and month > current_month: + raise ValueError("Cannot fetch ClinVar data for future months.") + + +def fetch_clinvar_variant_summary_tsv(month: int, year: int) -> bytes: + """ + Fetches the ClinVar variant summary TSV file for a specified month and year. + + This function attempts to download the variant summary file from the ClinVar FTP archive. + It first tries the top-level directory for recent files, and if not found, falls back to the year-based subdirectory. + The function validates the provided month and year before attempting the download. + + Args: + month (int): The month for which to fetch the variant summary (as an integer). + year (int): The year for which to fetch the variant summary. + + Returns: + bytes: The contents of the downloaded variant summary TSV file (gzipped). + + Raises: + requests.RequestException: If the file cannot be downloaded from either location. + ValueError: If the provided month or year is invalid. + """ + validate_clinvar_variant_summary_date(month, year) + + # Construct URLs for the variant summary TSV file. ClinVar stores recent files at the top level and older files in year-based subdirectories. + # The cadence at which files are moved is not documented, so we try both locations with a preference for the top-level URL. + url_top_level = f"{TSV_VARIANT_ARCHIVE_BASE_URL}/variant_summary_{year}-{month:02d}.txt.gz" + url_archive = f"{TSV_VARIANT_ARCHIVE_BASE_URL}/{year}/variant_summary_{year}-{month:02d}.txt.gz" + + try: + response = requests.get(url_top_level, stream=True) + response.raise_for_status() + return response.content + except requests.exceptions.HTTPError: + response = requests.get(url_archive, stream=True) + response.raise_for_status() + return response.content + + +def parse_clinvar_variant_summary(tsv_content: bytes) -> Dict[str, Dict[str, str]]: + """ + Parses a gzipped TSV file content and returns a dictionary mapping Allele IDs to row data. + + Args: + tsv_content (bytes): The gzipped TSV file content as bytes. + + Returns: + Dict[str, Dict[str, str]]: A dictionary where each key is a string Allele ID (from the '#AlleleID' column), + and each value is a dictionary representing the corresponding row with column names as keys. + + Raises: + KeyError: If the '#AlleleID' column is missing in any row. + ValueError: If the '#AlleleID' value cannot be converted to an integer. + csv.Error: If there is an error parsing the TSV content. + + Note: + The function temporarily increases the CSV field size limit to handle large fields in the TSV file. Some old ClinVar + variant summary files may have fields larger than the default limit. + """ + default_csv_field_size_limit = csv.field_size_limit() + + try: + csv.field_size_limit(sys.maxsize) + + with gzip.open(filename=io.BytesIO(tsv_content), mode="rt") as f: + # This readlines object will only be a list of bytes if the file is opened in "rb" mode. + reader = csv.DictReader(f.readlines(), delimiter="\t") # type: ignore + data = {str(row["#AlleleID"]): row for row in reader} + + finally: + csv.field_size_limit(default_csv_field_size_limit) + + return data diff --git a/src/mavedb/scripts/refresh_clinvar_variant_data.py b/src/mavedb/scripts/refresh_clinvar_variant_data.py index b043272c..5505aa15 100644 --- a/src/mavedb/scripts/refresh_clinvar_variant_data.py +++ b/src/mavedb/scripts/refresh_clinvar_variant_data.py @@ -1,172 +1,78 @@ -import click -from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant -import requests -import csv -import time +import datetime import logging -import gzip -import random -import io -import sys - -from typing import Dict, Any, Optional, Sequence -from datetime import date +from typing import Sequence -from sqlalchemy import and_, select, distinct -from sqlalchemy.orm import Session +import asyncclick as click +from sqlalchemy import select -from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.clinical_control import ClinicalControl -from mavedb.scripts.environment import with_database_session +from mavedb.db.session import SessionLocal +from mavedb.lib.workflow.job_factory import JobFactory +from mavedb.models.score_set import ScoreSet +from mavedb.worker.jobs.external_services.clinvar import refresh_clinvar_controls +from mavedb.worker.jobs.registry import STANDALONE_JOB_DEFINITIONS +from mavedb.worker.settings.lifecycle import standalone_ctx logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -# Some older variant summary files have larger field sizes than the default CSV reader can handle. -csv.field_size_limit(sys.maxsize) - - -def fetch_clinvar_variant_summary_tsv(month: Optional[str], year: str) -> bytes: - if month is None and year is None: - url = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/variant_summary.txt.gz" - else: - if int(year) <= 2023: - url = f"https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/archive/{year}/variant_summary_{year}-{month}.txt.gz" - else: - url = ( - f"https://ftp.ncbi.nlm.nih.gov/pub/clinvar/tab_delimited/archive/variant_summary_{year}-{month}.txt.gz" - ) - - response = requests.get(url, stream=True) - response.raise_for_status() - return response.content - - -def parse_tsv(tsv_content: bytes) -> Dict[int, Dict[str, str]]: - with gzip.open(filename=io.BytesIO(tsv_content), mode="rt") as f: - # This readlines object will only be a list of bytes if the file is opened in "rb" mode. - reader = csv.DictReader(f.readlines(), delimiter="\t") # type: ignore - data = {int(row["#AlleleID"]): row for row in reader} - - return data - - -def query_clingen_allele_api(allele_id: str) -> Dict[str, Any]: - url = f"https://reg.clinicalgenome.org/allele/{allele_id}" - retries = 5 - for i in range(retries): - try: - response = requests.get(url) - response.raise_for_status() - break - except requests.RequestException as e: - if i < retries - 1: - wait_time = (2**i) + random.uniform(0, 1) - logger.warning(f"Request failed ({e}), retrying in {wait_time:.2f} seconds...") - time.sleep(wait_time) - else: - logger.error(f"Request failed after {retries} attempts: {e}") - raise - - logger.debug(f"Fetched ClinGen data for allele ID {allele_id}.") - return response.json() - -def refresh_clinvar_variants(db: Session, month: Optional[str], year: str, urns: Sequence[str]) -> None: - tsv_content = fetch_clinvar_variant_summary_tsv(month, year) - tsv_data = parse_tsv(tsv_content) - version = f"{month}_{year}" if month and year else f"{date.today().month}_{date.today().year}" - logger.info(f"Fetched TSV variant data for ClinVar for {version}.") - if urns: - clingen_ids = db.scalars( - select(distinct(MappedVariant.clingen_allele_id)) - .join(Variant) - .join(ScoreSet) - .where( - and_( - MappedVariant.clingen_allele_id.is_not(None), - MappedVariant.current.is_(True), - ScoreSet.urn.in_(urns), - ) - ) - ).all() +@click.command() +@click.argument("urns", nargs=-1) +@click.option("--all", help="Refresh ClinVar variant data for all score sets.", is_flag=True) +@click.option("--month", type=int, help="Month of the ClinVar data release to use (1-12).", required=True) +@click.option("--year", type=int, help="Year of the ClinVar data release to use (e.g., 2024).", required=True) +async def main(urns: Sequence[str], all: bool, month: int, year: int) -> None: + """ + Refresh ClinVar variant data for mapped variants in the given score sets. + """ + db = SessionLocal() + + if urns and all: + logger.error("Cannot provide both URNs and --all option.") + return + + if all: + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info( + f"Command invoked with --all. Routine will refresh ClinVar variant data for {len(score_set_ids)} score sets." + ) else: - clingen_ids = db.scalars( - select(distinct(MappedVariant.clingen_allele_id)).where(MappedVariant.clingen_allele_id.is_not(None)) - ).all() - total_variants_with_clingen_ids = len(clingen_ids) - - logger.info(f"Fetching ClinGen data for {total_variants_with_clingen_ids} variants.") - for index, clingen_id in enumerate(clingen_ids): - if total_variants_with_clingen_ids > 0 and index % (max(total_variants_with_clingen_ids // 100, 1)) == 0: - logger.info(f"Progress: {index / total_variants_with_clingen_ids:.0%}") - - if clingen_id is not None and "," in clingen_id: - logger.debug("Detected a multi-variant ClinGen allele ID, skipping.") - continue - - # Guaranteed based on our query filters. - clingen_data = query_clingen_allele_api(clingen_id) # type: ignore - clinvar_allele_id = clingen_data.get("externalRecords", {}).get("ClinVarAlleles", [{}])[0].get("alleleId") - - if not clinvar_allele_id or clinvar_allele_id not in tsv_data: - logger.debug( - f"No ClinVar variant data found for ClinGen allele ID {clingen_id}. ({index + 1}/{total_variants_with_clingen_ids})." - ) - continue - - variant_data = tsv_data[clinvar_allele_id] - identifier = str(clinvar_allele_id) - - clinvar_variant = db.scalars( - select(ClinicalControl).where( - ClinicalControl.db_identifier == identifier, - ClinicalControl.db_version == version, - ClinicalControl.db_name == "ClinVar", - ) - ).one_or_none() - if clinvar_variant: - clinvar_variant.gene_symbol = variant_data.get("GeneSymbol") - clinvar_variant.clinical_significance = variant_data.get("ClinicalSignificance") - clinvar_variant.clinical_review_status = variant_data.get("ReviewStatus") - else: - clinvar_variant = ClinicalControl( - db_identifier=identifier, - gene_symbol=variant_data.get("GeneSymbol"), - clinical_significance=variant_data.get("ClinicalSignificance"), - clinical_review_status=variant_data.get("ReviewStatus"), - db_version=version, - db_name="ClinVar", - ) - - db.add(clinvar_variant) - - variants_with_clingen_allele_id = db.scalars( - select(MappedVariant).where(MappedVariant.clingen_allele_id == clingen_id) - ).all() - for mapped_variant in variants_with_clingen_allele_id: - if clinvar_variant.id in [c.id for c in mapped_variant.clinical_controls]: - continue - mapped_variant.clinical_controls.append(clinvar_variant) - db.add(mapped_variant) - - db.commit() - logger.debug( - f"Added ClinVar variant data ({identifier}) for ClinGen allele ID {clingen_id}. ({index + 1}/{total_variants_with_clingen_ids})." + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Refreshing ClinVar variant data for the provided score sets ({len(score_set_ids)}).") + + # Unique correlation ID for this batch run + correlation_id = f"populate_mapped_variants_{datetime.datetime.now().isoformat()}" + + # Job definition for ClinVar controls refresh + job_def = STANDALONE_JOB_DEFINITIONS[refresh_clinvar_controls] + job_factory = JobFactory(db) + + # Use a standalone context for job execution outside of ARQ worker. + ctx = standalone_ctx() + ctx["db"] = db + + for score_set_id in score_set_ids: + logger.info(f"Refreshing ClinVar variant data for score set ID {score_set_id}...") + + job_run = job_factory.create_job_run( + job_def=job_def, + pipeline_id=None, + correlation_id=correlation_id, + pipeline_params={ + "score_set_id": score_set_id, + "correlation_id": correlation_id, + "month": month, + "year": year, + }, ) + db.add(job_run) + db.flush() + logger.info(f"Submitted job run ID {job_run.id} for score set ID {score_set_id}.") - -@click.command() -@with_database_session -@click.argument("urns", nargs=-1) -@click.option("--month", default=None, help="Populate mapped variants for every score set in MaveDB.") -@click.option("--year", required=True, help="Populate mapped variants for every score set in MaveDB.") -def refresh_clinvar_variants_command(db: Session, month: Optional[str], year: str, urns: Sequence[str]) -> None: - refresh_clinvar_variants(db, month, year, urns) + # Despite accepting a third argument for the job manager and MyPy expecting it, this + # argument will be injected automatically by the decorator. We only need to pass + # the ctx and job_run.id here for the decorator to generate the job manager. + await refresh_clinvar_controls(ctx, job_run.id) # type: ignore if __name__ == "__main__": - refresh_clinvar_variants_command() + main() diff --git a/src/mavedb/worker/jobs/external_services/__init__.py b/src/mavedb/worker/jobs/external_services/__init__.py index eabe8ebe..eb88b7e9 100644 --- a/src/mavedb/worker/jobs/external_services/__init__.py +++ b/src/mavedb/worker/jobs/external_services/__init__.py @@ -11,6 +11,7 @@ submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, ) +from .clinvar import refresh_clinvar_controls from .gnomad import link_gnomad_variants from .uniprot import ( poll_uniprot_mapping_jobs_for_score_set, @@ -20,6 +21,7 @@ __all__ = [ "submit_score_set_mappings_to_car", "submit_score_set_mappings_to_ldh", + "refresh_clinvar_controls", "link_gnomad_variants", "poll_uniprot_mapping_jobs_for_score_set", "submit_uniprot_mapping_jobs_for_score_set", diff --git a/src/mavedb/worker/jobs/external_services/clinvar.py b/src/mavedb/worker/jobs/external_services/clinvar.py new file mode 100644 index 00000000..1f1b3140 --- /dev/null +++ b/src/mavedb/worker/jobs/external_services/clinvar.py @@ -0,0 +1,266 @@ +"""ClinVar integration jobs for variant annotation + +This module contains job definitions and utility functions for integrating ClinVar +variant data into MaveDB. It includes functions to fetch and parse ClinVar variant +summary data, and update MaveDB records with the latest ClinVar annotations. +""" + +import asyncio +import functools +import logging + +import requests +from sqlalchemy import select + +from mavedb.lib.annotation_status_manager import AnnotationStatusManager +from mavedb.lib.clingen.allele_registry import get_associated_clinvar_allele_id +from mavedb.lib.clinvar.utils import ( + fetch_clinvar_variant_summary_tsv, + parse_clinvar_variant_summary, + validate_clinvar_variant_summary_date, +) +from mavedb.models.clinical_control import ClinicalControl +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.utils.setup import validate_job_params +from mavedb.worker.lib.decorators.pipeline_management import with_pipeline_management +from mavedb.worker.lib.managers.job_manager import JobManager +from mavedb.worker.lib.managers.types import JobResultData + +logger = logging.getLogger(__name__) + + +@with_pipeline_management +async def refresh_clinvar_controls(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: + """ + Job to refresh ClinVar clinical control data in MaveDB. + + This job fetches the latest ClinVar variant summary data and updates + the clinical control records in MaveDB accordingly. + + Args: + ctx (dict): The job context containing necessary information. + job_id (int): The ID of the job being executed. + job_manager (JobManager): The job manager instance for managing job state. + + Returns: + JobResultData: The result of the job execution. + """ + # Get the job definition we are working on + job = job_manager.get_job() + + _job_required_params = ["score_set_id", "correlation_id", "year", "month"] + validate_job_params(_job_required_params, job) + + # Fetch required resources based on param inputs. Safely ignore mypy warnings here, as they were checked above. + score_set = job_manager.db.scalars(select(ScoreSet).where(ScoreSet.id == job.job_params["score_set_id"])).one() # type: ignore + correlation_id = job.job_params["correlation_id"] # type: ignore + year = int(job.job_params["year"]) # type: ignore + month = int(job.job_params["month"]) # type: ignore + + validate_clinvar_variant_summary_date(month, year) + # Version must be in MM_YYYY format + clinvar_version = f"{month:02d}_{year}" + + # Setup initial context and progress + job_manager.save_to_context( + { + "application": "mavedb-worker", + "function": "refresh_clinvar_controls", + "resource": score_set.urn, + "correlation_id": correlation_id, + "clinvar_year": year, + "clinvar_month": month, + } + ) + job_manager.update_progress(0, 100, f"Starting ClinVar clinical control refresh for version {clinvar_version}.") + logger.info(msg="Started ClinVar clinical control refresh", extra=job_manager.logging_context()) + + job_manager.update_progress(1, 100, "Fetching ClinVar variant summary TSV data.") + logger.debug("Fetching ClinVar variant summary TSV data.", extra=job_manager.logging_context()) + + # Fetch and parse ClinVar variant summary TSV data + blocking = functools.partial(fetch_clinvar_variant_summary_tsv, month, year) + loop = asyncio.get_running_loop() + tsv_content = await loop.run_in_executor(ctx["pool"], blocking) + tsv_data = parse_clinvar_variant_summary(tsv_content) + + job_manager.update_progress(10, 100, "Fetched and parsed ClinVar variant summary TSV data.") + logger.debug("Fetched and parsed ClinVar variant summary TSV data.", extra=job_manager.logging_context()) + + variants_to_refresh = job_manager.db.scalars( + select(MappedVariant) + .join(Variant) + .where( + Variant.score_set_id == score_set.id, + MappedVariant.current.is_(True), + ) + ).all() + total_variants_to_refresh = len(variants_to_refresh) + job_manager.save_to_context({"total_variants_to_refresh": total_variants_to_refresh}) + + logger.info( + f"Refreshing ClinVar data for {total_variants_to_refresh} variants.", extra=job_manager.logging_context() + ) + annotation_manager = AnnotationStatusManager(job_manager.db) + for index, mapped_variant in enumerate(variants_to_refresh): + job_manager.save_to_context({"mapped_variant_id": mapped_variant.id, "progress_index": index}) + if total_variants_to_refresh > 0 and index % (max(total_variants_to_refresh // 100, 1)) == 0: + job_manager.update_progress( + 10 + int((index / total_variants_to_refresh) * 90), + 100, + f"Refreshing ClinVar data for {total_variants_to_refresh} variants ({index} completed).", + ) + + clingen_id = mapped_variant.clingen_allele_id + job_manager.save_to_context({"clingen_allele_id": clingen_id}) + + if clingen_id is None: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "Mapped variant does not have an associated ClinGen allele ID.", + "failure_category": "missing_clingen_allele_id", + }, + ) + logger.debug( + "Mapped variant does not have an associated ClinGen allele ID.", extra=job_manager.logging_context() + ) + continue + + if clingen_id is not None and "," in clingen_id: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "Multi-variant ClinGen allele IDs cannot be associated with ClinVar data.", + "failure_category": "multi_variant_clingen_allele_id", + }, + ) + logger.debug("Detected a multi-variant ClinGen allele ID, skipping.", extra=job_manager.logging_context()) + continue + + # Fetch associated ClinVar Allele ID from ClinGen API + try: + # Guaranteed based on our query filters. + clinvar_allele_id = get_associated_clinvar_allele_id(clingen_id) # type: ignore + except requests.exceptions.RequestException as exc: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.FAILED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": f"Failed to retrieve ClinVar allele ID from ClinGen API: {str(exc)}", + "failure_category": "clingen_api_error", + }, + ) + logger.error( + f"Failed to retrieve ClinVar allele ID from ClinGen API for ClinGen allele ID {clingen_id}.", + extra=job_manager.logging_context(), + exc_info=exc, + ) + continue + + job_manager.save_to_context({"clinvar_allele_id": clinvar_allele_id}) + + if clinvar_allele_id is None: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "No ClinVar allele ID found for ClinGen allele ID.", + "failure_category": "no_associated_clinvar_allele_id", + }, + current=True, + ) + logger.debug("No ClinVar allele ID found for ClinGen allele ID.", extra=job_manager.logging_context()) + continue + + if clinvar_allele_id not in tsv_data: + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SKIPPED, + annotation_data={ + "job_run_id": job_manager.job_id, + "error_message": "No ClinVar data found for ClinVar allele ID.", + "failure_category": "no_clinvar_variant_data", + }, + ) + logger.debug("No ClinVar variant data found for ClinGen allele ID.", extra=job_manager.logging_context()) + continue + + variant_data = tsv_data[clinvar_allele_id] + identifier = str(clinvar_allele_id) + + clinvar_variant = job_manager.db.scalars( + select(ClinicalControl).where( + ClinicalControl.db_identifier == identifier, + ClinicalControl.db_version == clinvar_version, + ClinicalControl.db_name == "ClinVar", + ) + ).one_or_none() + if clinvar_variant is None: + job_manager.save_to_context({"creating_new_clinvar_variant": True}) + clinvar_variant = ClinicalControl( + db_identifier=identifier, + gene_symbol=variant_data.get("GeneSymbol"), + clinical_significance=variant_data.get("ClinicalSignificance"), + clinical_review_status=variant_data.get("ReviewStatus"), + db_version=clinvar_version, + db_name="ClinVar", + ) + else: + job_manager.save_to_context({"creating_new_clinvar_variant": False}) + clinvar_variant.gene_symbol = variant_data.get("GeneSymbol") + clinvar_variant.clinical_significance = variant_data.get("ClinicalSignificance") + clinvar_variant.clinical_review_status = variant_data.get("ReviewStatus") + + # Add and flush the updated/new clinical control + job_manager.db.add(clinvar_variant) + job_manager.db.flush() + + # Link the clinical control to the mapped variant if not already linked + if clinvar_variant not in mapped_variant.clinical_controls: + mapped_variant.clinical_controls.append(clinvar_variant) + job_manager.db.add(mapped_variant) + logger.debug("Linked ClinicalControl to MappedVariant.", extra=job_manager.logging_context()) + + annotation_manager.add_annotation( + variant_id=mapped_variant.variant_id, # type: ignore + annotation_type=AnnotationType.CLINVAR_CONTROL, + version=clinvar_version, + status=AnnotationStatus.SUCCESS, + annotation_data={ + "job_run_id": job_manager.job_id, + "success_data": { + "clinvar_allele_id": clinvar_allele_id, + }, + }, + current=True, + ) + + logger.debug("Updated ClinVar data for ClinGen allele ID.", extra=job_manager.logging_context()) + + logger.info( + msg=f"Fetched ClinVar variant summary data version {clinvar_version}", extra=job_manager.logging_context() + ) + job_manager.update_progress(100, 100, "Completed ClinVar clinical control refresh.") + + return {"status": "ok", "data": {}, "exception": None} diff --git a/src/mavedb/worker/jobs/registry.py b/src/mavedb/worker/jobs/registry.py index af1e9836..d2aab06b 100644 --- a/src/mavedb/worker/jobs/registry.py +++ b/src/mavedb/worker/jobs/registry.py @@ -18,6 +18,7 @@ from mavedb.worker.jobs.external_services import ( link_gnomad_variants, poll_uniprot_mapping_jobs_for_score_set, + refresh_clinvar_controls, submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, submit_uniprot_mapping_jobs_for_score_set, @@ -36,6 +37,7 @@ # External service jobs submit_score_set_mappings_to_car, submit_score_set_mappings_to_ldh, + refresh_clinvar_controls, submit_uniprot_mapping_jobs_for_score_set, poll_uniprot_mapping_jobs_for_score_set, link_gnomad_variants, @@ -95,6 +97,13 @@ "key": "submit_score_set_mappings_to_ldh", "type": JobType.MAPPED_VARIANT_ANNOTATION, }, + refresh_clinvar_controls: { + "dependencies": [], + "params": {"score_set_id": None, "correlation_id": None, "year": None, "month": None}, + "function": "refresh_clinvar_controls", + "key": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + }, submit_uniprot_mapping_jobs_for_score_set: { "dependencies": [], "params": {"score_set_id": None, "correlation_id": None}, diff --git a/tests/conftest.py b/tests/conftest.py index 82b43aeb..acebc569 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,6 +119,15 @@ def _db_session_cm(): # the test version. @pytest.fixture def patch_db_session_ctxmgr(db_session_fixture): + """Patches all known locations of the db_session fixture to use the test version. + + To use this fixture, add it to the pytestmark list of a test module: + pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + If you see an error about a test being unable to connect to the database, you + likely need to add another patch here for the module that is trying to use + db_session or include the above mark in your test module. + """ with ( mock.patch("mavedb.db.session.db_session", db_session_fixture), mock.patch("mavedb.worker.lib.decorators.utils.db_session", db_session_fixture), diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index 3735634e..579fbd5c 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -24,7 +24,7 @@ from mavedb.server_main import app from mavedb.worker.jobs import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS from mavedb.worker.lib.managers.types import JobResultData -from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER +from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER, VALID_CAID #################################################################################################### # REDIS @@ -447,7 +447,7 @@ def athena_engine(): "locus.contig": "chr1", "locus.position": 12345, "alleles": "[G, A]", - "caid": "CA123", + "caid": VALID_CAID, "joint.freq.all.ac": 23, "joint.freq.all.an": 32432423, "joint.fafmax.faf95_max_gen_anc": "anc1", diff --git a/tests/helpers/constants.py b/tests/helpers/constants.py index e46c2c2a..208a61e2 100644 --- a/tests/helpers/constants.py +++ b/tests/helpers/constants.py @@ -43,6 +43,7 @@ VALID_PRO_ACCESSION = "NP_001637.4" VALID_GENE = "BRCA1" VALID_UNIPROT_ACCESSION = "P05067" +VALID_CAID = "CA9765210" VALID_ENSEMBL_IDENTIFIER = "ENST00000530893.6" diff --git a/tests/lib/clinvar/network/test_utils.py b/tests/lib/clinvar/network/test_utils.py new file mode 100644 index 00000000..6bbf3650 --- /dev/null +++ b/tests/lib/clinvar/network/test_utils.py @@ -0,0 +1,23 @@ +from datetime import datetime + +import pytest + +from mavedb.lib.clinvar.utils import fetch_clinvar_variant_summary_tsv + + +@pytest.mark.network +@pytest.mark.slow +class TestFetchClinvarVariantSummaryTSVIntegration: + def test_fetch_recent_variant_summary(self): + now = datetime.now() + # Attempt to fetch the most recent available month (previous month) + month = now.month - 1 if now.month > 1 else 12 + year = now.year if now.month > 1 else now.year - 1 + + content = fetch_clinvar_variant_summary_tsv(month, year) + assert content.startswith(b"\x1f\x8b") # Gzip magic number + + def test_fetch_older_variant_summary(self): + # Fetch an older known date + content = fetch_clinvar_variant_summary_tsv(2, 2015) + assert content.startswith(b"\x1f\x8b") # Gzip magic number diff --git a/tests/lib/clinvar/test_utils.py b/tests/lib/clinvar/test_utils.py new file mode 100644 index 00000000..7dd19089 --- /dev/null +++ b/tests/lib/clinvar/test_utils.py @@ -0,0 +1,148 @@ +import csv +import gzip +import io +from datetime import datetime + +import pytest +import requests + +from mavedb.lib.clinvar.utils import ( + fetch_clinvar_variant_summary_tsv, + parse_clinvar_variant_summary, + validate_clinvar_variant_summary_date, +) + + +@pytest.mark.unit +class TestValidateClinvarVariantSummaryDate: + def test_valid_past_date(self): + # Should not raise for a valid past date + validate_clinvar_variant_summary_date(2, 2015) + + def test_valid_current_month_and_year(self): + now = datetime.now() + # Should not raise for current month and year + validate_clinvar_variant_summary_date(now.month, now.year) + + def test_invalid_month_low(self): + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + validate_clinvar_variant_summary_date(0, 2020) + + def test_invalid_month_high(self): + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + validate_clinvar_variant_summary_date(13, 2020) + + def test_year_before_2015(self): + with pytest.raises(ValueError, match="ClinVar archived data is only available from February 2015 onwards."): + validate_clinvar_variant_summary_date(6, 2014) + + def test_year_2015_before_february(self): + with pytest.raises(ValueError, match="ClinVar archived data is only available from February 2015 onwards."): + validate_clinvar_variant_summary_date(1, 2015) + + def test_year_in_future(self): + future_year = datetime.now().year + 1 + with pytest.raises(ValueError, match="Cannot fetch ClinVar data for future years."): + validate_clinvar_variant_summary_date(6, future_year) + + def test_month_in_future_for_current_year(self): + now = datetime.now() + if now.month == 12: + pytest.skip("December, no future month in current year") + return # December, no future month in current year + + future_month = now.month + 1 if now.month < 12 else 12 + with pytest.raises(ValueError, match="Cannot fetch ClinVar data for future months."): + validate_clinvar_variant_summary_date(future_month, now.year) + + +@pytest.mark.unit +class TestFetchClinvarVariantSummaryTSV: + class MockResponse: + def __init__(self, content, status_code=200, raise_exc=None): + self.content = content + self.status_code = status_code + self._raise_exc = raise_exc + + def raise_for_status(self): + if self._raise_exc: + raise self._raise_exc + + def test_fetch_clinvar_variant_summary_tsv_top_level_success(self, monkeypatch): + # Simulate successful fetch from top-level URL + mock_content = b"mock gzipped content" + + def mock_get(url, stream=True): + return self.MockResponse(mock_content) + + monkeypatch.setattr("requests.get", mock_get) + result = fetch_clinvar_variant_summary_tsv(1, 2016) + assert result == mock_content + + def test_fetch_clinvar_variant_summary_tsv_archive_success(self, monkeypatch): + # Simulate top-level fails, archive succeeds + mock_content = b"archive gzipped content" + + def mock_get(url, stream=True): + if "variant_summary_2015-01.txt.gz" in url and "/2015/" not in url: + raise requests.RequestException("Top-level not found") + return self.MockResponse(mock_content) + + monkeypatch.setattr("requests.get", mock_get) + result = fetch_clinvar_variant_summary_tsv(1, 2016) + assert result == mock_content + + def test_fetch_clinvar_variant_summary_tsv_both_fail(self, monkeypatch): + # Simulate both URLs failing + def mock_get(url, stream=True): + raise requests.RequestException("Not found") + + monkeypatch.setattr("requests.get", mock_get) + with pytest.raises(requests.RequestException, match="Not found"): + fetch_clinvar_variant_summary_tsv(1, 2016) + + def test_fetch_clinvar_variant_summary_tsv_invalid_date(self, monkeypatch): + # Should raise ValueError before any network call + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + fetch_clinvar_variant_summary_tsv(0, 2020) + + +class TestParseClinvarVariantSummary: + def make_gzipped_tsv(self, text: str) -> bytes: + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(text.encode("utf-8")) + return buf.getvalue() + + def test_parse_clinvar_variant_summary_basic(self): + tsv = "#AlleleID\tGeneSymbol\tClinicalSignificance\n" "123\tBRCA1\tPathogenic\n" "456\tTP53\tBenign\n" + gzipped = self.make_gzipped_tsv(tsv) + result = parse_clinvar_variant_summary(gzipped) + assert "123" in result + assert "456" in result + assert result["123"]["GeneSymbol"] == "BRCA1" + assert result["456"]["ClinicalSignificance"] == "Benign" + + def test_parse_clinvar_variant_summary_missing_alleleid_column(self): + tsv = "GeneSymbol\tClinicalSignificance\n" "BRCA1\tPathogenic\n" + gzipped = self.make_gzipped_tsv(tsv) + with pytest.raises(KeyError): + parse_clinvar_variant_summary(gzipped) + + def test_parse_clinvar_variant_summary_empty_content(self): + gzipped = self.make_gzipped_tsv("") + parse_clinvar_variant_summary(gzipped) + + def test_parse_clinvar_variant_summary_large_field(self): + large_field = "A" * (csv.field_size_limit() + 100) + tsv = f"#AlleleID\tGeneSymbol\n999\t{large_field}\n" + gzipped = self.make_gzipped_tsv(tsv) + result = parse_clinvar_variant_summary(gzipped) + assert result["999"]["GeneSymbol"] == large_field + + def test_parse_clinvar_variant_summary_does_not_alter_field_size_limit(self): + default_limit = csv.field_size_limit() + tsv = "#AlleleID\tGeneSymbol\n1\tBRCA1\n" + gzipped = self.make_gzipped_tsv(tsv) + parse_clinvar_variant_summary(gzipped) + assert csv.field_size_limit() == default_limit diff --git a/tests/worker/jobs/conftest.py b/tests/worker/jobs/conftest.py index 4a41aaab..677b4955 100644 --- a/tests/worker/jobs/conftest.py +++ b/tests/worker/jobs/conftest.py @@ -7,6 +7,7 @@ from mavedb.models.pipeline import Pipeline from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant +from tests.helpers.constants import VALID_CAID try: from .conftest_optional import * # noqa: F403, F401 @@ -87,6 +88,18 @@ def submit_score_set_mappings_to_car_params(with_populated_domain_data, sample_s } +@pytest.fixture +def refresh_clinvar_controls_sample_params(with_populated_domain_data, sample_score_set): + """Provide sample parameters for refresh_clinvar_controls job.""" + + return { + "correlation_id": "sample-correlation-id", + "score_set_id": sample_score_set.id, + "month": 1, + "year": 2026, + } + + ## Sample pipeline @@ -228,13 +241,14 @@ def setup_sample_variants_with_caid( session.commit() mapped_variant = MappedVariant( variant_id=variant.id, - clingen_allele_id="CA123", + clingen_allele_id=VALID_CAID, current=True, mapped_date="2024-01-01T00:00:00Z", mapping_api_version="1.0.0", ) session.add(mapped_variant) session.commit() + return variant, mapped_variant ## Uniprot Job Fixtures ## @@ -798,3 +812,61 @@ def with_full_dummy_pipeline(session, with_dummy_pipeline_start, sample_dummy_pi """Fixture to ensure dummy pipeline steps exist in the database.""" session.add(sample_dummy_pipeline_step) session.commit() + + +@pytest.fixture +def sample_refresh_clinvar_controls_job_run(refresh_clinvar_controls_sample_params): + """Create a JobRun instance for refresh_clinvar_controls job.""" + + return JobRun( + urn="test:refresh_clinvar_controls", + job_type="refresh_clinvar_controls", + job_function="refresh_clinvar_controls", + max_retries=3, + retry_count=0, + job_params=refresh_clinvar_controls_sample_params, + ) + + +@pytest.fixture +def with_refresh_clinvar_controls_job(session, sample_refresh_clinvar_controls_job_run): + """Add a refresh_clinvar_controls job run to the session.""" + + session.add(sample_refresh_clinvar_controls_job_run) + session.commit() + + +@pytest.fixture +def sample_refresh_clinvar_controls_pipeline(): + """Create a pipeline instance for refresh_clinvar_controls job.""" + + return Pipeline( + urn="test:refresh_clinvar_controls_pipeline", + name="Refresh ClinVar Controls Pipeline", + ) + + +@pytest.fixture +def with_refresh_clinvar_controls_pipeline( + session, + sample_refresh_clinvar_controls_pipeline, +): + """Add a refresh_clinvar_controls pipeline to the session.""" + + session.add(sample_refresh_clinvar_controls_pipeline) + session.commit() + + +@pytest.fixture +def sample_refresh_clinvar_controls_job_in_pipeline( + session, + with_refresh_clinvar_controls_job, + with_refresh_clinvar_controls_pipeline, + sample_refresh_clinvar_controls_job_run, + sample_refresh_clinvar_controls_pipeline, +): + """Provide a context with a refresh_clinvar_controls job run and pipeline.""" + + sample_refresh_clinvar_controls_job_run.pipeline_id = sample_refresh_clinvar_controls_pipeline.id + session.commit() + return sample_refresh_clinvar_controls_job_run diff --git a/tests/worker/jobs/external_services/network/test_clinvar.py b/tests/worker/jobs/external_services/network/test_clinvar.py new file mode 100644 index 00000000..54ae2fff --- /dev/null +++ b/tests/worker/jobs/external_services/network/test_clinvar.py @@ -0,0 +1,48 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("arq") + +from sqlalchemy import select + +from mavedb.models.clinical_control import ClinicalControl +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus, JobStatus +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.network +@pytest.mark.slow +class TestE2ERefreshClinvarControls: + async def test_refresh_clinvar_controls_e2e( + self, + session, + arq_redis, + arq_worker, + standalone_worker_context, + setup_sample_variants_with_caid, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test the end-to-end flow of refreshing ClinVar clinical controls.""" + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that clinical controls were added successfully + clinical_controls = session.scalars(select(ClinicalControl)).all() + assert len(clinical_controls) == 1 + assert clinical_controls[0].db_identifier == "3045425" + + # Verify that annotation status was added + annotation_statuses = session.scalars(select(VariantAnnotationStatus)).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == AnnotationStatus.SUCCESS + assert annotation_statuses[0].annotation_type == AnnotationType.CLINVAR_CONTROL + + # Verify that the job run was completed successfully + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED diff --git a/tests/worker/jobs/external_services/test_clinvar.py b/tests/worker/jobs/external_services/test_clinvar.py new file mode 100644 index 00000000..a7eeb6f2 --- /dev/null +++ b/tests/worker/jobs/external_services/test_clinvar.py @@ -0,0 +1,1470 @@ +# ruff: noqa: E402 + +import pytest +import requests + +from mavedb.models.clinical_control import ClinicalControl +from mavedb.models.enums.annotation_type import AnnotationType +from mavedb.models.enums.job_pipeline import AnnotationStatus, JobStatus, PipelineStatus +from mavedb.models.variant_annotation_status import VariantAnnotationStatus + +pytest.importorskip("arq") + +import gzip +from asyncio.unix_events import _UnixSelectorEventLoop +from unittest.mock import call, patch + +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.worker.jobs.external_services.clinvar import refresh_clinvar_controls +from mavedb.worker.lib.managers.job_manager import JobManager + +pytestmark = pytest.mark.usefixtures("patch_db_session_ctxmgr") + + +async def mock_fetch_tsv(*args, **kwargs): + data = b"#AlleleID\tClinicalSignificance\tGeneSymbol\tReviewStatus\nVCV000000123\tbenign\tTEST\treviewed by expert panel" + return gzip.compress(data) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestRefreshClinvarControlsUnit: + """Tests for the refresh_clinvar_controls job function.""" + + async def test_refresh_clinvar_controls_invalid_month_raises( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + # edit the job run to have an invalid month + sample_refresh_clinvar_controls_job_run.job_params["month"] = 13 + session.commit() + + with pytest.raises(ValueError, match="Month must be an integer between 1 and 12."): + await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + async def test_refresh_clinvar_controls_invalid_year_raises( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + # edit the job run to have an invalid year + sample_refresh_clinvar_controls_job_run.job_params["year"] = 1999 + session.commit() + + with pytest.raises(ValueError, match="ClinVar archived data is only available from February 2015 onwards."): + await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + async def test_refresh_clinvar_controls_propagates_exception_during_fetch( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + # Mock the fetch_clinvar_variant_data function to raise an exception + async def awaitable_exception(*args, **kwargs): + raise Exception("Network error") + + with ( + pytest.raises(Exception, match="Network error"), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=awaitable_exception(), + ), + ): + await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + async def test_refresh_clinvar_controls_no_mapped_variants( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test that the job completes successfully when there are no mapped variants.""" + + async def awaitable_noop(*args, **kwargs): + return {} + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=awaitable_noop(), + ), + patch("mavedb.worker.jobs.external_services.clinvar.parse_clinvar_variant_summary"), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + async def test_refresh_clinvar_controls_no_variants_have_caids( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test that the job completes successfully when no variants have CAIDs.""" + # Add a variant without a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:test-variant-no-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.2G>A", + hgvs_pro="NP_000000.1:p.Val2Ile", + data={"hgvs_c": "NM_000000.1:c.2G>A", "hgvs_p": "NP_000000.1:p.Val2Ile"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + with patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant without a CAID + variant_no_caid = ( + session.query(VariantAnnotationStatus).filter(VariantAnnotationStatus.variant_id == variant.id).one() + ) + assert variant_no_caid.status == AnnotationStatus.SKIPPED + assert variant_no_caid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert variant_no_caid.error_message == "Mapped variant does not have an associated ClinGen allele ID." + + async def test_refresh_clinvar_controls_variants_are_multivariants( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job completes successfully when all variants are multi-variant CAIDs.""" + # Update the mapped variant to have a multi-variant CAID + mapped_variant = session.query(MappedVariant).first() + mapped_variant.clingen_allele_id = "CA-MULTI-001,CA-MULTI-002" + session.commit() + + with patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the multi-variant CAID + variant_with_multicid = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_with_multicid.status == AnnotationStatus.SKIPPED + assert variant_with_multicid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert ( + variant_with_multicid.error_message + == "Multi-variant ClinGen allele IDs cannot be associated with ClinVar data." + ) + + async def test_refresh_clinvar_controls_clingen_api_failure( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles ClinGen API failures gracefully.""" + + # Mock the get_associated_clinvar_allele_id function to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=requests.exceptions.RequestException("ClinGen API error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to ClinGen API failure + mapped_variant = session.query(MappedVariant).first() + variant_with_api_failure = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_with_api_failure.status == AnnotationStatus.FAILED + assert variant_with_api_failure.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "Failed to retrieve ClinVar allele ID from ClinGen API" in variant_with_api_failure.error_message + + async def test_refresh_clinvar_controls_no_associated_clinvar_allele_id( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles no associated ClinVar Allele ID gracefully.""" + + # Mock the get_associated_clinvar_allele_id function to return None + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value=None, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no associated ClinVar Allele ID + mapped_variant = session.query(MappedVariant).first() + variant_no_clinvar_allele = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_allele.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_allele.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar allele ID found for ClinGen allele ID" in variant_no_clinvar_allele.error_message + + async def test_refresh_clinvar_controls_no_clinvar_data_found( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles no ClinVar data found for the associated ClinVar Allele ID.""" + + async def mock_fetch_tsv(*args, **kwargs): + data = b"#AlleleID\tClinicalSignificance\tGeneSymbol\tReviewStatus\nVCV000000001\tbenign\tTEST\treviewed by expert panel" + return gzip.compress(data) + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no ClinVar data found + mapped_variant = session.query(MappedVariant).first() + variant_no_clinvar_data = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_data.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_data.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar data found for ClinVar allele ID" in variant_no_clinvar_data.error_message + + async def test_refresh_clinvar_controls_successful_annotation_existing_control( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job successfully annotates a variant with ClinVar control data.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + mapped_variant = session.query(MappedVariant).first() + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + async def test_refresh_clinvar_controls_successful_annotation_new_control( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + ): + """Test that the job successfully annotates a variant with ClinVar control data when no prior status exists.""" + # Add a variant and mapped variant to the database with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:test-variant-with-caid-2", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.3C>T", + hgvs_pro="NP_000000.1:p.Ala3Val", + data={"hgvs_c": "NM_000000.1:c.3C>T", "hgvs_p": "NP_000000.1:p.Ala3Val"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA124", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + async def test_refresh_clinvar_controls_idempotent_run( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that running the job multiple times does not create duplicate annotation statuses.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[mock_fetch_tsv(), mock_fetch_tsv()], + ), + ): + # First run + result1 = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + session.commit() + + # Second run + result2 = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result1["status"] == "ok" + assert result2["status"] == "ok" + + # Verify only one clinical control annotation exists for the variant + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 1 + + # Verify two annotated variants exist but both reflect the same successful annotation, and only + # one is current + annotated_variants = session.query(VariantAnnotationStatus).all() + assert len(annotated_variants) == 2 + statuses = [av.status for av in annotated_variants] + assert statuses.count(AnnotationStatus.SUCCESS) == 2 + current_statuses = [av for av in annotated_variants if av.current] + assert len(current_statuses) == 1 + + async def test_refresh_clinvar_controls_partial_failure( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job handles partial failures gracefully.""" + + variant1, mapped_variant1 = setup_sample_variants_with_caid + + # Add an additional mapped variant to the database with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant2 = Variant( + urn="urn:variant:test-variant-with-caid-2", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.4G>C", + hgvs_pro="NP_000000.1:p.Gly4Ala", + data={"hgvs_c": "NM_000000.1:c.4G>C", "hgvs_p": "NP_000000.1:p.Gly4Ala"}, + ) + session.add(variant2) + session.commit() + mapped_variant2 = MappedVariant( + variant_id=variant2.id, + clingen_allele_id="CA125", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant2) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to raise an exception for the first call + def side_effect_get_associated_clinvar_allele_id(clingen_allele_id): + if clingen_allele_id == "CA125": + raise requests.exceptions.RequestException("ClinGen API error") + return "VCV000000123" + + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=side_effect_get_associated_clinvar_allele_id, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + # Verify annotation statuses for both variants + variant_with_api_failure = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant2.variant_id) + .one() + ) + assert variant_with_api_failure.status == AnnotationStatus.FAILED + assert variant_with_api_failure.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "Failed to retrieve ClinVar allele ID from ClinGen API" in variant_with_api_failure.error_message + + annotated_variant2 = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant1.variant_id) + .one() + ) + assert annotated_variant2.status == AnnotationStatus.SUCCESS + assert annotated_variant2.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant2.error_message is None + + async def test_refresh_clinvar_controls_updates_progress( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that the job updates progress correctly.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + patch.object(JobManager, "update_progress") as mock_update_progress, + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "ok" + + mock_update_progress.assert_has_calls( + [ + call(0, 100, "Starting ClinVar clinical control refresh for version 01_2026."), + call(1, 100, "Fetching ClinVar variant summary TSV data."), + call(10, 100, "Fetched and parsed ClinVar variant summary TSV data."), + call(10, 100, "Refreshing ClinVar data for 1 variants (0 completed)."), + call(100, 100, "Completed ClinVar clinical control refresh."), + ] + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestRefreshClinvarControlsIntegration: + """Integration tests for the refresh_clinvar_controls job function.""" + + async def test_refresh_clinvar_controls_no_mapped_variants( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job completes successfully when there are no mapped variants.""" + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify no controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_no_variants_with_caid( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job completes successfully when no variants have CAIDs.""" + # Add a variant without a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-no-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.5T>A", + hgvs_pro="NP_000000.1:p.Leu5Gln", + data={"hgvs_c": "NM_000000.1:c.5T>A", "hgvs_p": "NP_000000.1:p.Leu5Gln"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant without a CAID + variant_no_caid = ( + session.query(VariantAnnotationStatus).filter(VariantAnnotationStatus.variant_id == variant.id).one() + ) + assert variant_no_caid.status == AnnotationStatus.SKIPPED + assert variant_no_caid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert variant_no_caid.error_message == "Mapped variant does not have an associated ClinGen allele ID." + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controlsvariants_are_multivariants( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job completes successfully when all variants are multi-variant CAIDs.""" + # Add a variant with a multi-variant CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-multicid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.6A>G", + hgvs_pro="NP_000000.1:p.Thr6Ala", + data={"hgvs_c": "NM_000000.1:c.6A>G", "hgvs_p": "NP_000000.1:p.Thr6Ala"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA-MULTI-003,CA-MULTI-004", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + with ( + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the multi-variant CAID + variant_with_multicid = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_with_multicid.status == AnnotationStatus.SKIPPED + assert variant_with_multicid.annotation_type == AnnotationType.CLINVAR_CONTROL + assert ( + variant_with_multicid.error_message + == "Multi-variant ClinGen allele IDs cannot be associated with ClinVar data." + ) + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_no_associated_clinvar_allele_id( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job handles no associated ClinVar Allele ID gracefully.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.7C>A", + hgvs_pro="NP_000000.1:p.Ser7Tyr", + data={"hgvs_c": "NM_000000.1:c.7C>A", "hgvs_p": "NP_000000.1:p.Ser7Tyr"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA126", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return None + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value=None, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no associated ClinVar Allele ID + variant_no_clinvar_allele = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_allele.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_allele.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar allele ID found for ClinGen allele ID" in variant_no_clinvar_allele.error_message + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_no_clinvar_data( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job handles no ClinVar data found for the associated ClinVar Allele ID.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.8G>T", + hgvs_pro="NP_000000.1:p.Val8Phe", + data={"hgvs_c": "NM_000000.1:c.8G>T", "hgvs_p": "NP_000000.1:p.Val8Phe"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA127", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000001", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant due to no ClinVar data found + variant_no_clinvar_data = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert variant_no_clinvar_data.status == AnnotationStatus.SKIPPED + assert variant_no_clinvar_data.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "No ClinVar data found for ClinVar allele ID" in variant_no_clinvar_data.error_message + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_successful_annotation_existing_control( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job successfully annotates a variant with ClinVar control data.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.9A>C", + hgvs_pro="NP_000000.1:p.Lys9Thr", + data={"hgvs_c": "NM_000000.1:c.9A>C", "hgvs_p": "NP_000000.1:p.Lys9Thr"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA128", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + clinical_control = ClinicalControl( + db_name="ClinVar", + db_identifier="VCV000000123", + clinical_significance="likely pathogenic", + gene_symbol="TEST", + clinical_review_status="criteria provided, single submitter", + db_version="01_2026", + ) + session.add(clinical_control) + session.commit() + + mapped_variant.clinical_controls.append(clinical_control) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + # Verify the clinical control was updated + session.refresh(clinical_control) + assert clinical_control.clinical_significance == "benign" + assert clinical_control.clinical_review_status == "reviewed by expert panel" + assert mapped_variant in clinical_control.mapped_variants + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_successful_annotation_new_control( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + ): + """Integration test: job successfully annotates a variant with ClinVar control data when no prior status exists.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.10C>G", + hgvs_pro="NP_000000.1:p.Pro10Arg", + data={"hgvs_c": "NM_000000.1:c.10C>G", "hgvs_p": "NP_000000.1:p.Pro10Arg"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA129", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + # Verify the clinical control was added + clinical_control = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant)).one() + ) + assert clinical_control.db_identifier == "VCV000000123" + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_successful_annotation_pipeline_context( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_pipeline, + sample_refresh_clinvar_controls_job_in_pipeline, + ): + """Integration test: job successfully annotates a variant with ClinVar control data in a pipeline context.""" + # Add a variant with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_in_pipeline.job_params["score_set_id"]) + variant = Variant( + urn="urn:variant:integration-test-variant-with-caid", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.12G>A", + hgvs_pro="NP_000000.1:p.Met12Ile", + data={"hgvs_c": "NM_000000.1:c.12G>A", "hgvs_p": "NP_000000.1:p.Met12Ile"}, + ) + session.add(variant) + session.commit() + mapped_variant = MappedVariant( + variant_id=variant.id, + clingen_allele_id="CA130", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_in_pipeline.id) + + assert result["status"] == "ok" + + # Verify an annotation status was created for the variant with successful annotation + annotated_variant = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant.variant_id) + .one() + ) + assert annotated_variant.status == AnnotationStatus.SUCCESS + assert annotated_variant.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant.error_message is None + + # Verify the clinical control was added + clinical_control = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant)).one() + ) + assert clinical_control.db_identifier == "VCV000000123" + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_in_pipeline) + assert sample_refresh_clinvar_controls_job_in_pipeline.status == JobStatus.SUCCEEDED + + # Verify the pipeline is marked as completed + session.refresh(sample_refresh_clinvar_controls_pipeline) + assert sample_refresh_clinvar_controls_pipeline.status == PipelineStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_idempotent_run( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: running the job multiple times does not create duplicate annotation statuses.""" + + # Mock the get_associated_clinvar_allele_id function to return a ClinVar Allele ID + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + side_effect=[mock_fetch_tsv(), mock_fetch_tsv()], + ), + ): + # First run + result1 = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + session.commit() + # reset the job run status to pending for the second run + sample_refresh_clinvar_controls_job_run.status = JobStatus.PENDING + session.commit() + + # Second run + result2 = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result1["status"] == "ok" + assert result2["status"] == "ok" + + # Verify only one clinical control annotation exists for the variant + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 1 + + # Verify two annotated variants exist but both reflect the same successful annotation, and only + # one is current + annotated_variants = session.query(VariantAnnotationStatus).all() + assert len(annotated_variants) == 2 + statuses = [av.status for av in annotated_variants] + assert statuses.count(AnnotationStatus.SUCCESS) == 2 + current_statuses = [av for av in annotated_variants if av.current] + assert len(current_statuses) == 1 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_partial_failure( + self, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job handles partial failures gracefully.""" + + variant1, mapped_variant1 = setup_sample_variants_with_caid + # Add an additional mapped variant to the database with a CAID + score_set = session.get(ScoreSet, sample_refresh_clinvar_controls_job_run.job_params["score_set_id"]) + variant2 = Variant( + urn="urn:variant:integration-test-variant-with-caid-2", + score_set_id=score_set.id, + hgvs_nt="NM_000000.1:c.11G>C", + hgvs_pro="NP_000000.1:p.Gly11Ala", + data={"hgvs_c": "NM_000000.1:c.11G>C", "hgvs_p": "NP_000000.1:p.Gly11Ala"}, + ) + session.add(variant2) + session.commit() + mapped_variant2 = MappedVariant( + variant_id=variant2.id, + clingen_allele_id="CA130", + current=True, + mapped_date="2024-01-01T00:00:00Z", + mapping_api_version="1.0.0", + ) + session.add(mapped_variant2) + session.commit() + + # Mock the get_associated_clinvar_allele_id function to raise an exception for the first call + def side_effect_get_associated_clinvar_allele_id(clingen_allele_id): + if clingen_allele_id == "CA130": + raise requests.exceptions.RequestException("ClinGen API error") + return "VCV000000123" + + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=side_effect_get_associated_clinvar_allele_id, + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls(mock_worker_ctx, sample_refresh_clinvar_controls_job_run.id) + + assert result["status"] == "ok" + + # Verify annotation statuses for both variants + variant_with_api_failure = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant2.variant_id) + .one() + ) + assert variant_with_api_failure.status == AnnotationStatus.FAILED + assert variant_with_api_failure.annotation_type == AnnotationType.CLINVAR_CONTROL + assert "Failed to retrieve ClinVar allele ID from ClinGen API" in variant_with_api_failure.error_message + + annotated_variant2 = ( + session.query(VariantAnnotationStatus) + .filter(VariantAnnotationStatus.variant_id == mapped_variant1.variant_id) + .one() + ) + assert annotated_variant2.status == AnnotationStatus.SUCCESS + assert annotated_variant2.annotation_type == AnnotationType.CLINVAR_CONTROL + assert annotated_variant2.error_message is None + + # Verify a clinical control was added for the successfully annotated variant and not the unsuccessful one + clinical_control1 = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant1)).one() + ) + assert clinical_control1.db_identifier == "VCV000000123" + + clinical_control2 = ( + session.query(ClinicalControl).filter(ClinicalControl.mapped_variants.contains(mapped_variant2)).all() + ) + assert len(clinical_control2) == 0 + + # Verify job run status is marked as completed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_propagates_exceptions_to_decorator( + self, + mock_worker_ctx, + session, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Test that unexpected exceptions are propagated.""" + + # Mock the get_associated_clinvar_allele_id function to raise an unexpected exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=ValueError("Unexpected error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + result = await refresh_clinvar_controls( + mock_worker_ctx, + sample_refresh_clinvar_controls_job_run.id, + JobManager(session, mock_worker_ctx["redis"], sample_refresh_clinvar_controls_job_run.id), + ) + + assert result["status"] == "exception" + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as failed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.FAILED + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestRefreshClinvarControlsArqContext: + """Tests for running the refresh_clinvar_controls job function within an ARQ worker context.""" + + async def test_refresh_clinvar_controls_with_arq_context_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job completes successfully within an ARQ worker context.""" + + # Patch external service calls + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) > 0 + + # Verify annotation status was created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == AnnotationStatus.SUCCESS + assert annotation_statuses[0].annotation_type == AnnotationType.CLINVAR_CONTROL + + # Verify that the job completed successfully + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + async def test_refresh_clinvar_controls_with_arq_context_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job completes successfully within an ARQ worker context in a pipeline context.""" + + # Patch external service calls + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + return_value="VCV000000123", + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify that clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) > 0 + + # Verify annotation status was created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 1 + assert annotation_statuses[0].status == AnnotationStatus.SUCCESS + assert annotation_statuses[0].annotation_type == AnnotationType.CLINVAR_CONTROL + + # Verify that the job completed successfully + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.SUCCEEDED + + # Verify the pipeline is marked as completed + pass + + async def test_refresh_clinvar_controls_with_arq_context_exception_handling_independent( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job handles exceptions properly within an ARQ worker context.""" + # Patch external service calls to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=ValueError("Unexpected error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as failed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.FAILED + + async def test_refresh_clinvar_controls_with_arq_context_exception_handling_pipeline( + self, + arq_redis, + arq_worker, + session, + with_populated_domain_data, + with_refresh_clinvar_controls_job, + sample_refresh_clinvar_controls_job_run, + setup_sample_variants_with_caid, + ): + """Integration test: job handles exceptions properly within an ARQ worker context in a pipeline context.""" + # Patch external service calls to raise an exception + with ( + patch( + "mavedb.worker.jobs.external_services.clinvar.get_associated_clinvar_allele_id", + side_effect=ValueError("Unexpected error"), + ), + patch.object( + _UnixSelectorEventLoop, + "run_in_executor", + return_value=mock_fetch_tsv(), + ), + ): + await arq_redis.enqueue_job("refresh_clinvar_controls", sample_refresh_clinvar_controls_job_run.id) + await arq_worker.async_run() + await arq_worker.run_check() + + # Verify no annotation statuses were created + annotation_statuses = session.query(VariantAnnotationStatus).all() + assert len(annotation_statuses) == 0 + + # Verify no clinical controls were added + clinical_controls = session.query(ClinicalControl).all() + assert len(clinical_controls) == 0 + + # Verify job run status is marked as failed + session.refresh(sample_refresh_clinvar_controls_job_run) + assert sample_refresh_clinvar_controls_job_run.status == JobStatus.FAILED + + # Verify the pipeline is marked as failed + pass From f2b57a43a42e57cecf30d5ab933389006483553d Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 3 Feb 2026 16:00:03 -0800 Subject: [PATCH 65/70] feat: update annotation type handling to use enum directly and switch enum to str inheritance --- src/mavedb/lib/annotation_status_manager.py | 14 +++++++------- src/mavedb/models/enums/annotation_type.py | 6 +++--- tests/lib/test_annotation_status_manager.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/mavedb/lib/annotation_status_manager.py b/src/mavedb/lib/annotation_status_manager.py index 628846da..29b17bc0 100644 --- a/src/mavedb/lib/annotation_status_manager.py +++ b/src/mavedb/lib/annotation_status_manager.py @@ -78,7 +78,7 @@ def add_annotation( is responsible for persisting any changes (e.g., by calling session.commit()). """ logger.debug( - f"Adding annotation for variant_id={variant_id}, annotation_type={annotation_type.value}, version={version}" + f"Adding annotation for variant_id={variant_id}, annotation_type={annotation_type}, version={version}" ) # Find existing current annotations to be replaced @@ -86,7 +86,7 @@ def add_annotation( self.session.execute( select(VariantAnnotationStatus).where( VariantAnnotationStatus.variant_id == variant_id, - VariantAnnotationStatus.annotation_type == annotation_type.value, + VariantAnnotationStatus.annotation_type == annotation_type, VariantAnnotationStatus.version == version, VariantAnnotationStatus.current.is_(True), ) @@ -96,7 +96,7 @@ def add_annotation( ) for var_ann in existing_current: logger.debug( - f"Replacing current annotation {var_ann.id} for variant_id={variant_id}, annotation_type={annotation_type.value}, version={version}" + f"Replacing current annotation {var_ann.id} for variant_id={variant_id}, annotation_type={annotation_type}, version={version}" ) var_ann.current = False @@ -104,8 +104,8 @@ def add_annotation( new_status = VariantAnnotationStatus( variant_id=variant_id, - annotation_type=annotation_type.value, - status=status.value, + annotation_type=annotation_type, + status=status, version=version, current=current, **annotation_data, @@ -115,7 +115,7 @@ def add_annotation( self.session.flush() logger.info( - f"Successfully added annotation for variant_id={variant_id}, annotation_type={annotation_type.value}, version={version}" + f"Successfully added annotation for variant_id={variant_id}, annotation_type={annotation_type}, version={version}" ) return new_status @@ -135,7 +135,7 @@ def get_current_annotation( """ stmt = select(VariantAnnotationStatus).where( VariantAnnotationStatus.variant_id == variant_id, - VariantAnnotationStatus.annotation_type == annotation_type.value, + VariantAnnotationStatus.annotation_type == annotation_type, VariantAnnotationStatus.current.is_(True), ) diff --git a/src/mavedb/models/enums/annotation_type.py b/src/mavedb/models/enums/annotation_type.py index 773f056e..b1595347 100644 --- a/src/mavedb/models/enums/annotation_type.py +++ b/src/mavedb/models/enums/annotation_type.py @@ -1,12 +1,12 @@ -import enum +from enum import Enum -class AnnotationType(enum.Enum): +class AnnotationType(str, Enum): VRS_MAPPING = "vrs_mapping" CLINGEN_ALLELE_ID = "clingen_allele_id" MAPPED_HGVS = "mapped_hgvs" VARIANT_TRANSLATION = "variant_translation" GNOMAD_ALLELE_FREQUENCY = "gnomad_allele_frequency" - CLINVAR_CONTROLS = "clinvar_control" + CLINVAR_CONTROL = "clinvar_control" VEP_FUNCTIONAL_CONSEQUENCE = "vep_functional_consequence" LDH_SUBMISSION = "ldh_submission" diff --git a/tests/lib/test_annotation_status_manager.py b/tests/lib/test_annotation_status_manager.py index 98980f00..df78ce69 100644 --- a/tests/lib/test_annotation_status_manager.py +++ b/tests/lib/test_annotation_status_manager.py @@ -84,8 +84,8 @@ def test_add_annotation_creates_entry_with_annotation_type_version_status( ) session.commit() - assert annotation.annotation_type == annotation_type.value - assert annotation.status == status.value + assert annotation.annotation_type == annotation_type + assert annotation.status == status assert annotation.version == "v1.0" def test_add_annotation_persists_annotation_data( From ba70e179cbc98f050027d0b691ed8c991a200ba3 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 4 Feb 2026 10:56:50 -0800 Subject: [PATCH 66/70] feat: add functions to retrieve associated ClinVar Allele IDs and enhance test coverage --- src/mavedb/lib/clingen/allele_registry.py | 16 ++ .../clingen/network/test_allele_registry.py | 72 +++++++ tests/lib/clingen/test_allele_registry.py | 189 ++++++++++++++++++ 3 files changed, 277 insertions(+) create mode 100644 tests/lib/clingen/network/test_allele_registry.py create mode 100644 tests/lib/clingen/test_allele_registry.py diff --git a/src/mavedb/lib/clingen/allele_registry.py b/src/mavedb/lib/clingen/allele_registry.py index 5e025b14..a7951255 100644 --- a/src/mavedb/lib/clingen/allele_registry.py +++ b/src/mavedb/lib/clingen/allele_registry.py @@ -1,4 +1,5 @@ import logging + import requests logger = logging.getLogger(__name__) @@ -43,3 +44,18 @@ def get_matching_registered_ca_ids(clingen_pa_id: str) -> list[str]: ca_ids.extend([allele["@id"].split("/")[-1] for allele in allele["matchingRegisteredTranscripts"]]) return ca_ids + + +def get_associated_clinvar_allele_id(clingen_allele_id: str) -> str | None: + """Retrieve the associated ClinVar Allele ID for a given ClinGen Allele ID from the ClinGen API.""" + response = requests.get(f"{CLINGEN_API_URL}/{clingen_allele_id}") + if response.status_code != 200: + logger.error(f"Failed to query ClinGen API for {clingen_allele_id}: {response.status_code}") + return None + + data = response.json() + clinvar_allele_id = data.get("externalRecords", {}).get("ClinVarAlleles", [{}])[0].get("alleleId") + if clinvar_allele_id: + return str(clinvar_allele_id) + + return None diff --git a/tests/lib/clingen/network/test_allele_registry.py b/tests/lib/clingen/network/test_allele_registry.py new file mode 100644 index 00000000..f2ab2bff --- /dev/null +++ b/tests/lib/clingen/network/test_allele_registry.py @@ -0,0 +1,72 @@ +import pytest + +from mavedb.lib.clingen.allele_registry import ( + get_associated_clinvar_allele_id, + get_canonical_pa_ids, + get_matching_registered_ca_ids, +) + + +@pytest.mark.network +class TestGetCanonicalPaIdsNetwork: + def test_get_canonical_pa_ids_known_caid(self): + # Using a known ClinGen Allele ID with MANE transcripts + clingen_allele_id = "CA321211" # Example ClinGen Allele ID + result = get_canonical_pa_ids(clingen_allele_id) + assert isinstance(result, list) + assert result == ["PA2573050890", "PA321212"] # Expected MANE PA ID + + def test_get_canonical_pa_ids_known_no_mane(self): + # Using a ClinGen Allele ID for protein change, as this will not have mane transcripts + clingen_allele_id = "PA102264" # Example ClinGen Allele ID with no MANE + result = get_canonical_pa_ids(clingen_allele_id) + assert result == [] + + def test_get_canonical_pa_ids_invalid_id(self): + # Using an invalid ClinGen Allele ID + clingen_allele_id = "INVALID_ID" + result = get_canonical_pa_ids(clingen_allele_id) + assert result == [] + + +@pytest.mark.network +class TestGetMatchingRegisteredCaIdsNetwork: + def test_get_matching_registered_ca_ids_known_paid(self): + # Using a known ClinGen PA ID with registered CA IDs + clingen_pa_id = "PA2573050890" # Example ClinGen PA ID + result = get_matching_registered_ca_ids(clingen_pa_id) + assert isinstance(result, list) + assert "CA321211" in result # Expected registered CA ID + + def test_get_matching_registered_ca_ids_known_no_caids(self): + # Using a ClinGen PA ID with no registered CA IDs + clingen_pa_id = "PA3051398879" # Example ClinGen PA ID with no registered CA IDs + result = get_matching_registered_ca_ids(clingen_pa_id) + assert result == [] + + def test_get_matching_registered_ca_ids_invalid_id(self): + # Using an invalid ClinGen PA ID + clingen_pa_id = "INVALID_ID" + result = get_matching_registered_ca_ids(clingen_pa_id) + assert result == [] + + +@pytest.mark.network +class TestGetAssociatedClinvarAlleleIdNetwork: + def test_get_associated_clinvar_allele_id_known_caid(self): + # Using a known ClinGen Allele ID with associated ClinVar Allele ID + clingen_allele_id = "CA321211" # Example ClinGen Allele ID + result = get_associated_clinvar_allele_id(clingen_allele_id) + assert result == "211565" # Expected ClinVar Allele ID + + def test_get_associated_clinvar_allele_id_no_association(self): + # Using a ClinGen Allele ID with no associated ClinVar Allele ID + clingen_allele_id = "CA9532274" # Example ClinGen Allele ID with no association + result = get_associated_clinvar_allele_id(clingen_allele_id) + assert result is None + + def test_get_associated_clinvar_allele_id_invalid_id(self): + # Using an invalid ClinGen Allele ID + clingen_allele_id = "INVALID_ID" + result = get_associated_clinvar_allele_id(clingen_allele_id) + assert result is None diff --git a/tests/lib/clingen/test_allele_registry.py b/tests/lib/clingen/test_allele_registry.py new file mode 100644 index 00000000..d54b6d4a --- /dev/null +++ b/tests/lib/clingen/test_allele_registry.py @@ -0,0 +1,189 @@ +from unittest import mock + +import pytest + +from mavedb.lib.clingen.allele_registry import ( + get_associated_clinvar_allele_id, + get_canonical_pa_ids, + get_matching_registered_ca_ids, +) + + +@pytest.mark.unit +@mock.patch("mavedb.lib.clingen.allele_registry.requests.get") +class TestGetCanonicalPaIds: + def test_get_canonical_pa_ids_success(self, mock_request): + # Mock response object + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "transcriptAlleles": [ + {"MANE": True, "@id": "https://reg.genome.network/allele/PA12345"}, + {"MANE": False, "@id": "https://reg.genome.network/allele/PA54321"}, + {"MANE": True, "@id": "https://reg.genome.network/allele/PA67890"}, + {"@id": "https://reg.genome.network/allele/PA00000"}, # No MANE + ] + } + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00001") + assert result == ["PA12345", "PA67890"] + + def test_get_canonical_pa_ids_no_transcript_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00002") + assert result == [] + + def test_get_canonical_pa_ids_empty_transcript_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"transcriptAlleles": []} + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00003") + assert result == [] + + def test_get_canonical_pa_ids_missing_mane_or_id(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "transcriptAlleles": [ + {"MANE": True}, # Missing @id + {"@id": "https://reg.genome.network/allele/PA99999"}, # Missing MANE + {}, # Missing both + ] + } + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA00004") + assert result == [] + + def test_get_canonical_pa_ids_api_error(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_request.return_value = mock_response + + result = get_canonical_pa_ids("CA404") + assert result == [] + + +@pytest.mark.unit +@mock.patch("mavedb.lib.clingen.allele_registry.requests.get") +class TestGetMatchingRegisteredCaIds: + def test_get_matching_registered_ca_ids_success(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "aminoAcidAlleles": [ + { + "matchingRegisteredTranscripts": [ + {"@id": "https://reg.genome.network/allele/CA11111"}, + {"@id": "https://reg.genome.network/allele/CA22222"}, + ] + }, + { + "matchingRegisteredTranscripts": [ + {"@id": "https://reg.genome.network/allele/CA33333"}, + ] + }, + { + # No matchingRegisteredTranscripts + }, + ] + } + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA12345") + assert result == ["CA11111", "CA22222", "CA33333"] + + def test_get_matching_registered_ca_ids_no_amino_acid_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA00000") + assert result == [] + + def test_get_matching_registered_ca_ids_empty_amino_acid_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"aminoAcidAlleles": []} + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA00001") + assert result == [] + + def test_get_matching_registered_ca_ids_missing_matching_registered_transcripts(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "aminoAcidAlleles": [ + {}, # No matchingRegisteredTranscripts + {"matchingRegisteredTranscripts": []}, # Empty list + ] + } + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PA00002") + assert result == [] + + def test_get_matching_registered_ca_ids_api_error(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 500 + mock_request.return_value = mock_response + + result = get_matching_registered_ca_ids("PAERROR") + assert result == [] + + +@pytest.mark.unit +@mock.patch("mavedb.lib.clingen.allele_registry.requests.get") +class TestGetAssociatedClinvarAlleleId: + def test_get_associated_clinvar_allele_id_success(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"externalRecords": {"ClinVarAlleles": [{"alleleId": "123456"}]}} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00001") + assert result == "123456" + + def test_get_associated_clinvar_allele_id_no_external_records(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00002") + assert result is None + + def test_get_associated_clinvar_allele_id_no_clinvar_alleles(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"externalRecords": {}} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00003") + assert result is None + + def test_get_associated_clinvar_allele_id_missing_allele_id(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"externalRecords": {"ClinVarAlleles": [{}]}} + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA00004") + assert result is None + + def test_get_associated_clinvar_allele_id_api_error(self, mock_request): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_request.return_value = mock_response + + result = get_associated_clinvar_allele_id("CA404") + assert result is None From c7bf7f702bbfc867b67bf2f3574d2d96aae548df Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 4 Feb 2026 11:32:24 -0800 Subject: [PATCH 67/70] refactor: remove redundant fixture for setting up sample variants in gnomad tests --- .../jobs/external_services/test_gnomad.py | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/tests/worker/jobs/external_services/test_gnomad.py b/tests/worker/jobs/external_services/test_gnomad.py index a3e379e9..92f515c1 100644 --- a/tests/worker/jobs/external_services/test_gnomad.py +++ b/tests/worker/jobs/external_services/test_gnomad.py @@ -9,8 +9,6 @@ from mavedb.models.enums.job_pipeline import JobStatus, PipelineStatus from mavedb.models.gnomad_variant import GnomADVariant from mavedb.models.mapped_variant import MappedVariant -from mavedb.models.score_set import ScoreSet -from mavedb.models.variant import Variant from mavedb.models.variant_annotation_status import VariantAnnotationStatus from mavedb.worker.jobs.external_services.gnomad import link_gnomad_variants from mavedb.worker.lib.managers.job_manager import JobManager @@ -23,33 +21,6 @@ class TestLinkGnomadVariantsUnit: """Unit tests for the link_gnomad_variants job.""" - @pytest.fixture - def setup_sample_variants_with_caid( - self, session, with_populated_domain_data, mock_worker_ctx, sample_link_gnomad_variants_run - ): - """Setup variants and mapped variants in the database for testing.""" - score_set = session.get(ScoreSet, sample_link_gnomad_variants_run.job_params["score_set_id"]) - - # Add a variant and mapped variant to the database with a CAID - variant = Variant( - urn="urn:variant:test-variant-with-caid", - score_set_id=score_set.id, - hgvs_nt="NM_000000.1:c.1A>G", - hgvs_pro="NP_000000.1:p.Met1Val", - data={"hgvs_c": "NM_000000.1:c.1A>G", "hgvs_p": "NP_000000.1:p.Met1Val"}, - ) - session.add(variant) - session.commit() - mapped_variant = MappedVariant( - variant_id=variant.id, - clingen_allele_id="CA123", - current=True, - mapped_date="2024-01-01T00:00:00Z", - mapping_api_version="1.0.0", - ) - session.add(mapped_variant) - session.commit() - async def test_link_gnomad_variants_no_variants_with_caids( self, session, From 547be35864cd762ada63a8317a1a7d003d6bca19 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 4 Feb 2026 12:30:58 -0800 Subject: [PATCH 68/70] chore: add TODO for caching ClinVar control data to improve performance --- src/mavedb/worker/jobs/external_services/clinvar.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/mavedb/worker/jobs/external_services/clinvar.py b/src/mavedb/worker/jobs/external_services/clinvar.py index 1f1b3140..e66de3e5 100644 --- a/src/mavedb/worker/jobs/external_services/clinvar.py +++ b/src/mavedb/worker/jobs/external_services/clinvar.py @@ -33,6 +33,11 @@ logger = logging.getLogger(__name__) +# TODO#649: This function is currently called multiple times to fill in controls for each month/year. +# We should consider caching both fetched TSV data and/or ClinGen API results. This would +# significantly speed up large jobs annotating many variants. + + @with_pipeline_management async def refresh_clinvar_controls(ctx: dict, job_id: int, job_manager: JobManager) -> JobResultData: """ From 4a878b0c50a36cb09b09abe8377fd9b618494182 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 4 Feb 2026 12:41:40 -0800 Subject: [PATCH 69/70] feat: add multiple refresh job definitions for ClinVar controls with year and month parameters --- src/mavedb/lib/workflow/definitions.py | 145 +++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/src/mavedb/lib/workflow/definitions.py b/src/mavedb/lib/workflow/definitions.py index 54a7b645..72c83e42 100644 --- a/src/mavedb/lib/workflow/definitions.py +++ b/src/mavedb/lib/workflow/definitions.py @@ -49,6 +49,151 @@ def annotation_pipeline_job_definitions() -> list[JobDefinition]: }, "dependencies": [("submit_uniprot_mapping_jobs_for_score_set", DependencyType.SUCCESS_REQUIRED)], }, + # TODO#650: Simplify or automate the generation of these repetitive job definitions + { + "key": "refresh_clinvar_controls_201502", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2015, + "month": 2, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201601", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2016, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201701", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2017, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201801", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2018, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_201901", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2019, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202001", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2020, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202101", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2021, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202201", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2022, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202301", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2023, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202401", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2024, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202501", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2025, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, + { + "key": "refresh_clinvar_controls_202601", + "function": "refresh_clinvar_controls", + "type": JobType.MAPPED_VARIANT_ANNOTATION, + "params": { + "correlation_id": None, # Required param to be filled in at runtime + "score_set_id": None, # Required param to be filled in at runtime + "year": 2026, + "month": 1, + }, + "dependencies": [("submit_score_set_mappings_to_car", DependencyType.SUCCESS_REQUIRED)], + }, ] From f3ea5ce04ae68de3b5a396074268f425686df100 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Wed, 4 Feb 2026 14:57:05 -0800 Subject: [PATCH 70/70] feat: enhance test workflow to run fast tests on pull requests and full tests on main branch --- .github/workflows/run-tests-on-push.yml | 31 +++++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run-tests-on-push.yml b/.github/workflows/run-tests-on-push.yml index 6cb7d18e..f07da233 100644 --- a/.github/workflows/run-tests-on-push.yml +++ b/.github/workflows/run-tests-on-push.yml @@ -1,6 +1,7 @@ -name: Run Tests (On Push) +name: Run Tests on: push: + # Run all tests on main, fast tests on other branches env: LOG_CONFIG: test @@ -50,7 +51,12 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev - - run: poetry run pytest tests/ + - name: Run fast tests on non-main branches + if: github.event_name == 'push' && github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" + - name: Run full tests on main + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: poetry run pytest tests/ run-tests-3_11: runs-on: ubuntu-latest @@ -66,7 +72,12 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev --extras server - - run: poetry run pytest tests/ --show-capture=stdout --cov=src + - name: Run fast tests on non-main branches + if: github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" --show-capture=stdout + - name: Run all tests with coverage on main branch + if: github.ref == 'refs/heads/main' + run: poetry run pytest tests/ --show-capture=stdout --cov=src run-tests-3_12-core-dependencies: runs-on: ubuntu-latest @@ -80,7 +91,12 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev - - run: poetry run pytest tests/ + - name: Run fast tests on non-main branches + if: github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" + - name: Run all tests on main branch + if: github.ref == 'refs/heads/main' + run: poetry run pytest tests/ run-tests-3_12: runs-on: ubuntu-latest @@ -96,4 +112,9 @@ jobs: - run: pip install --upgrade pip - run: pip install poetry - run: poetry install --with dev --extras server - - run: poetry run pytest tests/ --show-capture=stdout --cov=src + - name: Run fast tests on non-main branches + if: github.ref != 'refs/heads/main' + run: poetry run pytest tests/ -m "not network and not slow" --show-capture=stdout + - name: Run all tests with coverage on main branch + if: github.ref == 'refs/heads/main' + run: poetry run pytest tests/ --show-capture=stdout --cov=src