From 5988c1c7858bdbbfa109c34488964e432e67703b Mon Sep 17 00:00:00 2001 From: jpearce Date: Mon, 11 Aug 2025 16:52:07 -0700 Subject: [PATCH 01/11] feat: multi-gpu inference --- README.md | 12 ++ download_artifacts.py | 163 ------------------ inference.py | 63 ------- src/transcriptformer/cli/__init__.py | 134 +++++++++----- .../cli/conf/inference_config.yaml | 1 + src/transcriptformer/data/dataclasses.py | 2 + src/transcriptformer/model/inference.py | 30 +++- 7 files changed, 138 insertions(+), 267 deletions(-) delete mode 100644 download_artifacts.py delete mode 100644 inference.py diff --git a/README.md b/README.md index c0f837b..c9a3485 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,10 @@ See the `pyproject.toml` file for the complete list of dependencies. ### Hardware Requirements - GPU (A100 40GB recommended) for efficient inference and embedding extraction. - Can also use a GPU with a lower amount of VRAM (16GB) by setting the inference batch size to 1-4. +- **Multi-GPU support**: For faster inference on large datasets, use multiple GPUs with the `--num-gpus` parameter. + - Recommended for datasets with >100k cells + - Scales batch processing across available GPUs using Distributed Data Parallel (DDP) + - Best performance with matched GPU types and sufficient inter-GPU bandwidth ## Using the TranscriptFormer CLI @@ -234,6 +238,13 @@ transcriptformer inference \ --data-file test/data/human_val.h5ad \ --emb-type cge \ --batch-size 8 + +# Multi-GPU inference using 4 GPUs (-1 will use all available on the system) +transcriptformer inference \ + --checkpoint-path ./checkpoints/tf_sapiens \ + --data-file test/data/human_val.h5ad \ + --num-gpus 4 \ + --batch-size 32 ``` You can also use the CLI it run inference on the ESM2-CE baseline model discussed in the paper: @@ -281,6 +292,7 @@ transcriptformer download-data --help - `--embedding-layer-index INT`: Index of the transformer layer to extract embeddings from (-1 for last layer, default: -1). Use with `transcriptformer` model type. - `--model-type {transcriptformer,esm2ce}`: Type of model to use (default: `transcriptformer`). Use `esm2ce` to extract raw ESM2-CE gene embeddings. - `--emb-type {cell,cge}`: Type of embeddings to extract (default: `cell`). Use `cell` for mean-pooled cell embeddings or `cge` for contextual gene embeddings. +- `--num-gpus INT`: Number of GPUs to use for inference (default: 1). Use -1 for all available GPUs, or specify a specific number. - `--config-override key.path=value`: Override any configuration value directly. ### Input Data Format and Preprocessing: diff --git a/download_artifacts.py b/download_artifacts.py deleted file mode 100644 index f424263..0000000 --- a/download_artifacts.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 - -""" -Download and extract TranscriptFormer model artifacts from a public S3 bucket. - -This script provides a convenient way to download and extract TranscriptFormer model weights -from a public S3 bucket. It supports downloading individual models or all models at once, -with progress indicators for both download and extraction processes. - -Usage: - python download_artifacts.py [model] [--checkpoint-dir DIR] - - model: The model to download. Options are: - - tf-sapiens: Download the sapiens model - - tf-exemplar: Download the exemplar model - - tf-metazoa: Download the metazoa model - - all: Download all models and embeddings - - all-embeddings: Download only the embedding files - - --checkpoint-dir: Optional directory to store the downloaded checkpoints. - Defaults to './checkpoints' - -Examples --------- - # Download the sapiens model - python download_artifacts.py tf-sapiens - - # Download all models and embeddings - python download_artifacts.py all - - # Download only the embeddings file - python download_artifacts.py all-embeddings - - # Download the exemplar model to a custom directory - python download_artifacts.py tf-exemplar --checkpoint-dir /path/to/models - -The downloaded models will be extracted to: - ./checkpoints/tf_sapiens/ - ./checkpoints/tf_exemplar/ - ./checkpoints/tf_metazoa/ - ./checkpoints/all_embeddings/ -""" - -import argparse -import math -import sys -import tarfile -import tempfile -import urllib.error -import urllib.request -from pathlib import Path - - -def print_progress(current, total, prefix="", suffix="", length=50): - """Print a simple progress bar.""" - filled = int(length * current / total) - bar = "█" * filled + "░" * (length - filled) - percent = math.floor(100 * current / total) - print(f"\r{prefix} |{bar}| {percent}% {suffix}", end="", flush=True) - if current == total: - print() - - -def download_and_extract(model_name: str, checkpoint_dir: str = "./checkpoints"): - """Download and extract a model artifact from S3.""" - s3_path = f"https://czi-transcriptformer.s3.amazonaws.com/weights/{model_name}.tar.gz" - output_dir = Path(checkpoint_dir) / model_name - - # Create checkpoint directory if it doesn't exist - output_dir.parent.mkdir(parents=True, exist_ok=True) - - print(f"Downloading {model_name} from {s3_path}...") - - try: - # Create a temporary file to store the tar.gz - with tempfile.NamedTemporaryFile(suffix=".tar.gz") as tmp_file: - # Download the file using urllib with progress bar - try: - - def report_hook(count, block_size, total_size): - """Callback function to report download progress.""" - if total_size > 0 and (count % 100 == 0 or count * block_size >= total_size): - print_progress( - count * block_size, - total_size, - prefix=f"Downloading {model_name}", - ) - - urllib.request.urlretrieve(s3_path, filename=tmp_file.name, reporthook=report_hook) - print() # New line after download completes - except urllib.error.HTTPError as e: - if e.code == 404: - print(f"Error: The model {model_name} was not found at {s3_path}") - else: - print(f"Error downloading file: HTTP {e.code}") - sys.exit(1) - except urllib.error.URLError as e: - print(f"Error downloading file: {str(e)}") - sys.exit(1) - - # Reset the file pointer to the beginning - tmp_file.seek(0) - - # Extract the tar.gz file - try: - print(f"Extracting {model_name}...") - with tarfile.open(fileobj=tmp_file, mode="r:gz") as tar: - members = tar.getmembers() - total_files = len(members) - for i, member in enumerate(members, 1): - tar.extract(member, path=str(output_dir.parent)) - print_progress( - i, - total_files, - prefix=f"Extracting {model_name}", - ) - print() # New line after extraction completes - except tarfile.ReadError: - print(f"Error: The downloaded file for {model_name} is not a valid tar.gz archive") - sys.exit(1) - - print(f"Successfully downloaded and extracted {model_name} to {output_dir}") - - except Exception as e: - print(f"Error: {str(e)}") - sys.exit(1) - - -def main(): - parser = argparse.ArgumentParser(description="Download and extract TranscriptFormer model artifacts") - parser.add_argument( - "model", - choices=["tf-sapiens", "tf-exemplar", "tf-metazoa", "all", "all-embeddings"], - help="Model to download (or 'all' for all models and embeddings, 'all-embeddings' for just embeddings)", - ) - parser.add_argument( - "--checkpoint-dir", - default="./checkpoints", - help="Directory to store the downloaded checkpoints (default: ./checkpoints)", - ) - - args = parser.parse_args() - - models = { - "tf-sapiens": "tf_sapiens", - "tf-exemplar": "tf_exemplar", - "tf-metazoa": "tf_metazoa", - "all-embeddings": "all_embeddings", - } - - if args.model == "all": - # Download all models and embeddings - for model in ["tf_sapiens", "tf_exemplar", "tf_metazoa", "all_embeddings"]: - download_and_extract(model, args.checkpoint_dir) - elif args.model == "all-embeddings": - # Download only embeddings - download_and_extract("all_embeddings", args.checkpoint_dir) - else: - download_and_extract(models[args.model], args.checkpoint_dir) - - -if __name__ == "__main__": - main() diff --git a/inference.py b/inference.py deleted file mode 100644 index 2992fe3..0000000 --- a/inference.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Script to perform inference with Transcriptformer models. - -Example usage: - python inference.py --config-name=inference_config.yaml \ - model.checkpoint_path=./checkpoints/tf_sapiens \ - model.inference_config.data_files.0=test/data/human_val.h5ad \ - model.inference_config.output_path=./custom_results_dir \ - model.inference_config.output_filename=custom_embeddings.h5ad \ - model.inference_config.batch_size=8 -""" - -import json -import logging -import os - -import hydra -from omegaconf import DictConfig, OmegaConf - -from transcriptformer.model.inference import run_inference - -# Set up logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - - -@hydra.main( - config_path=os.path.join(os.path.dirname(__file__), "conf"), - config_name="inference_config.yaml", - version_base=None, -) -def main(cfg: DictConfig): - logging.debug(OmegaConf.to_yaml(cfg)) - - config_path = os.path.join(cfg.model.checkpoint_path, "config.json") - with open(config_path) as f: - config_dict = json.load(f) - mlflow_cfg = OmegaConf.create(config_dict) - - # Merge the MLflow config with the main config - cfg = OmegaConf.merge(mlflow_cfg, cfg) - - # Set the checkpoint paths based on the unified checkpoint_path - cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt") - cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - - adata_output = run_inference(cfg, data_files=cfg.model.inference_config.data_files) - - # Save the output adata - output_path = cfg.model.inference_config.output_path - if not os.path.exists(output_path): - os.makedirs(output_path) - - # Get output filename from config or use default - output_filename = getattr(cfg.model.inference_config, "output_filename", "embeddings.h5ad") - save_file = os.path.join(output_path, output_filename) - - adata_output.write_h5ad(save_file) - logging.info(f"Saved embeddings to {save_file}") - - -if __name__ == "__main__": - main() diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index d1b66f1..82a97d4 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -24,6 +24,7 @@ --gene-col-name Column in AnnData.var with gene identifiers --precision Numerical precision (16-mixed or 32) --pretrained-embedding Path to embedding file for out-of-distribution species + --num-gpus Number of GPUs to use (1=single, -1=all available, >1=specific number) Advanced Configuration: Use --config-override for any configuration options not exposed as arguments above. @@ -52,10 +53,16 @@ """ import argparse +import json import logging +import os import sys import warnings +import torch +from omegaconf import OmegaConf +from transcriptformer.model.inference import run_inference + # Suppress annoying warnings warnings.filterwarnings("ignore", category=FutureWarning, module="anndata") warnings.filterwarnings("ignore", category=FutureWarning, message=".*read_.*from.*anndata.*deprecated.*") @@ -75,7 +82,6 @@ \033[38;2;108;113;131m |_| \033[0m""" - def setup_inference_parser(subparsers): """Setup the parser for the inference command.""" parser = subparsers.add_parser( @@ -164,6 +170,12 @@ def setup_inference_parser(subparsers): default=False, help="Remove duplicate genes if found instead of raising an error (default: False)", ) + parser.add_argument( + "--num-gpus", + type=int, + default=1, + help="Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) (default: 1)", + ) # Allow arbitrary config overrides parser.add_argument( @@ -242,45 +254,92 @@ def setup_download_data_parser(subparsers): def run_inference_cli(args): """Run inference using command line arguments.""" - # Import the inference module directly - from transcriptformer.cli.inference import main as inference_main - - # Create a hydra-compatible config dictionary for direct use with inference.py - cmd = [ - "--config-name=inference_config.yaml", - f"model.checkpoint_path={args.checkpoint_path}", - f"model.inference_config.data_files.0={args.data_file}", - f"model.inference_config.batch_size={args.batch_size}", - f"model.data_config.gene_col_name={args.gene_col_name}", - f"model.inference_config.output_path={args.output_path}", - f"model.inference_config.output_filename={args.output_filename}", - f"model.inference_config.precision={args.precision}", - f"model.model_type={args.model_type}", - f"model.inference_config.emb_type={args.emb_type}", - f"model.data_config.remove_duplicate_genes={args.remove_duplicate_genes}", - ] + # Only print logo if not in distributed mode (avoids duplicates) + is_distributed = args.num_gpus != 1 + if not is_distributed: + print(TF_LOGO) + + # Load the config + config_path = os.path.join(os.path.dirname(__file__), "conf", "inference_config.yaml") + cfg = OmegaConf.load(config_path) + + # Load model config from checkpoint + model_config_path = os.path.join(args.checkpoint_path, "config.json") + with open(model_config_path) as f: + config_dict = json.load(f) + mlflow_cfg = OmegaConf.create(config_dict) + + # Merge the MLflow config with the main config + cfg = OmegaConf.merge(mlflow_cfg, cfg) + + # Override config values with CLI arguments + cfg.model.checkpoint_path = args.checkpoint_path + cfg.model.inference_config.data_files = [args.data_file] + cfg.model.inference_config.batch_size = args.batch_size + cfg.model.data_config.gene_col_name = args.gene_col_name + cfg.model.inference_config.output_path = args.output_path + cfg.model.inference_config.output_filename = args.output_filename + cfg.model.inference_config.precision = args.precision + cfg.model.model_type = args.model_type + cfg.model.inference_config.emb_type = args.emb_type + cfg.model.data_config.remove_duplicate_genes = args.remove_duplicate_genes + cfg.model.inference_config.num_gpus = args.num_gpus + # Add pretrained embedding if specified if args.pretrained_embedding: - cmd.append(f"model.inference_config.pretrained_embedding={args.pretrained_embedding}") - - # Add any arbitrary config overrides + cfg.model.inference_config.pretrained_embedding = args.pretrained_embedding + + # Apply any arbitrary config overrides for override in args.config_override: - cmd.append(override) - - # Print logo - print(TF_LOGO) - - # Override sys.argv for Hydra to pick up - saved_argv = sys.argv - sys.argv = [sys.argv[0]] + cmd - - try: - # Call the main function directly - inference_main() - finally: - # Restore original sys.argv - sys.argv = saved_argv + if '=' not in override: + continue + key, value = override.split('=', 1) + # Convert value to appropriate type + try: + # Try to parse as a number or boolean + if value.lower() in ['true', 'false']: + value = value.lower() == 'true' + elif value.isdigit(): + value = int(value) + elif '.' in value and all(part.isdigit() for part in value.split('.')): + value = float(value) + except: + pass # Keep as string if conversion fails + + OmegaConf.set(cfg, key, value) + + # Set the checkpoint paths based on the unified checkpoint_path + cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt") + cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs") + cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs") + + # Run inference directly + adata_output = run_inference(cfg, data_files=cfg.model.inference_config.data_files) + + # Save the output adata + output_path = cfg.model.inference_config.output_path + if not os.path.exists(output_path): + os.makedirs(output_path) + + # Get output filename from config or use default + output_filename = getattr(cfg.model.inference_config, "output_filename", "embeddings.h5ad") + if not output_filename.endswith(".h5ad"): + output_filename = f"{output_filename}.h5ad" + save_file = os.path.join(output_path, output_filename) + + # Check if we're in a distributed environment + if is_distributed: + rank = torch.distributed.get_rank() + + # Split the filename and add rank before extension + rank_file = save_file.replace('.h5ad', f'_{rank}.h5ad') + adata_output.write_h5ad(rank_file) + print(f"Rank {rank} completed processing, saved partial results to {rank_file}") + else: + # Single GPU mode - save normally + adata_output.write_h5ad(save_file) + print(f"Inference completed! Saved embeddings to {save_file}") def run_download_cli(args): @@ -319,9 +378,6 @@ def run_download_data_cli(args): # Parse species list species_list = [s.strip() for s in args.species.split(",")] if args.species else [] - # Print logo - print(TF_LOGO) - # Run the download try: successful_downloads = download_data_main( diff --git a/src/transcriptformer/cli/conf/inference_config.yaml b/src/transcriptformer/cli/conf/inference_config.yaml index 4523c5b..b766bb9 100644 --- a/src/transcriptformer/cli/conf/inference_config.yaml +++ b/src/transcriptformer/cli/conf/inference_config.yaml @@ -23,6 +23,7 @@ model: pretrained_embedding: null # Path to pretrained embeddings for out-of-distribution species precision: 16-mixed # Numerical precision for inference (16-mixed, 32, etc.) emb_type: cell # Type of embeddings to extract: "cell" for mean-pooled cell embeddings or "cge" for contextual gene embeddings + num_gpus: 1 # Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) data_config: _target_: transcriptformer.data.dataclasses.DataConfig diff --git a/src/transcriptformer/data/dataclasses.py b/src/transcriptformer/data/dataclasses.py index 6cf8887..a611c2c 100644 --- a/src/transcriptformer/data/dataclasses.py +++ b/src/transcriptformer/data/dataclasses.py @@ -178,6 +178,7 @@ class InferenceConfig: output_path (str): Path to save outputs output_filename (str): Filename for the output embeddings (default: embeddings.h5ad) num_gpus_per_node (int): GPUs per node (default: 1) + num_gpus (int): Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) (default: 1) special_tokens (list): Special tokens to use emb_type (str): Type of embeddings to extract - "cell" for mean-pooled cell embeddings or "cge" for contextual gene embeddings (default: "cell") """ @@ -190,6 +191,7 @@ class InferenceConfig: output_path: str | None output_filename: str | None = "embeddings.h5ad" num_gpus_per_node: int = 1 + num_gpus: int = 1 num_nodes: int = 1 precision: str = "16-mixed" special_tokens: list = field(default_factory=list) diff --git a/src/transcriptformer/model/inference.py b/src/transcriptformer/model/inference.py index 98605e3..4092e31 100644 --- a/src/transcriptformer/model/inference.py +++ b/src/transcriptformer/model/inference.py @@ -125,14 +125,40 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): collate_fn=dataset.collate_fn, ) + # Determine number of GPUs to use + num_gpus = getattr(cfg.model.inference_config, 'num_gpus', 1) + + # Handle special cases for num_gpus + if num_gpus == -1: + # Use all available GPUs + devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 + accelerator = "gpu" if torch.cuda.is_available() else "cpu" + elif num_gpus > 1: + # Use specified number of GPUs + available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + if available_gpus < num_gpus: + logging.warning(f"Requested {num_gpus} GPUs but only {available_gpus} available. Using {available_gpus} GPUs.") + devices = available_gpus if available_gpus > 0 else 1 + accelerator = "gpu" if available_gpus > 0 else "cpu" + else: + devices = num_gpus + accelerator = "gpu" + else: + # Use single GPU or CPU + devices = 1 + accelerator = "gpu" if torch.cuda.is_available() else "cpu" + + logging.info(f"Using {devices} device(s) with accelerator: {accelerator}") + # Create Trainer trainer = pl.Trainer( - accelerator="gpu", - devices=1, # Multiple GPUs/nodes not supported for inference + accelerator=accelerator, + devices=devices, num_nodes=1, precision=cfg.model.inference_config.precision, limit_predict_batches=None, logger=CSVLogger("logs", name="inference"), + strategy="ddp" if devices > 1 else "auto", # Use DDP for multi-GPU ) # Run prediction From c21652684db67b378b7f36e2b4b8a0526ffe2bc2 Mon Sep 17 00:00:00 2001 From: jpearce Date: Tue, 12 Aug 2025 13:49:10 -0700 Subject: [PATCH 02/11] update --- src/transcriptformer/cli/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index 82a97d4..b2ae19a 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -284,6 +284,7 @@ def run_inference_cli(args): cfg.model.model_type = args.model_type cfg.model.inference_config.emb_type = args.emb_type cfg.model.data_config.remove_duplicate_genes = args.remove_duplicate_genes + cfg.model.data_config.use_raw = args.use_raw cfg.model.inference_config.num_gpus = args.num_gpus # Add pretrained embedding if specified From e56ac8ce5b2d41bde6b052b7e8a09cdb0fa2ffa2 Mon Sep 17 00:00:00 2001 From: jpearce Date: Tue, 12 Aug 2025 16:29:21 -0700 Subject: [PATCH 03/11] feat: data streaming for oom datasets --- README.md | 13 + src/transcriptformer/cli/__init__.py | 22 + .../cli/conf/inference_config.yaml | 2 + src/transcriptformer/data/dataclasses.py | 3 + src/transcriptformer/data/dataloader.py | 405 ++++++++++++++---- src/transcriptformer/model/inference.py | 12 +- 6 files changed, 365 insertions(+), 92 deletions(-) diff --git a/README.md b/README.md index c0f837b..b563c81 100644 --- a/README.md +++ b/README.md @@ -277,6 +277,8 @@ transcriptformer download-data --help - `--pretrained-embedding PATH`: Path to pretrained embeddings for out-of-distribution species. - `--clip-counts INT`: Maximum count value (higher values will be clipped) (default: 30). - `--filter-to-vocabs`: Whether to filter genes to only those in the vocabulary (default: True). +- `--use-iterable-dataset`: Use a streaming IterableDataset for low-memory processing; yields cells on-the-fly (default: False). +- `--iterable-chunk-size INT`: Optional chunk size (number of cells) processed per step when using the iterable dataset. - `--use-raw {True,False,auto}`: Whether to use raw counts from `AnnData.raw.X` (True), `adata.X` (False), or auto-detect (auto/None) (default: None). - `--embedding-layer-index INT`: Index of the transformer layer to extract embeddings from (-1 for last layer, default: -1). Use with `transcriptformer` model type. - `--model-type {transcriptformer,esm2ce}`: Type of model to use (default: `transcriptformer`). Use `esm2ce` to extract raw ESM2-CE gene embeddings. @@ -301,6 +303,17 @@ Input data files should be in H5AD format (AnnData objects) with the following r - `True`: Use only `adata.raw.X` - `False`: Use only `adata.X` + - **Streaming Inference**: + - To reduce peak memory usage on large datasets, enable streaming: + ```bash + transcriptformer inference \ + --checkpoint-path ./checkpoints/tf_sapiens \ + --data-file ./data/huge.h5ad \ + --use-iterable-dataset \ + --iterable-chunk-size 4096 + ``` + - This uses an IterableDataset that processes cells in chunks and yields items to the DataLoader without loading all cells into memory. + - **Count Processing**: - Count values are clipped at 30 by default (as was done in training) - If this seems too low, you can either: diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index d1b66f1..935c76d 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -164,6 +164,18 @@ def setup_inference_parser(subparsers): default=False, help="Remove duplicate genes if found instead of raising an error (default: False)", ) + parser.add_argument( + "--use-iterable-dataset", + action="store_true", + default=False, + help="Use streaming IterableDataset for low-memory processing (default: False)", + ) + parser.add_argument( + "--iterable-chunk-size", + type=int, + default=4096, + help="Chunk size of rows per processing step when using the iterable dataset (default: auto)", + ) # Allow arbitrary config overrides parser.add_argument( @@ -258,8 +270,17 @@ def run_inference_cli(args): f"model.model_type={args.model_type}", f"model.inference_config.emb_type={args.emb_type}", f"model.data_config.remove_duplicate_genes={args.remove_duplicate_genes}", + f"model.inference_config.use_iterable_dataset={args.use_iterable_dataset}", + f"model.data_config.use_raw={args.use_raw}", + f"model.data_config.clip_counts={args.clip_counts}", + f"model.data_config.filter_to_vocabs={args.filter_to_vocabs}", + ] + # Only pass iterable_chunk_size if explicitly provided (not None) + # Always pass explicit chunk size (has default) + cmd.append(f"model.inference_config.iterable_chunk_size={args.iterable_chunk_size}") + # Add pretrained embedding if specified if args.pretrained_embedding: cmd.append(f"model.inference_config.pretrained_embedding={args.pretrained_embedding}") @@ -268,6 +289,7 @@ def run_inference_cli(args): for override in args.config_override: cmd.append(override) + print("USE RAW", args.use_raw, type(args.use_raw)) # Print logo print(TF_LOGO) diff --git a/src/transcriptformer/cli/conf/inference_config.yaml b/src/transcriptformer/cli/conf/inference_config.yaml index 4523c5b..ad07f53 100644 --- a/src/transcriptformer/cli/conf/inference_config.yaml +++ b/src/transcriptformer/cli/conf/inference_config.yaml @@ -23,6 +23,8 @@ model: pretrained_embedding: null # Path to pretrained embeddings for out-of-distribution species precision: 16-mixed # Numerical precision for inference (16-mixed, 32, etc.) emb_type: cell # Type of embeddings to extract: "cell" for mean-pooled cell embeddings or "cge" for contextual gene embeddings + use_iterable_dataset: false # Use streaming IterableDataset for low-memory processing + iterable_chunk_size: 4096 # Chunk size (rows per step) when using iterable dataset data_config: _target_: transcriptformer.data.dataclasses.DataConfig diff --git a/src/transcriptformer/data/dataclasses.py b/src/transcriptformer/data/dataclasses.py index 6cf8887..1e4c34f 100644 --- a/src/transcriptformer/data/dataclasses.py +++ b/src/transcriptformer/data/dataclasses.py @@ -195,6 +195,9 @@ class InferenceConfig: special_tokens: list = field(default_factory=list) pretrained_embedding: list = field(default_factory=list) emb_type: str = "cell" + # Streaming/iterable dataset options (moved from data_config) + use_iterable_dataset: bool = False # If True, use StreamingAnnDataset instead of AnnDataset + iterable_chunk_size: int = 4096 # Per-file processing chunk size for streaming def __post_init__(self): if self.emb_type not in {"cell", "cge"}: diff --git a/src/transcriptformer/data/dataloader.py b/src/transcriptformer/data/dataloader.py index 58371f5..6c7792b 100644 --- a/src/transcriptformer/data/dataloader.py +++ b/src/transcriptformer/data/dataloader.py @@ -9,7 +9,7 @@ import torch from scipy.sparse import csc_matrix, csr_matrix from torch import tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterableDataset from transcriptformer.data.dataclasses import BatchData from transcriptformer.tokenizer.tokenizer import ( @@ -39,9 +39,12 @@ def apply_filters( min_expressed_genes, ): """Apply filters to the data.""" + n_cells = X.shape[0] + if filter_to_vocab: filter_idx = [i for i, name in enumerate(gene_names) if name in vocab] X = X[:, filter_idx] + logging.info(f"Filtered {len(gene_names)} genes to {len(filter_idx)} genes in vocab") gene_names = gene_names[filter_idx] if X.shape[1] == 0: logging.warning(f"Warning: Filtered all genes from {file_path}") @@ -64,6 +67,8 @@ def apply_filters( X = X[filter_idx] obs = obs.iloc[filter_idx] + logging.info(f"Filtered {n_cells} cells to {X.shape[0]} cells") + return X, obs, gene_names @@ -154,6 +159,87 @@ def process_batch( return result +def get_counts_layer(adata: anndata.AnnData, use_raw: bool | None): + if use_raw is True: + if adata.raw is not None: + logging.info("Using 'raw.X' layer from AnnData object") + return adata.raw.X + else: + raise ValueError("raw.X not found in AnnData object") + elif use_raw is False: + if adata.X is not None: + logging.info("Using 'X' layer from AnnData object") + return adata.X + else: + raise ValueError("X not found in AnnData object") + else: # None - try raw first, then fallback to X + if adata.raw is not None: + logging.info("Using 'raw.X' layer from AnnData object") + return adata.raw.X + elif adata.X is not None: + logging.info("Using 'X' layer from AnnData object") + return adata.X + else: + raise ValueError("No valid data layer found in AnnData object") + + +def to_dense(X: np.ndarray | csr_matrix | csc_matrix) -> np.ndarray: + if isinstance(X, (csr_matrix, csc_matrix)): + return X.toarray() + elif isinstance(X, np.ndarray): + return X + else: + raise TypeError(f"Expected numpy array or sparse matrix, got {type(X)}") + + +def is_raw_counts(X: np.ndarray) -> bool: + non_zero_mask = X > 0 + if not np.any(non_zero_mask): + return False + non_zero_values = X[non_zero_mask] + if len(non_zero_values) > 1000: + non_zero_values = np.random.choice(non_zero_values, 1000, replace=False) + is_integer = np.all(np.abs(non_zero_values - np.round(non_zero_values)) < 1e-6) + return is_integer + + +def load_gene_features(adata: anndata.AnnData, gene_col_name: str, remove_duplicate_genes: bool): + try: + gene_names = np.array(list(adata.var[gene_col_name].values)) + gene_names = np.array([id.split(".")[0] for id in gene_names]) + + gene_counts = Counter(gene_names) + duplicates = {gene for gene, count in gene_counts.items() if count > 1} + if len(duplicates) > 0: + if remove_duplicate_genes: + seen = set() + unique_indices = [] + for i, gene in enumerate(gene_names): + if gene not in seen: + seen.add(gene) + unique_indices.append(i) + adata = adata[:, unique_indices].copy() + gene_names = gene_names[unique_indices] + logging.warning( + f"Removed {len(duplicates)} duplicate genes after removing version numbers. Kept first occurrence." + ) + else: + raise ValueError( + "Found duplicate genes after removing version numbers. " + "Remove duplicates or pass --remove-duplicate-genes." + ) + + return gene_names, True, adata + except KeyError: + return None, False, adata + + +def validate_gene_dimension(X: np.ndarray, gene_names: np.ndarray, gene_col_name: str): + if X.shape[1] != len(gene_names): + raise ValueError( + f"Mismatch between expression matrix columns ({X.shape[1]}) and gene names length ({len(gene_names)}). " + f"Ensure 'adata.var[{gene_col_name}]' exists and aligns with the matrix columns." + ) class AnnDataset(Dataset): def __init__( @@ -210,91 +296,6 @@ def __init__( logging.info("Loading and processing all data") self.data = self.load_and_process_all_data() - def _get_counts_layer(self, adata: anndata.AnnData) -> str: - if self.use_raw is True: - if adata.raw is not None: - logging.info("Using 'raw.X' layer from AnnData object") - return adata.raw.X - else: - raise ValueError("raw.X not found in AnnData object") - elif self.use_raw is False: - if adata.X is not None: - logging.info("Using 'X' layer from AnnData object") - return adata.X - else: - raise ValueError("X not found in AnnData object") - else: # None - try raw first, then fallback to X - if adata.raw is not None: - logging.info("Using 'raw.X' layer from AnnData object") - return adata.raw.X - elif adata.X is not None: - logging.info("Using 'X' layer from AnnData object") - return adata.X - else: - raise ValueError("No valid data layer found in AnnData object") - - def _to_dense(self, X: np.ndarray | csr_matrix | csc_matrix) -> np.ndarray: - if isinstance(X, csr_matrix | csc_matrix): - return X.toarray() - elif isinstance(X, np.ndarray): - return X - else: - raise TypeError(f"Expected numpy array or sparse matrix, got {type(X)}") - - def _is_raw_counts(self, X: np.ndarray) -> bool: - # Get non-zero values - non_zero_mask = X > 0 - if not np.any(non_zero_mask): - return False - - # Sample up to 1000 non-zero values - non_zero_values = X[non_zero_mask] - if len(non_zero_values) > 1000: - non_zero_values = np.random.choice(non_zero_values, 1000, replace=False) - - # Check if values are roughly integers (within float32 precision) - # float32 has ~7 decimal digits of precision - is_integer = np.all(np.abs(non_zero_values - np.round(non_zero_values)) < 1e-6) - - return is_integer - - def _load_gene_features(self, adata): - """Load gene features and remove version numbers from ensembl ids""" - try: - gene_names = np.array(list(adata.var[self.gene_col_name].values)) - gene_names = np.array([id.split(".")[0] for id in gene_names]) - - # Check for duplicates after removing version numbers - gene_counts = Counter(gene_names) - duplicates = {gene for gene, count in gene_counts.items() if count > 1} - if len(duplicates) > 0: - if self.remove_duplicate_genes: - # Remove duplicates by keeping only the first occurrence - seen = set() - unique_indices = [] - for i, gene in enumerate(gene_names): - if gene not in seen: - seen.add(gene) - unique_indices.append(i) - - # Filter adata to keep only unique genes - adata = adata[:, unique_indices].copy() - gene_names = gene_names[unique_indices] - - logging.warning( - f"Removed {len(duplicates)} duplicate genes after removing version numbers. " - f"Kept first occurrence of each gene. " - ) - else: - raise ValueError( - f"Found {len(duplicates)} duplicate genes after removing version numbers. " - f"Please remove duplicate genes from your data or use --remove-duplicate-genes flag. " - ) - - return gene_names, True, adata - except KeyError: - return None, False, adata - def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: if isinstance(file, str): file_path = file @@ -313,22 +314,29 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: logging.error(f"Failed to load data from {file_path}") return None - gene_names, success, adata = self._load_gene_features(adata) + gene_names, success, adata = load_gene_features( + adata, self.gene_col_name, self.remove_duplicate_genes + ) if not success: logging.error(f"Failed to load gene features from {file_path}") return None - X = self._get_counts_layer(adata) - X = self._to_dense(X) + X = get_counts_layer(adata, self.use_raw) + X = to_dense(X) obs = adata.obs + # Validate that gene dimension matches number of gene names + validate_gene_dimension(X, gene_names, self.gene_col_name) + # Check if the data appears to be raw counts - if not self._is_raw_counts(X): + logging.info(f"Checking if data is raw counts") + if not is_raw_counts(X): logging.warning( "Data does not appear to be raw counts. TranscriptFormer expects unnormalized count data. " "If your data is normalized, consider using the original count matrix instead." ) + logging.info(f"Applying filters") vocab = self.gene_vocab X, obs, gene_names = apply_filters( X, @@ -340,10 +348,12 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: self.filter_outliers, self.min_expressed_genes, ) + if X is None: logging.warning(f"Data was filtered out completely for {file_path}") return None + logging.info(f"Processing data") batch = process_batch( X, obs, @@ -447,3 +457,218 @@ def collate_fn(batch: BatchData | list[BatchData]) -> BatchData: ), ) return collated_batch + + +class StreamingAnnDataset(IterableDataset): + """Iterable variant of AnnDataset that loads and processes data iteratively. + + This class reuses the same top-level helpers in this module: + - load_data + - apply_filters + - process_batch + + It yields one cell at a time as a BatchData instance, allowing PyTorch's + DataLoader to batch items using the existing AnnDataset.collate_fn. + """ + + # Reuse the same collate function for batching behavior identical to AnnDataset + collate_fn = staticmethod(AnnDataset.collate_fn) + + def __init__( + self, + files_list: list[str] | list[anndata.AnnData], + gene_vocab: dict[str, str], + data_dir: str = None, + aux_vocab: dict[str, dict[str, str]] = None, + max_len: int = 2048, + normalize_to_scale: bool = None, + sort_genes: bool = False, + randomize_order: bool = False, + pad_zeros: bool = True, + gene_col_name: str = "ensembl_id", + filter_to_vocab: bool = True, + filter_outliers: float = 0.0, + min_expressed_genes: int = 0, + seed: int = 0, + pad_token: str = "[PAD]", + clip_counts: float = 1e10, + inference: bool = False, + obs_keys: list[str] = None, + use_raw: bool = None, + remove_duplicate_genes: bool = False, + iter_chunk_size: int | None = None, + ): + super().__init__() + self.files_list = files_list + self.data_dir = data_dir + self.gene_vocab = gene_vocab + self.aux_vocab = aux_vocab + self.max_len = max_len + self.normalize_to_scale = normalize_to_scale + self.sort_genes = sort_genes + self.randomize_order = randomize_order + self.pad_zeros = pad_zeros + self.gene_col_name = gene_col_name + self.filter_to_vocab = filter_to_vocab + self.filter_outliers = filter_outliers + self.min_expressed_genes = min_expressed_genes + self.seed = seed + self.pad_token = pad_token + self.clip_counts = clip_counts + self.inference = inference + self.obs_keys = obs_keys + self.use_raw = use_raw + self.remove_duplicate_genes = remove_duplicate_genes + self.iter_chunk_size = iter_chunk_size + + self.gene_tokenizer = BatchGeneTokenizer(gene_vocab) + if aux_vocab is not None: + self.aux_tokenizer = BatchObsTokenizer(aux_vocab) + + random.seed(self.seed) + + # Estimate total cells for progress bars (approximate; ignores filtering) + self._estimated_total_cells = 0 + for file in self.files_list: + try: + if isinstance(file, str): + file_path = file if self.data_dir is None else os.path.join(self.data_dir, file) + ad = anndata.read_h5ad(file_path, backed='r') + self._estimated_total_cells += int(getattr(ad, 'n_obs', 0)) + if hasattr(ad, 'file') and ad.file is not None: + ad.file.close() + elif isinstance(file, anndata.AnnData): + self._estimated_total_cells += int(file.n_obs) + except Exception: + # If estimation fails for a file, skip it (progress bar may be undercounted) + continue + + + + def _yield_cells_from_arrays(self, X: np.ndarray, obs, gene_names: np.ndarray, file_path: str | None): + """Yield BatchData objects one cell at a time from arrays, processing in chunks.""" + num_cells = X.shape[0] + chunk = self.iter_chunk_size + for start in range(0, num_cells, chunk): + end = min(start + chunk, num_cells) + x_chunk = X[start:end] + obs_chunk = obs.iloc[start:end] + gene_names_chunk = gene_names + + batch = process_batch( + x_chunk, + obs_chunk, + gene_names_chunk, + self.gene_tokenizer, + getattr(self, "aux_tokenizer", None), + self.sort_genes, + self.randomize_order, + self.max_len, + self.pad_zeros, + self.pad_token, + self.gene_vocab, + self.normalize_to_scale, + self.clip_counts, + self.aux_vocab, + ) + + # Attach obs as needed to match AnnDataset behavior + if self.obs_keys is not None: + obs_data = {} + if "all" in self.obs_keys: + cols = obs_chunk.columns + for col in cols: + obs_data[col] = np.array(obs_chunk[col].tolist())[:, None] + else: + for col in self.obs_keys: + obs_data[col] = np.array(obs_chunk[col].tolist())[:, None] + else: + obs_data = None + + # Yield one cell at a time so DataLoader can batch them + for i in range(end - start): + yield BatchData( + gene_counts=batch["gene_counts"][i], + gene_token_indices=batch["gene_token_indices"][i], + file_path=None, + aux_token_indices=( + batch.get("aux_token_indices")[i] if batch.get("aux_token_indices") is not None else None + ), + obs=( + {col: obs_data[col][i][None, :] for col in obs_data} + if obs_data is not None + else None + ), + ) + + def __iter__(self): + # Deterministic iteration + random.seed(self.seed) + + for idx, file in enumerate(self.files_list): + logging.info(f"Streaming file {idx + 1} of {len(self.files_list)}") + + if isinstance(file, str): + file_path = file + if self.data_dir is not None: + file_path = os.path.join(self.data_dir, file_path) + adata, success = load_data(file_path) + elif isinstance(file, anndata.AnnData): + adata = file + success = True + file_path = None + else: + raise ValueError(f"Invalid file type: {type(file)}") + + if not success: + logging.error(f"Failed to load data from {file_path}") + continue + + gene_names, success, adata = load_gene_features( + adata, self.gene_col_name, self.remove_duplicate_genes + ) + if not success: + logging.error(f"Failed to load gene features from {file_path}") + continue + + X = get_counts_layer(adata, self.use_raw) + X = to_dense(X) + obs = adata.obs + + # Validate that gene dimension matches number of gene names + validate_gene_dimension(X, gene_names, self.gene_col_name) + + # Same checks as AnnDataset + logging.info("Checking if data is raw counts") + if not is_raw_counts(X): + logging.warning( + "Data does not appear to be raw counts. TranscriptFormer expects unnormalized count data. " + "If your data is normalized, consider using the original count matrix instead." + ) + + logging.info("Applying filters") + X, obs, gene_names = apply_filters( + X, + obs, + gene_names, + file_path, + self.filter_to_vocab, + self.gene_vocab, + self.filter_outliers, + self.min_expressed_genes, + ) + + if X is None: + logging.warning(f"Data was filtered out completely for {file_path}") + continue + + logging.info("Processing and yielding cells") + yield from self._yield_cells_from_arrays(X, obs, gene_names, file_path) + + def __len__(self) -> int: + """Return an estimated total number of samples for progress bars. + + Note: This is an upper bound prior to filtering; actual yielded samples + may be fewer. It enables progress bar display for IterableDataset. + """ + return int(self._estimated_total_cells) diff --git a/src/transcriptformer/model/inference.py b/src/transcriptformer/model/inference.py index 98605e3..c9eb4a2 100644 --- a/src/transcriptformer/model/inference.py +++ b/src/transcriptformer/model/inference.py @@ -9,7 +9,7 @@ from pytorch_lightning.loggers import CSVLogger from torch.utils.data import DataLoader -from transcriptformer.data.dataloader import AnnDataset +from transcriptformer.data.dataloader import AnnDataset, StreamingAnnDataset from transcriptformer.model.embedding_surgery import change_embedding_layer from transcriptformer.tokenizer.vocab import load_vocabs_and_embeddings from transcriptformer.utils.utils import stack_dict @@ -112,8 +112,16 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): "clip_counts": cfg.model.data_config.clip_counts, "obs_keys": cfg.model.inference_config.obs_keys, "remove_duplicate_genes": cfg.model.data_config.remove_duplicate_genes, + "use_raw": cfg.model.data_config.use_raw, } - dataset = AnnDataset(data_files, **data_kwargs) + if getattr(cfg.model.inference_config, "use_iterable_dataset", False): + # Optional chunk size + iter_kwargs = {} + if getattr(cfg.model.inference_config, "iterable_chunk_size", None) is not None: + iter_kwargs["iter_chunk_size"] = cfg.model.inference_config.iterable_chunk_size + dataset = StreamingAnnDataset(data_files, **data_kwargs, **iter_kwargs) + else: + dataset = AnnDataset(data_files, **data_kwargs) # Create dataloader dataloader = DataLoader( From 789f902d0880f087265256dce46af493cd7d723d Mon Sep 17 00:00:00 2001 From: jpearce Date: Wed, 13 Aug 2025 10:32:46 -0700 Subject: [PATCH 04/11] + test scripts; - old inf/dl scripts --- download_artifacts.py | 163 ------------------ inference.py | 63 ------- src/transcriptformer/cli/__init__.py | 5 +- test/test_compare_emb.py | 26 ++- test/test_compare_umap.py | 240 +++++++++++++++++++++++++++ 5 files changed, 259 insertions(+), 238 deletions(-) delete mode 100644 download_artifacts.py delete mode 100644 inference.py create mode 100644 test/test_compare_umap.py diff --git a/download_artifacts.py b/download_artifacts.py deleted file mode 100644 index f424263..0000000 --- a/download_artifacts.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 - -""" -Download and extract TranscriptFormer model artifacts from a public S3 bucket. - -This script provides a convenient way to download and extract TranscriptFormer model weights -from a public S3 bucket. It supports downloading individual models or all models at once, -with progress indicators for both download and extraction processes. - -Usage: - python download_artifacts.py [model] [--checkpoint-dir DIR] - - model: The model to download. Options are: - - tf-sapiens: Download the sapiens model - - tf-exemplar: Download the exemplar model - - tf-metazoa: Download the metazoa model - - all: Download all models and embeddings - - all-embeddings: Download only the embedding files - - --checkpoint-dir: Optional directory to store the downloaded checkpoints. - Defaults to './checkpoints' - -Examples --------- - # Download the sapiens model - python download_artifacts.py tf-sapiens - - # Download all models and embeddings - python download_artifacts.py all - - # Download only the embeddings file - python download_artifacts.py all-embeddings - - # Download the exemplar model to a custom directory - python download_artifacts.py tf-exemplar --checkpoint-dir /path/to/models - -The downloaded models will be extracted to: - ./checkpoints/tf_sapiens/ - ./checkpoints/tf_exemplar/ - ./checkpoints/tf_metazoa/ - ./checkpoints/all_embeddings/ -""" - -import argparse -import math -import sys -import tarfile -import tempfile -import urllib.error -import urllib.request -from pathlib import Path - - -def print_progress(current, total, prefix="", suffix="", length=50): - """Print a simple progress bar.""" - filled = int(length * current / total) - bar = "█" * filled + "░" * (length - filled) - percent = math.floor(100 * current / total) - print(f"\r{prefix} |{bar}| {percent}% {suffix}", end="", flush=True) - if current == total: - print() - - -def download_and_extract(model_name: str, checkpoint_dir: str = "./checkpoints"): - """Download and extract a model artifact from S3.""" - s3_path = f"https://czi-transcriptformer.s3.amazonaws.com/weights/{model_name}.tar.gz" - output_dir = Path(checkpoint_dir) / model_name - - # Create checkpoint directory if it doesn't exist - output_dir.parent.mkdir(parents=True, exist_ok=True) - - print(f"Downloading {model_name} from {s3_path}...") - - try: - # Create a temporary file to store the tar.gz - with tempfile.NamedTemporaryFile(suffix=".tar.gz") as tmp_file: - # Download the file using urllib with progress bar - try: - - def report_hook(count, block_size, total_size): - """Callback function to report download progress.""" - if total_size > 0 and (count % 100 == 0 or count * block_size >= total_size): - print_progress( - count * block_size, - total_size, - prefix=f"Downloading {model_name}", - ) - - urllib.request.urlretrieve(s3_path, filename=tmp_file.name, reporthook=report_hook) - print() # New line after download completes - except urllib.error.HTTPError as e: - if e.code == 404: - print(f"Error: The model {model_name} was not found at {s3_path}") - else: - print(f"Error downloading file: HTTP {e.code}") - sys.exit(1) - except urllib.error.URLError as e: - print(f"Error downloading file: {str(e)}") - sys.exit(1) - - # Reset the file pointer to the beginning - tmp_file.seek(0) - - # Extract the tar.gz file - try: - print(f"Extracting {model_name}...") - with tarfile.open(fileobj=tmp_file, mode="r:gz") as tar: - members = tar.getmembers() - total_files = len(members) - for i, member in enumerate(members, 1): - tar.extract(member, path=str(output_dir.parent)) - print_progress( - i, - total_files, - prefix=f"Extracting {model_name}", - ) - print() # New line after extraction completes - except tarfile.ReadError: - print(f"Error: The downloaded file for {model_name} is not a valid tar.gz archive") - sys.exit(1) - - print(f"Successfully downloaded and extracted {model_name} to {output_dir}") - - except Exception as e: - print(f"Error: {str(e)}") - sys.exit(1) - - -def main(): - parser = argparse.ArgumentParser(description="Download and extract TranscriptFormer model artifacts") - parser.add_argument( - "model", - choices=["tf-sapiens", "tf-exemplar", "tf-metazoa", "all", "all-embeddings"], - help="Model to download (or 'all' for all models and embeddings, 'all-embeddings' for just embeddings)", - ) - parser.add_argument( - "--checkpoint-dir", - default="./checkpoints", - help="Directory to store the downloaded checkpoints (default: ./checkpoints)", - ) - - args = parser.parse_args() - - models = { - "tf-sapiens": "tf_sapiens", - "tf-exemplar": "tf_exemplar", - "tf-metazoa": "tf_metazoa", - "all-embeddings": "all_embeddings", - } - - if args.model == "all": - # Download all models and embeddings - for model in ["tf_sapiens", "tf_exemplar", "tf_metazoa", "all_embeddings"]: - download_and_extract(model, args.checkpoint_dir) - elif args.model == "all-embeddings": - # Download only embeddings - download_and_extract("all_embeddings", args.checkpoint_dir) - else: - download_and_extract(models[args.model], args.checkpoint_dir) - - -if __name__ == "__main__": - main() diff --git a/inference.py b/inference.py deleted file mode 100644 index 2992fe3..0000000 --- a/inference.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Script to perform inference with Transcriptformer models. - -Example usage: - python inference.py --config-name=inference_config.yaml \ - model.checkpoint_path=./checkpoints/tf_sapiens \ - model.inference_config.data_files.0=test/data/human_val.h5ad \ - model.inference_config.output_path=./custom_results_dir \ - model.inference_config.output_filename=custom_embeddings.h5ad \ - model.inference_config.batch_size=8 -""" - -import json -import logging -import os - -import hydra -from omegaconf import DictConfig, OmegaConf - -from transcriptformer.model.inference import run_inference - -# Set up logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - - -@hydra.main( - config_path=os.path.join(os.path.dirname(__file__), "conf"), - config_name="inference_config.yaml", - version_base=None, -) -def main(cfg: DictConfig): - logging.debug(OmegaConf.to_yaml(cfg)) - - config_path = os.path.join(cfg.model.checkpoint_path, "config.json") - with open(config_path) as f: - config_dict = json.load(f) - mlflow_cfg = OmegaConf.create(config_dict) - - # Merge the MLflow config with the main config - cfg = OmegaConf.merge(mlflow_cfg, cfg) - - # Set the checkpoint paths based on the unified checkpoint_path - cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt") - cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - - adata_output = run_inference(cfg, data_files=cfg.model.inference_config.data_files) - - # Save the output adata - output_path = cfg.model.inference_config.output_path - if not os.path.exists(output_path): - os.makedirs(output_path) - - # Get output filename from config or use default - output_filename = getattr(cfg.model.inference_config, "output_filename", "embeddings.h5ad") - save_file = os.path.join(output_path, output_filename) - - adata_output.write_h5ad(save_file) - logging.info(f"Saved embeddings to {save_file}") - - -if __name__ == "__main__": - main() diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index 935c76d..7812a13 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -271,16 +271,13 @@ def run_inference_cli(args): f"model.inference_config.emb_type={args.emb_type}", f"model.data_config.remove_duplicate_genes={args.remove_duplicate_genes}", f"model.inference_config.use_iterable_dataset={args.use_iterable_dataset}", + f"model.inference_config.iterable_chunk_size={args.iterable_chunk_size}", f"model.data_config.use_raw={args.use_raw}", f"model.data_config.clip_counts={args.clip_counts}", f"model.data_config.filter_to_vocabs={args.filter_to_vocabs}", ] - # Only pass iterable_chunk_size if explicitly provided (not None) - # Always pass explicit chunk size (has default) - cmd.append(f"model.inference_config.iterable_chunk_size={args.iterable_chunk_size}") - # Add pretrained embedding if specified if args.pretrained_embedding: cmd.append(f"model.inference_config.pretrained_embedding={args.pretrained_embedding}") diff --git a/test/test_compare_emb.py b/test/test_compare_emb.py index 2c9178a..176fe70 100644 --- a/test/test_compare_emb.py +++ b/test/test_compare_emb.py @@ -3,7 +3,7 @@ import anndata as ad import numpy as np -from scipy.stats import pearsonr +from scipy.stats import pearsonr, ttest_rel def compare_embeddings(file1, file2, tolerance=1e-5): @@ -31,18 +31,18 @@ def compare_embeddings(file1, file2, tolerance=1e-5): adata2 = ad.read_h5ad(file2) # Check if embeddings exist - if "emb" not in adata1.obsm or "emb" not in adata2.obsm: + if "embeddings" not in adata1.obsm or "embeddings" not in adata2.obsm: missing = [] - if "emb" not in adata1.obsm: - missing.append(f"'emb' not found in {file1}") - if "emb" not in adata2.obsm: - missing.append(f"'emb' not found in {file2}") + if "embedding" not in adata1.obsm: + missing.append(f"'embedding' not found in {file1}") + if "embedding" not in adata2.obsm: + missing.append(f"'embedding' not found in {file2}") print(f"Error: {', '.join(missing)}") return False # Get embeddings - emb1 = adata1.obsm["emb"] - emb2 = adata2.obsm["emb"] + emb1 = adata1.obsm["embeddings"] + emb2 = adata2.obsm["embeddings"] # Check shapes if emb1.shape != emb2.shape: @@ -64,10 +64,20 @@ def compare_embeddings(file1, file2, tolerance=1e-5): emb2_flat = emb2.flatten() corr, _ = pearsonr(emb1_flat, emb2_flat) + # Calculate z-score and p-value for difference test + # Using paired t-test to test null hypothesis that embeddings are the same + diff = emb1_flat - emb2_flat + t_stat, p_value = ttest_rel(emb1_flat, emb2_flat) + + # Calculate z-score manually from the differences + z_score = np.mean(diff) / (np.std(diff) / np.sqrt(len(diff))) + print("Embeddings differ:") print(f" Max absolute difference: {max_diff:.6e}") print(f" Mean absolute difference: {mean_diff:.6e}") print(f" Pearson correlation: {corr:.6f}") + print(f" Z-score (difference): {z_score:.6f}") + print(f" P-value (paired t-test): {p_value:.6e}") return False diff --git a/test/test_compare_umap.py b/test/test_compare_umap.py new file mode 100644 index 0000000..1a561a9 --- /dev/null +++ b/test/test_compare_umap.py @@ -0,0 +1,240 @@ +import argparse +import os +import warnings + +import anndata as ad +import numpy as np +from scipy.spatial import procrustes +from scipy.stats import pearsonr + + +def _require_umap(): + try: + import umap.umap_ as umap # type: ignore + except Exception as exc: # pragma: no cover - informative error path + raise RuntimeError( + "This script requires the 'umap-learn' package. Install with 'pip install umap-learn'." + ) from exc + return umap + + +def compute_umap(embeddings: np.ndarray, n_neighbors: int = 15, min_dist: float = 0.1, metric: str = "euclidean", random_state: int | None = 42) -> np.ndarray: + """ + Compute a 2D UMAP from input embeddings. + + Parameters + ---------- + embeddings : np.ndarray + 2D array of shape (n_samples, n_features) + n_neighbors : int + UMAP n_neighbors parameter + min_dist : float + UMAP min_dist parameter + metric : str + Distance metric for UMAP + random_state : int | None + Random seed for reproducibility + + Returns + ------- + np.ndarray + 2D array of shape (n_samples, 2) with UMAP coordinates + """ + umap = _require_umap() + + # Suppress noisy warnings sometimes emitted by numba/umap during fit + warnings.filterwarnings("ignore", message=".*cannot cache compiled function.*") + reducer = umap.UMAP( + n_neighbors=n_neighbors, + min_dist=min_dist, + metric=metric, + n_components=2, + random_state=random_state, + ) + return reducer.fit_transform(embeddings) + + +def compare_umaps( + file1: str, + file2: str, + n_neighbors: int = 15, + min_dist: float = 0.1, + metric: str = "euclidean", + random_state: int | None = 42, + procrustes_tolerance: float = 1e-2, + save_plot: str | None = None, + obs_key: str | None = "cell_type", +) -> bool: + """ + Load embeddings from two AnnData files, compute UMAPs independently, and compare via Procrustes. + + Returns True if the Procrustes disparity <= procrustes_tolerance. + """ + # Load AnnData files + adata1 = ad.read_h5ad(file1) + adata2 = ad.read_h5ad(file2) + + if "embeddings" not in adata1.obsm or "embeddings" not in adata2.obsm: + missing = [] + if "embeddings" not in adata1.obsm: + missing.append(f"'embeddings' not found in {file1}") + if "embeddings" not in adata2.obsm: + missing.append(f"'embeddings' not found in {file2}") + raise ValueError( + "; ".join(missing) + ) + + emb1 = np.asarray(adata1.obsm["embeddings"]) # (n, d) + emb2 = np.asarray(adata2.obsm["embeddings"]) # (n, d) + + if emb1.shape[0] != emb2.shape[0]: + raise ValueError( + f"Number of rows differ between embeddings: {emb1.shape[0]} vs {emb2.shape[0]}" + ) + + # Compute UMAPs independently + umap1 = compute_umap(emb1, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, random_state=random_state) + umap2 = compute_umap(emb2, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, random_state=random_state) + + # Procrustes analysis aligns scale/rotation/translation; returns disparity (lower is better) + mtx1, mtx2, disparity = procrustes(umap1, umap2) + + # Simple additional similarity: correlation of flattened coordinates after alignment + corr, _ = pearsonr(mtx1.ravel(), mtx2.ravel()) + + print("UMAP comparison:") + print(f" Procrustes disparity: {disparity:.6e}") + print(f" Pearson correlation (aligned): {corr:.6f}") + + if save_plot is not None: + try: + import matplotlib.pyplot as plt # Lazy import for optional plotting + + # Prepare colors using a consistent palette across both datasets + labels1 = None + labels2 = None + color_map = None + if obs_key is not None: + try: + series1 = adata1.obs[obs_key] + series2 = adata2.obs[obs_key] + labels1 = series1.astype(str).to_numpy() + labels2 = series2.astype(str).to_numpy() + categories = sorted(list(set(labels1).union(set(labels2)))) + + import matplotlib as mpl + + if len(categories) <= 20: + cmap = mpl.cm.get_cmap("tab20", len(categories)) + palette = [mpl.colors.to_hex(cmap(i)) for i in range(len(categories))] + else: + # fallback palette for many categories + cmap = mpl.cm.get_cmap("hsv", len(categories)) + palette = [mpl.colors.to_hex(cmap(i)) for i in range(len(categories))] + color_map = {cat: palette[i] for i, cat in enumerate(categories)} + except Exception as e: + print(f"Warning: could not color by '{obs_key}': {e}") + + fig, axes = plt.subplots(1, 2, figsize=(10, 4)) + if labels1 is not None and color_map is not None: + colors1 = [color_map.get(lbl, "#000000") for lbl in labels1] + axes[0].scatter(mtx1[:, 0], mtx1[:, 1], s=3, alpha=0.6, c=colors1) + else: + axes[0].scatter(mtx1[:, 0], mtx1[:, 1], s=3, alpha=0.6) + axes[0].set_title("UMAP 1 (aligned)") + axes[0].set_xticks([]) + axes[0].set_yticks([]) + + if labels2 is not None and color_map is not None: + colors2 = [color_map.get(lbl, "#000000") for lbl in labels2] + axes[1].scatter(mtx2[:, 0], mtx2[:, 1], s=3, alpha=0.6, c=colors2) + else: + axes[1].scatter(mtx2[:, 0], mtx2[:, 1], s=3, alpha=0.6, color="tab:orange") + axes[1].set_title("UMAP 2 (aligned)") + axes[1].set_xticks([]) + axes[1].set_yticks([]) + + # Optional legend if number of categories is reasonable + if color_map is not None and len(color_map) > 0 and len(color_map) <= 20: + import matplotlib.patches as mpatches + + handles = [mpatches.Patch(color=clr, label=cat) for cat, clr in color_map.items()] + fig.legend(handles=handles, loc="lower center", ncol=min(5, len(handles)), frameon=False) + + plt.tight_layout() + out_dir = os.path.dirname(save_plot) + if out_dir and not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok=True) + plt.savefig(save_plot, dpi=150) + plt.close(fig) + print(f"Saved comparison plot to {save_plot}") + except Exception as e: # pragma: no cover - plotting not essential in tests + print(f"Warning: failed to save plot: {e}") + + # Pass criterion + return disparity <= procrustes_tolerance + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Compute UMAPs from two AnnData files' obsm['embeddings'] independently and compare them via Procrustes." + ) + ) + parser.add_argument("file1", type=str, help="Path to first AnnData .h5ad file") + parser.add_argument("file2", type=str, help="Path to second AnnData .h5ad file") + parser.add_argument("--n-neighbors", type=int, default=15, help="UMAP n_neighbors (default: 15)") + parser.add_argument("--min-dist", type=float, default=0.1, help="UMAP min_dist (default: 0.1)") + parser.add_argument("--metric", type=str, default="euclidean", help="UMAP metric (default: euclidean)") + parser.add_argument("--random-state", type=int, default=42, help="Random seed for UMAP (default: 42)") + parser.add_argument( + "--procrustes-tolerance", + type=float, + default=1e-2, + help="Maximum acceptable Procrustes disparity to consider UMAPs similar (default: 1e-2)", + ) + parser.add_argument( + "--save-plot", + type=str, + default='./umap_comparison.png', + help="Optional path to save a side-by-side aligned UMAP comparison plot (PNG)", + ) + parser.add_argument( + "--obs-key", + type=str, + default="cell_type", + help="Column in .obs to color points by (default: cell_type). Use 'none' to disable.", + ) + + args = parser.parse_args() + + # Existence check + for fp in [args.file1, args.file2]: + if not os.path.exists(fp): + print(f"Error: File not found: {fp}") + raise SystemExit(1) + + ok = compare_umaps( + args.file1, + args.file2, + n_neighbors=args.n_neighbors, + min_dist=args.min_dist, + metric=args.metric, + random_state=args.random_state, + procrustes_tolerance=args.procrustes_tolerance, + save_plot=args.save_plot, + obs_key=(None if (args.obs_key is None or str(args.obs_key).lower() == "none") else args.obs_key), + ) + + if ok: + print("UMAPs are similar within tolerance.") + raise SystemExit(0) + else: + print("UMAPs differ beyond tolerance.") + raise SystemExit(2) + + +if __name__ == "__main__": + main() + + From d730deafff6f948640a536438d031ae01609d473 Mon Sep 17 00:00:00 2001 From: jpearce Date: Fri, 15 Aug 2025 15:46:53 -0700 Subject: [PATCH 05/11] update --- README.md | 17 +- src/transcriptformer/cli/__init__.py | 63 ++-- .../cli/conf/inference_config.yaml | 5 +- src/transcriptformer/data/dataclasses.py | 4 +- src/transcriptformer/data/dataloader.py | 344 ++++++++---------- src/transcriptformer/model/inference.py | 46 ++- test/test_compare_emb.py | 2 +- test/test_compare_umap.py | 22 +- 8 files changed, 240 insertions(+), 263 deletions(-) diff --git a/README.md b/README.md index 488d067..b0b2887 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,7 @@ transcriptformer inference \ transcriptformer inference \ --checkpoint-path ./checkpoints/tf_sapiens \ --data-file test/data/human_val.h5ad \ - --num-gpus 4 \ + --num-gpus 4 \ --batch-size 32 ``` @@ -288,13 +288,13 @@ transcriptformer download-data --help - `--pretrained-embedding PATH`: Path to pretrained embeddings for out-of-distribution species. - `--clip-counts INT`: Maximum count value (higher values will be clipped) (default: 30). - `--filter-to-vocabs`: Whether to filter genes to only those in the vocabulary (default: True). -- `--use-iterable-dataset`: Use a streaming IterableDataset for low-memory processing; yields cells on-the-fly (default: False). -- `--iterable-chunk-size INT`: Optional chunk size (number of cells) processed per step when using the iterable dataset. - `--use-raw {True,False,auto}`: Whether to use raw counts from `AnnData.raw.X` (True), `adata.X` (False), or auto-detect (auto/None) (default: None). - `--embedding-layer-index INT`: Index of the transformer layer to extract embeddings from (-1 for last layer, default: -1). Use with `transcriptformer` model type. - `--model-type {transcriptformer,esm2ce}`: Type of model to use (default: `transcriptformer`). Use `esm2ce` to extract raw ESM2-CE gene embeddings. - `--emb-type {cell,cge}`: Type of embeddings to extract (default: `cell`). Use `cell` for mean-pooled cell embeddings or `cge` for contextual gene embeddings. - `--num-gpus INT`: Number of GPUs to use for inference (default: 1). Use -1 for all available GPUs, or specify a specific number. +- `--oom-dataloader`: Use the OOM-safe map-style DataLoader (uses backed reads and per-item densification; DistributedSampler-friendly). +- `--n-data-workers INT`: Number of DataLoader workers per process (default: 0). Order is preserved with the map-style dataset and DistributedSampler. - `--config-override key.path=value`: Override any configuration value directly. ### Input Data Format and Preprocessing: @@ -315,16 +315,17 @@ Input data files should be in H5AD format (AnnData objects) with the following r - `True`: Use only `adata.raw.X` - `False`: Use only `adata.X` - - **Streaming Inference**: - - To reduce peak memory usage on large datasets, enable streaming: + - **OOM-safe Data Loading**: + - To reduce peak memory usage on large datasets, enable the OOM-safe dataloader: ```bash transcriptformer inference \ --checkpoint-path ./checkpoints/tf_sapiens \ --data-file ./data/huge.h5ad \ - --use-iterable-dataset \ - --iterable-chunk-size 4096 + --oom-dataloader \ + --n-data-workers 4 \ + --num-gpus 8 ``` - - This uses an IterableDataset that processes cells in chunks and yields items to the DataLoader without loading all cells into memory. + - This uses a map-style dataset with backed reads and per-row densification. It is compatible with `DistributedSampler`, so multiple workers are safe and ordering is preserved. - **Count Processing**: - Count values are clipped at 30 by default (as was done in training) diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index 9cbd30f..3d7ce15 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -61,6 +61,7 @@ import torch from omegaconf import OmegaConf + from transcriptformer.model.inference import run_inference # Suppress annoying warnings @@ -82,6 +83,7 @@ \033[38;2;108;113;131m |_| \033[0m""" + def setup_inference_parser(subparsers): """Setup the parser for the inference command.""" parser = subparsers.add_parser( @@ -177,16 +179,16 @@ def setup_inference_parser(subparsers): help="Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) (default: 1)", ) parser.add_argument( - "--use-iterable-dataset", + "--oom-dataloader", action="store_true", default=False, - help="Use streaming IterableDataset for low-memory processing (default: False)", + help="Use map-style out-of-memory DataLoader (DistributedSampler-friendly)", ) parser.add_argument( - "--iterable-chunk-size", + "--n-data-workers", type=int, - default=4096, - help="Chunk size of rows per processing step when using the iterable dataset (default: auto)", + default=0, + help="Number of DataLoader workers per process (map-style dataset is order-safe).", ) # Allow arbitrary config overrides @@ -266,25 +268,24 @@ def setup_download_data_parser(subparsers): def run_inference_cli(args): """Run inference using command line arguments.""" - # Only print logo if not in distributed mode (avoids duplicates) is_distributed = args.num_gpus != 1 if not is_distributed: print(TF_LOGO) - # Load the config + # Load the config config_path = os.path.join(os.path.dirname(__file__), "conf", "inference_config.yaml") cfg = OmegaConf.load(config_path) - + # Load model config from checkpoint model_config_path = os.path.join(args.checkpoint_path, "config.json") with open(model_config_path) as f: config_dict = json.load(f) mlflow_cfg = OmegaConf.create(config_dict) - + # Merge the MLflow config with the main config cfg = OmegaConf.merge(mlflow_cfg, cfg) - + # Override config values with CLI arguments cfg.model.checkpoint_path = args.checkpoint_path cfg.model.inference_config.data_files = [args.data_file] @@ -298,59 +299,63 @@ def run_inference_cli(args): cfg.model.data_config.remove_duplicate_genes = args.remove_duplicate_genes cfg.model.data_config.use_raw = args.use_raw cfg.model.inference_config.num_gpus = args.num_gpus - cfg.model.inference_config.use_iterable_dataset = args.use_iterable_dataset - cfg.model.inference_config.iterable_chunk_size = args.iterable_chunk_size + cfg.model.inference_config.use_oom_dataloader = args.oom_dataloader cfg.model.data_config.clip_counts = args.clip_counts cfg.model.data_config.filter_to_vocabs = args.filter_to_vocabs - + cfg.model.data_config.n_data_workers = args.n_data_workers + # Add pretrained embedding if specified if args.pretrained_embedding: cfg.model.inference_config.pretrained_embedding = args.pretrained_embedding - + # Apply any arbitrary config overrides for override in args.config_override: - if '=' not in override: + if "=" not in override: continue - key, value = override.split('=', 1) + key, value = override.split("=", 1) # Convert value to appropriate type try: # Try to parse as a number or boolean - if value.lower() in ['true', 'false']: - value = value.lower() == 'true' + if value.lower() in ["true", "false"]: + value = value.lower() == "true" + elif value.lower() in ["none", "null"]: + value = None elif value.isdigit(): value = int(value) - elif '.' in value and all(part.isdigit() for part in value.split('.')): + elif "." in value and all(part.isdigit() for part in value.split(".")): value = float(value) - except: - pass # Keep as string if conversion fails - - OmegaConf.set(cfg, key, value) - + except Exception: + # Keep as string if conversion fails + pass + + # Use OmegaConf.update to set nested keys like "a.b.c" or list indices like "a.list.0" + OmegaConf.update(cfg, key, value) + # Set the checkpoint paths based on the unified checkpoint_path cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt") cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs") cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs") - + # Run inference directly adata_output = run_inference(cfg, data_files=cfg.model.inference_config.data_files) - + # Save the output adata output_path = cfg.model.inference_config.output_path if not os.path.exists(output_path): os.makedirs(output_path) - + # Get output filename from config or use default output_filename = getattr(cfg.model.inference_config, "output_filename", "embeddings.h5ad") if not output_filename.endswith(".h5ad"): output_filename = f"{output_filename}.h5ad" save_file = os.path.join(output_path, output_filename) - + # Check if we're in a distributed environment if is_distributed: rank = torch.distributed.get_rank() # Split the filename and add rank before extension - rank_file = save_file.replace('.h5ad', f'_{rank}.h5ad') + rank_file = save_file.replace(".h5ad", f"_{rank}.h5ad") adata_output.write_h5ad(rank_file) print(f"Rank {rank} completed processing, saved partial results to {rank_file}") else: diff --git a/src/transcriptformer/cli/conf/inference_config.yaml b/src/transcriptformer/cli/conf/inference_config.yaml index b519925..a1aafe6 100644 --- a/src/transcriptformer/cli/conf/inference_config.yaml +++ b/src/transcriptformer/cli/conf/inference_config.yaml @@ -23,9 +23,8 @@ model: pretrained_embedding: null # Path to pretrained embeddings for out-of-distribution species precision: 16-mixed # Numerical precision for inference (16-mixed, 32, etc.) emb_type: cell # Type of embeddings to extract: "cell" for mean-pooled cell embeddings or "cge" for contextual gene embeddings - use_iterable_dataset: false # Use streaming IterableDataset for low-memory processing - iterable_chunk_size: 4096 # Chunk size (rows per step) when using iterable dataset num_gpus: 1 # Number of GPUs to use for inference (1 = single GPU, -1 = all available GPUs, >1 = specific number) + use_oom_dataloader: false # Use OOM-safe map-style DataLoader with DistributedSampler data_config: _target_: transcriptformer.data.dataclasses.DataConfig @@ -39,4 +38,4 @@ model: min_expressed_genes: 0 # Minimum number of expressed genes required per cell use_raw: "auto" # Whether to use .raw.X (True), .X (False), or auto-detect (auto/null) remove_duplicate_genes: false # Whether to remove duplicate genes instead of raising an error - n_data_workers: 0 # Leave as 0 to ensure deterministic behavior \ No newline at end of file + n_data_workers: 0 # Leave as 0 to ensure deterministic behavior diff --git a/src/transcriptformer/data/dataclasses.py b/src/transcriptformer/data/dataclasses.py index dbd24e4..f98c8a3 100644 --- a/src/transcriptformer/data/dataclasses.py +++ b/src/transcriptformer/data/dataclasses.py @@ -191,13 +191,11 @@ class InferenceConfig: output_filename: str | None = "embeddings.h5ad" num_gpus: int = 1 num_nodes: int = 1 + use_oom_dataloader: bool = False precision: str = "16-mixed" special_tokens: list = field(default_factory=list) pretrained_embedding: list = field(default_factory=list) emb_type: str = "cell" - # Streaming/iterable dataset options (moved from data_config) - use_iterable_dataset: bool = False # If True, use StreamingAnnDataset instead of AnnDataset - iterable_chunk_size: int = 4096 # Per-file processing chunk size for streaming def __post_init__(self): if self.emb_type not in {"cell", "cge"}: diff --git a/src/transcriptformer/data/dataloader.py b/src/transcriptformer/data/dataloader.py index 24ca1de..860cb8c 100644 --- a/src/transcriptformer/data/dataloader.py +++ b/src/transcriptformer/data/dataloader.py @@ -9,7 +9,7 @@ import torch from scipy.sparse import csc_matrix, csr_matrix from torch import tensor -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import Dataset from transcriptformer.data.dataclasses import BatchData from transcriptformer.tokenizer.tokenizer import ( @@ -18,10 +18,18 @@ ) -def load_data(file_path): - """Load H5AD file.""" +def load_data(file_path, *, backed: bool = False): + """Load H5AD file. + + Args: + file_path: Path to .h5ad file + backed: If True, use memory-mapped backed='r' mode (for streaming); otherwise fully load into memory + """ try: - adata = sc.read_h5ad(file_path) + if backed: + adata = anndata.read_h5ad(file_path, backed="r") + else: + adata = sc.read_h5ad(file_path) return adata, True except Exception as e: logging.error(f"Failed to read file {file_path}: {e}") @@ -159,6 +167,7 @@ def process_batch( return result + def get_counts_layer(adata: anndata.AnnData, use_raw: bool | None): if use_raw is True: if adata.raw is not None: @@ -184,7 +193,7 @@ def get_counts_layer(adata: anndata.AnnData, use_raw: bool | None): def to_dense(X: np.ndarray | csr_matrix | csc_matrix) -> np.ndarray: - if isinstance(X, (csr_matrix, csc_matrix)): + if isinstance(X, csr_matrix | csc_matrix): return X.toarray() elif isinstance(X, np.ndarray): return X @@ -192,20 +201,50 @@ def to_dense(X: np.ndarray | csr_matrix | csc_matrix) -> np.ndarray: raise TypeError(f"Expected numpy array or sparse matrix, got {type(X)}") -def is_raw_counts(X: np.ndarray) -> bool: +def is_raw_counts(X: np.ndarray | csr_matrix | csc_matrix) -> bool: + """Check if a matrix looks like raw counts (integer-valued where non-zero). + + Handles both dense numpy arrays and sparse CSR/CSC matrices without densifying the full matrix. + """ + # Sparse path: operate on non-zero data directly + if isinstance(X, csr_matrix | csc_matrix): + data = X.data + if data.size == 0: + return False + # Sample if very large + if data.size > 1000: + idx = np.random.choice(data.size, 1000, replace=False) + data = data[idx] + return np.all(np.abs(data - np.round(data)) < 1e-6) + + # Dense path non_zero_mask = X > 0 if not np.any(non_zero_mask): return False non_zero_values = X[non_zero_mask] - if len(non_zero_values) > 1000: - non_zero_values = np.random.choice(non_zero_values, 1000, replace=False) - is_integer = np.all(np.abs(non_zero_values - np.round(non_zero_values)) < 1e-6) - return is_integer + if non_zero_values.size > 1000: + idx = np.random.choice(non_zero_values.size, 1000, replace=False) + non_zero_values = non_zero_values.flatten()[idx] + return np.all(np.abs(non_zero_values - np.round(non_zero_values)) < 1e-6) -def load_gene_features(adata: anndata.AnnData, gene_col_name: str, remove_duplicate_genes: bool): +def load_gene_features( + adata: anndata.AnnData, gene_col_name: str, remove_duplicate_genes: bool, use_raw: bool | None = None +): try: - gene_names = np.array(list(adata.var[gene_col_name].values)) + # Select the appropriate var depending on which matrix will be used + using_raw = bool(use_raw is True or (use_raw is None and getattr(adata, "raw", None) is not None)) + var_df = adata.raw.var if using_raw and getattr(adata, "raw", None) is not None else adata.var + + # Prefer requested column; otherwise use index which aligns with matrix columns for that layer + if gene_col_name in var_df.columns: + gene_names = np.array(list(var_df[gene_col_name].values)) + else: + raise ValueError( + f"Gene column '{gene_col_name}' not found in var DataFrame columns: {list(var_df.columns)}" + ) + + # Remove version numbers from gene names gene_names = np.array([id.split(".")[0] for id in gene_names]) gene_counts = Counter(gene_names) @@ -241,6 +280,7 @@ def validate_gene_dimension(X: np.ndarray, gene_names: np.ndarray, gene_col_name f"Ensure 'adata.var[{gene_col_name}]' exists and aligns with the matrix columns." ) + class AnnDataset(Dataset): def __init__( self, @@ -315,13 +355,14 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: return None gene_names, success, adata = load_gene_features( - adata, self.gene_col_name, self.remove_duplicate_genes + adata, self.gene_col_name, self.remove_duplicate_genes, use_raw=self.use_raw ) if not success: logging.error(f"Failed to load gene features from {file_path}") return None X = get_counts_layer(adata, self.use_raw) + # AnnDataset loads and processes all data in-memory; convert to dense for batching X = to_dense(X) obs = adata.obs @@ -329,14 +370,14 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: validate_gene_dimension(X, gene_names, self.gene_col_name) # Check if the data appears to be raw counts - logging.info(f"Checking if data is raw counts") + logging.info("Checking if data is raw counts") if not is_raw_counts(X): logging.warning( "Data does not appear to be raw counts. TranscriptFormer expects unnormalized count data. " "If your data is normalized, consider using the original count matrix instead." ) - logging.info(f"Applying filters") + logging.info("Applying filters") vocab = self.gene_vocab X, obs, gene_names = apply_filters( X, @@ -353,7 +394,7 @@ def _get_batch_from_file(self, file: str | anndata.AnnData) -> BatchData | None: logging.warning(f"Data was filtered out completely for {file_path}") return None - logging.info(f"Processing data") + logging.info("Processing data") batch = process_batch( X, obs, @@ -459,44 +500,33 @@ def collate_fn(batch: BatchData | list[BatchData]) -> BatchData: return collated_batch -class StreamingAnnDataset(IterableDataset): - """Iterable variant of AnnDataset that loads and processes data iteratively. +class AnnDatasetOOM(Dataset): + """Map-style OOM-safe dataset using backed reads and per-item processing. - This class reuses the same top-level helpers in this module: - - load_data - - apply_filters - - process_batch - - It yields one cell at a time as a BatchData instance, allowing PyTorch's - DataLoader to batch items using the existing AnnDataset.collate_fn. + Designed to provide OOM-safe iteration while leveraging PyTorch's + DistributedSampler for automatic sharding across DDP ranks. """ - # Reuse the same collate function for batching behavior identical to AnnDataset collate_fn = staticmethod(AnnDataset.collate_fn) def __init__( self, - files_list: list[str] | list[anndata.AnnData], + files_list: list[str], gene_vocab: dict[str, str], - data_dir: str = None, - aux_vocab: dict[str, dict[str, str]] = None, + data_dir: str | None = None, + aux_vocab: dict[str, dict[str, str]] | None = None, max_len: int = 2048, - normalize_to_scale: bool = None, + normalize_to_scale: float | None = None, sort_genes: bool = False, randomize_order: bool = False, pad_zeros: bool = True, + pad_token: str = "[PAD]", gene_col_name: str = "ensembl_id", filter_to_vocab: bool = True, - filter_outliers: float = 0.0, - min_expressed_genes: int = 0, - seed: int = 0, - pad_token: str = "[PAD]", clip_counts: float = 1e10, - inference: bool = False, - obs_keys: list[str] = None, - use_raw: bool = None, + obs_keys: list[str] | None = None, + use_raw: bool | None = None, remove_duplicate_genes: bool = False, - iter_chunk_size: int | None = None, ): super().__init__() self.files_list = files_list @@ -508,175 +538,109 @@ def __init__( self.sort_genes = sort_genes self.randomize_order = randomize_order self.pad_zeros = pad_zeros + self.pad_token = pad_token self.gene_col_name = gene_col_name self.filter_to_vocab = filter_to_vocab - self.filter_outliers = filter_outliers - self.min_expressed_genes = min_expressed_genes - self.seed = seed - self.pad_token = pad_token self.clip_counts = clip_counts - self.inference = inference self.obs_keys = obs_keys self.use_raw = use_raw self.remove_duplicate_genes = remove_duplicate_genes - self.iter_chunk_size = iter_chunk_size self.gene_tokenizer = BatchGeneTokenizer(gene_vocab) if aux_vocab is not None: self.aux_tokenizer = BatchObsTokenizer(aux_vocab) - random.seed(self.seed) - - # Estimate total cells for progress bars (approximate; ignores filtering) - self._estimated_total_cells = 0 + # Open backed handles and build cumulative row offsets + self._handles: list[anndata.AnnData] = [] + self._gene_names_per_file: list[np.ndarray] = [] + self._filter_idx_per_file: list[list[int] | None] = [] + self._X_per_file: list = [] + self._n_rows: list[int] = [] for file in self.files_list: - try: - if isinstance(file, str): - file_path = file if self.data_dir is None else os.path.join(self.data_dir, file) - ad = anndata.read_h5ad(file_path, backed='r') - self._estimated_total_cells += int(getattr(ad, 'n_obs', 0)) - if hasattr(ad, 'file') and ad.file is not None: - ad.file.close() - elif isinstance(file, anndata.AnnData): - self._estimated_total_cells += int(file.n_obs) - except Exception: - # If estimation fails for a file, skip it (progress bar may be undercounted) - continue - - - - def _yield_cells_from_arrays(self, X: np.ndarray, obs, gene_names: np.ndarray, file_path: str | None): - """Yield BatchData objects one cell at a time from arrays, processing in chunks.""" - num_cells = X.shape[0] - chunk = self.iter_chunk_size - for start in range(0, num_cells, chunk): - end = min(start + chunk, num_cells) - x_chunk = X[start:end] - obs_chunk = obs.iloc[start:end] - gene_names_chunk = gene_names - - batch = process_batch( - x_chunk, - obs_chunk, - gene_names_chunk, - self.gene_tokenizer, - getattr(self, "aux_tokenizer", None), - self.sort_genes, - self.randomize_order, - self.max_len, - self.pad_zeros, - self.pad_token, - self.gene_vocab, - self.normalize_to_scale, - self.clip_counts, - self.aux_vocab, - ) - - # Attach obs as needed to match AnnDataset behavior - if self.obs_keys is not None: - obs_data = {} - if "all" in self.obs_keys: - cols = obs_chunk.columns - for col in cols: - obs_data[col] = np.array(obs_chunk[col].tolist())[:, None] - else: - for col in self.obs_keys: - obs_data[col] = np.array(obs_chunk[col].tolist())[:, None] - else: - obs_data = None - - # Yield one cell at a time so DataLoader can batch them - for i in range(end - start): - yield BatchData( - gene_counts=batch["gene_counts"][i], - gene_token_indices=batch["gene_token_indices"][i], - file_path=None, - aux_token_indices=( - batch.get("aux_token_indices")[i] if batch.get("aux_token_indices") is not None else None - ), - obs=( - {col: obs_data[col][i][None, :] for col in obs_data} - if obs_data is not None - else None - ), - ) - - def __iter__(self): - # Deterministic iteration - random.seed(self.seed) - - for idx, file in enumerate(self.files_list): - logging.info(f"Streaming file {idx + 1} of {len(self.files_list)}") - - if isinstance(file, str): - file_path = file - if self.data_dir is not None: - file_path = os.path.join(self.data_dir, file_path) - adata, success = load_data(file_path) - elif isinstance(file, anndata.AnnData): - adata = file - success = True - file_path = None - else: - raise ValueError(f"Invalid file type: {type(file)}") - - if not success: - logging.error(f"Failed to load data from {file_path}") - continue - + file_path = file if self.data_dir is None else os.path.join(self.data_dir, file) + adata = anndata.read_h5ad(file_path, backed="r") gene_names, success, adata = load_gene_features( - adata, self.gene_col_name, self.remove_duplicate_genes + adata, self.gene_col_name, self.remove_duplicate_genes, use_raw=self.use_raw ) if not success: - logging.error(f"Failed to load gene features from {file_path}") - continue - - X = get_counts_layer(adata, self.use_raw) - X = to_dense(X) - obs = adata.obs - - # Validate that gene dimension matches number of gene names - validate_gene_dimension(X, gene_names, self.gene_col_name) - - # Same checks as AnnDataset - logging.info("Checking if data is raw counts") - if not is_raw_counts(X): - logging.warning( - "Data does not appear to be raw counts. TranscriptFormer expects unnormalized count data. " - "If your data is normalized, consider using the original count matrix instead." + raise ValueError(f"Failed to load gene features from {file_path}") + # Optional vocab filtering at token level + filter_idx = None + if self.filter_to_vocab: + original_gene_count = len(gene_names) + filter_idx = [i for i, name in enumerate(gene_names) if name in self.gene_vocab] + gene_names = gene_names[filter_idx] + logging.info( + f"Filtered {original_gene_count} genes to {len(gene_names)} genes in vocab for file {file_path}" ) + if len(gene_names) == 0: + raise ValueError(f"No genes remaining after filtering for file {file_path}") - logging.info("Applying filters") - X, obs, gene_names = apply_filters( - X, - obs, - gene_names, - file_path, - self.filter_to_vocab, - self.gene_vocab, - self.filter_outliers, - self.min_expressed_genes, - ) + self._handles.append(adata) + self._gene_names_per_file.append(gene_names) + self._filter_idx_per_file.append(filter_idx) + X_layer = get_counts_layer(adata, self.use_raw) + self._X_per_file.append(X_layer) + self._n_rows.append(int(adata.n_obs)) - if X is None: - logging.warning(f"Data was filtered out completely for {file_path}") - continue - - logging.info("Processing and yielding cells") - yield from self._yield_cells_from_arrays(X, obs, gene_names, file_path) + self._offsets = np.cumsum([0] + self._n_rows) def __len__(self) -> int: - """Return an estimated total number of samples for progress bars. - - Note: This is an upper bound prior to filtering; actual yielded samples - may be fewer. It enables progress bar display for IterableDataset. - - For distributed training, this returns the estimated length divided by - the number of processes to avoid the PyTorch Lightning warning. - """ - # Check if we're in a distributed environment - if torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - return int(self._estimated_total_cells // world_size) - else: - return int(self._estimated_total_cells) + return int(self._offsets[-1]) + + def _loc(self, idx: int) -> tuple[int, int]: + file_id = int(np.searchsorted(self._offsets, idx, side="right") - 1) + row = int(idx - self._offsets[file_id]) + return file_id, row + + def __getitem__(self, idx: int) -> BatchData: + file_id, row = self._loc(idx) + adata = self._handles[file_id] + gene_names = self._gene_names_per_file[file_id] + filter_idx = self._filter_idx_per_file[file_id] + + X = self._X_per_file[file_id] + # Some backed sparse implementations use __getitem__ returning 2D; ensure 1D + x_row = X[row] + x_row = x_row.toarray().ravel() if hasattr(x_row, "toarray") else np.asarray(x_row).ravel() + if filter_idx is not None: + x_row = x_row[filter_idx] + + obs_row = adata.obs.iloc[row : row + 1] + + # Build a 1-row batch and reuse existing processing pipeline + x_batch = np.expand_dims(x_row, axis=0) + batch = process_batch( + x_batch, + obs_row, + gene_names, + self.gene_tokenizer, + getattr(self, "aux_tokenizer", None), + self.sort_genes, + self.randomize_order, + self.max_len, + self.pad_zeros, + self.pad_token, + self.gene_vocab, + self.normalize_to_scale, + self.clip_counts, + self.aux_vocab, + ) + + # Convert to BatchData for collate_fn compatibility + obs_dict = None + if self.obs_keys is not None: + obs_dict = {} + cols = list(obs_row.columns) if "all" in self.obs_keys else list(self.obs_keys or []) + for col in cols: + obs_dict[col] = np.array(obs_row[col].tolist())[:, None] + + return BatchData( + gene_counts=batch["gene_counts"][0], + gene_token_indices=batch["gene_token_indices"][0], + file_path=None, + aux_token_indices=( + batch.get("aux_token_indices")[0] if batch.get("aux_token_indices") is not None else None + ), + obs=obs_dict, + ) diff --git a/src/transcriptformer/model/inference.py b/src/transcriptformer/model/inference.py index 1922984..edb89e3 100644 --- a/src/transcriptformer/model/inference.py +++ b/src/transcriptformer/model/inference.py @@ -9,7 +9,7 @@ from pytorch_lightning.loggers import CSVLogger from torch.utils.data import DataLoader -from transcriptformer.data.dataloader import AnnDataset, StreamingAnnDataset +from transcriptformer.data.dataloader import AnnDataset, AnnDatasetOOM from transcriptformer.model.embedding_surgery import change_embedding_layer from transcriptformer.tokenizer.vocab import load_vocabs_and_embeddings from transcriptformer.utils.utils import stack_dict @@ -38,15 +38,11 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): "ignore", message="The 'predict_dataloader' does not have many workers which may be a bottleneck" ) warnings.filterwarnings( - "ignore", + "ignore", message="Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading", - category=UserWarning - ) - warnings.filterwarnings( - "ignore", - message="Transforming to str index.", - category=UserWarning + category=UserWarning, ) + warnings.filterwarnings("ignore", message="Transforming to str index.", category=UserWarning) # Load vocabs and embeddings (gene_vocab, aux_vocab), emb_matrix = load_vocabs_and_embeddings(cfg) @@ -124,12 +120,24 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): "remove_duplicate_genes": cfg.model.data_config.remove_duplicate_genes, "use_raw": cfg.model.data_config.use_raw, } - if getattr(cfg.model.inference_config, "use_iterable_dataset", False): - # Optional chunk size - iter_kwargs = {} - if getattr(cfg.model.inference_config, "iterable_chunk_size", None) is not None: - iter_kwargs["iter_chunk_size"] = cfg.model.inference_config.iterable_chunk_size - dataset = StreamingAnnDataset(data_files, **data_kwargs, **iter_kwargs) + if getattr(cfg.model.inference_config, "use_oom_dataloader", False): + # Use OOM-safe map-style dataset + dataset = AnnDatasetOOM( + data_files, + gene_vocab, + aux_vocab=aux_vocab, + max_len=cfg.model.model_config.seq_len, + normalize_to_scale=cfg.model.data_config.normalize_to_scale, + sort_genes=cfg.model.data_config.sort_genes, + randomize_order=cfg.model.data_config.randomize_genes, + pad_zeros=cfg.model.data_config.pad_zeros, + gene_col_name=cfg.model.data_config.gene_col_name, + filter_to_vocab=cfg.model.data_config.filter_to_vocabs, + clip_counts=cfg.model.data_config.clip_counts, + obs_keys=cfg.model.inference_config.obs_keys, + use_raw=cfg.model.data_config.use_raw, + remove_duplicate_genes=cfg.model.data_config.remove_duplicate_genes, + ) else: dataset = AnnDataset(data_files, **data_kwargs) @@ -144,8 +152,8 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): ) # Determine number of GPUs to use - num_gpus = getattr(cfg.model.inference_config, 'num_gpus', 1) - + num_gpus = getattr(cfg.model.inference_config, "num_gpus", 1) + # Handle special cases for num_gpus if num_gpus == -1: # Use all available GPUs @@ -155,7 +163,9 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): # Use specified number of GPUs available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 if available_gpus < num_gpus: - logging.warning(f"Requested {num_gpus} GPUs but only {available_gpus} available. Using {available_gpus} GPUs.") + logging.warning( + f"Requested {num_gpus} GPUs but only {available_gpus} available. Using {available_gpus} GPUs." + ) devices = available_gpus if available_gpus > 0 else 1 accelerator = "gpu" if available_gpus > 0 else "cpu" else: @@ -165,7 +175,7 @@ def run_inference(cfg, data_files: list[str] | list[anndata.AnnData]): # Use single GPU or CPU devices = 1 accelerator = "gpu" if torch.cuda.is_available() else "cpu" - + logging.info(f"Using {devices} device(s) with accelerator: {accelerator}") # Create Trainer diff --git a/test/test_compare_emb.py b/test/test_compare_emb.py index 176fe70..a70c345 100644 --- a/test/test_compare_emb.py +++ b/test/test_compare_emb.py @@ -68,7 +68,7 @@ def compare_embeddings(file1, file2, tolerance=1e-5): # Using paired t-test to test null hypothesis that embeddings are the same diff = emb1_flat - emb2_flat t_stat, p_value = ttest_rel(emb1_flat, emb2_flat) - + # Calculate z-score manually from the differences z_score = np.mean(diff) / (np.std(diff) / np.sqrt(len(diff))) diff --git a/test/test_compare_umap.py b/test/test_compare_umap.py index 1a561a9..949c466 100644 --- a/test/test_compare_umap.py +++ b/test/test_compare_umap.py @@ -18,7 +18,13 @@ def _require_umap(): return umap -def compute_umap(embeddings: np.ndarray, n_neighbors: int = 15, min_dist: float = 0.1, metric: str = "euclidean", random_state: int | None = 42) -> np.ndarray: +def compute_umap( + embeddings: np.ndarray, + n_neighbors: int = 15, + min_dist: float = 0.1, + metric: str = "euclidean", + random_state: int | None = 42, +) -> np.ndarray: """ Compute a 2D UMAP from input embeddings. @@ -80,17 +86,13 @@ def compare_umaps( missing.append(f"'embeddings' not found in {file1}") if "embeddings" not in adata2.obsm: missing.append(f"'embeddings' not found in {file2}") - raise ValueError( - "; ".join(missing) - ) + raise ValueError("; ".join(missing)) emb1 = np.asarray(adata1.obsm["embeddings"]) # (n, d) emb2 = np.asarray(adata2.obsm["embeddings"]) # (n, d) if emb1.shape[0] != emb2.shape[0]: - raise ValueError( - f"Number of rows differ between embeddings: {emb1.shape[0]} vs {emb2.shape[0]}" - ) + raise ValueError(f"Number of rows differ between embeddings: {emb1.shape[0]} vs {emb2.shape[0]}") # Compute UMAPs independently umap1 = compute_umap(emb1, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, random_state=random_state) @@ -120,7 +122,7 @@ def compare_umaps( series2 = adata2.obs[obs_key] labels1 = series1.astype(str).to_numpy() labels2 = series2.astype(str).to_numpy() - categories = sorted(list(set(labels1).union(set(labels2)))) + categories = sorted(set(labels1).union(set(labels2))) import matplotlib as mpl @@ -196,7 +198,7 @@ def main(): parser.add_argument( "--save-plot", type=str, - default='./umap_comparison.png', + default="./umap_comparison.png", help="Optional path to save a side-by-side aligned UMAP comparison plot (PNG)", ) parser.add_argument( @@ -236,5 +238,3 @@ def main(): if __name__ == "__main__": main() - - From 54daed883818141b8e1e926b5eaf4be267d7d03a Mon Sep 17 00:00:00 2001 From: jpearce Date: Fri, 15 Aug 2025 16:08:14 -0700 Subject: [PATCH 06/11] rm unit test --- test/test_cli.py | 46 ---------------------------------------------- 1 file changed, 46 deletions(-) diff --git a/test/test_cli.py b/test/test_cli.py index b496d74..04c7113 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -63,52 +63,6 @@ def test_inference_command(self, mock_run_inference, monkeypatch): main() mock_run_inference.assert_called_once() - @mock.patch("transcriptformer.cli.inference.main") - def test_run_inference_cli(self, mock_inference_main, monkeypatch): - """Test run_inference_cli function properly calls inference.main.""" - args = mock.MagicMock() - args.checkpoint_path = "/path/to/checkpoint" - args.data_file = "/path/to/data.h5ad" - args.output_path = "./inference_results" - args.output_filename = "embeddings.h5ad" - args.batch_size = 8 - args.gene_col_name = "ensembl_id" - args.precision = "16-mixed" - args.pretrained_embedding = None - args.config_override = [] - args.model_type = "transcriptformer" - args.emb_type = "cell" - - # Test that the function properly sets up Hydra config - original_argv = sys.argv.copy() - run_inference_cli(args) - mock_inference_main.assert_called_once() - # Check sys.argv was restored - assert sys.argv == original_argv - - @mock.patch("transcriptformer.cli.inference.main") - def test_run_inference_cli_with_cge(self, mock_inference_main, monkeypatch): - """Test run_inference_cli function with CGE embedding type.""" - args = mock.MagicMock() - args.checkpoint_path = "/path/to/checkpoint" - args.data_file = "/path/to/data.h5ad" - args.output_path = "./inference_results" - args.output_filename = "embeddings.h5ad" - args.batch_size = 8 - args.gene_col_name = "ensembl_id" - args.precision = "16-mixed" - args.pretrained_embedding = None - args.config_override = [] - args.model_type = "transcriptformer" - args.emb_type = "cge" - - # Test that the function properly sets up Hydra config - original_argv = sys.argv.copy() - run_inference_cli(args) - mock_inference_main.assert_called_once() - # Check sys.argv was restored - assert sys.argv == original_argv - class TestDownloadCommand: """Tests for the download command.""" From ba282417b4a4c930183485cadedcb9dd0754c842 Mon Sep 17 00:00:00 2001 From: James Pearce <57334682+jdpearce4@users.noreply.github.com> Date: Tue, 19 Aug 2025 11:26:13 -0700 Subject: [PATCH 07/11] Update test/test_compare_emb.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/test_compare_emb.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_compare_emb.py b/test/test_compare_emb.py index a70c345..0247b73 100644 --- a/test/test_compare_emb.py +++ b/test/test_compare_emb.py @@ -36,7 +36,10 @@ def compare_embeddings(file1, file2, tolerance=1e-5): if "embedding" not in adata1.obsm: missing.append(f"'embedding' not found in {file1}") if "embedding" not in adata2.obsm: - missing.append(f"'embedding' not found in {file2}") + if "embeddings" not in adata1.obsm: + missing.append(f"'embeddings' not found in {file1}") + if "embeddings" not in adata2.obsm: + missing.append(f"'embeddings' not found in {file2}") print(f"Error: {', '.join(missing)}") return False From f9441dc060522b9641a4ca33f60429ed9289aaf1 Mon Sep 17 00:00:00 2001 From: James Pearce <57334682+jdpearce4@users.noreply.github.com> Date: Tue, 19 Aug 2025 11:26:55 -0700 Subject: [PATCH 08/11] Update src/transcriptformer/data/dataloader.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/transcriptformer/data/dataloader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transcriptformer/data/dataloader.py b/src/transcriptformer/data/dataloader.py index 860cb8c..e5f9373 100644 --- a/src/transcriptformer/data/dataloader.py +++ b/src/transcriptformer/data/dataloader.py @@ -602,7 +602,11 @@ def __getitem__(self, idx: int) -> BatchData: X = self._X_per_file[file_id] # Some backed sparse implementations use __getitem__ returning 2D; ensure 1D x_row = X[row] - x_row = x_row.toarray().ravel() if hasattr(x_row, "toarray") else np.asarray(x_row).ravel() + # Only convert to dense if the row is actually sparse + if isinstance(x_row, (csr_matrix, csc_matrix)): + x_row = x_row.toarray().ravel() + else: + x_row = np.asarray(x_row).ravel() if filter_idx is not None: x_row = x_row[filter_idx] From d2ee925750fcf9f1dd89cfdd0fc0f79971d5bc24 Mon Sep 17 00:00:00 2001 From: James Pearce <57334682+jdpearce4@users.noreply.github.com> Date: Tue, 19 Aug 2025 11:28:04 -0700 Subject: [PATCH 09/11] Update src/transcriptformer/data/dataloader.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/transcriptformer/data/dataloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transcriptformer/data/dataloader.py b/src/transcriptformer/data/dataloader.py index e5f9373..7d80dc5 100644 --- a/src/transcriptformer/data/dataloader.py +++ b/src/transcriptformer/data/dataloader.py @@ -234,7 +234,9 @@ def load_gene_features( try: # Select the appropriate var depending on which matrix will be used using_raw = bool(use_raw is True or (use_raw is None and getattr(adata, "raw", None) is not None)) - var_df = adata.raw.var if using_raw and getattr(adata, "raw", None) is not None else adata.var + has_raw = getattr(adata, "raw", None) is not None + using_raw = bool(use_raw is True or (use_raw is None and has_raw)) + var_df = adata.raw.var if using_raw and has_raw else adata.var # Prefer requested column; otherwise use index which aligns with matrix columns for that layer if gene_col_name in var_df.columns: From a0786ff1ecee610f0f6780a9c9d534150693ac6f Mon Sep 17 00:00:00 2001 From: James Pearce <57334682+jdpearce4@users.noreply.github.com> Date: Tue, 19 Aug 2025 11:28:34 -0700 Subject: [PATCH 10/11] Update src/transcriptformer/cli/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/transcriptformer/cli/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transcriptformer/cli/__init__.py b/src/transcriptformer/cli/__init__.py index 3d7ce15..f3fefdc 100644 --- a/src/transcriptformer/cli/__init__.py +++ b/src/transcriptformer/cli/__init__.py @@ -352,7 +352,10 @@ def run_inference_cli(args): # Check if we're in a distributed environment if is_distributed: - rank = torch.distributed.get_rank() + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 # Split the filename and add rank before extension rank_file = save_file.replace(".h5ad", f"_{rank}.h5ad") From a7a21fd343d2ed939db9f995d3d03d666f58dd13 Mon Sep 17 00:00:00 2001 From: jpearce Date: Tue, 19 Aug 2025 11:49:43 -0700 Subject: [PATCH 11/11] formatting --- src/transcriptformer/data/dataloader.py | 2 +- test/test_cli.py | 1 - test/test_compare_emb.py | 3 --- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transcriptformer/data/dataloader.py b/src/transcriptformer/data/dataloader.py index 7d80dc5..a436be4 100644 --- a/src/transcriptformer/data/dataloader.py +++ b/src/transcriptformer/data/dataloader.py @@ -605,7 +605,7 @@ def __getitem__(self, idx: int) -> BatchData: # Some backed sparse implementations use __getitem__ returning 2D; ensure 1D x_row = X[row] # Only convert to dense if the row is actually sparse - if isinstance(x_row, (csr_matrix, csc_matrix)): + if isinstance(x_row, csr_matrix | csc_matrix): x_row = x_row.toarray().ravel() else: x_row = np.asarray(x_row).ravel() diff --git a/test/test_cli.py b/test/test_cli.py index 04c7113..428f73b 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -8,7 +8,6 @@ from transcriptformer.cli import ( main, run_download_cli, - run_inference_cli, setup_download_parser, setup_inference_parser, ) diff --git a/test/test_compare_emb.py b/test/test_compare_emb.py index 0247b73..5410d07 100644 --- a/test/test_compare_emb.py +++ b/test/test_compare_emb.py @@ -33,9 +33,6 @@ def compare_embeddings(file1, file2, tolerance=1e-5): # Check if embeddings exist if "embeddings" not in adata1.obsm or "embeddings" not in adata2.obsm: missing = [] - if "embedding" not in adata1.obsm: - missing.append(f"'embedding' not found in {file1}") - if "embedding" not in adata2.obsm: if "embeddings" not in adata1.obsm: missing.append(f"'embeddings' not found in {file1}") if "embeddings" not in adata2.obsm: