diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe11f3c..de31cf7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + sudo apt-get update + sudo apt-get install -y libbz2-dev zlib1g-dev liblzma-dev python -m pip install --upgrade pip python -m pip install tox tox-gh-actions - name: Test with tox diff --git a/.gitignore b/.gitignore index 2a43010..7bb4264 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,5 @@ docs/build .tox dist log* -coverage.xml \ No newline at end of file +coverage.xml +tests/data/GRCh38_repeats.bed.sorted* \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ec84bd..6aad8fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,29 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.0.0] - 2026-02-25 + +* **Tabix-based STR panel handling** + + * STR reference is no longer loaded fully into memory. + * The panel is now prepared automatically: + + * sorted by genomic chromosome order, + * BGZF-compressed, + * tabix-indexed. + * During annotation, STR regions are queried directly from the tabix index, + enabling fast genomic lookups and significantly reducing memory usage. + * Improves scalability and allows safe multi-worker execution. + +* **Parallel directory processing (`jobs`)** + + * Added new `jobs` option to control parallel annotation of VCF files. + * Each worker processes one VCF file independently. + * If `jobs` is not provided, the tool now estimates an optimal number of workers + based on available CPU cores and system memory. + * This can substantially speed up processing of large VCF directories. + + ## [0.3.0] - 2026-01-12 - **Mismatch handling between VCF and STR panel** * Added support for cases where the VCF reference does not exactly match the STR panel. diff --git a/docs/conf.py b/docs/conf.py index 34c3ff9..8fddff1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,7 +1,7 @@ project = 'strvcf_annotator' copyright = '2026, Olesia Kondrateva' author = 'Olesia Kondrateva' -release = '0.3.0' +release = '1.0.0' extensions = [ "sphinx.ext.autodoc", diff --git a/pyproject.toml b/pyproject.toml index 7674be4..e4ae9b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ classifiers = [ dependencies = [ "pysam>=0.22.0", "pandas>=2.0.0", - "trtools>=5.0.0" + "trtools>=5.0.0", + "psutil" ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index cb4c1ab..6f0a86e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pysam>=0.22.0 pandas>=2.0.0 -trtools>=5.0.0 \ No newline at end of file +trtools>=5.0.0 +psutil \ No newline at end of file diff --git a/setup.py b/setup.py index cc51180..115d644 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,8 @@ requirements = [ "pysam>=0.22.0", "pandas>=2.0.0", + "trtools>=5.0.0", + "psutil", ] test_requirements = [ @@ -32,7 +34,7 @@ author_email="xkdnoa@gmail.com", python_requires=">=3.8", classifiers=[ - "Development Status :: 2 - Pre-Alpha", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Natural Language :: English", diff --git a/src/strvcf_annotator/__init__.py b/src/strvcf_annotator/__init__.py index 9f28c56..d299fe9 100644 --- a/src/strvcf_annotator/__init__.py +++ b/src/strvcf_annotator/__init__.py @@ -7,7 +7,7 @@ __author__ = """Olesia Kondrateva""" __email__ = "xkdnoa@gmail.com" -__version__ = "0.3.0" +__version__ = "1.0.0" # Public API exports from .api import STRAnnotator, annotate_vcf diff --git a/src/strvcf_annotator/api.py b/src/strvcf_annotator/api.py index d046f54..510b389 100644 --- a/src/strvcf_annotator/api.py +++ b/src/strvcf_annotator/api.py @@ -1,6 +1,8 @@ """Library API for programmatic access to STR annotation functionality.""" import logging +import statistics +from collections import Counter from typing import Iterator, Optional import pysam @@ -53,8 +55,8 @@ class STRAnnotator: ---------- str_bed_path : str Path to STR BED file - str_df : pd.DataFrame - Loaded STR reference data + str_panel_gz : str + Path to BGZF-compressed, tabix-indexed STR reference file. parser : BaseVCFParser Parser for genotype extraction somatic_mode : bool @@ -83,8 +85,7 @@ def __init__( mismatch_truth: str = "panel", # "panel" | "vcf" | "skip" ): validate_str_bed_file(str_bed_path) - self.str_bed_path = str_bed_path - self.str_df = load_str_reference(str_bed_path) + self.str_panel_gz = load_str_reference(str_bed_path) self.parser = parser if parser is not None else GenericParser() self.somatic_mode = somatic_mode @@ -92,8 +93,6 @@ def __init__( self.ignore_mismatch_warnings = ignore_mismatch_warnings self.mismatch_truth = mismatch_truth - logger.info(f"Loaded {len(self.str_df)} STR regions from {str_bed_path}") - def annotate_vcf_file( self, input_path: str, @@ -135,7 +134,6 @@ def annotate_vcf_file( - ``"skip"``: skip variants with mismatches entirely If ``None``, the value set on the annotator instance is used. - Raises ------ ValidationError @@ -160,7 +158,7 @@ def annotate_vcf_file( logger.info(f"Annotating {input_path}...") annotate_vcf_to_file( input_path, - self.str_df, + self.str_panel_gz, output_path, self.parser, somatic_mode=smode, @@ -231,7 +229,7 @@ def annotate_vcf_stream( yield from generate_annotated_records( vcf_in=vcf_in, - str_df=self.str_df, + str_panel_gz=self.str_panel_gz, parser=self.parser, somatic_mode=smode, ignore_mismatch_warnings=imw, @@ -246,6 +244,7 @@ def process_directory( somatic_mode: Optional[bool] = None, ignore_mismatch_warnings: Optional[bool] = None, mismatch_truth: Optional[str] = None, + jobs: Optional[int] = None, ) -> None: """ Batch process a directory of VCF files. @@ -281,7 +280,11 @@ def process_directory( - ``"skip"``: skip variants with mismatches entirely If ``None``, the value set on the annotator instance is used. - + jobs: int, optional + - If jobs is None: compute jobs automatically: + jobs_auto = min(cpu_cores, n_files) + jobs_auto = min(jobs_auto, floor(available_ram / ram_per_worker_estimate)) + - If jobs is provided: use it exactly. Raises ------ ValidationError @@ -307,12 +310,13 @@ def process_directory( logger.info(f"Processing VCF files in {input_dir}...") process_directory( input_dir=input_dir, - str_bed_path=self.str_bed_path, + str_panel_gz=self.str_panel_gz, output_dir=output_dir, parser=self.parser, somatic_mode=smode, ignore_mismatch_warnings=imw, mismatch_truth=mtruth, + jobs=jobs, ) logger.info(f"Batch processing complete. Output in {output_dir}") @@ -341,7 +345,7 @@ def get_str_at_position(self, chrom: str, pos: int) -> Optional[dict]: """ from .core.str_reference import get_str_at_position - return get_str_at_position(self.str_df, chrom, pos) + return get_str_at_position(self.str_panel_gz, chrom, pos) def get_statistics(self) -> dict: """ @@ -358,15 +362,49 @@ def get_statistics(self) -> dict: >>> stats = annotator.get_statistics() >>> print(f"Total STR regions: {stats['total_regions']}") """ - stats = { - "total_regions": len(self.str_df), - "chromosomes": self.str_df["CHROM"].nunique(), - "unique_repeat_units": self.str_df["RU"].nunique(), - "period_distribution": self.str_df["PERIOD"].value_counts().to_dict(), - "mean_repeat_count": self.str_df["COUNT"].mean(), - "median_repeat_count": self.str_df["COUNT"].median(), + tbx = pysam.TabixFile(self.str_panel_gz) + + total_regions = 0 + chromosomes = set() + repeat_units = set() + period_counter = Counter() + counts = [] + + # Iterate through all records in the file + for line in tbx.fetch(): + parts = line.rstrip("\n").split("\t") + if len(parts) < 5: + continue + + try: + chrom = parts[0] + start = int(parts[1]) + end = int(parts[2]) + period = int(parts[3]) + ru = parts[4] + count = int((end - start + 1) / period) + except ValueError: + continue + + total_regions += 1 + chromosomes.add(chrom) + repeat_units.add(ru) + period_counter[period] += 1 + counts.append(count) + + tbx.close() + + mean_count = statistics.mean(counts) if counts else None + median_count = statistics.median(counts) if counts else None + + return { + "total_regions": total_regions, + "chromosomes": len(chromosomes), + "unique_repeat_units": len(repeat_units), + "period_distribution": dict(period_counter), + "mean_repeat_count": mean_count, + "median_repeat_count": median_count, } - return stats def annotate_vcf( diff --git a/src/strvcf_annotator/cli.py b/src/strvcf_annotator/cli.py index 94ec11b..0b19407 100644 --- a/src/strvcf_annotator/cli.py +++ b/src/strvcf_annotator/cli.py @@ -94,6 +94,13 @@ def create_parser() -> argparse.ArgumentParser: "and VCF REF allele." ), ) + parser.add_argument( + "--jobs", + type=int, + help=( + "Number of parallel jobs to use for processing. Each job processes one VCF file. If not specified, the number of jobs is automatically determined based on CPU cores, number of files, and available RAM." + ), + ) parser.add_argument( "--mismatch-truth", choices=["panel", "vcf", "skip"], @@ -153,6 +160,7 @@ def main(): somatic_mode = getattr(args, "somatic_mode", False) ignore_mismatch_warnings = getattr(args, "ignore_mismatch_warnings", False) mismatch_truth = getattr(args, "mismatch_truth", "panel") + jobs = getattr(args, "jobs", None) annotator = STRAnnotator( args.str_bed, somatic_mode=somatic_mode, @@ -176,7 +184,7 @@ def main(): elif args.input_dir: # Batch directory mode logger.info(f"Processing directory: {args.input_dir}") - annotator.process_directory(args.input_dir, args.output_dir) + annotator.process_directory(args.input_dir, args.output_dir, jobs=jobs) logger.info(f"Successfully processed all VCF files to {args.output_dir}") logger.info("Annotation complete!") diff --git a/src/strvcf_annotator/core/__init__.py b/src/strvcf_annotator/core/__init__.py index 3281e68..25ce395 100644 --- a/src/strvcf_annotator/core/__init__.py +++ b/src/strvcf_annotator/core/__init__.py @@ -1,13 +1,13 @@ """Core modules for STR annotation functionality.""" +from .annotation import build_new_record, make_modified_header, should_skip_genotype +from .repeat_utils import apply_variant_to_repeat, count_repeat_units, extract_repeat_sequence from .str_reference import load_str_reference -from .repeat_utils import extract_repeat_sequence, count_repeat_units, apply_variant_to_repeat -from .annotation import make_modified_header, build_new_record, should_skip_genotype from .vcf_processor import ( + annotate_vcf_to_file, check_vcf_sorted, - reset_and_sort_vcf, generate_annotated_records, - annotate_vcf_to_file + reset_and_sort_vcf, ) __all__ = [ diff --git a/src/strvcf_annotator/core/str_reference.py b/src/strvcf_annotator/core/str_reference.py index 69292a4..0c20f01 100644 --- a/src/strvcf_annotator/core/str_reference.py +++ b/src/strvcf_annotator/core/str_reference.py @@ -1,114 +1,247 @@ """STR reference management for BED file loading and region lookups.""" +import os +from pathlib import Path from typing import Dict, Optional import pandas as pd +import pysam from ..utils.vcf_utils import chrom_to_order -def load_str_reference(str_path: str) -> pd.DataFrame: - """Load STR reference data from BED file. +def is_valid_tabix(gz_path: str) -> bool: + """Check that a BGZF file has a valid tabix index. - Loads a BED file containing STR (Short Tandem Repeat) regions and converts - coordinates from 0-based BED format to 1-based VCF format. Calculates the - number of repeat units for each region. + Returns True only if: + - .tbi exists + - pysam can open the file + - index is readable + """ + tbi_path = gz_path + ".tbi" + if not os.path.exists(gz_path) or not os.path.exists(tbi_path): + return False + + try: + tbx = pysam.TabixFile(gz_path) + # Accessing contigs forces index parsing + _ = tbx.contigs + tbx.close() + return True + except Exception: + return False + + +def sort_bed_file( + bed_path: str, + output_path: str, + chrom_col: int = 0, + start_col: int = 1, +) -> str: + """Sort a BED-like file by chromosome and start coordinate. + + Parameters + ---------- + bed_path : str + Path to input BED file (tab-delimited). + output_path : str + Path to write the sorted BED file. + chrom_col : int, optional + Zero-based column index for chromosome. Default is 0. + start_col : int, optional + Zero-based column index for start coordinate. Default is 1. + + Returns + ------- + str + Path to the sorted BED file. + + Notes + ----- + - This function loads the BED into memory via pandas. For extremely large + BED files, consider an external sort. + - Sorting is lexicographic by chromosome, then numeric by start. + """ + bed_path = str(bed_path) + output_path = str(output_path) + + df = pd.read_csv( + bed_path, + sep="\t", + header=None, + comment="#", + dtype={chrom_col: "string"}, + ) + + # Ensure start is numeric + df[start_col] = df[start_col].astype("int64") + + # Compute genomic order + chrom_series = df[chrom_col].astype("string") + chrom_order = chrom_series.map(chrom_to_order) + + # If unknown chromosomes appear, push them to the end + chrom_order = chrom_order.fillna(10**9) + + df = df.assign(_chrom_order=chrom_order) + + df.sort_values( + by=["_chrom_order", start_col], + kind="mergesort", # stable sort + inplace=True, + ) + + df.drop(columns=["_chrom_order"], inplace=True) + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + df.to_csv(output_path, sep="\t", header=False, index=False) + + return output_path + + +def load_str_reference(str_path: str) -> str: + """Ensure a BED file is BGZF-compressed and tabix-indexed. + + This function: + - Accepts a BED path (.bed or .bed.gz). + - If the input is already .gz and has a .tbi index, returns it. + - Otherwise, creates a sorted BED (if needed), BGZF-compresses it, and + creates a tabix index (preset="bed"). Parameters ---------- - str_path : str - Path to BED file with STR regions + bed_path : str + Path to input BED file (.bed or .bed.gz). Returns ------- - pd.DataFrame - DataFrame with columns: CHROM, START, END, PERIOD, RU, COUNT - - CHROM: Chromosome name - - START: 1-based start position (converted from BED 0-based) - - END: 1-based end position - - PERIOD: Length of repeat unit - - RU: Repeat unit sequence - - COUNT: Number of repeat units in the region + str + Path to the BGZF-compressed, tabix-indexed BED file (*.gz). Notes ----- - BED files use 0-based coordinates, but VCF files use 1-based coordinates. - This function converts START positions by adding 1. END positions are kept - as-is since BED END is exclusive and VCF END is inclusive. + - Tabix indexing requires the BED to be sorted by chromosome and start. + - This function uses `pysam.tabix_compress` and `pysam.tabix_index`. """ - df = pd.read_csv(str_path, sep="\t", header=None) - df.columns = ["CHROM", "START", "END", "PERIOD", "RU"] + bed_path = Path(str_path) + cache_dir_path = bed_path.parent + + # If input is already gz + tbi, trust and return. + if bed_path.suffix == ".gz" and bed_path.exists() and is_valid_tabix(str(bed_path)): + return str(bed_path) + + # Determine output names + # Use the base name without trailing .gz if present. + base_name = bed_path.name + if base_name.endswith(".gz"): + base_name = base_name[:-3] - # Convert from 0-based BED to 1-based VCF coordinates - df["START"] = df["START"] + 1 + sorted_bed = cache_dir_path / f"{base_name}.sorted" + gz_bed = cache_dir_path / f"{base_name}.sorted.gz" - # Calculate number of repeat units - df["COUNT"] = (df["END"] - df["START"] + 1) / df["PERIOD"] + # If cached gz + tbi exist, reuse. + if gz_bed.exists() and is_valid_tabix(str(gz_bed)): + return str(gz_bed) - # Add chromosome order column for proper sorting - df["CHROM_ORDER"] = df["CHROM"].apply(chrom_to_order) + # Ensure we have a sorted BED file to compress/index. + sort_bed_file(str(bed_path), str(sorted_bed)) - # Sort by chromosome (natural order) and position for efficient lookups - df.sort_values(by=["CHROM_ORDER", "START"], inplace=True) - df.drop(columns="CHROM_ORDER", inplace=True) - return df + # BGZF compress + pysam.tabix_compress(str(sorted_bed), str(gz_bed), force=True) + # Tabix index (BED preset expects chrom/start/end in the first columns) + pysam.tabix_index(str(gz_bed), preset="bed", force=True) -def find_overlapping_str(str_df: pd.DataFrame, chrom: str, pos: int, end: int) -> Optional[Dict]: - """Find STR region overlapping with variant coordinates. + return str(gz_bed) - Searches for an STR region that overlaps with the given variant position. - Uses efficient binary search on sorted DataFrame. + +def find_overlapping_str( + str_panel_gz: str, + chrom: str, + pos: int, + end: int, +) -> Optional[Dict]: + """Find STR region overlapping with variant coordinates using tabix index. Parameters ---------- - str_df : pd.DataFrame - DataFrame with STR regions (from load_str_reference) + str_panel_gz : str + Path to BGZF-compressed, tabix-indexed STR reference file. chrom : str - Chromosome name + Chromosome name. pos : int - Variant start position (1-based) + Variant start position (1-based). end : int - Variant end position (1-based) + Variant end position (1-based). Returns ------- Optional[Dict] - Dictionary with STR region data if overlap found, None otherwise - Contains keys: CHROM, START, END, PERIOD, RU, COUNT + Dictionary with STR region data if overlap found, None otherwise. + Keys: CHROM, START, END, PERIOD, RU, COUNT """ - # Filter by chromosome - chrom_df = str_df[str_df["CHROM"] == chrom] + # Convert to tabix coordinate system: 0-based half-open + query_start = max(0, pos - 1) + query_end = end + + try: + tbx = pysam.TabixFile(str_panel_gz) + + for row in tbx.fetch(chrom, query_start, query_end): + parts = row.rstrip("\n").split("\t") + if len(parts) < 5: + continue + + try: + str_chrom = parts[0] + str_start = int(parts[1]) + str_end = int(parts[2]) + period = int(parts[3]) + ru = parts[4] + count = int((str_end - str_start + 1) / period) + except ValueError: + continue + + # Check true overlap in 1-based coordinates + if str_start < end and str_end >= pos: + result = { + "CHROM": str_chrom, + "START": str_start, + "END": str_end, + "PERIOD": period, + "RU": ru, + "COUNT": count, + } + return result - if chrom_df.empty: return None - # Find overlapping regions - # Overlap occurs when: variant_end >= str_start AND variant_start <= str_end - overlapping = chrom_df[(chrom_df["START"] <= end) & (chrom_df["END"] >= pos)] - - if overlapping.empty: + except ValueError: + # Chromosome not present in index return None - - # Return first overlapping region - return overlapping.iloc[0].to_dict() + finally: + tbx.close() -def get_str_at_position(str_df: pd.DataFrame, chrom: str, pos: int) -> Optional[Dict]: - """Get STR region containing a specific position. +def get_str_at_position( + str_panel_gz: str, + chrom: str, + pos: int, +) -> Optional[Dict]: + """Get STR region containing a specific position using tabix index. Parameters ---------- - str_df : pd.DataFrame - DataFrame with STR regions (from load_str_reference) + str_panel_gz : str + Path to BGZF-compressed, tabix-indexed STR reference file. chrom : str - Chromosome name + Chromosome name. pos : int - Position to query (1-based) + Position to query (1-based). Returns ------- Optional[Dict] - Dictionary with STR region data if position is within an STR, None otherwise + Dictionary with STR region data if position is within an STR, None otherwise. """ - return find_overlapping_str(str_df, chrom, pos, pos) + return find_overlapping_str(str_panel_gz, chrom, pos, pos) diff --git a/src/strvcf_annotator/core/vcf_processor.py b/src/strvcf_annotator/core/vcf_processor.py index 45a8c88..109a012 100644 --- a/src/strvcf_annotator/core/vcf_processor.py +++ b/src/strvcf_annotator/core/vcf_processor.py @@ -1,19 +1,61 @@ """VCF file processing and workflow management.""" import logging +import multiprocessing +import os +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass from pathlib import Path -from typing import Iterator, List +from typing import Dict, Iterator, List, Optional, Tuple -import pandas as pd +import psutil import pysam from ..parsers.base import BaseVCFParser from ..parsers.generic import GenericParser from ..utils.vcf_utils import chrom_to_order from .annotation import build_new_record, make_modified_header, should_skip_genotype -from .str_reference import load_str_reference logger = logging.getLogger(__name__) +# Globals initialized once per worker process +WORKER_CONFIG = None + + +@dataclass(frozen=True) +class WorkerConfig: + """Configuration container for worker processes. + + Stores parameters required by each worker to annotate VCF files. + The configuration is passed once during worker initialization and + reused for all tasks processed by that worker. + + Attributes + ---------- + str_panel_gz : str + Path to the BGZF-compressed, tabix-indexed STR reference file. + somatic_mode : bool, optional + Enable somatic filtering. When True, skips variants where both samples + have identical genotypes. Default is False. + ignore_mismatch_warnings : bool, optional + If True, suppresses warnings about reference mismatches between the + STR panel and VCF REF alleles. Default is False. + mismatch_truth : str, optional + Specifies which source to consider as ground truth for mismatches. + Options are "panel", "vcf", or "skip". Default is "panel". + + Notes + ----- + - The dataclass is frozen to ensure the configuration remains + immutable once workers are initialized. + - Instances of this class are passed to `worker_init`, which loads + the STR reference and exposes these settings to worker tasks. + """ + + str_panel_gz: str + somatic_mode: bool + ignore_mismatch_warnings: bool + mismatch_truth: str + parser: BaseVCFParser def check_vcf_sorted(vcf_in: pysam.VariantFile) -> bool: @@ -85,7 +127,7 @@ def reset_and_sort_vcf(vcf_in: pysam.VariantFile) -> List[pysam.VariantRecord]: def generate_annotated_records( vcf_in: pysam.VariantFile, - str_df: pd.DataFrame, + str_panel_gz: str, parser: BaseVCFParser = None, somatic_mode: bool = False, ignore_mismatch_warnings: bool = False, @@ -103,8 +145,8 @@ def generate_annotated_records( ---------- vcf_in : pysam.VariantFile Input VCF file - str_df : pd.DataFrame - DataFrame with STR regions (from load_str_reference) + str_panel_gz : str + Path to BGZF-compressed, tabix-indexed STR reference file. parser : BaseVCFParser, optional Parser for genotype extraction. Uses GenericParser if None. somatic_mode : bool, optional @@ -147,54 +189,56 @@ def generate_annotated_records( vcf_in.reset() records = vcf_in.fetch() - # Prepare STR list for efficient lookup - str_idx = 0 - str_list = str_df.to_dict("records") + # Open tabix once for the whole generator (fast, avoids reopen per record) + tbx = pysam.TabixFile(str_panel_gz) skipped_count = 0 - for record in records: - # Advance STR index to current chromosome/position - while str_idx < len(str_list) and ( - str_list[str_idx]["CHROM"] != record.chrom - or ( - str_list[str_idx]["CHROM"] == record.chrom and str_list[str_idx]["END"] < record.pos - ) - ): - str_idx += 1 - - if str_idx >= len(str_list): - break - - str_row = str_list[str_idx] - - # Check for overlap - variant position should be within STR region - if ( - str_row["CHROM"] != record.chrom - or record.pos < str_row["START"] - or record.pos > str_row["END"] - ): - continue # No overlap - - # Skip based on genotype filtering (only if somatic_mode enabled) - if somatic_mode and should_skip_genotype(record, parser): - skipped_count += 1 - logger.debug( - f"Skipped {record.contig}:{record.pos} - identical genotypes (somatic mode)" - ) - continue - - # Try all STR intervals that overlap this record.pos on the same chromosome - chosen = None - j = str_idx - while j < len(str_list) and str_list[j]["CHROM"] == record.chrom: - str_row = str_list[j] - - # Once START is past POS, no further overlaps are possible - if str_row["START"] > record.pos: - break - - # Candidate overlaps POS? - if str_row["START"] <= record.pos <= str_row["END"]: + try: + for record in records: + # Skip based on genotype filtering (only if somatic_mode enabled) + if somatic_mode and should_skip_genotype(record, parser): + skipped_count += 1 + logger.debug( + f"Skipped {record.contig}:{record.pos} - identical genotypes (somatic mode)" + ) + continue + + # Query tabix for a 1bp window at record.pos. + # Tabix uses 0-based half-open coordinates. + query_start = max(0, record.pos - 1) + query_end = record.pos + try: + candidates = tbx.fetch(record.chrom, query_start, query_end) + except ValueError: + # Chromosome not present in the index + continue + + def parse_str_line(line: str) -> Optional[Dict]: + parts = line.rstrip("\n").split("\t") + if len(parts) < 5: + return None + try: + row = { + "CHROM": parts[0], + "START": int(parts[1]) + 1, # Convert to 1-based + "END": int(parts[2]), + "PERIOD": int(parts[3]), + "RU": parts[4], + } + row["COUNT"] = int((row["END"] - row["START"] + 1) / row["PERIOD"]) + return row + except ValueError: + return None + + chosen = None + for line in candidates: + str_row = parse_str_line(line) + if str_row is None: + continue + # True overlap check in 1-based coordinates + if not (str_row["START"] <= record.pos <= str_row["END"]): + continue + new_record = build_new_record( record, str_row, @@ -205,7 +249,6 @@ def generate_annotated_records( ) # Skip in case of mismatch and mismatch_truth is "skip" if new_record is None: - j += 1 continue if new_record.alleles[0] != new_record.alleles[1]: @@ -213,12 +256,11 @@ def generate_annotated_records( break # If variant effectively doesn't change STR allele (e.g., indel outside STR after normalization), # treat this STR row as not applicable and try next overlapping STR. - j += 1 - - if chosen is None: - continue - yield chosen + if chosen is not None: + yield chosen + finally: + tbx.close() # Log summary if records were skipped if skipped_count > 0: @@ -229,7 +271,7 @@ def generate_annotated_records( def annotate_vcf_to_file( vcf_path: str, - str_df: pd.DataFrame, + str_panel_gz: str, output_path: str, parser: BaseVCFParser = None, somatic_mode: bool = False, @@ -245,8 +287,8 @@ def annotate_vcf_to_file( ---------- vcf_path : str Path to input VCF file - str_df : pd.DataFrame - DataFrame with STR regions (from load_str_reference) + str_panel_gz : str + Path to BGZF-compressed, tabix-indexed STR reference file. output_path : str Path to output VCF file parser : BaseVCFParser, optional @@ -278,7 +320,7 @@ def annotate_vcf_to_file( written_count = 0 for record in generate_annotated_records( vcf_in, - str_df, + str_panel_gz, parser, somatic_mode=somatic_mode, ignore_mismatch_warnings=ignore_mismatch_warnings, @@ -293,14 +335,189 @@ def annotate_vcf_to_file( logger.info(f"Wrote {written_count} annotated records to {output_path}") +def annotate_one_vcf(task: Tuple[str, str]) -> str: + """Annotate a single VCF file in a worker process. + + Runs `annotate_vcf_to_file` for one input VCF and writes the annotated VCF + to the given output path. Created to be executed inside a process pool. + + Parameters + ---------- + task : Tuple[str, str] + (vcf_path, output_path) pair, where: + - vcf_path is the input VCF (optionally gzipped) + - output_path is the target annotated VCF path + + Returns + ------- + str + Path to the produced output VCF file. + + Notes + ----- + - Expects STR reference (STR_DF) and worker configuration (WORKER_CONFIG) + to be initialized once per worker via `worker_init`. + """ + global WORKER_CONFIG + + vcf_path, output_path = task + + annotate_vcf_to_file( + vcf_path=vcf_path, + str_panel_gz=WORKER_CONFIG.str_panel_gz, + output_path=output_path, + parser=WORKER_CONFIG.parser, + somatic_mode=WORKER_CONFIG.somatic_mode, + ignore_mismatch_warnings=WORKER_CONFIG.ignore_mismatch_warnings, + mismatch_truth=WORKER_CONFIG.mismatch_truth, + ) + return output_path + + +def get_available_ram_bytes() -> int: + """Get available system RAM. + + Returns + ------- + int + Available RAM in bytes. + """ + return int(psutil.virtual_memory().available) + + +def estimate_ram_per_worker_bytes(vcf_paths: List[str]) -> int: + """Estimate RAM usage per worker for VCF annotation. + + Provides an estimate of how much memory a single worker process + might require while annotating one VCF. This estimate is used to cap the + number of concurrent workers to reduce the risk of out-of-memory (OOM) + crashes. + + Parameters + ---------- + vcf_paths : list[str] + List of input VCF paths that will be processed. + + Returns + ------- + int + Estimated RAM usage per worker in bytes. + + Notes + ----- + - If a VCF is not sorted, the current pipeline may load all records into + memory for sorting, which can drastically increase memory usage. + - Even for sorted VCFs, pysam/htslib buffers plus Python object overhead + can be substantial. + - Each worker loads the STR reference once. The STR DataFrame and derived + Python objects often consume several times the BED file size on disk. + + Heuristic + --------- + - Identify the largest input file size on disk. + - If the largest file is gzipped, assume a higher expansion factor for the + working set (e.g., decompression + object overhead). + - Add a fixed overhead to account for Python/pysam allocations. + - Add STR panel RAM estimate as: str_panel_factor * BED_size_on_disk. + + This is intentionally conservative to avoid OOM. + """ + max_size = 0 + max_path = "" + for p in vcf_paths: + try: + s = os.path.getsize(p) + except OSError: + s = 0 + if s > max_size: + max_size = s + max_path = p + + is_gz = max_path.endswith(".gz") + expansion_factor = 5 if is_gz else 2 # VCF decompression + object overhead + + fixed_overhead = 700 * 1024**2 # ~700MB overhead per worker + + estimate = fixed_overhead + (expansion_factor * max_size) + + # Clamp to sane minimum/maximum to avoid weird estimates on tiny/huge files + min_estimate = 1 * 1024**3 # 1 GB + max_estimate = 120 * 1024**3 # 120 GB + return int(min(max(estimate, min_estimate), max_estimate)) + + +def compute_jobs_auto(n_files: int, vcf_paths: List[str]) -> int: + """Compute an automatic number of concurrent workers. + + Chooses a default number of parallel jobs for processing a directory of VCF + files, balancing CPU capacity and memory constraints. + + Parameters + ---------- + n_files : int + Number of VCF files that will be processed (after skipping outputs that + already exist). + vcf_paths : list[str] + List of VCF paths used to estimate per-worker memory usage. + + Returns + ------- + int + Recommended number of concurrent worker processes (at least 1). + + Notes + ----- + The selection follows: + - jobs_auto = min(cpu_cores, n_files) + - jobs_auto = min(jobs_auto, floor(available_ram / ram_per_worker_estimate)) + + If available RAM cannot be determined, the CPU-based limit is used. + """ + cpu_cores = multiprocessing.cpu_count() + jobs_auto = min(cpu_cores, n_files) + + available = get_available_ram_bytes() + ram_per_worker = estimate_ram_per_worker_bytes(vcf_paths) + + if available > 0 and ram_per_worker > 0: + ram_cap = max(1, available // ram_per_worker) + jobs_auto = min(jobs_auto, int(ram_cap)) + + return max(1, int(jobs_auto)) + + +def worker_init(config: WorkerConfig) -> None: + """Initialize worker process state. + + Called once when a worker process starts. Stores configuration + values so they can be reused for all VCF files processed by that worker. + + Parameters + ---------- + config : WorkerConfig + Configuration object containing: + - str_panel_gz : path to STR panel BGZF-compressed, tabix-indexed reference file + - somatic_mode : whether somatic filtering is enabled + - ignore_mismatch_warnings : whether to suppress mismatch warnings + - mismatch_truth : rule for handling panel/VCF mismatches + + Notes + ----- + - This function is used as the `initializer` for `ProcessPoolExecutor`. + """ + global WORKER_CONFIG + WORKER_CONFIG = config + + def process_directory( input_dir: str, - str_bed_path: str, + str_panel_gz: str, output_dir: str, parser: BaseVCFParser = None, somatic_mode: bool = False, ignore_mismatch_warnings: bool = False, mismatch_truth: str = "panel", + jobs: int = None, ) -> None: """Batch process directory of VCF files. @@ -311,8 +528,8 @@ def process_directory( ---------- input_dir : str Directory containing input VCF files - str_bed_path : str - Path to BED file with STR regions + str_panel_gz : str + Path to BGZF-compressed, tabix-indexed STR panel reference file output_dir : str Directory for output VCF files parser : BaseVCFParser, optional @@ -328,6 +545,11 @@ def process_directory( - "panel": trust panel repeat sequence (default behavior) - "vcf": trust VCF REF. patch the panel repeat sequence overlap to match VCF REF - "skip": skip record with mismatch + jobs: int, optional + - If jobs is None: compute jobs automatically: + jobs_auto = min(cpu_cores, n_files) + jobs_auto = min(jobs_auto, floor(available_ram / ram_per_worker_estimate)) + - If jobs is provided: use it exactly. """ if parser is None: @@ -336,11 +558,12 @@ def process_directory( # Create output directory Path(output_dir).mkdir(parents=True, exist_ok=True) - # Load STR reference - str_df = load_str_reference(str_bed_path) - # Process each VCF file input_path = Path(input_dir) + + tasks = [] + vcf_paths_for_estimate = [] + for vcf_file in input_path.glob("*.vcf*"): if vcf_file.suffix in [".vcf", ".gz"]: # Generate output filename @@ -354,14 +577,54 @@ def process_directory( logger.info(f"Skipping {vcf_file.name} — already processed.") continue - logger.info(f"Processing {vcf_file.name}...") - annotate_vcf_to_file( - str(vcf_file), - str_df, - str(output_file), - parser, - somatic_mode=somatic_mode, - ignore_mismatch_warnings=ignore_mismatch_warnings, - mismatch_truth=mismatch_truth, - ) - logger.info(f" → Output: {output_file}") + tasks.append((str(vcf_file), str(output_file))) + vcf_paths_for_estimate.append(str(vcf_file)) + if not tasks: + logger.info("No VCF files to process.") + return + + # Process larger files first (reduces idle time at end) + def get_size(path: str) -> int: + try: + return os.path.getsize(path) + except OSError: + return 0 + + tasks.sort(key=lambda t: get_size(t[0]), reverse=True) + + n_files = len(tasks) + + if jobs is None: + jobs_to_use = compute_jobs_auto(n_files=n_files, vcf_paths=vcf_paths_for_estimate) + avail_gb = get_available_ram_bytes() / (1024**3) if get_available_ram_bytes() else 0 + est_gb = estimate_ram_per_worker_bytes(vcf_paths_for_estimate) / (1024**3) + logger.info( + f"Auto jobs={jobs_to_use} (files={n_files}, cpu={multiprocessing.cpu_count()}, " + f"available_ram≈{avail_gb:.1f}GB, est_ram_per_worker≈{est_gb:.1f}GB)" + ) + else: + jobs_to_use = max(1, int(jobs)) + logger.info(f"Using fixed jobs={jobs_to_use} (files={n_files})") + + config = WorkerConfig( + str_panel_gz=str_panel_gz, + somatic_mode=somatic_mode, + ignore_mismatch_warnings=ignore_mismatch_warnings, + mismatch_truth=mismatch_truth, + parser=parser, + ) + + with ProcessPoolExecutor( + max_workers=jobs_to_use, + initializer=worker_init, + initargs=(config,), + ) as executor: + future_to_task = {executor.submit(annotate_one_vcf, t): t for t in tasks} + + for future in as_completed(future_to_task): + vcf_path, out_path = future_to_task[future] + try: + produced = future.result() + logger.info(f"Done: {Path(vcf_path).name} → {produced}") + except Exception: + logger.exception(f"Failed: {Path(vcf_path).name} → {out_path}") diff --git a/tests/cli/test_cli_commands.py b/tests/cli/test_cli_commands.py index 209cdbc..9ef3ec1 100644 --- a/tests/cli/test_cli_commands.py +++ b/tests/cli/test_cli_commands.py @@ -46,7 +46,7 @@ def test_version_command(self): """Test --version flag.""" result = subprocess.run(["strvcf-annotator", "--version"], capture_output=True, text=True) assert result.returncode == 0 - assert "0.3.0" in result.stdout + assert "1.0.0" in result.stdout def test_no_arguments_fails(self): """Test that running without arguments fails.""" diff --git a/tests/conftest.py b/tests/conftest.py index cd52189..e752b59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import os import shutil +import zipfile from pathlib import Path import pytest @@ -10,6 +11,7 @@ def data_dir(): """Provides absolute path to the test data directory.""" return os.path.abspath(os.path.join(os.path.dirname(__file__), "data")) + def pytest_addoption(parser): parser.addoption( "--update-vcf-hashes", @@ -20,7 +22,7 @@ def pytest_addoption(parser): # Write all outputs here (committed folder, but files are NOT committed) -OUTPUT_DIR = Path(__file__).resolve().parents[1] / "output" +OUTPUT_DIR = Path(__file__).resolve().parents[0] / "output" @pytest.fixture(scope="session", autouse=True) @@ -42,16 +44,46 @@ def output_path(request): out.unlink() -@pytest.fixture +@pytest.fixture(scope="session") def output_dir(request): """ Provide a unique output directory per test under tests/output/. Ensures cleanup after the test. """ - outdir = OUTPUT_DIR / f"{request.node.name}_dir" + test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + outdir = OUTPUT_DIR / f"{test_name}_dir" if outdir.exists(): shutil.rmtree(outdir) outdir.mkdir(parents=True, exist_ok=True) yield outdir if outdir.exists(): shutil.rmtree(outdir) + + +@pytest.fixture(scope="session") +def vcf_dir(data_dir: str, output_dir: str) -> str: + """ + Unpack tests/data/vcfs/test_input.zip into tests/output/vcfs + and return the directory containing VCF files. + + If the directory already exists (from a previous run), it is reused. + """ + data_path = Path(data_dir) + zip_path = data_path / "vcfs" / "test_input.zip" + assert zip_path.is_file(), f"Missing test input zip: {zip_path}" + + vcf_root = Path(output_dir) / "vcfs" + vcf_root.mkdir(parents=True, exist_ok=True) + + # If directory is empty, extract; otherwise assume it's already populated + if not any(vcf_root.iterdir()): + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(vcf_root) + + inner = list(vcf_root.iterdir()) + if len(inner) == 1 and inner[0].is_dir(): + # Zip contained a single directory; use that + return str(inner[0]) + + # Otherwise, use the top-level extraction dir + return str(vcf_root) diff --git a/tests/data/vcfs/test_input.zip b/tests/data/vcfs/test_input.zip new file mode 100755 index 0000000..81f64bf Binary files /dev/null and b/tests/data/vcfs/test_input.zip differ diff --git a/tests/integration/test_vcf_pipeline.py b/tests/integration/test_vcf_pipeline.py index fc1c8fe..0fafe49 100644 --- a/tests/integration/test_vcf_pipeline.py +++ b/tests/integration/test_vcf_pipeline.py @@ -1,5 +1,8 @@ import hashlib import os +import time +from pathlib import Path +from typing import List import pytest @@ -80,3 +83,75 @@ def test_reannotating_test_vcf_is_idempotent(self, data_dir, output_dir): assert first_hash == second_hash, ( f"Re-annotating {input_vcf} is not idempotent: {first_hash} != {second_hash}" ) + + +def list_input_vcfs(vcf_dir: str) -> List[str]: + """Return input VCF/VCF.GZ files found in vcf_dir (sorted).""" + root = Path(vcf_dir) + files = sorted([str(p) for p in root.rglob("*.vcf")]) + sorted( + [str(p) for p in root.rglob("*.vcf.gz")] + ) + return files + + +def expected_output_path(output_dir: str, input_vcf: str) -> str: + """Mirror process_directory output naming: .annotated.vcf""" + name = os.path.basename(input_vcf).replace(".vcf.gz", "").replace(".vcf", "") + return os.path.abspath(os.path.join(output_dir, f"{name}.annotated.vcf")) + + +@pytest.mark.integration +class TestProcessDirectoryParallel: + """Integration tests for directory-level parallel processing.""" + + def test_parallel_vs_serial(self, vcf_dir, output_dir): + """Serial (jobs=1) and parallel (jobs>1) runs should produce identical outputs.""" + str_bed = os.path.abspath(os.path.join(base_dir, "data", "GRCh38_repeats.bed")) + + inputs = list_input_vcfs(vcf_dir) + assert len(inputs) > 0, f"No VCF files found in {vcf_dir}" + + serial_out = os.path.abspath(os.path.join(output_dir, "serial")) + parallel_out = os.path.abspath(os.path.join(output_dir, "parallel")) + os.makedirs(serial_out, exist_ok=True) + os.makedirs(parallel_out, exist_ok=True) + + annotator = STRAnnotator(str_bed) + + # Run serial (jobs=1) + t0 = time.perf_counter() + annotator.process_directory( + input_dir=vcf_dir, + output_dir=serial_out, + jobs=1, + ) + t1 = time.perf_counter() + serial_time = t1 - t0 + + # Run parallel (jobs=auto) + t2 = time.perf_counter() + annotator.process_directory(input_dir=vcf_dir, output_dir=parallel_out) + t3 = time.perf_counter() + parallel_time = t3 - t2 + + # Parallel should not be slower + assert parallel_time <= serial_time, ( + f"Parallel run slower than expected: serial={serial_time:.3f}s " + f"parallel={parallel_time:.3f}s" + ) + + # Compare hashes for all expected outputs + for input_vcf in inputs: + serial_file = expected_output_path(serial_out, input_vcf) + parallel_file = expected_output_path(parallel_out, input_vcf) + + assert os.path.exists(serial_file), f"Missing serial output: {serial_file}" + assert os.path.exists(parallel_file), f"Missing parallel output: {parallel_file}" + + serial_hash = file_hash(serial_file) + parallel_hash = file_hash(parallel_file) + + assert serial_hash == parallel_hash, ( + f"Output mismatch for {os.path.basename(input_vcf)}: " + f"serial={serial_hash} parallel={parallel_hash}" + ) diff --git a/tests/test_performance.py b/tests/test_performance.py index 438b7a2..2122017 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -8,6 +8,7 @@ import pytest from strvcf_annotator import STRAnnotator +from strvcf_annotator.core.str_reference import is_valid_tabix class TestPerformance: @@ -60,7 +61,9 @@ def test_loading_performance(self, large_str_bed): # Should load 1000 regions in under 1 second assert load_time < 1.0, f"Loading took {load_time:.2f}s, expected < 1.0s" - assert len(annotator.str_df) == 1000 + assert Path(annotator.str_panel_gz).exists() and is_valid_tabix(annotator.str_panel_gz), ( + "STR BED file should be cached and indexed" + ) def test_annotation_performance(self, large_str_bed, large_vcf): """Test annotation performance.""" diff --git a/tests/test_str_reference.py b/tests/test_str_reference.py index f959820..ab0cf3d 100644 --- a/tests/test_str_reference.py +++ b/tests/test_str_reference.py @@ -1,14 +1,17 @@ """Unit tests for STR reference management.""" +import os import tempfile from pathlib import Path import pandas as pd +import pysam import pytest from strvcf_annotator.core.str_reference import ( find_overlapping_str, get_str_at_position, + is_valid_tabix, load_str_reference, ) @@ -20,8 +23,8 @@ class TestLoadSTRReference: def temp_bed_file(self): """Create temporary BED file.""" content = """chr1\t100\t115\t3\tCAG -chr1\t200\t212\t4\tATCG -chr2\t300\t318\t3\tGAT""" +chr2\t200\t212\t4\tATCG +chr1\t300\t318\t3\tGAT""" with tempfile.NamedTemporaryFile(mode="w", suffix=".bed", delete=False) as f: f.write(content) @@ -32,127 +35,172 @@ def temp_bed_file(self): # Cleanup Path(temp_path).unlink() - def test_load_basic(self, temp_bed_file): - """Test basic BED file loading.""" - df = load_str_reference(temp_bed_file) + def test_creates_gz_and_tbi(self, temp_bed_file): + """Test that load_str_reference creates .gz and .tbi.""" + gz_path = load_str_reference(temp_bed_file) - assert len(df) == 3 - assert list(df.columns) == ["CHROM", "START", "END", "PERIOD", "RU", "COUNT"] + assert gz_path.endswith(".gz") + assert Path(gz_path).exists() and is_valid_tabix(gz_path), ( + "BGZF-compressed and indexed file should exist" + ) - def test_coordinate_conversion(self, temp_bed_file): - """Test BED to VCF coordinate conversion.""" - df = load_str_reference(temp_bed_file) + def test_reuses_existing_indexed_panel(self, temp_bed_file): + """Test that load_str_reference reuses already created gz+tbi.""" + gz_path_1 = load_str_reference(temp_bed_file) + mtime_gz_1 = os.path.getmtime(gz_path_1) + mtime_tbi_1 = os.path.getmtime(gz_path_1 + ".tbi") - # BED START 100 should become VCF START 101 - assert df.iloc[0]["START"] == 101 - # BED END should remain the same - assert df.iloc[0]["END"] == 115 + gz_path_2 = load_str_reference(temp_bed_file) + mtime_gz_2 = os.path.getmtime(gz_path_2) + mtime_tbi_2 = os.path.getmtime(gz_path_2 + ".tbi") - def test_count_calculation(self, temp_bed_file): - """Test repeat count calculation.""" - df = load_str_reference(temp_bed_file) + assert gz_path_1 == gz_path_2 + assert mtime_gz_1 == mtime_gz_2 + assert mtime_tbi_1 == mtime_tbi_2 - # (115 - 101 + 1) / 3 = 5 - assert df.iloc[0]["COUNT"] == 5.0 + def test_output_is_sorted_for_tabix(self, temp_bed_file): + """Test that output is sorted by chromosome order and start.""" + gz_path = load_str_reference(temp_bed_file) - def test_sorting(self, temp_bed_file): - """Test that output is sorted.""" - df = load_str_reference(temp_bed_file) + tbx = pysam.TabixFile(gz_path) + rows = list(tbx.fetch()) # full file iteration + tbx.close() - # Check chromosomes are sorted - chroms = df["CHROM"].tolist() - assert chroms == sorted(chroms) + parsed = [] + for line in rows: + parts = line.rstrip("\n").split("\t") + chrom = parts[0] + start = int(parts[1]) + parsed.append((chrom, start)) - # Check positions within chromosome are sorted - for chrom in df["CHROM"].unique(): - chrom_df = df[df["CHROM"] == chrom] - positions = chrom_df["START"].tolist() - assert positions == sorted(positions) + # chr1 entries first + chr1_starts = [s for c, s in parsed if c == "chr1"] + chr2_starts = [s for c, s in parsed if c == "chr2"] + + assert len(chr1_starts) == 2 + assert len(chr2_starts) == 1 + assert chr1_starts == sorted(chr1_starts) + + # Ensure overall order doesn't put chr2 before chr1 + first_chr = parsed[0][0] + assert first_chr == "chr1" + + def test_recreates_index_if_tbi_is_invalid(self, temp_bed_file): + """Test that an invalid .tbi triggers rebuilding a valid tabix index.""" + gz_path = load_str_reference(temp_bed_file) + tbi_path = gz_path + ".tbi" + + assert Path(gz_path).exists() and is_valid_tabix(gz_path), ( + "BGZF-compressed and indexed file should exist" + ) + + # Record times so we can detect replacement if the same path is reused + old_gz = gz_path + old_tbi_mtime = os.path.getmtime(tbi_path) + + # Corrupt the index file + with open(tbi_path, "wb") as f: + f.write(b"NOT_A_REAL_TABIX_INDEX") + + # Ensure filesystem mtime changes (some FS have 1s resolution) + # time.sleep(1.1) + + # Now call load_str_reference on the gz itself + rebuilt_gz_path = load_str_reference(old_gz) + rebuilt_tbi_path = rebuilt_gz_path + ".tbi" + + assert Path(rebuilt_gz_path).exists() and is_valid_tabix(rebuilt_gz_path), ( + "BGZF-compressed and indexed file should exist" + ) + + new_tbi_mtime = os.path.getmtime(rebuilt_tbi_path) + assert new_tbi_mtime > old_tbi_mtime, "Expected .tbi to be recreated" class TestFindOverlappingSTR: - """Test suite for find_overlapping_str.""" + """Test suite for find_overlapping_str (tabix-backed).""" @pytest.fixture - def str_df(self): - """Create sample STR DataFrame.""" - data = { - "CHROM": ["chr1", "chr1", "chr2"], - "START": [101, 201, 301], - "END": [115, 212, 318], - "PERIOD": [3, 4, 3], - "RU": ["CAG", "ATCG", "GAT"], - "COUNT": [5.0, 3.0, 6.0], - } - return pd.DataFrame(data) - - def test_exact_overlap(self, str_df): + def tabix_panel(self): + """Create a small bgzip+tabix STR panel for overlap tests.""" + content = """chr1\t101\t115\t3\tCAG\t5 +chr1\t201\t212\t4\tATCG\t3 +chr2\t301\t318\t3\tGAT\t6 +""" + with tempfile.TemporaryDirectory() as tmp: + bed_path = Path(tmp) / "panel.bed" + bed_path.write_text(content, encoding="utf-8") + + gz_path = load_str_reference(str(bed_path)) + yield gz_path + + def test_exact_overlap(self, tabix_panel): """Test exact position overlap.""" - result = find_overlapping_str(str_df, "chr1", 101, 115) + result = find_overlapping_str(tabix_panel, "chr1", 101, 115) assert result is not None assert result["START"] == 101 + assert result["END"] == 115 assert result["RU"] == "CAG" - def test_partial_overlap(self, str_df): + def test_partial_overlap(self, tabix_panel): """Test partial overlap.""" - result = find_overlapping_str(str_df, "chr1", 105, 110) + result = find_overlapping_str(tabix_panel, "chr1", 105, 110) assert result is not None assert result["START"] == 101 + assert result["RU"] == "CAG" - def test_no_overlap(self, str_df): + def test_no_overlap(self, tabix_panel): """Test no overlap.""" - result = find_overlapping_str(str_df, "chr1", 150, 160) - + result = find_overlapping_str(tabix_panel, "chr1", 150, 160) assert result is None - def test_wrong_chromosome(self, str_df): + def test_wrong_chromosome(self, tabix_panel): """Test wrong chromosome.""" - result = find_overlapping_str(str_df, "chr3", 101, 115) - + result = find_overlapping_str(tabix_panel, "chr3", 101, 115) assert result is None - def test_variant_extends_beyond(self, str_df): - """Test variant extending beyond STR.""" - result = find_overlapping_str(str_df, "chr1", 110, 120) + def test_variant_extends_beyond(self, tabix_panel): + """Test variant extending beyond STR region.""" + result = find_overlapping_str(tabix_panel, "chr1", 110, 120) assert result is not None assert result["START"] == 101 + assert result["END"] == 115 class TestGetSTRAtPosition: - """Test suite for get_str_at_position.""" + """Test suite for get_str_at_position (tabix-backed).""" @pytest.fixture - def str_df(self): - """Create sample STR DataFrame.""" - data = { - "CHROM": ["chr1", "chr1", "chr2"], - "START": [101, 201, 301], - "END": [115, 212, 318], - "PERIOD": [3, 4, 3], - "RU": ["CAG", "ATCG", "GAT"], - "COUNT": [5.0, 3.0, 6.0], - } - return pd.DataFrame(data) - - def test_position_in_str(self, str_df): + def tabix_panel(self): + """Create a small bgzip+tabix STR panel for position tests.""" + content = """chr1\t101\t115\t3\tCAG\t5 +chr1\t201\t212\t4\tATCG\t3 +chr2\t301\t318\t3\tGAT\t6 +""" + with tempfile.TemporaryDirectory() as tmp: + bed_path = Path(tmp) / "panel.bed" + bed_path.write_text(content, encoding="utf-8") + gz_path = load_str_reference(str(bed_path)) + yield gz_path + + def test_position_in_str(self, tabix_panel): """Test position within STR.""" - result = get_str_at_position(str_df, "chr1", 105) + result = get_str_at_position(tabix_panel, "chr1", 105) assert result is not None assert result["RU"] == "CAG" - def test_position_outside_str(self, str_df): + def test_position_outside_str(self, tabix_panel): """Test position outside STR.""" - result = get_str_at_position(str_df, "chr1", 150) - + result = get_str_at_position(tabix_panel, "chr1", 101) assert result is None - def test_position_at_boundary(self, str_df): + def test_position_at_boundary(self, tabix_panel): """Test position at STR boundary.""" - result = get_str_at_position(str_df, "chr1", 101) + result = get_str_at_position(tabix_panel, "chr1", 102) assert result is not None assert result["RU"] == "CAG" diff --git a/tests/unit/test_vcf_processor.py b/tests/unit/test_vcf_processor.py index afa0466..b6e732c 100644 --- a/tests/unit/test_vcf_processor.py +++ b/tests/unit/test_vcf_processor.py @@ -2,19 +2,33 @@ import os import tempfile +from pathlib import Path -import pandas as pd import pysam import pytest from strvcf_annotator.core.vcf_processor import ( check_vcf_sorted, + estimate_ram_per_worker_bytes, generate_annotated_records, + get_available_ram_bytes, reset_and_sort_vcf, ) from strvcf_annotator.parsers.generic import GenericParser +@pytest.fixture +def vcf_paths(data_dir): + """Return absolute paths to all shipped VCF test files.""" + files = [ + "test.vcf.gz", + "pindel_header.vcf", + "mutec2_indel.vcf.gz", + "TCGA-DC-6682.vcf", + ] + return [os.path.abspath(os.path.join(data_dir, f)) for f in files] + + @pytest.fixture def basic_vcf_header(): """Create a basic VCF header for testing.""" @@ -114,18 +128,37 @@ def unsorted_vcf_file(basic_vcf_header): @pytest.fixture -def str_dataframe(): - """Create a sample STR DataFrame for testing.""" - return pd.DataFrame( - { - "CHROM": ["chr1", "chr1", "chr2"], - "START": [95, 195, 95], - "END": [115, 215, 115], - "PERIOD": [2, 3, 2], - "RU": ["AT", "CAG", "GC"], - "COUNT": [10, 7, 10], - } - ) +def str_panel_tabix(): + """Create a small tabix-indexed STR panel for tests. + + Returns + ------- + str + Path to bgzip-compressed, tabix-indexed STR panel. + """ + content = """chr1\t95\t115\t2\tAT\t10 +chr1\t195\t215\t3\tCAG\t7 +chr2\t95\t115\t2\tGC\t10 +""" + + with tempfile.TemporaryDirectory() as tmp: + bed_path = Path(tmp) / "str_panel.bed" + bed_path.write_text(content, encoding="utf-8") + + # Sort to guarantee tabix compatibility + sorted_path = Path(tmp) / "str_panel.sorted.bed" + lines = sorted( + [l.strip() for l in content.strip().splitlines()], + key=lambda x: (x.split("\t")[0], int(x.split("\t")[1])), + ) + sorted_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + gz_path = Path(tmp) / "str_panel.sorted.bed.gz" + + pysam.tabix_compress(str(sorted_path), str(gz_path), force=True) + pysam.tabix_index(str(gz_path), preset="bed", force=True) + + yield str(gz_path) class TestCheckVCFSorted: @@ -315,12 +348,12 @@ def test_sorts_by_contig_order(self, basic_vcf_header): class TestGenerateAnnotatedRecords: """Tests for generate_annotated_records function.""" - def test_returns_iterator(self, sorted_vcf_file, str_dataframe): + def test_returns_iterator(self, sorted_vcf_file, str_panel_tabix): """Test that function returns an iterator.""" vcf_in = pysam.VariantFile(sorted_vcf_file) parser = GenericParser() - result = generate_annotated_records(vcf_in, str_dataframe, parser) + result = generate_annotated_records(vcf_in, str_panel_tabix, parser) # Should be an iterator assert hasattr(result, "__iter__") @@ -328,12 +361,12 @@ def test_returns_iterator(self, sorted_vcf_file, str_dataframe): vcf_in.close() - def test_yields_variant_records(self, sorted_vcf_file, str_dataframe): + def test_yields_variant_records(self, sorted_vcf_file, str_panel_tabix): """Test that iterator yields VariantRecord objects.""" vcf_in = pysam.VariantFile(sorted_vcf_file) parser = GenericParser() - records = list(generate_annotated_records(vcf_in, str_dataframe, parser)) + records = list(generate_annotated_records(vcf_in, str_panel_tabix, parser)) if len(records) > 0: assert all(isinstance(r, pysam.VariantRecord) for r in records) @@ -342,54 +375,54 @@ def test_yields_variant_records(self, sorted_vcf_file, str_dataframe): def test_filters_non_overlapping_variants(self, sorted_vcf_file): """Test that variants outside STR regions are filtered.""" + vcf_in = pysam.VariantFile(sorted_vcf_file) parser = GenericParser() - # Create STR dataframe with no overlap - str_df = pd.DataFrame( - { - "CHROM": ["chr10"], - "START": [1000], - "END": [1100], - "PERIOD": [2], - "RU": ["AT"], - "COUNT": [50], - } - ) + # Create a tabix-indexed STR panel with no overlap to the VCF (chr10 only) + content = "chr10\t1000\t1100\t2\tAT\t50\n" - records = list(generate_annotated_records(vcf_in, str_df, parser)) + with tempfile.TemporaryDirectory() as tmp: + bed_path = Path(tmp) / "no_overlap.bed" + bed_path.write_text(content, encoding="utf-8") - # Should yield no records (no overlap) - assert len(records) == 0 + gz_path = Path(tmp) / "no_overlap.bed.gz" + pysam.tabix_compress(str(bed_path), str(gz_path), force=True) + pysam.tabix_index(str(gz_path), preset="bed", force=True) + + records = list(generate_annotated_records(vcf_in, str(gz_path), parser)) + + # Should yield no records (no overlap) + assert len(records) == 0 vcf_in.close() - def test_handles_unsorted_vcf(self, unsorted_vcf_file, str_dataframe): + def test_handles_unsorted_vcf(self, unsorted_vcf_file, str_panel_tabix): """Test that function handles unsorted VCF by sorting it.""" vcf_in = pysam.VariantFile(unsorted_vcf_file) parser = GenericParser() # Should not raise an error - records = list(generate_annotated_records(vcf_in, str_dataframe, parser)) + records = list(generate_annotated_records(vcf_in, str_panel_tabix, parser)) # Records should be processed (may be 0 if no overlaps) assert isinstance(records, list) vcf_in.close() - def test_uses_generic_parser_by_default(self, sorted_vcf_file, str_dataframe): + def test_uses_generic_parser_by_default(self, sorted_vcf_file, str_panel_tabix): """Test that GenericParser is used when parser=None.""" vcf_in = pysam.VariantFile(sorted_vcf_file) # Call without parser - records = list(generate_annotated_records(vcf_in, str_dataframe, parser=None)) + records = list(generate_annotated_records(vcf_in, str_panel_tabix, parser=None)) # Should work without error assert isinstance(records, list) vcf_in.close() - def test_empty_vcf_returns_no_records(self, basic_vcf_header, str_dataframe): + def test_empty_vcf_returns_no_records(self, basic_vcf_header, str_panel_tabix): """Test that empty VCF yields no records.""" temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".vcf", delete=False) temp_file.close() @@ -402,7 +435,7 @@ def test_empty_vcf_returns_no_records(self, basic_vcf_header, str_dataframe): vcf_in = pysam.VariantFile(temp_file.name) parser = GenericParser() - records = list(generate_annotated_records(vcf_in, str_dataframe, parser)) + records = list(generate_annotated_records(vcf_in, str_panel_tabix, parser)) assert len(records) == 0 @@ -410,16 +443,107 @@ def test_empty_vcf_returns_no_records(self, basic_vcf_header, str_dataframe): finally: os.unlink(temp_file.name) - def test_empty_str_dataframe_returns_no_records(self, sorted_vcf_file): - """Test that empty STR DataFrame yields no records.""" - vcf_in = pysam.VariantFile(sorted_vcf_file) - parser = GenericParser() + def test_empty_str_panel_tabix_returns_no_records(self, sorted_vcf_file): + """Test that empty STR panel (no overlapping loci) yields no records.""" - # Empty STR dataframe - empty_str_df = pd.DataFrame(columns=["CHROM", "START", "END", "PERIOD", "RU", "COUNT"]) + # Create a panel with one dummy locus on a chromosome not present in VCF + content = "chrUn\t1\t10\t2\tAT\t5\n" - records = list(generate_annotated_records(vcf_in, empty_str_df, parser)) + with tempfile.TemporaryDirectory() as tmp: + bed_path = Path(tmp) / "empty_panel.bed" + bed_path.write_text(content, encoding="utf-8") - assert len(records) == 0 + gz_path = Path(tmp) / "empty_panel.bed.gz" + pysam.tabix_compress(str(bed_path), str(gz_path), force=True) + pysam.tabix_index(str(gz_path), preset="bed", force=True) - vcf_in.close() + vcf_in = pysam.VariantFile(sorted_vcf_file) + parser = GenericParser() + + records = list(generate_annotated_records(vcf_in, str(gz_path), parser)) + + assert len(records) == 0 + + vcf_in.close() + + +class TestGetAvailableRamBytes: + """Tests for get_available_ram_bytes function.""" + + def test_returns_int(self): + """Test that function returns an integer.""" + result = get_available_ram_bytes() + assert isinstance(result, int) + + # @TODO fix + # def test_fake_memory(self, monkeypatch): + # """Test that psutil branch returns fake available RAM.""" + # fake_psutil = types.ModuleType("psutil") + + # class FakeVMem: + # available = 123456789 + + # def virtual_memory(): + # return FakeVMem() + + # fake_psutil.virtual_memory = virtual_memory + + # real_import = builtins.__import__ + + # def import_hook(name, globals=None, locals=None, fromlist=(), level=0): + # if name == "psutil": + # return fake_psutil + # return real_import(name, globals, locals, fromlist, level) + + # monkeypatch.setattr(builtins, "__import__", import_hook) + + # result = get_available_ram_bytes() + # assert result == 123456789 + + +class TestEstimateRamPerWorkerBytes: + """Tests for estimate_ram_per_worker_bytes function.""" + + def test_empty_list_returns_minimum(self): + """Test that empty input returns a safe minimum estimate.""" + result = estimate_ram_per_worker_bytes([]) + assert result == int(1 * 1024**3) # 1 GB + + def test_uses_largest_file_size(self, vcf_paths): + """Test that estimate is based on the largest file size among inputs.""" + sizes = {p: os.path.getsize(p) for p in vcf_paths} + max_path = max(sizes, key=sizes.get) + max_size = sizes[max_path] + + fixed_overhead = 700 * 1024**2 + expansion_factor = 5 if max_path.endswith(".gz") else 2 + + expected = fixed_overhead + expansion_factor * max_size + expected = min(max(expected, 1 * 1024**3), 120 * 1024**3) + + result = estimate_ram_per_worker_bytes(vcf_paths) + assert result == int(expected) + + def test_single_file_plain_vcf(self, data_dir): + """Test estimate for a single plain VCF.""" + path = os.path.abspath(os.path.join(data_dir, "TCGA-DC-6682.vcf")) + size = os.path.getsize(path) + + fixed_overhead = 700 * 1024**2 + expected = fixed_overhead + 2 * size + expected = min(max(expected, 1 * 1024**3), 120 * 1024**3) + + result = estimate_ram_per_worker_bytes([path]) + assert result == int(expected) + + def test_single_file_gz_vcf(self, data_dir): + """Test estimate for a single gzipped VCF.""" + path = os.path.abspath(os.path.join(data_dir, "test.vcf.gz")) + size = os.path.getsize(path) + + fixed_overhead = 700 * 1024**2 + expected = fixed_overhead + 5 * size + expected = min(max(expected, 1 * 1024**3), 120 * 1024**3) + + result = estimate_ram_per_worker_bytes([path]) + assert result == int(expected) diff --git a/tests/unit/test_vcf_utils.py b/tests/unit/test_vcf_utils.py new file mode 100644 index 0000000..6373356 --- /dev/null +++ b/tests/unit/test_vcf_utils.py @@ -0,0 +1,170 @@ +"""Unit tests for VCF utility functions.""" + +import os +import tempfile + +import pysam +import pytest + +from strvcf_annotator.utils.vcf_utils import ( + chrom_to_order, + get_sample_by_index, + get_sample_by_name, + has_format_field, + normalize_info_fields, +) + + +@pytest.fixture +def basic_header(): + """Create a basic VCF header for tests.""" + header = pysam.VariantHeader() + header.add_line("##fileformat=VCFv4.2") + header.contigs.add("chr1", length=1000000) + + header.add_sample("TUMOR") + header.add_sample("NORMAL") + + # INFO schema (S1 is Number=1 so pysam disallows tuple assignment via API) + header.info.add("FLAG1", 0, "Flag", "Flag field") + header.info.add("S1", 1, "String", "Single string") + header.info.add("I1", 1, "Integer", "Single integer") + header.info.add("R1", "R", "Integer", "REF+ALT values") + + header.formats.add("GT", 1, "String", "Genotype") + header.formats.add("AD", "R", "Integer", "Allelic depths") + + return header + + +@pytest.fixture +def basic_record(basic_header): + """Create a VariantRecord containing 'weird' INFO encodings for normalize_info_fields tests. + + We write the header via pysam, then append a raw VCF line to bypass pysam's + strict schema checks during assignment. + """ + tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".vcf", delete=False) + tmp.close() + + try: + # This writes the header automatically + with pysam.VariantFile(tmp.name, "w", header=basic_header): + pass + + # Append a raw record line with intentionally odd INFO encodings: + # - S1=a,b even though header says Number=1 (String) + # - I1=10,11 even though header says Number=1 (Integer) + # - R1 has 4 values though Number=R should be 2 for REF+ALT + raw_line = ( + "chr1\t100\t.\tA\tT\t30\tPASS\t" + "FLAG1;S1=a,b;I1=10,11;R1=1,2,3,4;UNKNOWN=skip_me\t" + "GT:AD\t0/1:10,5\t0/0:12,0\n" + ) + + with open(tmp.name, "a", encoding="utf-8") as f: + f.write(raw_line) + + with pysam.VariantFile(tmp.name) as vcf_in: + rec = next(iter(vcf_in)) + return rec + + finally: + os.unlink(tmp.name) + +class TestChromToOrder: + """Tests for chrom_to_order.""" + + def test_none(self): + assert chrom_to_order(None) == 1_000_000 + + def test_autosomes_with_chr_prefix(self): + assert chrom_to_order("chr1") == 1 + assert chrom_to_order("chr2") == 2 + assert chrom_to_order("chr10") == 10 + + def test_autosomes_without_chr_prefix(self): + assert chrom_to_order("1") == 1 + assert chrom_to_order("22") == 22 + + def test_sex_chromosomes(self): + assert chrom_to_order("chrX") == 23 + assert chrom_to_order("X") == 23 + assert chrom_to_order("chrY") == 24 + assert chrom_to_order("Y") == 24 + + def test_mitochondrial(self): + assert chrom_to_order("chrM") == 25 + assert chrom_to_order("M") == 25 + assert chrom_to_order("chrMT") == 25 + assert chrom_to_order("MT") == 25 + + def test_other_contigs_go_last(self): + assert chrom_to_order("chrUn_gl000220") == 1_000_000 + assert chrom_to_order("GL000220.1") == 1_000_000 + + +class TestNormalizeInfoFields: + """Tests for normalize_info_fields.""" + + def test_skips_unknown_info_fields(self, basic_record, basic_header): + fixed = normalize_info_fields(basic_record, basic_header) + assert "UNKNOWN" not in fixed + + def test_flag_field_included_only_if_true(self, basic_record, basic_header): + fixed = normalize_info_fields(basic_record, basic_header) + assert fixed["FLAG1"] is True + + def test_string_number_one_tuple_joined(self, basic_record, basic_header): + fixed = normalize_info_fields(basic_record, basic_header) + assert fixed["S1"] == "a|b" + + def test_scalar_number_one_tuple_clipped(self, basic_record, basic_header): + fixed = normalize_info_fields(basic_record, basic_header) + assert fixed["I1"] == 10 + + def test_r_field_is_clipped_to_two(self, basic_record, basic_header): + fixed = normalize_info_fields(basic_record, basic_header) + assert fixed["R1"] == [1, 2] + + +class TestGetSampleByName: + """Tests for get_sample_by_name.""" + + def test_returns_sample(self, basic_record): + tumor = get_sample_by_name(basic_record, "TUMOR") + assert tumor is not None + assert tumor["GT"] == (0, 1) + + def test_raises_keyerror(self, basic_record): + with pytest.raises(KeyError): + get_sample_by_name(basic_record, "NOT_A_SAMPLE") + + +class TestGetSampleByIndex: + """Tests for get_sample_by_index.""" + + def test_returns_by_index(self, basic_record): + s0 = get_sample_by_index(basic_record, 0) + s1 = get_sample_by_index(basic_record, 1) + + assert s0 is not None + assert s1 is not None + # Order is the header sample order: TUMOR then NORMAL + assert s0["GT"] == (0, 1) + assert s1["GT"] == (0, 0) + + def test_raises_indexerror(self, basic_record): + with pytest.raises(IndexError): + get_sample_by_index(basic_record, 999) + + +class TestHasFormatField: + """Tests for has_format_field.""" + + def test_true_when_present(self, basic_record): + assert has_format_field(basic_record, "GT") is True + assert has_format_field(basic_record, "AD") is True + + def test_false_when_absent(self, basic_record): + assert has_format_field(basic_record, "DP") is False