Skip to content
Merged
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha
* 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research.
* 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*).
* 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required.
* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**.
* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**.
* 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting.
* 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures.
* 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility.
Expand Down
4 changes: 2 additions & 2 deletions src/stamp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import yaml

from stamp.config import StampConfig
from stamp.modeling.config import (
AdvancedConfig,
MlpModelParams,
ModelParams,
VitModelParams,
)
from stamp.seed import Seed
from stamp.utils.config import StampConfig
from stamp.utils.seed import Seed

STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml")

Expand Down
24 changes: 22 additions & 2 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ preprocessing:
# Extractor to use for feature extractor. Possible options are "ctranspath",
# "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom",
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
# "virchow-full", "musk", "mstar", "plip"
# "virchow-full", "musk", "mstar", "plip", "ticon"
# Some of them require requesting access to the respective authors beforehand.
extractor: "chief-ctranspath"

Expand Down Expand Up @@ -76,6 +76,8 @@ crossval:

# Name of the column from the clini table to train on.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
Expand Down Expand Up @@ -133,6 +135,8 @@ training:

# Name of the column from the clini table to train on.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
Expand Down Expand Up @@ -175,6 +179,8 @@ deployment:

# Name of the column from the clini to compare predictions to.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
Expand All @@ -200,6 +206,8 @@ statistics:

# Name of the target label.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# A lot of the statistics are computed "one-vs-all", i.e. there needs to be
# a positive class to calculate the statistics for.
Expand Down Expand Up @@ -319,7 +327,7 @@ advanced_config:
max_lr: 1e-4
div_factor: 25.
# Select a model regardless of task
model_name: "vit" # or mlp, trans_mil
model_name: "vit" # or mlp, trans_mil, barspoon

model_params:
vit: # Vision Transformer
Expand All @@ -338,3 +346,15 @@ advanced_config:
dim_hidden: 512
num_layers: 2
dropout: 0.25

# NOTE: Only the `barspoon` model supports multi-target classification
# (i.e. `ground_truth_label` can be a list of column names). Other
# models expect a single target column.
barspoon: # Encoder-Decoder Transformer for multi-target classification
d_model: 512
num_encoder_heads: 8
num_decoder_heads: 8
num_encoder_layers: 2
num_decoder_layers: 2
dim_feedforward: 2048
positional_encoding: true
2 changes: 1 addition & 1 deletion src/stamp/encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def init_slide_encoder_(
selected_encoder = encoder

case _ as unreachable:
assert_never(unreachable) # type: ignore
assert_never(unreachable)

selected_encoder.encode_slides_(
output_dir=output_dir,
Expand Down
6 changes: 4 additions & 2 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import cast

import h5py
import numpy as np
Expand All @@ -12,11 +13,11 @@
from tqdm import tqdm

import stamp
from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.modeling.data import CoordsInfo, get_coords, read_table
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down Expand Up @@ -183,7 +184,8 @@ def _read_h5(
elif not h5_path.endswith(".h5"):
raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}")
with h5py.File(h5_path, "r") as f:
feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore
feats_ds = cast(h5py.Dataset, f["feats"])
feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision)
coords: CoordsInfo = get_coords(f)
extractor: str = f.attrs.get("extractor", "")
if extractor == "":
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/chief.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from numpy import ndarray
from tqdm import tqdm

from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel
from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from torch import Tensor
from tqdm import tqdm

from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.encoding.encoder.chief import CHIEF
from stamp.modeling.data import CoordsInfo
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/gigapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from gigapath import slide_encoder
from tqdm import tqdm

from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.modeling.data import CoordsInfo
from stamp.preprocessing.config import ExtractorName
from stamp.types import PandasLabel, SlideMPP
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/madeleine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from numpy import ndarray

from stamp.cache import STAMP_CACHE_DIR
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.preprocessing.config import ExtractorName
from stamp.utils.cache import STAMP_CACHE_DIR

try:
from madeleine.models.factory import create_model_from_pretrained
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/titan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from tqdm import tqdm
from transformers import AutoModel

from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.modeling.data import CoordsInfo
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
92 changes: 50 additions & 42 deletions src/stamp/heatmaps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from collections.abc import Collection, Iterable
from pathlib import Path
from typing import cast, no_type_check
from typing import cast

import h5py
import matplotlib.pyplot as plt
Expand All @@ -19,7 +19,7 @@
from packaging.version import Version
from PIL import Image
from torch import Tensor
from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage]
from torch.func import jacrev

