diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3365cac --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: tests + +on: + push: + branches: [main] + pull_request: + +jobs: + unit-tests: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.12"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Byte-compile common.smk + # common.smk is plain Python and carries the memory/length logic. + run: python -m py_compile workflow/rules/common.smk + - name: Run resource/length unit tests + # The suite loads common.smk directly and needs no third-party deps; + # the SLURM-plugin integration test self-skips when the plugin is absent. + run: python test/test_memory_resources.py diff --git a/README.md b/README.md index 5db0fe0..efff6c1 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,27 @@ you hit these. - **Restrict to one model** with `structure_inference_gpu_model` (e.g. `"A100"`) → the plugin emits `--gpus=:`. Accepts a single model name; leave `""` for any. +- **Route by complex size (VRAM)** with `structure_inference_gpu_tiers` → list your GPU pool as + tiers of `{min_vram_gb, nodes}`. A complex's estimated peak VRAM (≈ `per_token_sq·N²`) selects the + smallest tier that fits and all *smaller*-GPU nodes are excluded, so the job runs on **any** GPU at + or above that tier — using the whole pool, not one pinned model. A complex larger than every tier + uses the biggest tier and spills to host RAM via unified memory. + + ```yaml + # Example for EMBL gpu-el8 — replace nodes with your cluster's (nothing is hard-coded): + structure_inference_gpu_vram_headroom: 1.0 # <1.0 tolerates that fraction of host spill + structure_inference_gpu_tiers: + - {min_vram_gb: 24, nodes: "gpu21,gpu22,gpu29,gpu30,gpu31,gpu32,gpu33,gpu34,gpu35,gpu36,gpu37"} + - {min_vram_gb: 40, nodes: "gpu25,gpu26,gpu27,gpu28"} + - {min_vram_gb: 48, nodes: "gpu40,gpu41,gpu42,gpu43,gpu44,gpu45,gpu46,gpu47,gpu48"} + - {min_vram_gb: 80, nodes: "gpu38,gpu39"} + ``` + + When set this drives `--exclude` per job and **overrides** `structure_inference_gpu_model` (the two + would conflict). It's the practical "fit to GPU" lever: requested host RAM is a separate pool and + does not size GPU VRAM, but excluding too-small GPUs by length does. Use explicit comma node lists + (bracket ranges may be glob-expanded by the shell). Multi-partition routing (e.g. EMBL's bigger + `gpu-training` cards) is out of scope — keep one partition and let unified memory spill the tail. - **Exclude specific nodes** with `slurm_exclude_nodes` → passed verbatim to `sbatch --exclude` (e.g. `"gpu50,gpu51"`). Use it for nodes whose GPU the container can't use — e.g. a CUDA compute capability newer than the container's bundled `ptxas` (fails `ptxas too old` / `UNIMPLEMENTED`). @@ -301,6 +322,93 @@ exactly what the fraction is sized against. +
+Length-aware memory requests (sized automatically from the input sequences) + +Host RAM for both compute stages is requested **from the input sequence length**, so big +complexes get enough memory on the first attempt instead of failing and climbing the retry +ladder, while small jobs are not over-provisioned. The request is computed at scheduling +time by reading the per-chain FASTA(s) the pipeline already stages under +`/data/`: + +``` +create_features mem = safety * (feature_create_ram_bytes + per_residue * seq_len) +structure_inference mem = safety * (structure_inference_ram_bytes + per_token_sq * N^2) +``` + +- `seq_len` is the query length; `N` is the **total residues of the complex** (the + AlphaFold token count, summed over chains and copy numbers). AlphaFold's pair + representation is `O(N^2)`, hence the quadratic inference term. +- **The coefficients default by backend** (selected from `--data_pipeline` / `--fold_backend`). + AlphaFold-Multimer (AF2) is heavier than AlphaFold 3 — measured AF2 inference host RSS was + ~4× higher than AF3 at the same complex size, and AF2's feature stage runs HHblits (the + main OOM source), whereas the AF3 pipeline is lighter. Defaults: + + | backend | feature base | feature /residue | inference base | inference /N² | + |---|---|---|---|---| + | `alphafold2` | 64000 MB | 40 MB | 24000 MB | 0.0055 | + | `alphafold3` | 40000 MB | 25 MB | 16000 MB | 0.0045 | + + The AF3 inference quadratic is sized to the observed GPU-VRAM demand so that, with unified + memory, the host spill ceiling (`host_mem / gpu_vram`) covers large complexes instead of + OOM-ing. +- The first attempt already includes `mem_safety_factor` (default `1.25`) of head-room. + **OOM retries still escalate** on top, multiplying by `..._ram_scaling ** (attempt - 1)`, + so a bad estimate self-heals. +- Override any backend default by setting the matching key in `config/config.yaml` + (`feature_create_ram_bytes`, `feature_create_ram_per_residue_mb`, + `structure_inference_ram_bytes`, `structure_inference_ram_per_token_sq_mb`); an explicit + value applies to all backends. Also tune `mem_safety_factor`, the `..._ram_scaling` + factors, `structure_inference_runtime_minutes`, and `max_mem_mb` (set it to your largest + node's RAM where an over-estimate would otherwise never schedule; `0` = no cap). +- The `..._ram_bytes` keys are the **fixed base** of each model rather than a flat request; + raising a base only raises the floor. Setting `per_residue`/`per_token_sq` to `0` + reproduces the old length-blind behaviour (a flat base × retry scaling). +- **Precomputed features:** when a chain is supplied via `feature_directory`, no + `data/.fasta` is generated. Length is then recovered from the precomputed + `_af3_input.json` (AF3) or from the parse-time length cache written by the length + filter below (covers AF2 too). If neither is available the job falls back to the base + allocation plus retry escalation. AF3 ligand atoms are not counted (no sequence), a small + undercount absorbed by the safety margin. + +
+ +
+Skipping over-large complexes (length filtering) + +Folds that are too large to be worth submitting are **skipped before any job is created**, +so a single oversized complex (or one giant chain) doesn't waste a GPU/feature allocation +that will only OOM. Two configurable limits (in `config/config.yaml`): + +```yaml +# Max TOTAL complex length (sum of all chains), per backend — selected by --fold_backend. +max_total_length_alphafold2: 5000 # AF2-Multimer +max_total_length_alphafold3: 7000 # AF3 handles larger inputs +# max_total_length: 6000 # optional single override for both backends +# Max length of any SINGLE protein; 0 = off (issue #33). A protein over this drops every +# fold containing it, so it is never even downloaded. +max_protein_length: 0 +length_filter_fetch_uniprot: true # set false for fully offline runs +``` + +- Lengths are resolved at **parse time** from, in order: a local FASTA, an + already-downloaded `data/.fasta`, the persistent cache + `/.sequence_lengths.tsv`, and finally the UniProt REST API (cached for + next time). Set a limit to `0` to disable it; if both are `0`, no resolution/fetching + happens at all. +- Skipped folds are listed with reasons in `/skipped_folds.tsv` and logged + as a `[length-filter]` warning. **Unknown lengths fail open** (the fold is kept), so a + UniProt outage never silently drops work. +- First parse of a large all-UniProt sheet will fetch each unique length once (cached + afterwards); already-downloaded inputs and local FASTAs are read without any network call. +- **Applies to every profile, including local/workstation runs** (it runs during workflow + parsing, not in the executor). It's the only length-aware feature that does — the memory + and GPU-routing settings are SLURM resources that local runs ignore. To attempt a complex + larger than the caps on a big workstation, raise or zero the `max_total_length_*` values + (and set `length_filter_fetch_uniprot: false` for offline use). + +
+ ### Using precomputed features If you have precomputed protein features, specify the directory: diff --git a/config/config.yaml b/config/config.yaml index 430081f..68662e5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -52,10 +52,65 @@ analyze_structure_arguments: # Memory allocation for feature creation and structure inference. # NOTE: despite the "_bytes" suffix these values are in MEGABYTES (used directly as -# the SLURM --mem request), so 64000 = 64000 MB ~= 64 GB. They scale with retries. -feature_create_ram_bytes: 64000 # MB -feature_create_ram_scaling: 1.1 -structure_inference_ram_bytes: 64000 # MB +# the SLURM --mem request), so 64000 = 64000 MB ~= 64 GB. +# +# Host RAM for both stages is sized automatically from the input sequence length +# with a safety margin, and still escalates on OOM retries. The requested memory is: +# create_features : safety * (feature_create_ram_bytes + per_residue * seq_len) +# structure_inference : safety * (structure_inference_ram_bytes + per_token_sq * N^2) +# where seq_len is the query length, N is the total residues of the complex, and the +# request is multiplied by (scaling ** (attempt - 1)) on each retry. AlphaFold's pair +# representation is O(N^2), hence the quadratic term for inference. +# +# The base / per-length coefficients DEFAULT BY BACKEND (AF2 is heavier than AF3), +# selected automatically from --data_pipeline / --fold_backend: +# feature base feature /res infer base infer /N^2 +# alphafold2 (AF2 multimer): 64000 MB 40 MB 24000 MB 0.0055 +# alphafold3 (AF3): 40000 MB 25 MB 16000 MB 0.0045 +# Leave the keys below commented to use those backend defaults. Uncomment any key +# to override it for ALL backends. + +# Safety margin applied to the first-attempt estimate of every length-aware stage. +mem_safety_factor: 1.25 +# Optional hard ceiling (MB) on any single memory request; 0 = no cap. Set this to +# your largest node's RAM on clusters where an over-estimate would never schedule. +max_mem_mb: 0 + +# create_features (CPU/MSA) — base is database/MSA-tool dominated (AF2 HHblits is the +# main OOM source), with a mild linear dependence on query length. +# feature_create_ram_bytes: 64000 # MB, fixed base (default: backend-specific) +# feature_create_ram_per_residue_mb: 30 # MB per query residue (default: backend-specific) +feature_create_ram_scaling: 1.1 # per-retry escalation on OOM + +# structure_inference (GPU host RAM) — base plus a quadratic term in complex size. +# With unified memory enabled the XLA spill fraction is derived from this host +# allocation, so this also sizes the effective GPU memory ceiling. +# structure_inference_ram_bytes: 24000 # MB, fixed base (default: backend-specific) +# structure_inference_ram_per_token_sq_mb: 0.0045 # MB per residue^2 (default: backend-specific) +structure_inference_ram_scaling: 1.1 # per-retry escalation on OOM +# Wall-time minutes per attempt for structure_inference (capped by +# structure_inference_max_runtime). Default 1440; AF3 on an adequate GPU finishes far +# sooner, but host-memory spilling can take many hours, so the default stays generous. +# structure_inference_runtime_minutes: 1440 + +# Length filtering: skip folds that are too large to be worth submitting. Lengths +# are resolved at parse time (local FASTA / already-downloaded data / cache / the +# UniProt REST API) and cached in /.sequence_lengths.tsv; skipped +# folds are listed in /skipped_folds.tsv. Unknown lengths fail +# open (the fold is kept). Set a limit to 0 to disable it. +# +# Max TOTAL complex length (sum of all chains, residues), per backend; selected by +# --fold_backend. AF3 handles larger inputs than AF2-Multimer. +max_total_length_alphafold2: 5000 +max_total_length_alphafold3: 7000 +# Optional single override applied to BOTH backends (takes precedence if set): +# max_total_length: 6000 +# Max length of any SINGLE protein (residues); 0 disables (issue #33). A protein +# over this drops every fold containing it, so it is never downloaded/predicted. +max_protein_length: 0 +# Whether parse-time length resolution may query the UniProt REST API for IDs that +# have no local FASTA / cached length yet. Set false for fully offline runs. +length_filter_fetch_uniprot: true # Number of threads for AlphaFold inference alphafold_inference_threads: 8 @@ -67,6 +122,23 @@ structure_inference_gpus_per_task: 1 # Restrict structure_inference to one GPU model (sbatch --gpus=:N), e.g. "3090". # Leave "" to let SLURM pick any GPU in the partition. structure_inference_gpu_model: "3090" +# Optional length-based GPU routing by VRAM (within this partition). List your GPU +# pool as tiers of {min_vram_gb, nodes}; a complex's estimated peak VRAM +# (~ structure_inference_ram_per_token_sq * N^2) selects the smallest tier that fits +# and all SMALLER-GPU nodes are excluded, so the job runs on ANY GPU >= that tier +# (the whole pool, not one pinned model). A complex larger than every tier uses the +# biggest tier and spills to host RAM via unified memory. When set, this drives +# --exclude and OVERRIDES structure_inference_gpu_model. Use explicit comma node lists +# (avoid bracket ranges, which the shell may glob). The example below is for the EMBL +# gpu-el8 partition - replace nodes with your cluster's; nothing is hard-coded. +# structure_inference_gpu_vram_headroom: 1.0 # <1.0 tolerates that fraction of host spill +# structure_inference_gpu_tiers: +# - {min_vram_gb: 24, nodes: "gpu21,gpu22,gpu29,gpu30,gpu31,gpu32,gpu33,gpu34,gpu35,gpu36,gpu37"} # RTX 3090 +# - {min_vram_gb: 40, nodes: "gpu25,gpu26,gpu27,gpu28"} # A100 40GB +# - {min_vram_gb: 48, nodes: "gpu40,gpu41,gpu42,gpu43,gpu44,gpu45,gpu46,gpu47,gpu48"} # L40s/A40 48GB +# - {min_vram_gb: 80, nodes: "gpu38,gpu39"} # H100 PCIe 80GB +# Note: RTX PRO 6000 (gpu50-53, 96GB) are ptxas-incompatible -> keep in slurm_exclude_nodes. +# H100-SXM/H200/B200 live on the separate gpu-training partition (not routed here). # Optional: comma-separated nodes to keep structure_inference OFF, passed to sbatch # as --exclude. Useful for GPUs the prediction container cannot use (e.g. a CUDA # compute capability the bundled ptxas is too old for). Example: diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py new file mode 100644 index 0000000..498a03d --- /dev/null +++ b/test/test_memory_resources.py @@ -0,0 +1,346 @@ +"""Tests for the length-aware memory model and the linear_resources passthrough. + +``workflow/rules/common.smk`` is pure Python, so it is loaded directly here. The +``linear_resources`` callables are exercised exactly the way Snakemake invokes +them (wildcards positional; ``input``/``attempt`` by keyword) to guard the +input-passthrough refactor. + +Run standalone: python test/test_memory_resources.py +Or with pytest: pytest test/test_memory_resources.py +""" + +from __future__ import annotations + +import importlib.machinery +import importlib.util +import json +import os +import tempfile +from pathlib import Path + +_COMMON = Path(__file__).resolve().parents[1] / "workflow" / "rules" / "common.smk" +_loader = importlib.machinery.SourceFileLoader("aps_common", str(_COMMON)) +_spec = importlib.util.spec_from_loader("aps_common", _loader) +common = importlib.util.module_from_spec(_spec) +_loader.exec_module(common) + + +def _write_fasta(directory: str, name: str, length: int) -> str: + path = os.path.join(directory, f"{name}.fasta") + with open(path, "w") as handle: + handle.write(f">{name}\n") + # split across lines to confirm multi-line sequences are summed + seq = "A" * length + for i in range(0, length, 60): + handle.write(seq[i : i + 60] + "\n") + return path + + +def test_residue_count_counts_sequence_only(): + common._RESIDUE_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + p = _write_fasta(d, "X", 137) + assert common.residue_count(p) == 137 + # unreadable path degrades to 0 (dry-run safety) + assert common.residue_count(os.path.join(d, "does_not_exist.fasta")) == 0 + + +def test_residue_count_does_not_cache_missing_file(): + """Regression: an early read of a not-yet-created file must NOT cache 0, or + length-aware sizing collapses to the base once the file later appears.""" + common._RESIDUE_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + p = os.path.join(d, "late.fasta") + assert common.residue_count(p) == 0 # file absent now + _write_fasta(d, "late", 321) # produced by an upstream job later + assert common.residue_count(p) == 321 # re-read, not the stale 0 + + +def test_fold_total_tokens_sums_chains_and_copies(): + common._RESIDUE_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + _write_fasta(d, "A", 200) + _write_fasta(d, "B", 300) + assert common.fold_total_tokens("A+B", d, "+") == 500 + # copy number: homo-dimer counts twice + assert common.fold_total_tokens("A:2", d, "+") == 400 + # region selection is conservatively counted at full length + assert common.fold_total_tokens("A:1-100", d, "+") == 200 + # mixed + assert common.fold_total_tokens("A:2+B", d, "+") == 700 + + +def test_fold_total_tokens_af3_precomputed_feature_fallback(): + """When data/.fasta is absent (precomputed features), AF3 length comes + from the /_af3_input.json fallback.""" + common._RESIDUE_COUNT_CACHE.clear() + common._AF3_INPUT_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d, tempfile.TemporaryDirectory() as feat: + _write_fasta(d, "B", 300) # only B has a FASTA; A is "precomputed" + with open(os.path.join(feat, "A_af3_input.json"), "w") as fh: + json.dump({"sequences": [{"protein": {"id": "A", "sequence": "M" * 250}}]}, fh) + # without fallback A is unknown -> only B counted + assert common.fold_total_tokens("A+B", d, "+") == 300 + # with AF3 fallback A is recovered from the json + assert ( + common.fold_total_tokens("A+B", d, "+", features_dir=feat, is_af3=True) == 550 + ) + # fallback is AF3-only: AF2 precomputed stays at 0 for the missing chain + assert common.fold_total_tokens("A+B", d, "+", features_dir=feat, is_af3=False) == 300 + + +def test_feature_mem_model_math(): + # safety * (base + per_residue * L), attempt 1 has no extra escalation + val = common.estimate_feature_mem_mb( + 500, base_mb=64000, per_residue_mb=30, scaling=1.1, safety=1.25, attempt=1 + ) + assert val == int(1.25 * (64000 + 30 * 500)) # 98750 + # retry escalation multiplies by scaling ** (attempt - 1) + val2 = common.estimate_feature_mem_mb( + 500, base_mb=64000, per_residue_mb=30, scaling=1.1, safety=1.25, attempt=3 + ) + assert val2 == int(1.25 * (64000 + 30 * 500) * (1.1 ** 2)) + + +def test_inference_mem_model_math_and_cap(): + val = common.estimate_inference_mem_mb( + 1000, base_mb=24000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, attempt=1 + ) + assert val == int(1.25 * (24000 + 0.0045 * 1000 ** 2)) # 35625 + # cap is honoured + capped = common.estimate_inference_mem_mb( + 5000, base_mb=24000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, + attempt=1, cap_mb=50000, + ) + assert capped == 50000 + + +def test_af3_inference_defaults_cover_observed_gpu_demand_anchors(): + """AF3 host request (the unified-memory spill ceiling) must exceed the observed + AF3 GPU-VRAM demand for the pairs documented in the AlphaJudge handoff.""" + d = common.INFERENCE_RAM_DEFAULTS["alphafold3"] + kw = dict(base_mb=d["base_mb"], per_token_sq_mb=d["per_token_sq_mb"], + scaling=1.1, safety=1.25, attempt=1) + anchors = [(2066, 25), (4556, 82), (4836, 100)] # (tokens, observed GB) + for n, observed_gb in anchors: + req_gb = common.estimate_inference_mem_mb(n, **kw) / 1000.0 + assert req_gb >= 1.2 * observed_gb, (n, req_gb, observed_gb) + assert req_gb <= 2.7 * observed_gb, (n, req_gb, observed_gb) + sizes = [common.estimate_inference_mem_mb(n, **kw) for n in (200, 1000, 2000, 4000)] + assert sizes == sorted(sizes) + + +def test_af2_inference_defaults_cover_measured_host_rss(): + """AF2 host request must cover the measured AF2 inference host RSS (which IS the + consumed memory for AF2), with margin.""" + d = common.INFERENCE_RAM_DEFAULTS["alphafold2"] + kw = dict(base_mb=d["base_mb"], per_token_sq_mb=d["per_token_sq_mb"], + scaling=1.1, safety=1.25, attempt=1) + measured = [(1583, 16.9), (2256, 30.8), (2324, 30.8)] # (tokens, measured host RSS GB) + for n, rss_gb in measured: + req_gb = common.estimate_inference_mem_mb(n, **kw) / 1000.0 + assert req_gb >= 1.2 * rss_gb, (n, req_gb, rss_gb) + assert req_gb <= 3.2 * rss_gb, (n, req_gb, rss_gb) + + +def test_backend_defaults_af2_heavier_than_af3(): + assert common.normalize_backend("af3") == "alphafold3" + assert common.normalize_backend("AlphaFold2") == "alphafold2" + assert common.normalize_backend(None) == "alphafold2" + f2, f3 = common.FEATURE_RAM_DEFAULTS["alphafold2"], common.FEATURE_RAM_DEFAULTS["alphafold3"] + i2, i3 = common.INFERENCE_RAM_DEFAULTS["alphafold2"], common.INFERENCE_RAM_DEFAULTS["alphafold3"] + assert f2["base_mb"] > f3["base_mb"] + assert f2["per_residue_mb"] > f3["per_residue_mb"] + assert i2["base_mb"] > i3["base_mb"] + assert i2["per_token_sq_mb"] > i3["per_token_sq_mb"] + + +def test_linear_resources_forwards_input_to_new_style_callbacks(): + # Snakemake invokes the resource callable as f(wildcards, input=..., attempt=...) + res = common.linear_resources( + mem_fn=lambda wildcards, input, attempt: 1000 * len(input) + attempt + ) + assert res["mem_mb"]({}, input=["a", "b", "c"], attempt=1) == 3001 + assert res["avg_mem"]({}, input=["a", "b", "c"], attempt=1) == int(3001 * 0.75) + + +def test_linear_resources_still_supports_legacy_callbacks(): + res = common.linear_resources(mem_fn=lambda wc, attempt: 5000 * attempt) + assert res["mem_mb"]({}, input=[], attempt=2) == 10000 + + +def test_linear_resources_default_scaling_without_callbacks(): + res = common.linear_resources(mem=800, runtime=10) + assert res["mem_mb"]({}, input=[], attempt=3) == 2400 + assert res["runtime"]({}, input=[], attempt=2) == 20 + assert res["attempt"]({}, input=[], attempt=4) == 4 + + +# --- length filtering (issue #33 + total caps) ------------------------------- + +def test_parse_fold_chains(): + assert common.parse_fold_chains("A+B") == [("A", 1), ("B", 1)] + assert common.parse_fold_chains("A:2") == [("A", 2)] + assert common.parse_fold_chains("A:1-100") == [("A", 1)] # region, not a copy + assert common.parse_fold_chains("A:2:1-100+B") == [("A", 2), ("B", 1)] + + +def test_fold_length_violation(): + chains = [("A", 300, 1), ("B", 2000, 1)] # total 2300 + assert common.fold_length_violation(chains, 0, 0) is None # limits off + assert common.fold_length_violation(chains, 0, 5000) is None # under total cap + assert common.fold_length_violation(chains, 0, 1000) is not None # over total cap + assert common.fold_length_violation(chains, 1000, 0) is not None # B over per-protein + assert common.fold_length_violation(chains, 2500, 0) is None # under per-protein + # homo-dimer copies count toward the total + assert common.fold_length_violation([("A", 600, 3)], 0, 1500) is not None + # unknown length fails open (None treated as 0) + assert common.fold_length_violation([("A", None, 1)], 0, 10) is None + + +def test_default_total_length_caps_af3_gt_af2(): + assert common.MAX_TOTAL_LENGTH_DEFAULTS["alphafold2"] == 5000 + assert common.MAX_TOTAL_LENGTH_DEFAULTS["alphafold3"] == 7000 + + +def test_fetch_uniprot_length_parses_and_fails_open(): + import urllib.request as ur + + class _FakeResp: + def __init__(self, data): + self._data = data + + def read(self): + return self._data + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + orig = ur.urlopen + try: + common.fetch_uniprot_length.cache_clear() + ur.urlopen = lambda url, timeout=30.0: _FakeResp(b">sp|X\nMKLVMK\nAA\n") + assert common.fetch_uniprot_length("FAKE_OK") == 8 # 6 + 2, header skipped + + def _boom(url, timeout=30.0): + raise OSError("offline") + + ur.urlopen = _boom + assert common.fetch_uniprot_length("FAKE_FAIL") == 0 # fail open, no crash + finally: + ur.urlopen = orig + + +def test_chain_residue_count_length_cache_fallback(): + """AF2 precomputed features: no FASTA and no AF3 JSON, but the parse-time + length cache supplies the length.""" + common._RESIDUE_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + # no data/A.fasta exists; cache provides it + assert common.chain_residue_count("A", d) == 0 + assert common.chain_residue_count("A", d, length_cache={"A": 412}) == 412 + + +def test_af3_input_residue_count_skips_ligands(): + with tempfile.TemporaryDirectory() as d: + p = os.path.join(d, "x.json") + with open(p, "w") as fh: + json.dump( + {"sequences": [ + {"protein": {"id": "A", "sequence": "M" * 100}}, + {"ligand": {"id": "L", "ccdCodes": ["ATP"]}}, # no 'sequence' + ]}, + fh, + ) + assert common.af3_input_residue_count(p) == 100 # ligand contributes 0 + + +def test_required_gpu_vram_gb(): + # 0.0045 MB/token^2: N=4836 -> ~105 GB; headroom scales it + assert round(common.required_gpu_vram_gb(4836, 0.0045)) == 105 + assert round(common.required_gpu_vram_gb(2066, 0.0045)) == 19 + assert common.required_gpu_vram_gb(4836, 0.0045, headroom=0.5) < 60 + + +def test_gpu_exclude_nodes_vram_routing(): + tiers = [ + {"min_vram_gb": 24, "nodes": "n24a,n24b"}, + {"min_vram_gb": 48, "nodes": "n48a,n48b"}, + {"min_vram_gb": 80, "nodes": "n80a"}, + ] + c = 0.0045 + # small complex (~3 GB) fits the smallest tier -> exclude nothing + assert common.gpu_exclude_nodes(800, tiers, c) == "" + # ~26 GB -> needs >=48 GB tier -> exclude the 24 GB nodes + assert common.gpu_exclude_nodes(2400, tiers, c) == "n24a,n24b" + # ~55 GB -> needs the 80 GB tier -> exclude 24 and 48 GB nodes + assert common.gpu_exclude_nodes(3500, tiers, c) == "n24a,n24b,n48a,n48b" + # bigger than every tier -> use largest tier (spill), exclude all smaller + assert common.gpu_exclude_nodes(20000, tiers, c) == "n24a,n24b,n48a,n48b" + # static extra excludes are always appended (e.g. ptxas-incompatible cards) + assert common.gpu_exclude_nodes(800, tiers, c, extra_exclude="gpu50,gpu51") == "gpu50,gpu51" + assert ( + common.gpu_exclude_nodes(2400, tiers, c, extra_exclude="gpu50") + == "n24a,n24b,gpu50" + ) + # unsorted tiers handled; no tiers -> only the static excludes + assert common.gpu_exclude_nodes(2400, [], c, extra_exclude="gpu50") == "gpu50" + assert common.gpu_exclude_nodes(2400, [], c) == "" + + +def test_mem_mb_reaches_sbatch_via_real_plugin(): + """Integration: the value our model computes is what the SLURM plugin turns + into `sbatch --mem`. Skips gracefully if the plugin isn't importable.""" + try: + from snakemake_executor_plugin_slurm.submit_string import get_submit_command + except Exception as exc: # pragma: no cover - depends on environment + print(f" (skipped: plugin not importable: {exc})") + return + + mem = common.estimate_inference_mem_mb( + 2300, base_mb=16000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, attempt=1 + ) + + class _Res(dict): + def get(self, key, default=None): + return dict.get(self, key, default) + + def __getattr__(self, key): + try: + return self[key] + except KeyError as exc: + raise AttributeError(key) from exc + + class _Job: + threads = 8 + resources = _Res(mem_mb=mem, runtime=600, qos="normal") + + params = { + "run_uuid": "test", + "slurm_logfile": "/tmp/test.log", + "comment_str": "test", + "account": "", + "partition": "", + "workdir": "", + } + try: + cmd = get_submit_command(_Job(), params) + except Exception as exc: # pragma: no cover - plugin internals may change + print(f" (skipped: plugin API changed: {exc})") + return + assert f"--mem {mem}" in cmd, cmd + + +def _run_all(): + fns = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] + for fn in fns: + fn() + print(f"PASS {fn.__name__}") + print(f"\n{len(fns)} tests passed") + + +if __name__ == "__main__": + _run_all() diff --git a/workflow/Snakefile b/workflow/Snakefile index 064707b..a7abfc4 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -27,6 +27,14 @@ if isinstance(feature_directories, (str, Path)): DATA_PIPELINE = config["create_feature_arguments"].get("--data_pipeline", "alphafold2").lower() IS_AF3 = DATA_PIPELINE in ("alphafold3", "af3") +# Backend identity (feature stage = data pipeline; inference stage = fold backend, +# falling back to the data pipeline). Defined early because length filtering and +# memory sizing both depend on it. +FEATURE_BACKEND = normalize_backend(DATA_PIPELINE) +INFERENCE_BACKEND = normalize_backend( + config.get("structure_inference_arguments", {}).get("--fold_backend", DATA_PIPELINE) +) + if IS_AF3: FEATURE_COMPRESSION = None FEATURE_SUFFIX = None @@ -134,9 +142,124 @@ local_sequences = tuple(dataset.sequences_by_origin.get("local", ())) localrules: symlink_local_files, download_uniprot, symlink_features ruleorder: symlink_local_files > download_uniprot > symlink_features > create_features +# --- Length filtering (issue #33 + backend total-length caps) ---------------- +# Skip folds whose chains exceed a per-protein limit (max_protein_length) or whose +# total complex length exceeds the backend cap (max_total_length, default 5000 for +# AF2 / 7000 for AF3), so infeasible jobs are never submitted. Limits are +# configurable; 0 disables. Per-protein lengths are resolved at parse time from +# local FASTAs, already-downloaded data/.fasta, a persistent cache, and +# finally the UniProt REST API; the resulting cache also feeds memory sizing for +# precomputed-feature runs. Unknown lengths fail open (the fold is kept). +MAX_PROTEIN_LENGTH = int(config.get("max_protein_length", 0) or 0) +_max_total_length_cfg = config.get("max_total_length") +if _max_total_length_cfg is None: + _max_total_length_cfg = config.get( + f"max_total_length_{INFERENCE_BACKEND}", + MAX_TOTAL_LENGTH_DEFAULTS[INFERENCE_BACKEND], + ) +MAX_TOTAL_LENGTH = int(_max_total_length_cfg or 0) +LENGTH_FILTER_FETCH_UNIPROT = _as_bool( + config.get("length_filter_fetch_uniprot", True), default=True +) + +_sequence_length_cache_path = join(config["output_directory"], ".sequence_lengths.tsv") +sequence_length_cache = {} +if exists(_sequence_length_cache_path): + with open(_sequence_length_cache_path) as _cache_fh: + for _cache_line in _cache_fh: + _cache_row = _cache_line.rstrip("\n").split("\t") + if len(_cache_row) >= 2 and _cache_row[1].strip().isdigit(): + sequence_length_cache[_cache_row[0]] = int(_cache_row[1]) + +kept_folds = list(dataset.fold_specifications) +dropped_folds = [] + +if MAX_PROTEIN_LENGTH > 0 or MAX_TOTAL_LENGTH > 0: + _local_path_by_base = { + splitext(basename(p))[0]: p + for p in dataset.sequences_by_origin.get("local", ()) + } + _data_dir = join(config["output_directory"], "data") + _features_dir = join(config["output_directory"], "features") + + def _resolve_protein_length(name): + if name in sequence_length_cache: + return sequence_length_cache[name] + length = 0 + if name in _local_path_by_base: + length = residue_count(_local_path_by_base[name]) + if length == 0: + length = residue_count(join(_data_dir, f"{name}.fasta")) + if length == 0 and IS_AF3: + length = af3_input_residue_count( + join(_features_dir, f"{name}_af3_input.json") + ) + if length == 0 and LENGTH_FILTER_FETCH_UNIPROT: + length = fetch_uniprot_length(name) + if length > 0: + sequence_length_cache[name] = length + return length + return None + + kept_folds = [] + _unknown_proteins = set() + for fold in dataset.fold_specifications: + chain_lengths = [] + for name, copies in parse_fold_chains(fold, protein_delimiter): + length = _resolve_protein_length(name) + if length is None: + _unknown_proteins.add(name) + chain_lengths.append((name, length, copies)) + reason = fold_length_violation( + chain_lengths, MAX_PROTEIN_LENGTH, MAX_TOTAL_LENGTH + ) + if reason: + dropped_folds.append((fold, reason)) + else: + kept_folds.append(fold) + + try: + makedirs(config["output_directory"], exist_ok=True) + with open(_sequence_length_cache_path, "w") as _cache_fh: + for _name in sorted(sequence_length_cache): + _cache_fh.write(f"{_name}\t{sequence_length_cache[_name]}\n") + except OSError: + pass + + if dropped_folds: + _skipped_path = join(config["output_directory"], "skipped_folds.tsv") + try: + with open(_skipped_path, "w") as _skip_fh: + _skip_fh.write("fold\treason\n") + for _fold, _reason in dropped_folds: + _skip_fh.write(f"{_fold}\t{_reason}\n") + except OSError: + _skipped_path = "(could not be written)" + logger.warning( + f"[length-filter] skipping {len(dropped_folds)} of " + f"{len(dataset.fold_specifications)} folds over the length limits " + f"(max_protein_length={MAX_PROTEIN_LENGTH}, max_total_length=" + f"{MAX_TOTAL_LENGTH} for {INFERENCE_BACKEND}); see {_skipped_path}" + ) + if _unknown_proteins: + logger.warning( + f"[length-filter] length unknown for {len(_unknown_proteins)} protein(s); " + f"their folds were kept (fail-open). Examples: " + f"{sorted(_unknown_proteins)[:5]}" + ) + +# Proteins required by the surviving folds (matches FoldDataset dedup ordering). +kept_proteins = [] +_seen_proteins = set() +for fold in kept_folds: + for _name in dataset.sequences_by_fold.get(fold, ()): + if _name not in _seen_proteins: + _seen_proteins.add(_name) + kept_proteins.append(_name) + required_folds = [ join(config["output_directory"], "predictions", fold, "completed_fold.txt") - for fold in dataset.fold_specifications + for fold in kept_folds ] ENABLE_STRUCTURE_ANALYSIS = config.get("enable_structure_analysis", True) GENERATE_RECURSIVE_REPORT = config.get("generate_recursive_report", False) @@ -149,7 +272,7 @@ RECURSIVE_REPORT = ( required_feature_paths = [ join(config["output_directory"], "features", feature_name(fasta_basename)) - for fasta_basename in dataset.unique_sequences + for fasta_basename in kept_proteins ] if config.get("only_generate_features", False): required_targets = required_feature_paths @@ -160,7 +283,7 @@ else: else: required_targets = [ join(config["output_directory"], "predictions", fold, "interfaces.csv") - for fold in dataset.fold_specifications + for fold in kept_folds ] else: required_targets = required_folds @@ -217,10 +340,69 @@ rule download_uniprot: tail -n +2 "${{temp_file}}" >> {output} """ +# Length-aware memory model (all values in MB). Host RAM for both compute +# stages is sized from the input sequence length(s) with a safety margin, then +# escalated on OOM retries via `scaling ** (attempt - 1)`. Defaults differ by +# backend (AF2 is heavier than AF3); an explicit config value always wins. See README. +MEM_SAFETY_FACTOR = float(config.get("mem_safety_factor", 1.25)) +MAX_MEM_MB = int(config.get("max_mem_mb", 0) or 0) # 0 = no cap + +# Backend defaults (FEATURE_BACKEND / INFERENCE_BACKEND are defined near the top). +_feat_ram_defaults = FEATURE_RAM_DEFAULTS[FEATURE_BACKEND] +_infer_ram_defaults = INFERENCE_RAM_DEFAULTS[INFERENCE_BACKEND] + +# create_features (CPU/MSA): mem ~ safety * (base + per_residue * seq_len) feature_scaling = config.get("feature_create_ram_scaling", 1.1) -base_feature_ram = config.get("feature_create_ram_bytes", 64000) +base_feature_ram = config.get("feature_create_ram_bytes", _feat_ram_defaults["base_mb"]) +feature_ram_per_residue = config.get( + "feature_create_ram_per_residue_mb", _feat_ram_defaults["per_residue_mb"] +) feature_threads = config.get("feature_threads", 8) +# structure_inference (GPU host RAM): mem ~ safety * (base + per_token_sq * N^2), +# N = total residues of the complex (AlphaFold pair representation is O(N^2)). +structure_inference_base_ram = config.get( + "structure_inference_ram_bytes", _infer_ram_defaults["base_mb"] +) +structure_inference_ram_per_token_sq = config.get( + "structure_inference_ram_per_token_sq_mb", _infer_ram_defaults["per_token_sq_mb"] +) +structure_inference_ram_scaling = config.get("structure_inference_ram_scaling", 1.1) +structure_inference_runtime_minutes = config.get( + "structure_inference_runtime_minutes", _infer_ram_defaults["runtime_minutes"] +) + +# Optional length-based GPU routing by VRAM (within one partition). Each tier is +# {min_vram_gb, nodes}; a complex's estimated peak VRAM picks the smallest tier that +# fits, and all smaller-GPU nodes are excluded so the job runs on ANY GPU >= that +# tier (the whole pool, not one pinned model). Cluster-agnostic: list your own tiers. +# When set, this drives `slurm_extra=--exclude=...` and takes precedence over the +# fixed structure_inference_gpu_model (which is dropped to avoid an impossible +# "pin model X but exclude its nodes" request). +GPU_TIERS = config.get("structure_inference_gpu_tiers") or [] +GPU_VRAM_HEADROOM = float(config.get("structure_inference_gpu_vram_headroom", 1.0)) + + +def _inference_slurm_extra(wildcards): + """slurm_extra resource: --exclude small-GPU nodes for large complexes, merged + with the static slurm_exclude_nodes.""" + tokens = fold_total_tokens( + wildcards.fold, + join(config["output_directory"], "data"), + protein_delimiter, + features_dir=join(config["output_directory"], "features"), + is_af3=(INFERENCE_BACKEND == "alphafold3"), + length_cache=sequence_length_cache, + ) + nodes = gpu_exclude_nodes( + tokens, + GPU_TIERS, + structure_inference_ram_per_token_sq, + GPU_VRAM_HEADROOM, + DEFAULT_SLURM_EXCLUDE or "", + ) + return f"--exclude={nodes}" if nodes else "" + rule symlink_features: input: precomputed_features, @@ -256,7 +438,24 @@ rule create_features: resources: qos=DEFAULT_SLURM_QOS, **linear_resources( - mem_fn=lambda wc, attempt: base_feature_ram * (feature_scaling ** attempt), + # Length from the parse-time cache (populated before any job runs) or + # the staged FASTA, falling back to the rule input. Using the cache + # avoids depending on when Snakemake evaluates the resource relative to + # the upstream download. + mem_fn=lambda wildcards, input, attempt: estimate_feature_mem_mb( + chain_residue_count( + wildcards.fasta_basename, + join(config["output_directory"], "data"), + length_cache=sequence_length_cache, + ) + or (residue_count(str(input[0])) if input else 0), + base_mb=base_feature_ram, + per_residue_mb=feature_ram_per_residue, + scaling=feature_scaling, + safety=MEM_SAFETY_FACTOR, + attempt=attempt, + cap_mb=MAX_MEM_MB, + ), runtime_fn=lambda wc, attempt: 1440 * attempt, ), threads: feature_threads, @@ -304,13 +503,41 @@ rule structure_inference: slurm_partition=(DEFAULT_SLURM_PARTITION or "gpu"), qos=DEFAULT_SLURM_QOS, gpu=DEFAULT_STRUCTURE_INFERENCE_GPUS, - **({"gpu_model": DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL} if DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL is not None else {}), - **({"slurm_extra": f"--exclude={DEFAULT_SLURM_EXCLUDE}"} if DEFAULT_SLURM_EXCLUDE else {}), + # Pin a single GPU model only when VRAM-tier routing is NOT used (the two + # would conflict: tier routing may exclude the pinned model's nodes). + **( + {"gpu_model": DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL} + if (DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL is not None and not GPU_TIERS) + else {} + ), + # --exclude is dynamic (per-complex VRAM routing) when tiers are set, else the + # static slurm_exclude_nodes. + **( + {"slurm_extra": _inference_slurm_extra} + if (GPU_TIERS or DEFAULT_SLURM_EXCLUDE) + else {} + ), tasks_per_gpu=DEFAULT_STRUCTURE_INFERENCE_TASKS_PER_GPU, **linear_resources( - mem_fn=lambda wc, attempt: config.get("structure_inference_ram_bytes", 32000) - * (1.1 ** attempt), - runtime_fn=lambda wc, attempt: min(1440 * attempt, STRUCTURE_INFERENCE_MAX_RUNTIME), + mem_fn=lambda wildcards, attempt: estimate_inference_mem_mb( + fold_total_tokens( + wildcards.fold, + join(config["output_directory"], "data"), + protein_delimiter, + features_dir=join(config["output_directory"], "features"), + is_af3=(INFERENCE_BACKEND == "alphafold3"), + length_cache=sequence_length_cache, + ), + base_mb=structure_inference_base_ram, + per_token_sq_mb=structure_inference_ram_per_token_sq, + scaling=structure_inference_ram_scaling, + safety=MEM_SAFETY_FACTOR, + attempt=attempt, + cap_mb=MAX_MEM_MB, + ), + runtime_fn=lambda wc, attempt: min( + structure_inference_runtime_minutes * attempt, STRUCTURE_INFERENCE_MAX_RUNTIME + ), ), threads: config["alphafold_inference_threads"], diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index cabcae3..c7da452 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -7,11 +7,336 @@ from __future__ import annotations +import functools +import inspect +import json import os +import urllib.request from collections.abc import Iterable from pathlib import Path from typing import Any, Callable +# Default maximum *total* complex length (residues) per backend, used to skip +# folds that are too large to be feasible. AF3 supports larger inputs than +# AF2-Multimer. Override via config (see Snakefile / config.yaml). +MAX_TOTAL_LENGTH_DEFAULTS = {"alphafold2": 5000, "alphafold3": 7000} + + +# Length lookups cache only *successful* (>0) reads. Caching a 0 from a not-yet- +# created file would be a correctness bug: Snakemake's scheduler evaluates resource +# functions early (before upstream download/symlink rules run), and a memoised 0 +# would then stick even after the file appears, collapsing length-aware sizing to +# the base allocation. Re-reading a small FASTA/JSON on later calls is cheap. +_RESIDUE_COUNT_CACHE: dict[str, int] = {} +_AF3_INPUT_COUNT_CACHE: dict[str, int] = {} + + +def residue_count(fasta_path: str) -> int: + """Number of residues in a (single-record) FASTA file. + + Counts sequence characters, ignoring header lines and whitespace. Returns 0 + when the file cannot be read yet (so estimation degrades to the base + allocation rather than crashing) and does not cache that 0 — see the note + above on why caching a missing-file result would be wrong. + """ + cached = _RESIDUE_COUNT_CACHE.get(fasta_path) + if cached: + return cached + try: + total = 0 + with open(fasta_path) as handle: + for line in handle: + if not line.startswith(">"): + total += len(line.strip()) + except OSError: + return 0 + if total > 0: + _RESIDUE_COUNT_CACHE[fasta_path] = total + return total + + +def af3_input_residue_count(json_path: str) -> int: + """Total polymer residues in an AlphaFold 3 ``*_af3_input.json`` feature file. + + Sums the ``sequence`` length of every protein/RNA/DNA entry under + ``sequences`` (ligands have no sequence and are skipped). Returns 0 if the + file is missing or not parseable (not cached); used as a fallback for the + chain length when no ``data/.fasta`` exists (e.g. precomputed features + supplied via ``feature_directory``). + """ + cached = _AF3_INPUT_COUNT_CACHE.get(json_path) + if cached: + return cached + try: + with open(json_path) as handle: + data = json.load(handle) + except (OSError, ValueError): + return 0 + total = 0 + for entry in data.get("sequences", []): + if not isinstance(entry, dict): + continue + for mol in ("protein", "rna", "dna"): + mol_entry = entry.get(mol) + if isinstance(mol_entry, dict): + total += len(mol_entry.get("sequence", "") or "") + if total > 0: + _AF3_INPUT_COUNT_CACHE[json_path] = total + return total + + +def parse_fold_chains(fold: str, delimiter: str = "+") -> list[tuple[str, int]]: + """Parse a fold spec into ``(chain_name, copies)`` pairs. + + Follows the AlphaPulldown ``name[:copies][:region...]`` convention: the copy + number, if present, is the first ``:`` token after the name and is a bare + integer (e.g. ``A:2`` = dimer, ``A:2:1-100`` = dimer of residues 1-100). + Region tokens such as ``1-100`` are never bare integers, so ``A:1-100`` is a + single copy. Handles ``A+B`` heteromers. + """ + chains: list[tuple[str, int]] = [] + for token in str(fold).split(delimiter): + parts = [part for part in token.split(":") if part] + if not parts: + continue + name = parts[0] + copies = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 1 + chains.append((name, copies)) + return chains + + +@functools.lru_cache(maxsize=None) +def fetch_uniprot_length(uniprot_id: str, timeout: float = 30.0) -> int: + """Residue length of a UniProt entry via the REST API; 0 on any failure. + + Mirrors the reference snippet in issue #33. Used at parse time for length + filtering when no local FASTA is available yet; failures return 0 so the + caller can fail open (keep the fold) rather than crash offline. + """ + url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" + try: + with urllib.request.urlopen(url, timeout=timeout) as response: + text = response.read().decode("utf-8", "replace") + except Exception: + return 0 + total = 0 + for line in text.splitlines(): + if not line.startswith(">"): + total += len(line.strip()) + return total + + +def chain_residue_count( + name: str, + data_dir: str, + features_dir: str | None = None, + is_af3: bool = False, + length_cache: dict | None = None, +) -> int: + """Residue length of a single chain. + + Resolution order: ``/.fasta`` -> AF3 precomputed + ``/_af3_input.json`` (AF3 only) -> ``length_cache`` (the + parse-time length table, which covers the AF2 precomputed-feature case where + neither a FASTA nor an AF3 JSON exists). Returns 0 when length is unknown so + sizing degrades to the base allocation plus retry escalation. + """ + length = residue_count(os.path.join(data_dir, f"{name}.fasta")) + if length == 0 and is_af3 and features_dir: + length = af3_input_residue_count( + os.path.join(features_dir, f"{name}_af3_input.json") + ) + if length == 0 and length_cache: + length = int(length_cache.get(name, 0) or 0) + return length + + +def fold_total_tokens( + fold: str, + data_dir: str, + delimiter: str = "+", + features_dir: str | None = None, + is_af3: bool = False, + length_cache: dict | None = None, +) -> int: + """Total residue (token) count of a fold specification. + + Sums the residue length of every chain in ``fold``, honouring copy numbers + such as ``A:2`` (a homo-dimer counts twice). Region selections such as + ``A:1-100`` are conservatively counted at the chain's full length, which + over- rather than under-estimates memory. Per-chain lengths come from + ``chain_residue_count`` (FASTA -> AF3 JSON -> length cache). + + Note: AF3 ligand atoms are not counted (no ``sequence`` field); for + protein/nucleic complexes this matches the token count, and the safety + margin plus retry escalation absorb any small ligand undercount. + """ + total = 0 + for name, copies in parse_fold_chains(fold, delimiter): + total += ( + chain_residue_count(name, data_dir, features_dir, is_af3, length_cache) + * copies + ) + return total + + +def fold_length_violation( + chain_lengths: list[tuple[str, int | None, int]], + max_protein_length: int = 0, + max_total_length: int = 0, +) -> str | None: + """Return a human-readable reason if a fold exceeds a length limit, else None. + + ``chain_lengths`` is a list of ``(name, length_or_None, copies)``. Limits of + 0 (or negative) are disabled. Unknown lengths (``None``) are treated as 0 so + the decision fails open (the fold is kept) rather than dropped on missing data. + """ + if max_protein_length and max_protein_length > 0: + for name, length, _copies in chain_lengths: + if length is not None and length > max_protein_length: + return ( + f"protein {name} length {length} exceeds " + f"max_protein_length {max_protein_length}" + ) + if max_total_length and max_total_length > 0: + total = sum((length or 0) * copies for _name, length, copies in chain_lengths) + if total > max_total_length: + return f"total length {total} exceeds max_total_length {max_total_length}" + return None + + +def required_gpu_vram_gb( + total_tokens: int, per_token_sq_mb: float, headroom: float = 1.0 +) -> float: + """Estimated peak GPU VRAM (GB) for a complex of ``total_tokens``. + + Uses the same O(N^2) coefficient as host-memory sizing as a proxy for on-device + peak demand, scaled by ``headroom`` (e.g. 0.8 tolerates ~20% spill to host). + """ + return headroom * per_token_sq_mb * (max(int(total_tokens), 0) ** 2) / 1000.0 + + +def gpu_exclude_nodes( + total_tokens: int, + tiers, + per_token_sq_mb: float, + headroom: float = 1.0, + extra_exclude: str = "", +) -> str: + """Comma-joined SLURM nodes to exclude so a complex lands on a big-enough GPU. + + ``tiers`` is an iterable of ``{"min_vram_gb": int, "nodes": ""}`` + describing the cluster's GPU pool. The complex's required VRAM + (:func:`required_gpu_vram_gb`) selects the smallest tier that satisfies it (the + largest tier if none does — the remainder spills to host via unified memory); + the nodes of every *smaller* tier are excluded, so the job may run on any GPU at + or above the chosen tier (the whole pool, not one pinned model). ``extra_exclude`` + (the static ``slurm_exclude_nodes``) is always appended. + + Cluster-agnostic: each site lists its own GPU tiers/nodes; nothing about a + specific cluster is hard-coded. + """ + parts: list[str] = [] + valid = [t for t in tiers if t and t.get("nodes")] + if valid: + ordered = sorted(valid, key=lambda t: int(t["min_vram_gb"])) + required = required_gpu_vram_gb(total_tokens, per_token_sq_mb, headroom) + chosen = len(ordered) - 1 + for index, tier in enumerate(ordered): + if int(tier["min_vram_gb"]) >= required: + chosen = index + break + parts.extend(str(tier["nodes"]) for tier in ordered[:chosen]) + if extra_exclude: + parts.append(str(extra_exclude)) + return ",".join(part for part in parts if part) + + +def _cap_mem(value_mb: float, cap_mb: int) -> int: + value = max(int(value_mb), 1) + if cap_mb and cap_mb > 0: + value = min(value, int(cap_mb)) + return value + + +def estimate_feature_mem_mb( + seq_len: int, + *, + base_mb: float, + per_residue_mb: float, + scaling: float, + safety: float, + attempt: int, + cap_mb: int = 0, +) -> int: + """Length-aware host RAM (MB) for the feature-generation (MSA) stage. + + Feature memory is dominated by a near-fixed database/MSA-tooling footprint + with only a mild dependence on query length, so the model is linear: + + mem = safety * (base_mb + per_residue_mb * seq_len) + + The first attempt already carries the ``safety`` margin; OOM retries + escalate further via ``scaling ** (attempt - 1)``. + """ + estimate = safety * (base_mb + per_residue_mb * max(int(seq_len), 0)) + value = estimate * (scaling ** max(int(attempt) - 1, 0)) + return _cap_mem(value, cap_mb) + + +def estimate_inference_mem_mb( + total_tokens: int, + *, + base_mb: float, + per_token_sq_mb: float, + scaling: float, + safety: float, + attempt: int, + cap_mb: int = 0, +) -> int: + """Length-aware host RAM (MB) for the structure-inference stage. + + AlphaFold's pair representation is O(N^2) in the number of tokens N (total + residues of the complex), so peak memory follows a quadratic: + + mem = safety * (base_mb + per_token_sq_mb * N**2) + + With unified memory enabled the XLA fraction is derived from this host + allocation, so sizing host RAM by N also sizes the GPU spill ceiling. The + first attempt carries the ``safety`` margin; OOM retries escalate via + ``scaling ** (attempt - 1)``. + """ + estimate = safety * (base_mb + per_token_sq_mb * (max(int(total_tokens), 0) ** 2)) + value = estimate * (scaling ** max(int(attempt) - 1, 0)) + return _cap_mem(value, cap_mb) + + +# Backend-specific memory defaults. AlphaFold-Multimer (AF2) inference carries a +# substantially heavier host-RAM footprint than AlphaFold 3 at the same complex +# size (measured host RSS ~4x higher around N~2300 in the benchmark campaign), and +# its feature stage runs HHblits, the dominant OOM source; the AF3 data pipeline +# (jackhmmer/nhmmer, no HHblits) is lighter. So AF2 gets larger bases and a larger +# quadratic term. These apply only when the matching config key is unset, so an +# explicit config value always wins. +FEATURE_RAM_DEFAULTS = { + "alphafold2": {"base_mb": 64000, "per_residue_mb": 40}, + "alphafold3": {"base_mb": 40000, "per_residue_mb": 25}, +} +INFERENCE_RAM_DEFAULTS = { + # base_mb: fixed floor; per_token_sq_mb: quadratic coeff in N^2 (total residues). + # AF2 base/coeff cover the measured AF2 host RSS with margin; the AF3 quadratic is + # sized to the observed GPU-VRAM demand so the unified-memory spill ceiling + # (host_mem / gpu_vram) covers large complexes instead of OOM-ing. + "alphafold2": {"base_mb": 24000, "per_token_sq_mb": 0.0055, "runtime_minutes": 1440}, + "alphafold3": {"base_mb": 16000, "per_token_sq_mb": 0.0045, "runtime_minutes": 1440}, +} + + +def normalize_backend(name, default: str = "alphafold2") -> str: + """Map a backend/data-pipeline string to 'alphafold2' or 'alphafold3'.""" + n = str(name if name is not None else default).strip().lower() + return "alphafold3" if n in ("alphafold3", "af3") else "alphafold2" + def feature_suffix(compression: str = "lzma") -> str: _compression = { @@ -105,30 +430,46 @@ def linear_resources( runtime_fn: Callable[[Any, int], float] | None = None, attempt_fn: Callable[[Any, int], int] | None = None, ) -> dict[str, Any]: - """Return a Snakemake resources dictionary scaling with retry attempts.""" + """Return a Snakemake resources dictionary scaling with retry attempts. + + User-supplied ``*_fn`` callbacks receive ``wildcards`` positionally and may + additionally declare ``input`` and/or ``attempt`` parameters; only the ones + they declare are forwarded. This keeps legacy ``f(wc, attempt)`` callbacks + working while letting length-aware callbacks read input files via + ``f(wildcards, input, attempt)``. + """ + + def _invoke(fn, wc, input, attempt): + params = inspect.signature(fn).parameters + kwargs = {} + if "input" in params: + kwargs["input"] = input + if "attempt" in params: + kwargs["attempt"] = attempt + return fn(wc, **kwargs) - def _mem_value(wc, attempt: int) -> float: + def _mem_value(wc, input, attempt: int) -> float: if mem_fn: - return float(mem_fn(wc, attempt)) + return float(_invoke(mem_fn, wc, input, attempt)) return float(mem * attempt) - def _runtime_value(wc, attempt: int) -> float: + def _runtime_value(wc, input, attempt: int) -> float: if runtime_fn: - return float(runtime_fn(wc, attempt)) + return float(_invoke(runtime_fn, wc, input, attempt)) return float(runtime * attempt) - def _avg_mem(wc, attempt: int) -> int: - return int(_mem_value(wc, attempt) * avg_factor) + def _avg_mem(wc, input, attempt: int) -> int: + return int(_mem_value(wc, input, attempt) * avg_factor) - def _mem_mb(wc, attempt: int) -> int: - return int(_mem_value(wc, attempt)) + def _mem_mb(wc, input, attempt: int) -> int: + return int(_mem_value(wc, input, attempt)) - def _runtime(wc, attempt: int) -> int: - return int(_runtime_value(wc, attempt)) + def _runtime(wc, input, attempt: int) -> int: + return int(_runtime_value(wc, input, attempt)) - def _attempt(wc, attempt: int) -> int: + def _attempt(wc, input, attempt: int) -> int: if attempt_fn: - return int(attempt_fn(wc, attempt)) + return int(_invoke(attempt_fn, wc, input, attempt)) return attempt return {