diff --git a/README.md b/README.md index c0f837b..b0b2887 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,10 @@ See the `pyproject.toml` file for the complete list of dependencies. ### Hardware Requirements - GPU (A100 40GB recommended) for efficient inference and embedding extraction. - Can also use a GPU with a lower amount of VRAM (16GB) by setting the inference batch size to 1-4. +- **Multi-GPU support**: For faster inference on large datasets, use multiple GPUs with the `--num-gpus` parameter. + - Recommended for datasets with >100k cells + - Scales batch processing across available GPUs using Distributed Data Parallel (DDP) + - Best performance with matched GPU types and sufficient inter-GPU bandwidth ## Using the TranscriptFormer CLI @@ -234,6 +238,13 @@ transcriptformer inference \ --data-file test/data/human_val.h5ad \ --emb-type cge \ --batch-size 8 + +# Multi-GPU inference using 4 GPUs (-1 will use all available on the system) +transcriptformer inference \ + --checkpoint-path ./checkpoints/tf_sapiens \ + --data-file test/data/human_val.h5ad \ + --num-gpus 4 \ + --batch-size 32 ``` You can also use the CLI it run inference on the ESM2-CE baseline model discussed in the paper: @@ -281,6 +292,9 @@ transcriptformer download-data --help - `--embedding-layer-index INT`: Index of the transformer layer to extract embeddings from (-1 for last layer, default: -1). Use with `transcriptformer` model type. - `--model-type {transcriptformer,esm2ce}`: Type of model to use (default: `transcriptformer`). Use `esm2ce` to extract raw ESM2-CE gene embeddings. - `--emb-type {cell,cge}`: Type of embeddings to extract (default: `cell`). Use `cell` for mean-pooled cell embeddings or `cge` for contextual gene embeddings. +- `--num-gpus INT`: Number of GPUs to use for inference (default: 1). Use -1 for all available GPUs, or specify a specific number. +- `--oom-dataloader`: Use the OOM-safe map-style DataLoader (uses backed reads and per-item densification; DistributedSampler-friendly). +- `--n-data-workers INT`: Number of DataLoader workers per process (default: 0). Order is preserved with the map-style dataset and DistributedSampler. - `--config-override key.path=value`: Override any configuration value directly. ### Input Data Format and Preprocessing: @@ -301,6 +315,18 @@ Input data files should be in H5AD format (AnnData objects) with the following r - `True`: Use only `adata.raw.X` - `False`: Use only `adata.X` + - **OOM-safe Data Loading**: + - To reduce peak memory usage on large datasets, enable the OOM-safe dataloader: + ```bash + transcriptformer inference \ + --checkpoint-path ./checkpoints/tf_sapiens \ + --data-file ./data/huge.h5ad \ + --oom-dataloader \ + --n-data-workers 4 \ + --num-gpus 8 + ``` + - This uses a map-style dataset with backed reads and per-row densification. It is compatible with `DistributedSampler`, so multiple workers are safe and ordering is preserved. + - **Count Processing**: - Count values are clipped at 30 by default (as was done in training) - If this seems too low, you can either: diff --git a/download_artifacts.py b/download_artifacts.py deleted file mode 100644 index f424263..0000000 --- a/download_artifacts.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 - -""" -Download and extract TranscriptFormer model artifacts from a public S3 bucket. - -This script provides a convenient way to download and extract TranscriptFormer model weights -from a public S3 bucket. It supports downloading individual models or all models at once, -with progress indicators for both download and extraction processes. - -Usage: - python download_artifacts.py [model] [--checkpoint-dir DIR] - - model: The model to download. Options are: - - tf-sapiens: Download the sapiens model - - tf-exemplar: Download the exemplar model - - tf-metazoa: Download the metazoa model - - all: Download all models and embeddings - - all-embeddings: Download only the embedding files - - --checkpoint-dir: Optional directory to store the downloaded checkpoints. - Defaults to './checkpoints' - -Examples --------- - # Download the sapiens model - python download_artifacts.py tf-sapiens - - # Download all models and embeddings - python download_artifacts.py all - - # Download only the embeddings file - python download_artifacts.py all-embeddings - - # Download the exemplar model to a custom directory - python download_artifacts.py tf-exemplar --checkpoint-dir /path/to/models - -The downloaded models will be extracted to: - ./checkpoints/tf_sapiens/ - ./checkpoints/tf_exemplar/ - ./checkpoints/tf_metazoa/ - ./checkpoints/all_embeddings/ -""" - -import argparse -import math -import sys -import tarfile -import tempfile -import urllib.error -import urllib.request -from pathlib import Path - - -def print_progress(current, total, prefix="", suffix="", length=50): - """Print a simple progress bar.""" - filled = int(length * current / total) - bar = "█" * filled + "░" * (length - filled) - percent = math.floor(100 * current / total) - print(f"\r{prefix} |{bar}| {percent}% {suffix}", end="", flush=True) - if current == total: - print() - - -def download_and_extract(model_name: str, checkpoint_dir: str = "./checkpoints"): - """Download and extract a model artifact from S3.""" - s3_path = f"https://czi-transcriptformer.s3.amazonaws.com/weights/{model_name}.tar.gz" - output_dir = Path(checkpoint_dir) / model_name - - # Create checkpoint directory if it doesn't exist - output_dir.parent.mkdir(parents=True, exist_ok=True) - - print(f"Downloading {model_name} from {s3_path}...") - - try: - # Create a temporary file to store the tar.gz - with tempfile.NamedTemporaryFile(suffix=".tar.gz") as tmp_file: - # Download the file using urllib with progress bar - try: - - def report_hook(count, block_size, total_size): - """Callback function to report download progress.""" - if total_size > 0 and (count % 100 == 0 or count * block_size >= total_size): - print_progress( - count * block_size, - total_size, - prefix=f"Downloading {model_name}", - ) - - urllib.request.urlretrieve(s3_path, filename=tmp_file.name, reporthook=report_hook) - print() # New line after download completes - except urllib.error.HTTPError as e: - if e.code == 404: - print(f"Error: The model {model_name} was not found at {s3_path}") - else: - print(f"Error downloading file: HTTP {e.code}") - sys.exit(1) - except urllib.error.URLError as e: - print(f"Error downloading file: {str(e)}") - sys.exit(1) - - # Reset the file pointer to the beginning - tmp_file.seek(0) - - # Extract the tar.gz file - try: - print(f"Extracting {model_name}...") - with tarfile.open(fileobj=tmp_file, mode="r:gz") as tar: - members = tar.getmembers() - total_files = len(members) - for i, member in enumerate(members, 1): - tar.extract(member, path=str(output_dir.parent)) - print_progress( - i, - total_files, - prefix=f"Extracting {model_name}", - ) - print() # New line after extraction completes - except tarfile.ReadError: - print(f"Error: The downloaded file for {model_name} is not a valid tar.gz archive") - sys.exit(1) - - print(f"Successfully downloaded and extracted {model_name} to {output_dir}") - - except Exception as e: - print(f"Error: {str(e)}") - sys.exit(1) - - -def main(): - parser = argparse.ArgumentParser(description="Download and extract TranscriptFormer model artifacts") - parser.add_argument( - "model", - choices=["tf-sapiens", "tf-exemplar", "tf-metazoa", "all", "all-embeddings"], - help="Model to download (or 'all' for all models and embeddings, 'all-embeddings' for just embeddings)", - ) - parser.add_argument( - "--checkpoint-dir", - default="./checkpoints", - help="Directory to store the downloaded checkpoints (default: ./checkpoints)", - ) - - args = parser.parse_args() - - models = { - "tf-sapiens": "tf_sapiens", - "tf-exemplar": "tf_exemplar", - "tf-metazoa": "tf_metazoa", - "all-embeddings": "all_embeddings", - } - - if args.model == "all": - # Download all models and embeddings - for model in ["tf_sapiens", "tf_exemplar", "tf_metazoa", "all_embeddings"]: - download_and_extract(model, args.checkpoint_dir) - elif args.model == "all-embeddings": - # Download only embeddings - download_and_extract("all_embeddings", args.checkpoint_dir) - else: - download_and_extract(models[args.model], args.checkpoint_dir) - - -if __name__ == "__main__": - main() diff --git a/inference.py b/inference.py deleted file mode 100644 index 2992fe3..0000000 --- a/inference.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Script to perform inference with Transcriptformer models. - -Example usage: - python inference.py --config-name=inference_config.yaml \ - model.checkpoint_path=./checkpoints/tf_sapiens \ - model.inference_config.data_files.0=test/data/human_val.h5ad \ - model.inference_config.output_path=./custom_results_dir \ - model.inference_config.output_filename=custom_embeddings.h5ad \ - model.inference_config.batch_size=8 -""" - -import json -import logging -import os - -import hydra -from omegaconf import DictConfig, OmegaConf - -from transcriptformer.model.inference import run_inference - -# Set up logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - - -@hydra.main( - config_path=os.path.join(os.path.dirname(__file__), "conf"), - config_name="inference_config.yaml", - version_base=None, -) -def main(cfg: DictConfig): - logging.debug(OmegaConf.to_yaml(cfg)) - - config_path = os.path.join(cfg.model.checkpoint_path, "config.json") - with open(config_path) as f: - config_dict = json.load(f) - mlflow_cfg = OmegaConf.create(config_dict) - - # Merge the MLflow config with the main config - cfg = OmegaConf.merge(mlflow_cfg, cfg) - - # Set the checkpoint paths based on the unified checkpoint_path - cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt") - cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - - adata_output = run_inference(cfg, data_files=cfg.model.inference_config.data_files) - - # Save the output adata - output_path = cfg.model.inference_config.output_path - if not os.path.exists(output_path): - os.makedirs(output_path) - - # Get output filename from config or use default - output_filename = getattr(cfg.model.inference_config, "output_filename", "embeddings.h5ad") - save_file = os.path.join(output_path, output_filename) - - adata_output.write_h5ad(save_file) - logging.info(f"Saved embeddings to {save_file}") - - -if __name__ == "__main__": - main() diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index d1b66f1..f3fefdc 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -24,6 +24,7 @@ --gene-col-name Column in AnnData.var with gene identifiers --precision Numerical precision (16-mixed or 32) --pretrained-embedding Path to embedding file for out-of-distribution species + --num-gpus Number of GPUs to use (1=single, -1=all available, >1=specific number) Advanced Configuration: Use --config-override for any configuration options not exposed as arguments above. @@ -52,10 +53,17 @@ """ import argparse +import json import logging +import os import sys import warnings +import torch +from omegaconf import OmegaConf + +from transcriptformer.model.inference import run_inference + # Suppress annoying warnings warnings.filterwarnings("ignore", category=FutureWarning, module="anndata") warnings.filterwarnings("ignore", category=FutureWarning, message=".*read_.*from.*anndata.*deprecated.*") @@ -164,6 +172,24 @@ def setup_inference_parser(subparsers): default=False, help="Remove duplicate genes if found instead of raising an error (default: False)", ) + parser.add_argument( + "--num-gpus", + type=int, + default=1, + help="Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) (default: 1)", + ) + parser.add_argument( + "--oom-dataloader", + action="store_true", + default=False, + help="Use map-style out-of-memory DataLoader (DistributedSampler-friendly)", + ) + parser.add_argument( + "--n-data-workers", + type=int, + default=0, + help="Number of DataLoader workers per process (map-style dataset is order-safe).", + ) # Allow arbitrary config overrides parser.add_argument( @@ -242,45 +268,103 @@ def setup_download_data_parser(subparsers): def run_inference_cli(args): """Run inference using command line arguments.""" - # Import the inference module directly - from transcriptformer.cli.inference import main as inference_main - - # Create a hydra-compatible config dictionary for direct use with inference.py - cmd = [ - "--config-name=inference_config.yaml", - f"model.checkpoint_path={args.checkpoint_path}", - f"model.inference_config.data_files.0={args.data_file}", - f"model.inference_config.batch_size={args.batch_size}", - f"model.data_config.gene_col_name={args.gene_col_name}", - f"model.inference_config.output_path={args.output_path}", - f"model.inference_config.output_filename={args.output_filename}", - f"model.inference_config.precision={args.precision}", - f"model.model_type={args.model_type}", - f"model.inference_config.emb_type={args.emb_type}", - f"model.data_config.remove_duplicate_genes={args.remove_duplicate_genes}", - ] + # Only print logo if not in distributed mode (avoids duplicates) + is_distributed = args.num_gpus != 1 + if not is_distributed: + print(TF_LOGO) + + # Load the config + config_path = os.path.join(os.path.dirname(__file__), "conf", "inference_config.yaml") + cfg = OmegaConf.load(config_path) + + # Load model config from checkpoint + model_config_path = os.path.join(args.checkpoint_path, "config.json") + with open(model_config_path) as f: + config_dict = json.load(f) + mlflow_cfg = OmegaConf.create(config_dict) + + # Merge the MLflow config with the main config + cfg = OmegaConf.merge(mlflow_cfg, cfg) + + # Override config values with CLI arguments + cfg.model.checkpoint_path = args.checkpoint_path + cfg.model.inference_config.data_files = [args.data_file] + cfg.model.inference_config.batch_size = args.batch_size + cfg.model.data_config.gene_col_name = args.gene_col_name + cfg.model.inference_config.output_path = args.output_path + cfg.model.inference_config.output_filename = args.output_filename + cfg.model.inference_config.precision = args.precision + cfg.model.model_type = args.model_type + cfg.model.inference_config.emb_type = args.emb_type + cfg.model.data_config.remove_duplicate_genes = args.remove_duplicate_genes + cfg.model.data_config.use_raw = args.use_raw + cfg.model.inference_config.num_gpus = args.num_gpus + cfg.model.inference_config.use_oom_dataloader = args.oom_dataloader + cfg.model.data_config.clip_counts = args.clip_counts + cfg.model.data_config.filter_to_vocabs = args.filter_to_vocabs + cfg.model.data_config.n_data_workers = args.n_data_workers # Add pretrained embedding if specified if args.pretrained_embedding: - cmd.append(f"model.inference_config.pretrained_embedding={args.pretrained_embedding}") + cfg.model.inference_config.pretrained_embedding = args.pretrained_embedding - # Add any arbitrary config overrides + # Apply any arbitrary config overrides for override in args.config_override: - cmd.append(override) - - # Print logo - print(TF_LOGO) - - # Override sys.argv for Hydra to pick up - saved_argv = sys.argv - sys.argv = [sys.argv[0]] + cmd + if "=" not in override: + continue + key, value = override.split("=", 1) + # Convert value to appropriate type + try: + # Try to parse as a number or boolean + if value.lower() in ["true", "false"]: + value = value.lower() == "true" + elif value.lower() in ["none", "null"]: + value = None + elif value.isdigit(): + value = int(value) + elif "." in value and all(part.isdigit() for part in value.split(".")): + value = float(value) + except Exception: + # Keep as string if conversion fails + pass + + # Use OmegaConf.update to set nested keys like "a.b.c" or list indices like "a.list.0" + OmegaConf.update(cfg, key, value) + + # Set the checkpoint paths based on the unified checkpoint_path + cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt") + cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs") + cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs") + + # Run inference directly + adata_output = run_inference(cfg, data_files=cfg.model.inference_config.data_files) + + # Save the output adata + output_path = cfg.model.inference_config.output_path + if not os.path.exists(output_path): + os.makedirs(output_path) + + # Get output filename from config or use default + output_filename = getattr(cfg.model.inference_config, "output_filename", "embeddings.h5ad") + if not output_filename.endswith(".h5ad"): + output_filename = f"{output_filename}.h5ad" + save_file = os.path.join(output_path, output_filename) + + # Check if we're in a distributed environment + if is_distributed: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 - try: - # Call the main function directly - inference_main() - finally: - # Restore original sys.argv - sys.argv = saved_argv + # Split the filename and add rank before extension + rank_file = save_file.replace(".h5ad", f"_{rank}.h5ad") + adata_output.write_h5ad(rank_file) + print(f"Rank {rank} completed processing, saved partial results to {rank_file}") + else: + # Single GPU mode - save normally + adata_output.write_h5ad(save_file) + print(f"Inference completed! Saved embeddings to {save_file}") def run_download_cli(args): @@ -319,9 +403,6 @@ def run_download_data_cli(args): # Parse species list species_list = [s.strip() for s in args.species.split(",")] if args.species else [] - # Print logo - print(TF_LOGO) - # Run the download try: successful_downloads = download_data_main( diff --git a/src/transcriptformer/cli/conf/inference_config.yaml b/src/transcriptformer/cli/conf/inference_config.yaml index 4523c5b..a1aafe6 100644 --- a/src/transcriptformer/cli/conf/inference_config.yaml +++ b/src/transcriptformer/cli/conf/inference_config.yaml @@ -23,6 +23,8 @@ model: pretrained_embedding: null # Path to pretrained embeddings for out-of-distribution species precision: 16-mixed # Numerical precision for inference (16-mixed, 32, etc.) emb_type: cell # Type of embeddings to extract: "cell" for mean-pooled cell embeddings or "cge" for contextual gene embeddings + num_gpus: 1 # Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) + use_oom_dataloader: false # Use OOM-safe map-style DataLoader with DistributedSampler data_config: _target_: transcriptformer.data.dataclasses.DataConfig @@ -36,3 +38,4 @@ model: min_expressed_genes: 0 # Minimum number of expressed genes required per cell use_raw: "auto" # Whether to use .raw.X (True), .X (False), or auto-detect (auto/null) remove_duplicate_genes: false # Whether to remove duplicate genes instead of raising an error + n_data_workers: 0 # Leave as 0 to ensure deterministic behavior diff --git a/src/transcriptformer/data/dataclasses.py b/src/transcriptformer/data/dataclasses.py index 6cf8887..f98c8a3 100644 --- a/src/transcriptformer/data/dataclasses.py +++ b/src/transcriptformer/data/dataclasses.py @@ -177,7 +177,7 @@ class InferenceConfig: load_checkpoint (str): Path to checkpoint to load output_path (str): Path to save outputs output_filename (str): Filename for the output embeddings (default: embeddings.h5ad) - num_gpus_per_node (int): GPUs per node (default: 1) + num_gpus (int): Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) (default: 1) special_tokens (list): Special tokens to use emb_type (str): Type of embeddings to extract - "cell" for mean-pooled cell embeddings or "cge" for contextual gene embeddings (default: "cell") """ @@ -189,8 +189,9 @@ class InferenceConfig: load_checkpoint: str | None output_path: str | None output_filename: str | None = "embeddings.h5ad" - num_gpus_per_node: int = 1 + num_gpus: int = 1 num_nodes: int = 1 + use_oom_dataloader: bool = False precision: str = "16-mixed" special_tokens: list = field(default_factory=list) pretrained_embedding: list = field(default_factory=list) diff --git a/src/transcriptformer/data/dataloader.py b/src/transcriptformer/data/dataloader.py index 58371f5..a436be4 100644 --- a/src/transcriptformer/data/dataloader.py +++ b/src/transcriptformer/data/dataloader.py @@ -18,10 +18,18 @@ ) -def load_data(file_path): - """Load H5AD file.""" +def load_data(file_path, *, backed: bool = False): + """Load H5AD file. + + Args: + file_path: Path to .h5ad file + backed: If True, use memory-mapped backed='r' mode (for streaming); otherwise fully load into memory + """ try: - adata = sc.read_h5ad(file_path) + if backed: + adata = anndata.read_h5ad(file_path, backed="r") + else: + adata = sc.read_h5ad(file_path) return adata, True except Exception as e: logging.error(f"Failed to read file {file_path}: {e}") @@ -39,9 +47,12 @@ def apply_filters( min_expressed_genes, ): """Apply filters to the data.""" + n_cells = X.shape[0] + if filter_to_vocab: filter_idx = [i for i, name in enumerate(gene_names) if name in vocab] X = X[:, filter_idx] + logging.info(f"Filtered {len(gene_names)} genes to {len(filter_idx)} genes in vocab") gene_names = gene_names[filter_idx] if X.shape[1] == 0: logging.warning(f"Warning: Filtered all genes from {file_path}") @@ -64,6 +75,8 @@ def apply_filters( X = X[filter_idx] obs = obs.iloc[filter_idx] + logging.info(f"Filtered {n_cells} cells to {X.shape[0]} cells") + return X, obs, gene_names @@ -155,6 +168,121 @@ def process_batch( return result +def get_counts_layer(adata: anndata.AnnData, use_raw: bool | None): + if use_raw is True: + if adata.raw is not None: + logging.info("Using 'raw.X' layer from AnnData object") + return adata.raw.X + else: + raise ValueError("raw.X not found in AnnData object") + elif use_raw is False: + if adata.X is not None: + logging.info("Using 'X' layer from AnnData object") + return adata.X + else: + raise ValueError("X not found in AnnData object") + else: # None - try raw first, then fallback to X + if adata.raw is not None: + logging.info("Using 'raw.X' layer from AnnData object") + return adata.raw.X + elif adata.X is not None: + logging.info("Using 'X' layer from AnnData object") + return adata.X + else: + raise ValueError("No valid data layer found in AnnData object") + + +def to_dense(X: np.ndarray | csr_matrix | csc_matrix) -> np.ndarray: + if isinstance(X, csr_matrix | csc_matrix): + return X.toarray() + elif isinstance(X, np.ndarray): + return X + else: + raise TypeError(f"Expected numpy array or sparse matrix, got {type(X)}") + + +def is_raw_counts(X: np.ndarray | csr_matrix | csc_matrix) -> bool: + """Check if a matrix looks like raw counts (integer-valued where non-zero). + + Handles both dense numpy arrays and sparse CSR/CSC matrices without densifying the full matrix. + """ + # Sparse path: operate on non-zero data directly + if isinstance(X, csr_matrix | csc_matrix): + data = X.data + if data.size == 0: + return False + # Sample if very large + if data.size > 1000: + idx = np.random.choice(data.size, 1000, replace=False) + data = data[idx] + return np.all(np.abs(data - np.round(data)) < 1e-6) + + # Dense path + non_zero_mask = X > 0 + if not np.any(non_zero_mask): + return False + non_zero_values = X[non_zero_mask] + if non_zero_values.size > 1000: + idx = np.random.choice(non_zero_values.size, 1000, replace=False) + non_zero_values = non_zero_values.flatten()[idx] + return np.all(np.abs(non_zero_values - np.round(non_zero_values)) < 1e-6) + + +def load_gene_features( + adata: anndata.AnnData, gene_col_name: str, remove_duplicate_genes: bool, use_raw: bool | None = None +): + try: + # Select the appropriate var depending on which matrix will be used + using_raw = bool(use_raw is True or (use_raw is None and getattr(adata, "raw", None) is not None)) + has_raw = getattr(adata, "raw", None) is not None + using_raw = bool(use_raw is True or (use_raw is None and has_raw)) + var_df = adata.raw.var if using_raw and has_raw else adata.var + + # Prefer requested column; otherwise use index which aligns with matrix columns for that layer + if gene_col_name in var_df.columns: + gene_names = np.array(list(var_df[gene_col_name].values)) + else: + raise ValueError( + f"Gene column '{gene_col_name}' not found in var DataFrame columns: {list(var_df.columns)}" + ) + + # Remove version numbers from gene names + gene_names = np.array([id.split(".")[0] for id in gene_names]) + + gene_counts = Counter(gene_names) + duplicates = {gene for gene, count in gene_counts.items() if count > 1} + if len(duplicates) > 0: + if remove_duplicate_genes: + seen = set() + unique_indices = [] + for i, gene in enumerate(gene_names): + if gene not in seen: + seen.add(gene) + unique_indices.append(i) + adata = adata[:, unique_indices].copy() + gene_names = gene_names[unique_indices] + logging.warning( + f"Removed {len(duplicates)} duplicate genes after removing version numbers. Kept first occurrence." + ) + else: + raise ValueError( + "Found duplicate genes after removing version numbers. " + "Remove duplicates or pass --remove-duplicate-genes." + ) + + return gene_names, True, adata + except KeyError: + return None, False, adata + + +def validate_gene_dimension(X: np.ndarray, gene_names: np.ndarray, gene_col_name: str): + if X.shape[1] != len(gene_names): + raise ValueError( + f"Mismatch between expression matrix columns ({X.shape[1]}) and gene names length ({len(gene_names)}). " + f"Ensure 'adata.var[{gene_col_name}]' exists and aligns with the matrix columns." + ) + + class AnnDataset(Dataset): def __init__( self, @@ -210,91 +338,6 @@ def __init__( logging.info("Loading and processing all data") self.data = self.load_and_process_all_data() - def _get_counts_layer(self, adata: anndata.AnnData) -> str: - if self.use_raw is True: - if adata.raw is not None: - logging.info("Using 'raw.X' layer from AnnData object") - return adata.raw.X - else: - raise ValueError("raw.X not found in AnnData object") - elif self.use_raw is False: - if adata.X is not None: - logging.info("Using 'X' layer from AnnData object") - return adata.X - else: - raise ValueError("X not found in AnnData object") - else: # None - try raw first, then fallback to X - if adata.raw is not None: - logging.info("Using 'raw.X' layer from AnnData object") - return adata.raw.X - elif adata.X is not None: - logging.info("Using 'X' layer from AnnData object") - return adata.X - else: - raise ValueError("No valid data layer found in AnnData object") - - def _to_dense(self, X: np.ndarray | csr_matrix | csc_matrix) -> np.ndarray: - if isinstance(X, csr_matrix | csc_matrix): - return X.toarray() - elif isinstance(X, np.ndarray): - return X - else: - raise TypeError(f"Expected numpy array or sparse matrix, got {type(X)}") - - def _is_raw_counts(self, X: np.ndarray) -> bool: - # Get non-zero values - non_zero_mask = X > 0 - if not np.any(non_zero_mask): - return False - - # Sample up to 1000 non-zero values - non_zero_values = X[non_zero_mask] - if len(non_zero_values) > 1000: - non_zero_values = np.random.choice(non_zero_values, 1000, replace=False) - - # Check if values are roughly integers (within float32 precision) - # float32 has ~7 decimal digits of precision - is_integer = np.all(np.abs(non_zero_values - np.round(non_zero_values)) < 1e-6) - - return is_integer - - def _load_gene_features(self, adata): - """Load gene features and remove version numbers from ensembl ids""" - try: - gene_names = np.array(list(adata.var[self.gene_col_name].values)) - gene_names = np.array([id.split(".")[0] for id in gene_names]) - - # Check for duplicates after removing version numbers - gene_counts = Counter(gene_names) - duplicates = {gene for gene, count in gene_counts.items() if count > 1} - if len(duplicates) > 0: - if self.remove_duplicate_genes: - # Remove duplicates by keeping only the first occurrence - seen = set() - unique_indices = [] - for i, gene in enumerate(gene_names): - if gene not in seen: - seen.add(gene) - unique_indices.append(i) - - # Filter adata to keep only unique genes - adata = adata[:, unique_indices].copy() - gene_names = gene_names[unique_indices] - - logging.warning( - f"Removed {len(duplicates)} duplicate genes after removing version numbers. " - f"Kept first occurrence of each gene. " - ) - else: - raise ValueError( - f"Found {len(duplicates)} duplicate genes after removing version numbers. " - f"Please remove duplicate genes from your data or use --remove-duplicate-genes flag. " - ) - - return gene_names, True, adata - except KeyError: - return None, False, adata - def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: if isinstance(file, str): file_path = file @@ -313,22 +356,30 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: logging.error(f"Failed to load data from {file_path}") return None - gene_names, success, adata = self._load_gene_features(adata) + gene_names, success, adata = load_gene_features( + adata, self.gene_col_name, self.remove_duplicate_genes, use_raw=self.use_raw + ) if not success: logging.error(f"Failed to load gene features from {file_path}") return None - X = self._get_counts_layer(adata) - X = self._to_dense(X) + X = get_counts_layer(adata, self.use_raw) + # AnnDataset loads and processes all data in-memory; convert to dense for batching + X = to_dense(X) obs = adata.obs + # Validate that gene dimension matches number of gene names + validate_gene_dimension(X, gene_names, self.gene_col_name) + # Check if the data appears to be raw counts - if not self._is_raw_counts(X): + logging.info("Checking if data is raw counts") + if not is_raw_counts(X): logging.warning( "Data does not appear to be raw counts. TranscriptFormer expects unnormalized count data. " "If your data is normalized, consider using the original count matrix instead." ) + logging.info("Applying filters") vocab = self.gene_vocab X, obs, gene_names = apply_filters( X, @@ -340,10 +391,12 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: self.filter_outliers, self.min_expressed_genes, ) + if X is None: logging.warning(f"Data was filtered out completely for {file_path}") return None + logging.info("Processing data") batch = process_batch( X, obs, @@ -447,3 +500,153 @@ def collate_fn(batch: BatchData | list[BatchData]) -> BatchData: ), ) return collated_batch + + +class AnnDatasetOOM(Dataset): + """Map-style OOM-safe dataset using backed reads and per-item processing. + + Designed to provide OOM-safe iteration while leveraging PyTorch's + DistributedSampler for automatic sharding across DDP ranks. + """ + + collate_fn = staticmethod(AnnDataset.collate_fn) + + def __init__( + self, + files_list: list[str], + gene_vocab: dict[str, str], + data_dir: str | None = None, + aux_vocab: dict[str, dict[str, str]] | None = None, + max_len: int = 2048, + normalize_to_scale: float | None = None, + sort_genes: bool = False, + randomize_order: bool = False, + pad_zeros: bool = True, + pad_token: str = "[PAD]", + gene_col_name: str = "ensembl_id", + filter_to_vocab: bool = True, + clip_counts: float = 1e10, + obs_keys: list[str] | None = None, + use_raw: bool | None = None, + remove_duplicate_genes: bool = False, + ): + super().__init__() + self.files_list = files_list + self.data_dir = data_dir + self.gene_vocab = gene_vocab + self.aux_vocab = aux_vocab + self.max_len = max_len + self.normalize_to_scale = normalize_to_scale + self.sort_genes = sort_genes + self.randomize_order = randomize_order + self.pad_zeros = pad_zeros + self.pad_token = pad_token + self.gene_col_name = gene_col_name + self.filter_to_vocab = filter_to_vocab + self.clip_counts = clip_counts + self.obs_keys = obs_keys + self.use_raw = use_raw + self.remove_duplicate_genes = remove_duplicate_genes + + self.gene_tokenizer = BatchGeneTokenizer(gene_vocab) + if aux_vocab is not None: + self.aux_tokenizer = BatchObsTokenizer(aux_vocab) + + # Open backed handles and build cumulative row offsets + self._handles: list[anndata.AnnData] = [] + self._gene_names_per_file: list[np.ndarray] = [] + self._filter_idx_per_file: list[list[int] | None] = [] + self._X_per_file: list = [] + self._n_rows: list[int] = [] + for file in self.files_list: + file_path = file if self.data_dir is None else os.path.join(self.data_dir, file) + adata = anndata.read_h5ad(file_path, backed="r") + gene_names, success, adata = load_gene_features( + adata, self.gene_col_name, self.remove_duplicate_genes, use_raw=self.use_raw + ) + if not success: + raise ValueError(f"Failed to load gene features from {file_path}") + # Optional vocab filtering at token level + filter_idx = None + if self.filter_to_vocab: + original_gene_count = len(gene_names) + filter_idx = [i for i, name in enumerate(gene_names) if name in self.gene_vocab] + gene_names = gene_names[filter_idx] + logging.info( + f"Filtered {original_gene_count} genes to {len(gene_names)} genes in vocab for file {file_path}" + ) + if len(gene_names) == 0: + raise ValueError(f"No genes remaining after filtering for file {file_path}") + + self._handles.append(adata) + self._gene_names_per_file.append(gene_names) + self._filter_idx_per_file.append(filter_idx) + X_layer = get_counts_layer(adata, self.use_raw) + self._X_per_file.append(X_layer) + self._n_rows.append(int(adata.n_obs)) + + self._offsets = np.cumsum([0] + self._n_rows) + + def __len__(self) -> int: + return int(self._offsets[-1]) + + def _loc(self, idx: int) -> tuple[int, int]: + file_id = int(np.searchsorted(self._offsets, idx, side="right") - 1) + row = int(idx - self._offsets[file_id]) + return file_id, row + + def __getitem__(self, idx: int) -> BatchData: + file_id, row = self._loc(idx) + adata = self._handles[file_id] + gene_names = self._gene_names_per_file[file_id] + filter_idx = self._filter_idx_per_file[file_id] + + X = self._X_per_file[file_id] + # Some backed sparse implementations use __getitem__ returning 2D; ensure 1D + x_row = X[row] + # Only convert to dense if the row is actually sparse + if isinstance(x_row, csr_matrix | csc_matrix): + x_row = x_row.toarray().ravel() + else: + x_row = np.asarray(x_row).ravel() + if filter_idx is not None: + x_row = x_row[filter_idx] + + obs_row = adata.obs.iloc[row : row + 1] + + # Build a 1-row batch and reuse existing processing pipeline + x_batch = np.expand_dims(x_row, axis=0) + batch = process_batch( + x_batch, + obs_row, + gene_names, + self.gene_tokenizer, + getattr(self, "aux_tokenizer", None), + self.sort_genes, + self.randomize_order, + self.max_len, + self.pad_zeros, + self.pad_token, + self.gene_vocab, + self.normalize_to_scale, + self.clip_counts, + self.aux_vocab, + ) + + # Convert to BatchData for collate_fn compatibility + obs_dict = None + if self.obs_keys is not None: + obs_dict = {} + cols = list(obs_row.columns) if "all" in self.obs_keys else list(self.obs_keys or []) + for col in cols: + obs_dict[col] = np.array(obs_row[col].tolist())[:, None] + + return BatchData( + gene_counts=batch["gene_counts"][0], + gene_token_indices=batch["gene_token_indices"][0], + file_path=None, + aux_token_indices=( + batch.get("aux_token_indices")[0] if batch.get("aux_token_indices") is not None else None + ), + obs=obs_dict, + ) diff --git a/src/transcriptformer/model/inference.py b/src/transcriptformer/model/inference.py index 98605e3..edb89e3 100644 --- a/src/transcriptformer/model/inference.py +++ b/src/transcriptformer/model/inference.py @@ -9,7 +9,7 @@ from pytorch_lightning.loggers import CSVLogger from torch.utils.data import DataLoader -from transcriptformer.data.dataloader import AnnDataset +from transcriptformer.data.dataloader import AnnDataset, AnnDatasetOOM from transcriptformer.model.embedding_surgery import change_embedding_layer from transcriptformer.tokenizer.vocab import load_vocabs_and_embeddings from transcriptformer.utils.utils import stack_dict @@ -37,6 +37,12 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): warnings.filterwarnings( "ignore", message="The 'predict_dataloader' does not have many workers which may be a bottleneck" ) + warnings.filterwarnings( + "ignore", + message="Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading", + category=UserWarning, + ) + warnings.filterwarnings("ignore", message="Transforming to str index.", category=UserWarning) # Load vocabs and embeddings (gene_vocab, aux_vocab), emb_matrix = load_vocabs_and_embeddings(cfg) @@ -112,8 +118,28 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): "clip_counts": cfg.model.data_config.clip_counts, "obs_keys": cfg.model.inference_config.obs_keys, "remove_duplicate_genes": cfg.model.data_config.remove_duplicate_genes, + "use_raw": cfg.model.data_config.use_raw, } - dataset = AnnDataset(data_files, **data_kwargs) + if getattr(cfg.model.inference_config, "use_oom_dataloader", False): + # Use OOM-safe map-style dataset + dataset = AnnDatasetOOM( + data_files, + gene_vocab, + aux_vocab=aux_vocab, + max_len=cfg.model.model_config.seq_len, + normalize_to_scale=cfg.model.data_config.normalize_to_scale, + sort_genes=cfg.model.data_config.sort_genes, + randomize_order=cfg.model.data_config.randomize_genes, + pad_zeros=cfg.model.data_config.pad_zeros, + gene_col_name=cfg.model.data_config.gene_col_name, + filter_to_vocab=cfg.model.data_config.filter_to_vocabs, + clip_counts=cfg.model.data_config.clip_counts, + obs_keys=cfg.model.inference_config.obs_keys, + use_raw=cfg.model.data_config.use_raw, + remove_duplicate_genes=cfg.model.data_config.remove_duplicate_genes, + ) + else: + dataset = AnnDataset(data_files, **data_kwargs) # Create dataloader dataloader = DataLoader( @@ -125,14 +151,42 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): collate_fn=dataset.collate_fn, ) + # Determine number of GPUs to use + num_gpus = getattr(cfg.model.inference_config, "num_gpus", 1) + + # Handle special cases for num_gpus + if num_gpus == -1: + # Use all available GPUs + devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 + accelerator = "gpu" if torch.cuda.is_available() else "cpu" + elif num_gpus > 1: + # Use specified number of GPUs + available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + if available_gpus < num_gpus: + logging.warning( + f"Requested {num_gpus} GPUs but only {available_gpus} available. Using {available_gpus} GPUs." + ) + devices = available_gpus if available_gpus > 0 else 1 + accelerator = "gpu" if available_gpus > 0 else "cpu" + else: + devices = num_gpus + accelerator = "gpu" + else: + # Use single GPU or CPU + devices = 1 + accelerator = "gpu" if torch.cuda.is_available() else "cpu" + + logging.info(f"Using {devices} device(s) with accelerator: {accelerator}") + # Create Trainer trainer = pl.Trainer( - accelerator="gpu", - devices=1, # Multiple GPUs/nodes not supported for inference + accelerator=accelerator, + devices=devices, num_nodes=1, precision=cfg.model.inference_config.precision, limit_predict_batches=None, logger=CSVLogger("logs", name="inference"), + strategy="ddp" if devices > 1 else "auto", # Use DDP for multi-GPU ) # Run prediction diff --git a/test/test_cli.py b/test/test_cli.py index b496d74..428f73b 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -8,7 +8,6 @@ from transcriptformer.cli import ( main, run_download_cli, - run_inference_cli, setup_download_parser, setup_inference_parser, ) @@ -63,52 +62,6 @@ def test_inference_command(self, mock_run_inference, monkeypatch): main() mock_run_inference.assert_called_once() - @mock.patch("transcriptformer.cli.inference.main") - def test_run_inference_cli(self, mock_inference_main, monkeypatch): - """Test run_inference_cli function properly calls inference.main.""" - args = mock.MagicMock() - args.checkpoint_path = "/path/to/checkpoint" - args.data_file = "/path/to/data.h5ad" - args.output_path = "./inference_results" - args.output_filename = "embeddings.h5ad" - args.batch_size = 8 - args.gene_col_name = "ensembl_id" - args.precision = "16-mixed" - args.pretrained_embedding = None - args.config_override = [] - args.model_type = "transcriptformer" - args.emb_type = "cell" - - # Test that the function properly sets up Hydra config - original_argv = sys.argv.copy() - run_inference_cli(args) - mock_inference_main.assert_called_once() - # Check sys.argv was restored - assert sys.argv == original_argv - - @mock.patch("transcriptformer.cli.inference.main") - def test_run_inference_cli_with_cge(self, mock_inference_main, monkeypatch): - """Test run_inference_cli function with CGE embedding type.""" - args = mock.MagicMock() - args.checkpoint_path = "/path/to/checkpoint" - args.data_file = "/path/to/data.h5ad" - args.output_path = "./inference_results" - args.output_filename = "embeddings.h5ad" - args.batch_size = 8 - args.gene_col_name = "ensembl_id" - args.precision = "16-mixed" - args.pretrained_embedding = None - args.config_override = [] - args.model_type = "transcriptformer" - args.emb_type = "cge" - - # Test that the function properly sets up Hydra config - original_argv = sys.argv.copy() - run_inference_cli(args) - mock_inference_main.assert_called_once() - # Check sys.argv was restored - assert sys.argv == original_argv - class TestDownloadCommand: """Tests for the download command.""" diff --git a/test/test_compare_emb.py b/test/test_compare_emb.py index 2c9178a..5410d07 100644 --- a/test/test_compare_emb.py +++ b/test/test_compare_emb.py @@ -3,7 +3,7 @@ import anndata as ad import numpy as np -from scipy.stats import pearsonr +from scipy.stats import pearsonr, ttest_rel def compare_embeddings(file1, file2, tolerance=1e-5): @@ -31,18 +31,18 @@ def compare_embeddings(file1, file2, tolerance=1e-5): adata2 = ad.read_h5ad(file2) # Check if embeddings exist - if "emb" not in adata1.obsm or "emb" not in adata2.obsm: + if "embeddings" not in adata1.obsm or "embeddings" not in adata2.obsm: missing = [] - if "emb" not in adata1.obsm: - missing.append(f"'emb' not found in {file1}") - if "emb" not in adata2.obsm: - missing.append(f"'emb' not found in {file2}") + if "embeddings" not in adata1.obsm: + missing.append(f"'embeddings' not found in {file1}") + if "embeddings" not in adata2.obsm: + missing.append(f"'embeddings' not found in {file2}") print(f"Error: {', '.join(missing)}") return False # Get embeddings - emb1 = adata1.obsm["emb"] - emb2 = adata2.obsm["emb"] + emb1 = adata1.obsm["embeddings"] + emb2 = adata2.obsm["embeddings"] # Check shapes if emb1.shape != emb2.shape: @@ -64,10 +64,20 @@ def compare_embeddings(file1, file2, tolerance=1e-5): emb2_flat = emb2.flatten() corr, _ = pearsonr(emb1_flat, emb2_flat) + # Calculate z-score and p-value for difference test + # Using paired t-test to test null hypothesis that embeddings are the same + diff = emb1_flat - emb2_flat + t_stat, p_value = ttest_rel(emb1_flat, emb2_flat) + + # Calculate z-score manually from the differences + z_score = np.mean(diff) / (np.std(diff) / np.sqrt(len(diff))) + print("Embeddings differ:") print(f" Max absolute difference: {max_diff:.6e}") print(f" Mean absolute difference: {mean_diff:.6e}") print(f" Pearson correlation: {corr:.6f}") + print(f" Z-score (difference): {z_score:.6f}") + print(f" P-value (paired t-test): {p_value:.6e}") return False diff --git a/test/test_compare_umap.py b/test/test_compare_umap.py new file mode 100644 index 0000000..949c466 --- /dev/null +++ b/test/test_compare_umap.py @@ -0,0 +1,240 @@ +import argparse +import os +import warnings + +import anndata as ad +import numpy as np +from scipy.spatial import procrustes +from scipy.stats import pearsonr + + +def _require_umap(): + try: + import umap.umap_ as umap # type: ignore + except Exception as exc: # pragma: no cover - informative error path + raise RuntimeError( + "This script requires the 'umap-learn' package. Install with 'pip install umap-learn'." + ) from exc + return umap + + +def compute_umap( + embeddings: np.ndarray, + n_neighbors: int = 15, + min_dist: float = 0.1, + metric: str = "euclidean", + random_state: int | None = 42, +) -> np.ndarray: + """ + Compute a 2D UMAP from input embeddings. + + Parameters + ---------- + embeddings : np.ndarray + 2D array of shape (n_samples, n_features) + n_neighbors : int + UMAP n_neighbors parameter + min_dist : float + UMAP min_dist parameter + metric : str + Distance metric for UMAP + random_state : int | None + Random seed for reproducibility + + Returns + ------- + np.ndarray + 2D array of shape (n_samples, 2) with UMAP coordinates + """ + umap = _require_umap() + + # Suppress noisy warnings sometimes emitted by numba/umap during fit + warnings.filterwarnings("ignore", message=".*cannot cache compiled function.*") + reducer = umap.UMAP( + n_neighbors=n_neighbors, + min_dist=min_dist, + metric=metric, + n_components=2, + random_state=random_state, + ) + return reducer.fit_transform(embeddings) + + +def compare_umaps( + file1: str, + file2: str, + n_neighbors: int = 15, + min_dist: float = 0.1, + metric: str = "euclidean", + random_state: int | None = 42, + procrustes_tolerance: float = 1e-2, + save_plot: str | None = None, + obs_key: str | None = "cell_type", +) -> bool: + """ + Load embeddings from two AnnData files, compute UMAPs independently, and compare via Procrustes. + + Returns True if the Procrustes disparity <= procrustes_tolerance. + """ + # Load AnnData files + adata1 = ad.read_h5ad(file1) + adata2 = ad.read_h5ad(file2) + + if "embeddings" not in adata1.obsm or "embeddings" not in adata2.obsm: + missing = [] + if "embeddings" not in adata1.obsm: + missing.append(f"'embeddings' not found in {file1}") + if "embeddings" not in adata2.obsm: + missing.append(f"'embeddings' not found in {file2}") + raise ValueError("; ".join(missing)) + + emb1 = np.asarray(adata1.obsm["embeddings"]) # (n, d) + emb2 = np.asarray(adata2.obsm["embeddings"]) # (n, d) + + if emb1.shape[0] != emb2.shape[0]: + raise ValueError(f"Number of rows differ between embeddings: {emb1.shape[0]} vs {emb2.shape[0]}") + + # Compute UMAPs independently + umap1 = compute_umap(emb1, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, random_state=random_state) + umap2 = compute_umap(emb2, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, random_state=random_state) + + # Procrustes analysis aligns scale/rotation/translation; returns disparity (lower is better) + mtx1, mtx2, disparity = procrustes(umap1, umap2) + + # Simple additional similarity: correlation of flattened coordinates after alignment + corr, _ = pearsonr(mtx1.ravel(), mtx2.ravel()) + + print("UMAP comparison:") + print(f" Procrustes disparity: {disparity:.6e}") + print(f" Pearson correlation (aligned): {corr:.6f}") + + if save_plot is not None: + try: + import matplotlib.pyplot as plt # Lazy import for optional plotting + + # Prepare colors using a consistent palette across both datasets + labels1 = None + labels2 = None + color_map = None + if obs_key is not None: + try: + series1 = adata1.obs[obs_key] + series2 = adata2.obs[obs_key] + labels1 = series1.astype(str).to_numpy() + labels2 = series2.astype(str).to_numpy() + categories = sorted(set(labels1).union(set(labels2))) + + import matplotlib as mpl + + if len(categories) <= 20: + cmap = mpl.cm.get_cmap("tab20", len(categories)) + palette = [mpl.colors.to_hex(cmap(i)) for i in range(len(categories))] + else: + # fallback palette for many categories + cmap = mpl.cm.get_cmap("hsv", len(categories)) + palette = [mpl.colors.to_hex(cmap(i)) for i in range(len(categories))] + color_map = {cat: palette[i] for i, cat in enumerate(categories)} + except Exception as e: + print(f"Warning: could not color by '{obs_key}': {e}") + + fig, axes = plt.subplots(1, 2, figsize=(10, 4)) + if labels1 is not None and color_map is not None: + colors1 = [color_map.get(lbl, "#000000") for lbl in labels1] + axes[0].scatter(mtx1[:, 0], mtx1[:, 1], s=3, alpha=0.6, c=colors1) + else: + axes[0].scatter(mtx1[:, 0], mtx1[:, 1], s=3, alpha=0.6) + axes[0].set_title("UMAP 1 (aligned)") + axes[0].set_xticks([]) + axes[0].set_yticks([]) + + if labels2 is not None and color_map is not None: + colors2 = [color_map.get(lbl, "#000000") for lbl in labels2] + axes[1].scatter(mtx2[:, 0], mtx2[:, 1], s=3, alpha=0.6, c=colors2) + else: + axes[1].scatter(mtx2[:, 0], mtx2[:, 1], s=3, alpha=0.6, color="tab:orange") + axes[1].set_title("UMAP 2 (aligned)") + axes[1].set_xticks([]) + axes[1].set_yticks([]) + + # Optional legend if number of categories is reasonable + if color_map is not None and len(color_map) > 0 and len(color_map) <= 20: + import matplotlib.patches as mpatches + + handles = [mpatches.Patch(color=clr, label=cat) for cat, clr in color_map.items()] + fig.legend(handles=handles, loc="lower center", ncol=min(5, len(handles)), frameon=False) + + plt.tight_layout() + out_dir = os.path.dirname(save_plot) + if out_dir and not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok=True) + plt.savefig(save_plot, dpi=150) + plt.close(fig) + print(f"Saved comparison plot to {save_plot}") + except Exception as e: # pragma: no cover - plotting not essential in tests + print(f"Warning: failed to save plot: {e}") + + # Pass criterion + return disparity <= procrustes_tolerance + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Compute UMAPs from two AnnData files' obsm['embeddings'] independently and compare them via Procrustes." + ) + ) + parser.add_argument("file1", type=str, help="Path to first AnnData .h5ad file") + parser.add_argument("file2", type=str, help="Path to second AnnData .h5ad file") + parser.add_argument("--n-neighbors", type=int, default=15, help="UMAP n_neighbors (default: 15)") + parser.add_argument("--min-dist", type=float, default=0.1, help="UMAP min_dist (default: 0.1)") + parser.add_argument("--metric", type=str, default="euclidean", help="UMAP metric (default: euclidean)") + parser.add_argument("--random-state", type=int, default=42, help="Random seed for UMAP (default: 42)") + parser.add_argument( + "--procrustes-tolerance", + type=float, + default=1e-2, + help="Maximum acceptable Procrustes disparity to consider UMAPs similar (default: 1e-2)", + ) + parser.add_argument( + "--save-plot", + type=str, + default="./umap_comparison.png", + help="Optional path to save a side-by-side aligned UMAP comparison plot (PNG)", + ) + parser.add_argument( + "--obs-key", + type=str, + default="cell_type", + help="Column in .obs to color points by (default: cell_type). Use 'none' to disable.", + ) + + args = parser.parse_args() + + # Existence check + for fp in [args.file1, args.file2]: + if not os.path.exists(fp): + print(f"Error: File not found: {fp}") + raise SystemExit(1) + + ok = compare_umaps( + args.file1, + args.file2, + n_neighbors=args.n_neighbors, + min_dist=args.min_dist, + metric=args.metric, + random_state=args.random_state, + procrustes_tolerance=args.procrustes_tolerance, + save_plot=args.save_plot, + obs_key=(None if (args.obs_key is None or str(args.obs_key).lower() == "none") else args.obs_key), + ) + + if ok: + print("UMAPs are similar within tolerance.") + raise SystemExit(0) + else: + print("UMAPs differ beyond tolerance.") + raise SystemExit(2) + + +if __name__ == "__main__": + main()