from stamp.modeling.data import get_coords, get_stride
from stamp.modeling.deploy import load_model_from_ckpt
Expand All @@ -29,6 +29,8 @@

_logger = logging.getLogger("stamp")

_SlideLike = openslide.OpenSlide | openslide.ImageSlide


def _gradcam_per_category(
model: torch.nn.Module,
Expand All @@ -37,23 +39,19 @@ def _gradcam_per_category(
) -> Float[Tensor, "tile category"]:
feat_dim = -1

cam = (
(
feats
* jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze(0)
)(feats)
)
.mean(feat_dim) # type: ignore
.abs()
jac = cast(
Tensor,
jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze(0)
)(feats),
)

cam = (feats * jac).mean(feat_dim).abs()
cam = torch.softmax(cam, dim=-1)

return cam.permute(-1, -2)


Expand All @@ -79,12 +77,19 @@ def _attention_rollout_single(

# --- 2. Rollout computation ---
attn_rollout: torch.Tensor | None = None
for layer in model.transformer.layers: # type: ignore
attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights
transformer = getattr(model, "transformer", None)
if transformer is None:
raise RuntimeError("Model does not have a transformer attribute")
for layer in transformer.layers:
attn = getattr(layer, "attn_weights", None)
if attn is None:
first_child = next(iter(layer.children()), None)
if first_child is not None:
attn = getattr(first_child, "attn_weights", None)
if attn is None:
raise RuntimeError(
"SelfAttention.attn_weights not found. "
"Make sure SelfAttention stores them."
"Make sure SelfAttention stores them on the layer or its first child."
)

# attn: [heads, seq, seq]
Expand Down Expand Up @@ -117,15 +122,18 @@ def _gradcam_single(
"""
feat_dim = -1

jac = jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze()
)(feats)
jac = cast(
Tensor,
jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze()
)(feats),
)

cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile]
cam = (feats * jac).mean(feat_dim).abs() # [tile]

return cam

Expand All @@ -148,17 +156,21 @@ def _vals_to_im(


def _show_thumb(
slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None
slide: _SlideLike,
thumb_ax: Axes,
attention: Tensor,
default_slide_mpp: SlideMPP | None,
) -> np.ndarray:
mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp)
dims_um = np.array(slide.dimensions) * mpp
thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))
thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist())
thumb = slide.get_thumbnail(thumb_size)
thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8])
return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]


def _get_thumb_array(
slide,
slide: _SlideLike,
attention: torch.Tensor,
default_slide_mpp: SlideMPP | None,
) -> np.ndarray:
Expand All @@ -168,12 +180,12 @@ def _get_thumb_array(
"""
mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp)
dims_um = np.array(slide.dimensions) * mpp
thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)))
thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist())
thumb = np.array(slide.get_thumbnail(thumb_size))
thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8]
return thumb_crop


@no_type_check # beartype<=0.19.0 breaks here for some reason
def _show_class_map(
class_ax: Axes,
top_score_indices: Integer[Tensor, "width height"],
Expand Down Expand Up @@ -298,13 +310,8 @@ def heatmaps_(
raise ValueError(
f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only."
)
feats = (
torch.tensor(
h5["feats"][:] # pyright: ignore[reportIndexIssue]
)
.float()
.to(device)
)
feats_np = np.asarray(h5["feats"])
feats = torch.from_numpy(feats_np).float().to(device)
coords_info = get_coords(h5)
coords_um = torch.from_numpy(coords_info.coords_um).float()
stride_um = Microns(get_stride(coords_um))
Expand All @@ -322,9 +329,10 @@ def heatmaps_(
model = load_model_from_ckpt(checkpoint_path).eval()

# TODO: Update version when a newer model logic breaks heatmaps.
if Version(model.stamp_version) < Version("2.4.0"):
stamp_version = str(getattr(model, "stamp_version", ""))
if Version(stamp_version) < Version("2.4.0"):
raise ValueError(
f"model has been built with stamp version {model.stamp_version} "
f"model has been built with stamp version {stamp_version} "
f"which is incompatible with the current version."
)

Expand Down Expand Up @@ -356,7 +364,7 @@ def heatmaps_(

with torch.no_grad():
scores = torch.softmax(
model.model.forward(
model.model(
feats.unsqueeze(-2),
coords=coords_um.unsqueeze(-2),
mask=torch.zeros(
Expand Down
Loading
Loading