Skip to content
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ See the `pyproject.toml` file for the complete list of dependencies.
### Hardware Requirements
- GPU (A100 40GB recommended) for efficient inference and embedding extraction.
- Can also use a GPU with a lower amount of VRAM (16GB) by setting the inference batch size to 1-4.
- **Multi-GPU support**: For faster inference on large datasets, use multiple GPUs with the `--num-gpus` parameter.
- Recommended for datasets with >100k cells
- Scales batch processing across available GPUs using Distributed Data Parallel (DDP)
- Best performance with matched GPU types and sufficient inter-GPU bandwidth


## Using the TranscriptFormer CLI
Expand Down Expand Up @@ -234,6 +238,13 @@ transcriptformer inference \
--data-file test/data/human_val.h5ad \
--emb-type cge \
--batch-size 8

# Multi-GPU inference using 4 GPUs (-1 will use all available on the system)
transcriptformer inference \
--checkpoint-path ./checkpoints/tf_sapiens \
--data-file test/data/human_val.h5ad \
--num-gpus 4 \
--batch-size 32
```

You can also use the CLI it run inference on the ESM2-CE baseline model discussed in the paper:
Expand Down Expand Up @@ -281,6 +292,9 @@ transcriptformer download-data --help
- `--embedding-layer-index INT`: Index of the transformer layer to extract embeddings from (-1 for last layer, default: -1). Use with `transcriptformer` model type.
- `--model-type {transcriptformer,esm2ce}`: Type of model to use (default: `transcriptformer`). Use `esm2ce` to extract raw ESM2-CE gene embeddings.
- `--emb-type {cell,cge}`: Type of embeddings to extract (default: `cell`). Use `cell` for mean-pooled cell embeddings or `cge` for contextual gene embeddings.
- `--num-gpus INT`: Number of GPUs to use for inference (default: 1). Use -1 for all available GPUs, or specify a specific number.
- `--oom-dataloader`: Use the OOM-safe map-style DataLoader (uses backed reads and per-item densification; DistributedSampler-friendly).
- `--n-data-workers INT`: Number of DataLoader workers per process (default: 0). Order is preserved with the map-style dataset and DistributedSampler.
- `--config-override key.path=value`: Override any configuration value directly.

### Input Data Format and Preprocessing:
Expand All @@ -301,6 +315,18 @@ Input data files should be in H5AD format (AnnData objects) with the following r
- `True`: Use only `adata.raw.X`
- `False`: Use only `adata.X`

- **OOM-safe Data Loading**:
- To reduce peak memory usage on large datasets, enable the OOM-safe dataloader:
```bash
transcriptformer inference \
--checkpoint-path ./checkpoints/tf_sapiens \
--data-file ./data/huge.h5ad \
--oom-dataloader \
--n-data-workers 4 \
--num-gpus 8
```
- This uses a map-style dataset with backed reads and per-row densification. It is compatible with `DistributedSampler`, so multiple workers are safe and ordering is preserved.

- **Count Processing**:
- Count values are clipped at 30 by default (as was done in training)
- If this seems too low, you can either:
Expand Down
163 changes: 0 additions & 163 deletions download_artifacts.py

This file was deleted.

63 changes: 0 additions & 63 deletions inference.py

This file was deleted.

Loading