diff --git a/README.md b/README.md index d173dfc..c0f837b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

TranscriptFormer Overview
- Overview of TranscriptFormer pretraining data, model, outputs and downstream tasks. + Overview of TranscriptFormer pretraining data (A), model (B), outputs (C) and downstream tasks (D).

@@ -112,6 +112,49 @@ The command will download and extract the following files to the `./checkpoints` - `./checkpoints/tf_metazoa/`: Metazoa model weights - `./checkpoints/all_embeddings/`: Embedding files for out-of-distribution species +#### Available Protein Embeddings + +The following protein embeddings are available for download with `transcriptformer download all-embeddings`: + +| Scientific Name | Common Name | TF-Metazoa | TF-Exemplar | TF-Sapiens | Notes | +|-----------------|-------------|------------|-------------|------------|-------| +| *Homo sapiens* | Human | ✓ | ✓ | ✓ | Primary training species | +| *Mus musculus* | Mouse | ✓ | ✓ | - | Model organism | +| *Danio rerio* | Zebrafish | ✓ | ✓ | - | Model organism | +| *Drosophila melanogaster* | Fruit fly | ✓ | ✓ | - | Model organism | +| *Caenorhabditis elegans* | C. elegans | ✓ | ✓ | - | Model organism | +| *Oryctolagus cuniculus* | Rabbit | ✓ | - | - | Vertebrate | +| *Gallus gallus* | Chicken | ✓ | - | - | Vertebrate | +| *Xenopus laevis* | African clawed frog | ✓ | - | - | Vertebrate | +| *Lytechinus variegatus* | Sea urchin | ✓ | - | - | Invertebrate | +| *Spongilla lacustris* | Freshwater sponge | ✓ | - | - | Invertebrate | +| *Saccharomyces cerevisiae* | Yeast | ✓ | - | - | Fungus | +| *Plasmodium falciparum* | Malaria parasite | ✓ | - | - | Protist | +| *Rattus norvegicus* | Rat | - | - | - | Out-of-distribution | +| *Sus scrofa* | Pig | - | - | - | Out-of-distribution | +| *Pan troglodytes* | Chimpanzee | - | - | - | Out-of-distribution | +| *Gorilla gorilla* | Gorilla | - | - | - | Out-of-distribution | +| *Macaca mulatta* | Rhesus macaque | - | - | - | Out-of-distribution | +| *Callithrix jacchus* | Marmoset | - | - | - | Out-of-distribution | +| *Xenopus tropicalis* | Western clawed frog | - | - | - | Out-of-distribution | +| *Ornithorhynchus anatinus* | Platypus | - | - | - | Out-of-distribution | +| *Monodelphis domestica* | Opossum | - | - | - | Out-of-distribution | +| *Heterocephalus glaber* | Naked mole-rat | - | - | - | Out-of-distribution | +| *Petromyzon marinus* | Sea lamprey | - | - | - | Out-of-distribution | +| *Stylophora pistillata* | Coral | - | - | - | Out-of-distribution | + +**Legend:** +- ✓ = Species included in model training data +- \- = Species not included in model training (out-of-distribution) + +### Generating Protein Embeddings for New Species + +The pre-generated embeddings cover the most commonly used species. If you need to work with a species not included in the downloaded embeddings, you can generate protein embeddings using the ESM-2 models. + +**Note**: This is only necessary for new species that don't have pre-generated embeddings available for download. + +For detailed instructions on generating protein embeddings for additional species, see the [protein_embeddings/README.md](protein_embeddings/README.md) documentation. + ### Downloading Training Datasets Use the CLI to download single-cell RNA sequencing datasets from the CellxGene Discover portal: diff --git a/preprocess/README.md b/preprocess/README.md new file mode 100644 index 0000000..dd19de4 --- /dev/null +++ b/preprocess/README.md @@ -0,0 +1,194 @@ +# Protein Embeddings Generation + +This directory contains scripts for generating protein embeddings using Facebook's ESM-2 (Evolutionary Scale Modeling) models. The pipeline downloads protein sequences from Ensembl, processes them with pre-trained ESM-2 models, and outputs gene-level embeddings suitable for inputs to TranscriptFormer. + +## Overview + +The protein embedding pipeline consists of three main components: + +1. **`protein_embedding.py`** - Main script for generating protein embeddings using ESM-2 models +2. **`get_stable_id_mapping.py`** - Utility functions for mapping between gene, transcript, and protein stable IDs +3. **`fasta_manifest_pep.json`** - Configuration file containing download URLs for protein FASTA files from Ensembl + +## Installation + +### Using pip (traditional) + +Install the required dependencies: + +```bash +pip install -r requirements.txt +pip install fair-esm +``` + +For GPU acceleration (recommended): +```bash +# For CUDA-enabled PyTorch +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +### Using uv (recommended) + +[uv](https://github.com/astral-sh/uv) is a fast Python package installer and resolver: + +```bash +# Install uv if you haven't already +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Create and activate virtual environment +uv venv protein-embeddings +source protein-embeddings/bin/activate # On Windows: protein-embeddings\Scripts\activate + +# Install dependencies +uv pip install -r requirements.txt +uv pip install fair-esm + +# For GPU acceleration +uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +### System Requirements + +- **Memory**: At least 16GB RAM (32GB+ recommended for large models) +- **GPU**: NVIDIA GPU with 8GB+ VRAM (optional but highly recommended) +- **Storage**: Several GB for downloaded FASTA files and generated embeddings +- **Network**: Internet connection for downloading protein sequences from Ensembl + +## Usage + +### Basic Usage + +Generate protein embeddings for a single species: + +```bash +python protein_embedding.py --organism_key homo_sapiens +``` + +### Advanced Usage + +```bash +python protein_embedding.py \ + --organism_key mus_musculus \ + --output_dir /path/to/output \ + --batch_size 32 \ + --use_large_model true +``` + +### Command Line Arguments + +- `--organism_key`: Species to process (see [Supported Species](#supported-species)) +- `--output_dir`: Directory to save embeddings (default: current directory `./`) +- `--batch_size`: Batch size for processing (default: 16) +- `--use_large_model`: Use ESM-2 15B parameter model instead of 3B (default: false) + + +## Supported Species + +The pipeline supports the following species (from Ensembl release 110/113): + +| Species | Organism Key | Common Name | +|---------|-------------|-------------| +| Homo sapiens | `homo_sapiens` | Human | +| Mus musculus | `mus_musculus` | Mouse | +| Rattus norvegicus | `rattus_norvegicus` | Rat | +| Sus scrofa | `sus_scrofa` | Pig | +| Oryctolagus cuniculus | `oryctolagus_cuniculus` | Rabbit | +| Macaca mulatta | `macaca_mulatta` | Rhesus macaque | +| Pan troglodytes | `pan_troglodytes` | Chimpanzee | +| Gorilla gorilla | `gorilla_gorilla` | Gorilla | +| Callithrix jacchus | `callithrix_jacchus` | Marmoset | +| Microcebus murinus | `microcebus_murinus` | Mouse lemur | +| Gallus gallus | `gallus_gallus` | Chicken | +| Danio rerio | `danio_rerio` | Zebrafish | +| Xenopus tropicalis | `xenopus_tropicalis` | Frog | +| Drosophila melanogaster | `drosophila_melanogaster` | Fruit fly | +| Petromyzon marinus | `petromyzon_marinus` | Sea lamprey | +| Ornithorhynchus anatinus | `ornithorhynchus_anatinus` | Platypus | +| Monodelphis domestica | `monodelphis_domestica` | Opossum | +| Heterocephalus glaber | `heterocephalus_glaber` | Naked mole-rat | +| Stylophora pistillata | `stylophora_pistillata` | Coral | + +## Adding New Species + +To add support for a new species, you need to update the `fasta_manifest_pep.json` file with the appropriate Ensembl download URLs. + +### Step 1: Find Ensembl URLs + +1. Visit the [Ensembl FTP site](https://ftp.ensembl.org/pub/) or [Ensembl Genomes](https://ftp.ebi.ac.uk/ensemblgenomes/pub/) for non-vertebrates +2. Navigate to the latest release (e.g., `release-113/`) +3. Find your species under `fasta/{species_name}/pep/` +4. Copy the URL for the `.pep.all.fa.gz` file + +### Step 2: Update the Manifest + +Add an entry to `fasta_manifest_pep.json`: + +```json +{ + "new_species_name": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/new_species/pep/New_species.Assembly.pep.all.fa.gz" + } +} +``` + +### Step 3: Generate Embeddings + +```bash +python protein_embedding.py --organism_key new_species_name +``` + +### Example: Adding Sheep (Ovis aries) + +```json +{ + "ovis_aries": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/ovis_aries/pep/Ovis_aries_rambouillet.ARS-UI_Ramb_v2.0.pep.all.fa.gz" + } +} +``` + +### Notes + +- Use lowercase with underscores for organism keys (e.g., `ovis_aries`) +- Ensure the FASTA file contains protein sequences (`.pep.` not `.cdna.` or `.dna.`) +- Some species may be in Ensembl Genomes rather than main Ensembl +- Check that the assembly and release versions are current + +## Output Format + +The script generates embeddings in HDF5 format with the following structure: + +```python +import h5py + +# Load embeddings +with h5py.File('homo_sapiens_gene.h5', 'r') as f: + keys = f['keys'][:] # Gene IDs + embeddings = f['arrays'] # Group containing embedding arrays + + # Access specific gene embedding + gene_id = 'ENSG00000139618' # Example: BRCA2 + embedding = embeddings[gene_id][:] # Shape: (2560,) for ESM-2 3B model +``` + +### Output Files + +- **Standard model**: `{organism}_gene.h5` (d=2560, TranscriptFormer default) +- **Large model**: `{organism}_gene_large.h5` (d=5120, UCE default) + +## Pipeline Details + +### Processing Steps + +1. **Download**: Automatically downloads protein FASTA files from Ensembl FTP +2. **Parse**: Extracts gene IDs from protein sequence headers +3. **Deduplicate**: Removes duplicate sequences for the same gene +4. **Clean**: Replaces invalid amino acids (*) with `` tokens +5. **Embed**: Generates embeddings using ESM-2 model (layer 33 for 3B, layer 48 for 15B) +6. **Average**: Averages embeddings across all protein isoforms per gene +7. **Save**: Stores final gene-level embeddings in HDF5 format + +### Models Used + +- **ESM-2 3B** (`esm2_t36_3B_UR50D`): 36-layer, 3 billion parameter model +- **ESM-2 15B** (`esm2_t48_15B_UR50D`): 48-layer, 15 billion parameter model diff --git a/preprocess/fasta_manifest_pep.json b/preprocess/fasta_manifest_pep.json new file mode 100644 index 0000000..44673be --- /dev/null +++ b/preprocess/fasta_manifest_pep.json @@ -0,0 +1,63 @@ +{ + "homo_sapiens": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/homo_sapiens/pep/Homo_sapiens.GRCh38.pep.all.fa.gz" + }, + "mus_musculus": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/mus_musculus/pep/Mus_musculus.GRCm39.pep.all.fa.gz" + }, + "danio_rerio": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/danio_rerio/pep/Danio_rerio.GRCz11.pep.all.fa.gz" + }, + "callithrix_jacchus": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/callithrix_jacchus/pep/Callithrix_jacchus.mCalJac1.pat.X.pep.all.fa.gz" + }, + "gorilla_gorilla": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/gorilla_gorilla/pep/Gorilla_gorilla.gorGor4.pep.all.fa.gz" + }, + "macaca_mulatta": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/macaca_mulatta/pep/Macaca_mulatta.Mmul_10.pep.all.fa.gz" + }, + "sus_scrofa": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/sus_scrofa/pep/Sus_scrofa.Sscrofa11.1.pep.all.fa.gz" + }, + "pan_troglodytes": { + "fa": "https://ftp.ensembl.org/pub/release-110/fasta/pan_troglodytes/pep/Pan_troglodytes.Pan_tro_3.0.pep.all.fa.gz" + }, + "gallus_gallus": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/gallus_gallus/pep/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.fa.gz" + }, + "heterocephalus_glaber": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/heterocephalus_glaber_female/pep/Heterocephalus_glaber_female.Naked_mole-rat_maternal.pep.all.fa.gz" + }, + "monodelphis_domestica": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/monodelphis_domestica/pep/Monodelphis_domestica.ASM229v1.pep.all.fa.gz" + }, + "drosophila_melanogaster": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/drosophila_melanogaster/pep/Drosophila_melanogaster.BDGP6.46.pep.all.fa.gz" + }, + "ornithorhynchus_anatinus": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/ornithorhynchus_anatinus/pep/Ornithorhynchus_anatinus.mOrnAna1.p.v1.pep.all.fa.gz", + "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/ornithorhynchus_anatinus/Ornithorhynchus_anatinus.mOrnAna1.p.v1.113.gff3.gz" + }, + "oryctolagus_cuniculus": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/oryctolagus_cuniculus/pep/Oryctolagus_cuniculus.OryCun2.0.pep.all.fa.gz", + "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/oryctolagus_cuniculus/Oryctolagus_cuniculus.OryCun2.0.113.gff3.gz" + }, + "xenopus_tropicalis": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/xenopus_tropicalis/pep/Xenopus_tropicalis.UCB_Xtro_10.0.pep.all.fa.gz", + "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/xenopus_tropicalis/Xenopus_tropicalis.UCB_Xtro_10.0.113.gff3.gz" + }, + "microcebus_murinus": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/microcebus_murinus/pep/Microcebus_murinus.Mmur_3.0.pep.all.fa.gz", + "gff3": "https://ftp.ensembl.org/pub/release-113/gff3/microcebus_murinus/Microcebus_murinus.Mmur_3.0.113.chr.gff3.gz" + }, + "stylophora_pistillata": { + "fa": "https://ftp.ebi.ac.uk/ensemblgenomes/pub/release-60/metazoa/fasta/stylophora_pistillata_gca002571385v1/pep/Stylophora_pistillata_gca002571385v1.Stylophora_pistillata_v1.pep.all.fa.gz" + }, + "petromyzon_marinus": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/petromyzon_marinus/pep/Petromyzon_marinus.Pmarinus_7.0.pep.all.fa.gz" + }, + "rattus_norvegicus": { + "fa": "https://ftp.ensembl.org/pub/release-113/fasta/rattus_norvegicus/pep/Rattus_norvegicus.mRatBN7.2.pep.all.fa.gz" + } +} diff --git a/preprocess/get_stable_id_mapping.py b/preprocess/get_stable_id_mapping.py new file mode 100644 index 0000000..a3e41ff --- /dev/null +++ b/preprocess/get_stable_id_mapping.py @@ -0,0 +1,95 @@ +import os +import re +from itertools import product + +import pandas as pd + + +def get_stable_id_mapping_from_gff3( + gff3_file: str, + organism_key: str, + output_dir: str = "data/protein_embeddings/gene_protein_stable_ids", +): + mapping_table = [] + lines = open(gff3_file).read().split("\n") + + start_index = None + end_index = None + i = 0 + while i < len(lines): + line = lines[i] + if line == "###": + start_index = i + 1 + end_index = start_index + else: + i += 1 + continue + + while end_index < len(lines) and lines[end_index] != "###": + end_index += 1 + + data = lines[start_index:end_index] + data = "\n".join(data) + + transcript_id_pattern = re.compile(r"transcript_id=([a-zA-Z0-9]+(?:\.[0-9]+)?)") + protein_id_pattern = re.compile(r"protein_id=([a-zA-Z0-9]+(?:\.[0-9]+)?)") + + transcript_matches = transcript_id_pattern.findall(data) + transcript_ids = transcript_matches if transcript_matches else [] + transcript_ids = list(set(transcript_ids)) + protein_matches = protein_id_pattern.findall(data) + protein_ids = protein_matches if protein_matches else [] + protein_ids = list(set(protein_ids)) + + gene_id_pattern = re.compile(r"gene_id=([a-zA-Z0-9]+)") + gene_match = gene_id_pattern.search(data) + gene_id = gene_match.group(1) if gene_match else None + + combinations = list(product([gene_id], transcript_ids, protein_ids)) + + for gene_id, transcript_id, protein_id in combinations: + mapping_table.append( + { + "Protein stable ID version": protein_id, + "Gene stable ID": gene_id, + "Transcript stable ID": transcript_id, + } + ) + + i = end_index + + mapping_table = pd.DataFrame(mapping_table) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + mapping_table.to_csv(f"{output_dir}/{organism_key}.tsv", sep="\t", index=False) + + +def get_stable_id_mapping_from_fasta( + fasta_file: str, + organism_key: str, + output_dir: str = "data/protein_embeddings/gene_protein_stable_ids", +): + mapping_table = [] + lines = open(fasta_file).read().split("\n") + + for line in lines: + if line.startswith(">"): + gene_symbol = line.split("gene_symbol:")[-1].split(" ")[0].strip() + gene_id = line.split("gene:")[-1].split(" ")[0].strip() + protein_id = line.split(" ")[0].split(">")[-1].strip() + transcript_id = line.split("transcript:")[-1].split(" ")[0].strip() + mapping_table.append( + { + "Protein stable ID version": protein_id, + "Gene stable ID": gene_id, + "Transcript stable ID": transcript_id, + "Gene symbol": gene_symbol, + } + ) + + mapping_table = pd.DataFrame(mapping_table) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + mapping_table.to_csv(f"{output_dir}/{organism_key}.tsv", sep="\t", index=False) diff --git a/preprocess/protein_embedding.py b/preprocess/protein_embedding.py new file mode 100644 index 0000000..a203b67 --- /dev/null +++ b/preprocess/protein_embedding.py @@ -0,0 +1,256 @@ +import argparse +import gzip +import json +import logging +import os +import pickle +import shutil +import urllib.request +from pathlib import Path + +import esm +import h5py +import numpy as np +import torch +from Bio import SeqIO +from esm import FastaBatchedDataset +from esm.data import Alphabet + +STABLE_ID_DIR = "gene_protein_stable_ids/" +FASTA_MANIFEST = "fasta_manifest_pep.json" + + +def save_as_hdf5(data_dict, output_path): + """Save dictionary as HDF5 file.""" + with h5py.File(output_path, "w") as f: + # Store the keys as a dataset + keys = list(data_dict.keys()) + f.create_dataset("keys", data=np.array(keys, dtype="S")) + + # Create a group for the arrays + arrays_group = f.create_group("arrays") + for key, value in data_dict.items(): + arrays_group.create_dataset(str(key), data=value) + + +def clean_sequence(seq: str): + """ + Cleans the input protein sequence by replacing any asterisk (*) characters with the token. + + Args: + seq (str): The input protein sequence. + + Returns + ------- + str: The cleaned protein sequence with asterisks replaced by . + """ + return seq.replace("*", "") + + +def pad_batch(toks, num_gpus): + """ + Pads the batch to ensure its size is a multiple of the number of GPUs. + + Args: + toks (torch.Tensor): The tokenized sequences. + num_gpus (int): The number of GPUs. + + Returns + ------- + torch.Tensor: The padded tokenized sequences. + """ + batch_size = toks.size(0) + if batch_size % num_gpus != 0: + padding_size = num_gpus - (batch_size % num_gpus) + padding = torch.zeros((padding_size, toks.size(1)), dtype=toks.dtype) + toks = torch.cat([toks, padding], dim=0) + return toks + + +def generate_embeddings( + model: torch.nn.Module, + alphabet: Alphabet, + fasta: str, + save_file: str, + seq_length=1022, + batch_size=16, +): + """ + Generates embeddings for protein sequences from a given FASTA file using a pre-trained model. + + Args: + model (torch.nn.Module): The pre-trained PyTorch model to use for generating embeddings. + alphabet (Alphabet): The alphabet object used for encoding sequences. + fasta (str): Path to the input FASTA file containing protein sequences. + save_file (str): Path to save the generated embeddings. + seq_length (int, optional): Maximum sequence length for the embeddings. Defaults to 1022. + batch_size (int, optional): Batch size for processing. Defaults to 16. + + Returns + ------- + None + """ + save_dir = os.path.dirname(save_file) + if save_dir and not os.path.exists(save_dir): + os.makedirs(save_dir) + + dataset = FastaBatchedDataset.from_file(fasta) + + num_tokens = 4096 * batch_size + batches = dataset.get_batch_indices(num_tokens, extra_toks_per_seq=1) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=alphabet.get_batch_converter(seq_length), + batch_sampler=batches, + num_workers=0, + ) + + dataset.sequence_strs = [clean_sequence(seq) for seq in dataset.sequence_strs] + + if os.path.exists(save_file): + os.remove(save_file) + + embeddings = {} + num_gpus = torch.cuda.device_count() + with torch.no_grad(): + for batch_idx, (labels, strs, toks) in enumerate(data_loader): + print(f"Processing batch {batch_idx + 1} of {len(batches)}") + if torch.cuda.is_available(): + toks = pad_batch(toks, num_gpus).to(device="cuda", non_blocking=True) + + out = model(toks, repr_layers=[33], return_contacts=False) + + representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()} + + for i, label in enumerate(labels): + truncate_len = min(seq_length, len(strs[i])) + embedding = representations[33][i, 1 : truncate_len + 1].mean(0).numpy() + + entry_id = label.split()[0] + + if entry_id in embeddings: + embeddings[entry_id].append(embedding) + else: + embeddings[entry_id] = [embedding] + + # Dump as we go just in case pipeline crashes + temp_save_file = save_file + ".tmp" + pickle.dump(embeddings, open(temp_save_file, "wb")) + + averaged_embeddings = {k: np.mean(v, axis=0) for k, v in embeddings.items()} + + save_as_hdf5(averaged_embeddings, save_file) + + +def main(): + parser = argparse.ArgumentParser(description="Generate protein embeddings and convert to gene embeddings.") + parser.add_argument( + "--output_dir", + type=str, + default="./", + required=False, + help="Directory to save output files", + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="Batch size for embedding generation", + ) + parser.add_argument( + "--organism_key", + type=str, + default="homo_sapiens", + help="The organism key to generate protein embeddings for", + ) + parser.add_argument( + "--use_large_model", + action="store_true", + help="Whether to use the large ESM-2 model", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + # Load FASTA URL manifest + with open(FASTA_MANIFEST) as f: + fasta_urls = json.load(f) + + if args.organism_key not in fasta_urls: + raise ValueError(f"Organism {args.organism_key} is not a valid organism in the fasta manifest") + + # Create stable_id_dir if it doesn't exist + stable_id_dir = Path(STABLE_ID_DIR) + stable_id_dir.mkdir(parents=True, exist_ok=True) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load ESM-2 model + if args.use_large_model: + model, alphabet = esm.pretrained.esm2_t48_15B_UR50D() + suffix = "_large" + else: + model, alphabet = esm.pretrained.esm2_t36_3B_UR50D() + suffix = "" + + model.eval() # disables dropout for deterministic results + if torch.cuda.is_available(): + model = model.cuda() + + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + organism = args.organism_key + fasta_url = fasta_urls[organism]["fa"] + + fasta_file = stable_id_dir / f"{organism}.fa" + if not fasta_file.exists(): + logging.info(f"Downloading FASTA for {organism}") + # Download and decompress in one step using Python + with urllib.request.urlopen(fasta_url) as response: + if response.headers.get("Content-Encoding") == "gzip": + with gzip.GzipFile(fileobj=response) as gz_file: + with open(fasta_file, "w") as out_file: + shutil.copyfileobj(gz_file, out_file) + else: + with open(fasta_file, "w") as out_file: + shutil.copyfileobj(response, out_file) + + # Convert to gene IDs + new_records = [] + seen_names = set() + for record in SeqIO.parse(fasta_file, "fasta"): + if not args.use_large_model: + gene_id = record.description.split("gene:")[-1].split(" ")[0].strip().split(".")[0] + else: + if "gene_symbol" not in record.description: + gene_id = record.description.split("gene:")[-1].split(" ")[0].strip().split(".")[0] + else: + gene_id = record.description.split("gene_symbol:")[-1].split(" ")[0].strip() + + if gene_id in seen_names: + continue + seen_names.add(gene_id) + record.id = gene_id + record.name = gene_id + new_records.append(record) + + with open(fasta_file, "w") as f: + SeqIO.write(new_records, f, "fasta") + + emb_file_name = output_dir / f"{organism}_gene{suffix}.h5" + logging.info(f"Processing {fasta_file} for {organism}") + + generate_embeddings( + model, + alphabet, + str(fasta_file), + save_file=str(emb_file_name), + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() diff --git a/preprocess/requirements.txt b/preprocess/requirements.txt new file mode 100644 index 0000000..857175a --- /dev/null +++ b/preprocess/requirements.txt @@ -0,0 +1,5 @@ +numpy==1.24.4 +pandas==2.2.1 +scipy==1.13.0 +biopython +h5py