Skip to content

feat: support for multi-gpu inference and OOM datasets#53

Merged
jdpearce4 merged 12 commits intomainfrom
jpearce-oom-datasets
Aug 20, 2025
Merged

feat: support for multi-gpu inference and OOM datasets#53
jdpearce4 merged 12 commits intomainfrom
jpearce-oom-datasets

Conversation

@jdpearce4
Copy link
Collaborator

Summary

This PR enables scalable inference across multiple GPUs and adds an out-of-memory (OOM) safe map-style dataloader for very large .h5ad files.

Key Changes

  • Multi-GPU inference (DDP)

    • Trainer now uses devices and accelerator=gpu based on inference_config.num_gpus.
    • DistributedSampler is used with the new map-style dataset when devices > 1.
  • OOM-safe map-style dataloader

    • New AnnDatasetOOM:
      • Opens .h5ad in backed='r' mode.
      • Applies gene filtering once per file; caches layer handle; densifies per-row in __getitem__.
      • Compatible with DistributedSampler (order-safe with multiple workers).
    • Densification is moved to the row level to avoid materializing large matrices.
  • CLI and config

    • New flags:
      • --oom-dataloader: enable OOM-safe map-style dataloader.
      • --n-data-workers: number of DataLoader workers per process.
    • Added use_oom_dataloader to InferenceConfig and inference_config.yaml.
    • Wiring for n_data_workers via CLI → DataConfig.
  • Sparse-aware data handling

    • get_counts_layer safely selects raw.X or X with clear logging.
    • is_raw_counts supports sparse inputs via sampling of non-zero data.
    • AnnDataset explicitly densifies before batch processing to satisfy downstream expectations.

Testing

  • Verified multi-GPU prediction with 8 GPUs on a large .h5ad.
  • Confirmed memory stability and no full-matrix densification.
  • Pre-commit lint passes.

Documentation

  • Updated README.md with:
    • New flags (--oom-dataloader, --n-data-workers)
    • OOM-safe usage example
    • Removal of legacy streaming options

@jdpearce4 jdpearce4 changed the title Support for multi-gpu inference and OOM datasets feat: support for multi-gpu inference and OOM datasets Aug 15, 2025
@jdpearce4 jdpearce4 requested a review from Copilot August 19, 2025 17:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables scalable inference across multiple GPUs and adds an out-of-memory (OOM) safe map-style dataloader for very large .h5ad files. This addresses memory bottlenecks and processing speed limitations when working with large single-cell datasets.

  • Multi-GPU inference support using DistributedDataParallel (DDP) with configurable GPU allocation
  • OOM-safe map-style dataset that uses backed reads and per-row densification to handle large files
  • New CLI flags for enabling OOM dataloader, specifying data workers, and controlling GPU usage

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
test/test_compare_umap.py New test utility for comparing UMAPs between embeddings with Procrustes analysis
test/test_compare_emb.py Updated embedding comparison to use "embeddings" key and added statistical tests
test/test_cli.py Removed deprecated CLI test functions
src/transcriptformer/model/inference.py Added multi-GPU support and OOM dataloader integration
src/transcriptformer/data/dataloader.py Major refactoring with new AnnDatasetOOM class and extracted utility functions
src/transcriptformer/data/dataclasses.py Updated InferenceConfig with num_gpus and use_oom_dataloader fields
src/transcriptformer/cli/conf/inference_config.yaml Added new configuration options for multi-GPU and OOM handling
src/transcriptformer/cli/init.py Complete CLI rewrite with direct inference execution instead of Hydra delegation
inference.py Removed legacy inference script
download_artifacts.py Removed legacy download script
README.md Updated documentation with new CLI flags and usage examples

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

jdpearce4 and others added 5 commits August 19, 2025 11:26
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Collaborator

@SESDNA SESDNA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jdpearce4 jdpearce4 merged commit 1d0012c into main Aug 20, 2025
6 checks passed
@jdpearce4 jdpearce4 deleted the jpearce-oom-datasets branch August 20, 2025 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants