From a0159325940c7cb13dca9ebf3cb147f38d8a2d93 Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 13:50:01 +0200 Subject: [PATCH 1/8] feat: size SLURM memory from input sequence length Host RAM for both compute stages is now requested from the input sequence length instead of a flat per-rule value, so large complexes get enough memory on the first attempt rather than failing and climbing the OOM-retry ladder, while small jobs are no longer over-provisioned. Model (host RAM, MB), evaluated at scheduling time from the FASTAs the pipeline already stages under /data/: create_features mem = safety * (base + per_residue * seq_len) structure_inference mem = safety * (base + per_token_sq * N**2) N is the total residues of the complex (the AlphaFold token count, summed over chains and copy numbers). AlphaFold's pair representation is O(N^2), hence the quadratic inference term. The first attempt carries a safety margin (mem_safety_factor, default 1.25) and OOM retries still escalate on top via `*_ram_scaling ** (attempt - 1)`, so a mis-estimate self-heals. Coefficients are conservative and anchor-calibrated (AF3 performance docs + documented OOM/demand pairs from the AlphaJudge benchmark), not a dense empirical fit; the quadratic passes through the observed ~25/82/100 GB pairs at N=2066/4556/4836 with ~1.2-2.5x head-room. They can be tightened later without changing the mechanism. Implementation: - common.smk: residue_count(), fold_total_tokens(), estimate_feature_mem_mb(), estimate_inference_mem_mb(); linear_resources() now forwards `input` to resource callbacks that declare it (legacy `(wc, attempt)` callbacks still work, wildcards stays positional per Snakemake's calling convention). - Snakefile: create_features / structure_inference use the length-aware model. - config.yaml: new knobs mem_safety_factor, max_mem_mb, feature_create_ram_per_residue_mb, structure_inference_ram_per_token_sq_mb, structure_inference_ram_scaling; the *_ram_bytes keys are now the model base. Setting the per-length terms to 0 reproduces the old length-blind behaviour. - README: documents the model and knobs. - test/test_memory_resources.py: covers the math, retry escalation, cap, the observed-OOM anchors, and the exact linear_resources calling convention. Co-Authored-By: Claude Opus 4.7 --- README.md | 31 ++++++++ config/config.yaml | 32 +++++++- test/test_memory_resources.py | 135 +++++++++++++++++++++++++++++++ workflow/Snakefile | 41 +++++++++- workflow/rules/common.smk | 146 +++++++++++++++++++++++++++++++--- 5 files changed, 365 insertions(+), 20 deletions(-) create mode 100644 test/test_memory_resources.py diff --git a/README.md b/README.md index 5db0fe0..af69369 100644 --- a/README.md +++ b/README.md @@ -301,6 +301,37 @@ exactly what the fraction is sized against. +
+Length-aware memory requests (sized automatically from the input sequences) + +Host RAM for both compute stages is requested **from the input sequence length**, so big +complexes get enough memory on the first attempt instead of failing and climbing the retry +ladder, while small jobs are not over-provisioned. The request is computed at scheduling +time by reading the per-chain FASTA(s) the pipeline already stages under +`/data/`: + +``` +create_features mem = safety * (feature_create_ram_bytes + per_residue * seq_len) +structure_inference mem = safety * (structure_inference_ram_bytes + per_token_sq * N^2) +``` + +- `seq_len` is the query length; `N` is the **total residues of the complex** (the + AlphaFold token count, summed over chains and copy numbers). AlphaFold's pair + representation is `O(N^2)`, hence the quadratic inference term — calibrated so a ~2,000-, + ~4,500- and ~4,800-residue complex request ≈25, ≈90 and ≈100+ GB respectively. +- The first attempt already includes `mem_safety_factor` (default `1.25`) of head-room. + **OOM retries still escalate** on top, multiplying by `..._ram_scaling ** (attempt - 1)`, + so a bad estimate self-heals. +- Tune the model via `mem_safety_factor`, `feature_create_ram_per_residue_mb`, + `structure_inference_ram_per_token_sq_mb`, the two `..._ram_bytes` bases, and the two + `..._ram_scaling` factors (all in `config/config.yaml`). Set `max_mem_mb` to your largest + node's RAM on clusters where an over-estimate would otherwise never schedule (`0` = no cap). +- The `..._ram_bytes` keys are now the **fixed base** of each model rather than a flat + request; raising a base only raises the floor. Setting `per_residue`/`per_token_sq` to `0` + reproduces the old length-blind behaviour (a flat base × retry scaling). + +
+ ### Using precomputed features If you have precomputed protein features, specify the directory: diff --git a/config/config.yaml b/config/config.yaml index 430081f..22396b9 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -52,10 +52,34 @@ analyze_structure_arguments: # Memory allocation for feature creation and structure inference. # NOTE: despite the "_bytes" suffix these values are in MEGABYTES (used directly as -# the SLURM --mem request), so 64000 = 64000 MB ~= 64 GB. They scale with retries. -feature_create_ram_bytes: 64000 # MB -feature_create_ram_scaling: 1.1 -structure_inference_ram_bytes: 64000 # MB +# the SLURM --mem request), so 64000 = 64000 MB ~= 64 GB. +# +# Host RAM for both stages is sized automatically from the input sequence length +# with a safety margin, and still escalates on OOM retries. The requested memory is: +# create_features : safety * (feature_create_ram_bytes + per_residue * seq_len) +# structure_inference : safety * (structure_inference_ram_bytes + per_token_sq * N^2) +# where seq_len is the query length, N is the total residues of the complex, and the +# request is multiplied by (scaling ** (attempt - 1)) on each retry. AlphaFold's pair +# representation is O(N^2), hence the quadratic term for inference. + +# Safety margin applied to the first-attempt estimate of every length-aware stage. +mem_safety_factor: 1.25 +# Optional hard ceiling (MB) on any single memory request; 0 = no cap. Set this to +# your largest node's RAM on clusters where an over-estimate would never schedule. +max_mem_mb: 0 + +# create_features (CPU/MSA) — base footprint is database/MSA-tool dominated, with a +# mild linear dependence on query length. +feature_create_ram_bytes: 64000 # MB, fixed base (DB/MSA tooling) +feature_create_ram_per_residue_mb: 30 # MB per query residue +feature_create_ram_scaling: 1.1 # per-retry escalation on OOM + +# structure_inference (GPU host RAM) — base plus a quadratic term in complex size. +# With unified memory enabled the XLA spill fraction is derived from this host +# allocation, so this also sizes the effective GPU memory ceiling. +structure_inference_ram_bytes: 24000 # MB, fixed base +structure_inference_ram_per_token_sq_mb: 0.0045 # MB per residue^2 (O(N^2)) +structure_inference_ram_scaling: 1.1 # per-retry escalation on OOM # Number of threads for AlphaFold inference alphafold_inference_threads: 8 diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py new file mode 100644 index 0000000..e2f9390 --- /dev/null +++ b/test/test_memory_resources.py @@ -0,0 +1,135 @@ +"""Tests for the length-aware memory model and the linear_resources passthrough. + +``workflow/rules/common.smk`` is pure Python, so it is loaded directly here. The +``linear_resources`` callables are exercised exactly the way Snakemake invokes +them (wildcards positional; ``input``/``attempt`` by keyword) to guard the +input-passthrough refactor. + +Run standalone: python test/test_memory_resources.py +Or with pytest: pytest test/test_memory_resources.py +""" + +from __future__ import annotations + +import importlib.machinery +import importlib.util +import os +import tempfile +from pathlib import Path + +_COMMON = Path(__file__).resolve().parents[1] / "workflow" / "rules" / "common.smk" +_loader = importlib.machinery.SourceFileLoader("aps_common", str(_COMMON)) +_spec = importlib.util.spec_from_loader("aps_common", _loader) +common = importlib.util.module_from_spec(_spec) +_loader.exec_module(common) + + +def _write_fasta(directory: str, name: str, length: int) -> str: + path = os.path.join(directory, f"{name}.fasta") + with open(path, "w") as handle: + handle.write(f">{name}\n") + # split across lines to confirm multi-line sequences are summed + seq = "A" * length + for i in range(0, length, 60): + handle.write(seq[i : i + 60] + "\n") + return path + + +def test_residue_count_counts_sequence_only(): + common.residue_count.cache_clear() + with tempfile.TemporaryDirectory() as d: + p = _write_fasta(d, "X", 137) + assert common.residue_count(p) == 137 + # unreadable path degrades to 0 (dry-run safety) + assert common.residue_count(os.path.join(d, "does_not_exist.fasta")) == 0 + + +def test_fold_total_tokens_sums_chains_and_copies(): + common.residue_count.cache_clear() + with tempfile.TemporaryDirectory() as d: + _write_fasta(d, "A", 200) + _write_fasta(d, "B", 300) + assert common.fold_total_tokens("A+B", d, "+") == 500 + # copy number: homo-dimer counts twice + assert common.fold_total_tokens("A:2", d, "+") == 400 + # region selection is conservatively counted at full length + assert common.fold_total_tokens("A:1-100", d, "+") == 200 + # mixed + assert common.fold_total_tokens("A:2+B", d, "+") == 700 + + +def test_feature_mem_model_math(): + # safety * (base + per_residue * L), attempt 1 has no extra escalation + val = common.estimate_feature_mem_mb( + 500, base_mb=64000, per_residue_mb=30, scaling=1.1, safety=1.25, attempt=1 + ) + assert val == int(1.25 * (64000 + 30 * 500)) # 98750 + # retry escalation multiplies by scaling ** (attempt - 1) + val2 = common.estimate_feature_mem_mb( + 500, base_mb=64000, per_residue_mb=30, scaling=1.1, safety=1.25, attempt=3 + ) + assert val2 == int(1.25 * (64000 + 30 * 500) * (1.1 ** 2)) + + +def test_inference_mem_model_math_and_cap(): + val = common.estimate_inference_mem_mb( + 1000, base_mb=24000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, attempt=1 + ) + assert val == int(1.25 * (24000 + 0.0045 * 1000 ** 2)) # 35625 + # cap is honoured + capped = common.estimate_inference_mem_mb( + 5000, base_mb=24000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, + attempt=1, cap_mb=50000, + ) + assert capped == 50000 + + +def test_inference_model_covers_observed_oom_anchors_with_margin(): + """Requested RAM must exceed the empirically observed AF3 peak demand for the + pairs documented in the AlphaJudge handoff, with a sane (not wasteful) margin.""" + kw = dict(base_mb=24000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, attempt=1) + anchors = [ # (total_tokens, observed_peak_GB) + (2066, 25), # O00194+Q9ULV0 + (4556, 82), # P02549+P11277 + (4836, 100), # Q01082+Q13813 + ] + for n, observed_gb in anchors: + req_gb = common.estimate_inference_mem_mb(n, **kw) / 1000.0 + assert req_gb >= 1.2 * observed_gb, (n, req_gb, observed_gb) + assert req_gb <= 2.5 * observed_gb, (n, req_gb, observed_gb) + # monotonic in size + sizes = [common.estimate_inference_mem_mb(n, **kw) for n in (200, 1000, 2000, 4000)] + assert sizes == sorted(sizes) + + +def test_linear_resources_forwards_input_to_new_style_callbacks(): + # Snakemake invokes the resource callable as f(wildcards, input=..., attempt=...) + res = common.linear_resources( + mem_fn=lambda wildcards, input, attempt: 1000 * len(input) + attempt + ) + assert res["mem_mb"]({}, input=["a", "b", "c"], attempt=1) == 3001 + assert res["avg_mem"]({}, input=["a", "b", "c"], attempt=1) == int(3001 * 0.75) + + +def test_linear_resources_still_supports_legacy_callbacks(): + res = common.linear_resources(mem_fn=lambda wc, attempt: 5000 * attempt) + assert res["mem_mb"]({}, input=[], attempt=2) == 10000 + + +def test_linear_resources_default_scaling_without_callbacks(): + res = common.linear_resources(mem=800, runtime=10) + assert res["mem_mb"]({}, input=[], attempt=3) == 2400 + assert res["runtime"]({}, input=[], attempt=2) == 20 + assert res["attempt"]({}, input=[], attempt=4) == 4 + + +def _run_all(): + fns = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] + for fn in fns: + fn() + print(f"PASS {fn.__name__}") + print(f"\n{len(fns)} tests passed") + + +if __name__ == "__main__": + _run_all() diff --git a/workflow/Snakefile b/workflow/Snakefile index 064707b..38182fc 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -217,10 +217,26 @@ rule download_uniprot: tail -n +2 "${{temp_file}}" >> {output} """ +# Length-aware memory model (all values in MB). Host RAM for both compute +# stages is sized from the input sequence length(s) with a safety margin, then +# escalated on OOM retries via `scaling ** (attempt - 1)`. See README. +MEM_SAFETY_FACTOR = float(config.get("mem_safety_factor", 1.25)) +MAX_MEM_MB = int(config.get("max_mem_mb", 0) or 0) # 0 = no cap + +# create_features (CPU/MSA): mem ~ safety * (base + per_residue * seq_len) feature_scaling = config.get("feature_create_ram_scaling", 1.1) base_feature_ram = config.get("feature_create_ram_bytes", 64000) +feature_ram_per_residue = config.get("feature_create_ram_per_residue_mb", 30) feature_threads = config.get("feature_threads", 8) +# structure_inference (GPU host RAM): mem ~ safety * (base + per_token_sq * N^2), +# N = total residues of the complex (AlphaFold pair representation is O(N^2)). +structure_inference_base_ram = config.get("structure_inference_ram_bytes", 24000) +structure_inference_ram_per_token_sq = config.get( + "structure_inference_ram_per_token_sq_mb", 0.0045 +) +structure_inference_ram_scaling = config.get("structure_inference_ram_scaling", 1.1) + rule symlink_features: input: precomputed_features, @@ -256,7 +272,15 @@ rule create_features: resources: qos=DEFAULT_SLURM_QOS, **linear_resources( - mem_fn=lambda wc, attempt: base_feature_ram * (feature_scaling ** attempt), + mem_fn=lambda wildcards, input, attempt: estimate_feature_mem_mb( + residue_count(str(input[0])) if input else 0, + base_mb=base_feature_ram, + per_residue_mb=feature_ram_per_residue, + scaling=feature_scaling, + safety=MEM_SAFETY_FACTOR, + attempt=attempt, + cap_mb=MAX_MEM_MB, + ), runtime_fn=lambda wc, attempt: 1440 * attempt, ), threads: feature_threads, @@ -308,8 +332,19 @@ rule structure_inference: **({"slurm_extra": f"--exclude={DEFAULT_SLURM_EXCLUDE}"} if DEFAULT_SLURM_EXCLUDE else {}), tasks_per_gpu=DEFAULT_STRUCTURE_INFERENCE_TASKS_PER_GPU, **linear_resources( - mem_fn=lambda wc, attempt: config.get("structure_inference_ram_bytes", 32000) - * (1.1 ** attempt), + mem_fn=lambda wildcards, attempt: estimate_inference_mem_mb( + fold_total_tokens( + wildcards.fold, + join(config["output_directory"], "data"), + protein_delimiter, + ), + base_mb=structure_inference_base_ram, + per_token_sq_mb=structure_inference_ram_per_token_sq, + scaling=structure_inference_ram_scaling, + safety=MEM_SAFETY_FACTOR, + attempt=attempt, + cap_mb=MAX_MEM_MB, + ), runtime_fn=lambda wc, attempt: min(1440 * attempt, STRUCTURE_INFERENCE_MAX_RUNTIME), ), threads: diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index cabcae3..ba6f88b 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -7,12 +7,116 @@ from __future__ import annotations +import functools +import inspect import os from collections.abc import Iterable from pathlib import Path from typing import Any, Callable +@functools.lru_cache(maxsize=None) +def residue_count(fasta_path: str) -> int: + """Number of residues in a (single-record) FASTA file. + + Counts sequence characters, ignoring the header line(s) and whitespace. + Returns 0 when the file cannot be read yet (e.g. during a dry-run before + the upstream download/symlink rule has produced it) so that resource + estimation degrades gracefully to the base allocation instead of crashing. + Results are memoised because the structure-inference estimator may look up + the same chain repeatedly within a workflow. + """ + try: + total = 0 + with open(fasta_path) as handle: + for line in handle: + if line.startswith(">"): + continue + total += len(line.strip()) + return total + except OSError: + return 0 + + +def fold_total_tokens(fold: str, data_dir: str, delimiter: str = "+") -> int: + """Total residue (token) count of a fold specification. + + Sums the residue length of every chain in ``fold``, honouring copy numbers + such as ``A:2`` (a homo-dimer counts twice). Region selections such as + ``A:1-100`` are conservatively counted at the chain's full length, which + over- rather than under-estimates memory. Per-chain lengths are read from + ``/.fasta``. + """ + total = 0 + for token in str(fold).split(delimiter): + parts = [part for part in token.split(":") if part] + if not parts: + continue + name = parts[0] + copies = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else 1 + total += residue_count(os.path.join(data_dir, f"{name}.fasta")) * copies + return total + + +def _cap_mem(value_mb: float, cap_mb: int) -> int: + value = max(int(value_mb), 1) + if cap_mb and cap_mb > 0: + value = min(value, int(cap_mb)) + return value + + +def estimate_feature_mem_mb( + seq_len: int, + *, + base_mb: float, + per_residue_mb: float, + scaling: float, + safety: float, + attempt: int, + cap_mb: int = 0, +) -> int: + """Length-aware host RAM (MB) for the feature-generation (MSA) stage. + + Feature memory is dominated by a near-fixed database/MSA-tooling footprint + with only a mild dependence on query length, so the model is linear: + + mem = safety * (base_mb + per_residue_mb * seq_len) + + The first attempt already carries the ``safety`` margin; OOM retries + escalate further via ``scaling ** (attempt - 1)``. + """ + estimate = safety * (base_mb + per_residue_mb * max(int(seq_len), 0)) + value = estimate * (scaling ** max(int(attempt) - 1, 0)) + return _cap_mem(value, cap_mb) + + +def estimate_inference_mem_mb( + total_tokens: int, + *, + base_mb: float, + per_token_sq_mb: float, + scaling: float, + safety: float, + attempt: int, + cap_mb: int = 0, +) -> int: + """Length-aware host RAM (MB) for the structure-inference stage. + + AlphaFold's pair representation is O(N^2) in the number of tokens N (total + residues of the complex), so peak memory follows a quadratic: + + mem = safety * (base_mb + per_token_sq_mb * N**2) + + With unified memory enabled the XLA fraction is derived from this host + allocation, so sizing host RAM by N also sizes the GPU spill ceiling. The + first attempt carries the ``safety`` margin; OOM retries escalate via + ``scaling ** (attempt - 1)``. + """ + estimate = safety * (base_mb + per_token_sq_mb * (max(int(total_tokens), 0) ** 2)) + value = estimate * (scaling ** max(int(attempt) - 1, 0)) + return _cap_mem(value, cap_mb) + + def feature_suffix(compression: str = "lzma") -> str: _compression = { "lzma": "xz", @@ -105,30 +209,46 @@ def linear_resources( runtime_fn: Callable[[Any, int], float] | None = None, attempt_fn: Callable[[Any, int], int] | None = None, ) -> dict[str, Any]: - """Return a Snakemake resources dictionary scaling with retry attempts.""" + """Return a Snakemake resources dictionary scaling with retry attempts. + + User-supplied ``*_fn`` callbacks receive ``wildcards`` positionally and may + additionally declare ``input`` and/or ``attempt`` parameters; only the ones + they declare are forwarded. This keeps legacy ``f(wc, attempt)`` callbacks + working while letting length-aware callbacks read input files via + ``f(wildcards, input, attempt)``. + """ + + def _invoke(fn, wc, input, attempt): + params = inspect.signature(fn).parameters + kwargs = {} + if "input" in params: + kwargs["input"] = input + if "attempt" in params: + kwargs["attempt"] = attempt + return fn(wc, **kwargs) - def _mem_value(wc, attempt: int) -> float: + def _mem_value(wc, input, attempt: int) -> float: if mem_fn: - return float(mem_fn(wc, attempt)) + return float(_invoke(mem_fn, wc, input, attempt)) return float(mem * attempt) - def _runtime_value(wc, attempt: int) -> float: + def _runtime_value(wc, input, attempt: int) -> float: if runtime_fn: - return float(runtime_fn(wc, attempt)) + return float(_invoke(runtime_fn, wc, input, attempt)) return float(runtime * attempt) - def _avg_mem(wc, attempt: int) -> int: - return int(_mem_value(wc, attempt) * avg_factor) + def _avg_mem(wc, input, attempt: int) -> int: + return int(_mem_value(wc, input, attempt) * avg_factor) - def _mem_mb(wc, attempt: int) -> int: - return int(_mem_value(wc, attempt)) + def _mem_mb(wc, input, attempt: int) -> int: + return int(_mem_value(wc, input, attempt)) - def _runtime(wc, attempt: int) -> int: - return int(_runtime_value(wc, attempt)) + def _runtime(wc, input, attempt: int) -> int: + return int(_runtime_value(wc, input, attempt)) - def _attempt(wc, attempt: int) -> int: + def _attempt(wc, input, attempt: int) -> int: if attempt_fn: - return int(attempt_fn(wc, attempt)) + return int(_invoke(attempt_fn, wc, input, attempt)) return attempt return { From 1f5ccf8f52c313bc7a12758e0401fca3364317aa Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 14:37:32 +0200 Subject: [PATCH 2/8] feat: backend-specific memory defaults (AF2 vs AF3) AlphaFold-Multimer (AF2) and AlphaFold 3 have materially different memory and runtime profiles, so the length-aware coefficients now default by backend (selected from --data_pipeline for features and --fold_backend for inference; an explicit config value still overrides). Evidence (benchmark campaign, joined slurm-logs -> sacct -> sequence length): - AF2 inference host RSS is ~4x higher than AF3 at the same complex size (e.g. N~2300: AF2 ~31 GB vs AF3 ~7 GB) and rises quadratically. - AF2's feature stage runs HHblits, the dominant OOM source; the AF3 pipeline (jackhmmer/nhmmer, no HHblits) is lighter. Defaults: feature base feature/res infer base infer/N^2 alphafold2 (AF2): 64000 MB 40 MB 24000 MB 0.0055 alphafold3 (AF3): 40000 MB 25 MB 16000 MB 0.0045 The AF3 inference quadratic is sized to the observed GPU-VRAM demand so that, with unified memory, the host spill ceiling (host_mem/gpu_vram) covers large complexes instead of OOM-ing. Safety margin and OOM-retry escalation are unchanged. Runtime is now configurable per attempt (structure_inference_runtime_minutes, default 1440 for both backends) but kept generous because AF3 host-memory spilling can take many hours (measured ~8.5 h) despite AF3's faster on-GPU compute. - common.smk: FEATURE_RAM_DEFAULTS, INFERENCE_RAM_DEFAULTS, normalize_backend(). - Snakefile: resolve feature/inference backend; coefficients fall back to the backend default when the config key is unset; runtime knob wired in. - config.yaml: per-length coefficient keys commented out so backend defaults apply out of the box; documents the AF2/AF3 default table. - README: backend defaults table + override guidance. - tests: backend default ordering (AF2 >= AF3) and AF2 measured-host-RSS anchors alongside the AF3 GPU-demand anchors (10 tests pass). Co-Authored-By: Claude Opus 4.7 --- README.md | 30 +++++++++++++++++------- config/config.yaml | 26 +++++++++++++++------ test/test_memory_resources.py | 44 ++++++++++++++++++++++++++--------- workflow/Snakefile | 31 +++++++++++++++++++----- workflow/rules/common.smk | 27 +++++++++++++++++++++ 5 files changed, 126 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index af69369..7b1a5d7 100644 --- a/README.md +++ b/README.md @@ -317,17 +317,31 @@ structure_inference mem = safety * (structure_inference_ram_bytes + per_token_s - `seq_len` is the query length; `N` is the **total residues of the complex** (the AlphaFold token count, summed over chains and copy numbers). AlphaFold's pair - representation is `O(N^2)`, hence the quadratic inference term — calibrated so a ~2,000-, - ~4,500- and ~4,800-residue complex request ≈25, ≈90 and ≈100+ GB respectively. + representation is `O(N^2)`, hence the quadratic inference term. +- **The coefficients default by backend** (selected from `--data_pipeline` / `--fold_backend`). + AlphaFold-Multimer (AF2) is heavier than AlphaFold 3 — measured AF2 inference host RSS was + ~4× higher than AF3 at the same complex size, and AF2's feature stage runs HHblits (the + main OOM source), whereas the AF3 pipeline is lighter. Defaults: + + | backend | feature base | feature /residue | inference base | inference /N² | + |---|---|---|---|---| + | `alphafold2` | 64000 MB | 40 MB | 24000 MB | 0.0055 | + | `alphafold3` | 40000 MB | 25 MB | 16000 MB | 0.0045 | + + The AF3 inference quadratic is sized to the observed GPU-VRAM demand so that, with unified + memory, the host spill ceiling (`host_mem / gpu_vram`) covers large complexes instead of + OOM-ing. - The first attempt already includes `mem_safety_factor` (default `1.25`) of head-room. **OOM retries still escalate** on top, multiplying by `..._ram_scaling ** (attempt - 1)`, so a bad estimate self-heals. -- Tune the model via `mem_safety_factor`, `feature_create_ram_per_residue_mb`, - `structure_inference_ram_per_token_sq_mb`, the two `..._ram_bytes` bases, and the two - `..._ram_scaling` factors (all in `config/config.yaml`). Set `max_mem_mb` to your largest - node's RAM on clusters where an over-estimate would otherwise never schedule (`0` = no cap). -- The `..._ram_bytes` keys are now the **fixed base** of each model rather than a flat - request; raising a base only raises the floor. Setting `per_residue`/`per_token_sq` to `0` +- Override any backend default by setting the matching key in `config/config.yaml` + (`feature_create_ram_bytes`, `feature_create_ram_per_residue_mb`, + `structure_inference_ram_bytes`, `structure_inference_ram_per_token_sq_mb`); an explicit + value applies to all backends. Also tune `mem_safety_factor`, the `..._ram_scaling` + factors, `structure_inference_runtime_minutes`, and `max_mem_mb` (set it to your largest + node's RAM where an over-estimate would otherwise never schedule; `0` = no cap). +- The `..._ram_bytes` keys are the **fixed base** of each model rather than a flat request; + raising a base only raises the floor. Setting `per_residue`/`per_token_sq` to `0` reproduces the old length-blind behaviour (a flat base × retry scaling). diff --git a/config/config.yaml b/config/config.yaml index 22396b9..59e6f2f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -61,6 +61,14 @@ analyze_structure_arguments: # where seq_len is the query length, N is the total residues of the complex, and the # request is multiplied by (scaling ** (attempt - 1)) on each retry. AlphaFold's pair # representation is O(N^2), hence the quadratic term for inference. +# +# The base / per-length coefficients DEFAULT BY BACKEND (AF2 is heavier than AF3), +# selected automatically from --data_pipeline / --fold_backend: +# feature base feature /res infer base infer /N^2 +# alphafold2 (AF2 multimer): 64000 MB 40 MB 24000 MB 0.0055 +# alphafold3 (AF3): 40000 MB 25 MB 16000 MB 0.0045 +# Leave the keys below commented to use those backend defaults. Uncomment any key +# to override it for ALL backends. # Safety margin applied to the first-attempt estimate of every length-aware stage. mem_safety_factor: 1.25 @@ -68,18 +76,22 @@ mem_safety_factor: 1.25 # your largest node's RAM on clusters where an over-estimate would never schedule. max_mem_mb: 0 -# create_features (CPU/MSA) — base footprint is database/MSA-tool dominated, with a -# mild linear dependence on query length. -feature_create_ram_bytes: 64000 # MB, fixed base (DB/MSA tooling) -feature_create_ram_per_residue_mb: 30 # MB per query residue +# create_features (CPU/MSA) — base is database/MSA-tool dominated (AF2 HHblits is the +# main OOM source), with a mild linear dependence on query length. +# feature_create_ram_bytes: 64000 # MB, fixed base (default: backend-specific) +# feature_create_ram_per_residue_mb: 30 # MB per query residue (default: backend-specific) feature_create_ram_scaling: 1.1 # per-retry escalation on OOM # structure_inference (GPU host RAM) — base plus a quadratic term in complex size. # With unified memory enabled the XLA spill fraction is derived from this host # allocation, so this also sizes the effective GPU memory ceiling. -structure_inference_ram_bytes: 24000 # MB, fixed base -structure_inference_ram_per_token_sq_mb: 0.0045 # MB per residue^2 (O(N^2)) -structure_inference_ram_scaling: 1.1 # per-retry escalation on OOM +# structure_inference_ram_bytes: 24000 # MB, fixed base (default: backend-specific) +# structure_inference_ram_per_token_sq_mb: 0.0045 # MB per residue^2 (default: backend-specific) +structure_inference_ram_scaling: 1.1 # per-retry escalation on OOM +# Wall-time minutes per attempt for structure_inference (capped by +# structure_inference_max_runtime). Default 1440; AF3 on an adequate GPU finishes far +# sooner, but host-memory spilling can take many hours, so the default stays generous. +# structure_inference_runtime_minutes: 1440 # Number of threads for AlphaFold inference alphafold_inference_threads: 8 diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index e2f9390..6244989 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -84,24 +84,46 @@ def test_inference_mem_model_math_and_cap(): assert capped == 50000 -def test_inference_model_covers_observed_oom_anchors_with_margin(): - """Requested RAM must exceed the empirically observed AF3 peak demand for the - pairs documented in the AlphaJudge handoff, with a sane (not wasteful) margin.""" - kw = dict(base_mb=24000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, attempt=1) - anchors = [ # (total_tokens, observed_peak_GB) - (2066, 25), # O00194+Q9ULV0 - (4556, 82), # P02549+P11277 - (4836, 100), # Q01082+Q13813 - ] +def test_af3_inference_defaults_cover_observed_gpu_demand_anchors(): + """AF3 host request (the unified-memory spill ceiling) must exceed the observed + AF3 GPU-VRAM demand for the pairs documented in the AlphaJudge handoff.""" + d = common.INFERENCE_RAM_DEFAULTS["alphafold3"] + kw = dict(base_mb=d["base_mb"], per_token_sq_mb=d["per_token_sq_mb"], + scaling=1.1, safety=1.25, attempt=1) + anchors = [(2066, 25), (4556, 82), (4836, 100)] # (tokens, observed GB) for n, observed_gb in anchors: req_gb = common.estimate_inference_mem_mb(n, **kw) / 1000.0 assert req_gb >= 1.2 * observed_gb, (n, req_gb, observed_gb) - assert req_gb <= 2.5 * observed_gb, (n, req_gb, observed_gb) - # monotonic in size + assert req_gb <= 2.7 * observed_gb, (n, req_gb, observed_gb) sizes = [common.estimate_inference_mem_mb(n, **kw) for n in (200, 1000, 2000, 4000)] assert sizes == sorted(sizes) +def test_af2_inference_defaults_cover_measured_host_rss(): + """AF2 host request must cover the measured AF2 inference host RSS (which IS the + consumed memory for AF2), with margin.""" + d = common.INFERENCE_RAM_DEFAULTS["alphafold2"] + kw = dict(base_mb=d["base_mb"], per_token_sq_mb=d["per_token_sq_mb"], + scaling=1.1, safety=1.25, attempt=1) + measured = [(1583, 16.9), (2256, 30.8), (2324, 30.8)] # (tokens, measured host RSS GB) + for n, rss_gb in measured: + req_gb = common.estimate_inference_mem_mb(n, **kw) / 1000.0 + assert req_gb >= 1.2 * rss_gb, (n, req_gb, rss_gb) + assert req_gb <= 3.2 * rss_gb, (n, req_gb, rss_gb) + + +def test_backend_defaults_af2_heavier_than_af3(): + assert common.normalize_backend("af3") == "alphafold3" + assert common.normalize_backend("AlphaFold2") == "alphafold2" + assert common.normalize_backend(None) == "alphafold2" + f2, f3 = common.FEATURE_RAM_DEFAULTS["alphafold2"], common.FEATURE_RAM_DEFAULTS["alphafold3"] + i2, i3 = common.INFERENCE_RAM_DEFAULTS["alphafold2"], common.INFERENCE_RAM_DEFAULTS["alphafold3"] + assert f2["base_mb"] > f3["base_mb"] + assert f2["per_residue_mb"] > f3["per_residue_mb"] + assert i2["base_mb"] > i3["base_mb"] + assert i2["per_token_sq_mb"] > i3["per_token_sq_mb"] + + def test_linear_resources_forwards_input_to_new_style_callbacks(): # Snakemake invokes the resource callable as f(wildcards, input=..., attempt=...) res = common.linear_resources( diff --git a/workflow/Snakefile b/workflow/Snakefile index 38182fc..0482487 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -219,23 +219,40 @@ rule download_uniprot: # Length-aware memory model (all values in MB). Host RAM for both compute # stages is sized from the input sequence length(s) with a safety margin, then -# escalated on OOM retries via `scaling ** (attempt - 1)`. See README. +# escalated on OOM retries via `scaling ** (attempt - 1)`. Defaults differ by +# backend (AF2 is heavier than AF3); an explicit config value always wins. See README. MEM_SAFETY_FACTOR = float(config.get("mem_safety_factor", 1.25)) MAX_MEM_MB = int(config.get("max_mem_mb", 0) or 0) # 0 = no cap +# Feature backend = the data pipeline; inference backend = the fold backend +# (falls back to the data pipeline when --fold_backend is unset). +FEATURE_BACKEND = normalize_backend(DATA_PIPELINE) +INFERENCE_BACKEND = normalize_backend( + config.get("structure_inference_arguments", {}).get("--fold_backend", DATA_PIPELINE) +) +_feat_ram_defaults = FEATURE_RAM_DEFAULTS[FEATURE_BACKEND] +_infer_ram_defaults = INFERENCE_RAM_DEFAULTS[INFERENCE_BACKEND] + # create_features (CPU/MSA): mem ~ safety * (base + per_residue * seq_len) feature_scaling = config.get("feature_create_ram_scaling", 1.1) -base_feature_ram = config.get("feature_create_ram_bytes", 64000) -feature_ram_per_residue = config.get("feature_create_ram_per_residue_mb", 30) +base_feature_ram = config.get("feature_create_ram_bytes", _feat_ram_defaults["base_mb"]) +feature_ram_per_residue = config.get( + "feature_create_ram_per_residue_mb", _feat_ram_defaults["per_residue_mb"] +) feature_threads = config.get("feature_threads", 8) # structure_inference (GPU host RAM): mem ~ safety * (base + per_token_sq * N^2), # N = total residues of the complex (AlphaFold pair representation is O(N^2)). -structure_inference_base_ram = config.get("structure_inference_ram_bytes", 24000) +structure_inference_base_ram = config.get( + "structure_inference_ram_bytes", _infer_ram_defaults["base_mb"] +) structure_inference_ram_per_token_sq = config.get( - "structure_inference_ram_per_token_sq_mb", 0.0045 + "structure_inference_ram_per_token_sq_mb", _infer_ram_defaults["per_token_sq_mb"] ) structure_inference_ram_scaling = config.get("structure_inference_ram_scaling", 1.1) +structure_inference_runtime_minutes = config.get( + "structure_inference_runtime_minutes", _infer_ram_defaults["runtime_minutes"] +) rule symlink_features: input: @@ -345,7 +362,9 @@ rule structure_inference: attempt=attempt, cap_mb=MAX_MEM_MB, ), - runtime_fn=lambda wc, attempt: min(1440 * attempt, STRUCTURE_INFERENCE_MAX_RUNTIME), + runtime_fn=lambda wc, attempt: min( + structure_inference_runtime_minutes * attempt, STRUCTURE_INFERENCE_MAX_RUNTIME + ), ), threads: config["alphafold_inference_threads"], diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index ba6f88b..468d906 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -117,6 +117,33 @@ def estimate_inference_mem_mb( return _cap_mem(value, cap_mb) +# Backend-specific memory defaults. AlphaFold-Multimer (AF2) inference carries a +# substantially heavier host-RAM footprint than AlphaFold 3 at the same complex +# size (measured host RSS ~4x higher around N~2300 in the benchmark campaign), and +# its feature stage runs HHblits, the dominant OOM source; the AF3 data pipeline +# (jackhmmer/nhmmer, no HHblits) is lighter. So AF2 gets larger bases and a larger +# quadratic term. These apply only when the matching config key is unset, so an +# explicit config value always wins. +FEATURE_RAM_DEFAULTS = { + "alphafold2": {"base_mb": 64000, "per_residue_mb": 40}, + "alphafold3": {"base_mb": 40000, "per_residue_mb": 25}, +} +INFERENCE_RAM_DEFAULTS = { + # base_mb: fixed floor; per_token_sq_mb: quadratic coeff in N^2 (total residues). + # AF2 base/coeff cover the measured AF2 host RSS with margin; the AF3 quadratic is + # sized to the observed GPU-VRAM demand so the unified-memory spill ceiling + # (host_mem / gpu_vram) covers large complexes instead of OOM-ing. + "alphafold2": {"base_mb": 24000, "per_token_sq_mb": 0.0055, "runtime_minutes": 1440}, + "alphafold3": {"base_mb": 16000, "per_token_sq_mb": 0.0045, "runtime_minutes": 1440}, +} + + +def normalize_backend(name, default: str = "alphafold2") -> str: + """Map a backend/data-pipeline string to 'alphafold2' or 'alphafold3'.""" + n = str(name if name is not None else default).strip().lower() + return "alphafold3" if n in ("alphafold3", "af3") else "alphafold2" + + def feature_suffix(compression: str = "lzma") -> str: _compression = { "lzma": "xz", From 2d80c2d30bfeaa465a089574eb82156ac25040fe Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 15:03:52 +0200 Subject: [PATCH 3/8] fix: keep length-aware inference sizing correct with precomputed features With features supplied via `feature_directory`, a chain is provided by `symlink_features` and its `download_uniprot`/`create_features` rules never run, so `data/.fasta` does not exist. The inference estimator then read length 0 for that chain and silently undercounted N (the default config ships a `feature_directory`, so this is a common path). Fix: `chain_residue_count()` falls back to the precomputed `/_af3_input.json` (via new `af3_input_residue_count()`) when the FASTA is absent and the backend is AF3; `fold_total_tokens()` and the structure_inference rule pass the features dir + backend through. AF2 precomputed pickles can't be read cheaply, so they fall back to the base allocation plus retry escalation (documented). Example: precomputed P0001(300)+P0002(2000) now sizes from N=2300 instead of N=2000. Adds a regression test for the fallback (11 tests pass). Co-Authored-By: Claude Opus 4.7 --- test/test_memory_resources.py | 20 ++++++++++++ workflow/Snakefile | 2 ++ workflow/rules/common.smk | 60 ++++++++++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index 6244989..1a1d14d 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -13,6 +13,7 @@ import importlib.machinery import importlib.util +import json import os import tempfile from pathlib import Path @@ -58,6 +59,25 @@ def test_fold_total_tokens_sums_chains_and_copies(): assert common.fold_total_tokens("A:2+B", d, "+") == 700 +def test_fold_total_tokens_af3_precomputed_feature_fallback(): + """When data/.fasta is absent (precomputed features), AF3 length comes + from the /_af3_input.json fallback.""" + common.residue_count.cache_clear() + common.af3_input_residue_count.cache_clear() + with tempfile.TemporaryDirectory() as d, tempfile.TemporaryDirectory() as feat: + _write_fasta(d, "B", 300) # only B has a FASTA; A is "precomputed" + with open(os.path.join(feat, "A_af3_input.json"), "w") as fh: + json.dump({"sequences": [{"protein": {"id": "A", "sequence": "M" * 250}}]}, fh) + # without fallback A is unknown -> only B counted + assert common.fold_total_tokens("A+B", d, "+") == 300 + # with AF3 fallback A is recovered from the json + assert ( + common.fold_total_tokens("A+B", d, "+", features_dir=feat, is_af3=True) == 550 + ) + # fallback is AF3-only: AF2 precomputed stays at 0 for the missing chain + assert common.fold_total_tokens("A+B", d, "+", features_dir=feat, is_af3=False) == 300 + + def test_feature_mem_model_math(): # safety * (base + per_residue * L), attempt 1 has no extra escalation val = common.estimate_feature_mem_mb( diff --git a/workflow/Snakefile b/workflow/Snakefile index 0482487..724b2e6 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -354,6 +354,8 @@ rule structure_inference: wildcards.fold, join(config["output_directory"], "data"), protein_delimiter, + features_dir=join(config["output_directory"], "features"), + is_af3=(INFERENCE_BACKEND == "alphafold3"), ), base_mb=structure_inference_base_ram, per_token_sq_mb=structure_inference_ram_per_token_sq, diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 468d906..b8d52f4 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -9,6 +9,7 @@ from __future__ import annotations import functools import inspect +import json import os from collections.abc import Iterable from pathlib import Path @@ -38,14 +39,65 @@ def residue_count(fasta_path: str) -> int: return 0 -def fold_total_tokens(fold: str, data_dir: str, delimiter: str = "+") -> int: +@functools.lru_cache(maxsize=None) +def af3_input_residue_count(json_path: str) -> int: + """Total polymer residues in an AlphaFold 3 ``*_af3_input.json`` feature file. + + Sums the ``sequence`` length of every protein/RNA/DNA entry under + ``sequences`` (ligands have no sequence and are skipped). Returns 0 if the + file is missing or not parseable. Used as a fallback for the chain length + when no ``data/.fasta`` exists (e.g. precomputed features supplied via + ``feature_directory``, where the download/feature rules never run). + """ + try: + with open(json_path) as handle: + data = json.load(handle) + except (OSError, ValueError): + return 0 + total = 0 + for entry in data.get("sequences", []): + if not isinstance(entry, dict): + continue + for mol in ("protein", "rna", "dna"): + mol_entry = entry.get(mol) + if isinstance(mol_entry, dict): + total += len(mol_entry.get("sequence", "") or "") + return total + + +def chain_residue_count( + name: str, data_dir: str, features_dir: str | None = None, is_af3: bool = False +) -> int: + """Residue length of a single chain. + + Reads ``/.fasta`` first; if that is unavailable (returns 0) + and the run is AlphaFold 3, falls back to the precomputed + ``/_af3_input.json`` so length-aware sizing still works + when features are supplied via ``feature_directory`` rather than generated. + """ + length = residue_count(os.path.join(data_dir, f"{name}.fasta")) + if length == 0 and is_af3 and features_dir: + length = af3_input_residue_count( + os.path.join(features_dir, f"{name}_af3_input.json") + ) + return length + + +def fold_total_tokens( + fold: str, + data_dir: str, + delimiter: str = "+", + features_dir: str | None = None, + is_af3: bool = False, +) -> int: """Total residue (token) count of a fold specification. Sums the residue length of every chain in ``fold``, honouring copy numbers such as ``A:2`` (a homo-dimer counts twice). Region selections such as ``A:1-100`` are conservatively counted at the chain's full length, which - over- rather than under-estimates memory. Per-chain lengths are read from - ``/.fasta``. + over- rather than under-estimates memory. Per-chain lengths come from + ``/.fasta`` with an AF3 precomputed-feature fallback (see + ``chain_residue_count``). """ total = 0 for token in str(fold).split(delimiter): @@ -54,7 +106,7 @@ def fold_total_tokens(fold: str, data_dir: str, delimiter: str = "+") -> int: continue name = parts[0] copies = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else 1 - total += residue_count(os.path.join(data_dir, f"{name}.fasta")) * copies + total += chain_residue_count(name, data_dir, features_dir, is_af3) * copies return total From fe44a90af0b1f9ae2c9544ce889573d9a15f0e9b Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 15:21:55 +0200 Subject: [PATCH 4/8] feat: length filtering (issue #33 + backend total caps), CI, and follow-ups Adds configurable length limits, a CI workflow, an end-to-end SLURM wiring test, and closes the remaining review follow-ups. Length filtering (issue #33 + requested total caps): - Skip folds before any job is created when they exceed a length limit, so an oversized complex never wastes a feature/GPU allocation that would only OOM. - max_total_length_alphafold2: 5000 / max_total_length_alphafold3: 7000 (selected by --fold_backend; optional single `max_total_length` override). - max_protein_length (0=off, issue #33): a protein over the limit drops every fold containing it, so it is never downloaded. - Lengths resolved at parse time: local FASTA -> data/.fasta -> persistent cache /.sequence_lengths.tsv -> UniProt REST API (length_filter_fetch_uniprot, default true). Skipped folds + reasons go to /skipped_folds.tsv; unknown lengths fail open (kept). - Fixes copy-number parsing: AlphaPulldown spec is name:copies:region, so the copy count is the first ':' token after the name (A:2:1-100 = 2 copies), not the last. Review follow-ups: - #1: integration test feeds our computed mem_mb through the real SLURM plugin's get_submit_command and asserts it becomes `sbatch --mem ` (skips if the plugin is absent). - #2: GitHub Actions CI (.github/workflows/ci.yml) byte-compiles common.smk and runs the dependency-free unit suite on py3.10/3.12. - #5: AF2 precomputed-feature inference sizing now recovers length from the parse-time length cache (no FASTA / no AF3 JSON needed). - #6: AF3 ligand atoms are intentionally not counted (no sequence); documented, with a test asserting ligand entries contribute 0. The parse-time length cache is shared with memory sizing, so precomputed-feature runs get correct length-aware memory too. 18 unit tests pass; AF2/AF3/precomputed dry-runs build clean. Co-Authored-By: Claude Opus 4.7 --- .github/workflows/ci.yml | 26 +++++++ README.md | 37 +++++++++ config/config.yaml | 19 +++++ test/test_memory_resources.py | 125 +++++++++++++++++++++++++++++++ workflow/Snakefile | 137 +++++++++++++++++++++++++++++++--- workflow/rules/common.smk | 109 +++++++++++++++++++++++---- 6 files changed, 430 insertions(+), 23 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3365cac --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: tests + +on: + push: + branches: [main] + pull_request: + +jobs: + unit-tests: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.12"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Byte-compile common.smk + # common.smk is plain Python and carries the memory/length logic. + run: python -m py_compile workflow/rules/common.smk + - name: Run resource/length unit tests + # The suite loads common.smk directly and needs no third-party deps; + # the SLURM-plugin integration test self-skips when the plugin is absent. + run: python test/test_memory_resources.py diff --git a/README.md b/README.md index 7b1a5d7..c7c16d7 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,43 @@ structure_inference mem = safety * (structure_inference_ram_bytes + per_token_s - The `..._ram_bytes` keys are the **fixed base** of each model rather than a flat request; raising a base only raises the floor. Setting `per_residue`/`per_token_sq` to `0` reproduces the old length-blind behaviour (a flat base × retry scaling). +- **Precomputed features:** when a chain is supplied via `feature_directory`, no + `data/.fasta` is generated. Length is then recovered from the precomputed + `_af3_input.json` (AF3) or from the parse-time length cache written by the length + filter below (covers AF2 too). If neither is available the job falls back to the base + allocation plus retry escalation. AF3 ligand atoms are not counted (no sequence), a small + undercount absorbed by the safety margin. + + + +
+Skipping over-large complexes (length filtering) + +Folds that are too large to be worth submitting are **skipped before any job is created**, +so a single oversized complex (or one giant chain) doesn't waste a GPU/feature allocation +that will only OOM. Two configurable limits (in `config/config.yaml`): + +```yaml +# Max TOTAL complex length (sum of all chains), per backend — selected by --fold_backend. +max_total_length_alphafold2: 5000 # AF2-Multimer +max_total_length_alphafold3: 7000 # AF3 handles larger inputs +# max_total_length: 6000 # optional single override for both backends +# Max length of any SINGLE protein; 0 = off (issue #33). A protein over this drops every +# fold containing it, so it is never even downloaded. +max_protein_length: 0 +length_filter_fetch_uniprot: true # set false for fully offline runs +``` + +- Lengths are resolved at **parse time** from, in order: a local FASTA, an + already-downloaded `data/.fasta`, the persistent cache + `/.sequence_lengths.tsv`, and finally the UniProt REST API (cached for + next time). Set a limit to `0` to disable it; if both are `0`, no resolution/fetching + happens at all. +- Skipped folds are listed with reasons in `/skipped_folds.tsv` and logged + as a `[length-filter]` warning. **Unknown lengths fail open** (the fold is kept), so a + UniProt outage never silently drops work. +- First parse of a large all-UniProt sheet will fetch each unique length once (cached + afterwards); already-downloaded inputs and local FASTAs are read without any network call.
diff --git a/config/config.yaml b/config/config.yaml index 59e6f2f..e8171b5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -93,6 +93,25 @@ structure_inference_ram_scaling: 1.1 # per-retry escalation on OOM # sooner, but host-memory spilling can take many hours, so the default stays generous. # structure_inference_runtime_minutes: 1440 +# Length filtering: skip folds that are too large to be worth submitting. Lengths +# are resolved at parse time (local FASTA / already-downloaded data / cache / the +# UniProt REST API) and cached in /.sequence_lengths.tsv; skipped +# folds are listed in /skipped_folds.tsv. Unknown lengths fail +# open (the fold is kept). Set a limit to 0 to disable it. +# +# Max TOTAL complex length (sum of all chains, residues), per backend; selected by +# --fold_backend. AF3 handles larger inputs than AF2-Multimer. +max_total_length_alphafold2: 5000 +max_total_length_alphafold3: 7000 +# Optional single override applied to BOTH backends (takes precedence if set): +# max_total_length: 6000 +# Max length of any SINGLE protein (residues); 0 disables (issue #33). A protein +# over this drops every fold containing it, so it is never downloaded/predicted. +max_protein_length: 0 +# Whether parse-time length resolution may query the UniProt REST API for IDs that +# have no local FASTA / cached length yet. Set false for fully offline runs. +length_filter_fetch_uniprot: true + # Number of threads for AlphaFold inference alphafold_inference_threads: 8 diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index 1a1d14d..d032bb6 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -165,6 +165,131 @@ def test_linear_resources_default_scaling_without_callbacks(): assert res["attempt"]({}, input=[], attempt=4) == 4 +# --- length filtering (issue #33 + total caps) ------------------------------- + +def test_parse_fold_chains(): + assert common.parse_fold_chains("A+B") == [("A", 1), ("B", 1)] + assert common.parse_fold_chains("A:2") == [("A", 2)] + assert common.parse_fold_chains("A:1-100") == [("A", 1)] # region, not a copy + assert common.parse_fold_chains("A:2:1-100+B") == [("A", 2), ("B", 1)] + + +def test_fold_length_violation(): + chains = [("A", 300, 1), ("B", 2000, 1)] # total 2300 + assert common.fold_length_violation(chains, 0, 0) is None # limits off + assert common.fold_length_violation(chains, 0, 5000) is None # under total cap + assert common.fold_length_violation(chains, 0, 1000) is not None # over total cap + assert common.fold_length_violation(chains, 1000, 0) is not None # B over per-protein + assert common.fold_length_violation(chains, 2500, 0) is None # under per-protein + # homo-dimer copies count toward the total + assert common.fold_length_violation([("A", 600, 3)], 0, 1500) is not None + # unknown length fails open (None treated as 0) + assert common.fold_length_violation([("A", None, 1)], 0, 10) is None + + +def test_default_total_length_caps_af3_gt_af2(): + assert common.MAX_TOTAL_LENGTH_DEFAULTS["alphafold2"] == 5000 + assert common.MAX_TOTAL_LENGTH_DEFAULTS["alphafold3"] == 7000 + + +def test_fetch_uniprot_length_parses_and_fails_open(): + import urllib.request as ur + + class _FakeResp: + def __init__(self, data): + self._data = data + + def read(self): + return self._data + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + orig = ur.urlopen + try: + common.fetch_uniprot_length.cache_clear() + ur.urlopen = lambda url, timeout=30.0: _FakeResp(b">sp|X\nMKLVMK\nAA\n") + assert common.fetch_uniprot_length("FAKE_OK") == 8 # 6 + 2, header skipped + + def _boom(url, timeout=30.0): + raise OSError("offline") + + ur.urlopen = _boom + assert common.fetch_uniprot_length("FAKE_FAIL") == 0 # fail open, no crash + finally: + ur.urlopen = orig + + +def test_chain_residue_count_length_cache_fallback(): + """AF2 precomputed features: no FASTA and no AF3 JSON, but the parse-time + length cache supplies the length.""" + common.residue_count.cache_clear() + with tempfile.TemporaryDirectory() as d: + # no data/A.fasta exists; cache provides it + assert common.chain_residue_count("A", d) == 0 + assert common.chain_residue_count("A", d, length_cache={"A": 412}) == 412 + + +def test_af3_input_residue_count_skips_ligands(): + with tempfile.TemporaryDirectory() as d: + p = os.path.join(d, "x.json") + with open(p, "w") as fh: + json.dump( + {"sequences": [ + {"protein": {"id": "A", "sequence": "M" * 100}}, + {"ligand": {"id": "L", "ccdCodes": ["ATP"]}}, # no 'sequence' + ]}, + fh, + ) + assert common.af3_input_residue_count(p) == 100 # ligand contributes 0 + + +def test_mem_mb_reaches_sbatch_via_real_plugin(): + """Integration: the value our model computes is what the SLURM plugin turns + into `sbatch --mem`. Skips gracefully if the plugin isn't importable.""" + try: + from snakemake_executor_plugin_slurm.submit_string import get_submit_command + except Exception as exc: # pragma: no cover - depends on environment + print(f" (skipped: plugin not importable: {exc})") + return + + mem = common.estimate_inference_mem_mb( + 2300, base_mb=16000, per_token_sq_mb=0.0045, scaling=1.1, safety=1.25, attempt=1 + ) + + class _Res(dict): + def get(self, key, default=None): + return dict.get(self, key, default) + + def __getattr__(self, key): + try: + return self[key] + except KeyError as exc: + raise AttributeError(key) from exc + + class _Job: + threads = 8 + resources = _Res(mem_mb=mem, runtime=600, qos="normal") + + params = { + "run_uuid": "test", + "slurm_logfile": "/tmp/test.log", + "comment_str": "test", + "account": "", + "partition": "", + "workdir": "", + } + try: + cmd = get_submit_command(_Job(), params) + except Exception as exc: # pragma: no cover - plugin internals may change + print(f" (skipped: plugin API changed: {exc})") + return + assert f"--mem {mem}" in cmd, cmd + + def _run_all(): fns = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] for fn in fns: diff --git a/workflow/Snakefile b/workflow/Snakefile index 724b2e6..b42aedb 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -27,6 +27,14 @@ if isinstance(feature_directories, (str, Path)): DATA_PIPELINE = config["create_feature_arguments"].get("--data_pipeline", "alphafold2").lower() IS_AF3 = DATA_PIPELINE in ("alphafold3", "af3") +# Backend identity (feature stage = data pipeline; inference stage = fold backend, +# falling back to the data pipeline). Defined early because length filtering and +# memory sizing both depend on it. +FEATURE_BACKEND = normalize_backend(DATA_PIPELINE) +INFERENCE_BACKEND = normalize_backend( + config.get("structure_inference_arguments", {}).get("--fold_backend", DATA_PIPELINE) +) + if IS_AF3: FEATURE_COMPRESSION = None FEATURE_SUFFIX = None @@ -134,9 +142,124 @@ local_sequences = tuple(dataset.sequences_by_origin.get("local", ())) localrules: symlink_local_files, download_uniprot, symlink_features ruleorder: symlink_local_files > download_uniprot > symlink_features > create_features +# --- Length filtering (issue #33 + backend total-length caps) ---------------- +# Skip folds whose chains exceed a per-protein limit (max_protein_length) or whose +# total complex length exceeds the backend cap (max_total_length, default 5000 for +# AF2 / 7000 for AF3), so infeasible jobs are never submitted. Limits are +# configurable; 0 disables. Per-protein lengths are resolved at parse time from +# local FASTAs, already-downloaded data/.fasta, a persistent cache, and +# finally the UniProt REST API; the resulting cache also feeds memory sizing for +# precomputed-feature runs. Unknown lengths fail open (the fold is kept). +MAX_PROTEIN_LENGTH = int(config.get("max_protein_length", 0) or 0) +_max_total_length_cfg = config.get("max_total_length") +if _max_total_length_cfg is None: + _max_total_length_cfg = config.get( + f"max_total_length_{INFERENCE_BACKEND}", + MAX_TOTAL_LENGTH_DEFAULTS[INFERENCE_BACKEND], + ) +MAX_TOTAL_LENGTH = int(_max_total_length_cfg or 0) +LENGTH_FILTER_FETCH_UNIPROT = _as_bool( + config.get("length_filter_fetch_uniprot", True), default=True +) + +_sequence_length_cache_path = join(config["output_directory"], ".sequence_lengths.tsv") +sequence_length_cache = {} +if exists(_sequence_length_cache_path): + with open(_sequence_length_cache_path) as _cache_fh: + for _cache_line in _cache_fh: + _cache_row = _cache_line.rstrip("\n").split("\t") + if len(_cache_row) >= 2 and _cache_row[1].strip().isdigit(): + sequence_length_cache[_cache_row[0]] = int(_cache_row[1]) + +kept_folds = list(dataset.fold_specifications) +dropped_folds = [] + +if MAX_PROTEIN_LENGTH > 0 or MAX_TOTAL_LENGTH > 0: + _local_path_by_base = { + splitext(basename(p))[0]: p + for p in dataset.sequences_by_origin.get("local", ()) + } + _data_dir = join(config["output_directory"], "data") + _features_dir = join(config["output_directory"], "features") + + def _resolve_protein_length(name): + if name in sequence_length_cache: + return sequence_length_cache[name] + length = 0 + if name in _local_path_by_base: + length = residue_count(_local_path_by_base[name]) + if length == 0: + length = residue_count(join(_data_dir, f"{name}.fasta")) + if length == 0 and IS_AF3: + length = af3_input_residue_count( + join(_features_dir, f"{name}_af3_input.json") + ) + if length == 0 and LENGTH_FILTER_FETCH_UNIPROT: + length = fetch_uniprot_length(name) + if length > 0: + sequence_length_cache[name] = length + return length + return None + + kept_folds = [] + _unknown_proteins = set() + for fold in dataset.fold_specifications: + chain_lengths = [] + for name, copies in parse_fold_chains(fold, protein_delimiter): + length = _resolve_protein_length(name) + if length is None: + _unknown_proteins.add(name) + chain_lengths.append((name, length, copies)) + reason = fold_length_violation( + chain_lengths, MAX_PROTEIN_LENGTH, MAX_TOTAL_LENGTH + ) + if reason: + dropped_folds.append((fold, reason)) + else: + kept_folds.append(fold) + + try: + makedirs(config["output_directory"], exist_ok=True) + with open(_sequence_length_cache_path, "w") as _cache_fh: + for _name in sorted(sequence_length_cache): + _cache_fh.write(f"{_name}\t{sequence_length_cache[_name]}\n") + except OSError: + pass + + if dropped_folds: + _skipped_path = join(config["output_directory"], "skipped_folds.tsv") + try: + with open(_skipped_path, "w") as _skip_fh: + _skip_fh.write("fold\treason\n") + for _fold, _reason in dropped_folds: + _skip_fh.write(f"{_fold}\t{_reason}\n") + except OSError: + _skipped_path = "(could not be written)" + logger.warning( + f"[length-filter] skipping {len(dropped_folds)} of " + f"{len(dataset.fold_specifications)} folds over the length limits " + f"(max_protein_length={MAX_PROTEIN_LENGTH}, max_total_length=" + f"{MAX_TOTAL_LENGTH} for {INFERENCE_BACKEND}); see {_skipped_path}" + ) + if _unknown_proteins: + logger.warning( + f"[length-filter] length unknown for {len(_unknown_proteins)} protein(s); " + f"their folds were kept (fail-open). Examples: " + f"{sorted(_unknown_proteins)[:5]}" + ) + +# Proteins required by the surviving folds (matches FoldDataset dedup ordering). +kept_proteins = [] +_seen_proteins = set() +for fold in kept_folds: + for _name in dataset.sequences_by_fold.get(fold, ()): + if _name not in _seen_proteins: + _seen_proteins.add(_name) + kept_proteins.append(_name) + required_folds = [ join(config["output_directory"], "predictions", fold, "completed_fold.txt") - for fold in dataset.fold_specifications + for fold in kept_folds ] ENABLE_STRUCTURE_ANALYSIS = config.get("enable_structure_analysis", True) GENERATE_RECURSIVE_REPORT = config.get("generate_recursive_report", False) @@ -149,7 +272,7 @@ RECURSIVE_REPORT = ( required_feature_paths = [ join(config["output_directory"], "features", feature_name(fasta_basename)) - for fasta_basename in dataset.unique_sequences + for fasta_basename in kept_proteins ] if config.get("only_generate_features", False): required_targets = required_feature_paths @@ -160,7 +283,7 @@ else: else: required_targets = [ join(config["output_directory"], "predictions", fold, "interfaces.csv") - for fold in dataset.fold_specifications + for fold in kept_folds ] else: required_targets = required_folds @@ -224,12 +347,7 @@ rule download_uniprot: MEM_SAFETY_FACTOR = float(config.get("mem_safety_factor", 1.25)) MAX_MEM_MB = int(config.get("max_mem_mb", 0) or 0) # 0 = no cap -# Feature backend = the data pipeline; inference backend = the fold backend -# (falls back to the data pipeline when --fold_backend is unset). -FEATURE_BACKEND = normalize_backend(DATA_PIPELINE) -INFERENCE_BACKEND = normalize_backend( - config.get("structure_inference_arguments", {}).get("--fold_backend", DATA_PIPELINE) -) +# Backend defaults (FEATURE_BACKEND / INFERENCE_BACKEND are defined near the top). _feat_ram_defaults = FEATURE_RAM_DEFAULTS[FEATURE_BACKEND] _infer_ram_defaults = INFERENCE_RAM_DEFAULTS[INFERENCE_BACKEND] @@ -356,6 +474,7 @@ rule structure_inference: protein_delimiter, features_dir=join(config["output_directory"], "features"), is_af3=(INFERENCE_BACKEND == "alphafold3"), + length_cache=sequence_length_cache, ), base_mb=structure_inference_base_ram, per_token_sq_mb=structure_inference_ram_per_token_sq, diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index b8d52f4..226ef2d 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -11,10 +11,16 @@ import functools import inspect import json import os +import urllib.request from collections.abc import Iterable from pathlib import Path from typing import Any, Callable +# Default maximum *total* complex length (residues) per backend, used to skip +# folds that are too large to be feasible. AF3 supports larger inputs than +# AF2-Multimer. Override via config (see Snakefile / config.yaml). +MAX_TOTAL_LENGTH_DEFAULTS = {"alphafold2": 5000, "alphafold3": 7000} + @functools.lru_cache(maxsize=None) def residue_count(fasta_path: str) -> int: @@ -65,21 +71,69 @@ def af3_input_residue_count(json_path: str) -> int: return total +def parse_fold_chains(fold: str, delimiter: str = "+") -> list[tuple[str, int]]: + """Parse a fold spec into ``(chain_name, copies)`` pairs. + + Follows the AlphaPulldown ``name[:copies][:region...]`` convention: the copy + number, if present, is the first ``:`` token after the name and is a bare + integer (e.g. ``A:2`` = dimer, ``A:2:1-100`` = dimer of residues 1-100). + Region tokens such as ``1-100`` are never bare integers, so ``A:1-100`` is a + single copy. Handles ``A+B`` heteromers. + """ + chains: list[tuple[str, int]] = [] + for token in str(fold).split(delimiter): + parts = [part for part in token.split(":") if part] + if not parts: + continue + name = parts[0] + copies = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 1 + chains.append((name, copies)) + return chains + + +@functools.lru_cache(maxsize=None) +def fetch_uniprot_length(uniprot_id: str, timeout: float = 30.0) -> int: + """Residue length of a UniProt entry via the REST API; 0 on any failure. + + Mirrors the reference snippet in issue #33. Used at parse time for length + filtering when no local FASTA is available yet; failures return 0 so the + caller can fail open (keep the fold) rather than crash offline. + """ + url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" + try: + with urllib.request.urlopen(url, timeout=timeout) as response: + text = response.read().decode("utf-8", "replace") + except Exception: + return 0 + total = 0 + for line in text.splitlines(): + if not line.startswith(">"): + total += len(line.strip()) + return total + + def chain_residue_count( - name: str, data_dir: str, features_dir: str | None = None, is_af3: bool = False + name: str, + data_dir: str, + features_dir: str | None = None, + is_af3: bool = False, + length_cache: dict | None = None, ) -> int: """Residue length of a single chain. - Reads ``/.fasta`` first; if that is unavailable (returns 0) - and the run is AlphaFold 3, falls back to the precomputed - ``/_af3_input.json`` so length-aware sizing still works - when features are supplied via ``feature_directory`` rather than generated. + Resolution order: ``/.fasta`` -> AF3 precomputed + ``/_af3_input.json`` (AF3 only) -> ``length_cache`` (the + parse-time length table, which covers the AF2 precomputed-feature case where + neither a FASTA nor an AF3 JSON exists). Returns 0 when length is unknown so + sizing degrades to the base allocation plus retry escalation. """ length = residue_count(os.path.join(data_dir, f"{name}.fasta")) if length == 0 and is_af3 and features_dir: length = af3_input_residue_count( os.path.join(features_dir, f"{name}_af3_input.json") ) + if length == 0 and length_cache: + length = int(length_cache.get(name, 0) or 0) return length @@ -89,6 +143,7 @@ def fold_total_tokens( delimiter: str = "+", features_dir: str | None = None, is_af3: bool = False, + length_cache: dict | None = None, ) -> int: """Total residue (token) count of a fold specification. @@ -96,20 +151,46 @@ def fold_total_tokens( such as ``A:2`` (a homo-dimer counts twice). Region selections such as ``A:1-100`` are conservatively counted at the chain's full length, which over- rather than under-estimates memory. Per-chain lengths come from - ``/.fasta`` with an AF3 precomputed-feature fallback (see - ``chain_residue_count``). + ``chain_residue_count`` (FASTA -> AF3 JSON -> length cache). + + Note: AF3 ligand atoms are not counted (no ``sequence`` field); for + protein/nucleic complexes this matches the token count, and the safety + margin plus retry escalation absorb any small ligand undercount. """ total = 0 - for token in str(fold).split(delimiter): - parts = [part for part in token.split(":") if part] - if not parts: - continue - name = parts[0] - copies = int(parts[-1]) if len(parts) > 1 and parts[-1].isdigit() else 1 - total += chain_residue_count(name, data_dir, features_dir, is_af3) * copies + for name, copies in parse_fold_chains(fold, delimiter): + total += ( + chain_residue_count(name, data_dir, features_dir, is_af3, length_cache) + * copies + ) return total +def fold_length_violation( + chain_lengths: list[tuple[str, int | None, int]], + max_protein_length: int = 0, + max_total_length: int = 0, +) -> str | None: + """Return a human-readable reason if a fold exceeds a length limit, else None. + + ``chain_lengths`` is a list of ``(name, length_or_None, copies)``. Limits of + 0 (or negative) are disabled. Unknown lengths (``None``) are treated as 0 so + the decision fails open (the fold is kept) rather than dropped on missing data. + """ + if max_protein_length and max_protein_length > 0: + for name, length, _copies in chain_lengths: + if length is not None and length > max_protein_length: + return ( + f"protein {name} length {length} exceeds " + f"max_protein_length {max_protein_length}" + ) + if max_total_length and max_total_length > 0: + total = sum((length or 0) * copies for _name, length, copies in chain_lengths) + if total > max_total_length: + return f"total length {total} exceeds max_total_length {max_total_length}" + return None + + def _cap_mem(value_mb: float, cap_mb: int) -> int: value = max(int(value_mb), 1) if cap_mb and cap_mb > 0: From 03beeccc6138e615d5fe8068b351d5a61a52cf6d Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 21:33:59 +0200 Subject: [PATCH 5/8] fix: don't cache missing-file length reads (collapsed sizing to base) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Caught by a real SLURM submission: both create_features jobs requested mem=80000 (= safety*base, i.e. length 0) despite the FASTAs being present. Root cause: residue_count()/af3_input_residue_count() were lru_cached. Snakemake's scheduler evaluates resource functions early — before the upstream download_uniprot localrule has produced data/.fasta — so a 0 was memoised for the not-yet-existing file and then returned even after the file appeared, collapsing length-aware sizing to the base allocation. Fix: - Cache only successful (>0) reads; a missing/unreadable file returns 0 without caching, so it is re-read once produced. - create_features now sizes from the parse-time length cache (populated before any job runs) with the staged FASTA / rule input as fallback, so it no longer depends on when Snakemake evaluates the resource relative to the download. Verified on the real cluster: P01258 (141 aa) -> sbatch --mem 87050, P00533 (1210 aa) -> --mem 140500 (= 1.25*(64000+40*L)); two jobs, different length-aware memory. Adds a regression test that a missing file is not cached. 19 unit tests pass. Co-Authored-By: Claude Opus 4.7 --- test/test_memory_resources.py | 21 +++++++++++++---- workflow/Snakefile | 11 ++++++++- workflow/rules/common.smk | 44 +++++++++++++++++++++++------------ 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index d032bb6..4d0e70c 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -37,7 +37,7 @@ def _write_fasta(directory: str, name: str, length: int) -> str: def test_residue_count_counts_sequence_only(): - common.residue_count.cache_clear() + common._RESIDUE_COUNT_CACHE.clear() with tempfile.TemporaryDirectory() as d: p = _write_fasta(d, "X", 137) assert common.residue_count(p) == 137 @@ -45,8 +45,19 @@ def test_residue_count_counts_sequence_only(): assert common.residue_count(os.path.join(d, "does_not_exist.fasta")) == 0 +def test_residue_count_does_not_cache_missing_file(): + """Regression: an early read of a not-yet-created file must NOT cache 0, or + length-aware sizing collapses to the base once the file later appears.""" + common._RESIDUE_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + p = os.path.join(d, "late.fasta") + assert common.residue_count(p) == 0 # file absent now + _write_fasta(d, "late", 321) # produced by an upstream job later + assert common.residue_count(p) == 321 # re-read, not the stale 0 + + def test_fold_total_tokens_sums_chains_and_copies(): - common.residue_count.cache_clear() + common._RESIDUE_COUNT_CACHE.clear() with tempfile.TemporaryDirectory() as d: _write_fasta(d, "A", 200) _write_fasta(d, "B", 300) @@ -62,8 +73,8 @@ def test_fold_total_tokens_sums_chains_and_copies(): def test_fold_total_tokens_af3_precomputed_feature_fallback(): """When data/.fasta is absent (precomputed features), AF3 length comes from the /_af3_input.json fallback.""" - common.residue_count.cache_clear() - common.af3_input_residue_count.cache_clear() + common._RESIDUE_COUNT_CACHE.clear() + common._AF3_INPUT_COUNT_CACHE.clear() with tempfile.TemporaryDirectory() as d, tempfile.TemporaryDirectory() as feat: _write_fasta(d, "B", 300) # only B has a FASTA; A is "precomputed" with open(os.path.join(feat, "A_af3_input.json"), "w") as fh: @@ -226,7 +237,7 @@ def _boom(url, timeout=30.0): def test_chain_residue_count_length_cache_fallback(): """AF2 precomputed features: no FASTA and no AF3 JSON, but the parse-time length cache supplies the length.""" - common.residue_count.cache_clear() + common._RESIDUE_COUNT_CACHE.clear() with tempfile.TemporaryDirectory() as d: # no data/A.fasta exists; cache provides it assert common.chain_residue_count("A", d) == 0 diff --git a/workflow/Snakefile b/workflow/Snakefile index b42aedb..2a09349 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -407,8 +407,17 @@ rule create_features: resources: qos=DEFAULT_SLURM_QOS, **linear_resources( + # Length from the parse-time cache (populated before any job runs) or + # the staged FASTA, falling back to the rule input. Using the cache + # avoids depending on when Snakemake evaluates the resource relative to + # the upstream download. mem_fn=lambda wildcards, input, attempt: estimate_feature_mem_mb( - residue_count(str(input[0])) if input else 0, + chain_residue_count( + wildcards.fasta_basename, + join(config["output_directory"], "data"), + length_cache=sequence_length_cache, + ) + or (residue_count(str(input[0])) if input else 0), base_mb=base_feature_ram, per_residue_mb=feature_ram_per_residue, scaling=feature_scaling, diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 226ef2d..b6fb80f 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -22,39 +22,51 @@ from typing import Any, Callable MAX_TOTAL_LENGTH_DEFAULTS = {"alphafold2": 5000, "alphafold3": 7000} -@functools.lru_cache(maxsize=None) +# Length lookups cache only *successful* (>0) reads. Caching a 0 from a not-yet- +# created file would be a correctness bug: Snakemake's scheduler evaluates resource +# functions early (before upstream download/symlink rules run), and a memoised 0 +# would then stick even after the file appears, collapsing length-aware sizing to +# the base allocation. Re-reading a small FASTA/JSON on later calls is cheap. +_RESIDUE_COUNT_CACHE: dict[str, int] = {} +_AF3_INPUT_COUNT_CACHE: dict[str, int] = {} + + def residue_count(fasta_path: str) -> int: """Number of residues in a (single-record) FASTA file. - Counts sequence characters, ignoring the header line(s) and whitespace. - Returns 0 when the file cannot be read yet (e.g. during a dry-run before - the upstream download/symlink rule has produced it) so that resource - estimation degrades gracefully to the base allocation instead of crashing. - Results are memoised because the structure-inference estimator may look up - the same chain repeatedly within a workflow. + Counts sequence characters, ignoring header lines and whitespace. Returns 0 + when the file cannot be read yet (so estimation degrades to the base + allocation rather than crashing) and does not cache that 0 — see the note + above on why caching a missing-file result would be wrong. """ + cached = _RESIDUE_COUNT_CACHE.get(fasta_path) + if cached: + return cached try: total = 0 with open(fasta_path) as handle: for line in handle: - if line.startswith(">"): - continue - total += len(line.strip()) - return total + if not line.startswith(">"): + total += len(line.strip()) except OSError: return 0 + if total > 0: + _RESIDUE_COUNT_CACHE[fasta_path] = total + return total -@functools.lru_cache(maxsize=None) def af3_input_residue_count(json_path: str) -> int: """Total polymer residues in an AlphaFold 3 ``*_af3_input.json`` feature file. Sums the ``sequence`` length of every protein/RNA/DNA entry under ``sequences`` (ligands have no sequence and are skipped). Returns 0 if the - file is missing or not parseable. Used as a fallback for the chain length - when no ``data/.fasta`` exists (e.g. precomputed features supplied via - ``feature_directory``, where the download/feature rules never run). + file is missing or not parseable (not cached); used as a fallback for the + chain length when no ``data/.fasta`` exists (e.g. precomputed features + supplied via ``feature_directory``). """ + cached = _AF3_INPUT_COUNT_CACHE.get(json_path) + if cached: + return cached try: with open(json_path) as handle: data = json.load(handle) @@ -68,6 +80,8 @@ def af3_input_residue_count(json_path: str) -> int: mol_entry = entry.get(mol) if isinstance(mol_entry, dict): total += len(mol_entry.get("sequence", "") or "") + if total > 0: + _AF3_INPUT_COUNT_CACHE[json_path] = total return total From 0553f5bcabe0615fc25cc074cffa62379e579ec2 Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 21:59:29 +0200 Subject: [PATCH 6/8] feat: length-based GPU model selection for structure_inference Route each complex to the smallest GPU tier that fits it, instead of pinning a single GPU model. Peak GPU VRAM is ~0.0045*N^2, so a small complex wastes a big card and a large one OOMs a small card; this picks the right tier per job. New config: structure_inference_gpu_model_by_tokens, a map of inclusive upper total-token bound -> GPU model, e.g. {2000: "3090", 4000: "A100", 99999: "H100"}. Each fold routes to the smallest tier it fits (larger than all -> last tier). When set it takes precedence over structure_inference_gpu_model; when unset the fixed model is used (unchanged behaviour). select_gpu_model() in common.smk; the gpu_model resource is now a callable that derives N via fold_total_tokens (same length source + cache as memory sizing) -> sbatch --gpus=:. Verified on the cluster (3 folds, AF3): N=800 -> --gpus=3090:1 (--mem 23600), N=3500 -> --gpus=A100:1 (--mem 88906), N=4600 -> --gpus=H100:1 (--mem 139025); GPU model and memory both scale with length in the real sbatch call. 20 unit tests pass (added select_gpu_model coverage). Co-Authored-By: Claude Opus 4.7 --- README.md | 15 +++++++++++++++ config/config.yaml | 9 +++++++++ test/test_memory_resources.py | 15 +++++++++++++++ workflow/Snakefile | 26 +++++++++++++++++++++++++- workflow/rules/common.smk | 19 +++++++++++++++++++ 5 files changed, 83 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c7c16d7..cace858 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,21 @@ you hit these. - **Restrict to one model** with `structure_inference_gpu_model` (e.g. `"A100"`) → the plugin emits `--gpus=:`. Accepts a single model name; leave `""` for any. +- **Route by complex size** with `structure_inference_gpu_model_by_tokens` → a map of (inclusive) + upper total-token bound to GPU model, so small complexes use small GPUs and large ones go to + big-VRAM cards (peak VRAM ≈ `0.0045·N²` MB): + + ```yaml + structure_inference_gpu_model_by_tokens: + 2000: "3090" # <=2000 tokens -> 24 GB + 4000: "A100" # <=4000 tokens -> 80 GB + 99999: "H100" # larger -> biggest available (spills via unified memory) + ``` + + Each fold is routed to the smallest tier it fits (larger than all bounds → the last tier). When + set this takes precedence over `structure_inference_gpu_model`. This is the practical "fit to GPU" + lever: requested host RAM is a separate pool and does not size GPU VRAM, but choosing the GPU model + by length does. - **Exclude specific nodes** with `slurm_exclude_nodes` → passed verbatim to `sbatch --exclude` (e.g. `"gpu50,gpu51"`). Use it for nodes whose GPU the container can't use — e.g. a CUDA compute capability newer than the container's bundled `ptxas` (fails `ptxas too old` / `UNIMPLEMENTED`). diff --git a/config/config.yaml b/config/config.yaml index e8171b5..a1e3468 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -122,6 +122,15 @@ structure_inference_gpus_per_task: 1 # Restrict structure_inference to one GPU model (sbatch --gpus=:N), e.g. "3090". # Leave "" to let SLURM pick any GPU in the partition. structure_inference_gpu_model: "3090" +# Optional length-based GPU routing: map an (inclusive) upper TOTAL-token bound to a +# GPU model, so small complexes go to small GPUs and large ones to big-VRAM cards. +# Each complex is routed to the smallest tier it fits (larger than all -> last tier). +# When set, this takes precedence over structure_inference_gpu_model. Example +# (peak VRAM ~ 0.0045 * N^2 MB, so ~2000 tok fits 24 GB, ~4000 tok needs ~80 GB): +# structure_inference_gpu_model_by_tokens: +# 2000: "3090" # <=2000 tokens -> 24 GB +# 4000: "A100" # <=4000 tokens -> 80 GB +# 99999: "H100" # larger -> biggest available (will spill via unified memory) # Optional: comma-separated nodes to keep structure_inference OFF, passed to sbatch # as --exclude. Useful for GPUs the prediction container cannot use (e.g. a CUDA # compute capability the bundled ptxas is too old for). Example: diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index 4d0e70c..ec8a784 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -258,6 +258,21 @@ def test_af3_input_residue_count_skips_ligands(): assert common.af3_input_residue_count(p) == 100 # ligand contributes 0 +def test_select_gpu_model(): + tiers = {2000: "3090", 4000: "A100", 99999: "H100"} + assert common.select_gpu_model(800, tiers) == "3090" # small -> small GPU + assert common.select_gpu_model(2000, tiers) == "3090" # boundary inclusive + assert common.select_gpu_model(2001, tiers) == "A100" + assert common.select_gpu_model(4000, tiers) == "A100" + assert common.select_gpu_model(4600, tiers) == "H100" + assert common.select_gpu_model(500000, tiers) == "H100" # bigger than all -> last + # unordered / string keys (as parsed from YAML) still work + assert common.select_gpu_model(3000, {"4000": "A100", "2000": "3090"}) == "A100" + # empty map -> fall back to the fixed default model + assert common.select_gpu_model(3000, {}, default_model="3090") == "3090" + assert common.select_gpu_model(3000, {}) is None + + def test_mem_mb_reaches_sbatch_via_real_plugin(): """Integration: the value our model computes is what the SLURM plugin turns into `sbatch --mem`. Skips gracefully if the plugin isn't importable.""" diff --git a/workflow/Snakefile b/workflow/Snakefile index 2a09349..33cfdee 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -372,6 +372,30 @@ structure_inference_runtime_minutes = config.get( "structure_inference_runtime_minutes", _infer_ram_defaults["runtime_minutes"] ) +# Optional length-based GPU routing: map an (inclusive) upper token bound to a GPU +# model, e.g. {2000: "3090", 4000: "A100", 99999: "H100"}. Each complex is routed +# to the smallest tier it fits (bigger-than-all -> largest tier). When unset, the +# fixed structure_inference_gpu_model is used (previous behaviour). +GPU_MODEL_BY_TOKENS = config.get("structure_inference_gpu_model_by_tokens") or {} +_GPU_MODEL_CONFIGURED = bool(GPU_MODEL_BY_TOKENS) or ( + DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL is not None +) + + +def _inference_gpu_model(wildcards): + """gpu_model resource: length-tier model, or the fixed model, or '' (any GPU).""" + tokens = fold_total_tokens( + wildcards.fold, + join(config["output_directory"], "data"), + protein_delimiter, + features_dir=join(config["output_directory"], "features"), + is_af3=(INFERENCE_BACKEND == "alphafold3"), + length_cache=sequence_length_cache, + ) + return select_gpu_model( + tokens, GPU_MODEL_BY_TOKENS, DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL + ) or "" + rule symlink_features: input: precomputed_features, @@ -472,7 +496,7 @@ rule structure_inference: slurm_partition=(DEFAULT_SLURM_PARTITION or "gpu"), qos=DEFAULT_SLURM_QOS, gpu=DEFAULT_STRUCTURE_INFERENCE_GPUS, - **({"gpu_model": DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL} if DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL is not None else {}), + **({"gpu_model": _inference_gpu_model} if _GPU_MODEL_CONFIGURED else {}), **({"slurm_extra": f"--exclude={DEFAULT_SLURM_EXCLUDE}"} if DEFAULT_SLURM_EXCLUDE else {}), tasks_per_gpu=DEFAULT_STRUCTURE_INFERENCE_TASKS_PER_GPU, **linear_resources( diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index b6fb80f..797e77d 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -205,6 +205,25 @@ def fold_length_violation( return None +def select_gpu_model(total_tokens: int, tiers, default_model=None): + """Pick a GPU model for a complex of ``total_tokens`` from a token->model map. + + ``tiers`` maps an (inclusive) upper token bound to a GPU model name, e.g. + ``{2000: "3090", 4000: "A100", 99999: "H100"}``. Returns the model of the + smallest tier whose bound is >= ``total_tokens``; a complex larger than every + tier falls back to the largest tier's model (best available). When ``tiers`` + is empty, returns ``default_model`` (the fixed ``structure_inference_gpu_model``, + preserving the previous single-model behaviour). + """ + if not tiers: + return default_model + ordered = sorted((int(bound), str(model)) for bound, model in tiers.items()) + for bound, model in ordered: + if total_tokens <= bound: + return model + return ordered[-1][1] + + def _cap_mem(value_mb: float, cap_mb: int) -> int: value = max(int(value_mb), 1) if cap_mb and cap_mb > 0: From e5901b21a5c0966ae73dca78633f5c66643e69d8 Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 22:17:34 +0200 Subject: [PATCH 7/8] refactor: VRAM node-exclusion GPU routing (use the whole pool) Replaces the token->single-model GPU routing with VRAM-based node exclusion, so a complex runs on ANY GPU with enough memory instead of one pinned model. This matters where several GPU models share a VRAM tier (e.g. EMBL gpu-el8 has ~88 nodes at 48 GB across A40+L40s but only 2 at 80 GB) and mirrors the handoff's "exclude all <80 GB nodes" rescue. New config (cluster-agnostic; list your own tiers): structure_inference_gpu_tiers: [{min_vram_gb, nodes}, ...] structure_inference_gpu_vram_headroom: 1.0 # <1.0 tolerates that much host spill Each complex's estimated peak VRAM (~ structure_inference_ram_per_token_sq * N^2) picks the smallest tier that fits; nodes of all smaller tiers are excluded (largest tier if none fits -> spill via unified memory). Drives a per-job slurm_extra=--exclude=..., merged with the static slurm_exclude_nodes, and overrides structure_inference_gpu_model (the two would conflict). Stays within one partition (EMBL's bigger gpu-training cards are out of scope; the tail spills to host). common.smk: required_gpu_vram_gb(), gpu_exclude_nodes() (replaces select_gpu_model). config.yaml/README: EMBL gpu-el8 example (24/40/48/80 GB tiers) with notes that the RTX PRO 6000 (ptxas-incompatible) stays in slurm_exclude_nodes and gpu-training is separate; emphasises it is just an example. Verified live (3 AF3 folds): N=800 -> --exclude=gpu50 (any GPU); N=3500 and N=4600 -> --exclude all 24+48 GB nodes + gpu50 (80 GB pool, spill for the tail); --mem still 23600/88906/139025. 21 unit tests pass. Co-Authored-By: Claude Opus 4.7 --- README.md | 28 +++++++++------- config/config.yaml | 26 +++++++++------ test/test_memory_resources.py | 44 +++++++++++++++++-------- workflow/Snakefile | 53 +++++++++++++++++++++---------- workflow/rules/common.smk | 60 +++++++++++++++++++++++++---------- 5 files changed, 145 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index cace858..a41f6cc 100644 --- a/README.md +++ b/README.md @@ -241,21 +241,27 @@ you hit these. - **Restrict to one model** with `structure_inference_gpu_model` (e.g. `"A100"`) → the plugin emits `--gpus=:`. Accepts a single model name; leave `""` for any. -- **Route by complex size** with `structure_inference_gpu_model_by_tokens` → a map of (inclusive) - upper total-token bound to GPU model, so small complexes use small GPUs and large ones go to - big-VRAM cards (peak VRAM ≈ `0.0045·N²` MB): +- **Route by complex size (VRAM)** with `structure_inference_gpu_tiers` → list your GPU pool as + tiers of `{min_vram_gb, nodes}`. A complex's estimated peak VRAM (≈ `per_token_sq·N²`) selects the + smallest tier that fits and all *smaller*-GPU nodes are excluded, so the job runs on **any** GPU at + or above that tier — using the whole pool, not one pinned model. A complex larger than every tier + uses the biggest tier and spills to host RAM via unified memory. ```yaml - structure_inference_gpu_model_by_tokens: - 2000: "3090" # <=2000 tokens -> 24 GB - 4000: "A100" # <=4000 tokens -> 80 GB - 99999: "H100" # larger -> biggest available (spills via unified memory) + # Example for EMBL gpu-el8 — replace nodes with your cluster's (nothing is hard-coded): + structure_inference_gpu_vram_headroom: 1.0 # <1.0 tolerates that fraction of host spill + structure_inference_gpu_tiers: + - {min_vram_gb: 24, nodes: "gpu21,gpu22,gpu29,gpu30,gpu31,gpu32,gpu33,gpu34,gpu35,gpu36,gpu37"} + - {min_vram_gb: 40, nodes: "gpu25,gpu26,gpu27,gpu28"} + - {min_vram_gb: 48, nodes: "gpu40,gpu41,gpu42,gpu43,gpu44,gpu45,gpu46,gpu47,gpu48"} + - {min_vram_gb: 80, nodes: "gpu38,gpu39"} ``` - Each fold is routed to the smallest tier it fits (larger than all bounds → the last tier). When - set this takes precedence over `structure_inference_gpu_model`. This is the practical "fit to GPU" - lever: requested host RAM is a separate pool and does not size GPU VRAM, but choosing the GPU model - by length does. + When set this drives `--exclude` per job and **overrides** `structure_inference_gpu_model` (the two + would conflict). It's the practical "fit to GPU" lever: requested host RAM is a separate pool and + does not size GPU VRAM, but excluding too-small GPUs by length does. Use explicit comma node lists + (bracket ranges may be glob-expanded by the shell). Multi-partition routing (e.g. EMBL's bigger + `gpu-training` cards) is out of scope — keep one partition and let unified memory spill the tail. - **Exclude specific nodes** with `slurm_exclude_nodes` → passed verbatim to `sbatch --exclude` (e.g. `"gpu50,gpu51"`). Use it for nodes whose GPU the container can't use — e.g. a CUDA compute capability newer than the container's bundled `ptxas` (fails `ptxas too old` / `UNIMPLEMENTED`). diff --git a/config/config.yaml b/config/config.yaml index a1e3468..68662e5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -122,15 +122,23 @@ structure_inference_gpus_per_task: 1 # Restrict structure_inference to one GPU model (sbatch --gpus=:N), e.g. "3090". # Leave "" to let SLURM pick any GPU in the partition. structure_inference_gpu_model: "3090" -# Optional length-based GPU routing: map an (inclusive) upper TOTAL-token bound to a -# GPU model, so small complexes go to small GPUs and large ones to big-VRAM cards. -# Each complex is routed to the smallest tier it fits (larger than all -> last tier). -# When set, this takes precedence over structure_inference_gpu_model. Example -# (peak VRAM ~ 0.0045 * N^2 MB, so ~2000 tok fits 24 GB, ~4000 tok needs ~80 GB): -# structure_inference_gpu_model_by_tokens: -# 2000: "3090" # <=2000 tokens -> 24 GB -# 4000: "A100" # <=4000 tokens -> 80 GB -# 99999: "H100" # larger -> biggest available (will spill via unified memory) +# Optional length-based GPU routing by VRAM (within this partition). List your GPU +# pool as tiers of {min_vram_gb, nodes}; a complex's estimated peak VRAM +# (~ structure_inference_ram_per_token_sq * N^2) selects the smallest tier that fits +# and all SMALLER-GPU nodes are excluded, so the job runs on ANY GPU >= that tier +# (the whole pool, not one pinned model). A complex larger than every tier uses the +# biggest tier and spills to host RAM via unified memory. When set, this drives +# --exclude and OVERRIDES structure_inference_gpu_model. Use explicit comma node lists +# (avoid bracket ranges, which the shell may glob). The example below is for the EMBL +# gpu-el8 partition - replace nodes with your cluster's; nothing is hard-coded. +# structure_inference_gpu_vram_headroom: 1.0 # <1.0 tolerates that fraction of host spill +# structure_inference_gpu_tiers: +# - {min_vram_gb: 24, nodes: "gpu21,gpu22,gpu29,gpu30,gpu31,gpu32,gpu33,gpu34,gpu35,gpu36,gpu37"} # RTX 3090 +# - {min_vram_gb: 40, nodes: "gpu25,gpu26,gpu27,gpu28"} # A100 40GB +# - {min_vram_gb: 48, nodes: "gpu40,gpu41,gpu42,gpu43,gpu44,gpu45,gpu46,gpu47,gpu48"} # L40s/A40 48GB +# - {min_vram_gb: 80, nodes: "gpu38,gpu39"} # H100 PCIe 80GB +# Note: RTX PRO 6000 (gpu50-53, 96GB) are ptxas-incompatible -> keep in slurm_exclude_nodes. +# H100-SXM/H200/B200 live on the separate gpu-training partition (not routed here). # Optional: comma-separated nodes to keep structure_inference OFF, passed to sbatch # as --exclude. Useful for GPUs the prediction container cannot use (e.g. a CUDA # compute capability the bundled ptxas is too old for). Example: diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index ec8a784..498a03d 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -258,19 +258,37 @@ def test_af3_input_residue_count_skips_ligands(): assert common.af3_input_residue_count(p) == 100 # ligand contributes 0 -def test_select_gpu_model(): - tiers = {2000: "3090", 4000: "A100", 99999: "H100"} - assert common.select_gpu_model(800, tiers) == "3090" # small -> small GPU - assert common.select_gpu_model(2000, tiers) == "3090" # boundary inclusive - assert common.select_gpu_model(2001, tiers) == "A100" - assert common.select_gpu_model(4000, tiers) == "A100" - assert common.select_gpu_model(4600, tiers) == "H100" - assert common.select_gpu_model(500000, tiers) == "H100" # bigger than all -> last - # unordered / string keys (as parsed from YAML) still work - assert common.select_gpu_model(3000, {"4000": "A100", "2000": "3090"}) == "A100" - # empty map -> fall back to the fixed default model - assert common.select_gpu_model(3000, {}, default_model="3090") == "3090" - assert common.select_gpu_model(3000, {}) is None +def test_required_gpu_vram_gb(): + # 0.0045 MB/token^2: N=4836 -> ~105 GB; headroom scales it + assert round(common.required_gpu_vram_gb(4836, 0.0045)) == 105 + assert round(common.required_gpu_vram_gb(2066, 0.0045)) == 19 + assert common.required_gpu_vram_gb(4836, 0.0045, headroom=0.5) < 60 + + +def test_gpu_exclude_nodes_vram_routing(): + tiers = [ + {"min_vram_gb": 24, "nodes": "n24a,n24b"}, + {"min_vram_gb": 48, "nodes": "n48a,n48b"}, + {"min_vram_gb": 80, "nodes": "n80a"}, + ] + c = 0.0045 + # small complex (~3 GB) fits the smallest tier -> exclude nothing + assert common.gpu_exclude_nodes(800, tiers, c) == "" + # ~26 GB -> needs >=48 GB tier -> exclude the 24 GB nodes + assert common.gpu_exclude_nodes(2400, tiers, c) == "n24a,n24b" + # ~55 GB -> needs the 80 GB tier -> exclude 24 and 48 GB nodes + assert common.gpu_exclude_nodes(3500, tiers, c) == "n24a,n24b,n48a,n48b" + # bigger than every tier -> use largest tier (spill), exclude all smaller + assert common.gpu_exclude_nodes(20000, tiers, c) == "n24a,n24b,n48a,n48b" + # static extra excludes are always appended (e.g. ptxas-incompatible cards) + assert common.gpu_exclude_nodes(800, tiers, c, extra_exclude="gpu50,gpu51") == "gpu50,gpu51" + assert ( + common.gpu_exclude_nodes(2400, tiers, c, extra_exclude="gpu50") + == "n24a,n24b,gpu50" + ) + # unsorted tiers handled; no tiers -> only the static excludes + assert common.gpu_exclude_nodes(2400, [], c, extra_exclude="gpu50") == "gpu50" + assert common.gpu_exclude_nodes(2400, [], c) == "" def test_mem_mb_reaches_sbatch_via_real_plugin(): diff --git a/workflow/Snakefile b/workflow/Snakefile index 33cfdee..a7abfc4 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -372,18 +372,20 @@ structure_inference_runtime_minutes = config.get( "structure_inference_runtime_minutes", _infer_ram_defaults["runtime_minutes"] ) -# Optional length-based GPU routing: map an (inclusive) upper token bound to a GPU -# model, e.g. {2000: "3090", 4000: "A100", 99999: "H100"}. Each complex is routed -# to the smallest tier it fits (bigger-than-all -> largest tier). When unset, the -# fixed structure_inference_gpu_model is used (previous behaviour). -GPU_MODEL_BY_TOKENS = config.get("structure_inference_gpu_model_by_tokens") or {} -_GPU_MODEL_CONFIGURED = bool(GPU_MODEL_BY_TOKENS) or ( - DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL is not None -) - - -def _inference_gpu_model(wildcards): - """gpu_model resource: length-tier model, or the fixed model, or '' (any GPU).""" +# Optional length-based GPU routing by VRAM (within one partition). Each tier is +# {min_vram_gb, nodes}; a complex's estimated peak VRAM picks the smallest tier that +# fits, and all smaller-GPU nodes are excluded so the job runs on ANY GPU >= that +# tier (the whole pool, not one pinned model). Cluster-agnostic: list your own tiers. +# When set, this drives `slurm_extra=--exclude=...` and takes precedence over the +# fixed structure_inference_gpu_model (which is dropped to avoid an impossible +# "pin model X but exclude its nodes" request). +GPU_TIERS = config.get("structure_inference_gpu_tiers") or [] +GPU_VRAM_HEADROOM = float(config.get("structure_inference_gpu_vram_headroom", 1.0)) + + +def _inference_slurm_extra(wildcards): + """slurm_extra resource: --exclude small-GPU nodes for large complexes, merged + with the static slurm_exclude_nodes.""" tokens = fold_total_tokens( wildcards.fold, join(config["output_directory"], "data"), @@ -392,9 +394,14 @@ def _inference_gpu_model(wildcards): is_af3=(INFERENCE_BACKEND == "alphafold3"), length_cache=sequence_length_cache, ) - return select_gpu_model( - tokens, GPU_MODEL_BY_TOKENS, DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL - ) or "" + nodes = gpu_exclude_nodes( + tokens, + GPU_TIERS, + structure_inference_ram_per_token_sq, + GPU_VRAM_HEADROOM, + DEFAULT_SLURM_EXCLUDE or "", + ) + return f"--exclude={nodes}" if nodes else "" rule symlink_features: input: @@ -496,8 +503,20 @@ rule structure_inference: slurm_partition=(DEFAULT_SLURM_PARTITION or "gpu"), qos=DEFAULT_SLURM_QOS, gpu=DEFAULT_STRUCTURE_INFERENCE_GPUS, - **({"gpu_model": _inference_gpu_model} if _GPU_MODEL_CONFIGURED else {}), - **({"slurm_extra": f"--exclude={DEFAULT_SLURM_EXCLUDE}"} if DEFAULT_SLURM_EXCLUDE else {}), + # Pin a single GPU model only when VRAM-tier routing is NOT used (the two + # would conflict: tier routing may exclude the pinned model's nodes). + **( + {"gpu_model": DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL} + if (DEFAULT_STRUCTURE_INFERENCE_GPU_MODEL is not None and not GPU_TIERS) + else {} + ), + # --exclude is dynamic (per-complex VRAM routing) when tiers are set, else the + # static slurm_exclude_nodes. + **( + {"slurm_extra": _inference_slurm_extra} + if (GPU_TIERS or DEFAULT_SLURM_EXCLUDE) + else {} + ), tasks_per_gpu=DEFAULT_STRUCTURE_INFERENCE_TASKS_PER_GPU, **linear_resources( mem_fn=lambda wildcards, attempt: estimate_inference_mem_mb( diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 797e77d..c7da452 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -205,23 +205,51 @@ def fold_length_violation( return None -def select_gpu_model(total_tokens: int, tiers, default_model=None): - """Pick a GPU model for a complex of ``total_tokens`` from a token->model map. - - ``tiers`` maps an (inclusive) upper token bound to a GPU model name, e.g. - ``{2000: "3090", 4000: "A100", 99999: "H100"}``. Returns the model of the - smallest tier whose bound is >= ``total_tokens``; a complex larger than every - tier falls back to the largest tier's model (best available). When ``tiers`` - is empty, returns ``default_model`` (the fixed ``structure_inference_gpu_model``, - preserving the previous single-model behaviour). +def required_gpu_vram_gb( + total_tokens: int, per_token_sq_mb: float, headroom: float = 1.0 +) -> float: + """Estimated peak GPU VRAM (GB) for a complex of ``total_tokens``. + + Uses the same O(N^2) coefficient as host-memory sizing as a proxy for on-device + peak demand, scaled by ``headroom`` (e.g. 0.8 tolerates ~20% spill to host). + """ + return headroom * per_token_sq_mb * (max(int(total_tokens), 0) ** 2) / 1000.0 + + +def gpu_exclude_nodes( + total_tokens: int, + tiers, + per_token_sq_mb: float, + headroom: float = 1.0, + extra_exclude: str = "", +) -> str: + """Comma-joined SLURM nodes to exclude so a complex lands on a big-enough GPU. + + ``tiers`` is an iterable of ``{"min_vram_gb": int, "nodes": ""}`` + describing the cluster's GPU pool. The complex's required VRAM + (:func:`required_gpu_vram_gb`) selects the smallest tier that satisfies it (the + largest tier if none does — the remainder spills to host via unified memory); + the nodes of every *smaller* tier are excluded, so the job may run on any GPU at + or above the chosen tier (the whole pool, not one pinned model). ``extra_exclude`` + (the static ``slurm_exclude_nodes``) is always appended. + + Cluster-agnostic: each site lists its own GPU tiers/nodes; nothing about a + specific cluster is hard-coded. """ - if not tiers: - return default_model - ordered = sorted((int(bound), str(model)) for bound, model in tiers.items()) - for bound, model in ordered: - if total_tokens <= bound: - return model - return ordered[-1][1] + parts: list[str] = [] + valid = [t for t in tiers if t and t.get("nodes")] + if valid: + ordered = sorted(valid, key=lambda t: int(t["min_vram_gb"])) + required = required_gpu_vram_gb(total_tokens, per_token_sq_mb, headroom) + chosen = len(ordered) - 1 + for index, tier in enumerate(ordered): + if int(tier["min_vram_gb"]) >= required: + chosen = index + break + parts.extend(str(tier["nodes"]) for tier in ordered[:chosen]) + if extra_exclude: + parts.append(str(extra_exclude)) + return ",".join(part for part in parts if part) def _cap_mem(value_mb: float, cap_mb: int) -> int: From 665f362a3dce329ef41981e628ef984ccfaf5f1d Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Sat, 23 May 2026 23:01:47 +0200 Subject: [PATCH 8/8] docs: note length filtering applies to local/workstation runs too Clarify that the length filter (and only the length filter) runs during workflow parsing, so it affects every profile including local execution; the memory and GPU-routing settings are SLURM resources that local runs ignore. Document raising /zeroing max_total_length_* (and disabling the UniProt fetch) to attempt very large folds on a workstation. Co-Authored-By: Claude Opus 4.7 --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index a41f6cc..efff6c1 100644 --- a/README.md +++ b/README.md @@ -401,6 +401,11 @@ length_filter_fetch_uniprot: true # set false for fully offline runs UniProt outage never silently drops work. - First parse of a large all-UniProt sheet will fetch each unique length once (cached afterwards); already-downloaded inputs and local FASTAs are read without any network call. +- **Applies to every profile, including local/workstation runs** (it runs during workflow + parsing, not in the executor). It's the only length-aware feature that does — the memory + and GPU-routing settings are SLURM resources that local runs ignore. To attempt a complex + larger than the caps on a big workstation, raise or zero the `max_total_length_*` values + (and set `length_filter_fetch_uniprot: false` for offline use).