diff --git a/pyproject.toml b/pyproject.toml index cefeb0eb..02f842b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,9 +193,6 @@ exclude = [ "ANN201", "ANN202", ] -"src/ga4gh/vrs/extras/annotator/vcf.py" = [ - "PTH123", # see https://github.com/ga4gh/vrs-python/issues/482 -] "src/ga4gh/vrs/extras/object_store.py" = [ "ANN", "D", diff --git a/src/ga4gh/vrs/extras/annotator/cli.py b/src/ga4gh/vrs/extras/annotator/cli.py index 6b3de705..80c80c8b 100644 --- a/src/ga4gh/vrs/extras/annotator/cli.py +++ b/src/ga4gh/vrs/extras/annotator/cli.py @@ -172,19 +172,15 @@ def _annotate_vcf_cli( annotator = VCFAnnotator( seqrepo_dp_type, seqrepo_base_url, str(seqrepo_root_dir.absolute()) ) - vcf_out_str = str(vcf_out.absolute()) if vcf_out is not None else vcf_out - vrs_pkl_out_str = ( - str(vrs_pickle_out.absolute()) if vrs_pickle_out is not None else vrs_pickle_out - ) start = timer() msg = f"Annotating {vcf_in} with the VCF Annotator..." _logger.info(msg) if not silent: click.echo(msg) annotator.annotate( - str(vcf_in.absolute()), - vcf_out=vcf_out_str, - vrs_pickle_out=vrs_pkl_out_str, + vcf_in.absolute(), + output_vcf_path=vcf_out, + output_pkl_path=vrs_pickle_out, vrs_attributes=vrs_attributes, assembly=assembly, compute_for_ref=(not skip_ref), diff --git a/src/ga4gh/vrs/extras/annotator/vcf.py b/src/ga4gh/vrs/extras/annotator/vcf.py index fe2a6069..721d530d 100644 --- a/src/ga4gh/vrs/extras/annotator/vcf.py +++ b/src/ga4gh/vrs/extras/annotator/vcf.py @@ -3,6 +3,7 @@ import logging import pickle from enum import Enum +from pathlib import Path import pysam from biocommons.seqrepo import SeqRepo @@ -78,9 +79,9 @@ def __init__( @use_ga4gh_compute_identifier_when(VrsObjectIdentifierIs.MISSING) def annotate( self, - vcf_in: str, - vcf_out: str | None = None, - vrs_pickle_out: str | None = None, + input_vcf_path: Path, + output_vcf_path: Path | None = None, + output_pkl_path: Path | None = None, vrs_attributes: bool = False, assembly: str = "GRCh38", compute_for_ref: bool = True, @@ -89,9 +90,9 @@ def annotate( """Given a VCF, produce an output VCF annotated with VRS allele IDs, and/or a pickle file containing the full VRS objects. - :param vcf_in: Location of input VCF - :param vcf_out: The path for the output VCF file - :param vrs_pickle_out: The path for the output VCF pickle file + :param input_vcf_path: Location of input VCF + :param output_vcf_path: The path for the output VCF file + :param output_pkl_path: The path for the output VCF pickle file :param vrs_attributes: If `True`, include VRS_Start, VRS_End, VRS_State properties in the VCF INFO field. If `False` will not include these properties. Only used if `vcf_out` is defined. @@ -102,17 +103,17 @@ def annotate( object for a record. If `False` then VRS object will be returned even if validation checks fail, although all instances of failed validation are logged as warnings regardless. + :raise VCFAnnotatorError: if no output formats are selected """ - if not any((vcf_out, vrs_pickle_out)): + if not any((output_vcf_path, output_pkl_path)): msg = "Must provide one of: `vcf_out` or `vrs_pickle_out`" raise VCFAnnotatorError(msg) info_field_num = "R" if compute_for_ref else "A" info_field_desc = "REF and ALT" if compute_for_ref else "ALT" - vrs_data = {} - vcf_in = pysam.VariantFile(filename=vcf_in) - vcf_in.header.info.add( + vcf = pysam.VariantFile(filename=str(input_vcf_path.absolute())) + vcf.header.info.add( self.VRS_ALLELE_IDS_FIELD, info_field_num, "String", @@ -121,7 +122,7 @@ def annotate( f"GT indexes of the {info_field_desc} alleles" ), ) - vcf_in.header.info.add( + vcf.header.info.add( self.VRS_ERROR_FIELD, ".", "String", @@ -129,7 +130,7 @@ def annotate( ) if vrs_attributes: - vcf_in.header.info.add( + vcf.header.info.add( self.VRS_STARTS_FIELD, info_field_num, "String", @@ -138,7 +139,7 @@ def annotate( f"VRS Alleles corresponding to the GT indexes of the {info_field_desc} alleles" ), ) - vcf_in.header.info.add( + vcf.header.info.add( self.VRS_ENDS_FIELD, info_field_num, "String", @@ -147,7 +148,7 @@ def annotate( f"Alleles corresponding to the GT indexes of the {info_field_desc} alleles" ), ) - vcf_in.header.info.add( + vcf.header.info.add( self.VRS_STATES_FIELD, info_field_num, "String", @@ -157,20 +158,26 @@ def annotate( ), ) - if vcf_out: - vcf_out = pysam.VariantFile(vcf_out, "w", header=vcf_in.header) - - output_vcf = bool(vcf_out) - output_pickle = bool(vrs_pickle_out) + vcf_out = ( + pysam.VariantFile(str(output_vcf_path.absolute()), "w", header=vcf.header) + if output_vcf_path + else None + ) - for record in vcf_in: - additional_info_fields = [self.VRS_ALLELE_IDS_FIELD] - if vrs_attributes: - additional_info_fields += [ - self.VRS_STARTS_FIELD, - self.VRS_ENDS_FIELD, - self.VRS_STATES_FIELD, - ] + # only retain raw data if dumping to pkl + vrs_data = {} if output_pkl_path else None + for record in vcf: + if vcf_out: + additional_info_fields = [self.VRS_ALLELE_IDS_FIELD] + if vrs_attributes: + additional_info_fields += [ + self.VRS_STARTS_FIELD, + self.VRS_ENDS_FIELD, + self.VRS_STATES_FIELD, + ] + else: + # no INFO field names need to be designated if not producing an annotated VCF + additional_info_fields = [] try: vrs_field_data = self._get_vrs_data( record, @@ -178,8 +185,6 @@ def annotate( assembly, additional_info_fields, vrs_attributes=vrs_attributes, - output_pickle=output_pickle, - output_vcf=output_vcf, compute_for_ref=compute_for_ref, require_validation=require_validation, ) @@ -198,29 +203,27 @@ def annotate( vrs_field_data, ) - if output_vcf: + if output_vcf_path and vcf_out: for k in additional_info_fields: record.info[k] = [value or "." for value in vrs_field_data[k]] vcf_out.write(record) - vcf_in.close() + vcf.close() - if output_vcf: + if vcf_out: vcf_out.close() - if vrs_pickle_out: - with open(vrs_pickle_out, "wb") as wf: + if output_pkl_path: + with output_pkl_path.open("wb") as wf: pickle.dump(vrs_data, wf) def _get_vrs_object( self, vcf_coords: str, - vrs_data: dict, + vrs_data: dict | None, vrs_field_data: dict, assembly: str, vrs_data_key: str | None = None, - output_pickle: bool = True, - output_vcf: bool = False, vrs_attributes: bool = False, require_validation: bool = True, ) -> None: @@ -228,14 +231,13 @@ def _get_vrs_object( be mutated. :param vcf_coords: Allele to get VRS object for. Format is chr-pos-ref-alt - :param vrs_data: Dictionary containing the VRS object information for the VCF - :param vrs_field_data: If `output_vcf`, keys are VRS Fields and values are list - of VRS data. Empty otherwise. + :param vrs_data: All constructed VRS objects. Can be `None` if no data dumps + will be created. + :param vrs_field_data: If `vrs_data`, keys are VRS Fields and values are list + of VRS data. Empty dict otherwise. :param assembly: The assembly used in `vcf_coords` :param vrs_data_key: The key to update in `vrs_data`. If not provided, will use `vcf_coords` as the key. - :param output_pickle: If `True`, VRS pickle file will be output. - :param output_vcf: If `True`, annotated VCF file will be output. :param vrs_attributes: If `True` will include VRS_Start, VRS_End, VRS_State fields in the INFO field. If `False` will not include these fields. Only used if `output_vcf` set to `True`. @@ -277,11 +279,11 @@ def _get_vrs_object( "None was returned when translating %s from gnomad", vcf_coords ) - if output_pickle and vrs_obj: + if vrs_data and vrs_obj: key = vrs_data_key if vrs_data_key else vcf_coords vrs_data[key] = str(vrs_obj.model_dump(exclude_none=True)) - if output_vcf: + if vrs_field_data: allele_id = vrs_obj.id if vrs_obj else "" vrs_field_data[self.VRS_ALLELE_IDS_FIELD].append(allele_id) @@ -295,9 +297,7 @@ def _get_vrs_object( else "" ) else: - start = "" - end = "" - alt = "" + start = end = alt = "" vrs_field_data[self.VRS_STARTS_FIELD].append(start) vrs_field_data[self.VRS_ENDS_FIELD].append(end) @@ -306,12 +306,10 @@ def _get_vrs_object( def _get_vrs_data( self, record: pysam.VariantRecord, - vrs_data: dict, + vrs_data: dict | None, assembly: str, additional_info_fields: list[str], vrs_attributes: bool = False, - output_pickle: bool = True, - output_vcf: bool = True, compute_for_ref: bool = True, require_validation: bool = True, ) -> dict: @@ -325,20 +323,15 @@ def _get_vrs_data( :param vrs_attributes: If `True` will include VRS_Start, VRS_End, VRS_State fields in the INFO field. If `False` will not include these fields. Only used if `output_vcf` set to `True`. - :param output_pickle: If `True`, VRS pickle file will be output. - :param output_vcf: If `True`, annotated VCF file will be output. :param compute_for_ref: If true, compute VRS IDs for the reference allele :param require_validation: If `True` then validation checks must pass in order to return a VRS object. A `DataProxyValidationError` will be raised if validation checks fail. If `False` then VRS object will be returned even if validation checks fail. Defaults to `True`. - :return: If `output_vcf = True`, a dictionary containing VRS Fields and list - of associated values. If `output_vcf = False`, an empty dictionary will be - returned. + :return: A dictionary mapping VRS-related INFO fields to lists of associated + values. Will be empty if `create_annotated_vcf` is false. """ - vrs_field_data = ( - {field: [] for field in additional_info_fields} if output_vcf else {} - ) + vrs_field_data = {field: [] for field in additional_info_fields} # Get VRS data for reference allele gnomad_loc = f"{record.chrom}-{record.pos}" @@ -349,8 +342,6 @@ def _get_vrs_data( vrs_data, vrs_field_data, assembly, - output_pickle=output_pickle, - output_vcf=output_vcf, vrs_attributes=vrs_attributes, require_validation=require_validation, ) @@ -362,9 +353,8 @@ def _get_vrs_data( for allele in alleles: if "*" in allele: _logger.debug("Star allele found: %s", allele) - if output_vcf: - for field in additional_info_fields: - vrs_field_data[field].append("") + for field in additional_info_fields: + vrs_field_data[field].append("") else: self._get_vrs_object( allele, @@ -372,8 +362,6 @@ def _get_vrs_data( vrs_field_data, assembly, vrs_data_key=data, - output_pickle=output_pickle, - output_vcf=output_vcf, vrs_attributes=vrs_attributes, require_validation=require_validation, ) diff --git a/tests/extras/test_annotate_vcf.py b/tests/extras/test_annotate_vcf.py index e006057d..efceae74 100644 --- a/tests/extras/test_annotate_vcf.py +++ b/tests/extras/test_annotate_vcf.py @@ -10,7 +10,7 @@ from ga4gh.vrs.dataproxy import DataProxyValidationError from ga4gh.vrs.extras.annotator.vcf import VCFAnnotator, VCFAnnotatorError -TEST_DATA_DIR = "tests/extras/data" +TEST_DATA_DIR = Path("tests/extras/data") @pytest.fixture @@ -18,14 +18,21 @@ def vcf_annotator(): return VCFAnnotator("rest") +@pytest.fixture(scope="session") +def input_vcf(): + """Provide fixture for sample input VCF""" + return TEST_DATA_DIR / "test_vcf_input.vcf" + + @pytest.mark.vcr -def test_annotate_vcf_grch38_noattrs(vcf_annotator, vcr_cassette): +def test_annotate_vcf_grch38_noattrs( + vcf_annotator: VCFAnnotator, input_vcf: Path, tmp_path: Path, vcr_cassette +): vcr_cassette.allow_playback_repeats = False - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" - output_vcf = f"{TEST_DATA_DIR}/test_vcf_output_grch38_noattrs.vcf.gz" - output_vrs_pkl = f"{TEST_DATA_DIR}/test_vcf_pkl_grch38_noattrs.pkl" + output_vcf = tmp_path / "test_vcf_output_grch38_noattrs.vcf.gz" + output_vrs_pkl = tmp_path / "test_vcf_pkl_grch38_noattrs.pkl" expected_vcf_no_vrs_attrs = ( - f"{TEST_DATA_DIR}/test_vcf_expected_output_no_vrs_attrs.vcf.gz" + TEST_DATA_DIR / "test_vcf_expected_output_no_vrs_attrs.vcf.gz" ) # Test GRCh38 assembly, which was used for input_vcf and no vrs attributes @@ -38,19 +45,18 @@ def test_annotate_vcf_grch38_noattrs(vcf_annotator, vcr_cassette): out_vcf_lines, expected_output_lines, strict=False ): assert actual_line == expected_line - assert Path(output_vrs_pkl).exists() + assert output_vrs_pkl.exists() assert vcr_cassette.all_played - Path(output_vcf).unlink() - Path(output_vrs_pkl).unlink() @pytest.mark.vcr -def test_annotate_vcf_grch38_attrs(vcf_annotator, vcr_cassette): +def test_annotate_vcf_grch38_attrs( + vcf_annotator: VCFAnnotator, input_vcf: Path, tmp_path: Path, vcr_cassette +): vcr_cassette.allow_playback_repeats = False - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" - output_vcf = f"{TEST_DATA_DIR}/test_vcf_output_grch38_attrs.vcf.gz" - output_vrs_pkl = f"{TEST_DATA_DIR}/test_vcf_pkl_grch38_attrs.pkl" - expected_vcf = f"{TEST_DATA_DIR}/test_vcf_expected_output.vcf.gz" + output_vcf = tmp_path / "test_vcf_output_grch38_attrs.vcf.gz" + output_vrs_pkl = tmp_path / "test_vcf_pkl_grch38_attrs.pkl" + expected_vcf = TEST_DATA_DIR / "test_vcf_expected_output.vcf.gz" # Test GRCh38 assembly, which was used for input_vcf and vrs attributes vcf_annotator.annotate(input_vcf, output_vcf, output_vrs_pkl, vrs_attributes=True) @@ -62,19 +68,18 @@ def test_annotate_vcf_grch38_attrs(vcf_annotator, vcr_cassette): out_vcf_lines, expected_output_lines, strict=False ): assert actual_line == expected_line - assert Path(output_vrs_pkl).exists() + assert output_vrs_pkl.exists() assert vcr_cassette.all_played - Path(output_vcf).unlink() - Path(output_vrs_pkl).unlink() @pytest.mark.vcr -def test_annotate_vcf_grch38_attrs_altsonly(vcf_annotator, vcr_cassette): +def test_annotate_vcf_grch38_attrs_altsonly( + vcf_annotator: VCFAnnotator, input_vcf: Path, tmp_path: Path, vcr_cassette +): vcr_cassette.allow_playback_repeats = False - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" - output_vcf = f"{TEST_DATA_DIR}/test_vcf_output_grch38_attrs_altsonly.vcf.gz" - output_vrs_pkl = f"{TEST_DATA_DIR}/test_vcf_pkl_grch38_attrs_altsonly.pkl" - expected_altsonly_vcf = f"{TEST_DATA_DIR}/test_vcf_expected_altsonly_output.vcf.gz" + output_vcf = tmp_path / "test_vcf_output_grch38_attrs_altsonly.vcf.gz" + output_vrs_pkl = tmp_path / "test_vcf_pkl_grch38_attrs_altsonly.pkl" + expected_altsonly_vcf = TEST_DATA_DIR / "test_vcf_expected_altsonly_output.vcf.gz" # Test GRCh38 assembly with VRS computed for ALTs only, which was used for input_vcf and vrs attributes vcf_annotator.annotate( @@ -92,19 +97,18 @@ def test_annotate_vcf_grch38_attrs_altsonly(vcf_annotator, vcr_cassette): out_vcf_lines, expected_output_lines, strict=False ): assert actual_line == expected_line - assert Path(output_vrs_pkl).exists() + assert output_vrs_pkl.exists() assert vcr_cassette.all_played - Path(output_vcf).unlink() - Path(output_vrs_pkl).unlink() @pytest.mark.vcr -def test_annotate_vcf_grch37_attrs(vcf_annotator, vcr_cassette): +def test_annotate_vcf_grch37_attrs( + vcf_annotator: VCFAnnotator, input_vcf: Path, tmp_path: Path, vcr_cassette +): vcr_cassette.allow_playback_repeats = False - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" - output_vcf = f"{TEST_DATA_DIR}/test_vcf_output_grch37_attrs.vcf.gz" - output_vrs_pkl = f"{TEST_DATA_DIR}/test_vcf_pkl_grch37_attrs.pkl" - expected_vcf = f"{TEST_DATA_DIR}/test_vcf_expected_output.vcf.gz" + output_vcf = tmp_path / "test_vcf_output_grch37_attrs.vcf.gz" + output_vrs_pkl = tmp_path / "test_vcf_pkl_grch37_attrs.pkl" + expected_vcf = TEST_DATA_DIR / "test_vcf_expected_output.vcf.gz" # Test GRCh37 assembly, which was not used for input_vcf vcf_annotator.annotate( @@ -115,39 +119,38 @@ def test_annotate_vcf_grch37_attrs(vcf_annotator, vcr_cassette): with gzip.open(expected_vcf, "rt") as expected_output: expected_output_lines = expected_output.readlines() assert out_vcf_lines != expected_output_lines - assert Path(output_vrs_pkl).exists() + assert output_vrs_pkl.exists() assert vcr_cassette.all_played - Path(output_vcf).unlink() - Path(output_vrs_pkl).unlink() @pytest.mark.vcr -def test_annotate_vcf_pickle_only(vcf_annotator, vcr_cassette): +def test_annotate_vcf_pickle_only( + vcf_annotator: VCFAnnotator, input_vcf: Path, tmp_path: Path, vcr_cassette +): vcr_cassette.allow_playback_repeats = False - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" - output_vcf = f"{TEST_DATA_DIR}/test_vcf_output_pickle_only.vcf.gz" - output_vrs_pkl = f"{TEST_DATA_DIR}/test_vcf_pkl_pickle_only.pkl" + output_vcf = tmp_path / "test_vcf_output_pickle_only.vcf.gz" + output_vrs_pkl = tmp_path / "test_vcf_pkl_pickle_only.pkl" # Test only pickle output vcf_annotator.annotate( - input_vcf, vrs_pickle_out=output_vrs_pkl, vrs_attributes=True + input_vcf, output_pkl_path=output_vrs_pkl, vrs_attributes=True ) - assert Path(output_vrs_pkl).exists() - assert not Path(output_vcf).exists() + assert output_vrs_pkl.exists() + assert not output_vcf.exists() assert vcr_cassette.all_played - Path(output_vrs_pkl).unlink() @pytest.mark.vcr -def test_annotate_vcf_vcf_only(vcf_annotator, vcr_cassette): +def test_annotate_vcf_vcf_only( + vcf_annotator: VCFAnnotator, input_vcf: Path, tmp_path: Path, vcr_cassette +): vcr_cassette.allow_playback_repeats = False - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" - output_vcf = f"{TEST_DATA_DIR}/test_vcf_output_vcf_only.vcf.gz" - output_vrs_pkl = f"{TEST_DATA_DIR}/test_vcf_pkl_vcf_only.pkl" - expected_vcf = f"{TEST_DATA_DIR}/test_vcf_expected_output.vcf.gz" + output_vcf = tmp_path / "test_vcf_output_vcf_only.vcf.gz" + output_vrs_pkl = tmp_path / "test_vcf_pkl_vcf_only.pkl" + expected_vcf = TEST_DATA_DIR / "test_vcf_expected_output.vcf.gz" # Test only VCF output - vcf_annotator.annotate(input_vcf, vcf_out=output_vcf, vrs_attributes=True) + vcf_annotator.annotate(input_vcf, output_vcf_path=output_vcf, vrs_attributes=True) with gzip.open(output_vcf, "rt") as out_vcf: out_vcf_lines = out_vcf.readlines() with gzip.open(expected_vcf, "rt") as expected_output: @@ -155,19 +158,16 @@ def test_annotate_vcf_vcf_only(vcf_annotator, vcr_cassette): assert out_vcf_lines == expected_output_lines assert vcr_cassette.all_played assert not Path(output_vrs_pkl).exists() - Path(output_vcf).unlink() - -def test_annotate_vcf_input_validation(vcf_annotator): - input_vcf = f"{TEST_DATA_DIR}/test_vcf_input.vcf" +def test_annotate_vcf_input_validation(vcf_annotator: VCFAnnotator, input_vcf: Path): with pytest.raises(VCFAnnotatorError) as e: vcf_annotator.annotate(input_vcf) assert str(e.value) == "Must provide one of: `vcf_out` or `vrs_pickle_out`" @pytest.mark.vcr -def test_get_vrs_object_invalid_input(vcf_annotator, caplog): +def test_get_vrs_object_invalid_input(vcf_annotator: VCFAnnotator, caplog): """Test that _get_vrs_object method works as expected with invalid input""" # some tests below are checking for debug logging statements caplog.set_level(logging.DEBUG)