diff --git a/README.md b/README.md
index d173dfc..c0f837b 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
- Overview of TranscriptFormer pretraining data, model, outputs and downstream tasks.
+ Overview of TranscriptFormer pretraining data (A), model (B), outputs (C) and downstream tasks (D).
@@ -112,6 +112,49 @@ The command will download and extract the following files to the `./checkpoints`
- `./checkpoints/tf_metazoa/`: Metazoa model weights
- `./checkpoints/all_embeddings/`: Embedding files for out-of-distribution species
+#### Available Protein Embeddings
+
+The following protein embeddings are available for download with `transcriptformer download all-embeddings`:
+
+| Scientific Name | Common Name | TF-Metazoa | TF-Exemplar | TF-Sapiens | Notes |
+|-----------------|-------------|------------|-------------|------------|-------|
+| *Homo sapiens* | Human | ✓ | ✓ | ✓ | Primary training species |
+| *Mus musculus* | Mouse | ✓ | ✓ | - | Model organism |
+| *Danio rerio* | Zebrafish | ✓ | ✓ | - | Model organism |
+| *Drosophila melanogaster* | Fruit fly | ✓ | ✓ | - | Model organism |
+| *Caenorhabditis elegans* | C. elegans | ✓ | ✓ | - | Model organism |
+| *Oryctolagus cuniculus* | Rabbit | ✓ | - | - | Vertebrate |
+| *Gallus gallus* | Chicken | ✓ | - | - | Vertebrate |
+| *Xenopus laevis* | African clawed frog | ✓ | - | - | Vertebrate |
+| *Lytechinus variegatus* | Sea urchin | ✓ | - | - | Invertebrate |
+| *Spongilla lacustris* | Freshwater sponge | ✓ | - | - | Invertebrate |
+| *Saccharomyces cerevisiae* | Yeast | ✓ | - | - | Fungus |
+| *Plasmodium falciparum* | Malaria parasite | ✓ | - | - | Protist |
+| *Rattus norvegicus* | Rat | - | - | - | Out-of-distribution |
+| *Sus scrofa* | Pig | - | - | - | Out-of-distribution |
+| *Pan troglodytes* | Chimpanzee | - | - | - | Out-of-distribution |
+| *Gorilla gorilla* | Gorilla | - | - | - | Out-of-distribution |
+| *Macaca mulatta* | Rhesus macaque | - | - | - | Out-of-distribution |
+| *Callithrix jacchus* | Marmoset | - | - | - | Out-of-distribution |
+| *Xenopus tropicalis* | Western clawed frog | - | - | - | Out-of-distribution |
+| *Ornithorhynchus anatinus* | Platypus | - | - | - | Out-of-distribution |
+| *Monodelphis domestica* | Opossum | - | - | - | Out-of-distribution |
+| *Heterocephalus glaber* | Naked mole-rat | - | - | - | Out-of-distribution |
+| *Petromyzon marinus* | Sea lamprey | - | - | - | Out-of-distribution |
+| *Stylophora pistillata* | Coral | - | - | - | Out-of-distribution |
+
+**Legend:**
+- ✓ = Species included in model training data
+- \- = Species not included in model training (out-of-distribution)
+
+### Generating Protein Embeddings for New Species
+
+The pre-generated embeddings cover the most commonly used species. If you need to work with a species not included in the downloaded embeddings, you can generate protein embeddings using the ESM-2 models.
+
+**Note**: This is only necessary for new species that don't have pre-generated embeddings available for download.
+
+For detailed instructions on generating protein embeddings for additional species, see the [protein_embeddings/README.md](protein_embeddings/README.md) documentation.
+
### Downloading Training Datasets
Use the CLI to download single-cell RNA sequencing datasets from the CellxGene Discover portal:
diff --git a/preprocess/README.md b/preprocess/README.md
new file mode 100644
index 0000000..dd19de4
--- /dev/null
+++ b/preprocess/README.md
@@ -0,0 +1,194 @@
+# Protein Embeddings Generation
+
+This directory contains scripts for generating protein embeddings using Facebook's ESM-2 (Evolutionary Scale Modeling) models. The pipeline downloads protein sequences from Ensembl, processes them with pre-trained ESM-2 models, and outputs gene-level embeddings suitable for inputs to TranscriptFormer.
+
+## Overview
+
+The protein embedding pipeline consists of three main components:
+
+1. **`protein_embedding.py`** - Main script for generating protein embeddings using ESM-2 models
+2. **`get_stable_id_mapping.py`** - Utility functions for mapping between gene, transcript, and protein stable IDs
+3. **`fasta_manifest_pep.json`** - Configuration file containing download URLs for protein FASTA files from Ensembl
+
+## Installation
+
+### Using pip (traditional)
+
+Install the required dependencies:
+
+```bash
+pip install -r requirements.txt
+pip install fair-esm
+```
+
+For GPU acceleration (recommended):
+```bash
+# For CUDA-enabled PyTorch
+pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
+```
+
+### Using uv (recommended)
+
+[uv](https://github.com/astral-sh/uv) is a fast Python package installer and resolver:
+
+```bash
+# Install uv if you haven't already
+curl -LsSf https://astral.sh/uv/install.sh | sh
+
+# Create and activate virtual environment
+uv venv protein-embeddings
+source protein-embeddings/bin/activate # On Windows: protein-embeddings\Scripts\activate
+
+# Install dependencies
+uv pip install -r requirements.txt
+uv pip install fair-esm
+
+# For GPU acceleration
+uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
+```
+
+### System Requirements
+
+- **Memory**: At least 16GB RAM (32GB+ recommended for large models)
+- **GPU**: NVIDIA GPU with 8GB+ VRAM (optional but highly recommended)
+- **Storage**: Several GB for downloaded FASTA files and generated embeddings
+- **Network**: Internet connection for downloading protein sequences from Ensembl
+
+## Usage
+
+### Basic Usage
+
+Generate protein embeddings for a single species:
+
+```bash
+python protein_embedding.py --organism_key homo_sapiens
+```
+
+### Advanced Usage
+
+```bash
+python protein_embedding.py \
+ --organism_key mus_musculus \
+ --output_dir /path/to/output \
+ --batch_size 32 \
+ --use_large_model true
+```
+
+### Command Line Arguments
+
+- `--organism_key`: Species to process (see [Supported Species](#supported-species))
+- `--output_dir`: Directory to save embeddings (default: current directory `./`)
+- `--batch_size`: Batch size for processing (default: 16)
+- `--use_large_model`: Use ESM-2 15B parameter model instead of 3B (default: false)
+
+
+## Supported Species
+
+The pipeline supports the following species (from Ensembl release 110/113):
+
+| Species | Organism Key | Common Name |
+|---------|-------------|-------------|
+| Homo sapiens | `homo_sapiens` | Human |
+| Mus musculus | `mus_musculus` | Mouse |
+| Rattus norvegicus | `rattus_norvegicus` | Rat |
+| Sus scrofa | `sus_scrofa` | Pig |
+| Oryctolagus cuniculus | `oryctolagus_cuniculus` | Rabbit |
+| Macaca mulatta | `macaca_mulatta` | Rhesus macaque |
+| Pan troglodytes | `pan_troglodytes` | Chimpanzee |
+| Gorilla gorilla | `gorilla_gorilla` | Gorilla |
+| Callithrix jacchus | `callithrix_jacchus` | Marmoset |
+| Microcebus murinus | `microcebus_murinus` | Mouse lemur |
+| Gallus gallus | `gallus_gallus` | Chicken |
+| Danio rerio | `danio_rerio` | Zebrafish |
+| Xenopus tropicalis | `xenopus_tropicalis` | Frog |
+| Drosophila melanogaster | `drosophila_melanogaster` | Fruit fly |
+| Petromyzon marinus | `petromyzon_marinus` | Sea lamprey |
+| Ornithorhynchus anatinus | `ornithorhynchus_anatinus` | Platypus |
+| Monodelphis domestica | `monodelphis_domestica` | Opossum |
+| Heterocephalus glaber | `heterocephalus_glaber` | Naked mole-rat |
+| Stylophora pistillata | `stylophora_pistillata` | Coral |
+
+## Adding New Species
+
+To add support for a new species, you need to update the `fasta_manifest_pep.json` file with the appropriate Ensembl download URLs.
+
+### Step 1: Find Ensembl URLs
+
+1. Visit the [Ensembl FTP site](https://ftp.ensembl.org/pub/) or [Ensembl Genomes](https://ftp.ebi.ac.uk/ensemblgenomes/pub/) for non-vertebrates
+2. Navigate to the latest release (e.g., `release-113/`)
+3. Find your species under `fasta/{species_name}/pep/`
+4. Copy the URL for the `.pep.all.fa.gz` file
+
+### Step 2: Update the Manifest
+
+Add an entry to `fasta_manifest_pep.json`:
+
+```json
+{
+ "new_species_name": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/new_species/pep/New_species.Assembly.pep.all.fa.gz"
+ }
+}
+```
+
+### Step 3: Generate Embeddings
+
+```bash
+python protein_embedding.py --organism_key new_species_name
+```
+
+### Example: Adding Sheep (Ovis aries)
+
+```json
+{
+ "ovis_aries": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/ovis_aries/pep/Ovis_aries_rambouillet.ARS-UI_Ramb_v2.0.pep.all.fa.gz"
+ }
+}
+```
+
+### Notes
+
+- Use lowercase with underscores for organism keys (e.g., `ovis_aries`)
+- Ensure the FASTA file contains protein sequences (`.pep.` not `.cdna.` or `.dna.`)
+- Some species may be in Ensembl Genomes rather than main Ensembl
+- Check that the assembly and release versions are current
+
+## Output Format
+
+The script generates embeddings in HDF5 format with the following structure:
+
+```python
+import h5py
+
+# Load embeddings
+with h5py.File('homo_sapiens_gene.h5', 'r') as f:
+ keys = f['keys'][:] # Gene IDs
+ embeddings = f['arrays'] # Group containing embedding arrays
+
+ # Access specific gene embedding
+ gene_id = 'ENSG00000139618' # Example: BRCA2
+ embedding = embeddings[gene_id][:] # Shape: (2560,) for ESM-2 3B model
+```
+
+### Output Files
+
+- **Standard model**: `{organism}_gene.h5` (d=2560, TranscriptFormer default)
+- **Large model**: `{organism}_gene_large.h5` (d=5120, UCE default)
+
+## Pipeline Details
+
+### Processing Steps
+
+1. **Download**: Automatically downloads protein FASTA files from Ensembl FTP
+2. **Parse**: Extracts gene IDs from protein sequence headers
+3. **Deduplicate**: Removes duplicate sequences for the same gene
+4. **Clean**: Replaces invalid amino acids (*) with `` tokens
+5. **Embed**: Generates embeddings using ESM-2 model (layer 33 for 3B, layer 48 for 15B)
+6. **Average**: Averages embeddings across all protein isoforms per gene
+7. **Save**: Stores final gene-level embeddings in HDF5 format
+
+### Models Used
+
+- **ESM-2 3B** (`esm2_t36_3B_UR50D`): 36-layer, 3 billion parameter model
+- **ESM-2 15B** (`esm2_t48_15B_UR50D`): 48-layer, 15 billion parameter model
diff --git a/preprocess/fasta_manifest_pep.json b/preprocess/fasta_manifest_pep.json
new file mode 100644
index 0000000..44673be
--- /dev/null
+++ b/preprocess/fasta_manifest_pep.json
@@ -0,0 +1,63 @@
+{
+ "homo_sapiens": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/homo_sapiens/pep/Homo_sapiens.GRCh38.pep.all.fa.gz"
+ },
+ "mus_musculus": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/mus_musculus/pep/Mus_musculus.GRCm39.pep.all.fa.gz"
+ },
+ "danio_rerio": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/danio_rerio/pep/Danio_rerio.GRCz11.pep.all.fa.gz"
+ },
+ "callithrix_jacchus": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/callithrix_jacchus/pep/Callithrix_jacchus.mCalJac1.pat.X.pep.all.fa.gz"
+ },
+ "gorilla_gorilla": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/gorilla_gorilla/pep/Gorilla_gorilla.gorGor4.pep.all.fa.gz"
+ },
+ "macaca_mulatta": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/macaca_mulatta/pep/Macaca_mulatta.Mmul_10.pep.all.fa.gz"
+ },
+ "sus_scrofa": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/sus_scrofa/pep/Sus_scrofa.Sscrofa11.1.pep.all.fa.gz"
+ },
+ "pan_troglodytes": {
+ "fa": "https://ftp.ensembl.org/pub/release-110/fasta/pan_troglodytes/pep/Pan_troglodytes.Pan_tro_3.0.pep.all.fa.gz"
+ },
+ "gallus_gallus": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/gallus_gallus/pep/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.fa.gz"
+ },
+ "heterocephalus_glaber": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/heterocephalus_glaber_female/pep/Heterocephalus_glaber_female.Naked_mole-rat_maternal.pep.all.fa.gz"
+ },
+ "monodelphis_domestica": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/monodelphis_domestica/pep/Monodelphis_domestica.ASM229v1.pep.all.fa.gz"
+ },
+ "drosophila_melanogaster": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/drosophila_melanogaster/pep/Drosophila_melanogaster.BDGP6.46.pep.all.fa.gz"
+ },
+ "ornithorhynchus_anatinus": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/ornithorhynchus_anatinus/pep/Ornithorhynchus_anatinus.mOrnAna1.p.v1.pep.all.fa.gz",
+ "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/ornithorhynchus_anatinus/Ornithorhynchus_anatinus.mOrnAna1.p.v1.113.gff3.gz"
+ },
+ "oryctolagus_cuniculus": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/oryctolagus_cuniculus/pep/Oryctolagus_cuniculus.OryCun2.0.pep.all.fa.gz",
+ "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/oryctolagus_cuniculus/Oryctolagus_cuniculus.OryCun2.0.113.gff3.gz"
+ },
+ "xenopus_tropicalis": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/xenopus_tropicalis/pep/Xenopus_tropicalis.UCB_Xtro_10.0.pep.all.fa.gz",
+ "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/xenopus_tropicalis/Xenopus_tropicalis.UCB_Xtro_10.0.113.gff3.gz"
+ },
+ "microcebus_murinus": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/microcebus_murinus/pep/Microcebus_murinus.Mmur_3.0.pep.all.fa.gz",
+ "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/microcebus_murinus/Microcebus_murinus.Mmur_3.0.113.chr.gff3.gz"
+ },
+ "stylophora_pistillata": {
+ "fa": "https://ftp.ebi.ac.uk/ensemblgenomes/pub/release-60/metazoa/fasta/stylophora_pistillata_gca002571385v1/pep/Stylophora_pistillata_gca002571385v1.Stylophora_pistillata_v1.pep.all.fa.gz"
+ },
+ "petromyzon_marinus": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/petromyzon_marinus/pep/Petromyzon_marinus.Pmarinus_7.0.pep.all.fa.gz"
+ },
+ "rattus_norvegicus": {
+ "fa": "https://ftp.ensembl.org/pub/release-113/fasta/rattus_norvegicus/pep/Rattus_norvegicus.mRatBN7.2.pep.all.fa.gz"
+ }
+}
diff --git a/preprocess/get_stable_id_mapping.py b/preprocess/get_stable_id_mapping.py
new file mode 100644
index 0000000..a3e41ff
--- /dev/null
+++ b/preprocess/get_stable_id_mapping.py
@@ -0,0 +1,95 @@
+import os
+import re
+from itertools import product
+
+import pandas as pd
+
+
+def get_stable_id_mapping_from_gff3(
+ gff3_file: str,
+ organism_key: str,
+ output_dir: str = "data/protein_embeddings/gene_protein_stable_ids",
+):
+ mapping_table = []
+ lines = open(gff3_file).read().split("\n")
+
+ start_index = None
+ end_index = None
+ i = 0
+ while i < len(lines):
+ line = lines[i]
+ if line == "###":
+ start_index = i + 1
+ end_index = start_index
+ else:
+ i += 1
+ continue
+
+ while end_index < len(lines) and lines[end_index] != "###":
+ end_index += 1
+
+ data = lines[start_index:end_index]
+ data = "\n".join(data)
+
+ transcript_id_pattern = re.compile(r"transcript_id=([a-zA-Z0-9]+(?:\.[0-9]+)?)")
+ protein_id_pattern = re.compile(r"protein_id=([a-zA-Z0-9]+(?:\.[0-9]+)?)")
+
+ transcript_matches = transcript_id_pattern.findall(data)
+ transcript_ids = transcript_matches if transcript_matches else []
+ transcript_ids = list(set(transcript_ids))
+ protein_matches = protein_id_pattern.findall(data)
+ protein_ids = protein_matches if protein_matches else []
+ protein_ids = list(set(protein_ids))
+
+ gene_id_pattern = re.compile(r"gene_id=([a-zA-Z0-9]+)")
+ gene_match = gene_id_pattern.search(data)
+ gene_id = gene_match.group(1) if gene_match else None
+
+ combinations = list(product([gene_id], transcript_ids, protein_ids))
+
+ for gene_id, transcript_id, protein_id in combinations:
+ mapping_table.append(
+ {
+ "Protein stable ID version": protein_id,
+ "Gene stable ID": gene_id,
+ "Transcript stable ID": transcript_id,
+ }
+ )
+
+ i = end_index
+
+ mapping_table = pd.DataFrame(mapping_table)
+
+ # Create output directory if it doesn't exist
+ os.makedirs(output_dir, exist_ok=True)
+ mapping_table.to_csv(f"{output_dir}/{organism_key}.tsv", sep="\t", index=False)
+
+
+def get_stable_id_mapping_from_fasta(
+ fasta_file: str,
+ organism_key: str,
+ output_dir: str = "data/protein_embeddings/gene_protein_stable_ids",
+):
+ mapping_table = []
+ lines = open(fasta_file).read().split("\n")
+
+ for line in lines:
+ if line.startswith(">"):
+ gene_symbol = line.split("gene_symbol:")[-1].split(" ")[0].strip()
+ gene_id = line.split("gene:")[-1].split(" ")[0].strip()
+ protein_id = line.split(" ")[0].split(">")[-1].strip()
+ transcript_id = line.split("transcript:")[-1].split(" ")[0].strip()
+ mapping_table.append(
+ {
+ "Protein stable ID version": protein_id,
+ "Gene stable ID": gene_id,
+ "Transcript stable ID": transcript_id,
+ "Gene symbol": gene_symbol,
+ }
+ )
+
+ mapping_table = pd.DataFrame(mapping_table)
+
+ # Create output directory if it doesn't exist
+ os.makedirs(output_dir, exist_ok=True)
+ mapping_table.to_csv(f"{output_dir}/{organism_key}.tsv", sep="\t", index=False)
diff --git a/preprocess/protein_embedding.py b/preprocess/protein_embedding.py
new file mode 100644
index 0000000..a203b67
--- /dev/null
+++ b/preprocess/protein_embedding.py
@@ -0,0 +1,256 @@
+import argparse
+import gzip
+import json
+import logging
+import os
+import pickle
+import shutil
+import urllib.request
+from pathlib import Path
+
+import esm
+import h5py
+import numpy as np
+import torch
+from Bio import SeqIO
+from esm import FastaBatchedDataset
+from esm.data import Alphabet
+
+STABLE_ID_DIR = "gene_protein_stable_ids/"
+FASTA_MANIFEST = "fasta_manifest_pep.json"
+
+
+def save_as_hdf5(data_dict, output_path):
+ """Save dictionary as HDF5 file."""
+ with h5py.File(output_path, "w") as f:
+ # Store the keys as a dataset
+ keys = list(data_dict.keys())
+ f.create_dataset("keys", data=np.array(keys, dtype="S"))
+
+ # Create a group for the arrays
+ arrays_group = f.create_group("arrays")
+ for key, value in data_dict.items():
+ arrays_group.create_dataset(str(key), data=value)
+
+
+def clean_sequence(seq: str):
+ """
+ Cleans the input protein sequence by replacing any asterisk (*) characters with the token.
+
+ Args:
+ seq (str): The input protein sequence.
+
+ Returns
+ -------
+ str: The cleaned protein sequence with asterisks replaced by .
+ """
+ return seq.replace("*", "")
+
+
+def pad_batch(toks, num_gpus):
+ """
+ Pads the batch to ensure its size is a multiple of the number of GPUs.
+
+ Args:
+ toks (torch.Tensor): The tokenized sequences.
+ num_gpus (int): The number of GPUs.
+
+ Returns
+ -------
+ torch.Tensor: The padded tokenized sequences.
+ """
+ batch_size = toks.size(0)
+ if batch_size % num_gpus != 0:
+ padding_size = num_gpus - (batch_size % num_gpus)
+ padding = torch.zeros((padding_size, toks.size(1)), dtype=toks.dtype)
+ toks = torch.cat([toks, padding], dim=0)
+ return toks
+
+
+def generate_embeddings(
+ model: torch.nn.Module,
+ alphabet: Alphabet,
+ fasta: str,
+ save_file: str,
+ seq_length=1022,
+ batch_size=16,
+):
+ """
+ Generates embeddings for protein sequences from a given FASTA file using a pre-trained model.
+
+ Args:
+ model (torch.nn.Module): The pre-trained PyTorch model to use for generating embeddings.
+ alphabet (Alphabet): The alphabet object used for encoding sequences.
+ fasta (str): Path to the input FASTA file containing protein sequences.
+ save_file (str): Path to save the generated embeddings.
+ seq_length (int, optional): Maximum sequence length for the embeddings. Defaults to 1022.
+ batch_size (int, optional): Batch size for processing. Defaults to 16.
+
+ Returns
+ -------
+ None
+ """
+ save_dir = os.path.dirname(save_file)
+ if save_dir and not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ dataset = FastaBatchedDataset.from_file(fasta)
+
+ num_tokens = 4096 * batch_size
+ batches = dataset.get_batch_indices(num_tokens, extra_toks_per_seq=1)
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ collate_fn=alphabet.get_batch_converter(seq_length),
+ batch_sampler=batches,
+ num_workers=0,
+ )
+
+ dataset.sequence_strs = [clean_sequence(seq) for seq in dataset.sequence_strs]
+
+ if os.path.exists(save_file):
+ os.remove(save_file)
+
+ embeddings = {}
+ num_gpus = torch.cuda.device_count()
+ with torch.no_grad():
+ for batch_idx, (labels, strs, toks) in enumerate(data_loader):
+ print(f"Processing batch {batch_idx + 1} of {len(batches)}")
+ if torch.cuda.is_available():
+ toks = pad_batch(toks, num_gpus).to(device="cuda", non_blocking=True)
+
+ out = model(toks, repr_layers=[33], return_contacts=False)
+
+ representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
+
+ for i, label in enumerate(labels):
+ truncate_len = min(seq_length, len(strs[i]))
+ embedding = representations[33][i, 1 : truncate_len + 1].mean(0).numpy()
+
+ entry_id = label.split()[0]
+
+ if entry_id in embeddings:
+ embeddings[entry_id].append(embedding)
+ else:
+ embeddings[entry_id] = [embedding]
+
+ # Dump as we go just in case pipeline crashes
+ temp_save_file = save_file + ".tmp"
+ pickle.dump(embeddings, open(temp_save_file, "wb"))
+
+ averaged_embeddings = {k: np.mean(v, axis=0) for k, v in embeddings.items()}
+
+ save_as_hdf5(averaged_embeddings, save_file)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate protein embeddings and convert to gene embeddings.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./",
+ required=False,
+ help="Directory to save output files",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=16,
+ help="Batch size for embedding generation",
+ )
+ parser.add_argument(
+ "--organism_key",
+ type=str,
+ default="homo_sapiens",
+ help="The organism key to generate protein embeddings for",
+ )
+ parser.add_argument(
+ "--use_large_model",
+ action="store_true",
+ help="Whether to use the large ESM-2 model",
+ )
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO)
+
+ # Load FASTA URL manifest
+ with open(FASTA_MANIFEST) as f:
+ fasta_urls = json.load(f)
+
+ if args.organism_key not in fasta_urls:
+ raise ValueError(f"Organism {args.organism_key} is not a valid organism in the fasta manifest")
+
+ # Create stable_id_dir if it doesn't exist
+ stable_id_dir = Path(STABLE_ID_DIR)
+ stable_id_dir.mkdir(parents=True, exist_ok=True)
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Load ESM-2 model
+ if args.use_large_model:
+ model, alphabet = esm.pretrained.esm2_t48_15B_UR50D()
+ suffix = "_large"
+ else:
+ model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
+ suffix = ""
+
+ model.eval() # disables dropout for deterministic results
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ if torch.cuda.device_count() > 1:
+ model = torch.nn.DataParallel(model)
+
+ organism = args.organism_key
+ fasta_url = fasta_urls[organism]["fa"]
+
+ fasta_file = stable_id_dir / f"{organism}.fa"
+ if not fasta_file.exists():
+ logging.info(f"Downloading FASTA for {organism}")
+ # Download and decompress in one step using Python
+ with urllib.request.urlopen(fasta_url) as response:
+ if response.headers.get("Content-Encoding") == "gzip":
+ with gzip.GzipFile(fileobj=response) as gz_file:
+ with open(fasta_file, "w") as out_file:
+ shutil.copyfileobj(gz_file, out_file)
+ else:
+ with open(fasta_file, "w") as out_file:
+ shutil.copyfileobj(response, out_file)
+
+ # Convert to gene IDs
+ new_records = []
+ seen_names = set()
+ for record in SeqIO.parse(fasta_file, "fasta"):
+ if not args.use_large_model:
+ gene_id = record.description.split("gene:")[-1].split(" ")[0].strip().split(".")[0]
+ else:
+ if "gene_symbol" not in record.description:
+ gene_id = record.description.split("gene:")[-1].split(" ")[0].strip().split(".")[0]
+ else:
+ gene_id = record.description.split("gene_symbol:")[-1].split(" ")[0].strip()
+
+ if gene_id in seen_names:
+ continue
+ seen_names.add(gene_id)
+ record.id = gene_id
+ record.name = gene_id
+ new_records.append(record)
+
+ with open(fasta_file, "w") as f:
+ SeqIO.write(new_records, f, "fasta")
+
+ emb_file_name = output_dir / f"{organism}_gene{suffix}.h5"
+ logging.info(f"Processing {fasta_file} for {organism}")
+
+ generate_embeddings(
+ model,
+ alphabet,
+ str(fasta_file),
+ save_file=str(emb_file_name),
+ batch_size=args.batch_size,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/preprocess/requirements.txt b/preprocess/requirements.txt
new file mode 100644
index 0000000..857175a
--- /dev/null
+++ b/preprocess/requirements.txt
@@ -0,0 +1,5 @@
+numpy==1.24.4
+pandas==2.2.1
+scipy==1.13.0
+biopython
+h5py