diff --git a/docs/usage/blast.ipynb b/docs/usage/blast.ipynb
index b56140d7..d6cd57ef 100644
--- a/docs/usage/blast.ipynb
+++ b/docs/usage/blast.ipynb
@@ -1,363 +1,364 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# BLAST Search\n",
- "\n",
- "## Setup\n",
- "\n",
- "The BLAST service runs in a Docker container and requires:\n",
- "1. A local BLAST database\n",
- "2. The Docker service running"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# change log level to INFO\n",
- "import sys\n",
- "from loguru import logger\n",
- "\n",
- "logger.remove()\n",
- "level = logger.add(sys.stderr, level=\"WARNING\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Basic Usage\n",
- "\n",
- "The `Blast` class provides an interface to search protein or nucleotide sequences against a local BLAST database."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " subject_id | \n",
- " identity | \n",
- " alignment_length | \n",
- " mismatches | \n",
- " gap_opens | \n",
- " query_start | \n",
- " query_end | \n",
- " subject_start | \n",
- " subject_end | \n",
- " evalue | \n",
- " bit_score | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " seq7 | \n",
- " 81.818 | \n",
- " 22 | \n",
- " 3 | \n",
- " 1 | \n",
- " 31 | \n",
- " 51 | \n",
- " 11 | \n",
- " 32 | \n",
- " 0.003 | \n",
- " 22.3 | \n",
- "
\n",
- " \n",
- " | 1 | \n",
- " seq1 | \n",
- " 100.000 | \n",
- " 25 | \n",
- " 0 | \n",
- " 0 | \n",
- " 1 | \n",
- " 25 | \n",
- " 1 | \n",
- " 25 | \n",
- " 0.004 | \n",
- " 22.3 | \n",
- "
\n",
- " \n",
- " | 2 | \n",
- " seq2 | \n",
- " 61.538 | \n",
- " 26 | \n",
- " 10 | \n",
- " 0 | \n",
- " 20 | \n",
- " 45 | \n",
- " 5 | \n",
- " 30 | \n",
- " 0.038 | \n",
- " 19.2 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " subject_id identity alignment_length mismatches gap_opens query_start \\\n",
- "0 seq7 81.818 22 3 1 31 \n",
- "1 seq1 100.000 25 0 0 1 \n",
- "2 seq2 61.538 26 10 0 20 \n",
- "\n",
- " query_end subject_start subject_end evalue bit_score \n",
- "0 51 11 32 0.003 22.3 \n",
- "1 25 1 25 0.004 22.3 \n",
- "2 45 5 30 0.038 19.2 "
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from pyeed.tools import Blast\n",
- "\n",
- "# Example protein sequence\n",
- "sequence = \"MSEQVAAVAKLRAKASEAAKEAKAREAAKKLAEAAKKAKAKEAAKRAEAKLAEKAKAAKRAEAKAAKEAKRAAAKRAEAKLAEKAKAAK\"\n",
- "\n",
- "# Initialize BLAST search\n",
- "blast = Blast(\n",
- " # service_url=\"http://localhost:6001/blast\",\n",
- " mode=\"blastp\", # Use blastp for protein sequences\n",
- " db_path=\"/usr/local/bin/data/test_db\", # Path in Docker container\n",
- " db_name=\"protein_db\", # Name of your BLAST database\n",
- " evalue=0.1, # E-value threshold\n",
- " max_target_seqs=10, # Maximum number of hits to return\n",
- ")\n",
- "\n",
- "# Perform search\n",
- "results = blast.search(sequence)\n",
- "results"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The results are returned as a pandas DataFrame with the following columns:\n",
- "- subject_id: ID of the matched sequence\n",
- "- identity: Percentage identity\n",
- "- alignment_length: Length of the alignment\n",
- "- mismatches: Number of mismatches\n",
- "- gap_opens: Number of gap openings\n",
- "- query_start/end: Start/end positions in query sequence\n",
- "- subject_start/end: Start/end positions in subject sequence\n",
- "- evalue: Expectation value\n",
- "- bit_score: Bit score"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Creating a BLAST Database\n",
- "\n",
- "Before using BLAST, you need to create a local database. Here's how to create one from a FASTA file:"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "```bash\n",
- "# For protein sequences\n",
- "makeblastdb -in proteins.fasta -dbtype prot -out blast_db/my_proteins\n",
- "\n",
- "# For nucleotide sequences\n",
- "makeblastdb -in nucleotides.fasta -dbtype nucl -out blast_db/my_nucleotides\n",
- "```\n",
- "\n",
- "To access the BLAST Docker container shell and create databases:\n",
- "\n",
- "```bash\n",
- "# Enter the BLAST container shell\n",
- "docker compose exec blast bash\n",
- "# \n",
- "# Navigate to database directory\n",
- "cd /usr/local/bin/data/blast_db\n",
- "# \n",
- "# Create protein database\n",
- "makeblastdb -in proteins.fasta -dbtype prot -out my_proteins\n",
- "# \n",
- "# Create nucleotide database \n",
- "makeblastdb -in nucleotides.fasta -dbtype nucl -out my_nucleotides\n",
- "```\n",
- "Make sure your FASTA files are mounted in the container's `/usr/local/bin/data/blast_db` directory.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Advanced Usage\n",
- "\n",
- "You can customize the BLAST search parameters:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " subject_id | \n",
- " identity | \n",
- " alignment_length | \n",
- " mismatches | \n",
- " gap_opens | \n",
- " query_start | \n",
- " query_end | \n",
- " subject_start | \n",
- " subject_end | \n",
- " evalue | \n",
- " bit_score | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " seq7 | \n",
- " 81.818 | \n",
- " 22 | \n",
- " 3 | \n",
- " 1 | \n",
- " 31 | \n",
- " 51 | \n",
- " 11 | \n",
- " 32 | \n",
- " 0.003 | \n",
- " 22.3 | \n",
- "
\n",
- " \n",
- " | 1 | \n",
- " seq1 | \n",
- " 100.000 | \n",
- " 25 | \n",
- " 0 | \n",
- " 0 | \n",
- " 1 | \n",
- " 25 | \n",
- " 1 | \n",
- " 25 | \n",
- " 0.004 | \n",
- " 22.3 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " subject_id identity alignment_length mismatches gap_opens query_start \\\n",
- "0 seq7 81.818 22 3 1 31 \n",
- "1 seq1 100.000 25 0 0 1 \n",
- "\n",
- " query_end subject_start subject_end evalue bit_score \n",
- "0 51 11 32 0.003 22.3 \n",
- "1 25 1 25 0.004 22.3 "
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Configure BLAST for sensitive protein search\n",
- "blast = Blast(\n",
- " # service_url=\"http://localhost:6001/blast\",\n",
- " mode=\"blastp\",\n",
- " db_path=\"/usr/local/bin/data/test_db\",\n",
- " db_name=\"protein_db\",\n",
- " evalue=1e-1, # More stringent E-value\n",
- " max_target_seqs=100, # Return more hits\n",
- " num_threads=4, # Use 4 CPU threads\n",
- ")\n",
- "\n",
- "# Search with longer timeout\n",
- "results = blast.search(sequence, timeout=7200) # 2 hour timeout\n",
- "\n",
- "# Filter results\n",
- "significant_hits = results[results[\"identity\"] > 80] # Only hits with >90% identity\n",
- "significant_hits"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Thereafter, the ids of the hits can be added to the pyeed database, using the `fetch_from_primary_db` function."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "pyeed",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# BLAST Search\n",
+ "\n",
+ "## Setup\n",
+ "\n",
+ "The BLAST service runs in a Docker container and requires:\n",
+ "1. A local BLAST database\n",
+ "2. The Docker service running"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# change log level to INFO\n",
+ "import sys\n",
+ "\n",
+ "from loguru import logger\n",
+ "\n",
+ "logger.remove()\n",
+ "level = logger.add(sys.stderr, level=\"WARNING\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Basic Usage\n",
+ "\n",
+ "The `Blast` class provides an interface to search protein or nucleotide sequences against a local BLAST database."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " subject_id | \n",
+ " identity | \n",
+ " alignment_length | \n",
+ " mismatches | \n",
+ " gap_opens | \n",
+ " query_start | \n",
+ " query_end | \n",
+ " subject_start | \n",
+ " subject_end | \n",
+ " evalue | \n",
+ " bit_score | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " seq7 | \n",
+ " 81.818 | \n",
+ " 22 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 31 | \n",
+ " 51 | \n",
+ " 11 | \n",
+ " 32 | \n",
+ " 0.003 | \n",
+ " 22.3 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " seq1 | \n",
+ " 100.000 | \n",
+ " 25 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 25 | \n",
+ " 1 | \n",
+ " 25 | \n",
+ " 0.004 | \n",
+ " 22.3 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " seq2 | \n",
+ " 61.538 | \n",
+ " 26 | \n",
+ " 10 | \n",
+ " 0 | \n",
+ " 20 | \n",
+ " 45 | \n",
+ " 5 | \n",
+ " 30 | \n",
+ " 0.038 | \n",
+ " 19.2 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " subject_id identity alignment_length mismatches gap_opens query_start \\\n",
+ "0 seq7 81.818 22 3 1 31 \n",
+ "1 seq1 100.000 25 0 0 1 \n",
+ "2 seq2 61.538 26 10 0 20 \n",
+ "\n",
+ " query_end subject_start subject_end evalue bit_score \n",
+ "0 51 11 32 0.003 22.3 \n",
+ "1 25 1 25 0.004 22.3 \n",
+ "2 45 5 30 0.038 19.2 "
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from pyeed.tools import Blast\n",
+ "\n",
+ "# Example protein sequence\n",
+ "sequence = \"MSEQVAAVAKLRAKASEAAKEAKAREAAKKLAEAAKKAKAKEAAKRAEAKLAEKAKAAKRAEAKAAKEAKRAAAKRAEAKLAEKAKAAK\"\n",
+ "\n",
+ "# Initialize BLAST search\n",
+ "blast = Blast(\n",
+ " # service_url=\"http://localhost:6001/blast\",\n",
+ " mode=\"blastp\", # Use blastp for protein sequences\n",
+ " db_path=\"/usr/local/bin/data/test_db\", # Path in Docker container\n",
+ " db_name=\"protein_db\", # Name of your BLAST database\n",
+ " evalue=0.1, # E-value threshold\n",
+ " max_target_seqs=10, # Maximum number of hits to return\n",
+ ")\n",
+ "\n",
+ "# Perform search\n",
+ "results = blast.search(sequence)\n",
+ "results"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The results are returned as a pandas DataFrame with the following columns:\n",
+ "- subject_id: ID of the matched sequence\n",
+ "- identity: Percentage identity\n",
+ "- alignment_length: Length of the alignment\n",
+ "- mismatches: Number of mismatches\n",
+ "- gap_opens: Number of gap openings\n",
+ "- query_start/end: Start/end positions in query sequence\n",
+ "- subject_start/end: Start/end positions in subject sequence\n",
+ "- evalue: Expectation value\n",
+ "- bit_score: Bit score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating a BLAST Database\n",
+ "\n",
+ "Before using BLAST, you need to create a local database. Here's how to create one from a FASTA file:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "```bash\n",
+ "# For protein sequences\n",
+ "makeblastdb -in proteins.fasta -dbtype prot -out blast_db/my_proteins\n",
+ "\n",
+ "# For nucleotide sequences\n",
+ "makeblastdb -in nucleotides.fasta -dbtype nucl -out blast_db/my_nucleotides\n",
+ "```\n",
+ "\n",
+ "To access the BLAST Docker container shell and create databases:\n",
+ "\n",
+ "```bash\n",
+ "# Enter the BLAST container shell\n",
+ "docker compose exec blast bash\n",
+ "# \n",
+ "# Navigate to database directory\n",
+ "cd /usr/local/bin/data/blast_db\n",
+ "# \n",
+ "# Create protein database\n",
+ "makeblastdb -in proteins.fasta -dbtype prot -out my_proteins\n",
+ "# \n",
+ "# Create nucleotide database \n",
+ "makeblastdb -in nucleotides.fasta -dbtype nucl -out my_nucleotides\n",
+ "```\n",
+ "Make sure your FASTA files are mounted in the container's `/usr/local/bin/data/blast_db` directory.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Advanced Usage\n",
+ "\n",
+ "You can customize the BLAST search parameters:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " subject_id | \n",
+ " identity | \n",
+ " alignment_length | \n",
+ " mismatches | \n",
+ " gap_opens | \n",
+ " query_start | \n",
+ " query_end | \n",
+ " subject_start | \n",
+ " subject_end | \n",
+ " evalue | \n",
+ " bit_score | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " seq7 | \n",
+ " 81.818 | \n",
+ " 22 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 31 | \n",
+ " 51 | \n",
+ " 11 | \n",
+ " 32 | \n",
+ " 0.003 | \n",
+ " 22.3 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " seq1 | \n",
+ " 100.000 | \n",
+ " 25 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 25 | \n",
+ " 1 | \n",
+ " 25 | \n",
+ " 0.004 | \n",
+ " 22.3 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " subject_id identity alignment_length mismatches gap_opens query_start \\\n",
+ "0 seq7 81.818 22 3 1 31 \n",
+ "1 seq1 100.000 25 0 0 1 \n",
+ "\n",
+ " query_end subject_start subject_end evalue bit_score \n",
+ "0 51 11 32 0.003 22.3 \n",
+ "1 25 1 25 0.004 22.3 "
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Configure BLAST for sensitive protein search\n",
+ "blast = Blast(\n",
+ " # service_url=\"http://localhost:6001/blast\",\n",
+ " mode=\"blastp\",\n",
+ " db_path=\"/usr/local/bin/data/test_db\",\n",
+ " db_name=\"protein_db\",\n",
+ " evalue=1e-1, # More stringent E-value\n",
+ " max_target_seqs=100, # Return more hits\n",
+ " num_threads=4, # Use 4 CPU threads\n",
+ ")\n",
+ "\n",
+ "# Search with longer timeout\n",
+ "results = blast.search(sequence, timeout=7200) # 2 hour timeout\n",
+ "\n",
+ "# Filter results\n",
+ "significant_hits = results[results[\"identity\"] > 80] # Only hits with >90% identity\n",
+ "significant_hits"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Thereafter, the ids of the hits can be added to the pyeed database, using the `fetch_from_primary_db` function."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pyeed",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
}
diff --git a/docs/usage/clustalo.ipynb b/docs/usage/clustalo.ipynb
index 64ed62ee..d3ba2fba 100644
--- a/docs/usage/clustalo.ipynb
+++ b/docs/usage/clustalo.ipynb
@@ -1,171 +1,171 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Multiple Sequence Alignment with Clustal Omega\n",
- "\n",
- "PyEED provides a convenient interface to Clustal Omega for multiple sequence alignment. This notebook demonstrates how to:\n",
- "1. Align sequences from a dictionary\n",
- "2. Align sequences directly from the database"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "from pyeed import Pyeed\n",
- "from pyeed.tools.clustalo import ClustalOmega\n",
- "\n",
- "# change log level to INFO\n",
- "import sys\n",
- "from loguru import logger\n",
- "\n",
- "logger.remove()\n",
- "level = logger.add(sys.stderr, level=\"INFO\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Direct Sequence Alignment\n",
- "\n",
- "You can align sequences directly by providing a dictionary of sequences:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Aligned sequences:\n",
- "seq1 AKFVMPDRAWHLYTGNECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\n",
- "seq2 AKFVMPDRQWHLYTGQECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\n",
- "seq3 AKFVMPDRQWHLYTGNECSKQRLYVWFHDGAPILKTQADNMGAYRCALFHVTK----\n"
- ]
- }
- ],
- "source": [
- "# Initialize ClustalOmega\n",
- "clustalo = ClustalOmega()\n",
- "\n",
- "# Example sequences\n",
- "sequences = {\n",
- " \"seq1\": \"AKFVMPDRAWHLYTGNECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\",\n",
- " \"seq2\": \"AKFVMPDRQWHLYTGQECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\",\n",
- " \"seq3\": \"AKFVMPDRQWHLYTGNECSKQRLYVWFHDGAPILKTQADNMGAYRCALFHVTK\",\n",
- "}\n",
- "\n",
- "# Perform alignment\n",
- "alignment = clustalo.align(sequences)\n",
- "print(\"Aligned sequences:\")\n",
- "print(alignment)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Database-based Alignment\n",
- "\n",
- "You can also align sequences directly from the database by providing a list of accession IDs:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Pyeed Graph Object Mapping constraints not defined. Use _install_labels() to set up model constraints.\n",
- "📡 Connected to database.\n",
- "Database alignment:\n",
- "AAP20891.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "CAJ85677.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "SAQ02853.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGASERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "CDR98216.1 MSIQHFRVALIPFFAAFCFPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGASERGSRGIIAALGPDGKPSRIVVIYMTGSQATMDERNRQIAEIGASLIKHW\n",
- "WP_109963600.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGTGKRGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "CAA41038.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDHWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "WP_109874025.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "CAA46344.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGASERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
- "APG33178.1 MSIQHFRVALIPFFAAFCFPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYMTGSQATMDERNRQIAEIGASLIKHW\n",
- "AKC98298.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDHWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n"
- ]
- }
- ],
- "source": [
- "# Connect to database\n",
- "pyeed = Pyeed(uri=\"bolt://129.69.129.130:7687\", user=\"neo4j\", password=\"12345678\")\n",
- "\n",
- "# Get protein IDs from database\n",
- "from pyeed.model import Protein\n",
- "\n",
- "accession_ids = [protein.accession_id for protein in Protein.nodes.all()][:10]\n",
- "\n",
- "# Align sequences from database\n",
- "alignment = clustalo.align_from_db(accession_ids, pyeed.db)\n",
- "print(\"Database alignment:\")\n",
- "print(alignment)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Understanding Alignment Results\n",
- "\n",
- "The alignment result is a `MultipleSequenceAlignment` object with:\n",
- "- List of `Sequence` objects\n",
- "- Each sequence has an ID and aligned sequence\n",
- "- Gaps are represented by '-' characters\n",
- "- Sequences are padded to equal length\n",
- "\n",
- "The alignment preserves sequence order and maintains sequence IDs from the input."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Configuration\n",
- "\n",
- "ClustalOmega requires the PyEED Docker service to be running. Make sure to:\n",
- "1. Have Docker installed\n",
- "2. Start the service with `docker-compose up -d`\n",
- "3. The service runs on port 5001 by default"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "pyeed_niklas",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.8"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Multiple Sequence Alignment with Clustal Omega\n",
+ "\n",
+ "PyEED provides a convenient interface to Clustal Omega for multiple sequence alignment. This notebook demonstrates how to:\n",
+ "1. Align sequences from a dictionary\n",
+ "2. Align sequences directly from the database"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# change log level to INFO\n",
+ "import sys\n",
+ "\n",
+ "from loguru import logger\n",
+ "\n",
+ "from pyeed import Pyeed\n",
+ "from pyeed.model import Protein\n",
+ "from pyeed.tools.clustalo import ClustalOmega\n",
+ "\n",
+ "logger.remove()\n",
+ "level = logger.add(sys.stderr, level=\"INFO\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Direct Sequence Alignment\n",
+ "\n",
+ "You can align sequences directly by providing a dictionary of sequences:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Aligned sequences:\n",
+ "seq1 AKFVMPDRAWHLYTGNECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\n",
+ "seq2 AKFVMPDRQWHLYTGQECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\n",
+ "seq3 AKFVMPDRQWHLYTGNECSKQRLYVWFHDGAPILKTQADNMGAYRCALFHVTK----\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Initialize ClustalOmega\n",
+ "clustalo = ClustalOmega()\n",
+ "\n",
+ "# Example sequences\n",
+ "sequences = {\n",
+ " \"seq1\": \"AKFVMPDRAWHLYTGNECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\",\n",
+ " \"seq2\": \"AKFVMPDRQWHLYTGQECSKQRLYVWFHDGAPILKTQSDNMGAYRCPLFHVTKNWEI\",\n",
+ " \"seq3\": \"AKFVMPDRQWHLYTGNECSKQRLYVWFHDGAPILKTQADNMGAYRCALFHVTK\",\n",
+ "}\n",
+ "\n",
+ "# Perform alignment\n",
+ "alignment = clustalo.align(sequences)\n",
+ "print(\"Aligned sequences:\")\n",
+ "print(alignment)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Database-based Alignment\n",
+ "\n",
+ "You can also align sequences directly from the database by providing a list of accession IDs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Pyeed Graph Object Mapping constraints not defined. Use _install_labels() to set up model constraints.\n",
+ "📡 Connected to database.\n",
+ "Database alignment:\n",
+ "AAP20891.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "CAJ85677.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "SAQ02853.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGASERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "CDR98216.1 MSIQHFRVALIPFFAAFCFPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGASERGSRGIIAALGPDGKPSRIVVIYMTGSQATMDERNRQIAEIGASLIKHW\n",
+ "WP_109963600.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGTGKRGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "CAA41038.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDHWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "WP_109874025.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "CAA46344.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGASERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n",
+ "APG33178.1 MSIQHFRVALIPFFAAFCFPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVKYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDSWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYMTGSQATMDERNRQIAEIGASLIKHW\n",
+ "AKC98298.1 MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDKLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDHWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Connect to database\n",
+ "pyeed = Pyeed(uri=\"bolt://129.69.129.130:7687\", user=\"neo4j\", password=\"12345678\")\n",
+ "\n",
+ "# Get protein IDs from database\n",
+ "accession_ids = [protein.accession_id for protein in Protein.nodes.all()][:10]\n",
+ "\n",
+ "# Align sequences from database\n",
+ "alignment = clustalo.align_from_db(accession_ids, pyeed.db)\n",
+ "print(\"Database alignment:\")\n",
+ "print(alignment)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Understanding Alignment Results\n",
+ "\n",
+ "The alignment result is a `MultipleSequenceAlignment` object with:\n",
+ "- List of `Sequence` objects\n",
+ "- Each sequence has an ID and aligned sequence\n",
+ "- Gaps are represented by '-' characters\n",
+ "- Sequences are padded to equal length\n",
+ "\n",
+ "The alignment preserves sequence order and maintains sequence IDs from the input."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configuration\n",
+ "\n",
+ "ClustalOmega requires the PyEED Docker service to be running. Make sure to:\n",
+ "1. Have Docker installed\n",
+ "2. Start the service with `docker-compose up -d`\n",
+ "3. The service runs on port 5001 by default"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pyeed_niklas",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
}
diff --git a/docs/usage/embeddings_analysis.ipynb b/docs/usage/embeddings_analysis.ipynb
index 65a2398c..0b72e743 100644
--- a/docs/usage/embeddings_analysis.ipynb
+++ b/docs/usage/embeddings_analysis.ipynb
@@ -24,9 +24,10 @@
"source": [
"import sys\n",
"\n",
- "from loguru import logger\n",
- "import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from loguru import logger\n",
+ "\n",
"from pyeed import Pyeed\n",
"from pyeed.analysis.embedding_analysis import EmbeddingTool\n",
"\n",
diff --git a/docs/usage/mmseqs.ipynb b/docs/usage/mmseqs.ipynb
index 1185c6fe..2253fd8a 100644
--- a/docs/usage/mmseqs.ipynb
+++ b/docs/usage/mmseqs.ipynb
@@ -20,6 +20,7 @@
"outputs": [],
"source": [
"from pyeed import Pyeed\n",
+ "from pyeed.model import Protein\n",
"from pyeed.tools.mmseqs import MMSeqs"
]
},
@@ -134,8 +135,6 @@
"pyeed = Pyeed(uri=\"bolt://localhost:7687\", user=\"neo4j\", password=\"12345678\")\n",
"\n",
"# Get first 100 protein IDs from database\n",
- "from pyeed.model import Protein\n",
- "\n",
"accession_ids = [protein.accession_id for protein in Protein.nodes.all()][:100]\n",
"\n",
"# Cluster sequences\n",
diff --git a/docs/usage/mutation_analysis.ipynb b/docs/usage/mutation_analysis.ipynb
index 9ccabc1c..7d10d360 100644
--- a/docs/usage/mutation_analysis.ipynb
+++ b/docs/usage/mutation_analysis.ipynb
@@ -11,11 +11,12 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
+ "\n",
"from loguru import logger\n",
"\n",
"from pyeed import Pyeed\n",
@@ -37,7 +38,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -56,7 +57,7 @@
"\n",
"eedb = Pyeed(uri, user=user, password=password)\n",
"\n",
- "eedb.db.wipe_database(date=\"2025-03-14\")"
+ "eedb.db.wipe_database(date=\"2025-03-19\")"
]
},
{
@@ -75,14 +76,15 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"ids = [\"AAM15527.1\", \"AAF05614.1\", \"AFN21551.1\", \"CAA76794.1\", \"AGQ50511.1\"]\n",
"\n",
"eedb.fetch_from_primary_db(ids, db=\"ncbi_protein\")\n",
- "eedb.fetch_dna_entries_for_proteins()"
+ "eedb.fetch_dna_entries_for_proteins()\n",
+ "eedb.create_coding_sequences_regions()"
]
},
{
@@ -100,9 +102,42 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 12,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6ed852d438ab480fa4d1c6129eacfd26",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Region ids: [143, 129, 128, 69, 9]\n",
+ "len of ids: 5\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"sn_protein = StandardNumberingTool(name=\"test_standard_numbering_protein\")\n",
"\n",
@@ -111,11 +146,22 @@
" base_sequence_id=\"AAM15527.1\", db=eedb.db, list_of_seq_ids=ids\n",
")\n",
"\n",
- "sn_dna = StandardNumberingTool(name=\"test_standard_numbering_dna\")\n",
+ "sn_dna = StandardNumberingTool(name=\"test_standard_numbering_dna_pairwise\")\n",
"\n",
- "sn_dna.apply_standard_numbering(\n",
- " base_sequence_id=\"AF190695.1\", db=eedb.db, node_type=\"DNA\"\n",
- ")\n"
+ "query_get_region_ids = \"\"\"\n",
+ "MATCH (p:Protein)<-[rel:ENCODES]-(d:DNA)-[rel2:HAS_REGION]->(r:Region)\n",
+ "WHERE r.annotation = $region_annotation AND p.accession_id IN $protein_id\n",
+ "RETURN id(r)\n",
+ "\"\"\"\n",
+ "\n",
+ "region_ids = eedb.db.execute_read(query_get_region_ids, parameters={\"protein_id\": ids, \"region_annotation\": \"coding sequence\"})\n",
+ "region_ids = [id['id(r)'] for id in region_ids]\n",
+ "print(f\"Region ids: {region_ids}\")\n",
+ "print(f\"len of ids: {len(ids)}\")\n",
+ "\n",
+ "sn_dna.apply_standard_numbering_pairwise(\n",
+ " base_sequence_id=\"AF190695.1\", db=eedb.db, node_type=\"DNA\", region_ids_neo4j=region_ids\n",
+ ")"
]
},
{
@@ -136,7 +182,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -153,18 +199,19 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"md = MutationDetection()\n",
"\n",
+ "\n",
"seq1 = \"AF190695.1\"\n",
"seq2 = \"JX042489.1\"\n",
- "name_of_standard_numbering_tool = \"test_standard_numbering_dna\"\n",
+ "name_of_standard_numbering_tool = \"test_standard_numbering_dna_pairwise\"\n",
"\n",
"mutations_dna = md.get_mutations_between_sequences(\n",
- " seq1, seq2, eedb.db, name_of_standard_numbering_tool, node_type=\"DNA\"\n",
+ " seq1, seq2, eedb.db, name_of_standard_numbering_tool, node_type=\"DNA\", region_ids_neo4j=region_ids\n",
")"
]
},
@@ -183,14 +230,14 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'from_positions': [241, 272, 125], 'to_positions': [241, 272, 125], 'from_monomers': ['R', 'D', 'V'], 'to_monomers': ['S', 'N', 'I']}\n"
+ "{'from_positions': [241, 125, 272], 'to_positions': [241, 125, 272], 'from_monomers': ['R', 'V', 'D'], 'to_monomers': ['S', 'I', 'N']}\n"
]
}
],
@@ -216,29 +263,21 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Mutation on position 682 -> 615 with a nucleotide change of T -> C\n",
- "Mutation on position 407 -> 340 with a nucleotide change of C -> A\n",
- "Mutation on position 92 -> 25 with a nucleotide change of C -> A\n",
- "Mutation on position 162 -> 95 with a nucleotide change of G -> T\n",
- "Mutation on position 929 -> 862 with a nucleotide change of A -> C\n",
- "Mutation on position 346 -> 279 with a nucleotide change of A -> G\n",
- "Mutation on position 87 -> 20 with a nucleotide change of C -> A\n",
- "Mutation on position 88 -> 21 with a nucleotide change of T -> C\n",
- "Mutation on position 130 -> 63 with a nucleotide change of C -> T\n",
- "Mutation on position 175 -> 108 with a nucleotide change of G -> A\n",
- "Mutation on position 131 -> 64 with a nucleotide change of T -> C\n",
- "Mutation on position 132 -> 65 with a nucleotide change of A -> T\n",
- "Mutation on position 914 -> 847 with a nucleotide change of G -> A\n",
- "Mutation on position 604 -> 537 with a nucleotide change of T -> G\n",
- "Mutation on position 925 -> 858 with a nucleotide change of G -> A\n",
- "Mutation on position 226 -> 159 with a nucleotide change of T -> C\n"
+ "Mutation on position 705 -> 705 with a nucleotide change of G -> A\n",
+ "Mutation on position 395 -> 395 with a nucleotide change of T -> G\n",
+ "Mutation on position 137 -> 137 with a nucleotide change of A -> G\n",
+ "Mutation on position 17 -> 17 with a nucleotide change of T -> C\n",
+ "Mutation on position 473 -> 473 with a nucleotide change of T -> C\n",
+ "Mutation on position 716 -> 716 with a nucleotide change of G -> A\n",
+ "Mutation on position 720 -> 720 with a nucleotide change of A -> C\n",
+ "Mutation on position 198 -> 198 with a nucleotide change of C -> A\n"
]
}
],
diff --git a/docs/usage/network_analysis.ipynb b/docs/usage/network_analysis.ipynb
index 4d45db71..0b254610 100644
--- a/docs/usage/network_analysis.ipynb
+++ b/docs/usage/network_analysis.ipynb
@@ -11,6 +11,7 @@
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"from loguru import logger\n",
+ "\n",
"from pyeed import Pyeed\n",
"from pyeed.analysis.network_analysis import NetworkAnalysis\n",
"from pyeed.analysis.sequence_alignment import PairwiseAligner\n",
diff --git a/docs/usage/standard_numbering.ipynb b/docs/usage/standard_numbering.ipynb
index d2132d3c..54374cd6 100644
--- a/docs/usage/standard_numbering.ipynb
+++ b/docs/usage/standard_numbering.ipynb
@@ -16,17 +16,17 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"import sys\n",
+ "\n",
"from loguru import logger\n",
"\n",
"from pyeed import Pyeed\n",
- "from pyeed.analysis.mutation_detection import MutationDetection\n",
"from pyeed.analysis.standard_numbering import StandardNumberingTool\n",
"\n",
"logger.remove()\n",
@@ -35,7 +35,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
@@ -94,32 +94,33 @@
"password = \"12345678\"\n",
"\n",
"eedb = Pyeed(uri, user=user, password=password)\n",
- "eedb.db.wipe_database(date=\"2025-03-14\")\n",
+ "eedb.db.wipe_database(date=\"2025-03-19\")\n",
"\n",
"eedb.db.initialize_db_constraints(user=user, password=password)\n"
]
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"ids = [\"AAM15527.1\", \"AAF05614.1\", \"AFN21551.1\", \"CAA76794.1\", \"AGQ50511.1\"]\n",
"\n",
"eedb.fetch_from_primary_db(ids, db=\"ncbi_protein\")\n",
- "eedb.fetch_dna_entries_for_proteins()"
+ "eedb.fetch_dna_entries_for_proteins()\n",
+ "eedb.create_coding_sequences_regions()"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "6470d4b80eb648e1af401b2e59cbe95b",
+ "model_id": "0f961c177f1444fb8190669487a1cb89",
"version_major": 2,
"version_minor": 0
},
@@ -152,13 +153,13 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "47406b43c98e4b31ba41eb15f7cdd000",
+ "model_id": "b7c38c15de4c4fa2bcf3f0a223d527b0",
"version_major": 2,
"version_minor": 0
},
@@ -188,7 +189,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@@ -201,7 +202,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@@ -214,13 +215,13 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "6a3fb35a08714174b353558dedff592c",
+ "model_id": "b204fcf51571421b8fff36de4e9ba9dd",
"version_major": 2,
"version_minor": 0
},
@@ -250,6 +251,68 @@
")"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "526ca870c8fb4b76b2df332a4b06af18",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Region ids: [13, 0, 41, 38, 19]\n",
+ "len of ids: 5\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sn_dna_region = StandardNumberingTool(name=\"test_standard_numbering_dna_pairwise_region\")\n",
+ "\n",
+ "\n",
+ "ids = [\"AAM15527.1\", \"AAF05614.1\", \"AFN21551.1\", \"CAA76794.1\", \"AGQ50511.1\"]\n",
+ "\n",
+ "\n",
+ "query_get_region_ids = \"\"\"\n",
+ "MATCH (p:Protein)<-[rel:ENCODES]-(d:DNA)-[rel2:HAS_REGION]->(r:Region)\n",
+ "WHERE r.annotation = $region_annotation AND p.accession_id IN $protein_id\n",
+ "RETURN id(r)\n",
+ "\"\"\"\n",
+ "\n",
+ "region_ids = eedb.db.execute_read(query_get_region_ids, parameters={\"protein_id\": ids, \"region_annotation\": \"coding sequence\"})\n",
+ "region_ids = [id['id(r)'] for id in region_ids]\n",
+ "print(f\"Region ids: {region_ids}\")\n",
+ "print(f\"len of ids: {len(ids)}\")\n",
+ "\n",
+ "\n",
+ "sn_dna_region.apply_standard_numbering_pairwise(\n",
+ " base_sequence_id=\"AF190695.1\", db=eedb.db, node_type=\"DNA\", region_ids_neo4j=region_ids\n",
+ ")"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/pyproject.toml b/pyproject.toml
index dd10629a..b9897071 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,6 +36,10 @@ esm = "^3.1.3"
rdflib = "^6.0.0"
docker = "5.0.0"
absl-py = "1.0.0"
+crc64iso = "0.0.2"
+SPARQLWrapper = "2.0.0"
+pysam = "0.23.0"
+types-requests = "2.32.0.20250328"
[tool.poetry.group.dev.dependencies]
mkdocstrings = {extras = ["python"], version = "^0.26.2"}
diff --git a/src/pyeed/adapter/ncbi_protein_mapper.py b/src/pyeed/adapter/ncbi_protein_mapper.py
index e11d4fe7..3ecf485c 100644
--- a/src/pyeed/adapter/ncbi_protein_mapper.py
+++ b/src/pyeed/adapter/ncbi_protein_mapper.py
@@ -281,6 +281,8 @@ def add_to_db(self, response: Response) -> None:
protein = Protein(**protein_data)
protein.save()
+ if not isinstance(organism, Organism):
+ raise TypeError(f"Expected Organism, but got {type(organism)}")
protein.organism.connect(organism)
# Add features
diff --git a/src/pyeed/adapter/ncbi_to_uniprot_mapper.py b/src/pyeed/adapter/ncbi_to_uniprot_mapper.py
new file mode 100644
index 00000000..1373547a
--- /dev/null
+++ b/src/pyeed/adapter/ncbi_to_uniprot_mapper.py
@@ -0,0 +1,132 @@
+import json
+import logging
+import os
+import sys
+from typing import List
+
+import httpx
+from crc64iso import crc64iso
+from pysam import FastaFile
+
+logger = logging.getLogger(__name__)
+
+
+class NCBIToUniprotMapper:
+ def __init__(self, ids: List[str], file: str):
+ self.ids = ids
+ self.file = file
+ self.uniparc_url = "https://www.ebi.ac.uk/proteins/api/uniparc?offset=0&size=100&sequencechecksum="
+ self.ncbi_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
+
+ def download_fasta(self, refseq_id: str) -> None:
+ """
+ Downloads a FASTA file for a given RefSeq ID using httpx and saves it locally.
+
+ Args:
+ refseq_id str: NCBI ID
+ """
+
+ params = {
+ "db": "protein",
+ "id": refseq_id,
+ "rettype": "fasta",
+ "retmode": "text",
+ }
+
+ try:
+ response = httpx.get(self.ncbi_url, params=params, timeout=10.0)
+
+ if response.status_code == 200:
+ filename = f"{refseq_id}.fasta"
+ with open(filename, "w") as f:
+ f.write(response.text)
+ print(f"✅ Downloaded: {filename}")
+ else:
+ print(
+ f"❌ Failed to download {refseq_id} (Status: {response.status_code})"
+ )
+
+ except httpx.HTTPError as e:
+ print(f"❌ HTTP error occurred while downloading {refseq_id}: {e}")
+
+ def get_checksum(self, refseq_id: str) -> str:
+ """Fetches and calculates the checksum for a given RefSeq ID.
+
+ Args:
+ refseq_id str: NCBI ID
+
+ Returns:
+ str: checksum ID
+ """
+
+ self.download_fasta(refseq_id)
+ fa = FastaFile(f"{refseq_id}.fasta")
+ seq = fa.fetch(fa.references[0])
+ return f"{crc64iso.crc64(seq)}"
+
+ def checksum_list(self, refseq_ids: List[str]) -> List[str]:
+ """Creates a list of checksum IDs and deletes the FASTA files after processing.
+
+ Args:
+ refseq_ids str: NCBI IDs
+
+ Returns:
+ List[str]: cheksum IDs
+ """
+
+ checksums = []
+ for refseq_id in refseq_ids:
+ checksums.append(self.get_checksum(refseq_id))
+ fasta_file_path = f"{refseq_id}.fasta"
+ fai_file_path = f"{refseq_id}.fasta.fai"
+
+ if os.path.exists(fasta_file_path):
+ os.remove(fasta_file_path) # Delete the fasta file
+
+ if os.path.exists(fai_file_path):
+ os.remove(fai_file_path)
+ return checksums
+
+ def execute_request(self) -> None:
+ """Fetches the uniparc and uniprot ids for the given refseq ids and saves them in a json file."""
+
+ checksum_list = self.checksum_list(self.ids)
+
+ id_mapping_uniprot = {}
+ id_mapping_uniparc = {}
+ counter = 0
+
+ for checksum in checksum_list:
+ url = f"{self.uniparc_url}{checksum}"
+
+ # perform request and get response as JSON
+ with httpx.Client() as client:
+ response = client.get(url, headers={"Accept": "application/json"})
+
+ # check if the request was successful
+ if response.status_code != 200:
+ print(f"Request failed with status code {response.status_code}")
+ response.raise_for_status() # Raise exception for any non-200 response
+ sys.exit()
+
+ # Check if the response body is empty
+ if not response.content.strip(): # Check if the body is empty
+ print("The response body is empty.")
+ sys.exit()
+
+ # extracts the uniprot and the uniparc id from the repsonse and saves them in a dictionary
+ response_body = response.json()
+ for item in response_body:
+ uniparc_id = item.get("accession", None)
+ for ref in item.get("dbReference", []):
+ if ref.get("type") == "UniProtKB/TrEMBL":
+ uniprot_id = ref.get("id", None)
+ id_mapping_uniparc[self.ids[counter]] = uniparc_id
+ id_mapping_uniprot[self.ids[counter]] = uniprot_id
+ counter += 1
+
+ with open(f"{self.file}_uniprot.json", "w") as f:
+ json.dump(id_mapping_uniprot, f)
+
+ with open(f"{self.file}_uniparc.json", "w") as f:
+ json.dump(id_mapping_uniparc, f)
diff --git a/src/pyeed/adapter/uniprot_mapper.py b/src/pyeed/adapter/uniprot_mapper.py
index 5a285adb..f52d01b3 100644
--- a/src/pyeed/adapter/uniprot_mapper.py
+++ b/src/pyeed/adapter/uniprot_mapper.py
@@ -1,17 +1,21 @@
import json
from collections import defaultdict
-from typing import Any
+from typing import Any, List, Optional
+import requests
+from bs4 import BeautifulSoup, Tag
from httpx import Response
from loguru import logger
+from SPARQLWrapper import JSON, SPARQLWrapper
from pyeed.adapter.primary_db_adapter import PrimaryDBMapper
from pyeed.model import (
Annotation,
- CatalyticActivity,
GOAnnotation,
+ Molecule,
Organism,
Protein,
+ Reaction,
Site,
)
@@ -57,9 +61,9 @@ def add_to_db(self, response: Response) -> None:
return
protein.organism.connect(organism)
+ self.add_reaction(record, protein)
self.add_sites(record, protein)
- self.add_catalytic_activity(record, protein)
self.add_go(record, protein)
def add_sites(self, record: dict[str, Any], protein: Protein) -> None:
@@ -79,22 +83,137 @@ def add_sites(self, record: dict[str, Any], protein: Protein) -> None:
protein.site.connect(site, {"positions": positions})
- def add_catalytic_activity(self, record: dict[str, Any], protein: Protein) -> None:
- try:
- for reference in record["comments"]:
- if reference["type"] == "CATALYTIC_ACTIVITY":
- catalytic_annotation = CatalyticActivity.get_or_save(
- catalytic_id=int(reference["id"])
- if reference.get("id")
- else None,
- name=reference["reaction"]["name"],
- )
- protein.catalytic_annotation.connect(catalytic_annotation)
-
- except Exception as e:
- logger.error(
- f"Error saving catalytic activity for {protein.accession_id}: {e}"
+ def get_substrates_and_products_from_rhea(
+ self, rhea_id: str
+ ) -> dict[str, List[str]]:
+ """Fetch substrates and products from Rhea by parsing the side URI (_L = substrate, _R = product).
+
+ Args:
+ rhea_id (str or int): The Rhea reaction ID (e.g., 49528)
+
+ Returns:
+ dict: {
+ 'substrates': [list of chebi URIs],
+ 'products': [list of chebi URIs]
+ }
+ """
+ rhea_id = rhea_id.strip().replace("RHEA:", "")
+ rhea_id_str = str(rhea_id).strip()
+ sparql = SPARQLWrapper("https://sparql.rhea-db.org/sparql")
+ sparql.setQuery(f"""
+ PREFIX rh:
+ PREFIX rdfs:
+
+ SELECT DISTINCT ?participant ?compound ?chebi ?side
+ WHERE {{
+ rh:{rhea_id_str} rh:side ?side .
+ ?side rh:contains ?participant .
+ ?participant rh:compound ?compound .
+ OPTIONAL {{ ?compound rh:chebi ?chebi . }}
+ OPTIONAL {{ ?compound rh:underlyingChebi ?chebi . }}
+ OPTIONAL {{
+ ?compound rdfs:seeAlso ?chebi .
+ FILTER STRSTARTS(STR(?chebi), "http://purl.obolibrary.org/obo/CHEBI_")
+ }}
+ }}
+ """)
+ sparql.setReturnFormat(JSON)
+ sparql.addCustomHttpHeader("User-Agent", "MyPythonClient/1.0")
+
+ results_raw = sparql.query().convert()
+ if not isinstance(results_raw, dict):
+ raise TypeError("Expected dict from SPARQL query")
+
+ results: dict[str, Any] = results_raw
+
+ substrates = set()
+ products = set()
+
+ for r in results["results"]["bindings"]:
+ chebi_uri = r.get("chebi", {}).get("value")
+ if not chebi_uri:
+ logger.info(f"No ChEBI URI found for compound {r['compound']['value']}")
+
+ side_uri = r["side"]["value"]
+ if side_uri.endswith("_L"):
+ substrates.add(chebi_uri)
+ elif side_uri.endswith("_R"):
+ products.add(chebi_uri)
+
+ return {"substrates": sorted(substrates), "products": sorted(products)}
+
+ def get_smiles_from_chebi_web(self, chebi_url: str) -> Optional[str]:
+ """
+ Extract SMILES from the official ChEBI page using HTML scraping.
+ """
+ chebi_id = chebi_url.split("_")[-1]
+ url = f"https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{chebi_id}"
+
+ response = requests.get(url)
+ soup = BeautifulSoup(response.text, "html.parser")
+
+ # Look for table rows that contain the SMILES label
+ for table in soup.find_all("table", class_="chebiTableContent"):
+ if not isinstance(table, Tag):
+ continue
+ for row in table.find_all("tr"):
+ if not isinstance(row, Tag):
+ continue
+ headers = row.find_all("td", class_="chebiDataHeader")
+ if (
+ headers
+ and isinstance(headers[0], Tag)
+ and "SMILES" in headers[0].text
+ ):
+ data_cells = row.find_all("td")
+ if data_cells:
+ return f"{data_cells[-1].text.strip()}"
+ return None
+
+ def add_reaction(self, record: dict[str, Any], protein: Protein) -> None:
+ for reference in record.get("comments", []): # Safe retrieval with .get()
+ if reference.get("type") == "CATALYTIC_ACTIVITY":
+ rhea_id = None # Default value
+
+ for db_ref in reference.get("reaction", {}).get("dbReferences", []):
+ if db_ref.get("id", "").startswith("RHEA:"):
+ rhea_id = db_ref["id"]
+ break # Stop after finding the first match
+
+ catalytic_annotation = Reaction.get_or_save(
+ rhea_id=rhea_id,
+ )
+ if rhea_id is not None:
+ self.add_molecule(rhea_id, catalytic_annotation)
+ protein.reaction.connect(catalytic_annotation)
+
+ def add_molecule(self, rhea_id: str, reaction: Reaction) -> None:
+ chebi = self.get_substrates_and_products_from_rhea(rhea_id)
+
+ substrate_ids = chebi["substrates"]
+ product_ids = chebi["products"]
+
+ for i in substrate_ids:
+ smiles = self.get_smiles_from_chebi_web(i)
+
+ chebi_id = i.split("_")[-1]
+ chebi_id = f"CHEBI:{chebi_id}"
+ substrate = Molecule.get_or_save(
+ chebi_id=chebi_id,
+ smiles=smiles,
+ )
+ reaction.substrate.connect(substrate)
+
+ for i in product_ids:
+ smiles = self.get_smiles_from_chebi_web(i)
+
+ chebi_id = i.split("_")[-1]
+ chebi_id = f"CHEBI:{chebi_id}"
+ product = Molecule.get_or_save(
+ chebi_id=chebi_id,
+ smiles=smiles,
)
+ reaction.product.connect(product)
def add_go(self, record: dict[str, Any], protein: Protein) -> None:
for reference in record["dbReferences"]:
diff --git a/src/pyeed/analysis/embedding_analysis.py b/src/pyeed/analysis/embedding_analysis.py
index fa9d6c0e..b3535f74 100644
--- a/src/pyeed/analysis/embedding_analysis.py
+++ b/src/pyeed/analysis/embedding_analysis.py
@@ -348,8 +348,8 @@ def create_embedding_vector_index_neo4j(
def find_nearest_neighbors_based_on_vector_index(
self,
+ query_id: str,
db: DatabaseConnector,
- query_protein_id: str,
index_name: str = "embedding_index",
number_of_neighbors: int = 50,
) -> list[tuple[str, float]]:
@@ -406,10 +406,11 @@ def find_nearest_neighbors_based_on_vector_index(
logger.info(f"Index {index_name} is populated, finding nearest neighbors")
query_find_nearest_neighbors = f"""
- MATCH (source:Protein {{accession_id: '{query_protein_id}'}})
+ MATCH (source:Protein {{accession_id: '{query_id}'}})
WITH source.embedding AS embedding
CALL db.index.vector.queryNodes('{index_name}', {number_of_neighbors}, embedding)
YIELD node AS fprotein, score
+ WHERE score > 0.95
RETURN fprotein.accession_id, score
"""
results = db.execute_read(query_find_nearest_neighbors)
diff --git a/src/pyeed/analysis/mutation_detection.py b/src/pyeed/analysis/mutation_detection.py
index 082314f6..c2562ae1 100644
--- a/src/pyeed/analysis/mutation_detection.py
+++ b/src/pyeed/analysis/mutation_detection.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, Optional
from loguru import logger
from pyeed.dbconnect import DatabaseConnector
@@ -15,6 +15,7 @@ def get_sequence_data(
db: DatabaseConnector,
standard_numbering_tool_name: str,
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[int]] = None,
) -> tuple[dict[str, str], dict[str, list[str]]]:
"""Fetch sequence and position data for two sequences from the database.
@@ -23,6 +24,8 @@ def get_sequence_data(
sequence_id2: Second sequence accession ID
db: Database connection instance
standard_numbering_tool_name: Name of standard numbering tool to use
+ node_type: Type of node to use (default: "Protein")
+ region_ids_neo4j: List of region IDs for the sequence cuting based on region_based_sequence.
Returns:
tuple containing:
@@ -32,23 +35,47 @@ def get_sequence_data(
Raises:
ValueError: If standard numbering positions not found for both sequences
"""
- query = f"""
- MATCH (p:{node_type})-[r:HAS_STANDARD_NUMBERING]->(s:StandardNumbering)
- WHERE p.accession_id IN ['{sequence_id1}', '{sequence_id2}']
- AND s.name = '{standard_numbering_tool_name}'
- RETURN p.accession_id as id, p.sequence as sequence, r.positions as positions
- """
- results = db.execute_read(query)
+ if region_ids_neo4j is not None:
+ query = f"""
+ MATCH (p:{node_type})-[rel:HAS_REGION]->(r:Region)
+ WHERE id(r) IN $region_ids_neo4j
+ MATCH (r)-[rel2:HAS_STANDARD_NUMBERING]->(s:StandardNumbering)
+ WHERE p.accession_id IN ['{sequence_id1}', '{sequence_id2}']
+ AND s.name = '{standard_numbering_tool_name}'
+ RETURN p.accession_id as id, p.sequence as sequence, rel2.positions as positions, rel.start as start, rel.end as end
+ """
+ results = db.execute_read(
+ query, parameters={"region_ids_neo4j": region_ids_neo4j}
+ )
+ else:
+ query = f"""
+ MATCH (p:{node_type})-[r:HAS_STANDARD_NUMBERING]->(s:StandardNumbering)
+ WHERE p.accession_id IN ['{sequence_id1}', '{sequence_id2}']
+ AND s.name = '{standard_numbering_tool_name}'
+ RETURN p.accession_id as id, p.sequence as sequence, r.positions as positions
+ """
+ results = db.execute_read(query)
if len(results) < 2:
raise ValueError(
f"Could not find standard numbering positions for both sequences {sequence_id1} and {sequence_id2}"
)
+ if region_ids_neo4j is not None:
+ sequences = {
+ results[i]["id"]: results[i]["sequence"][
+ results[i]["start"] : results[i]["end"]
+ ]
+ for i in range(len(results))
+ }
+ positions = {
+ results[i]["id"]: results[i]["positions"] for i in range(len(results))
+ }
- sequences = {result["id"]: result["sequence"] for result in results}
- positions = {result["id"]: result["positions"] for result in results}
-
- return sequences, positions
+ return sequences, positions
+ else:
+ sequences = {result["id"]: result["sequence"] for result in results}
+ positions = {result["id"]: result["positions"] for result in results}
+ return sequences, positions
def find_mutations(
self,
@@ -105,6 +132,7 @@ def save_mutations_to_db(
sequence_id1: str,
sequence_id2: str,
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[int]] = None,
) -> None:
"""Save detected mutations to the database.
@@ -117,41 +145,87 @@ def save_mutations_to_db(
db: Database connection instance
sequence_id1: First sequence accession ID
sequence_id2: Second sequence accession ID
+ node_type: Type of node to use (default: "Protein")
+ region_ids_neo4j: List of region IDs for the sequence cuting based on region_based_sequence.
"""
-
# Check if a mutation relationship already exists between these proteins
- existing_mutations = db.execute_read(
- f"""
- MATCH (p1:{node_type})-[r:MUTATION]->(p2:{node_type})
- WHERE p1.accession_id = $sequence_id1 AND p2.accession_id = $sequence_id2
- RETURN r
- """,
- {"sequence_id1": sequence_id1, "sequence_id2": sequence_id2},
- )
+ if region_ids_neo4j is not None:
+ query = f"""
+ MATCH (p1:{node_type} {{accession_id: $sequence_id1}})-[rel:HAS_REGION]->(r1:Region)
+ WHERE id(r1) IN $region_ids_neo4j
+ MATCH (r1)-[rel_mutation:MUTATION]->(r2:Region)
+ WHERE id(r2) IN $region_ids_neo4j
+ MATCH (r2)<-[:HAS_REGION]-(p2:{node_type} {{accession_id: $sequence_id2}})
+ RETURN rel_mutation
+ """
+ existing_mutations = db.execute_read(
+ query,
+ {
+ "sequence_id1": sequence_id1,
+ "sequence_id2": sequence_id2,
+ "region_ids_neo4j": region_ids_neo4j,
+ },
+ )
+ else:
+ existing_mutations = db.execute_read(
+ f"""
+ MATCH (p1:{node_type})-[r:MUTATION]->(p2:{node_type})
+ WHERE p1.accession_id = $sequence_id1 AND p2.accession_id = $sequence_id2
+ RETURN r
+ """,
+ {"sequence_id1": sequence_id1, "sequence_id2": sequence_id2},
+ )
if existing_mutations:
logger.debug(
f"Mutation relationship already exists between {sequence_id1} and {sequence_id2}"
)
return
- query = f"""
- MATCH (p1:{node_type}), (p2:{node_type})
- WHERE p1.accession_id = $sequence_id1 AND p2.accession_id = $sequence_id2
- CREATE (p1)-[r:MUTATION]->(p2)
- SET r.from_positions = $from_positions,
- r.to_positions = $to_positions,
- r.from_monomers = $from_monomers,
- r.to_monomers = $to_monomers
- """
- params = {
- "sequence_id1": sequence_id1,
- "sequence_id2": sequence_id2,
- "from_positions": mutations["from_positions"],
- "to_positions": mutations["to_positions"],
- "from_monomers": mutations["from_monomers"],
- "to_monomers": mutations["to_monomers"],
- }
- db.execute_write(query, params)
+ if region_ids_neo4j is not None:
+ # saving the mutation between the regions
+ query = f"""
+ MATCH (r1:Region)
+ WHERE id(r1) IN $region_ids_neo4j
+ MATCH (r1)<-[:HAS_REGION]-(p1:{node_type} {{accession_id: $sequence_id1}})
+ MATCH (r2:Region)
+ WHERE id(r2) IN $region_ids_neo4j
+ MATCH (r2)<-[:HAS_REGION]-(p2:{node_type} {{accession_id: $sequence_id2}})
+ CREATE (r1)-[r:MUTATION]->(r2)
+ SET r.from_positions = $from_positions,
+ r.to_positions = $to_positions,
+ r.from_monomers = $from_monomers,
+ r.to_monomers = $to_monomers
+ """
+ params = {
+ "sequence_id1": sequence_id1,
+ "sequence_id2": sequence_id2,
+ "region_ids_neo4j": region_ids_neo4j,
+ "from_positions": mutations["from_positions"],
+ "to_positions": mutations["to_positions"],
+ "from_monomers": mutations["from_monomers"],
+ "to_monomers": mutations["to_monomers"],
+ }
+ db.execute_write(query, params)
+ else:
+ query = f"""
+ MATCH (p1:{node_type}), (p2:{node_type})
+ WHERE p1.accession_id = $sequence_id1 AND p2.accession_id = $sequence_id2
+ CREATE (p1)-[r:MUTATION]->(p2)
+ SET r.from_positions = $from_positions,
+ r.to_positions = $to_positions,
+ r.from_monomers = $from_monomers,
+ r.to_monomers = $to_monomers
+ """
+ params = {
+ "sequence_id1": sequence_id1,
+ "sequence_id2": sequence_id2,
+ "from_positions": mutations["from_positions"],
+ "to_positions": mutations["to_positions"],
+ "from_monomers": mutations["from_monomers"],
+ "to_monomers": mutations["to_monomers"],
+ }
+ db.execute_write(query, params)
+
logger.debug(
f"Saved {len(list(params['from_positions']))} mutations to database"
)
@@ -165,6 +239,7 @@ def get_mutations_between_sequences(
save_to_db: bool = True,
debug: bool = False,
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[int]] = None,
) -> dict[str, list[int | str]]:
"""Get mutations between two sequences using standard numbering.
@@ -174,6 +249,8 @@ def get_mutations_between_sequences(
db: Database connection instance
standard_numbering_tool_name: Name of standard numbering tool to use
save_to_db: Whether to save mutations to database (default: True)
+ node_type: Type of node to use (default: "Protein")
+ region_ids_neo4j: List of region IDs for the sequence cuting based on region_based_sequence.
Returns:
dict containing mutation information:
@@ -186,7 +263,12 @@ def get_mutations_between_sequences(
ValueError: If standard numbering positions not found for both sequences
"""
sequences, positions = self.get_sequence_data(
- sequence_id1, sequence_id2, db, standard_numbering_tool_name, node_type
+ sequence_id1,
+ sequence_id2,
+ db,
+ standard_numbering_tool_name,
+ node_type,
+ region_ids_neo4j,
)
if debug:
@@ -201,7 +283,12 @@ def get_mutations_between_sequences(
if save_to_db:
self.save_mutations_to_db(
- mutations, db, sequence_id1, sequence_id2, node_type
+ mutations,
+ db,
+ sequence_id1,
+ sequence_id2,
+ node_type,
+ region_ids_neo4j,
)
return mutations
diff --git a/src/pyeed/analysis/network_analysis.py b/src/pyeed/analysis/network_analysis.py
index dab37fa0..dd66b45c 100644
--- a/src/pyeed/analysis/network_analysis.py
+++ b/src/pyeed/analysis/network_analysis.py
@@ -20,6 +20,27 @@ def __init__(self, db: DatabaseConnector):
self.db: DatabaseConnector = db
self.graph: nx.Graph = nx.Graph()
+ def check_indexes(self) -> list[dict[str, Any]]:
+ """
+ Checks all existing indexes in the Neo4j database.
+
+ Returns:
+ list[dict[str, Any]]: List of dictionaries containing index information including:
+ - name: The name of the index
+ - type: The type of index (e.g., "BTREE", "LOOKUP")
+ - labelsOrTypes: The labels or relationship types the index is on
+ - properties: The properties the index is on
+ - uniqueness: Whether the index is unique
+ - state: The state of the index (e.g., "ONLINE", "POPULATING")
+ """
+ query = """
+ SHOW INDEXES
+ """
+ logger.info("Checking existing indexes in the database")
+ indexes = self.db.execute_read(query)
+ logger.info(f"Found {len(indexes)} indexes")
+ return indexes
+
def create_graph(
self,
nodes: Optional[list[str]] = None,
@@ -37,60 +58,62 @@ def create_graph(
Returns:
networkx.Graph: The created graph.
"""
-
logger.info(
f"Creating graph with node types: {nodes} and relationships: {relationships} and ids: {ids}"
)
- # Query to fetch nodes with filters
- node_filter = ""
+ # Build the base query
+ base_query = """
+ MATCH (n)
+ """
+
+ # Add node filters
+ node_filters = []
if nodes:
- node_filter += "WHERE labels(n)[0] IN $node_types "
+ node_filters.append("labels(n)[0] IN $node_types")
if ids:
- if "WHERE" in node_filter:
- node_filter += "AND n.accession_id IN $accession_ids "
- else:
- node_filter += "WHERE n.accession_id IN $accession_ids "
+ node_filters.append("n.accession_id IN $accession_ids")
- query_nodes = f"""
- MATCH (n)
- {node_filter}
- RETURN ID(n) as id, labels(n) as labels, properties(n) as properties
+ if node_filters:
+ base_query += "WHERE " + " AND ".join(node_filters)
+
+ # Add relationship pattern and filters
+ base_query += """
+ OPTIONAL MATCH (n)-[r]->(m)
"""
- # Query to fetch relationships with filters
- relationship_filter = ""
+ # Add relationship type filter if specified
if relationships:
- relationship_filter += "WHERE type(r) IN $relationships "
+ base_query += "WHERE type(r) IN $relationships "
- query_relationships = f"""
- MATCH (n)-[r]->(m)
- {relationship_filter}
- RETURN ID(n) as source, ID(m) as target, type(r) as type, properties(r) as properties
+ # Return both nodes and relationships in a single query
+ base_query += """
+ RETURN
+ collect(DISTINCT {id: ID(n), labels: labels(n), properties: properties(n)}) as nodes,
+ collect(DISTINCT {source: ID(n), target: ID(m), type: type(r), properties: properties(r)}) as relationships
"""
- # Fetch nodes and relationships
- logger.debug(f"Executing query: {query_nodes}")
- nodes_results = self.db.execute_read(
- query_nodes, {"node_types": nodes, "accession_ids": ids}
- )
- logger.debug(f"Executing query: {query_relationships}")
- relationships_results = self.db.execute_read(
- query_relationships, {"relationships": relationships}
+ logger.info("Executing combined query for nodes and relationships")
+ results = self.db.execute_read(
+ base_query,
+ {"node_types": nodes, "accession_ids": ids, "relationships": relationships},
)
- logger.debug(f"Number of nodes: {len(nodes_results)}")
- logger.debug(f"Number of relationships: {len(relationships_results)}")
- # Add nodes
- for node in nodes_results:
+ if not results or not results[0]:
+ logger.warning("No results found in the database")
+ return self.graph
+
+ # Process nodes
+ nodes_data = results[0]["nodes"]
+ for node in nodes_data:
self.graph.add_node(
- node["id"],
- labels=node["labels"],
- properties=node["properties"],
+ node["id"], labels=node["labels"], properties=node["properties"]
)
+ logger.info(f"Added {len(nodes_data)} nodes to the graph")
- # Add relationships
- for rel in relationships_results:
+ # Process relationships
+ relationships_data = results[0]["relationships"]
+ for rel in relationships_data:
if rel["source"] in self.graph and rel["target"] in self.graph:
self.graph.add_edge(
rel["source"],
@@ -98,6 +121,7 @@ def create_graph(
type=rel["type"],
properties=rel["properties"],
)
+ logger.info(f"Added {len(relationships_data)} relationships to the graph")
return self.graph
@@ -233,8 +257,8 @@ def calculate_positions_2d(
filtered_graph.remove_edges_from(self_referential_edges)
# Find isolated nodes
- isolated_nodes = self.find_isolated_nodes(filtered_graph)
- filtered_graph.remove_nodes_from(isolated_nodes)
+ # isolated_nodes = self.find_isolated_nodes(filtered_graph)
+ # filtered_graph.remove_nodes_from(isolated_nodes)
# Use spring layout for force-directed graph
weight_attr = attribute if attribute is not None else None
diff --git a/src/pyeed/analysis/sequence_alignment.py b/src/pyeed/analysis/sequence_alignment.py
index 8dd41553..440cbb1e 100644
--- a/src/pyeed/analysis/sequence_alignment.py
+++ b/src/pyeed/analysis/sequence_alignment.py
@@ -91,6 +91,8 @@ def align_multipairwise(
return_results: bool = True,
pairs: Optional[list[tuple[str, str]]] = None,
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
+ num_cores: int = cpu_count() - 1,
) -> Optional[list[dict[str, Any]]]:
"""
Creates all possible pairwise alignments from a dictionary of sequences or from sequence IDs.
@@ -114,6 +116,7 @@ def align_multipairwise(
pairs (Optional[list[tuple[str, str]]]): A list of tuples, where each tuple contains two
sequence IDs to align. If provided, only these pairs will be aligned.
node_type (str): The type of node to align. Defaults to "Protein".
+ region_ids_neo4j (Optional[list[str]]): A list of region IDs for the sequence cuting based on region_based_sequence.
Returns:
Optional[List[dict]]: A list of dictionaries containing the alignment results if
`return_results` is True. If False, returns None.
@@ -121,7 +124,7 @@ def align_multipairwise(
# Fetch sequences if ids are provided
if ids is not None and db is not None:
- sequences = self._get_id_sequence_dict(db, ids, node_type)
+ sequences = self._get_id_sequence_dict(db, ids, node_type, region_ids_neo4j)
if not sequences:
raise ValueError(
@@ -134,6 +137,28 @@ def align_multipairwise(
total_pairs = len(pairs)
all_alignments = []
+ query = """
+ MATCH (p1:Protein)-[:PAIRWISE_ALIGNED]->(p2:Protein)
+ RETURN p1.accession_id AS Protein1_ID, p2.accession_id AS Protein2_ID
+ """
+
+ # Fetch results properly as a list of tuples
+ existing_pairs = set()
+ if db is not None:
+ existing_pairs = set(
+ tuple(sorted((row["Protein1_ID"], row["Protein2_ID"])))
+ for row in db.execute_write(query)
+ )
+
+ # Filter new pairs that are not in existing_pairs
+ new_pairs = [
+ pair for pair in pairs if tuple(sorted(pair)) not in existing_pairs
+ ]
+
+ print(f"Number of existing pairs: {len(existing_pairs)}")
+ print(f"Number of total pairs: {len(pairs)}")
+ print(f"Number of pairs to align: {len(new_pairs)}")
+
with Progress() as progress:
align_task = progress.add_task(
f"⛓️ Aligning {total_pairs} sequence pairs...", total=total_pairs
@@ -142,9 +167,9 @@ def align_multipairwise(
"📥 Inserting alignment results to database...", total=total_pairs
)
- for pair_chunk in chunks(pairs, batch_size):
+ for pair_chunk in chunks(new_pairs, batch_size):
# Align the pairs in the current chunk
- alignments = Parallel(n_jobs=cpu_count(), prefer="processes")(
+ alignments = Parallel(n_jobs=num_cores, prefer="processes")(
delayed(self.align_pairwise)(
{pair[0]: sequences[pair[0]]},
{pair[1]: sequences[pair[1]]},
@@ -156,7 +181,7 @@ def align_multipairwise(
progress.update(align_task, advance=len(pair_chunk))
if db:
- self._to_db(alignments, db)
+ self._to_db(alignments, db, node_type, region_ids_neo4j)
progress.update(db_task, advance=len(pair_chunk))
if return_results:
@@ -168,28 +193,53 @@ def _to_db(
self,
alignments: list[dict[str, Any]],
db: DatabaseConnector,
+ node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
) -> None:
"""Inserts the alignment results to pyeed graph database.
Args:
alignments (list[dict]): A list of dictionaries containing the alignment results.
db (DatabaseConnector): A `DatabaseConnector` object.
+ node_type (str): The type of node to align. Defaults to "Protein".
+ region_ids_neo4j (Optional[list[str]]): A list of region IDs for the sequence cuting based on region_based_sequence.
"""
- query = """
- UNWIND $alignments AS alignment
- MATCH (p1:Protein {accession_id: alignment.query_id})
- MATCH (p2:Protein {accession_id: alignment.target_id})
- MERGE (p1)-[r:PAIRWISE_ALIGNED]->(p2)
- SET r.similarity = alignment.identity,
+ if region_ids_neo4j is None:
+ query = f"""
+ UNWIND $alignments AS alignment
+ MATCH (p1:{node_type} {{accession_id: alignment.query_id}})
+ MATCH (p2:{node_type} {{accession_id: alignment.target_id}})
+ MERGE (p1)-[r:PAIRWISE_ALIGNED]->(p2)
+ SET r.similarity = alignment.identity,
r.mismatches = alignment.mismatches,
r.gaps = alignment.gaps,
r.score = alignment.score,
r.query_aligned = alignment.query_aligned,
r.target_aligned = alignment.target_aligned
- """
-
- db.execute_write(query, {"alignments": alignments})
+ """
+ db.execute_write(query, parameters={"alignments": alignments})
+ else:
+ query = f"""
+ UNWIND $alignments AS alignment
+ MATCH (p1:{node_type} {{accession_id: alignment.query_id}})-[rel1:HAS_REGION]->(r1:Region)
+ MATCH (p2:{node_type} {{accession_id: alignment.target_id}})-[rel2:HAS_REGION]->(r2:Region)
+ WHERE id(r1) IN $region_ids_neo4j AND id(r2) IN $region_ids_neo4j
+ MERGE (r1)-[r:PAIRWISE_ALIGNED]->(r2)
+ SET r.similarity = alignment.identity,
+ r.mismatches = alignment.mismatches,
+ r.gaps = alignment.gaps,
+ r.score = alignment.score,
+ r.query_aligned = alignment.query_aligned,
+ r.target_aligned = alignment.target_aligned
+ """
+ db.execute_write(
+ query,
+ parameters={
+ "alignments": alignments,
+ "region_ids_neo4j": region_ids_neo4j,
+ },
+ )
def _get_aligner(self) -> BioPairwiseAligner:
"""Creates a BioPython pairwise aligner object with the specified parameters
@@ -244,6 +294,7 @@ def _get_id_sequence_dict(
db: DatabaseConnector,
ids: list[str] = [],
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
) -> dict[str, str]:
"""Gets all sequences from the database and returns them in a dictionary.
Key is the accession id and value is the sequence.
@@ -257,21 +308,52 @@ def _get_id_sequence_dict(
dict[str, str]: Dictionary of sequences with accession id as key.
"""
- if not ids:
- query = f"""
- MATCH (p:{node_type})
- RETURN p.accession_id AS accession_id, p.sequence AS sequence
- """
- nodes = db.execute_read(query)
- else:
- query = f"""
- MATCH (p:{node_type})
- WHERE p.accession_id IN $ids
- RETURN p.accession_id AS accession_id, p.sequence AS sequence
- """
- nodes = db.execute_read(query, {"ids": ids})
+ if ids != []:
+ if region_ids_neo4j is not None:
+ query = f"""
+ MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
+ WHERE id(r) IN $region_ids_neo4j AND p.accession_id IN $ids
+ RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
+ """
+ nodes = db.execute_read(
+ query,
+ parameters={"region_ids_neo4j": region_ids_neo4j, "ids": ids},
+ )
+ else:
+ query = f"""
+ MATCH (p:{node_type})
+ WHERE p.accession_id IN $ids
+ RETURN p.accession_id AS accession_id, p.sequence AS sequence
+ """
+ nodes = db.execute_read(query, parameters={"ids": ids})
- return {node["accession_id"]: node["sequence"] for node in nodes}
+ else:
+ if region_ids_neo4j is not None:
+ query = f"""
+ MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
+ WHERE id(r) IN $region_ids_neo4j
+ RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
+ """
+ nodes = db.execute_read(
+ query,
+ parameters={
+ "region_ids_neo4j": region_ids_neo4j,
+ },
+ )
+ else:
+ query = f"""
+ MATCH (p:{node_type})
+ RETURN p.accession_id AS accession_id, p.sequence AS sequence
+ """
+ nodes = db.execute_read(query)
+
+ if region_ids_neo4j is not None:
+ return {
+ node["accession_id"]: node["sequence"][node["start"] : node["end"]]
+ for node in nodes
+ }
+ else:
+ return {node["accession_id"]: node["sequence"] for node in nodes}
def _load_substitution_matrix(self) -> "BioSubstitutionMatrix":
from Bio.Align import substitution_matrices
diff --git a/src/pyeed/analysis/standard_numbering.py b/src/pyeed/analysis/standard_numbering.py
index 04d78d96..6f81869f 100644
--- a/src/pyeed/analysis/standard_numbering.py
+++ b/src/pyeed/analysis/standard_numbering.py
@@ -41,7 +41,11 @@ def __init__(self, name: str) -> None:
self.name = name
def get_node_base_sequence(
- self, base_sequence_id: str, db: DatabaseConnector, node_type: str = "Protein"
+ self,
+ base_sequence_id: str,
+ db: DatabaseConnector,
+ node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
) -> dict[str, str]:
"""
Retrieve the base node sequence from the database for a given accession id.
@@ -52,21 +56,37 @@ def get_node_base_sequence(
Args:
base_sequence_id: The accession id of the base node sequence.
db: The database connector instance to perform the query.
-
+ region_ids_neo4j: A list of region IDs for the sequence cuting based on region_based_sequence.
Returns:
A dictionary with keys 'id' and 'sequence' holding the node type id and its sequence.
"""
- query = f"""
- MATCH (p:{node_type})
- WHERE p.accession_id = '{base_sequence_id}'
- RETURN p.accession_id AS accession_id, p.sequence AS sequence
- """
+ if region_ids_neo4j:
+ query = f"""
+ MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
+ WHERE id(r) IN $region_ids_neo4j
+ WHERE p.accession_id = '{base_sequence_id}'
+ RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
+ """
+ else:
+ query = f"""
+ MATCH (p:{node_type})
+ WHERE p.accession_id = '{base_sequence_id}'
+ RETURN p.accession_id AS accession_id, p.sequence AS sequence
+ """
base_sequence_read = db.execute_read(query)
# Assume the first returned record is the desired base sequence
- base_sequence = {
- "id": base_sequence_read[0]["accession_id"],
- "sequence": base_sequence_read[0]["sequence"],
- }
+ if region_ids_neo4j:
+ base_sequence = {
+ "id": base_sequence_read[0]["accession_id"],
+ "sequence": base_sequence_read[0]["sequence"][
+ base_sequence_read[0]["start"] : base_sequence_read[0]["end"]
+ ],
+ }
+ else:
+ base_sequence = {
+ "id": base_sequence_read[0]["accession_id"],
+ "sequence": base_sequence_read[0]["sequence"],
+ }
return base_sequence
def save_positions(
@@ -74,6 +94,7 @@ def save_positions(
db: DatabaseConnector,
positions: dict[str, list[str]],
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
) -> None:
"""
Save the calculated numbering positions for each protein into the database.
@@ -84,16 +105,30 @@ def save_positions(
Args:
db: The database connector instance used to execute the write queries.
+ positions: A dictionary mapping protein accession ids to lists of numbering positions.
+ node_type: The type of node to process. Default is "Protein".
+ region_ids_neo4j: A list of region IDs for the sequence cuting based on region_based_sequence.
"""
for protein_id in positions:
- query = f"""
- MATCH (p:{node_type} {{accession_id: '{protein_id}'}})
- MATCH (s:StandardNumbering {{name: '{self.name}'}})
- MERGE (p)-[r:HAS_STANDARD_NUMBERING]->(s)
- SET r.positions = {str(positions[protein_id])}
- """
- # Execute the write query to update the standard numbering relationship.
- db.execute_write(query)
+ if region_ids_neo4j:
+ query = f"""
+ MATCH (p:{node_type} {{accession_id: '{protein_id}'}})-[e:HAS_REGION]->(r:Region)
+ WHERE id(r) IN $region_ids_neo4j
+ MATCH (s:StandardNumbering {{name: '{self.name}'}})
+ MERGE (r)-[rel:HAS_STANDARD_NUMBERING]->(s)
+ SET rel.positions = {str(positions[protein_id])}
+ """
+ db.execute_write(
+ query, parameters={"region_ids_neo4j": region_ids_neo4j}
+ )
+ else:
+ query = f"""
+ MATCH (p:{node_type} {{accession_id: '{protein_id}'}})
+ MATCH (s:StandardNumbering {{name: '{self.name}'}})
+ MERGE (p)-[rel:HAS_STANDARD_NUMBERING]->(s)
+ SET rel.positions = {str(positions[protein_id])}
+ """
+ db.execute_write(query)
def run_numbering_algorithm_clustalo(
self, base_sequence_id: str, alignment: Any
@@ -317,6 +352,7 @@ def apply_standard_numbering_pairwise(
list_of_seq_ids: Optional[List[str]] = None,
return_positions: bool = False,
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
) -> Optional[Dict[str, List[str]]]:
"""
Apply standard numbering via pairwise alignment using a base sequence.
@@ -332,7 +368,7 @@ def apply_standard_numbering_pairwise(
list_of_seq_ids: An optional list of node type ids to process. If None, all node type ids are used.
return_positions: If True, the method returns the computed positions dictionary after processing.
node_type: The type of node to process. Default is "Protein".
-
+ region_ids_neo4j: A list of region IDs for the sequence cuting based on region_based_sequence.
Raises:
ValueError: If the pairwise alignment fails and returns no results.
"""
@@ -362,15 +398,36 @@ def apply_standard_numbering_pairwise(
pairs.append((base_sequence_id, node_id))
# check if the pairs are already existing with the same name under the same standard numbering node
- query = f"""
- MATCH (s:StandardNumbering {{name: $name}})
- MATCH (p:{node_type})-[r:HAS_STANDARD_NUMBERING]->(s)
- WHERE p.accession_id IN $list_of_seq_ids
- RETURN p.accession_id AS accession_id
- """
- results = db.execute_read(
- query, parameters={"list_of_seq_ids": list_of_seq_ids, "name": self.name}
- )
+ if node_type == "DNA" and region_ids_neo4j is not None:
+ query = """
+ MATCH (s:StandardNumbering {name: $name})
+ MATCH (r:Region)
+ WHERE id(r) IN $region_ids_neo4j
+ MATCH (r:Region)<-[:HAS_STANDARD_NUMBERING]-(s)
+ WHERE r.accession_id IN $list_of_seq_ids
+ RETURN r.accession_id AS accession_id
+ """
+
+ results = db.execute_read(
+ query,
+ parameters={
+ "list_of_seq_ids": list_of_seq_ids,
+ "name": self.name,
+ "region_ids_neo4j": region_ids_neo4j,
+ },
+ )
+ else:
+ query = f"""
+ MATCH (s:StandardNumbering {{name: $name}})
+ MATCH (p:{node_type})-[rel:HAS_STANDARD_NUMBERING]->(s)
+ WHERE p.accession_id IN $list_of_seq_ids
+ RETURN p.accession_id AS accession_id
+ """
+ results = db.execute_read(
+ query,
+ parameters={"list_of_seq_ids": list_of_seq_ids, "name": self.name},
+ )
+
if results is not None:
for row in results:
if row is not None:
@@ -398,6 +455,7 @@ def apply_standard_numbering_pairwise(
db=db,
pairs=pairs, # List of sequence pairs to be aligned
node_type=node_type,
+ region_ids_neo4j=region_ids_neo4j,
)
logger.info(f"Pairwise alignment results: {results_pairwise}")
@@ -435,7 +493,7 @@ def apply_standard_numbering_pairwise(
)
# Update the database with the calculated positions.
- self.save_positions(db, positions, node_type)
+ self.save_positions(db, positions, node_type, region_ids_neo4j)
if return_positions:
return positions
@@ -447,6 +505,7 @@ def apply_standard_numbering(
db: DatabaseConnector,
list_of_seq_ids: Optional[List[str]] = None,
node_type: str = "Protein",
+ region_ids_neo4j: Optional[list[str]] = None,
) -> None:
"""
Apply a standard numbering scheme to a collection of nodes using multiple sequence alignment.
@@ -460,6 +519,7 @@ def apply_standard_numbering(
db: DatabaseConnector instance used for executing queries.
list_of_seq_ids: An optional list of specific node type ids to process. If None, all node type ids are used.
node_type: The type of node to process. Default is "Protein".
+ region_ids_neo4j: A list of region IDs for the sequence cuting based on region_based_sequence.
"""
if list_of_seq_ids is None:
@@ -489,12 +549,37 @@ def apply_standard_numbering(
nodes_read = []
else:
nodes_read = query_result
- nodes_dict = {node["accession_id"]: node["sequence"] for node in nodes_read}
+
+ if node_type == "DNA" and region_ids_neo4j is not None:
+ # then the sequence is a region based sequence.
+ # get the region objects for each of the nodes as well
+ query = f"""
+ MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
+ WHERE id(r) IN $region_ids_neo4j
+ WHERE p.accession_id IN $list_of_seq_ids
+ RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
+ """
+ region_read = db.execute_read(
+ query,
+ parameters={
+ "list_of_seq_ids": list_of_seq_ids,
+ "region_ids_neo4j": region_ids_neo4j,
+ },
+ )
+ nodes_dict = {
+ node["accession_id"]: node["sequence"][node["start"] : node["end"]]
+ for node in region_read
+ }
+
+ else:
+ nodes_dict = {node["accession_id"]: node["sequence"] for node in nodes_read}
logger.info(f"Using {len(nodes_dict)} sequences for standard numbering")
# Obtain the base sequence details from the database.
- base_sequence = self.get_node_base_sequence(base_sequence_id, db, node_type)
+ base_sequence = self.get_node_base_sequence(
+ base_sequence_id, db, node_type, region_ids_neo4j
+ )
# Remove the base sequence from the nodes list to prevent duplicate alignment.
if base_sequence_id in nodes_dict:
@@ -525,4 +610,4 @@ def apply_standard_numbering(
)
# Update the database with the relationships between nodes and standard numbering.
- self.save_positions(db, positions, node_type)
+ self.save_positions(db, positions, node_type, region_ids_neo4j)
diff --git a/src/pyeed/dbconnect.py b/src/pyeed/dbconnect.py
index ec2df259..8abcab52 100644
--- a/src/pyeed/dbconnect.py
+++ b/src/pyeed/dbconnect.py
@@ -227,7 +227,12 @@ def _get_driver(uri: str, user: str | None, password: str | None) -> Driver:
Creates a new Neo4j driver instance.
"""
auth = (user, password) if user and password else None
- return GraphDatabase.driver(uri, auth=auth)
+ return GraphDatabase.driver(
+ uri,
+ auth=auth,
+ connection_timeout=60, # Increase initial connection timeout
+ max_connection_lifetime=86400, # Keep connections alive longer
+ )
@property
def node_properties(self) -> list[dict[str, str]]:
diff --git a/src/pyeed/embedding.py b/src/pyeed/embedding.py
index 82b5c6d7..28f66a1b 100644
--- a/src/pyeed/embedding.py
+++ b/src/pyeed/embedding.py
@@ -4,10 +4,13 @@
import numpy as np
import torch
+from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
-from esm.sdk.api import ESMProtein, LogitsConfig
+from esm.sdk.api import ESM3InferenceClient, ESMProtein, LogitsConfig, SamplingConfig
from huggingface_hub import HfFolder, login
+from loguru import logger
from numpy.typing import NDArray
+from torch.nn import DataParallel, Module
from transformers import EsmModel, EsmTokenizer
from pyeed.dbconnect import DatabaseConnector
@@ -30,35 +33,89 @@ def get_hf_token() -> str:
raise RuntimeError("Failed to get Hugging Face token")
+def process_batches_on_gpu(
+ data: list[tuple[str, str]],
+ batch_size: int,
+ model: Module,
+ tokenizer: EsmTokenizer,
+ db: DatabaseConnector,
+ device: torch.device,
+) -> None:
+ """
+ Splits data into batches and processes them on a single GPU.
+
+ Args:
+ data (list): List of (accession_id, sequence) tuples.
+ batch_size (int): Size of each batch.
+ model: The model instance for this GPU.
+ tokenizer: The tokenizer for the model.
+ device (str): The assigned GPU device.
+ db: Database connection.
+ """
+ logger.debug(f"Processing {len(data)} sequences on {device}.")
+
+ model = model.to(device)
+
+ # Split data into smaller batches
+ for batch_start in range(0, len(data), batch_size):
+ batch_end = min(batch_start + batch_size, len(data))
+ batch = data[batch_start:batch_end]
+
+ accessions, sequences = zip(*batch)
+
+ current_batch_size = len(sequences)
+
+ while current_batch_size > 0:
+ try:
+ # Compute embeddings
+ embeddings_batch = get_batch_embeddings(
+ list(sequences[:current_batch_size]), model, tokenizer, device
+ )
+
+ # Update the database
+ update_protein_embeddings_in_db(
+ db, list(accessions[:current_batch_size]), embeddings_batch
+ )
+
+ # Move to the next batch
+ break # Successful execution, move to the next batch
+
+ except torch.cuda.OutOfMemoryError:
+ torch.cuda.empty_cache()
+ current_batch_size = max(
+ 1, current_batch_size // 2
+ ) # Reduce batch size
+ logger.warning(
+ f"Reduced batch size to {current_batch_size} due to OOM error."
+ )
+
+ # Free memory
+ del model
+ torch.cuda.empty_cache()
+
+
def load_model_and_tokenizer(
model_name: str,
-) -> Tuple[
- Union[EsmModel, ESMC], # Changed from ESM3InferenceClient to ESMC
- Union[EsmTokenizer, None],
- torch.device,
-]:
+ device: torch.device,
+) -> Tuple[Any, Union[Any, None], torch.device]:
"""
- Loads either an ESM-3 (using ESMC) or an ESM-2 (using Transformers) model,
- depending on the `model_name` provided.
+ Loads the model and assigns it to a specific GPU.
Args:
- model_name (str): The model name or identifier (e.g., 'esmc' or 'esm2_t12_35M_UR50D').
+ model_name (str): The model name.
+ device (str): The specific GPU device.
Returns:
- Tuple of (model, tokenizer, device)
+ Tuple: (model, tokenizer, device)
"""
- # Get token only when loading model
token = get_hf_token()
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ tokenizer = None
- # Check if this is an ESM-3 variant
if "esmc" in model_name.lower():
- # Using ESMC from_pretrained
model = ESMC.from_pretrained(model_name)
- model = model.to(device)
- return model, None, device
+ elif "esm3-sm-open-v1" in model_name.lower():
+ model = ESM3.from_pretrained("esm3_sm_open_v1")
else:
- # Otherwise, assume it's an ESM-2 model on Hugging Face
full_model_name = (
model_name
if model_name.startswith("facebook/")
@@ -66,41 +123,52 @@ def load_model_and_tokenizer(
)
model = EsmModel.from_pretrained(full_model_name, use_auth_token=token)
tokenizer = EsmTokenizer.from_pretrained(full_model_name, use_auth_token=token)
- model = model.to(device)
- return model, tokenizer, device
+
+ model = model.to(device)
+ return model, tokenizer, device
def get_batch_embeddings(
batch_sequences: list[str],
- model: Union[EsmModel, ESMC],
+ model: Union[
+ EsmModel,
+ ESMC,
+ DataParallel[Module],
+ ESM3InferenceClient,
+ ESM3,
+ ],
tokenizer_or_alphabet: Union[EsmTokenizer, None],
device: torch.device,
pool_embeddings: bool = True,
) -> list[NDArray[np.float64]]:
"""
Generates mean-pooled embeddings for a batch of sequences.
+ Supports ESM++, ESM-2 and ESM-3 models.
Args:
- batch_sequences (list[str]): List of sequence strings to be embedded.
- model (Union[EsmModel, ESMC]): Loaded model (ESM-2 or ESM-3).
- tokenizer_or_alphabet (Union[EsmTokenizer, None]): Tokenizer if ESM-2, None if ESM-3.
- device (torch.device): Device on which to run inference (CPU or GPU).
- pool_embeddings (bool): Whether to pool embeddings across sequence length.
+ batch_sequences (list[str]): List of sequence strings.
+ model: Loaded model (could be wrapped in DataParallel).
+ tokenizer_or_alphabet: Tokenizer if needed.
+ device: Inference device (CPU/GPU).
+ pool_embeddings (bool): Whether to average embeddings across the sequence length.
Returns:
- list[NDArray[np.float64]]: A list of embeddings as NumPy arrays.
+ List of embeddings as NumPy arrays.
"""
- if isinstance(model, ESMC):
+ # First, determine the base model type
+ base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
+
+ if isinstance(base_model, ESMC):
+ # For ESMC models
+ embedding_list = []
with torch.no_grad():
- embedding_list = []
for sequence in batch_sequences:
- # Process each sequence individually
protein = ESMProtein(sequence=sequence)
- protein_tensor = model.encode(protein)
- logits_output = model.logits(
+ # Use the model directly - DataParallel handles internal distribution
+ protein_tensor = base_model.encode(protein)
+ logits_output = base_model.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
- # Convert embeddings to numpy array - ensure embeddings is not None
if logits_output.embeddings is None:
raise ValueError(
"Model did not return embeddings. Check LogitsConfig settings."
@@ -109,9 +177,27 @@ def get_batch_embeddings(
if pool_embeddings:
embeddings = embeddings.mean(axis=1)
embedding_list.append(embeddings[0])
-
return embedding_list
-
+ elif isinstance(base_model, ESM3):
+ # For ESM3 models
+ embedding_list = []
+ with torch.no_grad():
+ for sequence in batch_sequences:
+ protein = ESMProtein(sequence=sequence)
+ sequence_encoding = base_model.encode(protein)
+ result = base_model.forward_and_sample(
+ sequence_encoding,
+ SamplingConfig(return_per_residue_embeddings=True),
+ )
+ if result is None or result.per_residue_embedding is None:
+ raise ValueError("Model did not return embeddings")
+ embeddings = (
+ result.per_residue_embedding.to(torch.float32).cpu().numpy()
+ )
+ if pool_embeddings:
+ embeddings = embeddings.mean(axis=0)
+ embedding_list.append(embeddings)
+ return embedding_list
else:
# ESM-2 logic
assert tokenizer_or_alphabet is not None, "Tokenizer required for ESM-2 models"
@@ -119,15 +205,21 @@ def get_batch_embeddings(
batch_sequences, padding=True, truncation=True, return_tensors="pt"
).to(device)
with torch.no_grad():
- outputs = model(**inputs)
- embeddings = outputs.last_hidden_state.cpu().numpy()
+ outputs = model(**inputs, output_hidden_states=True)
+
+ # Get last hidden state for each sequence
+ hidden_states = outputs.last_hidden_state.cpu().numpy()
+
if pool_embeddings:
- return [embedding.mean(axis=0) for embedding in embeddings]
- return list(embeddings)
+ # Mean pooling across sequence length
+ return [embedding.mean(axis=0) for embedding in hidden_states]
+ return list(hidden_states)
def calculate_single_sequence_embedding_last_hidden_state(
- sequence: str, model_name: str = "facebook/esm2_t33_650M_UR50D"
+ sequence: str,
+ device: torch.device,
+ model_name: str = "facebook/esm2_t33_650M_UR50D",
) -> NDArray[np.float64]:
"""
Calculates an embedding for a single sequence.
@@ -139,12 +231,14 @@ def calculate_single_sequence_embedding_last_hidden_state(
Returns:
NDArray[np.float64]: Normalized embedding vector for the sequence
"""
- model, tokenizer, device = load_model_and_tokenizer(model_name)
+ model, tokenizer, device = load_model_and_tokenizer(model_name, device)
return get_single_embedding_last_hidden_state(sequence, model, tokenizer, device)
def calculate_single_sequence_embedding_all_layers(
- sequence: str, model_name: str = "facebook/esm2_t33_650M_UR50D"
+ sequence: str,
+ device: torch.device,
+ model_name: str = "facebook/esm2_t33_650M_UR50D",
) -> NDArray[np.float64]:
"""
Calculates embeddings for a single sequence across all layers.
@@ -156,7 +250,7 @@ def calculate_single_sequence_embedding_all_layers(
Returns:
NDArray[np.float64]: A numpy array containing layer embeddings for the sequence.
"""
- model, tokenizer, device = load_model_and_tokenizer(model_name)
+ model, tokenizer, device = load_model_and_tokenizer(model_name, device)
return get_single_embedding_all_layers(sequence, model, tokenizer, device)
diff --git a/src/pyeed/main.py b/src/pyeed/main.py
index 25a1b225..1189fcb3 100644
--- a/src/pyeed/main.py
+++ b/src/pyeed/main.py
@@ -1,20 +1,22 @@
import asyncio
+import time
+from concurrent.futures import ThreadPoolExecutor
from typing import Any, Literal
import nest_asyncio
+import torch
from loguru import logger
from pyeed.adapter.ncbi_dna_mapper import NCBIDNAToPyeed
from pyeed.adapter.ncbi_protein_mapper import NCBIProteinToPyeed
+from pyeed.adapter.ncbi_to_uniprot_mapper import NCBIToUniprotMapper
from pyeed.adapter.primary_db_adapter import PrimaryDBAdapter
from pyeed.adapter.uniprot_mapper import UniprotToPyeed
from pyeed.dbchat import DBChat
from pyeed.dbconnect import DatabaseConnector
from pyeed.embedding import (
- free_memory,
- get_batch_embeddings,
load_model_and_tokenizer,
- update_protein_embeddings_in_db,
+ process_batches_on_gpu,
)
@@ -189,64 +191,110 @@ def fetch_ncbi_nucleotide(self, ids: list[str]) -> None:
nest_asyncio.apply()
asyncio.get_event_loop().run_until_complete(adapter.execute_requests())
+ def database_id_mapper(self, ids: list[str], file: str) -> None:
+ """
+ Maps IDs from one database to another using the UniProt ID mapping service
+
+ Args:
+ ids (list[str]): List of IDs to map.
+ """
+
+ mapper = NCBIToUniprotMapper(ids, file)
+ mapper.execute_request()
+
+ nest_asyncio.apply()
+
def calculate_sequence_embeddings(
self,
batch_size: int = 16,
model_name: str = "facebook/esm2_t33_650M_UR50D",
+ num_gpus: int = 1, # Number of GPUs to use
) -> None:
"""
- Calculates embeddings for all sequences in the database that do not have embeddings, processing in batches.
+ Calculates embeddings for all sequences in the database that do not have embeddings,
+ distributing the workload across available GPUs.
Args:
batch_size (int): Number of sequences to process in each batch.
- model_name (str): Name of the model to use for calculating embeddings.
- Defaults to "facebook/esm2_t33_650M_UR50D".
- Available models can be found at https://huggingface.co/facebook/esm2_t6_8M_UR50D.
+ model_name (str): Model used for calculating embeddings.
+ num_gpus (int, optional): Number of GPUs to use. If None, use all available GPUs.
"""
- # Load the model, tokenizer, and device
- model, tokenizer, device = load_model_and_tokenizer(model_name)
+ # Get the available GPUs
+ available_gpus = torch.cuda.device_count()
+ if num_gpus is None or num_gpus > available_gpus:
+ num_gpus = available_gpus
- # Cypher query to retrieve proteins without embeddings and with valid sequences
+ if num_gpus == 0:
+ logger.warning("No GPU available! Running on CPU.")
+
+ # Load separate models for each GPU
+ devices = (
+ [torch.device(f"cuda:{i}") for i in range(num_gpus)]
+ if num_gpus > 0
+ else [torch.device("cpu")]
+ )
+
+ models_and_tokenizers = [
+ load_model_and_tokenizer(model_name, device) for device in devices
+ ]
+
+ # Retrieve sequences without embeddings
query = """
MATCH (p:Protein)
WHERE p.embedding IS NULL AND p.sequence IS NOT NULL
RETURN p.accession_id AS accession, p.sequence AS sequence
"""
-
- # Execute the query and retrieve the results
results = self.db.execute_read(query)
data = [(result["accession"], result["sequence"]) for result in results]
+
if not data:
logger.info("No sequences to process.")
return
+
accessions, sequences = zip(*data)
total_sequences = len(sequences)
- logger.debug(f"Calculating embeddings for {total_sequences} sequences.")
-
- # Process and save embeddings batch by batch
- for batch_start in range(0, total_sequences, batch_size):
- batch_end = min(batch_start + batch_size, total_sequences)
- batch_sequences = sequences[batch_start:batch_end]
- batch_accessions = accessions[batch_start:batch_end]
- logger.debug(
- f"Processing batch {batch_start // batch_size + 1}/"
- f"{(total_sequences + batch_size - 1) // batch_size + 1}"
- )
+ logger.debug(f"Total sequences to process: {total_sequences}")
- # Get embeddings for the current batch
- embeddings_batch = get_batch_embeddings(
- list(batch_sequences), model, tokenizer, device
- )
+ # Split the data into num_gpus chunks
+ gpu_batches = [
+ list(zip(accessions[i::num_gpus], sequences[i::num_gpus]))
+ for i in range(num_gpus)
+ ]
- # Update the database for the current batch
- update_protein_embeddings_in_db(
- self.db, list(batch_accessions), embeddings_batch
- )
+ start_time = time.time()
+
+ # Process batches in parallel across GPUs
+ with ThreadPoolExecutor(max_workers=num_gpus) as executor:
+ futures = []
+ for i, gpu_data in enumerate(gpu_batches):
+ if not gpu_data:
+ continue # Skip empty GPU batches
+
+ model, tokenizer, device = models_and_tokenizers[i]
+ futures.append(
+ executor.submit(
+ process_batches_on_gpu,
+ gpu_data,
+ batch_size,
+ model,
+ tokenizer,
+ self.db,
+ device,
+ )
+ )
+
+ for future in futures:
+ future.result() # Wait for all threads to complete
+
+ end_time = time.time()
+ logger.info(
+ f"Total embedding calculation time: {end_time - start_time:.2f} seconds"
+ )
- # Free memory after processing all batches
- del model, tokenizer
- free_memory()
+ # Cleanup
+ for model, _, _ in models_and_tokenizers:
+ del model
def get_proteins(self, accession_ids: list[str]) -> list[dict[str, Any]]:
"""
@@ -426,3 +474,20 @@ def fetch_dna_entries_for_proteins(self, ids: list[str] | None = None) -> None:
f"Error processing relationship batch {i//BATCH_SIZE + 1}: {str(e)}"
)
continue
+
+ def create_coding_sequences_regions(self) -> None:
+ """
+ Creates coding sequences regions for all proteins in the database.
+
+ It finds the nucleotide start and end positions and create a Region object for the corresponding DNA sequence.
+ Create the region object with the right annotation. And then connect it to the DNA sequence.
+ """
+ query = """
+ MATCH (p:Protein)
+ WHERE p.nucleotide_id IS NOT NULL
+ CREATE (r:Region {annotation: 'coding sequence', sequence_id: p.accession_id})
+ WITH p, r
+ MATCH (d:DNA {accession_id: p.nucleotide_id})
+ CREATE (d)-[:HAS_REGION {start: p.nucleotide_start, end: p.nucleotide_end}]->(r)
+ """
+ self.db.execute_write(query)
diff --git a/src/pyeed/model.py b/src/pyeed/model.py
index 869a4091..5a3bf188 100644
--- a/src/pyeed/model.py
+++ b/src/pyeed/model.py
@@ -1,9 +1,10 @@
from enum import Enum
-from typing import Any
+from typing import Any, cast
# from pyeed.nodes_and_relations import StrictStructuredNode
from neomodel import (
ArrayProperty,
+ BooleanProperty,
FloatProperty,
IntegerProperty,
RelationshipTo,
@@ -112,6 +113,12 @@ def save(self, *args: Any, **kwargs: Any) -> None:
if not all(isinstance(item, float) for item in prop):
raise TypeError(f"All items in '{field}' must be floats")
+ # Validate BoleanProperty
+ elif isinstance(neo_type, BooleanProperty) and not isinstance(prop, bool):
+ raise TypeError(
+ f"Expected a boolean for '{field}', got {type(prop).__name__}"
+ )
+
super().save(*args, **kwargs) # Don't return the result
@classmethod
@@ -146,6 +153,134 @@ class Organism(StrictStructuredNode):
taxonomy_id = IntegerProperty(required=True, unique_index=True)
name = StringProperty()
+ @classmethod
+ def get_or_save(cls, **kwargs: Any) -> "Organism":
+ taxonomy_id = kwargs.get("taxonomy_id")
+ name = kwargs.get("name")
+ try:
+ organism = cast(Organism, cls.nodes.get(taxonomy_id=taxonomy_id))
+ return organism
+ except cls.DoesNotExist:
+ try:
+ organism = cls(taxonomy_id=taxonomy_id, name=name)
+ organism.save()
+ return organism
+ except Exception as e:
+ print(f"Error during saving of the organism: {e}")
+ raise
+
+
+class Mutation(StructuredRel): # type: ignore
+ """A relationship representing mutations between two sequences."""
+
+ from_positions = ArrayProperty(IntegerProperty(), required=True)
+ to_positions = ArrayProperty(IntegerProperty(), required=True)
+ from_monomers = ArrayProperty(StringProperty(), required=True)
+ to_monomers = ArrayProperty(StringProperty(), required=True)
+
+ @classmethod
+ def validate_and_connect(
+ cls,
+ molecule1: Any,
+ molecule2: Any,
+ from_positions: list[int],
+ to_positions: list[int],
+ from_monomers: list[str],
+ to_monomers: list[str],
+ ) -> "Mutation":
+ """Validates the mutations and connects the two molecules, ensuring that no double mutations
+ occur – i.e. if a mutation affecting any of the same positions already exists between these proteins,
+ a new mutation cannot be created.
+
+ Raises:
+ ValueError: If input lists have different lengths or if a mutation for any of these positions
+ already exists.
+ """
+ # Instead of checking *any* mutation, retrieve all mutation relationships between these proteins.
+ # Here molecule1.mutation.relationship(molecule2) returns a list of mutation relationship instances.
+ existing_mutations = molecule1.mutation.relationship(molecule2)
+
+ if existing_mutations:
+ raise ValueError(
+ "A mutation relationship affecting one or more of these positions already exists between these proteins."
+ )
+
+ if (
+ len(from_positions) != len(to_positions)
+ or len(from_positions) != len(from_monomers)
+ or len(from_positions) != len(to_monomers)
+ ):
+ raise ValueError("All input lists must have the same length.")
+
+ for from_position, from_monomer in zip(from_positions, from_monomers):
+ if molecule1.sequence[from_position] != from_monomer:
+ raise ValueError(
+ f"Monomer '{from_monomer}' does not match the sequence {molecule1.accession_id} at position {from_position}"
+ )
+
+ for to_position, to_monomer in zip(to_positions, to_monomers):
+ if molecule2.sequence[to_position] != to_monomer:
+ raise ValueError(
+ f"Monomer '{to_monomer}' does not match the sequence {molecule2.accession_id} at position {to_position}"
+ )
+
+ molecule1.mutation.connect(
+ molecule2,
+ {
+ "from_positions": from_positions,
+ "to_positions": to_positions,
+ "from_monomers": from_monomers,
+ "to_monomers": to_monomers,
+ },
+ )
+
+ return cls(
+ from_positions=from_positions,
+ to_positions=to_positions,
+ from_monomers=from_monomers,
+ to_monomers=to_monomers,
+ )
+
+ @property
+ def label(self) -> str:
+ """The label of the mutation."""
+ return ",".join(
+ f"{from_monomer}{from_position}{to_monomer}"
+ for from_position, from_monomer, to_monomer in zip(
+ list(self.from_positions),
+ list(self.from_monomers),
+ list(self.to_monomers),
+ )
+ )
+
+
+class StandardNumberingRel(StructuredRel): # type: ignore
+ positions = ArrayProperty(StringProperty(), required=True)
+
+ @classmethod
+ def validate_and_connect(
+ cls,
+ molecule1: Any,
+ molecule2: Any,
+ positions: list[str],
+ ) -> "StandardNumberingRel":
+ """Validates the positions and connects the two molecules."""
+ molecule1.sequences_protein.connect(
+ molecule2,
+ {
+ "positions": positions,
+ },
+ )
+
+ return cls(
+ positions=positions,
+ )
+
+ @property
+ def label(self) -> str:
+ """The label of the standard numbering."""
+ return f"{self.positions}"
+
class SiteRel(StructuredRel): # type: ignore
positions = ArrayProperty(IntegerProperty(), required=True)
@@ -187,6 +322,13 @@ class Region(StrictStructuredNode):
annotation = StringProperty(
choices=[(e.value, e.name) for e in Annotation], required=True
)
+ sequence_id = StringProperty()
+
+ # Relationships
+ has_mutation_region = RelationshipTo("Region", "MUTATION", model=Mutation)
+ has_standard_numbering = RelationshipTo(
+ "StandardNumbering", "HAS_STANDARD_NUMBERING", model=StandardNumberingRel
+ )
class DNAProteinRel(StructuredRel): # type: ignore
@@ -250,46 +392,53 @@ def label(self) -> str:
return f"{self.start}-{self.end}"
-class CatalyticActivity(StrictStructuredNode):
+class Reaction(StrictStructuredNode):
"""
- A node representing a catalytic activity.
+ A node representing a reaction.
"""
- catalytic_id = IntegerProperty(required=False, unique_index=True)
- name = StringProperty()
+ rhea_id = StringProperty(unique_index=True, required=True)
+ chebi_id = ArrayProperty(StringProperty())
+
+ # Relationships
+ substrate = RelationshipTo("Molecule", "SUBSTRATE")
+ product = RelationshipTo("Molecule", "PRODUCT")
@property
def label(self) -> str:
- """The label of the catalytic activity."""
- return str(self.name)
+ """The label of the reaction."""
+ return f"{self.rhea_id}"
-class StandardNumberingRel(StructuredRel): # type: ignore
- positions = ArrayProperty(StringProperty(), required=True)
+class Molecule(StrictStructuredNode):
+ """
+ A node representing a molecule in the database.
+ """
- @classmethod
- def validate_and_connect(
- cls,
- molecule1: Any,
- molecule2: Any,
- positions: list[str],
- ) -> "StandardNumberingRel":
- """Validates the positions and connects the two molecules."""
- molecule1.sequences_protein.connect(
- molecule2,
- {
- "positions": positions,
- },
- )
+ chebi_id = StringProperty(unique_index=True, required=True)
+ rhea_compound_id = StringProperty()
+ smiles = StringProperty()
- return cls(
- positions=positions,
- )
+ @classmethod
+ def get_or_save(cls, **kwargs: Any) -> "Molecule":
+ chebi_id = kwargs.get("chebi_id")
+ smiles = kwargs.get("smiles")
+ try:
+ molecule = cast(Molecule, cls.nodes.get(chebi_id=chebi_id))
+ return molecule
+ except cls.DoesNotExist:
+ try:
+ molecule = cls(chebi_id=chebi_id, smiles=smiles)
+ molecule.save()
+ return molecule
+ except Exception as e:
+ print(f"Error during saving of the molecule: {e}")
+ raise
@property
def label(self) -> str:
- """The label of the standard numbering."""
- return f"{self.positions}"
+ """The label of the molecule."""
+ return f"{self.chebi_id}"
class StandardNumbering(StrictStructuredNode):
@@ -389,90 +538,6 @@ def label(self) -> str:
return str(self.term)
-class Mutation(StructuredRel): # type: ignore
- """A relationship representing mutations between two sequences."""
-
- from_positions = ArrayProperty(IntegerProperty(), required=True)
- to_positions = ArrayProperty(IntegerProperty(), required=True)
- from_monomers = ArrayProperty(StringProperty(), required=True)
- to_monomers = ArrayProperty(StringProperty(), required=True)
-
- @classmethod
- def validate_and_connect(
- cls,
- molecule1: Any,
- molecule2: Any,
- from_positions: list[int],
- to_positions: list[int],
- from_monomers: list[str],
- to_monomers: list[str],
- ) -> "Mutation":
- """Validates the mutations and connects the two molecules, ensuring that no double mutations
- occur – i.e. if a mutation affecting any of the same positions already exists between these proteins,
- a new mutation cannot be created.
-
- Raises:
- ValueError: If input lists have different lengths or if a mutation for any of these positions
- already exists.
- """
- # Instead of checking *any* mutation, retrieve all mutation relationships between these proteins.
- # Here molecule1.mutation.relationship(molecule2) returns a list of mutation relationship instances.
- existing_mutations = molecule1.mutation.relationship(molecule2)
-
- if existing_mutations:
- raise ValueError(
- "A mutation relationship affecting one or more of these positions already exists between these proteins."
- )
-
- if (
- len(from_positions) != len(to_positions)
- or len(from_positions) != len(from_monomers)
- or len(from_positions) != len(to_monomers)
- ):
- raise ValueError("All input lists must have the same length.")
-
- for from_position, from_monomer in zip(from_positions, from_monomers):
- if molecule1.sequence[from_position] != from_monomer:
- raise ValueError(
- f"Monomer '{from_monomer}' does not match the sequence {molecule1.accession_id} at position {from_position}"
- )
-
- for to_position, to_monomer in zip(to_positions, to_monomers):
- if molecule2.sequence[to_position] != to_monomer:
- raise ValueError(
- f"Monomer '{to_monomer}' does not match the sequence {molecule2.accession_id} at position {to_position}"
- )
-
- molecule1.mutation.connect(
- molecule2,
- {
- "from_positions": from_positions,
- "to_positions": to_positions,
- "from_monomers": from_monomers,
- "to_monomers": to_monomers,
- },
- )
-
- return cls(
- from_positions=from_positions,
- to_positions=to_positions,
- from_monomers=from_monomers,
- to_monomers=to_monomers,
- )
-
- @property
- def label(self) -> str:
- """The label of the mutation."""
- return ",".join(
- f"{from_monomer}{from_position}{to_monomer}"
- for from_position, from_monomer, to_monomer in zip(
- list(self.from_positions),
- list(self.from_monomers),
- list(self.to_monomers),
- )
- )
-
-
class Protein(StrictStructuredNode):
"""A protein sequence node in the database."""
@@ -488,25 +553,35 @@ class Protein(StrictStructuredNode):
locus_tag = StringProperty()
structure_ids = ArrayProperty(StringProperty())
go_terms = ArrayProperty(StringProperty())
- catalytic_name = ArrayProperty(StringProperty())
+ rhea_id = ArrayProperty(StringProperty())
+ chebi_id = ArrayProperty(StringProperty())
embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=1280),
index_type="hnsw",
distance_metric="COSINE",
)
+ TBT = StringProperty()
+ PCL = StringProperty()
+ BHET = StringProperty()
+ PET_powder = StringProperty()
# Relationships
organism = RelationshipTo("Organism", "ORIGINATES_FROM")
site = RelationshipTo("Site", "HAS_SITE", model=SiteRel)
region = RelationshipTo("Region", "HAS_REGION", model=RegionRel)
go_annotation = RelationshipTo("GOAnnotation", "ASSOCIATED_WITH")
- catalytic_annotation = RelationshipTo("CatalyticActivity", "CATALYTIC_ACTIVITY")
+ reaction = RelationshipTo("Reaction", "HAS_REACTION")
+ substrate = RelationshipTo("Molecule", "SUBSTRATE")
+ product = RelationshipTo("Molecule", "PRODUCT")
ontology_object = RelationshipTo("OntologyObject", "ASSOCIATED_WITH")
mutation = RelationshipTo("Protein", "MUTATION", model=Mutation)
pairwise_aligned = RelationshipTo(
"Protein", "PAIRWISE_ALIGNED", model=PairwiseAlignmentResult
)
+ has_standard_numbering = RelationshipTo(
+ "StandardNumbering", "HAS_STANDARD_NUMBERING", model=StandardNumberingRel
+ )
class DNA(StrictStructuredNode):
@@ -533,6 +608,9 @@ class DNA(StrictStructuredNode):
pairwise_aligned = RelationshipTo(
"DNA", "PAIRWISE_ALIGNED", model=PairwiseAlignmentResult
)
+ has_standard_numbering = RelationshipTo(
+ "StandardNumbering", "HAS_STANDARD_NUMBERING", model=StandardNumberingRel
+ )
class CustomRealationship(StructuredRel): # type: ignore