diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d841539 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,120 @@ +name: CI + +on: + push: + branches: [master, main] + pull_request: + branches: [master, main] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff mypy + + - name: Run Ruff linter + run: ruff check . + + - name: Run Ruff formatter check + run: ruff format --check . + + - name: Run mypy + run: mypy ribotricer --ignore-missing-imports + continue-on-error: true # Remove this once type hints are added + + test: + name: Test Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + include: + - os: macos-latest + python-version: "3.12" + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install system dependencies (Ubuntu) + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install -y samtools + + - name: Install system dependencies (macOS) + if: runner.os == 'macOS' + run: | + brew install samtools + + - name: Install package + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run unit tests + run: | + pytest tests/ -v --cov=ribotricer --cov-report=xml + continue-on-error: true # Remove once unit tests are added + + - name: Run integration tests + run: | + bash ./run_test.sh + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' + with: + file: ./coverage.xml + fail_ci_if_error: false + + build: + name: Build package + runs-on: ubuntu-latest + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Build package + run: python -m build + + - name: Check package + run: twine check dist/* + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..c1e16bc --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,85 @@ +name: Publish to PyPI + +on: + release: + types: [published] + workflow_dispatch: + inputs: + test_pypi: + description: 'Publish to Test PyPI instead' + required: false + default: 'false' + type: boolean + +jobs: + build: + name: Build distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Build package + run: python -m build + + - name: Check package + run: twine check dist/* + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + publish-test-pypi: + name: Publish to Test PyPI + needs: build + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' && inputs.test_pypi == 'true' + environment: + name: test-pypi + url: https://test.pypi.org/p/ribotricer + permissions: + id-token: write + + steps: + - name: Download artifact + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to Test PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + + publish-pypi: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && inputs.test_pypi == 'false') + environment: + name: pypi + url: https://pypi.org/p/ribotricer + permissions: + id-token: write + + steps: + - name: Download artifact + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml deleted file mode 100644 index c8e4c00..0000000 --- a/.github/workflows/pythonpackage.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Python package - -on: [push] - -jobs: - build: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install wheel - pip install -r requirements.txt - - name: Lint with flake8 - run: | - pip install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pip install --editable . - bash ./run_test.sh diff --git a/.gitignore b/.gitignore index 7c3f0c0..a985fc6 100755 --- a/.gitignore +++ b/.gitignore @@ -74,10 +74,33 @@ snakemake/*.stats *.sqlite.gz *.sra -#pytest +# pytest .pytest_cache/ utils/* snakemake/configs/test* tests/data/.hg38* tests/data/hg38.fa.fai record.txt + +# Type checking +.mypy_cache/ +.dmypy.json +dmypy.json + +# Ruff +.ruff_cache/ + +# Virtual environments +.venv/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Pre-commit +.pre-commit-cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..b8e018f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,38 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-merge-conflict + - id: check-toml + - id: debug-statements + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.4 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + additional_dependencies: + - types-tqdm + args: [--ignore-missing-imports] + exclude: ^(tests/|docs/) + +ci: + autofix_commit_msg: | + [pre-commit.ci] auto fixes from pre-commit.com hooks + autofix_prs: true + autoupdate_branch: '' + autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' + autoupdate_schedule: weekly + skip: [] + submodules: false diff --git a/HISTORY.md b/HISTORY.md index cbc3b39..ffc076b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,11 @@ +# v1.5.0 (2025-02-14) + +- Modernized packaging with `pyproject.toml` (PEP 517/518 compliant) +- Dropped support for Python 3.7 and 3.8 (end of life) +- Added support for Python 3.13 +- Added type checking +- Default to sans-serif fonts for plotting + # v1.4.0 (2024-04-14) - Added `meta_min_reads` parameter to control minimum coverage for metagene plots ([#155](https://github.com/smithlabcode/ribotricer/pull/155)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ef1850b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,218 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "ribotricer" +version = "1.5.0" +description = "Python package to detect translating ORFs from Ribo-seq data" +readme = "README.md" +license = {text = "GPLv3"} +authors = [ + {name = "Saket Choudhary", email = "saketkc@gmail.com"}, + {name = "Wenzheng Li"}, +] +maintainers = [ + {name = "Saket Choudhary", email = "saketkc@gmail.com"}, +] +requires-python = ">=3.9" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Natural Language :: English", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Topic :: Utilities", + "Typing :: Typed", +] +keywords = ["bioinformatics", "ribo-seq", "ribosome-profiling", "ORF", "translation"] +dependencies = [ + "click>=8.1.3", + "click-help-colors>=0.9.1", + "matplotlib>=3.5.3", + "numpy>=1.21.1", + "pandas>=1.3", + "pyfaidx>=0.7.1", + "pysam>=0.19.1", + "quicksect>=0.2.2", + "scipy>=1.7.0", + "tqdm>=4.64.1", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov>=4.0", + "black>=23.0", + "isort>=5.12", + "ruff>=0.1.0", + "mypy>=1.0", + "pre-commit>=3.0", + "build>=1.0", + "twine>=4.0", +] +docs = [ + "sphinx>=6.0", + "sphinx-rtd-theme>=1.0", + "sphinx-click>=4.0", +] + +[project.scripts] +ribotricer = "ribotricer.cli:cli" + +[project.urls] +Homepage = "https://github.com/smithlabcode/ribotricer" +Documentation = "https://github.com/smithlabcode/ribotricer#readme" +Repository = "https://github.com/smithlabcode/ribotricer" +Issues = "https://github.com/smithlabcode/ribotricer/issues" +Changelog = "https://github.com/smithlabcode/ribotricer/blob/master/HISTORY.md" + +[tool.setuptools.packages.find] +where = ["."] +include = ["ribotricer*"] + +[tool.black] +line-length = 88 +target-version = ["py39", "py310", "py311", "py312", "py313"] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | docs +)/ +''' + +[tool.isort] +profile = "black" +line_length = 88 +known_first_party = ["ribotricer"] +skip = [".git", ".venv", "build", "dist"] + +[tool.ruff] +target-version = "py39" +line-length = 88 +exclude = [ + ".git", + ".mypy_cache", + ".ruff_cache", + ".venv", + "build", + "dist", + "docs", + "notebooks", +] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade + "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify +] +ignore = [ + "E501", # line too long (handled by black) + "B008", # do not perform function calls in argument defaults + "B007", # loop control variable not used within loop body + "B028", # no explicit stacklevel keyword argument found + "B904", # raise exceptions with raise ... from err + "C901", # too complex + "ARG001", # unused function argument + "ARG002", # unused method argument + "SIM102", # use single if statement instead of nested if + "SIM108", # use ternary operator + "SIM113", # use enumerate for index variable + "SIM117", # use single with statement with multiple contexts +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +"tests/*" = ["ARG001", "ARG002"] + +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +exclude = [ + "build", + "dist", + "docs", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +addopts = [ + "-v", + "--tb=short", + "--strict-markers", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", +] + +[tool.coverage.run] +source = ["ribotricer"] +branch = true +omit = [ + "*/tests/*", + "*/__init__.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] +fail_under = 0 +show_missing = true + +[tool.bumpversion] +current_version = "1.5.0" +commit = true +tag = false +parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(\\-(?P[a-z]+)(?P\\d+))?" +serialize = [ + "{major}.{minor}.{patch}-{release}{build}", + "{major}.{minor}.{patch}", +] + +[[tool.bumpversion.files]] +filename = "pyproject.toml" +search = 'version = "{current_version}"' +replace = 'version = "{new_version}"' + +[[tool.bumpversion.files]] +filename = "ribotricer/__init__.py" +search = '__version__ = "{current_version}"' +replace = '__version__ = "{new_version}"' diff --git a/ribotricer/__init__.py b/ribotricer/__init__.py index 4b6e54e..0e15a5a 100644 --- a/ribotricer/__init__.py +++ b/ribotricer/__init__.py @@ -1,4 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- +""" +ribotricer - Accurate detection of short and long active ORFs using Ribo-seq data. + +This package provides tools for detecting translating Open Reading Frames (ORFs) +from ribosome profiling (Ribo-seq) data. +""" + __author__ = "Saket Choudhary, Wenzheng Li" -__version__ = "1.4.0" +__version__ = "1.5.0" +__all__ = ["__version__", "__author__"] diff --git a/ribotricer/bam.py b/ribotricer/bam.py index cd35456..36425b2 100644 --- a/ribotricer/bam.py +++ b/ribotricer/bam.py @@ -1,8 +1,8 @@ -"""Utilities for spliting bam file""" +"""Utilities for splitting bam file""" # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,48 +14,58 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from .common import is_read_uniq_mapping -from collections import Counter -from collections import defaultdict +from __future__ import annotations + +from collections import Counter, defaultdict import pysam from tqdm.autonotebook import tqdm +from .common import is_read_uniq_mapping + tqdm.pandas() +# Type aliases for complex nested types +AlignmentDict = defaultdict[int, defaultdict[str, Counter[tuple[str, int]]]] +ReadLengthCounts = defaultdict[int, int] + -def split_bam(bam_path, protocol, prefix, read_lengths=None): - """Split bam by read length and strand +def split_bam( + bam_path: str, + protocol: str, + prefix: str, + read_lengths: list[int] | None = None, +) -> tuple[AlignmentDict, ReadLengthCounts]: + """Split bam by read length and strand. Parameters ---------- bam_path : str - Path to bam file - protocol: str - Experiment protocol [forward, reverse] - prefix: str - prefix for output files - read_lengths: list[int] - read lengths to use - If None, it will be automatically determined by assessing - the periodicity of metagene profile of this read length + Path to bam file. + protocol : str + Experiment protocol ['forward', 'reverse']. + prefix : str + Prefix for output files. + read_lengths : list[int] | None, optional + Read lengths to use. If None, it will be automatically determined + by assessing the periodicity of metagene profile of this read length. Returns ------- - alignments: dict(dict(Counter)) - bam split by length, strand, (chrom, pos) - read_length_counts: dict - key is the length, value is the number of reads + tuple[AlignmentDict, ReadLengthCounts] + - alignments: dict(dict(Counter)) - bam split by length, strand, (chrom, pos) + - read_length_counts: dict - key is the length, value is the number of reads """ - alignments = defaultdict(lambda: defaultdict(Counter)) - read_length_counts = defaultdict(int) + alignments: AlignmentDict = defaultdict(lambda: defaultdict(Counter)) + read_length_counts: ReadLengthCounts = defaultdict(int) total_count = qcfail = duplicate = secondary = unmapped = multi = valid = 0 - # print('reading bam file...') + # First pass just counts the reads # this is required to display a progress bar bam = pysam.AlignmentFile(bam_path, "rb") total_reads = bam.count(until_eof=True) bam.close() + with tqdm(total=total_reads, unit="reads", leave=False) as pbar: bam = pysam.AlignmentFile(bam_path, "rb") for read in bam.fetch(until_eof=True): @@ -83,10 +93,11 @@ def split_bam(bam_path, protocol, prefix, read_lengths=None): if is_usable: map_strand = "-" if read.is_reverse else "+" ref_positions = read.get_reference_positions() - strand = None - pos = None + strand: str | None = None + pos: int | None = None chrom = read.reference_name length = len(ref_positions) + if read_lengths is not None and length not in read_lengths: # Do nothing pass @@ -119,21 +130,24 @@ def split_bam(bam_path, protocol, prefix, read_lengths=None): # The 5'end is the first position pos = ref_positions[0] - # convert bam coordinate to one-based - alignments[length][strand][(chrom, pos + 1)] += 1 - read_length_counts[length] += 1 - valid += 1 + if strand is not None and pos is not None and chrom is not None: + # convert bam coordinate to one-based + alignments[length][strand][(chrom, pos + 1)] += 1 + read_length_counts[length] += 1 + valid += 1 + bam.close() + summary = ( - "summary:\n\ttotal_reads: {}\n\tunique_mapped: {}\n" - "\tqcfail: {}\n\tduplicate: {}\n\tsecondary: {}\n" - "\tunmapped:{}\n\tmulti:{}\n\nlength dist:\n" - ).format(total_count, valid, qcfail, duplicate, secondary, unmapped, multi) + f"summary:\n\ttotal_reads: {total_count}\n\tunique_mapped: {valid}\n" + f"\tqcfail: {qcfail}\n\tduplicate: {duplicate}\n\tsecondary: {secondary}\n" + f"\tunmapped:{unmapped}\n\tmulti:{multi}\n\nlength dist:\n" + ) for length in sorted(read_length_counts): - summary += "\t{}: {}\n".format(length, read_length_counts[length]) + summary += f"\t{length}: {read_length_counts[length]}\n" - with open("{}_bam_summary.txt".format(prefix), "w") as output: + with open(f"{prefix}_bam_summary.txt", "w") as output: output.write(summary) return (alignments, read_length_counts) diff --git a/ribotricer/cli.py b/ribotricer/cli.py index eecbcd9..bd39be6 100644 --- a/ribotricer/cli.py +++ b/ribotricer/cli.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,38 +14,39 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -import click +from __future__ import annotations + import os import sys +from typing import Any + +import click +from click_help_colors import HelpColorsGroup from . import __version__ from .common import _clean_input -from .const import CUTOFF -from .const import MINIMUM_VALID_CODONS -from .const import MINIMUM_VALID_CODONS_RATIO -from .const import MINIMUM_READS_PER_CODON -from .const import MINIMUM_DENSITY_OVER_ORF -from .const import META_MIN_READS - -from .count_orfs import count_orfs -from .count_orfs import count_orfs_codon +from .const import ( + CUTOFF, + META_MIN_READS, + MINIMUM_DENSITY_OVER_ORF, + MINIMUM_READS_PER_CODON, + MINIMUM_VALID_CODONS, + MINIMUM_VALID_CODONS_RATIO, +) +from .count_orfs import count_orfs, count_orfs_codon from .detect_orfs import detect_orfs -from .learn_cutoff import determine_cutoff_bam -from .learn_cutoff import determine_cutoff_tsv - +from .learn_cutoff import determine_cutoff_bam, determine_cutoff_tsv from .orf_seq import orf_seq from .prepare_orfs import prepare_orfs -from click_help_colors import HelpColorsGroup - -CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) +CONTEXT_SETTINGS: dict[str, Any] = {"help_option_names": ["-h", "--help"]} @click.group( cls=HelpColorsGroup, help_headers_color="yellow", help_options_color="green" ) @click.version_option(version=__version__) -def cli(): +def cli() -> None: """ribotricer: Tool for detecting translating ORF from Ribo-seq data""" pass @@ -84,8 +85,15 @@ def cli(): is_flag=True, ) def prepare_orfs_cmd( - gtf, fasta, prefix, min_orf_length, start_codons, stop_codons, longest -): + gtf: str, + fasta: str, + prefix: str, + min_orf_length: int, + start_codons: str, + stop_codons: str, + longest: bool, +) -> None: + """Prepare ORFs command handler.""" if not os.path.isfile(gtf): sys.exit("Error: GTF file not found") @@ -95,20 +103,24 @@ def prepare_orfs_cmd( if min_orf_length <= 0: sys.exit("Error: min ORF length at least to be 1") - start_codons = set([x.strip().upper() for x in start_codons.strip().split(",")]) - if not start_codons: + start_codons_set = {x.strip().upper() for x in start_codons.strip().split(",")} + if not start_codons_set: sys.exit("Error: start codons cannot be empty") - if not all([len(x) == 3 and set(x) <= {"A", "C", "G", "T"} for x in start_codons]): + if not all( + len(x) == 3 and set(x) <= {"A", "C", "G", "T"} for x in start_codons_set + ): sys.exit("Error: invalid codon, only A, C, G, T allowed") - stop_codons = set([x.strip().upper() for x in stop_codons.strip().split(",")]) - if not stop_codons: + stop_codons_set = {x.strip().upper() for x in stop_codons.strip().split(",")} + if not stop_codons_set: sys.exit("Error: stop codons cannot be empty") - if not all([len(x) == 3 and set(x) <= {"A", "C", "G", "T"} for x in stop_codons]): + if not all(len(x) == 3 and set(x) <= {"A", "C", "G", "T"} for x in stop_codons_set): sys.exit("Error: invalid codon, only A, C, G, T allowed") - print("Using start codons: {}".format(",".join(start_codons))) - prepare_orfs(gtf, fasta, prefix, min_orf_length, start_codons, stop_codons, longest) + print("Using start codons: {}".format(",".join(start_codons_set))) + prepare_orfs( + gtf, fasta, prefix, min_orf_length, start_codons_set, stop_codons_set, longest + ) # detect-orfs function ######################################### @@ -195,7 +207,7 @@ def prepare_orfs_cmd( ) @click.option( "--report_all", - help=("Whether output all ORFs including those " "non-translating ones"), + help=("Whether output all ORFs including those non-translating ones"), is_flag=True, ) @click.option( @@ -206,63 +218,74 @@ def prepare_orfs_cmd( help="Minimum number of reads for a read length to be considered", ) def detect_orfs_cmd( - bam, - ribotricer_index, - prefix, - stranded, - read_lengths, - psite_offsets, - phase_score_cutoff, - min_valid_codons, - min_reads_per_codon, - min_valid_codons_ratio, - min_read_density, - report_all, - meta_min_reads, -): + bam: str, + ribotricer_index: str, + prefix: str, + stranded: str | None, + read_lengths: str | None, + psite_offsets: str | None, + phase_score_cutoff: float, + min_valid_codons: int, + min_reads_per_codon: int, + min_valid_codons_ratio: float, + min_read_density: float, + report_all: bool, + meta_min_reads: int, +) -> None: + """Detect ORFs command handler.""" if not os.path.isfile(bam): sys.exit("Error: BAM file not found") if not os.path.isfile(ribotricer_index): sys.exit("Error: ribotricer index file not found") + read_lengths_list: list[int] | None = None + psite_offsets_dict: dict[int, int] | None = None + if read_lengths is not None: try: - read_lengths = [int(x.strip()) for x in read_lengths.strip().split(",")] + read_lengths_list = [ + int(x.strip()) for x in read_lengths.strip().split(",") + ] except Exception: sys.exit("Error: cannot convert read_lengths into integers") - if not all([x > 0 for x in read_lengths]): + if not all(x > 0 for x in read_lengths_list): sys.exit("Error: read length must be positive") - if read_lengths is None and psite_offsets is not None: + if read_lengths_list is None and psite_offsets is not None: sys.exit("Error: psite_offsets only allowed when read_lengths is provided") - if read_lengths is not None and psite_offsets is not None: + if read_lengths_list is not None and psite_offsets is not None: try: - psite_offsets = [int(x.strip()) for x in psite_offsets.strip().split(",")] + psite_offsets_list = [ + int(x.strip()) for x in psite_offsets.strip().split(",") + ] except Exception: sys.exit("Error: cannot convert psite_offsets into integers") - if len(read_lengths) != len(psite_offsets): + if len(read_lengths_list) != len(psite_offsets_list): sys.exit("Error: psite_offsets must match read_lengths") - if not all(x >= 0 for x in psite_offsets): + if not all(x >= 0 for x in psite_offsets_list): sys.exit("Error: P-site offset must be >= 0") - if not all(x > y for (x, y) in zip(read_lengths, psite_offsets)): + if not all(x > y for (x, y) in zip(read_lengths_list, psite_offsets_list)): sys.exit("Error: P-site offset must be smaller than read length") - psite_offsets = dict(list(zip(read_lengths, psite_offsets))) + psite_offsets_dict = dict(list(zip(read_lengths_list, psite_offsets_list))) + if stranded == "yes": stranded = "forward" + detect_orfs( bam, ribotricer_index, prefix, stranded, - read_lengths, - psite_offsets, + read_lengths_list, + psite_offsets_dict, phase_score_cutoff, min_valid_codons, min_reads_per_codon, min_valid_codons_ratio, min_read_density, report_all, + meta_min_reads, ) @@ -292,20 +315,26 @@ def detect_orfs_cmd( @click.option("--out", help="Path to output file", required=True) @click.option( "--report_all", - help=("Whether output all ORFs including those " "non-translating ones"), + help=("Whether output all ORFs including those non-translating ones"), is_flag=True, ) -def count_orfs_cmd(ribotricer_index, detected_orfs, features, out, report_all): - +def count_orfs_cmd( + ribotricer_index: str, + detected_orfs: str, + features: str, + out: str, + report_all: bool, +) -> None: + """Count ORFs command handler.""" if not os.path.isfile(ribotricer_index): sys.exit("Error: ribotricer index file not found") if not os.path.isfile(detected_orfs): sys.exit("Error: detected orfs file not found") - features = set(x.strip() for x in features.strip().split(",")) + features_set = {x.strip() for x in features.strip().split(",")} - count_orfs(ribotricer_index, detected_orfs, features, out, report_all) + count_orfs(ribotricer_index, detected_orfs, features_set, out, report_all) # count-orfs-codon function ######################################### @@ -335,18 +364,18 @@ def count_orfs_cmd(ribotricer_index, detected_orfs, features, out, report_all): @click.option("--prefix", help="Prefix for output files", required=True) @click.option( "--report_all", - help=("Whether output all ORFs including those " "non-translating ones"), + help=("Whether output all ORFs including those non-translating ones"), is_flag=True, ) def count_orfs_codon_cmd( - ribotricer_index, - detected_orfs, - features, - ribotricer_index_fasta, - prefix, - report_all, -): - + ribotricer_index: str, + detected_orfs: str, + features: str, + ribotricer_index_fasta: str, + prefix: str, + report_all: bool, +) -> None: + """Count ORFs at codon level command handler.""" if not os.path.isfile(ribotricer_index): sys.exit("Error: ribotricer index file not found") @@ -356,12 +385,12 @@ def count_orfs_codon_cmd( if not os.path.isfile(ribotricer_index_fasta): sys.exit("Error: ribotricer_index_fasta file not found") - features = set(x.strip() for x in features.strip().split(",")) + features_set = {x.strip() for x in features.strip().split(",")} count_orfs_codon( ribotricer_index, detected_orfs, - features, + features_set, ribotricer_index_fasta, prefix, report_all, @@ -387,7 +416,13 @@ def count_orfs_codon_cmd( "--protein", help="Output protein sequence instead of nucleotide", is_flag=True ) @click.option("--saveto", help="Path to output file", required=True) -def orf_seq_cmd(ribotricer_index, fasta, saveto, protein): +def orf_seq_cmd( + ribotricer_index: str, + fasta: str, + saveto: str, + protein: bool, +) -> None: + """Generate ORF sequences command handler.""" if not os.path.isfile(ribotricer_index): sys.exit("Error: ribotricer index file not found") @@ -457,23 +492,23 @@ def orf_seq_cmd(ribotricer_index, fasta, saveto, protein): help="Number of bootstraps", ) def determine_cutoff_cmd( - ribo_bams, - rna_bams, - ribo_tsvs, - rna_tsvs, - ribotricer_index, - prefix, - filter_by_tx_annotation, - phase_score_cutoff, - min_valid_codons, - sampling_ratio, - n_bootstraps, -): - + ribo_bams: str | None, + rna_bams: str | None, + ribo_tsvs: str | None, + rna_tsvs: str | None, + ribotricer_index: str | None, + prefix: str | None, + filter_by_tx_annotation: str, + phase_score_cutoff: float, + min_valid_codons: int, + sampling_ratio: float, + n_bootstraps: int, +) -> None: + """Learn cutoff command handler.""" filter_by = _clean_input(filter_by_tx_annotation) - ribo_stranded_protocols = [] - rna_stranded_protocols = [] + ribo_stranded_protocols: list[str | None] = [] + rna_stranded_protocols: list[str | None] = [] if ribo_bams and ribo_tsvs: sys.exit("Error: --ribo-bams and --rna_bams cannot be specified together") @@ -484,25 +519,30 @@ def determine_cutoff_cmd( if (ribo_bams and rna_tsvs) or (rna_bams and ribo_tsvs): sys.exit("Error: BAM and TSV inputs cannot be specified together") - if ribotricer_index: - if not os.path.isfile(ribotricer_index): - sys.exit("Error: ribotricer index file not found") + if ribotricer_index and not os.path.isfile(ribotricer_index): + sys.exit("Error: ribotricer index file not found") + + ribo_bams_list: list[str] = [] + rna_bams_list: list[str] = [] + ribo_tsvs_list: list[str] = [] + rna_tsvs_list: list[str] = [] + if ribo_bams: - ribo_bams = _clean_input(ribo_bams) - rna_bams = _clean_input(rna_bams) + ribo_bams_list = _clean_input(ribo_bams) + rna_bams_list = _clean_input(rna_bams) if rna_bams else [] else: - ribo_tsvs = _clean_input(ribo_tsvs) - rna_tsvs = _clean_input(rna_tsvs) + ribo_tsvs_list = _clean_input(ribo_tsvs) if ribo_tsvs else [] + rna_tsvs_list = _clean_input(rna_tsvs) if rna_tsvs else [] - if ribo_bams and rna_bams: + if ribo_bams_list and rna_bams_list: if not prefix: sys.exit("Error: --prefix required with BAM inputs") elif not ribotricer_index: sys.exit("Error: --ribotricer_index required with BAM inputs") else: determine_cutoff_bam( - ribo_bams, - rna_bams, + ribo_bams_list, + rna_bams_list, ribotricer_index, prefix, ribo_stranded_protocols, @@ -516,5 +556,5 @@ def determine_cutoff_cmd( ) else: determine_cutoff_tsv( - ribo_tsvs, rna_tsvs, filter_by, sampling_ratio, n_bootstraps + ribo_tsvs_list, rna_tsvs_list, filter_by, sampling_ratio, n_bootstraps ) diff --git a/ribotricer/common.py b/ribotricer/common.py index 6de517c..4599b61 100644 --- a/ribotricer/common.py +++ b/ribotricer/common.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,23 +14,37 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + import ntpath import pathlib import sys +from typing import TYPE_CHECKING + from .interval import Interval +if TYPE_CHECKING: + import pysam + # Source: https://broadinstitute.github.io/picard/explain-flags.html -__SAM_NOT_UNIQ_FLAGS__ = [4, 20, 256, 272, 2048] +__SAM_NOT_UNIQ_FLAGS__: list[int] = [4, 20, 256, 272, 2048] -def is_read_uniq_mapping(read): +def is_read_uniq_mapping(read: pysam.AlignedSegment) -> bool | None: """Check if read is uniquely mappable. Parameters ---------- - read : pysam.Alignment.fetch object + read : pysam.AlignedSegment + A pysam alignment object. + Returns + ------- + bool | None + True if uniquely mapping, False if not, None if unable to determine. + Notes + ----- Most reliable: ['NH'] tag """ # Filter out secondary alignments @@ -44,10 +58,7 @@ def is_read_uniq_mapping(read): # Reliable in case of STAR if read.mapping_quality == 255: return True - elif read.mapping_quality < 1: - return False - # NH tag not set so rely on flags - elif read.flag in __SAM_NOT_UNIQ_FLAGS__: + elif read.mapping_quality < 1 or read.flag in __SAM_NOT_UNIQ_FLAGS__: return False else: sys.stdout.write( @@ -55,22 +66,24 @@ def is_read_uniq_mapping(read): "determining multimapping status. All the reads will be " "treated as uniquely mapping\n" ) + return None -def merge_intervals(intervals): - """ +def merge_intervals(intervals: list[Interval]) -> list[Interval]: + """Merge overlapping intervals. + Parameters ---------- - intervals: List[Interval] + intervals : list[Interval] + List of intervals to merge. Returns ------- - merged_intervals: List[Interval] - sorted and merged intervals + list[Interval] + Sorted and merged intervals. """ - intervals = sorted(intervals, key=lambda x: x.start) - merged_intervals = [] + merged_intervals: list[Interval] = [] i = 0 while i < len(intervals): to_merge = Interval( @@ -87,44 +100,79 @@ def merge_intervals(intervals): return merged_intervals -def mkdir_p(path): +def mkdir_p(path: str) -> None: """Make directory even if it exists. Parameters ---------- - path: str + path : str + Path to directory to create. """ pathlib.Path(path).mkdir(parents=True, exist_ok=True) -def path_leaf(path): - """Get path's tail from a filepath""" +def path_leaf(path: str) -> str: + """Get path's tail from a filepath. + + Parameters + ---------- + path : str + File path. + + Returns + ------- + str + The tail (filename) portion of the path. + """ head, tail = ntpath.split(path) return tail or ntpath.basename(head) -def parent_dir(path): - """Get path's tail from a filepath""" +def parent_dir(path: str) -> str: + """Get path's parent directory from a filepath. + + Parameters + ---------- + path : str + File path. + + Returns + ------- + str + The parent directory portion of the path. + """ head, tail = ntpath.split(path) return head -def _clean_input(comma_string): - """Clean comma separated option inputs in CLI""" - return list(map(lambda term: term.strip(" "), comma_string.split(","))) +def _clean_input(comma_string: str) -> list[str]: + """Clean comma separated option inputs in CLI. + Parameters + ---------- + comma_string : str + Comma-separated string of values. -def collapse_coverage_to_codon(coverage): + Returns + ------- + list[str] + List of stripped string values. + """ + return [term.strip(" ") for term in comma_string.split(",")] + + +def collapse_coverage_to_codon(coverage: list[int]) -> list[int]: """Collapse nucleotide level coverage to codon level. Parameters ---------- - coverage: list - Nucleotide level counts + coverage : list[int] + Nucleotide level counts. + Returns ------- - codon_coverage: list - Coverage collapsed to codon level + list[int] + Coverage collapsed to codon level. """ codon_coverage = [ sum(coverage[current : current + 3]) for current in range(0, len(coverage), 3) diff --git a/ribotricer/const.py b/ribotricer/const.py index 8d68402..e5336fc 100644 --- a/ribotricer/const.py +++ b/ribotricer/const.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,21 +14,29 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# ribotricer default cutoff for lavbeling ORFs 'translating' -CUTOFF = 0.428571428571 +from typing import Final + +# ribotricer default cutoff for labeling ORFs 'translating' +CUTOFF: Final[float] = 0.428571428571 + # p-site offset -TYPICAL_OFFSET = 12 +TYPICAL_OFFSET: Final[int] = 12 + # minimum number of valid codons required in an ORF to label # it 'translating' -MINIMUM_VALID_CODONS = 5 +MINIMUM_VALID_CODONS: Final[int] = 5 + # minimum number of reads required per codon in an ORF to label # it 'translating' # default: 0 (decided by CUTOFF and MINIMUM_VALID_CODONS) -MINIMUM_READS_PER_CODON = 0 +MINIMUM_READS_PER_CODON: Final[int] = 0 + # fraction of codons with non zero reads -MINIMUM_VALID_CODONS_RATIO = 0 +MINIMUM_VALID_CODONS_RATIO: Final[float] = 0 + # Minimum read density over ORF # defined as the number of reads per unit length of the ORF -MINIMUM_DENSITY_OVER_ORF = 0.0 +MINIMUM_DENSITY_OVER_ORF: Final[float] = 0.0 + # Minimum number of reads for a read length to be considered -META_MIN_READS = 100000 +META_MIN_READS: Final[int] = 100000 diff --git a/ribotricer/count_orfs.py b/ribotricer/count_orfs.py index 487beb0..dfc37d3 100644 --- a/ribotricer/count_orfs.py +++ b/ribotricer/count_orfs.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,39 +14,49 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + from collections import defaultdict from textwrap import wrap -from .orf import ORF import numpy as np import pandas as pd +from .orf import ORF + + +def count_orfs( + ribotricer_index: str, + detected_orfs: str, + features: set[str], + outfile: str, + report_all: bool = False, +) -> None: + """Count ORFs from ribotricer output. -def count_orfs(ribotricer_index, detected_orfs, features, outfile, report_all=False): - """ Parameters ---------- - ribotricer_index: str - Path to the index file generated by ribotricer prepare_orfs - detected_orfs: str - Path to the detected orfs file generated by ribotricer detect_orfs - features: set - set of ORF types, such as {annotated} - prefix: str - prefix for output file - report_all: bool - if True, all coverages will be exported + ribotricer_index : str + Path to the index file generated by ribotricer prepare_orfs. + detected_orfs : str + Path to the detected ORFs file generated by ribotricer detect_orfs. + features : set[str] + Set of ORF types, such as {'annotated'}. + outfile : str + Path to output file. + report_all : bool, optional + If True, all coverages will be exported, by default False. """ orf_index = {} read_counts = defaultdict(dict) - with open(ribotricer_index, "r") as fin: + with open(ribotricer_index) as fin: # Skip header fin.readline() for line in fin: orf = ORF.from_string(line) if orf.category in features: orf_index[orf.oid] = orf - with open(detected_orfs, "r") as fin: + with open(detected_orfs) as fin: # Skip header fin.readline() for line in fin: @@ -62,7 +72,7 @@ def count_orfs(ribotricer_index, detected_orfs, features, outfile, report_all=Fa if strand == "-": coor = coor[::-1] profile_stripped = profile.strip()[1:-1].split(", ") - profile = list() + profile = [] if profile_stripped[0]: profile = list(map(int, profile_stripped)) for pos, cov in zip(coor, profile): @@ -76,46 +86,45 @@ def count_orfs(ribotricer_index, detected_orfs, features, outfile, report_all=Fa values = read_counts[gene_id, gene_name].values() length = len(values) total = sum(values) - fout.write("{}\t{}\t{}\n".format(gene_id, total, length)) + fout.write(f"{gene_id}\t{total}\t{length}\n") def count_orfs_codon( - ribotricer_index, - detected_orfs, - features, - ribotricer_index_fasta, - prefix, - report_all=False, -): - """ - Create genewise codon summaries + ribotricer_index: str, + detected_orfs: str, + features: set[str], + ribotricer_index_fasta: str, + prefix: str, + report_all: bool = False, +) -> None: + """Create genewise codon summaries. Parameters ---------- - ribotricer_index: str - Path to the index file generated by ribotricer prepare_orfs - detected_orfs: str - Path to the detected orfs file generated by ribotricer detect_orfs - features: set - set of ORF types, such as {annotated} - ribotricer_index_fasta: str - Path to fasta index generated using orf-seq - prefix: str - path to output file - report_all: bool - if True, all coverages will be exported + ribotricer_index : str + Path to the index file generated by ribotricer prepare_orfs. + detected_orfs : str + Path to the detected ORFs file generated by ribotricer detect_orfs. + features : set[str] + Set of ORF types, such as {'annotated'}. + ribotricer_index_fasta : str + Path to FASTA index generated using orf-seq. + prefix : str + Path to output file. + report_all : bool, optional + If True, all coverages will be exported, by default False. """ orf_index = {} fasta_df = pd.read_csv(ribotricer_index_fasta, sep="\t").set_index("ORF_ID") read_counts = defaultdict(dict) - with open(ribotricer_index, "r") as fin: + with open(ribotricer_index) as fin: # Skip header fin.readline() for line in fin: orf = ORF.from_string(line) if orf.category in features: orf_index[orf.oid] = orf - with open(detected_orfs, "r") as fin: + with open(detected_orfs) as fin: # Skip header fin.readline() for line in fin: @@ -134,7 +143,7 @@ def count_orfs_codon( if strand == "-": coor = coor[::-1] profile_stripped = profile.strip()[1:-1].split(", ") - profile = list() + profile = [] if profile_stripped[0]: profile = list(map(int, profile_stripped)) # IMP: Skip profiles that are not 3n long to avoid errors @@ -147,7 +156,7 @@ def count_orfs_codon( ).tolist() assert sum(codon_profile) == sum(profile) codon_seq = str(fasta_df.loc[oid].sequence) - if not len(codon_seq) % 3 == 0: + if len(codon_seq) % 3 != 0: print(oid, len(codon_seq)) codon_seq_partitioned = wrap(codon_seq, 3) for pos, cov, codon_seq in zip( @@ -157,7 +166,7 @@ def count_orfs_codon( read_counts[gene_id, codon_seq][pos] = cov # Output count table - with open("{}_genewise.tsv".format(prefix), "w") as fout: + with open(f"{prefix}_genewise.tsv", "w") as fout: fout.write( "\t".join( "gene_id", @@ -179,18 +188,9 @@ def count_orfs_codon( median_codon_coverage = np.median(values) var_codon_coverage = np.var(values) fout.write( - "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format( - gene_id, - codon_seq, - values, - mean_codon_coverage, - median_codon_coverage, - var_codon_coverage, - codon_occurences, - total_codon_coverage, - ) + f"{gene_id}\t{codon_seq}\t{values}\t{mean_codon_coverage}\t{median_codon_coverage}\t{var_codon_coverage}\t{codon_occurences}\t{total_codon_coverage}\n" ) - fout_df = pd.read_csv("{}_genewise.tsv".format(prefix), sep="\t") + fout_df = pd.read_csv(f"{prefix}_genewise.tsv", sep="\t") fout_df["per_codon_enrichment(total/n_occur)"] = ( fout_df["total_codon_coverage"] / fout_df["codon_occurences"] ) @@ -199,7 +199,7 @@ def count_orfs_codon( / fout_df.groupby("gene_id")["total_codon_coverage"].transform("sum") ) # Overwrite - fout_df.to_csv("{}_genewise.tsv".format(prefix), sep="\t", index=False, header=True) + fout_df.to_csv(f"{prefix}_genewise.tsv", sep="\t", index=False, header=True) # Remove infs fout_df = fout_df.replace([np.inf, -np.inf], np.nan) fout_df = fout_df.dropna() @@ -229,5 +229,5 @@ def count_orfs_codon( relative_enrichment = relative_enrichment.reset_index() relative_enrichment.to_csv( - "{}_codonwise.tsv".format(prefix), sep="\t", index=False, header=True + f"{prefix}_codonwise.tsv", sep="\t", index=False, header=True ) diff --git a/ribotricer/detect_orfs.py b/ribotricer/detect_orfs.py index 9ccca65..9155149 100644 --- a/ribotricer/detect_orfs.py +++ b/ribotricer/detect_orfs.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,55 +14,62 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from .statistics import phasescore -from .plotting import plot_metagene -from .plotting import plot_read_lengths -from .orf import ORF -from .metagene import align_metagenes -from .metagene import metagene_coverage -from .infer_protocol import infer_protocol -from .const import MINIMUM_DENSITY_OVER_ORF -from .const import MINIMUM_READS_PER_CODON -from .const import MINIMUM_VALID_CODONS_RATIO -from .const import MINIMUM_VALID_CODONS -from .const import CUTOFF -from .common import parent_dir -from .common import mkdir_p -from .common import collapse_coverage_to_codon -from .bam import split_bam -from quicksect import Interval, IntervalTree -from collections import Counter -from collections import defaultdict +from __future__ import annotations + import datetime +from collections import Counter, defaultdict +from typing import Final import numpy as np +from quicksect import Interval, IntervalTree from tqdm.autonotebook import tqdm +from .bam import AlignmentDict, split_bam +from .common import collapse_coverage_to_codon, mkdir_p, parent_dir +from .const import ( + CUTOFF, + MINIMUM_DENSITY_OVER_ORF, + MINIMUM_READS_PER_CODON, + MINIMUM_VALID_CODONS, + MINIMUM_VALID_CODONS_RATIO, +) +from .infer_protocol import infer_protocol +from .metagene import align_metagenes, metagene_coverage +from .orf import ORF +from .plotting import plot_metagene, plot_read_lengths +from .statistics import phasescore + tqdm.pandas() # Required for IntervalTree -STRAND_TO_NUM = {"+": 1, "-": -1} +STRAND_TO_NUM: Final[dict[str, int]] = {"+": 1, "-": -1} +# Type aliases +MergedAlignments = defaultdict[str, Counter[tuple[str, int]]] +PsiteOffsets = dict[int, int] +RefSeq = defaultdict[str, IntervalTree] -def merge_read_lengths(alignments, psite_offsets): - """ - Merge read counts for different read lengths after - applying appropriate offset(s). + +def merge_read_lengths( + alignments: AlignmentDict, + psite_offsets: PsiteOffsets, +) -> MergedAlignments: + """Merge read counts for different read lengths after applying appropriate offset(s). Parameters ---------- - alignments: dict(dict(Counter)) - bam split by length, strand - psite_offsets: dict - key is the length, value is the offset + alignments : AlignmentDict + BAM split by length, strand. + psite_offsets : PsiteOffsets + Key is the length, value is the offset. + Returns ------- - merged_alignments: dict(dict) - alignments by merging all lengths + MergedAlignments + Alignments by merging all lengths. """ - # print('merging different lengths...') - merged_alignments = defaultdict(Counter) + merged_alignments: MergedAlignments = defaultdict(Counter) for length, offset in list(psite_offsets.items()): for strand in alignments[length]: @@ -76,28 +83,21 @@ def merge_read_lengths(alignments, psite_offsets): return merged_alignments -def parse_ribotricer_index(ribotricer_index): - """ - Parse ribotricer index to get only 'annotated' - features. +def parse_ribotricer_index(ribotricer_index: str) -> tuple[list[ORF], RefSeq]: + """Parse ribotricer index to get only 'annotated' features. Parameters ---------- - ribotricer_index: str - Path to the index file generated by ribotricer prepare_orfs + ribotricer_index : str + Path to the index file generated by ribotricer prepare_orfs. Returns ------- - annotated: List[ORF] - ORFs of CDS annotated - novel: List[ORF] - list of non-annotated ORFs - refseq: defaultdict(IntervalTree) - chrom: (start, end, strand) + tuple[list[ORF], RefSeq] + Tuple of (annotated ORFs, refseq interval tree). """ - - annotated = [] - refseq = defaultdict(IntervalTree) + annotated: list[ORF] = [] + refseq: RefSeq = defaultdict(IntervalTree) # First count the number of # annotated regions to count. @@ -105,12 +105,12 @@ def parse_ribotricer_index(ribotricer_index): # so need to read only upto a point where the regions # no longer have the annotated tag. total_lines = 0 - with open(ribotricer_index, "r") as anno: + with open(ribotricer_index) as anno: # read header anno.readline() while "annotated" in anno.readline(): total_lines += 1 - with open(ribotricer_index, "r") as anno: + with open(ribotricer_index) as anno: with tqdm(total=total_lines, unit="lines", leave=False) as pbar: # read header anno.readline() @@ -131,25 +131,31 @@ def parse_ribotricer_index(ribotricer_index): return (annotated, refseq) -def orf_coverage(orf, alignments, offset_5p=0, offset_3p=0): - """ +def orf_coverage( + orf: ORF, + alignments: MergedAlignments, + offset_5p: int = 0, + offset_3p: int = 0, +) -> list[int]: + """Compute coverage for an ORF. + Parameters ---------- - orf: ORF - instance of ORF - alignments: dict(Counter) - alignments summarized from bam by merging lengths - offset_5p: int - the number of nts to include from 5'prime - offset_3p: int - the number of nts to include from 3'prime + orf : ORF + Instance of ORF. + alignments : MergedAlignments + Alignments summarized from BAM by merging lengths. + offset_5p : int, optional + The number of nts to include from 5'prime, by default 0. + offset_3p : int, optional + The number of nts to include from 3'prime, by default 0. Returns ------- - coverage: array - coverage for ORF + list[int] + Coverage for ORF. """ - coverage = [] + coverage: list[int] = [] chrom = orf.chrom strand = orf.strand if strand == "-": @@ -198,27 +204,38 @@ def orf_coverage(orf, alignments, offset_5p=0, offset_3p=0): def export_orf_coverages( - ribotricer_index, - merged_alignments, - prefix, - phase_score_cutoff=CUTOFF, - min_valid_codons=MINIMUM_VALID_CODONS, - min_reads_per_codon=MINIMUM_READS_PER_CODON, - min_valid_codons_ratio=MINIMUM_VALID_CODONS_RATIO, - min_density_over_orf=MINIMUM_DENSITY_OVER_ORF, - report_all=False, -): - """ + ribotricer_index: str, + merged_alignments: MergedAlignments, + prefix: str, + phase_score_cutoff: float = CUTOFF, + min_valid_codons: int = MINIMUM_VALID_CODONS, + min_reads_per_codon: float = MINIMUM_READS_PER_CODON, + min_valid_codons_ratio: float = MINIMUM_VALID_CODONS_RATIO, + min_density_over_orf: float = MINIMUM_DENSITY_OVER_ORF, + report_all: bool = False, +) -> None: + """Export ORF coverages to file. + Parameters ---------- - ribotricer_index: str - Path to the index file generated by ribotricer prepare_orfs - merged_alignments: dict(dict) - alignments by merging all lengths - prefix: str - prefix for output file - report_all: bool - if True, all coverages will be exported + ribotricer_index : str + Path to the index file generated by ribotricer prepare_orfs. + merged_alignments : MergedAlignments + Alignments by merging all lengths. + prefix : str + Prefix for output file. + phase_score_cutoff : float, optional + Phase score cutoff value, by default CUTOFF. + min_valid_codons : int, optional + Minimum valid codons, by default MINIMUM_VALID_CODONS. + min_reads_per_codon : float, optional + Minimum reads per codon, by default MINIMUM_READS_PER_CODON. + min_valid_codons_ratio : float, optional + Minimum valid codons ratio, by default MINIMUM_VALID_CODONS_RATIO. + min_density_over_orf : float, optional + Minimum density over ORF, by default MINIMUM_DENSITY_OVER_ORF. + report_all : bool, optional + If True, all coverages will be exported, by default False. """ # print('exporting coverages for all ORFs...') columns = [ @@ -243,12 +260,13 @@ def export_orf_coverages( ] to_write = "\t".join(columns) formatter = "{}\t" * (len(columns) - 1) + "{}\n" - with open(ribotricer_index, "r") as anno: + with open(ribotricer_index) as anno: total_lines = len(["" for line in anno]) - with open(ribotricer_index, "r") as anno, open( - "{}_translating_ORFs.tsv".format(prefix), "w" - ) as output: + with ( + open(ribotricer_index) as anno, + open(f"{prefix}_translating_ORFs.tsv", "w") as output, + ): output.write(to_write) with tqdm(total=total_lines, unit="ORFs") as pbar: # Skip header @@ -306,14 +324,15 @@ def export_orf_coverages( output.write(to_write) -def export_wig(merged_alignments, prefix): - """ +def export_wig(merged_alignments: MergedAlignments, prefix: str) -> None: + """Export merged alignments to WIG files. + Parameters ---------- - merged_alignments: dict(dict) - alignments by merging all lengths - prefix: str - prefix of output wig files + merged_alignments : MergedAlignments + Alignments by merging all lengths. + prefix : str + Prefix of output WIG files. """ # print('exporting merged alignments to wig file...') for strand in merged_alignments: @@ -322,59 +341,62 @@ def export_wig(merged_alignments, prefix): for chrom, pos in sorted(merged_alignments[strand]): if chrom != cur_chrom: cur_chrom = chrom - to_write += "variableStep chrom={}\n".format(chrom) - to_write += "{}\t{}\n".format(pos, merged_alignments[strand][(chrom, pos)]) + to_write += f"variableStep chrom={chrom}\n" + to_write += f"{pos}\t{merged_alignments[strand][(chrom, pos)]}\n" if strand == "+": - fname = "{}_pos.wig".format(prefix) + fname = f"{prefix}_pos.wig" else: - fname = "{}_neg.wig".format(prefix) + fname = f"{prefix}_neg.wig" with open(fname, "w") as output: output.write(to_write) def detect_orfs( - bam, - ribotricer_index, - prefix, - protocol, - read_lengths, - psite_offsets, - phase_score_cutoff, - min_valid_codons, - min_reads_per_codon, - min_valid_codons_ratio, - min_density_over_orf, - report_all, - meta_min_reads=100000, -): - """ + bam: str, + ribotricer_index: str, + prefix: str, + protocol: str | None, + read_lengths: list[int] | None, + psite_offsets: PsiteOffsets | None, + phase_score_cutoff: float, + min_valid_codons: int, + min_reads_per_codon: float, + min_valid_codons_ratio: float, + min_density_over_orf: float, + report_all: bool, + meta_min_reads: int = 100000, +) -> None: + """Detect translating ORFs from Ribo-seq data. + Parameters ---------- - bam: str - Path to the bam file - ribotricer_index: str - Path to the index file generated by ribotricer prepare_orfs - prefix: str - prefix for all output files - protocol: str - {'forward', 'no', 'reverse'} - If None, the protocolness will be automatically inferred - read_lengths: list[int] - read lengths to use - If None, it will be automatically determined by assessing - the periodicity of metagene profile of this read length - psite_offsets: dict - Psite offsets for each read lengths - If None, the profiles from different read lengths will be - automatically aligned using cross-correlation - phase_score_cutoff: float - Phase score cutoff value for tagging an ORF as translating o - or non-translating - report_all: bool - Whether to output all ORFs' scores regardless of translation - status - meta_min_reads: int - minimum number of reads for a read length to be considered. Passed to metagene_coverage(). + bam : str + Path to the BAM file. + ribotricer_index : str + Path to the index file generated by ribotricer prepare_orfs. + prefix : str + Prefix for all output files. + protocol : str | None + Protocol: 'forward', 'no', or 'reverse'. If None, will be inferred. + read_lengths : list[int] | None + Read lengths to use. If None, will be automatically determined. + psite_offsets : PsiteOffsets | None + P-site offsets for each read length. If None, will be aligned using + cross-correlation. + phase_score_cutoff : float + Phase score cutoff value for tagging an ORF as translating. + min_valid_codons : int + Minimum valid codons. + min_reads_per_codon : float + Minimum reads per codon. + min_valid_codons_ratio : float + Minimum valid codons ratio. + min_density_over_orf : float + Minimum density over ORF. + report_all : bool + Whether to output all ORFs' scores regardless of translation status. + meta_min_reads : int, optional + Minimum number of reads for a read length to be considered, by default 100000. """ now = datetime.datetime.now() print(now.strftime("%b %d %H:%M:%S ..... started ribotricer detect-orfs")) diff --git a/ribotricer/fasta.py b/ribotricer/fasta.py index 5d385e3..e8b0831 100644 --- a/ribotricer/fasta.py +++ b/ribotricer/fasta.py @@ -1,8 +1,8 @@ -"""process fasta files""" +"""Process fasta files""" # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,93 +14,109 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from collections import OrderedDict +from __future__ import annotations + import os import warnings +from collections import OrderedDict +from typing import TYPE_CHECKING from pyfaidx import Fasta +if TYPE_CHECKING: + from .interval import Interval + class FastaReader: - """Class for reading and querying fasta file.""" + """Class for reading and querying fasta file. + + Attributes + ---------- + fasta_location : str + Path to the fasta file. + fasta : Fasta + pyfaidx Fasta object. + """ + + def __init__(self, fasta_location: str) -> None: + """Initialize FastaReader. - def __init__(self, fasta_location): - """ Parameters - --------- - fasta_location : string - Path to fasta file + ---------- + fasta_location : str + Path to fasta file. + Raises + ------ + Exception + If the fasta file cannot be read. """ self.fasta_location = fasta_location try: self.fasta = Fasta(fasta_location, as_raw=True, sequence_always_upper=True) except Exception as e: raise Exception( - "Error reading fasta file {} : {}".format( - os.path.abspath(self.fasta_location), e - ) + f"Error reading fasta file {os.path.abspath(self.fasta_location)} : {e}" ) - def query(self, intervals): + def query(self, intervals: list[Interval]) -> list[str]: """Query regions for sequence. Parameters ---------- - intervals: list of Interval - The intervals for fasta is one-based and full-closed + intervals : list[Interval] + The intervals for fasta (one-based and full-closed). Returns ------- - sequences: list(str) - An array containing scores for each Interval - This function is agnostic of the strand information, - the position in the scores is corresponding to the interval - + list[str] + An array containing sequences for each Interval. + This function is agnostic of the strand information, + the position in the scores corresponds to the interval. + + Raises + ------ + Exception + If start or end position exceeds chromosome length. """ - sequences = [] + sequences: list[str] = [] chrom_lengths = self.chromosomes for i in intervals: if i.chrom not in list(chrom_lengths.keys()): warnings.warn( - "Chromosome {} does not appear in the fasta".format(i.chrom), + f"Chromosome {i.chrom} does not appear in the fasta", UserWarning, ) else: chrom_length = chrom_lengths[i.chrom] if i.start > chrom_length: raise Exception( - "Chromsome start point exceeds chromosome length: {}>{}".format( - i.start, chrom_length - ) + f"Chromosome start point exceeds chromosome length: {i.start}>{chrom_length}" ) elif i.end > chrom_length: raise Exception( - "Chromsome end point exceeds chromosome length: {}>{}".format( - i.end, chrom_length - ) + f"Chromosome end point exceeds chromosome length: {i.end}>{chrom_length}" ) seq = self.fasta.get_seq(i.chrom, i.start, i.end) sequences.append(seq) return sequences - def complement(self, seq): + def complement(self, seq: str) -> str: """Complement a FASTA sequence. Parameters ---------- - seq: str - String fasta sequence - + seq : str + String fasta sequence. Returns ------- - complement_seq: str - complemenet of input fasta + str + Complement of input fasta. """ complement_letters = {"A": "T", "C": "G", "T": "A", "G": "C"} seq = seq.upper() - comp = [] + comp: list[str] = [] for nuc in seq: if nuc in complement_letters: comp.append(complement_letters[nuc]) @@ -108,39 +124,32 @@ def complement(self, seq): comp.append(nuc) return "".join(comp) - def reverse_complement(self, seq): - """Reverse-complment a FASTA sequence. + def reverse_complement(self, seq: str) -> str: + """Reverse-complement a FASTA sequence. Parameters ---------- - seq: str - String fasta sequence - + seq : str + String fasta sequence. Returns ------- - complement_seq: str - complemenet of input fasta + str + Reverse complement of input fasta. """ seq = seq.upper() return self.complement(seq)[::-1] @property - def chromosomes(self): - """Return list of chromsome and their sizes - as in the fasta file. + def chromosomes(self) -> OrderedDict[str, int]: + """Return list of chromosome and their sizes as in the fasta file. Returns ------- - chroms : dict - Dictionary with {"chr": "Length"} format - - - .. currentmodule:: .FastaReader - .. autosummary:: - .FastaReader + OrderedDict[str, int] + Dictionary with {"chr": length} format. """ - chroms = OrderedDict() + chroms: OrderedDict[str, int] = OrderedDict() for chrom in list(self.fasta.keys()): chroms[chrom] = len(self.fasta[chrom]) return chroms diff --git a/ribotricer/gtf.py b/ribotricer/gtf.py index 5940e63..2c8780f 100644 --- a/ribotricer/gtf.py +++ b/ribotricer/gtf.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,20 +14,56 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + from collections import defaultdict +from typing import ClassVar + from tqdm.autonotebook import tqdm tqdm.pandas() -class GTFTrack(object): - """Class for feature in GTF file.""" - - standards = {"gene_biotype": "gene_type", "transcript_biotype": "transcript_type"} +class GTFTrack: + """Class for feature in GTF file. + + Attributes + ---------- + chrom : str + Chromosome name. + source : str + Source of the annotation. + feature : str + Feature type (e.g., 'exon', 'cds'). + start : int + Start position (1-based). + end : int + End position (1-based). + score : str + Score field. + strand : str + Strand ('+' or '-'). + frame : str + Reading frame. + """ + + standards: ClassVar[dict[str, str]] = { + "gene_biotype": "gene_type", + "transcript_biotype": "transcript_type", + } def __init__( - self, chrom, source, feature, start, end, score, strand, frame, attribute - ): + self, + chrom: str, + source: str, + feature: str, + start: int, + end: int, + score: str, + strand: str, + frame: str, + attribute: str, + ) -> None: self.chrom = chrom self.source = source self.feature = feature @@ -36,32 +72,51 @@ def __init__( self.score = score self.strand = strand self.frame = frame + + # Parse attributes for att in attribute.split(";"): if len(att.split()) == 2: k, v = att.strip().split() if k in GTFTrack.standards: k = GTFTrack.standards[k] setattr(self, k, v.strip('"')) + if not hasattr(self, "gene_name") and hasattr(self, "gene_id"): - setattr(self, "gene_name", self.gene_id) + self.gene_name = self.gene_id if not hasattr(self, "transcript_name") and hasattr(self, "transcript_id"): - setattr(self, "transcript_name", self.transcript_id) + self.transcript_name = self.transcript_id if not hasattr(self, "transcript_type") and not hasattr( self, GTFTrack.standards["transcript_biotype"] ): # transcript_type not set so set it to "assumed_protein_coding". - setattr(self, "transcript_type", "assumed_protein_coding") + self.transcript_type = "assumed_protein_coding" if not hasattr(self, "gene_type") and hasattr(self, "transcript_type"): - setattr(self, "gene_type", self.transcript_type) + self.gene_type = self.transcript_type + + # Type hints for dynamically set attributes + gene_id: str + gene_name: str + gene_type: str + transcript_id: str + transcript_name: str + transcript_type: str @classmethod - def from_string(cls, line): - """ + def from_string(cls, line: str) -> GTFTrack | None: + """Parse a GTF line into a GTFTrack. + Parameters ---------- - line: string - one line in gtf file + line : str + One line in GTF file. + + Returns + ------- + GTFTrack | None + Parsed track object, or None if line should be skipped. + Notes + ----- This method follows the fails-fast strategy and hence uses multiple returns, ultimately returning a line from the GTF parsed into a feature (chrom, start end etc.) @@ -89,27 +144,40 @@ def from_string(cls, line): return cls(chrom, source, feature, start, end, score, strand, frame, attribute) - def __repr__(self): + def __repr__(self) -> str: return str(self.__dict__) -class GTFReader(object): - """Class for reading and parseing gtf file.""" +class GTFReader: + """Class for reading and parsing gtf file. + + Attributes + ---------- + gtf_location : str + Path to GTF file. + transcript : defaultdict[str, list[GTFTrack]] + Dictionary mapping transcript IDs to their exon tracks. + cds : defaultdict[str, defaultdict[str, list[GTFTrack]]] + Dictionary mapping gene IDs to transcript IDs to CDS tracks. + """ + + def __init__(self, gtf_location: str) -> None: + """Initialize GTFReader. - def __init__(self, gtf_location): - """ Parameters - --------- - gtf_location : string - Path to gtf file + ---------- + gtf_location : str + Path to GTF file. """ self.gtf_location = gtf_location - self.transcript = defaultdict(list) - self.cds = defaultdict(lambda: defaultdict(list)) - # print('reading GTF file...') - with open(self.gtf_location, "r") as gtf: + self.transcript: defaultdict[str, list[GTFTrack]] = defaultdict(list) + self.cds: defaultdict[str, defaultdict[str, list[GTFTrack]]] = defaultdict( + lambda: defaultdict(list) + ) + + with open(self.gtf_location) as gtf: total_lines = len(["" for line in gtf]) - with open(self.gtf_location, "r") as gtf: + with open(self.gtf_location) as gtf: with tqdm(total=total_lines, unit="lines", leave=False) as pbar: for line in gtf: pbar.update() @@ -120,9 +188,7 @@ def __init__(self, gtf_location): tid = track.transcript_id except AttributeError: print( - "missing gene or transcript id {}:{}-{}".format( - track.chrom, track.start, track.end - ) + f"missing gene or transcript id {track.chrom}:{track.start}-{track.end}" ) else: if track.feature == "exon": diff --git a/ribotricer/infer_protocol.py b/ribotricer/infer_protocol.py index 04face3..5cf9a26 100644 --- a/ribotricer/infer_protocol.py +++ b/ribotricer/infer_protocol.py @@ -1,8 +1,8 @@ -"""infer experimental protocol""" +"""Infer experimental protocol""" # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,35 +14,49 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from collections import Counter -from .common import is_read_uniq_mapping +from __future__ import annotations + +from collections import Counter, defaultdict +from typing import TYPE_CHECKING import pysam -from quicksect import Interval +from quicksect import Interval, IntervalTree + +from .common import is_read_uniq_mapping + +if TYPE_CHECKING: + pass # required to convert numeric strands to '+/-' -NUM_TO_STRAND = {1: "+", -1: "-"} +NUM_TO_STRAND: dict[int, str] = {1: "+", -1: "-"} -def infer_protocol(bam, gene_interval_tree, prefix, n_reads=20000): - """Infer strandedness protocol given a bam file +def infer_protocol( + bam: str, + gene_interval_tree: defaultdict[str, IntervalTree], + prefix: str, + n_reads: int = 20000, +) -> str: + """Infer strandedness protocol given a bam file. Parameters ---------- - bam: str - Path to bam file - gene_interval_tree: defaultdict(IntervalTree) - chrom: (start, end, strand) - prefix: str - Prefix for protocol file - n_reads: int - Number of reads to use (downsampled) + bam : str + Path to bam file. + gene_interval_tree : defaultdict[str, IntervalTree] + Dictionary mapping chrom to (start, end, strand) interval tree. + prefix : str + Prefix for protocol file. + n_reads : int, optional + Number of reads to use (downsampled), by default 20000. Returns ------- - protocol: string - forward/reverse + str + Protocol string: 'forward' or 'reverse'. + Notes + ----- The strategy to do this is simple: keep a track of mapped reads and their strand and then tally if the location of their mapping has a gene defined @@ -52,13 +66,13 @@ def infer_protocol(bam, gene_interval_tree, prefix, n_reads=20000): gene strand respectively: Higher proportion of (++, --) implies forward protocol Higher proportion of (+-, -+) implies reverse protocol - Equal proportion of the above two scenairos implies unstranded protocol. - + Equal proportion of the above two scenarios implies unstranded protocol. """ iteration = 0 - bam = pysam.AlignmentFile(bam, "rb") - strandedness = Counter() - for read in bam.fetch(until_eof=True): + bam_file = pysam.AlignmentFile(bam, "rb") + strandedness: Counter[str] = Counter() + + for read in bam_file.fetch(until_eof=True): if iteration <= n_reads: if is_read_uniq_mapping(read): if read.is_reverse: @@ -68,21 +82,26 @@ def infer_protocol(bam, gene_interval_tree, prefix, n_reads=20000): mapped_start = read.reference_start mapped_end = read.reference_end chrom = read.reference_name - # get corresponding gene's strand - interval = list( - set( - gene_interval_tree[chrom].find( - Interval(mapped_start, mapped_end) + + if chrom is not None and mapped_end is not None: + # get corresponding gene's strand + interval = list( + set( + gene_interval_tree[chrom].find( + Interval(mapped_start, mapped_end) + ) ) ) - ) - if len(interval) == 1: - # Filter out genes with ambiguous strand info - # (those) that have a tx_start on opposite strands - gene_strand = NUM_TO_STRAND[interval[0].data] - # count table for mapped strand vs gene strand - strandedness["{}{}".format(mapped_strand, gene_strand)] += 1 - iteration += 1 + if len(interval) == 1: + # Filter out genes with ambiguous strand info + # (those) that have a tx_start on opposite strands + gene_strand = NUM_TO_STRAND[interval[0].data] + # count table for mapped strand vs gene strand + strandedness[f"{mapped_strand}{gene_strand}"] += 1 + iteration += 1 + + bam_file.close() + # Add pseudocounts strandedness["++"] += 1 strandedness["--"] += 1 @@ -93,17 +112,11 @@ def infer_protocol(bam, gene_interval_tree, prefix, n_reads=20000): forward_mapped_reads = strandedness["++"] + strandedness["--"] reverse_mapped_reads = strandedness["-+"] + strandedness["+-"] to_write = ( - "In total {} reads checked:\n" - '\tNumber of reads explained by "++, --": {} ({:.4f})\n' - '\tNumber of reads explained by "+-, -+": {} ({:.4f})\n' - ).format( - total, - forward_mapped_reads, - forward_mapped_reads / total, - reverse_mapped_reads, - reverse_mapped_reads / total, + f"In total {total} reads checked:\n" + f'\tNumber of reads explained by "++, --": {forward_mapped_reads} ({forward_mapped_reads / total:.4f})\n' + f'\tNumber of reads explained by "+-, -+": {reverse_mapped_reads} ({reverse_mapped_reads / total:.4f})\n' ) - with open("{}_protocol.txt".format(prefix), "w") as output: + with open(f"{prefix}_protocol.txt", "w") as output: output.write(to_write) protocol = "forward" if reverse_mapped_reads > forward_mapped_reads: diff --git a/ribotricer/interval.py b/ribotricer/interval.py index 0e1030b..b801f3e 100644 --- a/ribotricer/interval.py +++ b/ribotricer/interval.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,20 +14,44 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + class Interval: - """Class for interval - All the intervals used in this project is 1-based and closed + """Class for interval. + + All the intervals used in this project are 1-based and closed. + + Attributes + ---------- + chrom : str | None + Chromosome name. + start : int + Start position (1-based, inclusive). + end : int + End position (1-based, inclusive). + strand : str + Strand ('+' or '-'). """ - def __init__(self, chrom=None, start=1, end=1, strand="+"): + __slots__ = ("chrom", "start", "end", "strand") + + def __init__( + self, + chrom: str | None = None, + start: int = 1, + end: int = 1, + strand: str = "+", + ) -> None: self.chrom = chrom self.start = int(start) self.end = int(end) self.strand = strand - def __eq__(self, other): - """Override the default Equals behavior""" + def __eq__(self, other: object) -> bool: + """Override the default Equals behavior.""" + if not isinstance(other, Interval): + return NotImplemented return ( self.chrom == other.chrom and self.start == other.start @@ -35,5 +59,13 @@ def __eq__(self, other): and self.strand == other.strand ) - def __repr__(self): - return "{}\t{}\t{}\t{}".format(self.chrom, self.start, self.end, self.strand) + def __repr__(self) -> str: + return f"{self.chrom}\t{self.start}\t{self.end}\t{self.strand}" + + def __len__(self) -> int: + """Return the length of the interval.""" + return self.end - self.start + 1 + + def __hash__(self) -> int: + """Make Interval hashable.""" + return hash((self.chrom, self.start, self.end, self.strand)) diff --git a/ribotricer/learn_cutoff.py b/ribotricer/learn_cutoff.py index 269adac..092edef 100644 --- a/ribotricer/learn_cutoff.py +++ b/ribotricer/learn_cutoff.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,39 +14,50 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + import sys + import numpy as np import pandas as pd -from .common import mkdir_p -from .common import parent_dir -from .const import CUTOFF -from .const import MINIMUM_VALID_CODONS -from .const import MINIMUM_VALID_CODONS_RATIO -from .const import MINIMUM_READS_PER_CODON -from .const import MINIMUM_DENSITY_OVER_ORF +from .common import mkdir_p, parent_dir +from .const import ( + CUTOFF, + MINIMUM_DENSITY_OVER_ORF, + MINIMUM_READS_PER_CODON, + MINIMUM_VALID_CODONS, + MINIMUM_VALID_CODONS_RATIO, +) from .detect_orfs import detect_orfs def determine_cutoff_tsv( - ribo_tsvs, rna_tsvs, filter_by=["protein_coding"], sampling_ratio=0.33, reps=10000 -): - """Learn cutoff empirically from ribotricer generated ORF tsvs. + ribo_tsvs: list[str], + rna_tsvs: list[str], + filter_by: list[str] | None = None, + sampling_ratio: float = 0.33, + reps: int = 10000, +) -> None: + """Learn cutoff empirically from ribotricer generated ORF TSVs. Parameters ---------- - ribo_tsvs: list - List of filepath of ribotricer generated *translating_ORFs.tsv - for Ribo-seq samples - - rna_tsvs: list - List of filepath of ribotricer generated *translating_ORFs.tsv - for RNA-seq samples - Returns - ------- - cutoff: float - Suggested cutoff + ribo_tsvs : list[str] + List of filepaths of ribotricer generated *translating_ORFs.tsv + for Ribo-seq samples. + rna_tsvs : list[str] + List of filepaths of ribotricer generated *translating_ORFs.tsv + for RNA-seq samples. + filter_by : list[str] | None, optional + Transcript types to filter by, by default ['protein_coding']. + sampling_ratio : float, optional + Sampling ratio, by default 0.33. + reps : int, optional + Number of replicates, by default 10000. """ + if filter_by is None: + filter_by = ["protein_coding"] ribo_df = pd.DataFrame() for tsv in ribo_tsvs: df = pd.read_csv( @@ -65,7 +76,7 @@ def determine_cutoff_tsv( ) rna_df = pd.concat([rna_df, df]) - filter_by = list(map(lambda x: x.lower(), filter_by)) + filter_by = [x.lower() for x in filter_by] ribo_df_filtered = ribo_df.loc[ribo_df.ORF_type == "annotated"] ribo_df_filtered = ribo_df_filtered.loc[ @@ -111,69 +122,84 @@ def determine_cutoff_tsv( diff_all_mean = np.mean(diff_all) diff_all_std = np.std(diff_all) - print("sampling_ratio: {}".format(sampling_ratio)) - print("n_samples: {}".format(reps)) + print(f"sampling_ratio: {sampling_ratio}") + print(f"n_samples: {reps}") - print("ribo_phase_score_mean: {:.3f}".format(ribo_phase_score_mean)) - print("ribo_phase_score_median: {:.3f}".format(ribo_phase_score_median)) - print("ribo_phase_score_sd: {:.3f}".format(ribo_phase_score_sd)) + print(f"ribo_phase_score_mean: {ribo_phase_score_mean:.3f}") + print(f"ribo_phase_score_median: {ribo_phase_score_median:.3f}") + print(f"ribo_phase_score_sd: {ribo_phase_score_sd:.3f}") - print("rna_phase_score_mean: {:.3f}".format(rna_phase_score_mean)) - print("rna_phase_score_median: {:.3f}".format(rna_phase_score_median)) - print("rna_phase_score_sd: {:.3f}".format(rna_phase_score_sd)) + print(f"rna_phase_score_mean: {rna_phase_score_mean:.3f}") + print(f"rna_phase_score_median: {rna_phase_score_median:.3f}") + print(f"rna_phase_score_sd: {rna_phase_score_sd:.3f}") - print("diff_phase_score_sampled_mean: {:.3f}".format(diff_phase_score_mean)) - print("diff_phase_score_sampled_median: {:.3f}".format(diff_phase_score_median)) - print("diff_phase_score_sampled_sd: {:.3f}".format(diff_phase_score_sd)) + print(f"diff_phase_score_sampled_mean: {diff_phase_score_mean:.3f}") + print(f"diff_phase_score_sampled_median: {diff_phase_score_median:.3f}") + print(f"diff_phase_score_sampled_sd: {diff_phase_score_sd:.3f}") - print("diff_phase_score_all_mean: {:.3f}".format(diff_all_mean)) - print("diff_phase_score_all_median: {:.3f}".format(diff_all_median)) - print("diff_phase_score_all_sd: {:.3f}".format(diff_all_std)) + print(f"diff_phase_score_all_mean: {diff_all_mean:.3f}") + print(f"diff_phase_score_all_median: {diff_all_median:.3f}") + print(f"diff_phase_score_all_sd: {diff_all_std:.3f}") - print("recommended_cutoff: {:.3f}".format(diff_phase_score_median)) + print(f"recommended_cutoff: {diff_phase_score_median:.3f}") def determine_cutoff_bam( - ribo_bams, - rna_bams, - ribotricer_index, - prefix, - ribo_stranded_protocols=[], - rna_stranded_protocols=[], - filter_by=["protein_coding"], - sampling_ratio=0.33, - reps=10000, - phase_score_cutoff=CUTOFF, - min_valid_codons=MINIMUM_VALID_CODONS, - report_all=True, -): - """Learn cutoff emprically from the given data. + ribo_bams: list[str], + rna_bams: list[str], + ribotricer_index: str, + prefix: str, + ribo_stranded_protocols: list[str | None] | None = None, + rna_stranded_protocols: list[str | None] | None = None, + filter_by: list[str] | None = None, + sampling_ratio: float = 0.33, + reps: int = 10000, + phase_score_cutoff: float = CUTOFF, + min_valid_codons: int = MINIMUM_VALID_CODONS, + report_all: bool = True, +) -> None: + """Learn cutoff empirically from the given data. This uses the following steps: 1. Run ribotricer using a cutoff of 0 for both RNA and Ribo samples - 2. For each output of ribotricer, find the median difference between RNA and Ribo-seq - phase scores using the protein_coding annotated regions in the output. + 2. For each output of ribotricer, find the median difference between RNA + and Ribo-seq phase scores using the protein_coding annotated regions + in the output. Parameters ---------- - ribo_bams: list - List of filepaths to Ribo-seq bams - - rna_bams: list - List of filepaths to RNA-seq bams - - ribo_stranded_protocols: list - List of 'yes/no/reverse' - rna_stranded_protocols: list - List of 'yes/no/reverse' - - - Returns - ------- - cutoff: float - Suggested cutoff + ribo_bams : list[str] + List of filepaths to Ribo-seq BAMs. + rna_bams : list[str] + List of filepaths to RNA-seq BAMs. + ribotricer_index : str + Path to ribotricer index file. + prefix : str + Output prefix. + ribo_stranded_protocols : list[str | None] | None, optional + List of 'yes/no/reverse', by default None. + rna_stranded_protocols : list[str | None] | None, optional + List of 'yes/no/reverse', by default None. + filter_by : list[str] | None, optional + Transcript types to filter by, by default ['protein_coding']. + sampling_ratio : float, optional + Sampling ratio, by default 0.33. + reps : int, optional + Number of replicates, by default 10000. + phase_score_cutoff : float, optional + Phase score cutoff, by default CUTOFF. + min_valid_codons : int, optional + Minimum valid codons, by default MINIMUM_VALID_CODONS. + report_all : bool, optional + Whether to report all ORFs, by default True. """ + if ribo_stranded_protocols is None: + ribo_stranded_protocols = [] + if rna_stranded_protocols is None: + rna_stranded_protocols = [] + if filter_by is None: + filter_by = ["protein_coding"] if len(ribo_stranded_protocols) > 1: if len(ribo_stranded_protocols) != len(ribo_bams): sys.exit("Error: Ribo-seq protocol and bam file length mismatch") @@ -189,22 +215,18 @@ def determine_cutoff_bam( sample_str = "samples" if len(rna_bams) == 1: sample_str = "sample" - print( - "Running ribotricer on {} Ribo-seq {} ..... \n".format( - len(rna_bams), sample_str - ) - ) + print(f"Running ribotricer on {len(rna_bams)} Ribo-seq {sample_str} ..... \n") ribo_bams_renamed = dict( - zip(ribo_bams, ["ribo_bam_{}".format(i + 1) for i in range(len(ribo_bams))]) + zip(ribo_bams, [f"ribo_bam_{i + 1}" for i in range(len(ribo_bams))]) ) rna_bams_renamed = dict( - zip(rna_bams, ["rna_bam_{}".format(i + 1) for i in range(len(rna_bams))]) + zip(rna_bams, [f"rna_bam_{i + 1}" for i in range(len(rna_bams))]) ) rna_tsvs = [] ribo_tsvs = [] for bam, stranded in zip(ribo_bams, ribo_stranded_protocols): - bam_prefix = "{}__{}".format(prefix, ribo_bams_renamed[bam]) + bam_prefix = f"{prefix}__{ribo_bams_renamed[bam]}" mkdir_p(parent_dir(bam_prefix)) detect_orfs( bam, @@ -220,12 +242,10 @@ def determine_cutoff_bam( min_density_over_orf=MINIMUM_DENSITY_OVER_ORF, report_all=report_all, ) - ribo_tsvs.append("{}_translating_ORFs.tsv".format(bam_prefix)) - print( - "Running ribotricer on {} RNA-seq {} ..... \n".format(len(rna_bams), sample_str) - ) + ribo_tsvs.append(f"{bam_prefix}_translating_ORFs.tsv") + print(f"Running ribotricer on {len(rna_bams)} RNA-seq {sample_str} ..... \n") for bam, stranded in zip(rna_bams, rna_stranded_protocols): - bam_prefix = "{}__{}".format(prefix, rna_bams_renamed[bam]) + bam_prefix = f"{prefix}__{rna_bams_renamed[bam]}" mkdir_p(parent_dir(bam_prefix)) detect_orfs( bam, @@ -241,5 +261,5 @@ def determine_cutoff_bam( min_density_over_orf=MINIMUM_DENSITY_OVER_ORF, report_all=report_all, ) - rna_tsvs.append("{}_translating_ORFs.tsv".format(bam_prefix)) + rna_tsvs.append(f"{bam_prefix}_translating_ORFs.tsv") determine_cutoff_tsv(ribo_tsvs, rna_tsvs, filter_by, sampling_ratio, reps) diff --git a/ribotricer/metagene.py b/ribotricer/metagene.py index 50b6a3d..adaeb72 100644 --- a/ribotricer/metagene.py +++ b/ribotricer/metagene.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,20 +14,58 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from .statistics import phasescore -from .interval import Interval -from .const import CUTOFF, TYPICAL_OFFSET +from __future__ import annotations + import sys from collections import Counter, OrderedDict +from collections.abc import Iterator +from typing import TYPE_CHECKING import numpy as np import pandas as pd from tqdm.autonotebook import tqdm +from .const import CUTOFF, TYPICAL_OFFSET +from .interval import Interval +from .statistics import phasescore + +if TYPE_CHECKING: + from .orf import ORF + tqdm.pandas() +# Type aliases +AlignmentDict = dict[int, dict[str, Counter[tuple[str, int]]]] +MetageneDict = dict[int, tuple[pd.Series, pd.Series, float, int, float, int]] + -def next_genome_pos(ivs, max_positions, leader, trailer, reverse=False): +def next_genome_pos( + ivs: list[Interval], + max_positions: int, + leader: int, + trailer: int, + reverse: bool = False, +) -> Iterator[int]: + """Generate genome positions from intervals. + + Parameters + ---------- + ivs : list[Interval] + List of intervals. + max_positions : int + Maximum number of positions to yield. + leader : int + Number of positions to include upstream. + trailer : int + Number of positions to include downstream. + reverse : bool, optional + Whether to iterate in reverse, by default False. + + Yields + ------ + int + Genome position. + """ if len(ivs) == 0: return iter([]) cnt = 0 @@ -55,32 +93,37 @@ def next_genome_pos(ivs, max_positions, leader, trailer, reverse=False): def orf_coverage_length( - orf, alignments, length, max_positions, offset_5p=20, offset_3p=0 -): - """ + orf: ORF, + alignments: AlignmentDict, + length: int, + max_positions: int, + offset_5p: int = 20, + offset_3p: int = 0, +) -> tuple[pd.Series, pd.Series]: + """Calculate ORF coverage for a specific read length. + Parameters ---------- - orf: ORF - instance of ORF - alignments: dict(dict(Counter)) - alignments summarized from bam - length: int - the target length - max_positions: int - the number of nts to include - offset_5p: int - the number of nts to include from 5'prime - offset_3p: int - the number of nts to include from 3'prime + orf : ORF + Instance of ORF. + alignments : AlignmentDict + Alignments summarized from bam. + length : int + The target length. + max_positions : int + The number of nts to include. + offset_5p : int, optional + The number of nts to include from 5'prime, by default 20. + offset_3p : int, optional + The number of nts to include from 3'prime, by default 0. Returns ------- - from_start: Series - coverage for ORF for specific length aligned at start codon - from_stop: Series - coverage for ORF for specific length aligned at stop codon + tuple[pd.Series, pd.Series] + - from_start: Coverage for ORF for specific length aligned at start codon. + - from_stop: Coverage for ORF for specific length aligned at stop codon. """ - coverage = [] + coverage: list[int] = [] chrom = orf.chrom strand = orf.strand if strand == "-": @@ -115,43 +158,43 @@ def orf_coverage_length( def metagene_coverage( - cds, - alignments, - read_lengths, - prefix, - max_positions=600, - offset_5p=20, - offset_3p=0, - meta_min_reads=100000, -): - """ + cds: list[ORF], + alignments: AlignmentDict, + read_lengths: dict[int, int], + prefix: str, + max_positions: int = 600, + offset_5p: int = 20, + offset_3p: int = 0, + meta_min_reads: int = 100000, +) -> MetageneDict: + """Calculate metagene coverage profiles. + Parameters ---------- - cds: List[ORF] - list of cds - alignments: dict(dict(Counter)) - alignments summarized from bam - read_lengths: dict - key is the length, value is the number reads - prefix: str - prefix for the output file - max_positions: int - the number of nts to include - offset_5p: int - the number of nts to include from the 5'prime - offset_3p: int - the number of nts to include from the 3'prime - meta_min_reads: int - minimum number of reads for a read length to be considered + cds : list[ORF] + List of CDS ORFs. + alignments : AlignmentDict + Alignments summarized from bam. + read_lengths : dict[int, int] + Dictionary where key is the length, value is the number reads. + prefix : str + Prefix for the output file. + max_positions : int, optional + The number of nts to include, by default 600. + offset_5p : int, optional + The number of nts to include from the 5'prime, by default 20. + offset_3p : int, optional + The number of nts to include from the 3'prime, by default 0. + meta_min_reads : int, optional + Minimum number of reads for a read length to be considered, by default 100000. Returns ------- - metagenes: dict - key is the length, value is (from_start, from_stop, phasescore, - pval) + MetageneDict + Dictionary where key is the length, value is tuple of + (from_start, from_stop, phasescore_5p, valid_5p, phasescore_3p, valid_3p). """ - # print('calculating metagene profiles...') - metagenes = {} + metagenes: MetageneDict = {} # remove read length whose read number is small for length, reads in list(read_lengths.items()): @@ -159,11 +202,10 @@ def metagene_coverage( del read_lengths[length] for length in tqdm(read_lengths, unit="read-length", leave=False): - - metagene_coverage_start = pd.Series(dtype=float) - position_counter_start = Counter() - metagene_coverage_stop = pd.Series(dtype=float) - position_counter_stop = Counter() + metagene_coverage_start: pd.Series = pd.Series(dtype=float) + position_counter_start: Counter[int] = Counter() + metagene_coverage_stop: pd.Series = pd.Series(dtype=float) + position_counter_stop: Counter[int] = Counter() for orf in tqdm(cds, position=1, unit="ORFs", leave=False): from_start, from_stop = orf_coverage_length( @@ -189,10 +231,14 @@ def metagene_coverage( position_counter_stop ) != len(metagene_coverage_stop): raise RuntimeError("Metagene coverage and counter mismatch") - position_counter_start = pd.Series(position_counter_start) - metagene_coverage_start = metagene_coverage_start.div(position_counter_start) - position_counter_stop = pd.Series(position_counter_stop) - metagene_coverage_stop = metagene_coverage_stop.div(position_counter_stop) + position_counter_start_series = pd.Series(position_counter_start) + metagene_coverage_start = metagene_coverage_start.div( + position_counter_start_series + ) + position_counter_stop_series = pd.Series(position_counter_stop) + metagene_coverage_stop = metagene_coverage_stop.div( + position_counter_stop_series + ) phasescore_5p, valid_5p = phasescore(metagene_coverage_start.tolist()) phasescore_3p, valid_3p = phasescore(metagene_coverage_stop.tolist()) @@ -208,53 +254,46 @@ def metagene_coverage( to_write_5p = "fragment_length\toffset_5p\tprofile\tphase_score\tvalid_codons\n" to_write_3p = "fragment_length\toffset_3p\tprofile\tphase_score\tvalid_codons\n" for length in sorted(metagenes): - to_write_5p += "{}\t{}\t{}\t{}\t{}\n".format( - length, - offset_5p, - metagenes[length][0].tolist(), - metagenes[length][2], - metagenes[length][3], - ) - to_write_3p += "{}\t{}\t{}\t{}\t{}\n".format( - length, - offset_3p, - metagenes[length][1].tolist(), - metagenes[length][4], - metagenes[length][5], - ) + to_write_5p += f"{length}\t{offset_5p}\t{metagenes[length][0].tolist()}\t{metagenes[length][2]}\t{metagenes[length][3]}\n" + to_write_3p += f"{length}\t{offset_3p}\t{metagenes[length][1].tolist()}\t{metagenes[length][4]}\t{metagenes[length][5]}\n" - with open("{}_metagene_profiles_5p.tsv".format(prefix), "w") as output: + with open(f"{prefix}_metagene_profiles_5p.tsv", "w") as output: output.write(to_write_5p) - with open("{}_metagene_profiles_3p.tsv".format(prefix), "w") as output: + with open(f"{prefix}_metagene_profiles_3p.tsv", "w") as output: output.write(to_write_3p) return metagenes def align_metagenes( - metagenes, read_lengths, prefix, phase_score_cutoff=CUTOFF, remove_nonperiodic=False -): - """align metagene coverages to determine the lag of the psites, the - non-periodic read length will be discarded in this step + metagenes: MetageneDict, + read_lengths: dict[int, int], + prefix: str, + phase_score_cutoff: float = CUTOFF, + remove_nonperiodic: bool = False, +) -> OrderedDict[int, int]: + """Align metagene coverages to determine the lag of the psites. + + The non-periodic read length will be discarded in this step. Parameters ---------- - metagenes: dict - key is the length, value is the metagene coverage - read_lengths: dict - key is the length, value is the number of reads - prefix: str - prefix for output files - remove_nonperiodic: bool - Whether remove non-periodic read lengths + metagenes : MetageneDict + Dictionary where key is the length, value is the metagene coverage. + read_lengths : dict[int, int] + Dictionary where key is the length, value is the number of reads. + prefix : str + Prefix for output files. + phase_score_cutoff : float, optional + Phase score cutoff, by default CUTOFF. + remove_nonperiodic : bool, optional + Whether remove non-periodic read lengths, by default False. Returns ------- - psite_offsets: dict - key is the length, value is the offset + OrderedDict[int, int] + Dictionary where key is the length, value is the offset. """ - # print('aligning metagene profiles from different lengths...') - # discard non-periodic read lengths if remove_nonperiodic: for length, (_, _, coh, _, _, _) in list(metagenes.items()): @@ -264,28 +303,26 @@ def align_metagenes( if len(read_lengths) == 0: sys.exit( - "WARNING: no periodic read length found... using cutoff {}".format( - phase_score_cutoff - ) + f"WARNING: no periodic read length found... using cutoff {phase_score_cutoff}" ) - psite_offsets = OrderedDict() + psite_offsets: OrderedDict[int, int] = OrderedDict() base = n_reads = 0 for length, reads in list(read_lengths.items()): if reads > n_reads: base = length n_reads = reads reference = metagenes[base][0].values - to_write = "relative lag to base: {}\n".format(base) + to_write = f"relative lag to base: {base}\n" for length, (meta, _, _, _, _, _) in list(metagenes.items()): cov = meta.values xcorr = np.correlate(reference, cov, "full") origin = len(xcorr) // 2 bound = min(base, length) xcorr = xcorr[(origin - bound) : (origin + bound)] - lag = np.argmax(xcorr) - len(xcorr) // 2 + lag = int(np.argmax(xcorr) - len(xcorr) // 2) psite_offsets[length] = lag + TYPICAL_OFFSET - to_write += "\tlag of {}: {}\n".format(length, lag) - with open("{}_psite_offsets.txt".format(prefix), "w") as output: + to_write += f"\tlag of {length}: {lag}\n" + with open(f"{prefix}_psite_offsets.txt", "w") as output: output.write(to_write) return psite_offsets diff --git a/ribotricer/orf.py b/ribotricer/orf.py index 9ffdce5..afefe16 100644 --- a/ribotricer/orf.py +++ b/ribotricer/orf.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,28 +14,81 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + import sys +from typing import TYPE_CHECKING + from .interval import Interval +if TYPE_CHECKING: + from .gtf import GTFTrack + class ORF: - """Class for candidate ORF.""" + """Class for candidate ORF. + + Attributes + ---------- + category : str + ORF category (e.g., 'annotated', 'uORF', 'dORF', etc.). + tid : str + Transcript ID. + ttype : str + Transcript type. + gid : str + Gene ID. + gname : str + Gene name. + gtype : str + Gene type. + chrom : str + Chromosome name. + strand : str + Strand ('+' or '-'). + intervals : list[Interval] + List of intervals comprising the ORF. + oid : str + ORF ID (auto-generated). + seq : str + ORF sequence. + leader : str + 5' UTR sequence. + trailer : str + 3' UTR sequence. + """ + + __slots__ = ( + "category", + "tid", + "ttype", + "gid", + "gname", + "gtype", + "chrom", + "strand", + "intervals", + "oid", + "seq", + "leader", + "trailer", + ) def __init__( self, - category, - transcript_id, - transcript_type, - gene_id, - gene_name, - gene_type, - chrom, - strand, - intervals, - seq="", - leader="", - trailer="", - ): + category: str, + transcript_id: str, + transcript_type: str, + gene_id: str, + gene_name: str, + gene_type: str, + chrom: str, + strand: str, + intervals: list[Interval], + seq: str = "", + leader: str = "", + trailer: str = "", + ) -> None: self.category = category self.tid = transcript_id self.ttype = transcript_type @@ -47,34 +100,42 @@ def __init__( self.intervals = sorted(intervals, key=lambda x: x.start) start = self.intervals[0].start end = self.intervals[-1].end - self.oid = "{}_{}_{}_{}".format( - transcript_id, - start, - end, - sum([x.end - x.start + 1 for x in self.intervals]), - ) + self.oid = f"{transcript_id}_{start}_{end}_{sum([x.end - x.start + 1 for x in self.intervals])}" self.seq = seq self.leader = leader self.trailer = trailer @property - def start_codon(self): - """Return the first 3 bases from sequence""" + def start_codon(self) -> str | None: + """Return the first 3 bases from sequence. + + Returns + ------- + str | None + The start codon sequence, or None if sequence is too short. + """ if len(self.seq) < 3: return None return self.seq[:3] @classmethod - def from_string(cls, line): - """ + def from_string(cls, line: str) -> ORF | None: + """Create ORF from a ribotricer index file line. + Parameters ---------- - line: string - line for ribotricer index file generated by prepare_orfs + line : str + Line from ribotricer index file generated by prepare_orfs. + Returns + ------- + ORF | None + Parsed ORF object, or None if parsing fails. - This method uses a fail-fast stragy and hence multiple - returns. It ultimately retulrs an object correponding to the + Notes + ----- + This method uses a fail-fast strategy and hence multiple + returns. It ultimately returns an object corresponding to the parsed line. """ if not line: @@ -101,12 +162,12 @@ def from_string(cls, line): strand = fields[8] start_codon = fields[9] coordinate = fields[10] - intervals = [] + intervals: list[Interval] = [] for group in coordinate.split(","): start, end = group.split("-") - start = int(start) - end = int(end) - intervals.append(Interval(chrom, start, end, strand)) + start_pos = int(start) + end_pos = int(end) + intervals.append(Interval(chrom, start_pos, end_pos, strand)) return cls( category, tid, @@ -121,26 +182,50 @@ def from_string(cls, line): ) @classmethod - def from_tracks(cls, tracks, category, seq="", leader="", trailer=""): - """ + def from_tracks( + cls, + tracks: list[GTFTrack], + category: str, + seq: str = "", + leader: str = "", + trailer: str = "", + ) -> ORF | None: + """Create ORF from a list of GTF tracks. + Parameters ---------- - tracks: list of GTFTrack + tracks : list[GTFTrack] + List of GTF track objects. + category : str + ORF category. + seq : str, optional + ORF sequence. + leader : str, optional + 5' UTR sequence. + trailer : str, optional + 3' UTR sequence. + + Returns + ------- + ORF | None + Parsed ORF object, or None if parsing fails. - This method uses a fail-fast stragy and hence multiple - returns. It ultimately retulrs an object correponding to the + Notes + ----- + This method uses a fail-fast strategy and hence multiple + returns. It ultimately returns an object corresponding to the parsed line. """ if not tracks: return None - intervals = [] - tid = set() - ttype = set() - gid = set() - gname = set() - gtype = set() - chrom = set() - strand = set() + intervals: list[Interval] = [] + tid: set[str] = set() + ttype: set[str] = set() + gid: set[str] = set() + gname: set[str] = set() + gtype: set[str] = set() + chrom: set[str] = set() + strand: set[str] = set() required_attributes = [ "transcript_id", "transcript_type", @@ -165,7 +250,7 @@ def from_tracks(cls, tracks, category, seq="", leader="", trailer=""): except AttributeError: for attribute in required_attributes: if not hasattr(track, attribute): - print('missing attribute "{}" in {}'.format(attribute, track)) + print(f'missing attribute "{attribute}" in {track}') return None if ( len(tid) != 1 @@ -176,24 +261,24 @@ def from_tracks(cls, tracks, category, seq="", leader="", trailer=""): or len(chrom) != 1 or len(strand) != 1 ): - print("inconsistent tracks for ORF: {}".format(track)) + print(f"inconsistent tracks for ORF: {track}") return None - tid = list(tid)[0] - ttype = list(ttype)[0] - gid = list(gid)[0] - gname = list(gname)[0] - gtype = list(gtype)[0] - chrom = list(chrom)[0] - strand = list(strand)[0] + tid_val = list(tid)[0] + ttype_val = list(ttype)[0] + gid_val = list(gid)[0] + gname_val = list(gname)[0] + gtype_val = list(gtype)[0] + chrom_val = list(chrom)[0] + strand_val = list(strand)[0] return cls( category, - tid, - ttype, - gid, - gname, - gtype, - chrom, - strand, + tid_val, + ttype_val, + gid_val, + gname_val, + gtype_val, + chrom_val, + strand_val, intervals, seq, leader, diff --git a/ribotricer/orf_seq.py b/ribotricer/orf_seq.py index 8a33e5a..c4139b3 100644 --- a/ribotricer/orf_seq.py +++ b/ribotricer/orf_seq.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,113 +14,139 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from .fasta import FastaReader -from .interval import Interval -import pandas as pd +from __future__ import annotations + import sys +from typing import Final + +import pandas as pd from tqdm.autonotebook import tqdm +from .fasta import FastaReader +from .interval import Interval + tqdm.pandas() +# Codon to amino acid translation table +CODON_TABLE: Final[dict[str, str]] = { + "ATA": "I", + "ATC": "I", + "ATT": "I", + "ATG": "M", + "ACA": "T", + "ACC": "T", + "ACG": "T", + "ACT": "T", + "AAC": "N", + "AAT": "N", + "AAA": "K", + "AAG": "K", + "AGC": "S", + "AGT": "S", + "AGA": "R", + "AGG": "R", + "CTA": "L", + "CTC": "L", + "CTG": "L", + "CTT": "L", + "CCA": "P", + "CCC": "P", + "CCG": "P", + "CCT": "P", + "CAC": "H", + "CAT": "H", + "CAA": "Q", + "CAG": "Q", + "CGA": "R", + "CGC": "R", + "CGG": "R", + "CGT": "R", + "GTA": "V", + "GTC": "V", + "GTG": "V", + "GTT": "V", + "GCA": "A", + "GCC": "A", + "GCG": "A", + "GCT": "A", + "GAC": "D", + "GAT": "D", + "GAA": "E", + "GAG": "E", + "GGA": "G", + "GGC": "G", + "GGG": "G", + "GGT": "G", + "TCA": "S", + "TCC": "S", + "TCG": "S", + "TCT": "S", + "TTC": "F", + "TTT": "F", + "TTA": "L", + "TTG": "L", + "TAC": "Y", + "TAT": "Y", + "TAA": "_", + "TAG": "_", + "TGC": "C", + "TGT": "C", + "TGA": "_", + "TGG": "W", +} + -def translate_nt_to_aa(seq): - codon_table = { - "ATA": "I", - "ATC": "I", - "ATT": "I", - "ATG": "M", - "ACA": "T", - "ACC": "T", - "ACG": "T", - "ACT": "T", - "AAC": "N", - "AAT": "N", - "AAA": "K", - "AAG": "K", - "AGC": "S", - "AGT": "S", - "AGA": "R", - "AGG": "R", - "CTA": "L", - "CTC": "L", - "CTG": "L", - "CTT": "L", - "CCA": "P", - "CCC": "P", - "CCG": "P", - "CCT": "P", - "CAC": "H", - "CAT": "H", - "CAA": "Q", - "CAG": "Q", - "CGA": "R", - "CGC": "R", - "CGG": "R", - "CGT": "R", - "GTA": "V", - "GTC": "V", - "GTG": "V", - "GTT": "V", - "GCA": "A", - "GCC": "A", - "GCG": "A", - "GCT": "A", - "GAC": "D", - "GAT": "D", - "GAA": "E", - "GAG": "E", - "GGA": "G", - "GGC": "G", - "GGG": "G", - "GGT": "G", - "TCA": "S", - "TCC": "S", - "TCG": "S", - "TCT": "S", - "TTC": "F", - "TTT": "F", - "TTA": "L", - "TTG": "L", - "TAC": "Y", - "TAT": "Y", - "TAA": "_", - "TAG": "_", - "TGC": "C", - "TGT": "C", - "TGA": "_", - "TGG": "W", - } +def translate_nt_to_aa(seq: str) -> str: + """Translate nucleotide sequence to amino acid sequence. + + Parameters + ---------- + seq : str + Nucleotide sequence. + + Returns + ------- + str + Amino acid sequence. + """ protein = "" if len(seq) % 3 == 0: for i in range(0, len(seq), 3): codon = seq[i : i + 3] if "N" in codon: protein += "X" - elif codon not in codon_table: + elif codon not in CODON_TABLE: sys.stderr.write( - "Found unknown codon {}. Substituing with X..\n".format(codon) + f"Found unknown codon {codon}. Substituting with X..\n" ) + protein += "X" else: - protein += codon_table[codon] + protein += CODON_TABLE[codon] return protein -def orf_seq(ribotricer_index, genome_fasta, saveto, translate=False): +def orf_seq( + ribotricer_index: str, + genome_fasta: str, + saveto: str, + translate: bool = False, +) -> None: """Generate sequence for ribotricer annotation. Parameters - ----------- - - ribotricer_index: string - Path to ribotricer generate annotation - genome_Fasta: string - Path to genome fasta - - saveto: string - Path to output + ---------- + ribotricer_index : str + Path to ribotricer generated annotation. + genome_fasta : str + Path to genome fasta. + saveto : str + Path to output file. + translate : bool, optional + Whether to translate to protein sequence, by default False. """ fasta = FastaReader(genome_fasta) annotation_df = pd.read_csv(ribotricer_index, sep="\t") + with open(saveto, "w") as fh: fh.write("ORF_ID\tsequence\n") for idx, row in tqdm(annotation_df.iterrows(), total=annotation_df.shape[0]): @@ -128,13 +154,14 @@ def orf_seq(ribotricer_index, genome_fasta, saveto, translate=False): orf_id = row.ORF_ID coordinates = row.coordinate.split(",") strand = row.strand - intervals = [] + intervals: list[Interval] = [] seq = "" + for coordinate in coordinates: start, stop = coordinate.split("-") - start = int(start) - stop = int(stop) - interval = Interval(chrom, start, stop, strand) + start_pos = int(start) + stop_pos = int(stop) + interval = Interval(chrom, start_pos, stop_pos, strand) intervals.append(interval) seq = ("").join(fasta.query(intervals)) @@ -143,10 +170,10 @@ def orf_seq(ribotricer_index, genome_fasta, saveto, translate=False): if translate: if len(seq) % 3 != 0: sys.stderr.write( - "WARNING: Sequence length with ORF ID '{orf_id}' is not " + f"WARNING: Sequence length with ORF ID '{orf_id}' is not " "a multiple of three. Output sequence might be " "truncated.\n" ) seq = seq[0 : (len(seq) // 3) * 3] seq = translate_nt_to_aa(seq) - fh.write("{}\t{}\n".format(orf_id, seq)) + fh.write(f"{orf_id}\t{seq}\n") diff --git a/ribotricer/plotting.py b/ribotricer/plotting.py index 6aa3c00..944a4e4 100644 --- a/ribotricer/plotting.py +++ b/ribotricer/plotting.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,27 +14,42 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -import numpy as np +from __future__ import annotations + import matplotlib +import numpy as np +import pandas as pd matplotlib.use("Agg") -# ADS: verify that matplotlib.use("Agg") must precede imports below -import matplotlib.pyplot as plt # noqa E402 -from matplotlib.backends.backend_pdf import PdfPages # noqa E402 +import matplotlib.pyplot as plt # noqa: E402 +from matplotlib.backends.backend_pdf import PdfPages # noqa: E402 +matplotlib.rcParams["font.family"] = "sans-serif" +matplotlib.rcParams["font.sans-serif"] = [ + "Arial", + "Helvetica", + "Liberation Sans", + "Nimbus Sans", + "FreeSans", + "DejaVu Sans", +] matplotlib.rcParams["pdf.fonttype"] = 42 matplotlib.rcParams["ps.fonttype"] = 42 +# Type alias for metagene data +MetageneDict = dict[int, tuple[pd.Series, pd.Series, float, int, float, int]] + + +def plot_read_lengths(read_lengths: dict[int, int], prefix: str) -> None: + """Plot read length distribution. -def plot_read_lengths(read_lengths, prefix): - """ Parameters ---------- - read_lengths: dict - key is the length, value is the number of reads - prefix: str - prefix for the output file + read_lengths : dict[int, int] + Dictionary where key is the length, value is the number of reads. + prefix : str + Prefix for the output file. """ fig, ax = plt.subplots() x = sorted(read_lengths.keys()) @@ -44,24 +59,33 @@ def plot_read_lengths(read_lengths, prefix): ax.set_ylabel("Number of reads") ax.set_title("Read length distribution") fig.tight_layout() - fig.savefig("{}_read_length_dist.pdf".format(prefix)) + fig.savefig(f"{prefix}_read_length_dist.pdf") plt.close() -def plot_metagene(metagenes, read_lengths, prefix, offset=200): - """ +def plot_metagene( + metagenes: MetageneDict, + read_lengths: dict[int, int], + prefix: str, + offset: int = 200, +) -> None: + """Plot metagene profiles. + Parameters ---------- - metagenes: dict - key is the length, value is the metagene coverage - read_lengths: dict - key is the length, value is the number of reads - prefix: str - prefix for the output file + metagenes : MetageneDict + Dictionary where key is the length, value is the metagene coverage tuple. + read_lengths : dict[int, int] + Dictionary where key is the length, value is the number of reads. + prefix : str + Prefix for the output file. + offset : int, optional + Number of positions to show, by default 200. """ total_reads = sum(read_lengths.values()) frame_colors = ["#fc8d62", "#66c2a5", "#8da0cb"] - with PdfPages("{}_metagene_plots.pdf".format(prefix)) as pdf: + + with PdfPages(f"{prefix}_metagene_plots.pdf") as pdf: for length in sorted(metagenes): # TODO: This only consider the 5' end, should be generalized to 3' metagene_cov_start, metagene_cov_stop, coh, valid, _, _ = metagenes[length] @@ -86,9 +110,7 @@ def plot_metagene(metagenes, read_lengths, prefix, offset=200): ax.set_xlabel("Distance from start codon (nt)") ax.set_ylabel("Normalized mean reads") ax.set_title( - ("{} nt reads, proportion: {:.2%}\nphase_score: {:.2}").format( - length, ratio, coh - ) + f"{length} nt reads, proportion: {ratio:.2%}\nphase_score: {coh:.2}" ) # plot distance from stop codon diff --git a/ribotricer/prepare_orfs.py b/ribotricer/prepare_orfs.py index cb753ad..a6a452e 100644 --- a/ribotricer/prepare_orfs.py +++ b/ribotricer/prepare_orfs.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,41 +14,47 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from collections import defaultdict +from __future__ import annotations + import datetime import re +from collections import defaultdict + +from tqdm.autonotebook import tqdm from .common import merge_intervals from .fasta import FastaReader -from .gtf import GTFReader +from .gtf import GTFReader, GTFTrack from .interval import Interval from .orf import ORF -from tqdm.autonotebook import tqdm - tqdm.pandas() +# Type alias for CDS ORF dictionary +CDSOrfs = defaultdict[str, defaultdict[str, ORF]] + + +def tracks_to_ivs(tracks: list[GTFTrack]) -> list[Interval]: + """Convert GTF tracks to intervals. -def tracks_to_ivs(tracks): - """ Parameters ---------- - tracks: List[GTFTrack] - list of gtf tracks + tracks : list[GTFTrack] + List of GTF tracks. Returns ------- - intervals: List[Interval] - list of Interval + list[Interval] + List of Interval objects. """ - chrom = {track.chrom for track in tracks} - strand = {track.strand for track in tracks} - if len(chrom) != 1 or len(strand) != 1: + chrom_set = {track.chrom for track in tracks} + strand_set = {track.strand for track in tracks} + if len(chrom_set) != 1 or len(strand_set) != 1: print("fail to fetch seq: inconsistent chrom or strand") - intervals = [] + intervals: list[Interval] = [] else: - chrom = list(chrom)[0] - strand = list(strand)[0] + chrom = list(chrom_set)[0] + strand = list(strand_set)[0] intervals = [ Interval(chrom, track.start, track.end, strand) for track in tracks ] @@ -56,33 +62,36 @@ def tracks_to_ivs(tracks): return intervals -def transcript_to_genome_iv(start, end, intervals, reverse=False): - """ +def transcript_to_genome_iv( + start: int, + end: int, + intervals: list[Interval], + reverse: bool = False, +) -> list[Interval]: + """Convert transcript coordinates to genome coordinates. + Parameters ---------- - start: int - start position in transcript - 0-based closed - end: int - end position in transcript - 0-based closed - intervals: List[Interval] - coordinate in genome - 1-based closed - reverse: bool - whether if it is on the reverse strand + start : int + Start position in transcript (0-based closed). + end : int + End position in transcript (0-based closed). + intervals : list[Interval] + Coordinate in genome (1-based closed). + reverse : bool, optional + Whether on the reverse strand, by default False. Returns ------- - ivs: List[Interval] - the coordinate for start, end in genome + list[Interval] + The coordinate for start, end in genome. """ total_len = sum(i.end - i.start + 1 for i in intervals) if reverse: start, end = total_len - end - 1, total_len - start - 1 - ivs = [] - start_genome = None - end_genome = None + ivs: list[Interval] = [] + start_genome: int | None = None + end_genome: int | None = None # find start in genome cur = 0 @@ -111,19 +120,20 @@ def transcript_to_genome_iv(start, end, intervals, reverse=False): return ivs -def fetch_seq(fasta, tracks): - """ +def fetch_seq(fasta: FastaReader | str, tracks: list[GTFTrack]) -> str: + """Fetch sequence for GTF tracks. + Parameters ---------- - fasta: FastaReader - instance of FastaReader - tracks: List[GTFTrack] - list of gtf track + fasta : FastaReader | str + Instance of FastaReader or path to FASTA file. + tracks : list[GTFTrack] + List of GTF tracks. Returns ------- - merged_seq: str - combined seqeunce for the region + str + Combined sequence for the region. """ intervals = tracks_to_ivs(tracks) if not isinstance(fasta, FastaReader): @@ -136,32 +146,36 @@ def fetch_seq(fasta, tracks): return merged_seq -def search_orfs(fasta, intervals, min_orf_length, start_codons, stop_codons, longest): - """ +def search_orfs( + fasta: FastaReader | str, + intervals: list[Interval], + min_orf_length: int, + start_codons: set[str], + stop_codons: set[str], + longest: bool, +) -> list[tuple[list[Interval], str, str, str]]: + """Search for ORFs in intervals. + Parameters ---------- - fasta: FastaReader - instance of FastaReader - intervals: List[Interval] - list of intervals - min_orf_length: int - minimum length (nts) of ORF to include - start_codons: set - set of start codons - stop_codons: set - set of stop codons - longest: bool - whether to choose the most upstream start codon when multiple in - frame ones exist + fasta : FastaReader | str + Instance of FastaReader or path to FASTA file. + intervals : list[Interval] + List of intervals. + min_orf_length : int + Minimum length (nts) of ORF to include. + start_codons : set[str] + Set of start codons. + stop_codons : set[str] + Set of stop codons. + longest : bool + Whether to choose the most upstream start codon when multiple + in-frame ones exist. Returns ------- - orfs: list - list of (List[Interval], seq, leader, trailer) - list of intervals for candidate ORF - seq: sequence for the candidate ORF - leader: sequence upstream of the ORF - trailer: sequence downstream of the ORF + list[tuple[list[Interval], str, str, str]] + List of (intervals, seq, leader, trailer) tuples. """ if not intervals: return [] @@ -215,22 +229,24 @@ def search_orfs(fasta, intervals, min_orf_length, start_codons, stop_codons, lon return orfs -def check_orf_type(orf, cds_orfs): - """ +def check_orf_type(orf: ORF, cds_orfs: CDSOrfs) -> str: + """Determine the ORF type relative to annotated CDS. + Parameters ---------- - orf: GTFReader - instance of GTFReader - cds_orfs: FastaReader - instance of FastaReader + orf : ORF + Instance of ORF. + cds_orfs : CDSOrfs + Dictionary of annotated CDS ORFs. Returns ------- - otype: str - Type of the candidate ORF + str + Type of the candidate ORF. - This method uses a fail-fast strategy - and hence multiple returns. + Notes + ----- + This method uses a fail-fast strategy and hence multiple returns. """ if orf.gid not in cds_orfs: return "novel" @@ -260,26 +276,33 @@ def check_orf_type(orf, cds_orfs): def prepare_orfs( - gtf, fasta, prefix, min_orf_length, start_codons, stop_codons, longest -): - """ + gtf: GTFReader | str, + fasta: FastaReader | str, + prefix: str, + min_orf_length: int, + start_codons: set[str], + stop_codons: set[str], + longest: bool, +) -> None: + """Prepare candidate ORFs from GTF and FASTA files. + Parameters ---------- - gtf: GTFReader - instance of GTFReader - fasta: FastaReader - instance of FastaReader - prefix: str - prefix for output file - min_orf_length: int - minimum length (nts) of ORF to include - start_codons: set - set of start codons - stop_codons: set - set of stop codons - longest: bool - whether to choose the most upstream start codon when multiple in - frame ones exist + gtf : GTFReader | str + Instance of GTFReader or path to GTF file. + fasta : FastaReader | str + Instance of FastaReader or path to FASTA file. + prefix : str + Prefix for output file. + min_orf_length : int + Minimum length (nts) of ORF to include. + start_codons : set[str] + Set of start codons. + stop_codons : set[str] + Set of stop codons. + longest : bool + Whether to choose the most upstream start codon when multiple + in-frame ones exist. """ now = datetime.datetime.now() @@ -361,9 +384,7 @@ def prepare_orfs( to_write = "\t".join(columns) formatter = "{}\t" * (len(columns) - 1) + "{}\n" for orf in tqdm(candidate_orfs, unit="ORFs"): - coordinate = ",".join( - ["{}-{}".format(iv.start, iv.end) for iv in orf.intervals] - ) + coordinate = ",".join([f"{iv.start}-{iv.end}" for iv in orf.intervals]) if orf.start_codon in start_codons: to_write += formatter.format( orf.oid, @@ -379,7 +400,7 @@ def prepare_orfs( coordinate, ) - with open("{}_candidate_orfs.tsv".format(prefix), "w") as output: + with open(f"{prefix}_candidate_orfs.tsv", "w") as output: output.write(to_write) now = datetime.datetime.now() print(now.strftime("%b %d %H:%M:%S ... finished ribotricer prepare-orfs")) diff --git a/ribotricer/statistics.py b/ribotricer/statistics.py index 81c4619..d05a1ba 100644 --- a/ribotricer/statistics.py +++ b/ribotricer/statistics.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,56 +14,59 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -from math import sin, cos, pi, sqrt +from __future__ import annotations + import warnings +from collections.abc import Sequence +from math import cos, pi, sin, sqrt import numpy as np -from scipy import stats -from scipy import signal +from numpy.typing import ArrayLike +from scipy import signal, stats -def pvalue(x, N): - """Calculate p-value for phase score +def pvalue(x: float, N: int) -> float: + """Calculate p-value for phase score. Parameters ---------- - x: double - phase score - N: int - number of valid codons + x : float + Phase score. + N : int + Number of valid codons. Returns ------- - pval: double - p-value for the phase score + float + P-value for the phase score. """ df, nc = 2, 2.0 / (N - 1) x = 2 * N**2 * x / (N - 1) - return stats.ncx2.sf(x, df, nc) + return float(stats.ncx2.sf(x, df, nc)) -def phasescore(original_values): +def phasescore(original_values: Sequence[float] | ArrayLike) -> tuple[float, int]: """Calculate phase score of a given signal. Parameters ---------- - values : array like - List of value + original_values : Sequence[float] | ArrayLike + List of coverage values. Returns ------- - coh : float - Periodicity score calculated as - coherence between input and idea 1-0-0 signal - - valid: int - number of valid codons used for calculation - + tuple[float, int] + Tuple of (periodicity_score, valid_codons). + - coh: Periodicity score calculated as coherence between + input and ideal 1-0-0 signal. + - valid: Number of valid codons used for calculation. """ - coh, valid = 0.0, -1 + coh: float = 0.0 + valid: int = -1 + for frame in [0, 1, 2]: - values = original_values[frame:] - normalized_values = [] + values = list(original_values)[frame:] + normalized_values: list[float] = [] i = 0 while i + 2 < len(values): if values[i] == values[i + 1] == values[i + 2] == 0: @@ -91,12 +94,12 @@ def phasescore(original_values): if length == 0: coh, valid = (0.0, 0) else: - normalized_values = np.array(normalized_values[:length]) + normalized_arr = np.array(normalized_values[:length]) uniform_signal = np.array([1, 0, 0] * (len(normalized_values) // 3)) with warnings.catch_warnings(): warnings.simplefilter("ignore") f, Cxy = signal.coherence( - normalized_values, + normalized_arr, uniform_signal, window=np.array([1.0, 1.0, 1.0]), nperseg=3, @@ -108,4 +111,5 @@ def phasescore(original_values): valid = length // 3 if valid == -1: valid = length // 3 + return np.sqrt(coh), valid diff --git a/ribotricer/utils.py b/ribotricer/utils.py index 563ca44..76a9fc2 100644 --- a/ribotricer/utils.py +++ b/ribotricer/utils.py @@ -2,7 +2,7 @@ # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,15 +14,20 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. +from __future__ import annotations + from collections import defaultdict -from .statistics import phasescore +from typing import Final import numpy as np +from numpy.typing import NDArray from tqdm.autonotebook import tqdm +from .statistics import phasescore + tqdm.pandas() -CODON_TO_AA = { +CODON_TO_AA: Final[dict[str, str]] = { "ATA": "I", "ATC": "I", "ATT": "I", @@ -90,23 +95,24 @@ } -def parse_ccds(annotation, orfs, saveto): - """ +def parse_ccds(annotation: str, orfs: str, saveto: str) -> None: + """Parse CCDS annotations. + Parameters ---------- - annotation: str - Path for annotation files of candidate ORFs - orfs: str - Path for translating ORFs - saveto: str - output file name + annotation : str + Path for annotation files of candidate ORFs. + orfs : str + Path for translating ORFs. + saveto : str + Output file name. """ anno_oids = [] real_oids = [] ccds = defaultdict(list) - with open(annotation, "r") as anno: + with open(annotation) as anno: total_lines = len(["" for line in anno]) - with open(annotation, "r") as anno: + with open(annotation) as anno: with tqdm(total=total_lines) as pbar: # Skip header anno.readline() @@ -125,9 +131,9 @@ def parse_ccds(annotation, orfs, saveto): anno_oids.append(oid) ccds_orfs = {} - with open(orfs, "r") as orf: + with open(orfs) as orf: total_lines = len(["" for line in orf]) - with open(orfs, "r") as orf: + with open(orfs) as orf: with tqdm(total=total_lines) as pbar: # Skip header orf.readline() @@ -147,7 +153,7 @@ def parse_ccds(annotation, orfs, saveto): ccds_orfs[oid] = (count, corr, pval) real_oids.append(oid) - rename = {x: y for (x, y) in zip(anno_oids, real_oids)} + rename = dict(zip(anno_oids, real_oids)) to_write = "Gene_ID\tCount\tPeriodicity\tPval\n" n_genes = 0 for gid in ccds: @@ -158,21 +164,38 @@ def parse_ccds(annotation, orfs, saveto): t_cnt, t_corr, t_pval = ccds_orfs[oid] if t_corr >= corr: count, corr, pval = (t_cnt, t_corr, t_pval) - to_write += "{}\t{}\t{}\t{}\n".format(gid, count, corr, pval) + to_write += f"{gid}\t{count}\t{corr}\t{pval}\n" with open(saveto, "w") as output: output.write(to_write) -def benchmark(rna_file, ribo_file, prefix, cutoff=5): +def benchmark( + rna_file: str, + ribo_file: str, + prefix: str, + cutoff: int = 5, +) -> None: + """Benchmark RNA vs Ribo profiles. - rna = {} - ribo = {} + Parameters + ---------- + rna_file : str + Path to RNA profile file. + ribo_file : str + Path to Ribo profile file. + prefix : str + Output prefix. + cutoff : int, optional + Minimum coverage cutoff, by default 5. + """ + rna: dict[str, list[int]] = {} + ribo: dict[str, list[int]] = {} print("reading RNA profiles") - with open(rna_file, "r") as orf: + with open(rna_file) as orf: total_lines = len(["" for line in orf]) - with open(rna_file, "r") as orf: + with open(rna_file) as orf: with tqdm(total=total_lines) as pbar: for line in orf: pbar.update() @@ -184,9 +207,9 @@ def benchmark(rna_file, ribo_file, prefix, cutoff=5): rna[ID] = cov print("reading Ribo profiles") - with open(ribo_file, "r") as orf: + with open(ribo_file) as orf: total_lines = len(["" for line in orf]) - with open(ribo_file, "r") as orf: + with open(ribo_file) as orf: with tqdm(total=total_lines) as pbar: for line in orf: pbar.update() @@ -206,15 +229,27 @@ def benchmark(rna_file, ribo_file, prefix, cutoff=5): ribo_coh, ribo_valid = phasescore(ribo[ID]) ribo_cov = ribo_valid / len(ribo[ID]) - to_write += "{}\t{}\t{}\t{}\t{}\n".format( - ID, ribo_coh, rna_coh, ribo_cov, rna_cov - ) - with open("{}_results.txt".format(prefix), "w") as output: + to_write += f"{ID}\t{ribo_coh}\t{rna_coh}\t{ribo_cov}\t{rna_cov}\n" + with open(f"{prefix}_results.txt", "w") as output: output.write(to_write) -def angle(cov, frame): - ans = [] +def angle(cov: list[int], frame: int) -> tuple[list[float], int]: + """Compute angles for coverage profile. + + Parameters + ---------- + cov : list[int] + Coverage profile. + frame : int + Frame offset. + + Returns + ------- + tuple[list[float], int] + Tuple of (angles, number of zero vectors). + """ + ans: list[float] = [] nzeros = 0 cov = cov[frame:] i = 0 @@ -233,16 +268,36 @@ def angle(cov, frame): return ans, nzeros -def theta_dist(rna_file, ribo_file, frame_file, prefix, cutoff=5): +def theta_dist( + rna_file: str, + ribo_file: str, + frame_file: str, + prefix: str, + cutoff: int = 5, +) -> None: + """Compute theta distribution from RNA and Ribo profiles. - rna = {} - ribo = {} - frame = {} + Parameters + ---------- + rna_file : str + Path to RNA profile file. + ribo_file : str + Path to Ribo profile file. + frame_file : str + Path to frame file. + prefix : str + Output prefix. + cutoff : int, optional + Minimum coverage cutoff, by default 5. + """ + rna: dict[str, list[int]] = {} + ribo: dict[str, list[int]] = {} + frame: dict[str, int] = {} print("reading frame file") - with open(frame_file, "r") as frame_r: + with open(frame_file) as frame_r: total_lines = len(["" for line in frame_r]) - with open(frame_file, "r") as frame_r: + with open(frame_file) as frame_r: with tqdm(total=total_lines) as pbar: for line in frame_r: pbar.update() @@ -250,9 +305,9 @@ def theta_dist(rna_file, ribo_file, frame_file, prefix, cutoff=5): frame[name] = int(frame_n) print("reading RNA profiles") - with open(rna_file, "r") as orf: + with open(rna_file) as orf: total_lines = len(["" for line in orf]) - with open(rna_file, "r") as orf: + with open(rna_file) as orf: with tqdm(total=total_lines) as pbar: for line in orf: pbar.update() @@ -264,9 +319,9 @@ def theta_dist(rna_file, ribo_file, frame_file, prefix, cutoff=5): rna[ID] = cov print("reading Ribo profiles") - with open(ribo_file, "r") as orf: + with open(ribo_file) as orf: total_lines = len(["" for line in orf]) - with open(ribo_file, "r") as orf: + with open(ribo_file) as orf: with tqdm(total=total_lines) as pbar: for line in orf: pbar.update() @@ -302,31 +357,41 @@ def theta_dist(rna_file, ribo_file, frame_file, prefix, cutoff=5): mean = total_reads / total_length poisson_cov = np.random.poisson(mean, total_length) poisson_angles, poisson_zeros = angle(poisson_cov, 0) - with open("{}_angle_stats.txt".format(prefix), "w") as output: - output.write("total_rna_reads: {}\n".format(total_reads)) - output.write("total_rna_ccds_length: {}\n".format(total_length)) - output.write("total_ribo_reads: {}\n".format(total_ribo_reads)) - output.write("total_ribo_ccds_length: {}\n".format(total_ribo_length)) - output.write("mean reads: {}\n".format(mean)) - output.write("rna zero vectors: {}\n".format(rna_zeros)) - output.write("poisson zero vectors: {}\n".format(poisson_zeros)) - output.write("ribo zero vectors: {}\n".format(ribo_zeros)) - with open("{}_rna_angles.txt".format(prefix), "w") as output: + with open(f"{prefix}_angle_stats.txt", "w") as output: + output.write(f"total_rna_reads: {total_reads}\n") + output.write(f"total_rna_ccds_length: {total_length}\n") + output.write(f"total_ribo_reads: {total_ribo_reads}\n") + output.write(f"total_ribo_ccds_length: {total_ribo_length}\n") + output.write(f"mean reads: {mean}\n") + output.write(f"rna zero vectors: {rna_zeros}\n") + output.write(f"poisson zero vectors: {poisson_zeros}\n") + output.write(f"ribo zero vectors: {ribo_zeros}\n") + with open(f"{prefix}_rna_angles.txt", "w") as output: output.write("\n".join(map(str, rna_angles))) - with open("{}_ribo_angles.txt".format(prefix), "w") as output: + with open(f"{prefix}_ribo_angles.txt", "w") as output: output.write("\n".join(map(str, ribo_angles))) - with open("{}_poisson_angles.txt".format(prefix), "w") as output: + with open(f"{prefix}_poisson_angles.txt", "w") as output: output.write("\n".join(map(str, poisson_angles))) -def theta_rna(rna_file, prefix, cutoff=10): +def theta_rna(rna_file: str, prefix: str, cutoff: int = 10) -> None: + """Compute theta distribution from RNA profiles. - rna = {} + Parameters + ---------- + rna_file : str + Path to RNA profile file. + prefix : str + Output prefix. + cutoff : int, optional + Minimum coverage cutoff, by default 10. + """ + rna: dict[str, list[int]] = {} print("reading RNA profiles") - with open(rna_file, "r") as orf: + with open(rna_file) as orf: total_lines = len(["" for line in orf]) - with open(rna_file, "r") as orf: + with open(rna_file) as orf: with tqdm(total=total_lines) as pbar: # Skip header orf.readline() @@ -343,32 +408,47 @@ def theta_rna(rna_file, prefix, cutoff=10): rna_angles = [] for ID in tqdm(list(rna.keys())): rna_angles += angle(rna[ID], 0) - with open("{}_raw_rna_angles.txt".format(prefix), "w") as output: + with open(f"{prefix}_raw_rna_angles.txt", "w") as output: output.write("\n".join(map(str, rna_angles))) -def _nucleotide_to_codon_profile(profile): - """Summarize nucleotid profile to a codon level profile""" +def _nucleotide_to_codon_profile( + profile: str | list[int], +) -> NDArray[np.int64]: + """Summarize nucleotide profile to a codon level profile. + + Parameters + ---------- + profile : str | list[int] + Nucleotide profile as string or list. + + Returns + ------- + NDArray[np.int64] + Codon level profile. + """ if isinstance(profile, str): profile = eval(profile) - profile = np.array(profile) - codon_profile = np.add.reduceat(profile, range(0, len(profile), 3)) + profile_arr = np.array(profile) + codon_profile: NDArray[np.int64] = np.add.reduceat( + profile_arr, range(0, len(profile_arr), 3) + ) return codon_profile -def summarize_profile_to_codon_level(detected_orfs, saveto): - """Collapse nucleotide level profiles in ribotricer to codon leve. +def summarize_profile_to_codon_level(detected_orfs: str, saveto: str) -> None: + """Collapse nucleotide level profiles in ribotricer to codon level. Parameters ---------- - ribotricer_output: string - Path to ribotricer detect-orfs output - saveto: string - Path to write output to + detected_orfs : str + Path to ribotricer detect-orfs output. + saveto : str + Path to write output to. """ with open(saveto, "w") as fout: fout.write("ORF_ID\tcodon_profile\n") - with open(detected_orfs, "r") as fin: + with open(detected_orfs) as fin: # Skip header fin.readline() for line in fin: @@ -381,23 +461,22 @@ def summarize_profile_to_codon_level(detected_orfs, saveto): if profile_stripped[0]: profile = np.array(list(map(int, profile_stripped))) codon_profile = np.add.reduceat(profile, range(0, len(profile), 3)) - fout.write("{}\t{}\n".format(oid, list(codon_profile))) + fout.write(f"{oid}\t{list(codon_profile)}\n") -def translate(seq): - """Translate a given nucleotide sequence to an amino acid sequence +def translate(seq: str) -> str: + """Translate a given nucleotide sequence to an amino acid sequence. Parameters ---------- - seq: str - Nucleotide seqeunce + seq : str + Nucleotide sequence. Returns ------- - protein: str - Translated sequence of amino acids + str + Translated sequence of amino acids. """ - protein = "" if len(seq) % 3 == 0: for i in range(0, len(seq), 3): @@ -406,41 +485,38 @@ def translate(seq): return protein -def learn_ribotricer_cutoff(roc_input_file): - """Learn ribotricer phase score cutoff +def learn_ribotricer_cutoff(roc_input_file: str) -> tuple[float, float]: + """Learn ribotricer phase score cutoff. Parameters ---------- - roc_input_file: str - Path to ROC file generated using ribotricer benchmark + roc_input_file : str + Path to ROC file generated using ribotricer benchmark. Returns ------- - cutoff: float - Recommended phase score cutoff - - fscore: float - Corresponding F1 score acheived at the determined cutoff + tuple[float, float] + Tuple of (cutoff, fscore). """ - from sklearn.metrics import precision_recall_fscore_support import pandas as pd + from sklearn.metrics import precision_recall_fscore_support data = pd.read_csv(roc_input_file, sep="\t") ribotricer_scores = data.ribotricer truth = data.truth - precision_recall_fscore_support_df = [] + precision_recall_fscore_support_list: list[list[float]] = [] cutoffs = np.linspace(0, 1, 1000) - for cutoff in cutoffs: - predicted = np.where(ribotricer_scores > cutoff, 1, 0) + for cutoff_val in cutoffs: + predicted = np.where(ribotricer_scores > cutoff_val, 1, 0) s = precision_recall_fscore_support( truth, predicted, average="binary", pos_label=1 ) - precision_recall_fscore_support_df.append([cutoff] + list(s)) + precision_recall_fscore_support_list.append([cutoff_val] + list(s)) precision_recall_fscore_support_df = pd.DataFrame( - precision_recall_fscore_support_df, - columns=["cutoff", "precision", "recall", "fscore", "cutoff"], + precision_recall_fscore_support_list, + columns=["cutoff", "precision", "recall", "fscore", "cutoff2"], ) cutoff = precision_recall_fscore_support_df.loc[ precision_recall_fscore_support_df["fscore"].idxmax() diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 80ada03..0000000 --- a/setup.cfg +++ /dev/null @@ -1,34 +0,0 @@ -[bumpversion] -current_version = 1.4.0 -commit = True -tag = False -parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? -serialize = - {major}.{minor}.{patch}-{release}{build} - {major}.{minor}.{patch} - -[bumpversion:part:release] -optional_value = prod -first_value = dev -values = - dev - prod - -[bumpversion:part:build] - -[bumpversion:file:setup.py] -search = version="{current_version}" -replace = version="{new_version}" - -[bumpversion:file:ribotricer/__init__.py] -search = __version__ = "{current_version}" -replace = __version__ = "{new_version}" - -[flake8] -exclude = docs - -[aliases] -test = pytest - -[tool:pytest] -collect_ignore = ["setup.py"] diff --git a/setup.py b/setup.py index 1de0a99..d0e570a 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,7 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # Part of ribotricer software # -# Copyright (C) 2020 Saket Choudhary, Wenzheng Li, and Andrew D Smith +# Copyright (C) 2020-2026 Saket Choudhary, Wenzheng Li, and Andrew D Smith # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,44 +13,15 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -import setuptools +""" +Minimal setup.py for backward compatibility. + +All package configuration is now in pyproject.toml. +This file is kept for compatibility with older pip versions +and editable installs. +""" -with open("README.md") as readme_file: - readme = readme_file.read() -with open("requirements.txt") as req_file: - requirements = req_file.read() +import setuptools -setuptools.setup( - name="ribotricer", - version="1.4.0", - author="Saket Choudhary, Wenzheng Li", - author_email="saketkc@gmail.com", - maintainer="Saket Choudhary", - maintainer_email="saketkc@gmail.com", - description="Python package to detect translating ORFs from Ribo-seq data", - license="GPLv3", - long_description=readme, - long_description_content_type="text/markdown", - url="https://github.com/smithlabcode/ribotricer", - packages=setuptools.find_packages(), - entry_points={"console_scripts": ["ribotricer=ribotricer.cli:cli"]}, - python_requires=">=3.7", - install_requires=requirements, - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Console", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", - "Natural Language :: English", - "Operating System :: POSIX :: Linux", - "Operating System :: MacOS", - "Operating System :: Microsoft :: Windows", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Topic :: Scientific/Engineering :: Bio-Informatics", - "Topic :: Utilities", - ], -) +if __name__ == "__main__": + setuptools.setup() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..76f838e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for ribotricer.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..cce0464 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,23 @@ +"""Pytest configuration and fixtures for ribotricer tests.""" + +from __future__ import annotations + +import pytest + +from ribotricer.interval import Interval + + +@pytest.fixture +def sample_interval() -> Interval: + """Create a sample interval for testing.""" + return Interval("chr1", 100, 200, "+") + + +@pytest.fixture +def sample_intervals() -> list[Interval]: + """Create a list of sample intervals for testing.""" + return [ + Interval("chr1", 100, 200, "+"), + Interval("chr1", 300, 400, "+"), + Interval("chr1", 500, 600, "+"), + ] diff --git a/tests/test_interval.py b/tests/test_interval.py new file mode 100644 index 0000000..0b82d28 --- /dev/null +++ b/tests/test_interval.py @@ -0,0 +1,49 @@ +"""Tests for the interval module.""" + +from ribotricer.interval import Interval + + +class TestInterval: + """Tests for the Interval class.""" + + def test_interval_creation(self): + """Test basic interval creation.""" + interval = Interval("chr1", 100, 200, "+") + assert interval.chrom == "chr1" + assert interval.start == 100 + assert interval.end == 200 + assert interval.strand == "+" + + def test_interval_length(self): + """Test interval length calculation.""" + interval = Interval("chr1", 100, 200, "+") + assert len(interval) == 101 # 1-based, closed interval + + def test_interval_equality(self): + """Test interval equality comparison.""" + int1 = Interval("chr1", 100, 200, "+") + int2 = Interval("chr1", 100, 200, "+") + int3 = Interval("chr1", 100, 200, "-") + + assert int1 == int2 + assert int1 != int3 + + def test_interval_str_representation(self): + """Test string representation of interval.""" + interval = Interval("chr1", 100, 200, "+") + str_repr = str(interval) + assert "chr1" in str_repr + assert "100" in str_repr + assert "200" in str_repr + + def test_interval_hash(self): + """Test that intervals can be hashed (for use in sets/dicts).""" + int1 = Interval("chr1", 100, 200, "+") + int2 = Interval("chr1", 100, 200, "+") + + # Same intervals should have same hash + assert hash(int1) == hash(int2) + + # Should be usable in a set + interval_set = {int1, int2} + assert len(interval_set) == 1 diff --git a/tests/test_orf.py b/tests/test_orf.py new file mode 100644 index 0000000..801566c --- /dev/null +++ b/tests/test_orf.py @@ -0,0 +1,104 @@ +"""Tests for the ORF module.""" + +from __future__ import annotations + +from ribotricer.interval import Interval +from ribotricer.orf import ORF + + +class TestORF: + """Tests for the ORF class.""" + + def test_orf_creation(self) -> None: + """Test basic ORF creation.""" + intervals = [ + Interval("chr1", 100, 200, "+"), + Interval("chr1", 300, 400, "+"), + ] + orf = ORF( + category="annotated", + transcript_id="tx1", + transcript_type="protein_coding", + gene_id="gene1", + gene_name="Gene1", + gene_type="protein_coding", + chrom="chr1", + strand="+", + intervals=intervals, + seq="ATGAAA", + ) + + assert orf.tid == "tx1" + assert orf.gid == "gene1" + assert orf.category == "annotated" + assert orf.start_codon == "ATG" + assert len(orf.intervals) == 2 + + def test_orf_strand(self) -> None: + """Test ORF strand property.""" + intervals = [Interval("chr1", 100, 200, "+")] + orf = ORF( + category="annotated", + transcript_id="tx1", + transcript_type="protein_coding", + gene_id="gene1", + gene_name="Gene1", + gene_type="protein_coding", + chrom="chr1", + strand="+", + intervals=intervals, + seq="ATG", + ) + + assert orf.strand == "+" + + def test_orf_chrom(self) -> None: + """Test ORF chromosome property.""" + intervals = [Interval("chr1", 100, 200, "+")] + orf = ORF( + category="annotated", + transcript_id="tx1", + transcript_type="protein_coding", + gene_id="gene1", + gene_name="Gene1", + gene_type="protein_coding", + chrom="chr1", + strand="+", + intervals=intervals, + seq="ATG", + ) + + assert orf.chrom == "chr1" + + def test_orf_from_string(self) -> None: + """Test ORF parsing from index file line.""" + line = ( + "tx1_100_200_101\tannotated\ttx1\tprotein_coding\t" + "gene1\tGene1\tprotein_coding\tchr1\t+\tATG\t100-200" + ) + orf = ORF.from_string(line) + + assert orf is not None + assert orf.tid == "tx1" + assert orf.category == "annotated" + assert orf.start_codon == "ATG" + assert orf.chrom == "chr1" + assert orf.strand == "+" + + def test_orf_start_codon_short_seq(self) -> None: + """Test that start_codon returns None for short sequences.""" + intervals = [Interval("chr1", 100, 102, "+")] + orf = ORF( + category="annotated", + transcript_id="tx1", + transcript_type="protein_coding", + gene_id="gene1", + gene_name="Gene1", + gene_type="protein_coding", + chrom="chr1", + strand="+", + intervals=intervals, + seq="AT", # Only 2 bases + ) + + assert orf.start_codon is None