diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 8532bf0..b3e8ab8 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,14 +1,15 @@ import importlib import os +from abc import ABC from typing import Callable, List, Optional import pandas as pd import torch import tqdm -from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, + ChEBIOverX, ChEBIOverXPartial, ) from lightning_utilities.core.rank_zero import rank_zero_info @@ -48,7 +49,7 @@ def _resolve_property( return getattr(graph_properties, property)() -class GraphPropertiesMixIn(XYBaseDataModule): +class GraphPropertiesMixIn(ChEBIOverX, ABC): READER = GraphPropertyReader def __init__( @@ -107,10 +108,12 @@ def enc_if_not_none(encode, value): if not os.path.isfile(self.get_property_path(property)): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode + rank_zero_info("\tReading property valeus...") property_values = [ self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) ] + rank_zero_info("\tEncoding property values...") property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) @@ -166,7 +169,11 @@ def _merge_props_into_base(self, row): if isinstance(property, AtomProperty): x = torch.cat([x, property_values], dim=1) elif isinstance(property, BondProperty): - edge_attr = torch.cat([edge_attr, property_values], dim=1) + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) return GeomData( diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index e71fcfe..f0b995a 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -54,14 +54,14 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) - - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) + # First source to target edges, then target to source edges + src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()] + tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()] + edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long) + + # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] + edge_attr = torch.zeros((edge_index.size(1), 0)) + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) def on_finish(self): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py new file mode 100644 index 0000000..0222fa4 --- /dev/null +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -0,0 +1,86 @@ +import unittest + +import torch +from torch_geometric.data import Data as GeomData + +from chebai_graph.preprocessing.reader import GraphPropertyReader +from tests.unit.test_data import MoleculeGraph + + +class TestGraphPropertyReader(unittest.TestCase): + """Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects.""" + + def setUp(self) -> None: + """Initialize the reader and the reference molecule graph.""" + self.reader: GraphPropertyReader = GraphPropertyReader() + self.molecule_graph: MoleculeGraph = MoleculeGraph() + + def test_read_data(self) -> None: + """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" + smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin + + data: GeomData = self.reader._read_data(smiles) # noqa + + self.assertIsInstance( + data, + GeomData, + msg="The output should be an instance of torch_geometric.data.Data.", + ) + + self.assertEqual( + data.edge_index.shape[0], + 2, + msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}", + ) + + self.assertEqual( + data.edge_index.shape[1], + data.edge_attr.shape[0], + msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})", + ) + + self.assertEqual( + len(set(data.edge_index[0].tolist())), + data.x.shape[0], + msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})", + ) + + # Check for duplicates by checking if the rows are the same (direction matters) + _, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True) + self.assertFalse( + torch.any(counts > 1), + msg="There are duplicates of directed edge in edge_index", + ) + + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + self.assertTrue( + torch.equal(data.edge_index, expected_data.edge_index), + msg=( + "edge_index tensors do not match.\n" + f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n" + f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}" + f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6" + ), + ) + + self.assertEqual( + data.x.shape[0], + expected_data.x.shape[0], + msg=( + "The number of atoms (nodes) in the parsed graph does not match the reference graph.\n" + f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}" + ), + ) + + self.assertEqual( + data.edge_attr.shape[0], + expected_data.edge_attr.shape[0], + msg=( + "The number of edge attributes does not match the expected value.\n" + f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}" + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 0000000..4acf41d --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,102 @@ +import torch +from torch_geometric.data import Data + + +class MoleculeGraph: + """Class representing molecular graph data.""" + + def get_aspirin_graph(self): + """ + Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin. + + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 + + Node labels (atom indices): + O2 C5———C6 + \ / \ + C1———O3———C4 C7 + / \ / + C0 C9———C8 + / + C10 + / \ + O12 O11 + + + Returns: + torch_geometric.data.Data: A Data object with attributes: + - x (FloatTensor): Node feature matrix of shape (num_nodes, 1). + - edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges). + - edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1). + + Refer: + For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html + """ + + # --- Node features: atomic numbers (C=6, O=8) --- + # Shape of x : num_nodes x num_of_node_features + # fmt: off + x = torch.tensor( + [ + [6], # C0 - This feature belongs to atom/node with 0 value in edge_index + [6], # C1 - This feature belongs to atom/node with 1 value in edge_index + [8], # O2 - This feature belongs to atom/node with 2 value in edge_index + [8], # O3 - This feature belongs to atom/node with 3 value in edge_index + [6], # C4 - This feature belongs to atom/node with 4 value in edge_index + [6], # C5 - This feature belongs to atom/node with 5 value in edge_index + [6], # C6 - This feature belongs to atom/node with 6 value in edge_index + [6], # C7 - This feature belongs to atom/node with 7 value in edge_index + [6], # C8 - This feature belongs to atom/node with 8 value in edge_index + [6], # C9 - This feature belongs to atom/node with 9 value in edge_index + [6], # C10 - This feature belongs to atom/node with 10 value in edge_index + [8], # O11 - This feature belongs to atom/node with 11 value in edge_index + [8], # O12 - This feature belongs to atom/node with 12 value in edge_index + ], + dtype=torch.float, + ) + # fmt: on + + # --- Edge list (bidirectional) --- + # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) + # Generated using RDKIT 2024.9.6 + # fmt: off + _edge_index = torch.tensor([ + [0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v) + ], dtype=torch.long) + # fmt: on + + # Reverse the edges + reversed_edge_index = _edge_index[[1, 0], :] + + # First all directed edges from source to target are placed, + # then all directed edges from target to source are placed --- this is needed + undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1) + + # --- Dummy edge features --- + # Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1) + # fmt: off + _edge_attr = torch.tensor([ + [1], # C0 - C1, This two features belong to elements at index 0 in `edge_index` + [2], # C1 - C2, This two features belong to elements at index 1 in `edge_index` + [2], # C1 - O3, This two features belong to elements at index 2 in `edge_index` + [2], # O3 - C4, This two features belong to elements at index 3 in `edge_index` + [1], # C4 - C5, This two features belong to elements at index 4 in `edge_index` + [1], # C5 - C6, This two features belong to elements at index 5 in `edge_index` + [1], # C6 - C7, This two features belong to elements at index 6 in `edge_index` + [1], # C7 - C8, This two features belong to elements at index 7 in `edge_index` + [1], # C8 - C9, This two features belong to elements at index 8 in `edge_index` + [1], # C9 - C10, This two features belong to elements at index 9 in `edge_index` + [1], # C10 - O11, This two features belong to elements at index 10 in `edge_index` + [1], # C10 - O12, This two features belong to elements at index 11 in `edge_index` + [1], # C9 - C4, This two features belong to elements at index 12 in `edge_index` + ], dtype=torch.float) + # fmt: on + + # Alignement of edge attributes should in same order as of edge_index + undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0) + + # Create graph data object + return Data( + x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr + )