diff --git a/.gitignore b/.gitignore index 3f0c0ab..d01eecf 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,8 @@ cython_debug/ electra_pretrained.ckpt .isort.cfg /.vscode + +*.err +*.out +*.sh +*.ckpt diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index e69de29..9e20b2d 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -0,0 +1,19 @@ +from .augmented import ( + GATAugNodePoolGraphPred, + GATGraphNodeFGNodePoolGraphPred, + ResGatedAugNodePoolGraphPred, + ResGatedGraphNodeFGNodePoolGraphPred, +) +from .dynamic_gni import ResGatedDynamicGNIGraphPred +from .gat import GATGraphPred +from .resgated import ResGatedGraphPred + +__all__ = [ + "ResGatedGraphPred", + "ResGatedAugNodePoolGraphPred", + "ResGatedGraphNodeFGNodePoolGraphPred", + "GATGraphPred", + "GATAugNodePoolGraphPred", + "GATGraphNodeFGNodePoolGraphPred", + "ResGatedDynamicGNIGraphPred", +] diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py new file mode 100644 index 0000000..fdb5388 --- /dev/null +++ b/chebai_graph/models/augmented.py @@ -0,0 +1,45 @@ +from .base import AugmentedNodePoolingNet, GraphNodeFGNodePoolingNet +from .gat import GATGraphPred +from .resgated import ResGatedGraphPred + + +class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred): + """ + Combines: + - AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes). + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ + + ... + + +class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred): + """ + Combines: + - AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes). + - GATGraphPred: Graph attention network for final graph prediction. + """ + + ... + + +class ResGatedGraphNodeFGNodePoolGraphPred( + GraphNodeFGNodePoolingNet, ResGatedGraphPred +): + """ + Combines: + - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes). + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ + + ... + + +class GATGraphNodeFGNodePoolGraphPred(GraphNodeFGNodePoolingNet, GATGraphPred): + """ + Combines: + - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes). + - GATGraphPred: Graph attention network for final graph prediction. + """ + + ... diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py new file mode 100644 index 0000000..3dd7d57 --- /dev/null +++ b/chebai_graph/models/base.py @@ -0,0 +1,706 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.structures import XYData +from torch_geometric.data import Data as GraphData +from torch_scatter import scatter_add + + +class GraphBaseNet(ChebaiBaseNet, ABC): + """ + Base class for graph-based prediction networks. + """ + + def _get_prediction_and_labels( + self, data: XYData, labels: torch.Tensor, output: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply sigmoid activation to outputs and return processed labels. + + Args: + data (XYData): Input batch data. + labels (torch.Tensor): Ground-truth labels. + output (torch.Tensor): Raw model output. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of (predictions, labels). + """ + return torch.sigmoid(output), labels.int() + + def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor | None: + """ + Process labels from XYData batch. + + Returns: + torch.Tensor | None: Processed labels if present, else None. + """ + return batch.y.float() if batch.y is not None else None + + +class GraphModelBase(torch.nn.Module, ABC): + """ + Abstract base class for graph models with configurable architecture. + """ + + def __init__(self, config: dict, **kwargs) -> None: + """ + Initialize model hyperparameters from configuration. + + Args: + config (dict): Configuration dictionary with keys: + - 'num_layers' + - 'in_channels' + - 'hidden_channels' + - 'out_channels' + - 'edge_dim' + - 'dropout' + **kwargs: Additional keyword arguments for torch.nn.Module. + """ + super().__init__(**kwargs) + self.num_layers = int(config["num_layers"]) + assert self.num_layers > 1, "Need atleast two convolution layers" + self.in_channels = int(config["in_channels"]) # number of node/atom properties + self.hidden_channels = int(config["hidden_channels"]) + self.out_channels = int(config["out_channels"]) + self.edge_dim = int(config["edge_dim"]) # number of bond properties + self.dropout = float(config["dropout"]) + + +class GraphNetWrapper(GraphBaseNet, ABC): + """ + Base wrapper class for GNNs with linear layers for property prediction. + """ + + def __init__( + self, + config: dict, + n_linear_layers: int, + n_molecule_properties: Optional[int] = 0, + use_batch_norm: bool = False, + **kwargs, + ): + """ + Args: + config (dict): Model configuration. + n_linear_layers (int): Number of linear layers. + n_molecule_properties (int): Number of molecular-level features. + **kwargs: Additional arguments. + """ + super().__init__(**kwargs) + self.gnn = self._get_gnn(config) + gnn_out_dim = int(config["out_channels"]) + self.activation = torch.nn.ELU + self.lin_input_dim = self._get_lin_seq_input_dim( + gnn_out_dim=gnn_out_dim, + n_molecule_properties=( + n_molecule_properties if n_molecule_properties is not None else 0 + ), + ) + self.use_batch_norm = use_batch_norm + if self.use_batch_norm: + self.batch_norm = torch.nn.BatchNorm1d(self.lin_input_dim) + + lin_hidden_dim = kwargs.get("lin_hidden_dim", gnn_out_dim) + self.lin_sequential: torch.nn.Sequential = self._get_linear_module_list( + n_linear_layers=n_linear_layers, + in_dim=self.lin_input_dim, + hidden_dim=lin_hidden_dim, + out_dim=self.out_dim, + ) + + @abstractmethod + def _get_gnn(self, config: dict) -> torch.nn.Module: + """ + Create the graph neural network. + + Args: + config (dict): Configuration dictionary. + + Returns: + torch.nn.Module: Instantiated GNN module. + """ + pass + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute input dimension for the linear layers. + + Args: + gnn_out_dim (int): Output dimension of GNN. + n_molecule_properties (int): Number of molecule-level features. + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + + def _get_linear_module_list( + self, + n_linear_layers: int, + in_dim: int, + hidden_dim: int, + out_dim: int, + ) -> torch.nn.Sequential: + """ + Construct a sequential module of linear layers. + + Args: + n_linear_layers (int): Number of linear layers. + in_dim (int): Input dimension. + hidden_dim (int): Hidden dimension. + out_dim (int): Output dimension. + + Returns: + torch.nn.Sequential: Linear layers with activations. + """ + if n_linear_layers < 1: + raise ValueError("n_linear_layers must be at least 1") + + layers = [] + if n_linear_layers == 1: + layers.append(torch.nn.Linear(in_dim, out_dim)) + else: + layers.append(torch.nn.Linear(in_dim, hidden_dim)) + layers.append(self.activation()) + for _ in range(n_linear_layers - 2): + layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) + layers.append(self.activation()) + layers.append(torch.nn.Linear(hidden_dim, out_dim)) + + return torch.nn.Sequential(*layers) + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass through GNN, pooling and linear layers. + + Args: + batch (dict): Input batch with graph features. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + a = self.gnn(batch) + a = scatter_add(a, graph_data.batch, dim=0) + a = torch.cat([a, graph_data.molecule_attr], dim=1) + if self.use_batch_norm: + a = self.batch_norm(a) + return self.lin_sequential(a) + + +class AugmentedNodePoolingNet(GraphNetWrapper, ABC): + """ + A pooling network that aggregates: + - Atom node embeddings + - Molecular attributes (if provided else skipped) + - Augmented node embeddings (FG nodes and graph node) + + The concatenated vector is then passed through a linear sequential block. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute the input dimension for the final linear sequential block. + + Includes: + - Atom embeddings + - Molecular attributes (if any) + - Augmented node embeddings + + Args: + gnn_out_dim (int): Dimension of the GNN output per node. + n_molecule_properties (int): Number of molecule-level attributes. + + Returns: + int: Total input dimension for the linear sequential block. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass for pooling node embeddings. + + Steps: + 1. Identify atom nodes and augmented nodes. + 2. Compute node embeddings with the GNN. + 3. Aggregate embeddings for atoms and augmented nodes separately using scatter add. + 4. Concatenate: + - Atom nodes vector + - Molecular attributes + - Augmented nodes vector + 5. Pass the concatenated vector through the linear sequential block. + + Args: + batch (dict): Input batch containing graph data and features. + + Returns: + torch.Tensor: Output tensor after pooling and linear transformation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + is_atom_node = graph_data.is_atom_node.bool() + is_augmented_node = ~is_atom_node + + node_embeddings = self.gnn(batch) + + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] + + augmented_nodes_embeddings = node_embeddings[is_augmented_node] + augmented_nodes_batch = graph_data.batch[is_augmented_node] + + # Scatter add separately + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + aug_nodes_vec = scatter_add( + augmented_nodes_embeddings, augmented_nodes_batch, dim=0 + ) + + # Concatenate all + graph_vector = torch.cat( + [atoms_vec, graph_data.molecule_attr, aug_nodes_vec], dim=1 + ) + + return self.lin_sequential(graph_vector) + + +class FGNodePoolingNet(GraphNetWrapper, ABC): + """ + A pooling network that pools node embeddings by aggregating: + - All non-functional-group nodes' embeddings (atom and graph node) + - Molecular attributes + - Functional group node embeddings + + The concatenated vector is then passed through a linear sequential block. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Computes the input dimension for the final linear sequential block. + + Combines: + - All nodes embeddings except functional group nodes + - Molecular attributes + - Functional group node embeddings + + Args: + gnn_out_dim (int): Dimension of the GNN output per node. + n_molecule_properties (int): Number of molecule-level attributes. + + Returns: + int: Total input dimension for the linear sequential block. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass for pooling node embeddings. + + Steps: + 1. Identify graph, atom, and functional group nodes. + 2. Aggregate embeddings for remaining nodes and functional group nodes separately. + 3. Concatenate: + - Remaining nodes vector + - Molecular attributes + - Functional group nodes vector + 4. Pass the concatenated vector through the linear sequential block. + + Args: + batch (dict): Batch containing graph data and features. + + Returns: + torch.Tensor: Output tensor after pooling and linear transformation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + is_remaining_node = ~is_fg_node + + node_embeddings = self.gnn(batch) + + remaining_nodes_embedding = node_embeddings[is_remaining_node] + remaining_nodes_batch = graph_data.batch[is_remaining_node] + + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + remaining_nodes_vec = scatter_add( + remaining_nodes_embedding, remaining_nodes_batch, dim=0 + ) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [remaining_nodes_vec, graph_data.molecule_attr, fg_nodes_vec], dim=1 + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodeFGNodePoolingNet(GraphNetWrapper, ABC): + """ + A pooling network that pools node embeddings by aggregating: + - Atom nodes + - Molecular attributes + - Functional group node embeddings + - Graph node embeddings + + The concatenated vector is then passed through a linear sequential block. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Computes the input dimension for the final linear sequential block. + + Combines: + - Atom embeddings + - Molecular attributes + - Functional group node embeddings + - Graph node embeddings + + Args: + gnn_out_dim (int): Dimension of the GNN output per node. + n_molecule_properties (int): Number of molecule-level attributes. + + Returns: + int: Total input dimension for the linear sequential block. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass for pooling node embeddings. + + Steps: + 1. Identify graph, atom, and functional group nodes. + 2. Aggregate embeddings for each node type separately. + 3. Concatenate: + - Atom nodes vector + - Molecular attributes + - Functional group nodes vector + - Graph node vector + 4. Pass the concatenated vector through the linear sequential block. + + Args: + batch (dict): Batch containing graph data and features. + + Returns: + torch.Tensor: Output tensor after pooling and linear transformation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + + node_embeddings = self.gnn(batch) + + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] + + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [atoms_vec, graph_data.molecule_attr, fg_nodes_vec, graph_node_vec], dim=1 + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodePoolingNet(GraphNetWrapper, ABC): + """ + Pooling using non-graph nodes and graph node embeddings. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension including graph node embeddings. + - all_nodes_embeddings_except_graph_node + molecule attributes + graph_node_embedding + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass with separate pooling for graph and other nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_not_graph_node = ~is_graph_node + + node_embeddings = self.gnn(batch) + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + remaining_nodes_embedding = node_embeddings[is_not_graph_node] + remaining_nodes_batch = graph_data.batch[is_not_graph_node] + + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + remaining_nodes_vec = scatter_add( + remaining_nodes_embedding, remaining_nodes_batch, dim=0 + ) + + graph_vector = torch.cat( + [remaining_nodes_vec, graph_data.molecule_attr, graph_node_vec], dim=1 + ) + return self.lin_sequential(graph_vector) + + +class FGNodePoolingNoGraphNodeNet(GraphNetWrapper, ABC): + """ + Graph Node not considered here in any computation. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute input dimension including: + - atom_embeddings + - molecule attributes + - functional_group_node_embeddings + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling atoms and functional group nodes. + Graph nodes are ignored. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + + node_embeddings = self.gnn(batch) + + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] + + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] + + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + + graph_vector = torch.cat( + [atoms_vec, graph_data.molecule_attr, fg_nodes_vec], dim=1 + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodeNoFGNodePoolingNet(GraphNetWrapper, ABC): + """ + Functional Group Nodes not considered here in any computation. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute input dimension including: + - atom_embeddings + - molecule attributes + - graph_node_embeddings + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling atoms and graph nodes. + Functional group nodes are ignored. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + + node_embeddings = self.gnn(batch) + + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] + + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + + graph_vector = torch.cat( + [atoms_vec, graph_data.molecule_attr, graph_node_vec], dim=1 + ) + + return self.lin_sequential(graph_vector) + + +class AugmentedOnlyPoolingNet(GraphNetWrapper, ABC): + """ + Only augmented node embeddings are pooled. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension using only augmented node embeddings. + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling only augmented nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + is_atom_node = graph_data.is_atom_node.bool() + augmented_nodes_embeddings = self.gnn(batch)[~is_atom_node] + augmented_nodes_batch = graph_data.batch[~is_atom_node] + + aug_nodes_vec = scatter_add( + augmented_nodes_embeddings, augmented_nodes_batch, dim=0 + ) + graph_vector = torch.cat([aug_nodes_vec, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(graph_vector) + + +class FGOnlyPoolingNet(GraphNetWrapper, ABC): + """ + Only functional group node embeddings are pooled. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension using only FG node embeddings. + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling only functional group nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + fg_nodes_embeddings = self.gnn(batch)[~is_fg_node] + fg_nodes_batch = graph_data.batch[~is_fg_node] + + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + graph_vector = torch.cat([fg_nodes_vec, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(graph_vector) + + +class GraphNodeOnlyPoolingNet(GraphNetWrapper, ABC): + """ + Only graph node embeddings are pooled. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension using only graph node embeddings. + + Returns: + int: Total input dimension. + """ + return gnn_out_dim + n_molecule_properties + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling only graph nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ + graph_data = batch["features"][0] + is_graph_node = graph_data.is_graph_node.bool() + + graph_node_embedding = self.gnn(batch)[~is_graph_node] + graph_node_batch = graph_data.batch[~is_graph_node] + + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + graph_vector = torch.cat([graph_node_vec, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(graph_vector) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py new file mode 100644 index 0000000..20a6c6a --- /dev/null +++ b/chebai_graph/models/dynamic_gni.py @@ -0,0 +1,165 @@ +from typing import Any + +import torch +from torch import Tensor +from torch.nn import ELU +from torch_geometric.data import Data as GraphData +from torch_geometric.nn.models.basic_gnn import BasicGNN + +from chebai_graph.preprocessing.reader import RandomFeatureInitializationReader + +from .base import GraphModelBase, GraphNetWrapper +from .resgated import ResGatedModel + + +class ResGatedDynamicGNI(GraphModelBase): + """ + Base model class for applying ResGatedGraphConv layers to graph-structured data + with dynamic initialization of features for nodes and edges. + + Args: + config (dict): Configuration dictionary containing model hyperparameters. + **kwargs: Additional keyword arguments for parent class. + """ + + def __init__(self, config: dict[str, Any], **kwargs: Any): + super().__init__(config=config, **kwargs) + self.activation = ELU() # Instantiate ELU once for reuse. + + distribution = config.get("distribution", "normal") + assert distribution in RandomFeatureInitializationReader.DISTRIBUTIONS, ( + f"Unsupported distribution: {distribution}. " + f"Choose from {RandomFeatureInitializationReader.DISTRIBUTIONS}." + ) + self.distribution = distribution + + self.complete_randomness = ( + str(config.get("complete_randomness", "True")).lower() == "true" + ) + + print("Using complete randomness: ", self.complete_randomness) + + if not self.complete_randomness: + assert ( + "pad_node_features" in config or "pad_edge_features" in config + ), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False" + self.pad_node_features = ( + int(config["pad_node_features"]) + if config.get("pad_node_features") is not None + else None + ) + if self.pad_node_features is not None: + print( + f"[Info] Node features will be padded with {self.pad_node_features} " + f"new set of random features from distribution {self.distribution} " + f"in each forward pass." + ) + + self.pad_edge_features = ( + int(config["pad_edge_features"]) + if config.get("pad_edge_features") is not None + else None + ) + if self.pad_edge_features is not None: + print( + f"[Info] Edge features will be padded with {self.pad_edge_features} " + f"new set of random features from distribution {self.distribution} " + f"in each forward pass." + ) + + assert ( + self.pad_node_features > 0 or self.pad_edge_features > 0 + ), "'pad_node_features' or 'pad_edge_features' must be positive integers" + + self.resgated: BasicGNN = ResGatedModel( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + out_channels=self.out_channels, + num_layers=self.num_layers, + edge_dim=self.edge_dim, + act=self.activation, + ) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def forward(self, batch: dict[str, Any]) -> Tensor: + """ + Forward pass of the model. + + Args: + batch (dict): A batch containing graph input features under the key "features". + + Returns: + Tensor: The output node-level embeddings after the final activation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData), "Expected GraphData instance" + + new_x = None + new_edge_attr = None + if self.complete_randomness: + new_x = torch.empty( + graph_data.x.shape[0], graph_data.x.shape[1], device=self.device + ) + RandomFeatureInitializationReader.random_gni(new_x, self.distribution) + + new_edge_attr = torch.empty( + graph_data.edge_attr.shape[0], + graph_data.edge_attr.shape[1], + device=self.device, + ) + RandomFeatureInitializationReader.random_gni( + new_edge_attr, self.distribution + ) + else: + if self.pad_node_features is not None: + pad_node = torch.empty( + graph_data.x.shape[0], + self.pad_node_features, + device=self.device, + ) + RandomFeatureInitializationReader.random_gni( + pad_node, self.distribution + ) + new_x = torch.cat((graph_data.x, pad_node), dim=1) + + if self.pad_edge_features is not None: + pad_edge = torch.empty( + graph_data.edge_attr.shape[0], + self.pad_edge_features, + device=self.device, + ) + RandomFeatureInitializationReader.random_gni( + pad_edge, self.distribution + ) + new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1) + + assert ( + new_x is not None and new_edge_attr is not None + ), "Feature initialization failed" + out = self.resgated( + x=new_x.float(), + edge_index=graph_data.edge_index.long(), + edge_attr=new_edge_attr.float(), + ) + + return self.activation(out) + + +class ResGatedDynamicGNIGraphPred(GraphNetWrapper): + """ + Wrapper for graph-level prediction using ResGatedDynamicGNI. + + This class instantiates the core GNN model using the provided config. + """ + + def _get_gnn(self, config: dict[str, Any]) -> ResGatedDynamicGNI: + """ + Returns the core ResGated GNN model. + + Args: + config (dict): Configuration dictionary for the GNN model. + + Returns: + ResGatedDynamicGNI: The core graph convolutional network. + """ + return ResGatedDynamicGNI(config=config) diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py new file mode 100644 index 0000000..13dcabc --- /dev/null +++ b/chebai_graph/models/gat.py @@ -0,0 +1,94 @@ +import torch +from torch.nn import ELU +from torch_geometric.data import Data as GraphData +from torch_geometric.nn.models import GAT + +from .base import GraphModelBase, GraphNetWrapper + + +class GATGraphConvNetBase(GraphModelBase): + """ + Graph Attention Network (GAT) base module for graph convolution. + + Uses PyTorch Geometric's `GAT` implementation to process atomic node features + and bond edge attributes through multiple attention heads and layers. + """ + + def __init__(self, config: dict, **kwargs): + """ + Initialize the GATGraphConvNetBase. + + Args: + config (dict): Model configuration containing: + - 'heads' (int): Number of attention heads. + - 'v2' (bool): Whether to use the GATv2 variant. + - Other required GraphModelBase parameters. + **kwargs: Additional arguments for the base class. + """ + super().__init__(config=config, **kwargs) + self.heads = int(config["heads"]) + self.v2 = bool(config["v2"]) + local_kwargs = {} + if self.v2: + local_kwargs["share_weights"] = bool(config.get("share_weights", False)) + self.activation = ELU() # Instantiate ELU once for reuse. + self.gat = GAT( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + out_channels=self.out_channels, + num_layers=self.num_layers, + dropout=self.dropout, + edge_dim=self.edge_dim, + heads=self.heads, + v2=self.v2, + act=self.activation, + **local_kwargs, + ) + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass through the GAT network. + + Processes atomic node features and edge attributes, and applies + an ELU activation to the output. + + Args: + batch (dict): Input batch containing: + - 'features': A list with a `GraphData` object as its first element. + + Returns: + torch.Tensor: Node embeddings after GAT and activation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + out = self.gat( + x=graph_data.x.float(), + edge_index=graph_data.edge_index.long(), + edge_attr=graph_data.edge_attr, + ) + + return self.activation(out) + + +class GATGraphPred(GraphNetWrapper): + """ + GAT-based graph prediction model. + + Uses a `GATGraphConvNetBase` as the GNN backbone for generating node embeddings, + which are then pooled by the GraphNetWrapper for final prediction. + """ + + NAME = "GATGraphPred" + + def _get_gnn(self, config: dict) -> GATGraphConvNetBase: + """ + Instantiate the GAT graph convolutional network base. + + Args: + config (dict): Model configuration. + + Returns: + GATGraphConvNetBase: The initialized GNN. + """ + return GATGraphConvNetBase(config=config) diff --git a/chebai_graph/models/graph.py b/chebai_graph/models/graph.py index 5da9a62..3be0fdd 100644 --- a/chebai_graph/models/graph.py +++ b/chebai_graph/models/graph.py @@ -3,8 +3,6 @@ import torch import torch.nn.functional as F -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.structures import XYData from torch import nn from torch_geometric import nn as tgnn from torch_geometric.data import Data as GraphData @@ -12,15 +10,9 @@ from chebai_graph.loss.pretraining import MaskPretrainingLoss -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - +from .base import GraphBaseNet -class GraphBaseNet(ChebaiBaseNet): - def _get_prediction_and_labels(self, data, labels, output): - return torch.sigmoid(output), labels.int() - - def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: - return batch.y.float() if batch.y is not None else None +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) class JCIGraphNet(GraphBaseNet): @@ -188,6 +180,67 @@ def forward(self, batch): return a +class ResGatedAugmentedGraphPred(GraphBaseNet): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedAugmentedGraphPred" + + def __init__( + self, + config: typing.Dict, + n_linear_layers=2, + **kwargs, + ): + super().__init__(**kwargs) + self.gnn = ResGatedGraphConvNetBase(config, **kwargs) + self.linear_layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + self.gnn.hidden_length + + (i == 0) * self.gnn.n_molecule_properties + + (i == 0) * self.gnn.hidden_length, + self.gnn.hidden_length, + ) + for i in range(n_linear_layers - 1) + ] + ) + self.final_layer = nn.Linear(self.gnn.hidden_length, self.out_dim) + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_atom_node = graph_data.is_atom_node.bool() # Boolean mask: shape [num_nodes] + is_augmented_node = ~is_atom_node + + node_embeddings = self.gnn(batch) + + atom_embeddings = node_embeddings[is_atom_node] + atom_batch = graph_data.batch[is_atom_node] + + augmented_node_embeddings = node_embeddings[is_augmented_node] + augmented_node_batch = graph_data.batch[is_augmented_node] + + # Scatter add separately + graph_vec_atoms = scatter_add(atom_embeddings, atom_batch, dim=0) + graph_vec_augmented_nodes = scatter_add( + augmented_node_embeddings, augmented_node_batch, dim=0 + ) + + # Concatenate both + graph_vector = torch.cat( + [ + graph_vec_atoms, + graph_data.molecule_attr, + graph_vec_augmented_nodes, + ], + dim=1, + ) + + for lin in self.linear_layers: + a = self.gnn.activation(lin(graph_vector)) + return self.final_layer(a) + + class ResGatedGraphConvNetPretrain(GraphBaseNet): """For pretraining. BaseNet with an additional output layer for predicting atom properties""" diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py new file mode 100644 index 0000000..521a853 --- /dev/null +++ b/chebai_graph/models/resgated.py @@ -0,0 +1,109 @@ +from typing import Any, Final + +from torch import Tensor +from torch.nn import ELU +from torch_geometric import nn as tgnn +from torch_geometric.data import Data as GraphData +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.models.basic_gnn import BasicGNN + +from .base import GraphModelBase, GraphNetWrapper + + +class ResGatedModel(BasicGNN): + """ + A residual gated GNN model based on PyG's BasicGNN using ResGatedGraphConv layers. + + Attributes: + supports_edge_weight (bool): Indicates edge weights are not supported. + supports_edge_attr (bool): Indicates edge attributes are supported. + supports_norm_batch (bool): Indicates if batch normalization is supported. + """ + + supports_edge_weight: Final[bool] = False + supports_edge_attr: Final[bool] = True + supports_norm_batch: Final[bool] + + def init_conv( + self, in_channels: int | tuple[int, int], out_channels: int, **kwargs: Any + ) -> MessagePassing: + """ + Initializes a ResGatedGraphConv layer. + + Args: + in_channels (int or Tuple[int, int]): Number of input channels. + out_channels (int): Number of output channels. + **kwargs: Additional keyword arguments for the convolution layer. + + Returns: + MessagePassing: A ResGatedGraphConv layer instance. + """ + return tgnn.ResGatedGraphConv( + in_channels, + out_channels, + **kwargs, + ) + + +class ResGatedGraphConvNetBase(GraphModelBase): + """ + Base model class for applying ResGatedGraphConv layers to graph-structured data. + + Args: + config (dict): Configuration dictionary containing model hyperparameters. + **kwargs: Additional keyword arguments for parent class. + """ + + def __init__(self, config: dict[str, Any], **kwargs: Any): + super().__init__(config=config, **kwargs) + self.activation = ELU() # Instantiate ELU once for reuse. + + self.resgated: BasicGNN = ResGatedModel( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + out_channels=self.out_channels, + num_layers=self.num_layers, + edge_dim=self.edge_dim, + act=self.activation, + ) + + def forward(self, batch: dict[str, Any]) -> Tensor: + """ + Forward pass of the model. + + Args: + batch (dict): A batch containing graph input features under the key "features". + + Returns: + Tensor: The output node-level embeddings after the final activation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData), "Expected GraphData instance" + + out = self.resgated( + x=graph_data.x.float(), + edge_index=graph_data.edge_index.long(), + edge_attr=graph_data.edge_attr, + ) + + return self.activation(out) + + +class ResGatedGraphPred(GraphNetWrapper): + """ + Wrapper for graph-level prediction using ResGatedGraphConvNetBase. + + This class instantiates the core GNN model using the provided config. + """ + + def _get_gnn(self, config: dict[str, Any]) -> ResGatedGraphConvNetBase: + """ + Returns the core ResGated GNN model. + + Args: + config (dict): Configuration dictionary for the GNN model. + + Returns: + ResGatedGraphConvNetBase: The core graph convolutional network. + """ + return ResGatedGraphConvNetBase(config=config) diff --git a/chebai_graph/preprocessing/__init__.py b/chebai_graph/preprocessing/__init__.py index 80488cc..e69de29 100644 --- a/chebai_graph/preprocessing/__init__.py +++ b/chebai_graph/preprocessing/__init__.py @@ -1,37 +0,0 @@ -from chebai_graph.preprocessing.properties import ( - AtomAromaticity, - AtomCharge, - AtomChirality, - AtomHybridization, - AtomNumHs, - AtomProperty, - AtomType, - BondAromaticity, - BondInRing, - BondProperty, - BondType, - MolecularProperty, - MoleculeNumRings, - MoleculeProperty, - NumAtomBonds, - RDKit2DNormalized, -) - -__all__ = [ - "AtomAromaticity", - "AtomCharge", - "AtomChirality", - "AtomHybridization", - "AtomNumHs", - "AtomProperty", - "AtomType", - "BondAromaticity", - "BondInRing", - "BondProperty", - "BondType", - "MolecularProperty", - "MoleculeNumRings", - "MoleculeProperty", - "NumAtomBonds", - "RDKit2DNormalized", -] diff --git a/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt new file mode 100644 index 0000000..adc6d60 --- /dev/null +++ b/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt @@ -0,0 +1,158 @@ +NO_FG +graph_fg +alkyl +4_ammonium_ion +fluoro +chloro +ether +RING_6 +nitrile +primary_amine +secondary_amine +tertiary_carbon +RING_5 +alkene +RING_7 +tertiary_amine +ketone +hydroxyl +sulfide +sulfonyl +amide +ester +carboxyl +quaternary_carbon +secondary_ketimine +alkyne +alkene_carbon +nitro +trifluoromethyl +RING_4 +secondary_aldimine +primary_ketimine +acetal +carbonate_ester +azo +carbamate +bromo +RING_3 +phosphono +RING_40 +imide +RING_19 +nitroso +RING_8 +sulfhydryl +aldehyde +ketoxime +borono +difluoromethyl +thiolester +aldoxime +carboxylate +thiocyanate +isothiocyanate +phosphate +phosphodiester +iodo +sulfonic_acid +carbodithio +sulfinyl +azide +thioketone +RING_23 +trimethylsilyl +RING_21 +sulfonate_ester +RING_16 +RING_15 +RING_9 +RING_18 +RING_20 +RING_32 +RING_14 +RING_22 +nitrate +primary_aldimine +disulfide +RING_10 +hemiacetal +dichloromethyl +trichloromethyl +haloformyl +hydroperoxy +hemiketal +isonitrile +RING_24 +RING_11 +RING_12 +nitrosooxy +RING_33 +phosphino +sulfino +RING_17 +cyanate +RING_13 +carbothioic_O-acid +RING_36 +RING_72 +hydrazone +RING_31 +RING_25 +RING_28 +RING_35 +RING_34 +RING_42 +RING_38 +RING_29 +RING_26 +RING_51 +RING_75 +RING_27 +RING_45 +RING_60 +RING_30 +RING_43 +RING_41 +RING_56 +RING_81 +RING_39 +RING_37 +RING_47 +RING_49 +RING_44 +phosphoryl +RING_54 +RING_53 +RING_63 +carboxylic_anhydride +carbothioic_S-acid +silyl_ether +borinate +carbodithioic_acid +peroxy +tribromomethyl +chlorobromomethyl +amidine +dibromomethyl +isocyanate +bromodichloromethyl +ketal +boronate +borino +chloroiodomethyl +bromoiodomethyl +thionoester +fluorochloromethyl +RING_68 +bromodifluoromethyl +RING_58 +difluorochloromethyl +thial +RING_52 +RING_48 +RING_62 +RING_164 +RING_71 +RING_46 +orthoester diff --git a/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt new file mode 100644 index 0000000..2d65776 --- /dev/null +++ b/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt @@ -0,0 +1,3 @@ +atom_node_lvl +fg_node_lvl +graph_node_level diff --git a/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt new file mode 100644 index 0000000..b389485 --- /dev/null +++ b/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt @@ -0,0 +1,4 @@ +atom_fg_lvl +to_graphNode_lvl +within_atoms_lvl +within_fg_lvl diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 4be36cf..53bdde5 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData @@ -9,19 +7,27 @@ class GraphCollator(RaggedCollator): + """Collates a batch of molecular graph data with label handling and edge consistency.""" + def __call__(self, data): - loss_kwargs: Dict = dict() + loss_kwargs: dict = {} + # Unpack labels and optional identifiers y, idents = zip(*((d["labels"], d.get("ident")) for d in data)) + + # Replace labels with `y` inside graph features and collect them merged_data = [] for row in data: row["features"].y = row["labels"] merged_data.append(row["features"]) - # add empty edge_attr to avoid problems during collate (only relevant for molecules without edges) + + # Add empty edge_attr for graphs with no edges to prevent PyG errors for mdata in merged_data: - for i, store in enumerate(mdata.stores): + for store in mdata.stores: if "edge_attr" not in store: store["edge_attr"] = torch.tensor([]) + + # Ensure all attributes are float tensors to prevent torch.cat dtype issues for attr in merged_data[0].keys(): for data in merged_data: for store in data.stores: @@ -34,26 +40,35 @@ def __call__(self, data): else: store[attr] = torch.tensor(store[attr], dtype=torch.float32) + # Use PyG's batch collate for graph data x = graph_collate( GeomData, merged_data, follow_batch=["x", "edge_attr", "edge_index", "label"], ) + + # Handle various combinations of missing or available labels if any(x is not None for x in y): + # If any label is not None: (None, None, `1`, None) if any(x is None for x in y): + # If any label is None: (`None`, `None`, 1, `None`) non_null_labels = [i for i, r in enumerate(y) if r is not None] y = self.process_label_rows( tuple(ye for i, ye in enumerate(y) if i in non_null_labels) ) loss_kwargs["non_null_labels"] = non_null_labels else: + # If all labels are not None: (`0`, `2`, `1`, `3`) y = self.process_label_rows(y) else: + # If all labels are None: e.g., (None, None, None, None) y = None loss_kwargs["non_null_labels"] = [] + # Set node features (x) to long dtype (e.g., for categorical features) x[0].x = x[0].x.to(dtype=torch.int64) # x is a Tuple[BaseData, Mapping, Mapping] + return XYGraphData( x, y, diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py new file mode 100644 index 0000000..d13b1c3 --- /dev/null +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -0,0 +1,34 @@ +from .chebi import ( + ChEBI50_Atom_WGNOnly_GraphProp, + ChEBI50_GN_WithAllNodes_FG_WithAtoms_FGE, + ChEBI50_GN_WithAllNodes_FG_WithAtoms_NoFGE, + ChEBI50_GN_WithAtoms_FG_WithAtoms_FGE, + ChEBI50_GN_WithAtoms_FG_WithAtoms_NoFGE, + ChEBI50_NFGE_NGN_GraphProp, + ChEBI50_NFGE_WGN_GraphProp, + ChEBI50_StaticGNI, + ChEBI50_WFGE_NGN_GraphProp, + ChEBI50_WFGE_WGN_AsPerNodeType, + ChEBI50_WFGE_WGN_GraphProp, + ChEBI50GraphData, + ChEBI50GraphProperties, +) +from .pubchem import PubChemGraphProperties + +__all__ = [ + "ChEBI50GraphFGAugmentorReader", + "ChEBI50GraphProperties", + "ChEBI50GraphData", + "PubChemGraphProperties", + "ChEBI50_Atom_WGNOnly_GraphProp", + "ChEBI50_NFGE_NGN_GraphProp", + "ChEBI50_NFGE_WGN_GraphProp", + "ChEBI50_WFGE_NGN_GraphProp", + "ChEBI50_WFGE_WGN_GraphProp", + "ChEBI50_StaticGNI", + "ChEBI50_WFGE_WGN_AsPerNodeType", + "ChEBI50_GN_WithAllNodes_FG_WithAtoms_FGE", + "ChEBI50_GN_WithAllNodes_FG_WithAtoms_NoFGE", + "ChEBI50_GN_WithAtoms_FG_WithAtoms_FGE", + "ChEBI50_GN_WithAtoms_FG_WithAtoms_NoFGE", +] diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 8532bf0..c94b772 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,78 +1,99 @@ -import importlib import os -from typing import Callable, List, Optional +from abc import ABC +from collections.abc import Callable +from pprint import pformat import pandas as pd import torch import tqdm -from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, + ChEBIOverX, ChEBIOverXPartial, ) from lightning_utilities.core.rank_zero import rank_zero_info from torch_geometric.data.data import Data as GeomData -import chebai_graph.preprocessing.properties as graph_properties from chebai_graph.preprocessing.properties import ( + AllNodeTypeProperty, + AtomNodeTypeProperty, AtomProperty, BondProperty, + FGNodeTypeProperty, MolecularProperty, + MoleculeProperty, ) -from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader +from chebai_graph.preprocessing.reader import ( + AtomFGReader_NoFGEdges_WithGraphNode, + AtomFGReader_WithFGEdges_NoGraphNode, + AtomFGReader_WithFGEdges_WithGraphNode, + AtomReader_WithGraphNodeOnly, + AtomsFGReader_NoFGEdges_NoGraphNode, + GN_WithAllNodes_FG_WithAtoms_FGE, + GN_WithAllNodes_FG_WithAtoms_NoFGE, + GN_WithAtoms_FG_WithAtoms_FGE, + GN_WithAtoms_FG_WithAtoms_NoFGE, + GraphPropertyReader, + GraphReader, + RandomFeatureInitializationReader, +) + +from .utils import resolve_property class ChEBI50GraphData(ChEBIOver50): + """ChEBI dataset with at least 50 samples per class, using GraphReader.""" + READER = GraphReader def __init__(self, **kwargs): super().__init__(**kwargs) -def _resolve_property( - property, #: str | properties.MolecularProperty -) -> MolecularProperty: - # if property is given as a string, try to resolve as a class path - if isinstance(property, MolecularProperty): - return property - try: - # split class_path into module-part and class name - last_dot = property.rindex(".") - module_name = property[:last_dot] - class_name = property[last_dot + 1 :] - module = importlib.import_module(module_name) - return getattr(module, class_name)() - except ValueError: - # if only a class name is given, assume the module is chebai_graph.processing.properties - return getattr(graph_properties, property)() - - -class GraphPropertiesMixIn(XYBaseDataModule): +class DataPropertiesSetter(ChEBIOverX, ABC): + """Mixin for adding molecular property encodings to graph-based ChEBI datasets.""" + READER = GraphPropertyReader def __init__( - self, properties: Optional[List], transform: Optional[Callable] = None, **kwargs + self, + properties: list | None = None, + transform: Callable | None = None, + **kwargs, ): + """ + Initialize GraphPropertiesMixIn. + + Args: + properties: Optional list of MolecularProperty class paths or instances. + transform: Optional transformation applied to each data sample. + """ super().__init__(**kwargs) # atom_properties and bond_properties are given as lists containing class_paths if properties is not None: - properties = [_resolve_property(prop) for prop in properties] - properties = sorted( - properties, key=lambda prop: self.get_property_path(prop) - ) + properties = [resolve_property(prop) for prop in properties] + properties = self._sort_properties(properties) else: properties = [] self.properties = properties assert isinstance(self.properties, list) and all( isinstance(p, MolecularProperty) for p in self.properties ) - rank_zero_info( - f"Data module uses these properties (ordered): {', '.join([str(p) for p in properties])}" - ) self.transform = transform - def _setup_properties(self): + def _sort_properties( + self, properties: list[MolecularProperty] + ) -> list[MolecularProperty]: + return sorted(properties, key=lambda prop: self.get_property_path(prop)) + + def _setup_properties(self) -> None: + """ + Process and cache molecular properties to disk. + + Returns: + None + """ raw_data = [] os.makedirs(self.processed_properties_dir, exist_ok=True) @@ -92,6 +113,7 @@ def _setup_properties(self): file, ) raw_data += list(self._load_dict(path)) + idents = [row["ident"] for row in raw_data] features = [row["features"] for row in raw_data] @@ -107,10 +129,12 @@ def enc_if_not_none(encode, value): if not os.path.isfile(self.get_property_path(property)): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode + rank_zero_info(f"\tReading property values of {property.name}...") property_values = [ self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) ] + rank_zero_info(f"\tEncoding property values of {property.name}...") property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) @@ -128,30 +152,89 @@ def enc_if_not_none(encode, value): property.on_finish() @property - def processed_properties_dir(self): + def processed_properties_dir(self) -> str: return os.path.join(self.processed_dir, "properties") - def get_property_path(self, property: MolecularProperty): + def get_property_path(self, property: MolecularProperty) -> str: + """ + Construct the cache path for a given molecular property. + + Args: + property: Instance of a MolecularProperty. + + Returns: + Path to the cached property file. + """ return os.path.join( self.processed_properties_dir, f"{property.name}_{property.encoder.name}.pt", ) - def _after_setup(self, **kwargs): + def _after_setup(self, **kwargs) -> None: """ - Finalize the setup process after ensuring the processed data is available. + Finalize setup after ensuring properties are processed. - This method performs post-setup tasks like finalizing the reader and setting internal properties. + Args: + **kwargs: Additional keyword arguments passed to superclass. + + Returns: + None """ self._setup_properties() super()._after_setup(**kwargs) - def _merge_props_into_base(self, row): + +class GraphPropertiesMixIn(DataPropertiesSetter, ABC): + def __init__( + self, + properties=None, + transform=None, + pad_node_features: int = None, + pad_edge_features: int = None, + distribution: str = "normal", + **kwargs, + ): + super().__init__(properties, transform, **kwargs) + self.pad_edge_features = int(pad_edge_features) if pad_edge_features else None + self.pad_node_features = int(pad_node_features) if pad_node_features else None + if self.pad_node_features or self.pad_edge_features: + assert ( + distribution is not None + and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS + ), "When using padding for features, a valid distribution must be specified." + self.distribution = distribution + if self.pad_node_features: + print( + f"[Info] Node-level features will be padded with random" + f"{self.pad_node_features} values from {self.distribution} distribution." + ) + if self.pad_edge_features: + print( + f"[Info] Edge-level features will be padded with random" + f"{self.pad_edge_features} values from {self.distribution} distribution." + ) + + if self.properties: + print( + f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" + ) + + def _merge_props_into_base(self, row: pd.Series) -> GeomData: + """ + Merge encoded molecular properties into the GeomData object. + + Args: + row: A dictionary containing 'features' and encoded properties. + + Returns: + A GeomData object with merged features. + """ geom_data = row["features"] + assert isinstance(geom_data, GeomData) edge_attr = geom_data.edge_attr x = geom_data.x molecule_attr = torch.empty((1, 0)) - assert isinstance(geom_data, GeomData) + for property in self.properties: property_values = row[f"{property.name}"] if isinstance(property_values, torch.Tensor): @@ -163,12 +246,34 @@ def _merge_props_into_base(self, row): property_values = torch.zeros( (0, property.encoder.get_encoding_length()) ) + if isinstance(property, AtomProperty): x = torch.cat([x, property_values], dim=1) elif isinstance(property, BondProperty): - edge_attr = torch.cat([edge_attr, property_values], dim=1) - else: + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) + elif isinstance(property, MoleculeProperty): molecule_attr = torch.cat([molecule_attr, property_values], dim=1) + else: + raise TypeError(f"Unsupported property type: {type(property).__name__}") + + if self.pad_node_features: + padding_values = torch.empty((x.shape[0], self.pad_node_features)) + RandomFeatureInitializationReader.random_gni( + padding_values, self.distribution + ) + x = torch.cat([x, padding_values], dim=1) + + if self.pad_edge_features: + padding_values = torch.empty((edge_attr.shape[0], self.pad_edge_features)) + RandomFeatureInitializationReader.random_gni( + padding_values, self.distribution + ) + edge_attr = torch.cat([edge_attr, padding_values], dim=1) + return GeomData( x=x, edge_index=geom_data.edge_index, @@ -176,9 +281,19 @@ def _merge_props_into_base(self, row): molecule_attr=molecule_attr, ) - def load_processed_data_from_file(self, filename): + def load_processed_data_from_file(self, filename: str) -> list[dict]: + """ + Load dataset and merge cached properties into base features. + + Args: + filename: The path to the file to load. + + Returns: + List of data entries, each a dictionary. + """ base_data = super().load_processed_data_from_file(filename) base_df = pd.DataFrame(base_data) + for property in self.properties: property_data = torch.load( self.get_property_path(property), weights_only=False @@ -206,26 +321,409 @@ def load_processed_data_from_file(self, filename): (prop.name, prop.encoder.get_encoding_length()) for prop in self.properties ] + # -------------------------- Count total node properties + n_node_properties = sum( + p.encoder.get_encoding_length() + for p in self.properties + if isinstance(p, AtomProperty) + ) + + in_channels_str = "" + if self.pad_node_features: + n_node_properties += self.pad_node_features + in_channels_str += f" (with {self.pad_node_features} padded random values from {self.distribution} distribution)" + + in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str + + # -------------------------- Count total edge properties + n_edge_properties = sum( + p.encoder.get_encoding_length() + for p in self.properties + if isinstance(p, BondProperty) + ) + edge_dim_str = "" + if self.pad_edge_features: + n_edge_properties += self.pad_edge_features + edge_dim_str += f" (with {self.pad_edge_features} padded random values from {self.distribution} distribution)" + + edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str + + rank_zero_info( + f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" + f"Use following values for given parameters for model configuration: \n\t" + f"{in_channels_str} \n\t" + f"{edge_dim_str} \n\t" + f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}" + ) + + return base_df[base_data[0].keys()].to_dict("records") + + +class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): + def __init__(self, properties=None, transform=None, **kwargs): + super().__init__(properties, transform, **kwargs) + # Sort properties so that AllNodeTypeProperty instances come first, rest of the properties order remain same + first = self._sort_properties( + [prop for prop in self.properties if isinstance(prop, AllNodeTypeProperty)] + ) + rest = self._sort_properties( + [ + prop + for prop in self.properties + if not isinstance(prop, AllNodeTypeProperty) + ] + ) + self.properties = first + rest + print( + "Properties are sorted so that `AllNodeTypeProperty` properties are first in sequence and rest of the order remains same\n", + f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}", + ) + + def load_processed_data_from_file(self, filename: str) -> list[dict]: + """ + Load dataset and merge cached properties into base features. + + Args: + filename: The path to the file to load. + + Returns: + List of data entries, each a dictionary. + """ + base_data = super().load_processed_data_from_file(filename) + base_df = pd.DataFrame(base_data) + + props_categories = { + "AllNodeTypeProperties": [], + "FGNodeTypeProperties": [], + "AtomNodeTypeProperties": [], + "GraphNodeTypeProperties": [], + "BondProperties": [], + } + n_atom_node_properties, n_fg_node_properties = 0, 0 + n_bond_properties, n_graph_node_properties = 0, 0 + prop_lengths = [] + for prop in self.properties: + prop_length = prop.encoder.get_encoding_length() + prop_name = prop.name + prop_lengths.append((prop_name, prop_length)) + if isinstance(prop, AllNodeTypeProperty): + n_atom_node_properties += prop_length + n_fg_node_properties += prop_length + n_graph_node_properties += prop_length + props_categories["AllNodeTypeProperties"].append(prop_name) + elif isinstance(prop, FGNodeTypeProperty): + n_fg_node_properties += prop_length + props_categories["FGNodeTypeProperties"].append(prop_name) + elif isinstance(prop, AtomNodeTypeProperty): + n_atom_node_properties += prop_length + props_categories["AtomNodeTypeProperties"].append(prop_name) + elif isinstance(prop, BondProperty): + n_bond_properties += prop_length + props_categories["BondProperties"].append(prop_name) + elif isinstance(prop, MoleculeProperty): + # molecule props will be used as graph node props + n_graph_node_properties += prop_length + props_categories["GraphNodeTypeProperties"].append(prop_name) + else: + raise TypeError(f"Unsupported property type: {type(prop).__name__}") + + n_node_properties = max( + n_atom_node_properties, n_fg_node_properties, n_graph_node_properties + ) rank_zero_info( - f"Finished loading dataset from properties." - f"\nEncoding lengths are: " - f"{prop_lengths}" - f"\nIf you train a model with these properties and encodings, " - f"use n_atom_properties: {sum([prop.encoder.get_encoding_length() for prop in self.properties if isinstance(prop, AtomProperty)])}, " - f"n_bond_properties: {sum([prop.encoder.get_encoding_length() for prop in self.properties if isinstance(prop, BondProperty)])} " - f"and n_molecule_properties: {sum([prop.encoder.get_encoding_length() for prop in self.properties if not (isinstance(prop, AtomProperty) or isinstance(prop, BondProperty))])}" + f"\nFinished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n\n" + f"Properties Categories:\n{pformat(props_categories)}\n\n" + f"n_atom_node_properties: {n_atom_node_properties}, " + f"n_fg_node_properties: {n_fg_node_properties}, " + f"n_bond_properties: {n_bond_properties}, " + f"n_graph_node_properties: {n_graph_node_properties}\n\n" + f"Use following values for given parameters for model configuration: \n\t" + f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0\n" + ) + + for property in self.properties: + property_data = torch.load( + self.get_property_path(property), weights_only=False + ) + if len(property_data[0][property.name].shape) > 1: + property.encoder.set_encoding_length( + property_data[0][property.name].shape[1] + ) + + property_df = pd.DataFrame(property_data) + property_df.rename( + columns={property.name: f"{property.name}"}, inplace=True + ) + base_df = base_df.merge(property_df, on="ident", how="left") + + base_df["features"] = base_df.apply( + lambda row: self._merge_props_into_base( + row, + max_len_node_properties=n_node_properties, + ), + axis=1, + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + base_df["features"] = base_df["features"].apply(self.transform) + + return base_df[base_data[0].keys()].to_dict("records") + + def _merge_props_into_base( + self, row: pd.Series, max_len_node_properties: int + ) -> GeomData: + """ + Merge encoded molecular properties into the GeomData object. + + Args: + row: A dictionary containing 'features' and encoded properties. + + Returns: + A GeomData object with merged features. + """ + geom_data = row["features"] + assert isinstance(geom_data, GeomData) + + is_atom_node = geom_data.is_atom_node + assert is_atom_node is not None, "`is_atom_node` must be set in the geom_data" + is_graph_node = geom_data.is_graph_node + assert is_graph_node is not None, "`is_graph_node` must be set in the geom_data" + + is_fg_node = ~is_atom_node & ~is_graph_node + num_nodes = geom_data.x.size(0) + edge_attr = geom_data.edge_attr + + # Initialize node feature matrix + assert ( + max_len_node_properties is not None + ), "Maximum len of node properties should not be None" + x = torch.zeros((num_nodes, max_len_node_properties)) + + # Track column offsets for each node type + atom_offset, fg_offset, graph_offset = 0, 0, 0 + + for property in self.properties: + property_values = row[f"{property.name}"].to(dtype=torch.float32) + if isinstance(property_values, torch.Tensor): + if len(property_values.size()) == 0: + property_values = property_values.unsqueeze(0) + if len(property_values.size()) == 1: + property_values = property_values.unsqueeze(1) + else: + property_values = torch.zeros( + (0, property.encoder.get_encoding_length()) + ) + + enc_len = property_values.shape[1] + # -------------- Node properties --------------- + if isinstance(property, AllNodeTypeProperty): + x[:, atom_offset : atom_offset + enc_len] = property_values + atom_offset += enc_len + fg_offset += enc_len + graph_offset += enc_len + + elif isinstance(property, AtomNodeTypeProperty): + x[is_atom_node, atom_offset : atom_offset + enc_len] = property_values[ + is_atom_node + ] + atom_offset += enc_len + + elif isinstance(property, FGNodeTypeProperty): + x[is_fg_node, fg_offset : fg_offset + enc_len] = property_values[ + is_fg_node + ] + fg_offset += enc_len + + elif isinstance(property, MoleculeProperty): + x[is_graph_node, graph_offset : graph_offset + enc_len] = ( + property_values + ) + graph_offset += enc_len + + # ------------- Bond Properties -------------- + elif isinstance(property, BondProperty): + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) + else: + raise TypeError(f"Unsupported property type: {type(property).__name__}") + + total_used_columns = max(atom_offset, fg_offset, graph_offset) + assert ( + total_used_columns <= max_len_node_properties + ), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}" + + return GeomData( + x=x, + edge_index=geom_data.edge_index, + edge_attr=edge_attr, + molecule_attr=torch.empty((1, 0)), # empty as not used for this class + is_atom_node=is_atom_node, + is_fg_node=is_fg_node, + is_graph_node=is_graph_node, ) + +class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50): + READER = RandomFeatureInitializationReader + + def _setup_properties(self): ... + + def load_processed_data_from_file(self, filename): + base_data = super().load_processed_data_from_file(filename) + base_df = pd.DataFrame(base_data) + + rank_zero_info( + f"Use following values for given parameters for model configuration: \n\t" + f"in_channels: {self.reader.num_node_properties} , " + f"edge_dim: {self.reader.num_bond_properties}, " + f"n_molecule_properties: {self.reader.num_molecule_properties}" + ) return base_df[base_data[0].keys()].to_dict("records") class ChEBI50GraphProperties(GraphPropertiesMixIn, ChEBIOver50): + """ChEBIOver50 dataset with molecular property encodings.""" + pass class ChEBI100GraphProperties(GraphPropertiesMixIn, ChEBIOver100): + """ChEBIOver100 dataset with molecular property encodings.""" + pass class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial): + """Partial version of ChEBIOver50 with molecular properties.""" + pass + + +class AugGraphPropMixIn_NoGraphNode(GraphPropertiesMixIn, ABC): + """Mixin for augmented graph data without additional graph nodes.""" + + READER = None + + def _merge_props_into_base(self, row: pd.Series) -> GeomData: + data = super()._merge_props_into_base(row) + geom_data = row["features"] + assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) + + is_atom_node = geom_data.is_atom_node + assert is_atom_node is not None, "is_atom_node must be set in the geom_data" + data.is_atom_node = is_atom_node + return data + + +class AugGraphPropMixIn_WithGraphNode(AugGraphPropMixIn_NoGraphNode, ABC): + """Mixin for augmented graph data with graph-level nodes.""" + + READER = None + + def _merge_props_into_base(self, row: pd.Series) -> GeomData: + data = super()._merge_props_into_base(row) + return self._add_graph_node_mask(data, row) + + def _add_graph_node_mask(self, data: GeomData, row: pd.Series) -> GeomData: + """ + Add a graph node mask to the GeomData object. + + Args: + data: A GeomData object with features. + row: A dictionary containing 'features' and other metadata. + + Returns: + Modified GeomData with graph node mask added. + """ + geom_data = row["features"] + assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) + is_graph_node = geom_data.is_graph_node + assert is_graph_node is not None, "is_graph_node must be set in the geom_data" + data.is_graph_node = is_graph_node + return data + + +class ChEBI50_WFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + """ChEBIOver50 with with FG nodes and FG edges and graph node.""" + + READER = AtomFGReader_WithFGEdges_WithGraphNode + + +class ChEBI50_GN_WithAllNodes_FG_WithAtoms_FGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + READER = GN_WithAllNodes_FG_WithAtoms_FGE + + +class ChEBI50_GN_WithAllNodes_FG_WithAtoms_NoFGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + READER = GN_WithAllNodes_FG_WithAtoms_NoFGE + + +class ChEBI50_GN_WithAtoms_FG_WithAtoms_FGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + READER = GN_WithAtoms_FG_WithAtoms_FGE + + +class ChEBI50_GN_WithAtoms_FG_WithAtoms_NoFGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + READER = GN_WithAtoms_FG_WithAtoms_NoFGE + + +class ChEBI50_NFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + """ChEBIOver50 with FG nodes but without FG edges, with graph node.""" + + READER = AtomFGReader_NoFGEdges_WithGraphNode + + +class ChEBI50_WFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): + """ChEBIOver50 with FG nodes and FG edges, no graph node.""" + + READER = AtomFGReader_WithFGEdges_NoGraphNode + + +class ChEBI50_NFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): + """ChEBIOver50 with FG nodes but without FG edges or graph node.""" + + READER = AtomsFGReader_NoFGEdges_NoGraphNode + + +class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + """ChEBIOver50 with atom-level nodes and graph node only.""" + + READER = AtomReader_WithGraphNodeOnly + + +class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50): + READER = AtomFGReader_WithFGEdges_WithGraphNode + + +class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100): + READER = AtomFGReader_WithFGEdges_WithGraphNode diff --git a/chebai_graph/preprocessing/datasets/utils.py b/chebai_graph/preprocessing/datasets/utils.py new file mode 100644 index 0000000..3ec6515 --- /dev/null +++ b/chebai_graph/preprocessing/datasets/utils.py @@ -0,0 +1,43 @@ +import importlib + +import chebai_graph.preprocessing.properties as graph_properties +from chebai_graph.preprocessing.properties import MolecularProperty + + +def resolve_property(property: str | MolecularProperty) -> MolecularProperty: + """ + Resolves a molecular property specification (either as a class instance or class path string) + into a MolecularProperty instance. + + This utility is designed to support flexible specification of molecular properties + in dataset configurations. It handles: + - Direct instances of MolecularProperty + - Full class paths as strings (e.g., "my_module.MyProperty") + - Shorthand class names assumed to exist in chebai_graph.preprocessing.properties + + Args: + property (str | MolecularProperty): The property to resolve. Can be a class instance, + a fully qualified class name (e.g. "module.ClassName"), or a class name assumed + to be in `chebai_graph.preprocessing.properties`. + + Returns: + MolecularProperty: An instance of the resolved MolecularProperty. + + Raises: + AttributeError: If the class name cannot be found in the target module. + ModuleNotFoundError: If the module cannot be imported. + TypeError: If the resolved object is not a MolecularProperty instance. + """ + # if property is given as a string, try to resolve as a class path + if isinstance(property, MolecularProperty): + return property + try: + # split class_path into module-part and class name + last_dot = property.rindex(".") + module_name = property[:last_dot] + class_name = property[last_dot + 1 :] + module = importlib.import_module(module_name) + return getattr(module, class_name)() + except ValueError: + # if only a class name is given, assume the module is chebai_graph.processing.properties + return getattr(graph_properties, property)() diff --git a/chebai_graph/preprocessing/fg_detection/__init__.py b/chebai_graph/preprocessing/fg_detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py new file mode 100644 index 0000000..f60f580 --- /dev/null +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -0,0 +1,1936 @@ +# This code file is taken from https://github.com/thaonguyen217/farm_molecular_representation +# Reference Paper: Nguyen, Thao, et al. "FARM: Functional Group-Aware Representations for Small Molecules." +# arXiv preprint arXiv:2410.02082 (2024). + +import re + +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import MolToSmiles as m2s + +from .fg_constants import ELEMENTS, FLAG_NO_FG + + +def ring_size_processing(ring_size): + if ring_size[0] > ring_size[-1]: + return list(reversed(ring_size)) + else: + return ring_size + + +# Function to find all rings connected to a given ring +def find_connected_rings(ring, remaining_rings) -> list[set[int]]: + connected_rings = [ring] + merged = True + while merged: + merged = False + for other_ring in remaining_rings: + if ring & other_ring: # If there is a shared atom, they are connected + connected_rings.append(other_ring) + remaining_rings.remove(other_ring) + ring = ring.union(other_ring) + merged = True + return connected_rings + + +def set_ring_properties(mol: Chem.Mol) -> list[list[set[int]]] | None: + if mol is None: + return + + for atom in mol.GetAtoms(): + atom.SetProp("RING", "") + + AllChem.GetSymmSSSR(mol) + + ######## SET RING PROP ######## + # Get ring information + ring_info = mol.GetRingInfo() + fused_rings_groups: list[list[set[int]]] = [] + + if ring_info.NumRings() > 0: + # Get list of atom rings + atom_rings = ring_info.AtomRings() + + # Initialize a list to hold fused ring blocks and their sizes + fused_ring_blocks = [] + ring_sizes = [] + + # Set of rings to process + remaining_rings = [set(ring) for ring in atom_rings] + + # Process each ring block + while remaining_rings: + ring = remaining_rings.pop(0) + connected_rings: list[set[int]] = find_connected_rings( + ring, remaining_rings + ) + + if len(connected_rings) > 1: + fused_rings_groups.append(connected_rings) + + for ring in connected_rings: + ring_size = len(ring) + for atom_idx in ring: + atom = mol.GetAtomWithIdx(atom_idx) + if not atom.HasProp("RING"): + atom.SetProp("RING", f"{ring_size}") + else: + # An atom shared across multiple rings in fused-ring will have Ring size like "5-6-6" + atom.SetProp("RING", f"{atom.GetProp('RING')}-{ring_size}") + + # Merge all connected rings into one fused block + fused_block = set().union(*connected_rings) + fused_ring_blocks.append(sorted(fused_block)) + ring_sizes.append([len(r) for r in connected_rings]) + + # Display the fused ring blocks and their ring sizes + for i, block in enumerate(fused_ring_blocks): + rs = "-".join(str(size) for size in ring_size_processing(ring_sizes[i])) + for idx in block: + atom = mol.GetAtomWithIdx(idx) + atom.SetProp("RING", rs) + + return fused_rings_groups + + +def detect_functional_group(mol: Chem.Mol): + if mol is None: + return + + for atom in mol.GetAtoms(): + atom.SetProp("FG", "") + + ######## SET FUNCTIONAL GROUP PROP ######## + for atom in mol.GetAtoms(): + atom_symbol = atom.GetSymbol() + atom_neighbors = atom.GetNeighbors() + atom_num_neighbors = len(atom_neighbors) + num_H = atom.GetTotalNumHs() + in_ring = atom.IsInRing() + atom_idx = atom.GetIdx() + charge = atom.GetFormalCharge() + + ########################### Groups containing oxygen ########################### + if atom_symbol in ["C", "*"] and charge == 0: # and atom.GetProp('FG') == '': + num_O, num_X, num_C, num_N, num_S = 0, 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Cl", "Br", "I"]: + num_X += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "N": + num_N += 1 + if neighbor.GetSymbol() == "S": + num_S += 1 + + if ( + num_H == 1 + and atom_num_neighbors == 3 + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "tertiary_carbon") + if atom_num_neighbors == 4 and charge == 0 and atom.GetProp("FG") == "": + atom.SetProp("FG", "quaternary_carbon") + if ( + num_H == 0 + and atom_num_neighbors == 3 + and charge == 0 + and atom.GetProp("FG") == "" + and not in_ring + ): + atom.SetProp("FG", "alkene_carbon") + + if ( + num_O == 1 + and atom_symbol == "C" + and atom.GetProp("FG") + not in [ + "hemiacetal", + "hemiketal", + "acetal", + "ketal", + "orthoester", + "orthocarbonate_ester", + "carbonate_ester", + ] + ): + if num_N == 1: # Cyanate and Isocyanate + condition1, condition2 = False, False + condition3, condition4 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition3 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition4 = True + + if condition1 and condition2 and not in_ring: # Cyanate + atom.SetProp("FG", "cyanate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "cyanate") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + for C_neighbor in neighbor.GetNeighbors(): + if ( + C_neighbor.GetSymbol() in ["C", "*"] + and C_neighbor.GetIdx() != atom_idx + ): + C_neighbor.SetProp("FG", "") + + if condition3 and condition4 and not in_ring: # Isocyanate + atom.SetProp("FG", "isocyanate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "isocyanate") + + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + bond = mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()) + bondtype = bond.GetBondType() + if ( + bondtype == Chem.BondType.SINGLE + ): # and not neighbor.IsInRing(): # [C-O]: Alcohol (COH) or Ether [COC] or Hydroperoxy [C-O-O-H] or Peroxide [C-O-O-C] + if neighbor.GetTotalNumHs() == 1: # Alcohol [COH] + neighbor.SetProp("FG", "hydroxyl") + else: + for O_neighbor in neighbor.GetNeighbors(): + # if not O_neighbor.IsInRing(): + if ( + O_neighbor.GetIdx() != atom_idx + and O_neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetProp("FG") == "" + ): # Ether [COC] + neighbor.SetProp("FG", "ether") + if O_neighbor.GetSymbol() == "O": + if ( + O_neighbor.GetTotalNumHs() == 1 + ): # Hydroperoxy [C-O-O-H] + neighbor.SetProp("FG", "hydroperoxy") + O_neighbor.SetProp("FG", "hydroperoxy") + else: + neighbor.SetProp("FG", "peroxy") + O_neighbor.SetProp("FG", "peroxy") + + if ( + bondtype == Chem.BondType.DOUBLE + ): # [C=O]: Ketone [CC(=0)C] or Aldehyde [CC(=O)H] or Acyl halide [C(=O)X] + if ( + num_X == 1 and not neighbor.IsInRing() + ): # Acyl halide [C(=O)X] + atom.SetProp("FG", "haloformyl") + for neighbor_ in atom_neighbors: + if neighbor_.GetSymbol() in [ + "O", + "F", + "Cl", + "Br", + "I", + ]: + neighbor_.SetProp("FG", "haloformyl") + + if ( + (num_C == 1 and num_H == 1) + or num_H == 2 + and not in_ring + ): # Aldehyde [C(=O)H] + atom.SetProp("FG", "aldehyde") + neighbor.SetProp("FG", "aldehyde") + + if atom_num_neighbors == 3 and atom.GetProp("FG") not in [ + "haloformyl", + "amide", + ]: # Ketone [C(=0)C] + atom.SetProp("FG", "ketone") + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and not neighbor.IsInRing() + ): + neighbor.SetProp("FG", "ketone") + + if num_O == 2: # and atom.GetProp('FG') == '': + if atom_num_neighbors == 3: + if num_H == 0: + condition1, condition2, condition3, condition4 = ( + False, + False, + False, + False, + ) + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + and not neighbor.IsInRing() + ): + condition2 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): + condition3 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + and atom.GetProp("FG") != "carbamate" + ): + condition4 = True + + if condition1 and condition2: + atom.SetProp("FG", "carboxylate") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "carboxylate") + if condition1 and condition3: + atom.SetProp("FG", "carboxyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "carboxyl") + if ( + condition1 + and condition4 + and atom.GetProp("FG") + not in ["carbamate", "carbonate_ester"] + ): + atom.SetProp("FG", "ester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "ester") + for O_neighbor in neighbor.GetNeighbors(): + O_neighbor.SetProp("FG", "ester") + + if num_H == 1 and not in_ring: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + condition2 = True + cnt += 1 + + if condition1 and condition2: + atom.SetProp("FG", "hemiacetal") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "hemiacetal") + if cnt == 2: + atom.SetProp("FG", "acetal") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "acetal") + + if atom_num_neighbors == 4 and not in_ring: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + condition2 = True + cnt += 1 + + if condition1 and condition2: + atom.SetProp("FG", "hemiketal") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "hemiketal") + if cnt == 2: + atom.SetProp("FG", "ketal") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "ketal") + + if num_O == 3 and atom_num_neighbors == 4 and not in_ring: + n_C = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + n_C += 1 + if n_C == 3: + atom.SetProp("FG", "orthoester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "orthoester") + + if num_O == 3 and atom_num_neighbors == 3 and charge == 0 and not in_ring: + condition1 = False + n_O = 0 + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + n_O += 1 + if condition1 and n_O == 2: + atom.SetProp("FG", "carbonate_ester") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "carbonate_ester") + + if num_O == 4 and not in_ring: + n_C = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + n_C += 1 + if n_C == 4: + atom.SetProp("FG", "orthocarbonate_ester") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "orthocarbonate_ester") + + ########################### Groups containing nitrogen ########################### + #### Amidine #### + if num_N == 2 and atom_num_neighbors == 3: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "amidine") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "N": + neighbor.SetProp("FG", "amidine") + + if num_N == 1 and num_O == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 2 + ): + condition2 = True + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 + and not neighbor.IsInRing() + ): + condition3 = True + if condition1 and condition2 and condition3: + atom.SetProp("FG", "carbamate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "carbamate") + + if num_N == 1 and num_S == 1: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "isothiocyanate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "isothiocyanate") + + if num_S == 1 and atom_num_neighbors == 3: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + atom.SetProp("FG", "thioketone") + neighbor.SetProp("FG", "thioketone") + + if num_S == 1 and num_H == 1 and atom_num_neighbors == 2: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + atom.SetProp("FG", "thial") + neighbor.SetProp("FG", "thial") + + if num_S == 1 and num_O == 1 and atom_num_neighbors == 3: + condition1, condition2 = False, False + condition3, condition4 = False, False + condition5, condition6 = False, False + condition7, condition8 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and not neighbor.IsInRing() + ): + condition2 = True + + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): + condition3 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and not len(neighbor.GetNeighbors()) == 1 + ): + condition4 = True + + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + flag = True + for bond in neighbor.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + flag = False + if flag: + condition5 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and not neighbor.IsInRing() + ): + condition6 = True + + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition7 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): + condition8 = True + + if condition1 and condition2: + atom.SetProp("FG", "carbothioic_S-acid") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "carbothioic_S-acid") + if condition3 and condition4: + atom.SetProp("FG", "carbothioic_O-acid") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "carbothioic_O-acid") + if condition5 and condition6: + atom.SetProp("FG", "thiolester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "thiolester") + if condition7 and condition8: + atom.SetProp("FG", "thionoester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "thionoester") + + if num_S == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): + condition2 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 2 + and not neighbor.IsInRing() + ): + flag = True + for bond in neighbor.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + flag = False + if flag: + condition3 = True + + if condition1 and condition2: + atom.SetProp("FG", "carbodithioic_acid") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "carbodithioic_acid") + if condition3 and condition2: + atom.SetProp("FG", "carbodithio") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "carbodithio") + + if num_X == 3 and charge == 0 and atom_num_neighbors == 4: + num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + num_F += 1 + if neighbor.GetSymbol() == "Cl": + num_Cl += 1 + if neighbor.GetSymbol() == "Br": + num_Br += 1 + if neighbor.GetSymbol() == "I": + num_I += 1 + if num_F == 3: + atom.SetProp("FG", "trifluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + neighbor.SetProp("FG", "trifluoromethyl") + if num_F == 2 and num_Cl == 1: + atom.SetProp("FG", "difluorochloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Cl"]: + neighbor.SetProp("FG", "difluorochloromethyl") + if num_F == 2 and num_Br == 1: + atom.SetProp("FG", "bromodifluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Br"]: + neighbor.SetProp("FG", "bromodifluoromethyl") + + if num_Cl == 3: + atom.SetProp("FG", "trichloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Cl": + neighbor.SetProp("FG", "trichloromethyl") + if num_Cl == 2 and num_Br == 1: + atom.SetProp("FG", "bromodichloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Cl", "Br"]: + neighbor.SetProp("FG", "bromodichloromethyl") + + if num_Br == 3: + atom.SetProp("FG", "tribromomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Br": + neighbor.SetProp("FG", "tribromomethyl") + if num_Br == 2 and num_F == 1: + atom.SetProp("FG", "dibromofluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Br"]: + neighbor.SetProp("FG", "dibromofluoromethyl") + + if num_I == 3: + atom.SetProp("FG", "triiodomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "I": + neighbor.SetProp("FG", "triiodomethyl") + + if num_X == 2 and charge == 0 and atom_num_neighbors == 3 and num_H == 1: + num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + num_F += 1 + if neighbor.GetSymbol() == "Cl": + num_Cl += 1 + if neighbor.GetSymbol() == "Br": + num_Br += 1 + if neighbor.GetSymbol() == "I": + num_I += 1 + + if num_F == 2: + atom.SetProp("FG", "difluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + neighbor.SetProp("FG", "difluoromethyl") + if num_F == 1 and num_Cl == 1: + atom.SetProp("FG", "fluorochloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Cl"]: + neighbor.SetProp("FG", "fluorochloromethyl") + + if num_Cl == 2: + atom.SetProp("FG", "dichloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Cl": + neighbor.SetProp("FG", "dichloromethyl") + if num_Cl == 1 and num_Br == 1: + atom.SetProp("FG", "chlorobromomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Cl", "Br"]: + neighbor.SetProp("FG", "chlorobromomethyl") + if num_Cl == 1 and num_I == 1: + atom.SetProp("FG", "chloroiodomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Cl", "I"]: + neighbor.SetProp("FG", "chloroiodomethyl") + + if num_Br == 2: + atom.SetProp("FG", "dibromomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Br": + neighbor.SetProp("FG", "dibromomethyl") + if num_Br == 1 and num_I == 1: + atom.SetProp("FG", "bromoiodomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Br", "I"]: + neighbor.SetProp("FG", "bromoiodomethyl") + + if num_I == 2: + atom.SetProp("FG", "diiodomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "I": + neighbor.SetProp("FG", "diiodomethyl") + + if ( + (atom_num_neighbors == 2 or atom_num_neighbors == 1) + and not in_ring + and atom.GetProp("FG") == "" + ): + bonds = atom.GetBonds() + ns, nd, nt = 0, 0, 0 + for bond in bonds: + if bond.GetBondType() == Chem.BondType.SINGLE: + ns += 1 + elif bond.GetBondType() == Chem.BondType.DOUBLE: + nd += 1 + else: + nt += 1 + if ns >= 1 and nd == 0 and nt == 0: + atom.SetProp("FG", "alkyl") + if nd >= 1: + atom.SetProp("FG", "alkene") + if nt == 1: + atom.SetProp("FG", "alkyne") + + elif ( + atom_symbol == "O" and not in_ring and charge == 0 and num_H == 0 + ): # Carboxylic anhydride [C(CO)O(CO)C] + num_C = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if num_C == 2: + cnt = 0 + for neighbor in atom_neighbors: + for C_neighbor in neighbor.GetNeighbors(): + if ( + C_neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), C_neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + ): + cnt += 1 + if cnt == 2: + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "carboxylic_anhydride") + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() == "O": + C_neighbor.SetProp("FG", "carboxylic_anhydride") + + elif atom_symbol == "N": # and atom.GetProp('FG') == '': + num_C, num_O, num_N = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + if neighbor.GetSymbol() == "N": + num_N += 1 + + #### Amines #### + if ( + charge == 0 + and num_H == 2 + and atom_num_neighbors == 1 + and atom.GetProp("FG") != "hydrazone" + ): # Primary amine [RNH2] + atom.SetProp("FG", "primary_amine") + + if ( + charge == 0 and num_H == 1 and atom_num_neighbors == 2 + ): # Secondary amine [R'R"NH] + atom.SetProp("FG", "secondary_amine") + + if ( + charge == 0 + and atom_num_neighbors == 3 + and atom.GetProp("FG") != "carbamate" + ): + cnt = 0 + C_idx = [] + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + for C_neighbor in neighbor.GetNeighbors(): + if ( + C_neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), C_neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + and atom.GetProp("FG") != "imide" + ): + atom.SetProp("FG", "amide") + neighbor.SetProp("FG", "amide") + C_neighbor.SetProp("FG", "amide") + cnt += 1 + C_idx.append(neighbor.GetIdx()) + + if cnt == 2: + for neighbor in atom_neighbors: + if neighbor.GetIdx() in C_idx: + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() in ["O", "N"]: + neighbor.SetProp("FG", "imide") + C_neighbor.SetProp("FG", "imide") + + if atom.GetProp("FG") not in [ + "imide", + "amide", + "amidine", + "carbamate", + ]: # Tertiary amine [R3N] + atom.SetProp("FG", "tertiary_amine") + + if charge == 1 and atom_num_neighbors == 4: # 4° ammonium ion [R3N] + atom.SetProp("FG", "4_ammonium_ion") + + if ( + charge == 0 + and num_C == 1 + and num_N == 1 + and num_H == 0 + and atom_num_neighbors == 2 + ): # Hydrazone [R'R"CN2H2] + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "N" + and neighbor.GetTotalNumHs() == 2 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "hydrazone") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "hydrazone") + + #### Imine #### + if ( + charge == 0 + and num_C == 1 + and num_H == 1 + and num_N == 0 + and atom_num_neighbors == 1 + ): # Primary ketimine [RC(=NH)R'] + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "primary_ketimine") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "primary_ketimine") + + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "primary_aldimine") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "primary_aldimine") + + if ( + charge == 0 + and atom_num_neighbors == 1 + and atom.GetProp("FG") not in ["thiocyanate", "cyanate"] + ): # Nitrile + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + ): + atom.SetProp("FG", "nitrile") + + if ( + charge == 0 + and num_C >= 1 + and atom_num_neighbors == 2 + and atom.GetProp("FG") != "hydrazone" + ): # Secondary ketimine [RC(=NR'')R'] + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "secondary_ketimine") + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "secondary_ketimine") + + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + ): + atom.SetProp("FG", "secondary_aldimine") + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "secondary_aldimine") + + if charge == 1 and num_N == 2 and atom_num_neighbors == 2: # Azide [RN3] + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetFormalCharge() == 0 + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition1 = True + if ( + neighbor.GetFormalCharge() == -1 + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition2 = True + if condition1 and condition2 and not in_ring: + atom.SetProp("FG", "azide") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "azide") + + if ( + charge == 0 and num_N == 1 and atom_num_neighbors == 2 and not in_ring + ): # Azo [RN2R'] + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "azo") + neighbor.SetProp("FG", "azo") + break + + if ( + charge == 1 and num_O == 3 and atom_num_neighbors == 3 + ): # Nitrate [RONO2] + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + ): + condition2 = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + ): + condition3 = True + + if condition1 and condition2 and condition3 and not in_ring: + atom.SetProp("FG", "nitrate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "nitrate") + + if charge == 1 and num_C >= 1 and atom_num_neighbors == 2: # Isonitrile + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetFormalCharge() == -1 + and len(neighbor.GetNeighbors()) == 1 + ): + atom.SetProp("FG", "isonitrile") + neighbor.SetProp("FG", "isonitrile") + + if ( + charge == 0 and num_O == 2 and atom_num_neighbors == 2 and not in_ring + ): # Nitrite + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + ): + atom.SetProp("FG", "nitrosooxy") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "nitrosooxy") + + if ( + charge == 1 and num_O == 2 and atom_num_neighbors == 3 and not in_ring + ): # Nitro compound + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + ): + condition2 = True + if condition1 and condition2 and not in_ring: + atom.SetProp("FG", "nitro") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "nitro") + + if charge == 0 and num_O == 1 and atom_num_neighbors == 2 and not in_ring: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): # Nitroso compound + atom.SetProp("FG", "nitroso") + neighbor.SetProp("FG", "nitroso") + + if charge == 0 and num_O == 1 and num_C == 1 and atom_num_neighbors == 2: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + ): + condition1 = True + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 + ): + condition3 = True + + if condition1 and condition2 and not in_ring: + atom.SetProp("FG", "aldoxime") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "aldoxime") + if condition1 and condition3 and not in_ring: + atom.SetProp("FG", "ketoxime") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "ketoxime") + + ########################### Groups containing sulfur ########################### + elif atom_symbol == "S" and charge == 0: + num_C, num_S, num_O = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "S": + num_S += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + + if ( + num_H == 1 + and atom_num_neighbors == 1 + and atom.GetProp("FG") + not in ["carbothioic_S-acid", "carbodithioic_acid"] + ): + neighbor = atom_neighbors[0] + if ( + mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() + == Chem.BondType.SINGLE + ): + atom.SetProp("FG", "sulfhydryl") + + if ( + num_H == 0 + and atom_num_neighbors == 2 + and atom.GetProp("FG") not in ["sulfhydrylester", "carbodithio"] + ): + cnt = 0 + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if cnt == 2: + atom.SetProp("FG", "sulfide") + + if num_H == 0 and num_S == 1 and atom_num_neighbors == 2: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + ): + condition1 = True + if ( + neighbor.GetSymbol() != "S" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "disulfide") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "disulfide") + + if num_H == 0 and num_O >= 1 and atom_num_neighbors == 3: + condition = False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if condition and cnt == 2: + atom.SetProp("FG", "sulfinyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfinyl") + + if num_H == 0 and num_O >= 2 and atom_num_neighbors == 4: + cnt1 = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + cnt1 += 1 + if cnt1 == 2: + atom.SetProp("FG", "sulfonyl") + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "sulfonyl") + + if num_H == 0 and num_O == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition3 = True + if condition1 and condition2 and condition3 and not in_ring: + atom.SetProp("FG", "sulfino") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfino") + + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + cnt += 1 + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + if condition1 and condition2 and cnt == 2 and not in_ring: + atom.SetProp("FG", "sulfonic_acid") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfonic_acid") + + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + cnt += 1 + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if condition1 and condition2 and cnt == 2: + atom.SetProp("FG", "sulfonate_ester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfonate_ester") + + if num_H == 0 and atom_num_neighbors == 2: + for neighbor in atom_neighbors: + for C_neighbor in neighbor.GetNeighbors(): + if ( + C_neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + C_neighbor.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + and not in_ring + ): + atom.SetProp("FG", "thiocyanate") + neighbor.SetProp("FG", "thiocyanate") + C_neighbor.SetProp("FG", "thiocyanate") + + ########################### Groups containing phosphorus ########################### + elif atom_symbol == "P" and not in_ring and charge == 0: + num_C, num_O = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + + if atom_num_neighbors == 3: + cnt = 0 + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if cnt == 3: + atom.SetProp("FG", "phosphino") + + if num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + cnt += 1 + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + if condition1 and condition2 and cnt == 2: + atom.SetProp("FG", "phosphono") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "phosphono") + + if num_O == 4 and atom_num_neighbors == 4: + condition1 = False + cnt1, cnt2 = 0, 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + cnt1 += 1 + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + ): + cnt2 += 1 + + if condition1 and cnt1 == 2 and cnt2 == 1: + atom.SetProp("FG", "phosphate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "phosphate") + if condition1 and cnt1 == 1 and cnt2 == 2: + atom.SetProp("FG", "phosphodiester") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "phosphodiester") + + if num_O == 1 and atom_num_neighbors == 4: + condition = False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if condition and cnt == 3: + atom.SetProp("FG", "phosphoryl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "phosphoryl") + + ########################### Groups containing boron ########################### + elif atom_symbol == "B" and not in_ring and charge == 0: + num_C, num_O = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + + if num_O == 2 and atom_num_neighbors == 3: + cnt1, cnt2 = 0, 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + cnt1 += 1 + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 2 + ): + cnt2 += 1 + if cnt1 == 2: + atom.SetProp("FG", "borono") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "borono") + if cnt2 == 2: + atom.SetProp("FG", "boronate") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "boronate") + + if num_O == 1 and atom_num_neighbors == 3: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O" and neighbor.GetFormalCharge() == 0: + if neighbor.GetTotalNumHs() == 1: + atom.SetProp("FG", "borino") + neighbor.SetProp("FG", "borino") + if len(neighbor.GetNeighbors()) == 2: + atom.SetProp("FG", "borinate") + neighbor.SetProp("FG", "borinate") + + ########################### Groups containing silicon ########################### + elif atom_symbol == "Si" and not in_ring and charge == 0: + num_O, num_Cl, num_C = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + num_O += 1 + if neighbor.GetSymbol() == "Cl": + num_Cl += 1 + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if num_O == 1 and charge == 0 and atom_num_neighbors == 4: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "silyl_ether") + neighbor.SetProp("FG", "silyl_ether") + if num_Cl == 2 and charge == 0 and atom_num_neighbors == 4: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Cl" and neighbor.GetFormalCharge() == 0: + atom.SetProp("FG", "dichlorosilane") + neighbor.SetProp("FG", "dichlorosilane") + if ( + num_C >= 3 + and charge == 0 + and atom_num_neighbors == 4 + and atom.GetProp("FG") != "silyl_ether" + ): + cnt = 0 + C_idx = [] + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 3 + ): + cnt += 1 + C_idx.append(neighbor.GetIdx()) + if cnt == 3: + atom.SetProp("FG", "trimethylsilyl") + for idx in C_idx: + mol.GetAtomWithIdx(idx).SetProp("FG", "trimethylsilyl") + + ########################### Groups containing halogen ########################### + elif ( + atom_symbol == "F" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "fluoro") + elif ( + atom_symbol == "Cl" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "chloro") + elif ( + atom_symbol == "Br" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "bromo") + elif ( + atom_symbol == "I" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "iodo") + else: + pass + + ########################### Groups containing other elements ########################### + if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: + if charge == 0: + atom.SetProp("FG", atom_symbol) + atom.SetProp(FLAG_NO_FG, "") + # atom.SetProp("FG", NO_FG) # changes the fg-detection algo (num of fg dectected reduces) + else: + atom.SetProp("FG", f"{atom_symbol}[{charge}]") + atom.SetProp(FLAG_NO_FG, "") + # atom.SetProp("FG", NO_FG) # changes the fg-detection algo (num of fg dectected reduces) + else: + pass + + if atom_symbol == "*": + atom.SetProp("FG", "") + + +def set_atom_map_num(mol): + if mol is not None: + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + if idx != 0: + atom.SetAtomMapNum(idx) + else: + atom.SetAtomMapNum(-9) + + +def find_neighbor_map(smiles): + matches = re.findall(r"\[(\d+)\*\]", smiles) + idx = [int(match) for match in matches] + if smiles.startswith("*"): + return set(idx) | {0} + else: + return set(idx) + + +def find_atom_map(smiles): + matches = re.findall(r"\[[^\]]*:(\d+)\]", smiles) + idx = [int(match) for match in matches] + return set(idx) + + +def get_structure(mol): + set_atom_map_num(mol) + fused_rings_groups: list[list[set[int]]] = set_ring_properties(mol) + detect_functional_group(mol) + rings = mol.GetRingInfo().AtomRings() + + splitting_bonds = set() + for bond in mol.GetBonds(): + begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + begin_atom_prop = begin_atom.GetProp("FG") + end_atom_prop = end_atom.GetProp("FG") + begin_atom_symbol = begin_atom.GetSymbol() + end_atom_symbol = end_atom.GetSymbol() + + if ( + (begin_atom.IsInRing() and not end_atom.IsInRing()) + or (not begin_atom.IsInRing() and end_atom.IsInRing()) + or (begin_atom.IsInRing() and end_atom.IsInRing()) + ): + flag = True + for ring in rings: + if begin_atom.GetIdx() in ring and end_atom.GetIdx() in ring: + flag = False + break + if flag: + splitting_bonds.add(bond) + else: + if begin_atom_prop != end_atom_prop: + splitting_bonds.add(bond) + if begin_atom_prop == "" and end_atom_prop == "": + if (begin_atom_symbol in ["C", "*"] and end_atom_symbol != "C") or ( + begin_atom_symbol != "C" and end_atom_symbol in ["C", "*"] + ): + splitting_bonds.add(bond) + + splitting_bonds = list(splitting_bonds) + if splitting_bonds != []: + fragments = Chem.FragmentOnBonds( + mol, [bond.GetIdx() for bond in splitting_bonds], addDummies=True + ) + BONDS = set() + for bond in splitting_bonds: + BONDS.add( + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType()) + ) + else: + fragments = mol + BONDS = set() + smiles = m2s(fragments).replace("-9", "0").split(".") + + structure = {} + for frag in smiles: + atom_idx: set = find_atom_map(frag) + structure[frag] = {"atom": atom_idx, "is_ring_fg": False} + + # Convert fragment SMILES back to mol to match with fused ring atom indices + frag_mol = Chem.MolFromSmiles(frag) + frag_rings = frag_mol.GetRingInfo().AtomRings() + if len(frag_rings) >= 1: + structure[frag]["is_ring_fg"] = True + + if len(frag_rings) <= 1: + continue + + if not fused_rings_groups: + continue + + for group in fused_rings_groups: + flat_atoms = set().union(*group) + if flat_atoms.issubset(atom_idx) and len(flat_atoms) == len(atom_idx): + for idx, ring_atoms in enumerate(group): + structure[f"{frag}_{idx + 1}"] = { + "atom": ring_atoms, + "is_ring_fg": True, + } + structure.pop(frag) + break + + return structure, BONDS diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py new file mode 100644 index 0000000..5431264 --- /dev/null +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -0,0 +1,14 @@ +# fmt: off +ELEMENTS = { + "Ac", "Ag", "Al", "Am", "As", "At", "Au", "B", "Ba", "Be", "Bi", "Bk", "Br", "Ca", + "Cd", "Ce", "Cf", "Cl", "Cm", "Co", "Cr", "Cs", "Cu", "Dy", "Er", "Es", "Eu", "F", + "Fe", "Fm", "Fr", "Ga", "Gd", "Ge", "He", "Hf", "Hg", "Ho", "I", "In", "Ir", "K", + "Kr", "La", "Li", "Lr", "Lu", "Md", "Mg", "Mn", "Mo", "N", "Na", "Nb", "Nd", "Ne", + "Ni", "Np", "O", "Os", "P", "Pa", "Pb", "Pd", "Pm", "Po", "Pr", "Pt", "Pu", "Ra", + "Rb", "Re", "Rh", "Rn", "Ru", "S", "Sb", "Sc", "Se", "Si", "Sm", "Sn", "Sr", "Ta", + "Tb", "Tc", "Te", "Th", "Ti", "Tl", "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr" +} +# fmt: on + + +FLAG_NO_FG = "flag_no_fg" diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py deleted file mode 100644 index 2b3acf8..0000000 --- a/chebai_graph/preprocessing/properties.py +++ /dev/null @@ -1,159 +0,0 @@ -import abc -from typing import Optional - -import numpy as np -import rdkit.Chem as Chem -from descriptastorus.descriptors import rdNormalizedDescriptors - -from chebai_graph.preprocessing.property_encoder import ( - AsIsEncoder, - BoolEncoder, - IndexEncoder, - OneHotEncoder, - PropertyEncoder, -) - - -class MolecularProperty(abc.ABC): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - if encoder is None: - encoder = IndexEncoder(self) - self.encoder = encoder - - @property - def name(self): - """Unique identifier for this property.""" - return self.__class__.__name__ - - def on_finish(self): - """Called after dataset processing is done.""" - self.encoder.on_finish() - - def __str__(self): - return self.name - - def get_property_value(self, mol: Chem.rdchem.Mol): - raise NotImplementedError - - -class AtomProperty(MolecularProperty): - """Property of an atom.""" - - def get_property_value(self, mol: Chem.rdchem.Mol): - return [self.get_atom_value(atom) for atom in mol.GetAtoms()] - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return NotImplementedError - - -class BondProperty(MolecularProperty): - def get_property_value(self, mol: Chem.rdchem.Mol): - return [self.get_bond_value(bond) for bond in mol.GetBonds()] - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return NotImplementedError - - -class MoleculeProperty(MolecularProperty): - """Global property of a molecule.""" - - -class AtomType(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetAtomicNum() - - -class NumAtomBonds(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetDegree() - - -class AtomCharge(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetFormalCharge() - - -class AtomChirality(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetChiralTag() - - -class AtomHybridization(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetHybridization() - - -class AtomNumHs(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetTotalNumHs() - - -class AtomAromaticity(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or BoolEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetIsAromatic() - - -class BondAromaticity(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or BoolEncoder(self)) - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return bond.GetIsAromatic() - - -class BondType(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return bond.GetBondType() - - -class BondInRing(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or BoolEncoder(self)) - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return bond.IsInRing() - - -class MoleculeNumRings(MolecularProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_property_value(self, mol: Chem.rdchem.Mol): - return [mol.GetRingInfo().NumRings()] - - -class RDKit2DNormalized(MolecularProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or AsIsEncoder(self)) - - def get_property_value(self, mol: Chem.rdchem.Mol): - generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() - features_normalized = generator_normalized.processMol( - mol, Chem.MolToSmiles(mol) - ) - features_normalized = np.nan_to_num(features_normalized) - return [features_normalized[1:]] diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py new file mode 100644 index 0000000..a0de30d --- /dev/null +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -0,0 +1,89 @@ +# Formating is turned off here, because isort sorts the augmented properties imports in first order, +# but it has to be imported after properties module, to avoid circular imports +# This is because augmented properties module imports from properties module +# isort: off + +from .base import ( + MolecularProperty, + AtomProperty, + BondProperty, + MoleculeProperty, + AllNodeTypeProperty, + AtomNodeTypeProperty, + FGNodeTypeProperty, +) + +from .properties import ( + AtomType, + NumAtomBonds, + AtomCharge, + AtomChirality, + AtomHybridization, + AtomNumHs, + AtomAromaticity, + BondAromaticity, + BondType, + BondInRing, + RDKit2DNormalized, +) + +from .augmented_properties import ( + AtomNodeLevel, + AtomFunctionalGroup, + IsHydrogenBondDonorFG, + IsHydrogenBondAcceptorFG, + IsFGAlkyl, + BondLevel, + AugAtomType, + AugNumAtomBonds, + AugAtomCharge, + AugAtomHybridization, + AugAtomNumHs, + AugAtomAromaticity, + AugBondAromaticity, + AugBondType, + AugBondInRing, + AugRDKit2DNormalized, +) + +# isort: on + +__all__ = [ + # -------------- Properties Base classes -------------- + "MolecularProperty", + "MoleculeProperty", + "AtomProperty", + "BondProperty", + "AllNodeTypeProperty", + "AtomNodeTypeProperty", + "FGNodeTypeProperty", + # -------------- Regular Properties ----------------- + "AtomType", + "NumAtomBonds", + "AtomCharge", + "AtomChirality", + "AtomHybridization", + "AtomNumHs", + "AtomAromaticity", + "BondAromaticity", + "BondType", + "BondInRing", + "RDKit2DNormalized", + # -------- Augmented Molecular Properties ---------- + "AtomNodeLevel", + "AtomFunctionalGroup", + "IsHydrogenBondDonorFG", + "IsHydrogenBondAcceptorFG", + "IsFGAlkyl", + "BondLevel", + "AugAtomType", + "AugNumAtomBonds", + "AugAtomCharge", + "AugAtomHybridization", + "AugAtomNumHs", + "AugAtomAromaticity", + "AugBondAromaticity", + "AugBondType", + "AugBondInRing", + "AugRDKit2DNormalized", +] diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py new file mode 100644 index 0000000..f5f7b1d --- /dev/null +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -0,0 +1,411 @@ +from abc import ABC + +from rdkit import Chem + +from chebai_graph.preprocessing.property_encoder import ( + BoolEncoder, + OneHotEncoder, + PropertyEncoder, +) + +from . import constants as k +from . import properties as pr +from .base import ( + AllNodeTypeProperty, + AtomNodeTypeProperty, + AugmentedBondProperty, + AugmentedMoleculeProperty, + FGNodeTypeProperty, + FrozenPropertyAlias, +) + +# --------------------- Atom Properties ----------------------------- + + +class AtomNodeLevel(AllNodeTypeProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize AtomNodeLevel with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to OneHotEncoder. + """ + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> str | int | bool: + """ + Get the node level property for a given atom/node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + str | int | bool: Property value. + """ + return self._check_modify_atom_prop_value(atom, k.NODE_LEVEL) + + +class AtomFunctionalGroup(FGNodeTypeProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize AtomFunctionalGroup with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to OneHotEncoder. + """ + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> str | int | bool: + """ + Get the functional group property for a given atom/node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + str | int | bool: Property value. + """ + return self._check_modify_atom_prop_value(atom, "FG") + + +class AtomRingSize(FGNodeTypeProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize AtomRingSize with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to OneHotEncoder. + """ + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int: + """ + Get the ring size for a given atom/node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + int: Maximum ring size the atom belongs to, or 0 if none. + """ + return self._check_modify_atom_prop_value(atom, "RING") + + def _check_modify_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> int: + """ + Override to parse and return maximum ring size from a property string. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + prop (str): Property name. + + Returns: + int: Maximum ring size or 0. + """ + ring_size_str = self._get_atom_prop_value(atom, prop) + if ring_size_str: + ring_sizes = list(map(int, str(ring_size_str).split("-"))) + # TODO: Decide ring size for atoms belongs to fused rings, rn only max ring size taken + return max(ring_sizes) + else: + return 0 + + +class IsHydrogenBondDonorFG(FGNodeTypeProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize IsHydrogenBondDonorFG with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to BoolEncoder. + """ + super().__init__(encoder or BoolEncoder(self)) + # fmt: off + # https://github.com/thaonguyen217/farm_molecular_representation/blob/main/src/(6)gen_FG_KG.py#L26-L31 + self._hydrogen_bond_donor: set[str] = { + 'hydroxyl', 'hydroperoxy', 'primary_amine', 'secondary_amine', + 'hydrazone', 'primary_ketimine', 'secondary_ketimine', 'primary_aldimine', + 'amide', 'sulfhydryl', 'sulfonic_acid', 'thiolester', 'hemiacetal', + 'hemiketal', 'carboxyl', 'aldoxime', 'ketoxime' + } + # fmt: on + + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> bool: + """ + Check if the atom's functional group is a hydrogen bond donor. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + bool: True if hydrogen bond donor, else False. + """ + fg = self._check_modify_atom_prop_value(atom, "FG") + return fg in self._hydrogen_bond_donor + + +class IsHydrogenBondAcceptorFG(FGNodeTypeProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize IsHydrogenBondAcceptorFG with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to BoolEncoder. + """ + super().__init__(encoder or BoolEncoder(self)) + # fmt: off + # https://github.com/thaonguyen217/farm_molecular_representation/blob/main/src/(6)gen_FG_KG.py#L33-L39 + self._hydrogen_bond_acceptor: set[str] = { + 'ether', 'peroxy', 'haloformyl', 'ketone', 'aldehyde', 'carboxylate', + 'carboxyl', 'ester', 'ketal', 'carbonate_ester', 'carboxylic_anhydride', + 'primary_amine', 'secondary_amine', 'tertiary_amine', '4_ammonium_ion', + 'hydrazone', 'primary_ketimine', 'secondary_ketimine', 'primary_aldimine', + 'amide', 'sulfhydryl', 'sulfonic_acid', 'thiolester', 'aldoxime', 'ketoxime' + } + # fmt: on + + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> bool: + """ + Determine if the atom is a hydrogen bond acceptor. + + Args: + atom (Chem.rdchem.Atom | dict): The atom object or a dictionary of atom properties. + + Returns: + bool: True if the atom is a hydrogen bond acceptor, False otherwise. + """ + fg = self._check_modify_atom_prop_value(atom, "FG") + return fg in self._hydrogen_bond_acceptor + + +class IsFGAlkyl(FGNodeTypeProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Args: + encoder (PropertyEncoder | None): Optional encoder to use for this property. + Defaults to BoolEncoder if not provided. + """ + super().__init__(encoder or BoolEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int: + """ + Get the alkyl group status of the given atom. + + Args: + atom (Chem.rdchem.Atom | dict): Atom object or atom property dictionary. + + Returns: + int: 1 if alkyl, 0 otherwise. + """ + return int(self._check_modify_atom_prop_value(atom, "is_alkyl")) + + +class AugNodeValueDefaulter(AtomNodeTypeProperty, FrozenPropertyAlias, ABC): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int | None: + """ + Get the property value for an atom or dict node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom object or dict representing node properties. + + Returns: + int | None: Property value or None for dict nodes. + + Raises: + TypeError: If input is neither Chem.rdchem.Atom nor dict. + """ + if isinstance(atom, Chem.rdchem.Atom): + # Delegate to superclass method for atom + return super().get_atom_value(atom) + elif isinstance(atom, dict): + return None + else: + raise TypeError( + f"Expected Chem.rdchem.Atom or dict, got {type(atom).__name__}" + ) + + +class AugAtomType(AugNodeValueDefaulter, pr.AtomType): + """ + This property uses OneHotEncoder as default encoder + + TODO: Can we return 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + Currently, we return None which leads to zero-tensor for augmented nodes + + RDKit uses 0 as the atomic number for a "dummy atom", which usually means: + - A placeholder atom (e.g. [*], R#, or attachment points in SMARTS/SMILES). + - An undefined or wildcard atom. + - A pseudoatom (e.g., for certain fragments or placeholders in reaction centers). + """ + + ... + + +class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): + """ + This property uses OneHotEncoder as default encoder + + Default return value for this property can't be zero, 0 is used for isolated atoms in molecule. + It has to be None or actual node degree. + + TODO: Can return actual node degree/num of connections for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Currently, we return None which leads to zero-tensor for augmented nodes + + But then the question aries shall we count only the atoms connected to a fg node, or all nodes including atoms. + Consider graph node too. + """ + + ... + + +class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): + """ + This property uses OneHotEncoder as default encoder + + Default return value for this property can't be zero, as atoms can have 0 charge. + + TODO: Can return some `unk` value for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + + ... + + +class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): + """ + This property uses OneHotEncoder as default encoder + + TODO: Can return some `HybridizationType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.HybridizationType + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + + ... + + +class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): + """ + This property uses OneHotEncoder as default encoder + + Default return value for this property can't be zero, as atoms can have 0 Hydrogen atoms attached + which mean atoms is full balanced by bonding with other non-hydrogen atoms. + + TODO: Can return some `unk` value for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + + ... + + +class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): + """ + This property uses BoolEncoder as default encoder + + Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + + This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + + ... + + +# --------------------- Bond Properties ------------------------------ + + +class BondLevel(AugmentedBondProperty): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Args: + encoder (PropertyEncoder | None): Optional encoder to use. Defaults to OneHotEncoder. + """ + super().__init__(encoder or OneHotEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond | dict) -> str: + """ + Get the bond level property value. + + Args: + bond (Chem.rdchem.Bond | dict): Bond or bond dict. + + Returns: + str: Bond level property. + """ + return self._check_modify_bond_prop_value(bond, k.EDGE_LEVEL) + + +class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC): + def get_bond_value(self, bond: Chem.rdchem.Bond | dict) -> str | None: + """ + Get bond property value or None for dict bonds. + + Args: + bond (Chem.rdchem.Bond | dict): Bond or bond dict. + + Returns: + str | None: Property value or None for dict. + + Raises: + TypeError: If input type is invalid. + """ + if isinstance(bond, Chem.rdchem.Bond): + # Delegate to superclass method for bond + return super().get_bond_value(bond) + elif isinstance(bond, dict): + return None + else: + raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") + + +class AugBondAromaticity(AugBondValueDefaulter, pr.BondAromaticity): + """ + This property uses BoolEncoder as default encoder + + Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + + This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + + ... + + +class AugBondType(AugBondValueDefaulter, pr.BondType): + """ + This property uses OneHotEncoder as default encoder + + TODO: Can return some `BondType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondType + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + + ... + + +class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): + """ + This property uses BoolEncoder as default encoder + + Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + + This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + + ... + + +# --------------------- Molecule Properties ------------------------------ + + +class AugRDKit2DNormalized(AugmentedMoleculeProperty, pr.RDKit2DNormalized): ... diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py new file mode 100644 index 0000000..da5d9c2 --- /dev/null +++ b/chebai_graph/preprocessing/properties/base.py @@ -0,0 +1,455 @@ +import sys +from abc import ABC, abstractmethod +from types import MappingProxyType + +import rdkit.Chem as Chem + +from chebai_graph.preprocessing.property_encoder import IndexEncoder, PropertyEncoder + +from . import constants as k + +# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order +# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights +# https://mail.python.org/pipermail/python-dev/2017-December/151283.html +assert sys.version_info >= ( + 3, + 7, +), "This code requires Python 3.7 or higher." +# Order preservation is necessary to to create `prop_list`in Augmented properties + + +class MolecularProperty(ABC): + """ + Abstract base class representing a molecular property. + + Properties can be atom-level, bond-level, or molecule-level. + Each property is associated with a PropertyEncoder that encodes + the raw property values into suitable feature representations. + + Args: + encoder: Optional encoder instance to encode property values. + Defaults to IndexEncoder if not provided. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + if encoder is None: + encoder = IndexEncoder(self) + self.encoder: PropertyEncoder = encoder + + @property + def name(self) -> str: + """ + Unique identifier for this property, typically the class name. + + Returns: + The class name as the property's unique name. + """ + return self.__class__.__name__ + + def on_finish(self) -> None: + """ + Called after dataset processing is complete. + + Typically used to finalize encoder states, e.g., saving cache. + """ + self.encoder.on_finish() + + def __str__(self) -> str: + """ + String representation of the property. + + Returns: + The property name. + """ + return self.name + + @abstractmethod + def get_property_value(self, mol: Chem.rdchem.Mol | dict) -> list: + """ + Abstract method to extract the raw property value(s) from a molecule. + + Args: + mol: RDKit molecule object or a dictionary representation. + + Returns: + A list of raw property values for the molecule. + """ + ... + + +class AtomProperty(MolecularProperty, ABC): + """ + Abstract base class representing an atom-level molecular property. + + Subclasses must implement get_atom_value to extract property per atom. + """ + + def get_property_value(self, mol: Chem.rdchem.Mol) -> list: + """ + Extract the property value for each atom in the molecule. + + Args: + mol: RDKit molecule object. + + Returns: + List of property values, one per atom. + """ + return [self.get_atom_value(atom) for atom in mol.GetAtoms()] + + @abstractmethod + def get_atom_value(self, atom: Chem.rdchem.Atom) -> object: + """ + Abstract method to extract the property value of a single atom. + + Args: + atom: RDKit atom object. + + Returns: + The property value for the atom. + """ + pass + + +class BondProperty(MolecularProperty, ABC): + """ + Abstract base class representing a bond-level molecular property. + + Subclasses must implement get_bond_value to extract property per bond. + """ + + def get_property_value(self, mol: Chem.rdchem.Mol) -> list: + """ + Extract the property value for each bond in the molecule. + + Args: + mol: RDKit molecule object. + + Returns: + List of property values, one per bond. + """ + return [self.get_bond_value(bond) for bond in mol.GetBonds()] + + @abstractmethod + def get_bond_value(self, bond: Chem.rdchem.Bond) -> object: + """ + Abstract method to extract the property value of a single bond. + + Args: + bond: RDKit bond object. + + Returns: + The property value for the bond. + """ + pass + + +class MoleculeProperty(MolecularProperty, ABC): + """ + Class representing a global (molecule-level) property. + + Subclasses should override get_property_value for molecule-wide values. + """ + + pass + + +class FrozenPropertyAlias(MolecularProperty, ABC): + """ + Wrapper base class for augmented graph properties that reuse existing molecular properties. + + This allows an augmented property class (with an 'Aug' prefix in its name) to: + - Reuse the encoder and index files of the base property by removing the 'Aug' prefix from its name. + - Prevent adding new tokens to the encoder cache by freezing it (using MappingProxyType). + + Usage: + Inherit from FrozenPropertyAlias and the desired base molecular property class, + and name the class with an 'Aug' prefix (e.g., 'AugAtomType'). + + Example: + ```python + class AugAtomType(FrozenPropertyAlias, AtomType): ... + ``` + + Raises: + ValueError: If new tokens are added to the frozen encoder during processing. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder) + # Lock the encoder's cache to prevent adding new tokens + if hasattr(self.encoder, "cache") and isinstance(self.encoder.cache, dict): + self.encoder.cache = MappingProxyType(self.encoder.cache) + + @property + def name(self) -> str: + """ + Unique identifier for this property. + + Returns: + The class name with the 'Aug' prefix removed if present, + allowing reuse of the base property encoder/index files. + """ + class_name = self.__class__.__name__ + return class_name[3:] if class_name.startswith("Aug") else class_name + + def on_finish(self) -> None: + """ + Called after dataset processing. + + Ensures no new tokens were added to the frozen encoder cache. + Raises an error if this condition is violated. + """ + if ( + hasattr(self.encoder, "cache") + and len(self.encoder.cache) > self.encoder.index_length_start + ): + raise ValueError( + f"{self.__class__.__name__} attempted to add new tokens " + f"to a frozen encoder at {self.encoder.index_path}" + ) + super().on_finish() + + +class AugmentedAtomProperty(AtomProperty, ABC): + MAIN_KEY = "nodes" + + def get_property_value(self, augmented_mol: dict) -> list: + """ + Extract property values for atoms from the augmented molecule dictionary. + + Args: + augmented_mol (dict): Dictionary representing the augmented molecule. + + Raises: + KeyError: If required keys are missing in the dictionary. + TypeError: If types of contained objects are incorrect. + AssertionError: If the number of property values does not match number of nodes. + + Returns: + list: List of property values for all atoms, functional groups, and graph nodes. + """ + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = {"atom_nodes"} - augmented_mol[self.MAIN_KEY].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] + if not isinstance(atom_molecule, Chem.Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' + ) + prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] + + if "fg_nodes" in augmented_mol[self.MAIN_KEY]: + fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] + if not isinstance(fg_nodes, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]) must be an instance of dict ' + f"containing its properties" + ) + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) + + if "graph_node" in augmented_mol[self.MAIN_KEY]: + graph_node = augmented_mol[self.MAIN_KEY]["graph_node"] + if not isinstance(graph_node, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["graph_node"]) must be an instance of dict ' + f"containing its properties" + ) + prop_list.append(self.get_atom_value(graph_node)) + + assert ( + len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] + ), "Number of property values should be equal to number of nodes" + return prop_list + + def _check_modify_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> str | int | bool: + """ + Check that the property value for the atom/node exists and is not empty. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node representation. + prop (str): Property name. + + Raises: + ValueError: If the property is empty. + + Returns: + str | int | bool: The property value. + """ + value = self._get_atom_prop_value(atom, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + def _get_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> str | int | bool: + """ + Retrieve a property value from an atom or dict node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + prop (str): Property name. + + Raises: + TypeError: If atom is not an expected type. + + Returns: + str | int | bool: The property value. + """ + if isinstance(atom, Chem.rdchem.Atom): + return atom.GetProp(prop) + elif isinstance(atom, dict): + return atom[prop] + else: + raise TypeError( + f"Atom/Node in key `{self.MAIN_KEY}` should be of type `Chem.rdchem.Atom` or `dict`." + ) + + +class AtomNodeTypeProperty(AugmentedAtomProperty, ABC): ... + + +class FGNodeTypeProperty(AugmentedAtomProperty, ABC): ... + + +class AllNodeTypeProperty(AugmentedAtomProperty, ABC): ... + + +class AugmentedBondProperty(BondProperty, ABC): + MAIN_KEY = "edges" + + def get_property_value(self, augmented_mol: dict) -> list: + """ + Get bond property values from augmented molecule dict. + + Args: + augmented_mol (dict): Augmented molecule dictionary containing edges. + + Returns: + list: List of property values for bonds in the augmented molecule. + + Raises: + KeyError: If required keys are missing in augmented_mol. + TypeError: If the expected objects are not of correct types. + AssertionError: If number of property values does not match expected edge count. + """ + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = {k.WITHIN_ATOMS_EDGE} - augmented_mol[self.MAIN_KEY].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][k.WITHIN_ATOMS_EDGE] + if not isinstance(atom_molecule, Chem.Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["{k.WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' + ) + prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] + + if k.ATOM_FG_EDGE in augmented_mol[self.MAIN_KEY]: + fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE] + if not isinstance(fg_atom_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.ATOM_FG_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_atom_edges.values()] + ) + + if k.WITHIN_FG_EDGE in augmented_mol[self.MAIN_KEY]: + fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE] + if not isinstance(fg_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.WITHIN_FG_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) + + if k.TO_GRAPHNODE_EDGE in augmented_mol[self.MAIN_KEY]: + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.TO_GRAPHNODE_EDGE] + if not isinstance(fg_graph_node_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.TO_GRAPHNODE_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] + ) + + num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 + assert ( + len(prop_list) == num_directed_edges + ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + + return prop_list + + def _check_modify_bond_prop_value( + self, bond: Chem.rdchem.Bond | dict, prop: str + ) -> str: + """ + Helper to check and get bond property value. + + Args: + bond (Chem.rdchem.Bond | dict): Bond object or bond property dict. + prop (str): Property key to get. + + Returns: + str: Property value. + + Raises: + ValueError: If value is empty or falsy. + """ + value = self._get_bond_prop_value(bond, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + @staticmethod + def _get_bond_prop_value(bond: Chem.rdchem.Bond | dict, prop: str) -> str: + """ + Extract bond property value from bond or dict. + + Args: + bond (Chem.rdchem.Bond | dict): Bond object or dict. + prop (str): Property key. + + Returns: + str: Property value. + + Raises: + TypeError: If bond is not the expected type. + """ + if isinstance(bond, Chem.rdchem.Bond): + return bond.GetProp(prop) + elif isinstance(bond, dict): + return bond[prop] + else: + raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") + + +class AugmentedMoleculeProperty(MoleculeProperty, ABC): + def get_property_value(self, augmented_mol: dict) -> list: + """ + Get molecular property values from augmented molecule dict. + Args: + augmented_mol (dict): Augmented molecule dict. + Returns: + list: Property values of molecule. + """ + mol: Chem.Mol = augmented_mol[AugmentedAtomProperty.MAIN_KEY]["atom_nodes"] + assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`" + return super().get_property_value(mol) diff --git a/chebai_graph/preprocessing/properties/constants.py b/chebai_graph/preprocessing/properties/constants.py new file mode 100644 index 0000000..e4cd2b9 --- /dev/null +++ b/chebai_graph/preprocessing/properties/constants.py @@ -0,0 +1,13 @@ +ATOM_NODE_LEVEL = "atom_node_lvl" +FG_NODE_LEVEL = "fg_node_lvl" +GRAPH_NODE_LEVEL = "graph_node_level" +NODE_LEVEL = "node_level" +NODE_LEVELS = {ATOM_NODE_LEVEL, FG_NODE_LEVEL, GRAPH_NODE_LEVEL} + +EDGE_LEVEL = "edge_level" +WITHIN_ATOMS_EDGE = "within_atoms_lvl" +WITHIN_FG_EDGE = "within_fg_lvl" +ATOM_FG_EDGE = "atom_fg_lvl" +TO_GRAPHNODE_EDGE = "to_graphNode_lvl" +EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, TO_GRAPHNODE_EDGE} +NUM_EDGES = "num_undirected_edges" diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py new file mode 100644 index 0000000..b76f244 --- /dev/null +++ b/chebai_graph/preprocessing/properties/properties.py @@ -0,0 +1,299 @@ +import numpy as np +import rdkit.Chem as Chem +from descriptastorus.descriptors import rdNormalizedDescriptors + +from chebai_graph.preprocessing.property_encoder import ( + AsIsEncoder, + BoolEncoder, + OneHotEncoder, + PropertyEncoder, +) + +from .base import AtomProperty, BondProperty, MoleculeProperty + + +class AtomType(AtomProperty): + """ + Atom property representing the atomic number (type) of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the atomic number of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Atomic number of the atom. + """ + return atom.GetAtomicNum() + + +class NumAtomBonds(AtomProperty): + """ + Atom property representing the number of bonds (degree) of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the number of bonds for the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Number of bonds (degree). + """ + return atom.GetDegree() + + +class AtomCharge(AtomProperty): + """ + Atom property representing the formal charge of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the formal charge of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Formal charge. + """ + return atom.GetFormalCharge() + + +class AtomChirality(AtomProperty): + """ + Atom property representing the chirality tag of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> Chem.rdchem.ChiralType: + """ + Get the chirality tag of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + Chem.rdchem.ChiralType: Chirality tag. + """ + return atom.GetChiralTag() + + +class AtomHybridization(AtomProperty): + """ + Atom property representing the hybridization state of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> Chem.rdchem.HybridizationType: + """ + Get the hybridization state of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + Chem.rdchem.HybridizationType: Hybridization state. + """ + return atom.GetHybridization() + + +class AtomNumHs(AtomProperty): + """ + Atom property representing the total number of hydrogens bonded to an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the total number of hydrogens attached to the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Number of attached hydrogens. + """ + return atom.GetTotalNumHs() + + +class AtomAromaticity(AtomProperty): + """ + Atom property representing whether an atom is aromatic. + + Uses a boolean encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or BoolEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom) -> bool: + """ + Check if the atom is aromatic. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + bool: True if aromatic, else False. + """ + return atom.GetIsAromatic() + + +class BondAromaticity(BondProperty): + """ + Bond property representing whether a bond is aromatic. + + Uses a boolean encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or BoolEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond) -> bool: + """ + Check if the bond is aromatic. + + Args: + bond (Chem.rdchem.Bond): RDKit bond object. + + Returns: + bool: True if aromatic, else False. + """ + return bond.GetIsAromatic() + + +class BondType(BondProperty): + """ + Bond property representing the bond type (single, double, etc.). + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond) -> Chem.rdchem.BondType: + """ + Get the bond type. + + Args: + bond (Chem.rdchem.Bond): RDKit bond object. + + Returns: + Chem.rdchem.BondType: Type of bond. + """ + return bond.GetBondType() + + +class BondInRing(BondProperty): + """ + Bond property indicating whether a bond is in a ring. + + Uses a boolean encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or BoolEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond) -> bool: + """ + Check if the bond is part of a ring. + + Args: + bond (Chem.rdchem.Bond): RDKit bond object. + + Returns: + bool: True if in a ring, else False. + """ + return bond.IsInRing() + + +class MoleculeNumRings(MoleculeProperty): + """ + Molecule-level property representing the number of rings in the molecule. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or OneHotEncoder(self)) + + def get_property_value(self, mol: Chem.rdchem.Mol) -> list[int]: + """ + Get the number of rings in the molecule. + + Args: + mol (Chem.rdchem.Mol): RDKit molecule object. + + Returns: + list[int]: List with single integer representing number of rings. + """ + return [mol.GetRingInfo().NumRings()] + + +class RDKit2DNormalized(MoleculeProperty): + """ + Molecule-level property representing normalized 2D descriptors from RDKit. + + Uses an identity encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: + super().__init__(encoder or AsIsEncoder(self)) + self.generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() + # Create a dummy molecule (e.g., methane) to extract the length of descriptor vector + dummy_mol = Chem.MolFromSmiles("C") + descr_values = self.generator_normalized.processMol( + dummy_mol, Chem.MolToSmiles(dummy_mol) + ) + self.encoder.set_encoding_length(len(descr_values) - 1) + + def get_property_value(self, mol: Chem.rdchem.Mol) -> list[np.ndarray]: + """ + Compute normalized RDKit 2D descriptors for the molecule. + + Args: + mol (Chem.rdchem.Mol): RDKit molecule object. + + Returns: + list[np.ndarray]: List containing the descriptor numpy array (excluding first element). + """ + features_normalized = self.generator_normalized.processMol( + mol, Chem.MolToSmiles(mol) + ) + features_normalized = np.nan_to_num(features_normalized) + return [features_normalized[1:]] diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 532c91f..1487163 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -3,41 +3,69 @@ import os import sys from itertools import islice -from typing import Optional import torch class PropertyEncoder(abc.ABC): - def __init__(self, property, **kwargs): + """ + Abstract base class for encoding property values. + + Args: + property: The property object associated with this encoder. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, property, **kwargs) -> None: self.property = property - self._encoding_length = 1 + self._encoding_length: int = 1 @property - def name(self): + def name(self) -> str: + """Name of the encoder.""" return "" def get_encoding_length(self) -> int: + """Return the length of the encoding vector.""" return self._encoding_length def set_encoding_length(self, encoding_length: int) -> None: + """Set the length of the encoding vector.""" self._encoding_length = encoding_length - def encode(self, value): + def encode(self, value) -> torch.Tensor: + """ + Encode the given value. + + Args: + value: The value to encode. + + Returns: + Encoded tensor. + """ return value - def on_start(self, **kwargs): + def on_start(self, **kwargs) -> None: + """Hook called at the start of encoding process.""" pass - def on_finish(self): + def on_finish(self) -> None: + """Hook called at the end of encoding process.""" return class IndexEncoder(PropertyEncoder): - """Encodes property values as indices. For that purpose, compiles a dynamic list of different values that have - occurred. Stores this list in a file for later reference.""" + """ + Encodes property values as indices. For that purpose, compiles a dynamic list of different values that have + occurred. Stores this list in a file for later reference. + + Args: + property: The property object. + indices_dir: Optional directory to store index files. + **kwargs: Additional keyword arguments. + """ - def __init__(self, property, indices_dir=None, **kwargs): + def __init__(self, property, indices_dir: str | None = None, **kwargs) -> None: super().__init__(property, **kwargs) if indices_dir is None: indices_dir = os.path.dirname(inspect.getfile(self.__class__)) @@ -48,15 +76,22 @@ def __init__(self, property, indices_dir=None, **kwargs): token.strip(): idx for idx, token in enumerate(pk) } self.index_length_start = len(self.cache) - self.offset = 0 + self._unk_token_idx = 0 + self._count_for_unk_token = 0 + self.offset = 1 @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "index" @property - def index_path(self): - """Get path to store indices of property values, create file if it does not exist yet""" + def index_path(self) -> str: + """Get path to store indices of property values, create file if it does not exist yet + + Returns: + Path to index file. + """ index_path = os.path.join( self.dirname, "bin", self.property.name, f"indices_{self.name}.txt" ) @@ -68,8 +103,12 @@ def index_path(self): pass return index_path - def on_finish(self): - """Save cache""" + def on_finish(self) -> None: + """ + Save cache + + Saves new tokens added to the cache to the index file and logs count of unknown tokens. + """ total_tokens = len(self.cache) if total_tokens > self.index_length_start: print("New tokens added to the cache, Saving them to index token file.....") @@ -92,29 +131,63 @@ def on_finish(self): f"Now, the total length of the index of property {self.property.name} is {total_tokens}" ) - def encode(self, token): - """Returns a unique number for each token, automatically adds new tokens to the cache.""" + if self._count_for_unk_token > 0: + print( + f"{self.__class__.__name__} Encountered {self._count_for_unk_token} unknown tokens" + ) + + def encode(self, token: str | None) -> torch.Tensor: + """ + Returns a unique number for each token, automatically adds new tokens to the cache. + + Args: + token: The token to encode. + + Returns: + A tensor containing the encoded index. + """ + if token is None: + self._count_for_unk_token += 1 + return torch.tensor([self._unk_token_idx]) + if str(token) not in self.cache: - self.cache[(str(token))] = len(self.cache) + self.cache[str(token)] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) class OneHotEncoder(IndexEncoder): - """Returns one-hot encoding of the value (position in one-hot vector is defined by index).""" + """ + Returns one-hot encoding of the value (position in one-hot vector is defined by index). + + Args: + property: The property object. + n_labels: Optional number of labels for encoding. + **kwargs: Additional keyword arguments. + """ - def __init__(self, property, n_labels: Optional[int] = None, **kwargs): + def __init__(self, property, n_labels: int | None = None, **kwargs) -> None: super().__init__(property, **kwargs) self._encoding_length = n_labels + # To undo any offset set by index encoder as its not relevant for one-hot-encoder (no offset needed for some unknown/reserved token) + # Also, `torch.nn.functional.one_hot` that class values must be smaller than num_classes. + self.offset = 0 def get_encoding_length(self) -> int: + """Return the number of classes for one-hot encoding.""" return self._encoding_length or len(self.cache) @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "one_hot" - def on_start(self, property_values): - """To get correct number of classes during encoding, cache unique tokens beforehand""" + def on_start(self, property_values: list[list[str | None]]) -> None: + """ + To get correct number of classes during encoding, cache unique tokens beforehand + + Args: + property_values: List of property value sequences. + """ unique_tokens = list( dict.fromkeys( [ @@ -126,31 +199,80 @@ def on_start(self, property_values): ] ) ) - self.tokens_dict = {} + self.tokens_dict: dict[str, torch.Tensor] = {} for token in unique_tokens: self.tokens_dict[token] = super().encode(token) - def encode(self, token): + def encode(self, token: str | None) -> torch.Tensor: + """ + Returns one-hot encoded tensor for the token. + + Args: + token: The token to encode. + + Returns: + One-hot encoded tensor of shape (1, encoding_length). + """ + if token not in self.tokens_dict: + self._count_for_unk_token += 1 + return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64) + return torch.nn.functional.one_hot( self.tokens_dict[token], num_classes=self.get_encoding_length() ) class AsIsEncoder(PropertyEncoder): - """Returns the input value as it is, useful e.g. for float values.""" + """ + Returns the input value as it is, useful e.g. for float values. + """ @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "asis" - def encode(self, token): - return torch.tensor([token]) + def encode(self, token: float | int | None) -> torch.Tensor: + """ + Return the input value as tensor, or zero tensor if None. + + Args: + token: The value to encode. + + Returns: + Tensor of shape (1,) containing the input value or zero. + """ + if token is None: + return torch.zeros(1, self.get_encoding_length()) + assert ( + len(token) == self.get_encoding_length() + ), "Length of token should be equal to encoding length" + # return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning + # UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. + # Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. + # (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\tensor_new.cpp:257.) + # ----- fix: for above warning + return torch.tensor(token).unsqueeze(0) # shape: (1, len(token)) class BoolEncoder(PropertyEncoder): + """ + Encodes boolean values as 0 or 1. + """ + @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "bool" - def encode(self, token: bool): + def encode(self, token: bool) -> torch.Tensor: + """ + Encode boolean token as tensor. + + Args: + token: Boolean value. + + Returns: + Tensor with 1 if True else 0. + """ return torch.tensor([1 if token else 0]) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py deleted file mode 100644 index e71fcfe..0000000 --- a/chebai_graph/preprocessing/reader.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -from typing import List, Optional - -import chebai.preprocessing.reader as dr -import pysmiles as ps -import rdkit.Chem as Chem -import torch -from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn -from torch_geometric.data import Data as GeomData -from torch_geometric.utils import from_networkx - -import chebai_graph.preprocessing.properties as properties -from chebai_graph.preprocessing.collate import GraphCollator - - -class GraphPropertyReader(dr.DataReader): - COLLATOR = GraphCollator - - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.failed_counter = 0 - self.mol_object_buffer = {} - - @classmethod - def name(cls): - return "graph_properties" - - def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: - """Load smiles into rdkit, store object in buffer""" - if smiles in self.mol_object_buffer: - return self.mol_object_buffer[smiles] - - mol = Chem.MolFromSmiles(smiles) - if mol is None: - rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") - self.failed_counter += 1 - else: - try: - Chem.SanitizeMol(mol) - except Exception: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") - self.failed_counter += 1 - self.mol_object_buffer[smiles] = mol - return mol - - def _read_data(self, raw_data): - mol = self._smiles_to_mol(raw_data) - if mol is None: - return None - - x = torch.zeros((mol.GetNumAtoms(), 0)) - - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) - - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) - - def on_finish(self): - rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") - self.mol_object_buffer = {} - - def read_property( - self, smiles: str, property: properties.MolecularProperty - ) -> Optional[List]: - mol = self._smiles_to_mol(smiles) - if mol is None: - return None - return property.get_property_value(mol) - - -class GraphReader(dr.ChemDataReader): - """Reads each atom as one token (atom symbol + charge), reads bond order as edge attribute. - Creates nx Graph from SMILES.""" - - COLLATOR = GraphCollator - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dirname = os.path.dirname(__file__) - - @classmethod - def name(cls): - return "graph" - - def _read_data(self, raw_data) -> Optional[GeomData]: - import networkx as nx - - # raw_data is a SMILES string - try: - mol = ps.read_smiles(raw_data) - except ValueError: - return None - assert isinstance(mol, nx.Graph) - d = {} - de = {} - for node in mol.nodes: - n = mol.nodes[node] - try: - m = n["element"] - charge = n["charge"] - if charge: - if charge > 0: - m += "+" - else: - m += "-" - charge *= -1 - if charge > 1: - m += str(charge) - m = f"[{m}]" - except KeyError: - m = "*" - d[node] = self._get_token_index(m) - for attr in list(mol.nodes[node].keys()): - del mol.nodes[node][attr] - for edge in mol.edges: - de[edge] = mol.edges[edge]["order"] - for attr in list(mol.edges[edge].keys()): - del mol.edges[edge][attr] - nx.set_node_attributes(mol, d, "x") - nx.set_edge_attributes(mol, de, "edge_attr") - data = from_networkx(mol) - return data - - def collate(self, list_of_tuples): - return self.collator(list_of_tuples) diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py new file mode 100644 index 0000000..3bc71ed --- /dev/null +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -0,0 +1,28 @@ +from .augmented_reader import ( + AtomFGReader_NoFGEdges_WithGraphNode, + AtomFGReader_WithFGEdges_NoGraphNode, + AtomFGReader_WithFGEdges_WithGraphNode, + AtomReader_WithGraphNodeOnly, + AtomsFGReader_NoFGEdges_NoGraphNode, + GN_WithAllNodes_FG_WithAtoms_FGE, + GN_WithAllNodes_FG_WithAtoms_NoFGE, + GN_WithAtoms_FG_WithAtoms_FGE, + GN_WithAtoms_FG_WithAtoms_NoFGE, +) +from .reader import GraphPropertyReader, GraphReader +from .static_gni import RandomFeatureInitializationReader + +__all__ = [ + "GraphReader", + "GraphPropertyReader", + "AtomReader_WithGraphNodeOnly", + "AtomsFGReader_NoFGEdges_NoGraphNode", + "AtomFGReader_NoFGEdges_WithGraphNode", + "AtomFGReader_WithFGEdges_NoGraphNode", + "AtomFGReader_WithFGEdges_WithGraphNode", + "RandomFeatureInitializationReader", + "GN_WithAtoms_FG_WithAtoms_FGE", + "GN_WithAtoms_FG_WithAtoms_NoFGE", + "GN_WithAllNodes_FG_WithAtoms_FGE", + "GN_WithAllNodes_FG_WithAtoms_NoFGE", +] diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py new file mode 100644 index 0000000..b6b2d90 --- /dev/null +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -0,0 +1,943 @@ +import re +import sys +from abc import ABC + +import torch +from chebai.preprocessing.reader import DataReader +from rdkit import Chem +from torch_geometric.data import Data as GeomData + +from chebai_graph.preprocessing.collate import GraphCollator +from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import get_structure +from chebai_graph.preprocessing.fg_detection.fg_constants import FLAG_NO_FG +from chebai_graph.preprocessing.properties import MolecularProperty +from chebai_graph.preprocessing.properties import constants as k + +assert sys.version_info >= ( + 3, + 7, +), "This code requires Python 3.7 or higher." +# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order +# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights +# https://mail.python.org/pipermail/python-dev/2017-December/151283.html +# Order preservation is necessary to to create `is_atom_node` mask + + +class _AugmentorReader(DataReader, ABC): + """ + Abstract base class for augmentor readers that extend ChemDataReader. + Handles reading molecular data and augmenting molecules with functional group + information. + """ + + COLLATOR = GraphCollator + + def __init__(self, *args, **kwargs) -> None: + """ + Initializes the augmentor reader and sets up the failure counter and molecule cache. + + Args: + *args: Additional arguments passed to the ChemDataReader. + **kwargs: Additional keyword arguments passed to the ChemDataReader. + """ + super().__init__(*args, **kwargs) + # Record number of failures when constructing molecule from smiles + self.f_cnt_for_smiles: int = 0 + # Record number of failure during augmented graph construction + self.f_cnt_for_aug_graph: int = 0 + self.mol_object_buffer: dict[str, dict] = {} + self._idx_of_node: int = 0 + self._idx_of_edge: int = 0 + + @classmethod + def name(cls) -> str: + """ + Returns the name of the augmentor. + + Returns: + str: Name of the augmentor. + """ + return f"{cls.__name__}".lower() + + def _read_data(self, smiles: str) -> GeomData | None: + """ + Reads and augments molecular data from a SMILES string. + + Args: + smiles (str): SMILES representation of the molecule. + + Returns: + GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges, + or None if parsing or augmentation fails. + + Raises: + RuntimeError: If an unexpected error occurs during graph augmentation. + """ + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + + try: + returned_result = self._create_augmented_graph(mol) + except Exception as e: + raise RuntimeError( + f"Error has occurred for following SMILES: {smiles}\n\t {e}" + ) from e + + # If the returned result is None, it indicates that the graph augmentation failed + if returned_result is None: + print(f"Failed to construct augmented graph for smiles {smiles}") + self.f_cnt_for_aug_graph += 1 + return None + + edge_index, augmented_molecule = returned_result + self.mol_object_buffer[smiles] = augmented_molecule + + # Empty features initialized; node and edge features can be added later + NUM_NODES = augmented_molecule["nodes"]["num_nodes"] + assert ( + NUM_NODES is not None and NUM_NODES > 1 + ), "Num of nodes in augmented graph should be more than 1" + + x = torch.zeros((NUM_NODES, 0)) + edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0)) + + assert ( + edge_index.shape[0] == 2 + ), f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}" + + assert ( + edge_index.shape[1] == edge_attr.shape[0] + ), f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})" + + assert ( + len(set(edge_index[0].tolist())) == x.shape[0] + ), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" + + # Create a boolean mask: True for atom, False for augmented + is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool) + NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() + is_atom_mask[:NUM_ATOM_NODES] = True + + return GeomData( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + is_atom_node=is_atom_mask, + ) + + def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: + """ + Converts a SMILES string to an RDKit molecule object. Sanitizes the molecule. + + Args: + smiles (str): SMILES string representing the molecule. + + Returns: + Chem.Mol | None: RDKit molecule object if successful, else None. + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"RDKit failed to parse {smiles} (returned None)") + self.f_cnt_for_smiles += 1 + else: + try: + Chem.SanitizeMol(mol) + except Exception as e: + print(f"RDKit failed at sanitizing {smiles}, Error {e}") + self.f_cnt_for_smiles += 1 + mol = None + return mol + + def _create_augmented_graph( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict] | None: + """ + Generates an augmented graph from a molecule. + + Args: + mol (Chem.Mol): A molecule generated by RDKit. + + Returns: + tuple[torch.Tensor, dict] | None: + - Augmented graph edge index tensor, + - Augmented graph data dictionary (nodes and edges), + or None if augmentation fails. + + Raises: + ValueError: If directed_edge_index shape is incorrect. + """ + augmented_mol = self._augment_graph_structure(mol) + + directed_edge_index = augmented_mol["directed_edge_index"] + if directed_edge_index is None or directed_edge_index.shape[0] != 2: + raise ValueError( + f"Expected directed_edge_index to have shape [2, num_edges], but got shape {directed_edge_index.shape}" + ) + + # First all directed edges from source to target are placed, then all directed edges from target to source + # are placed --- this is needed as it is easier to align the property values in same way + undirected_edge_index = torch.cat( + [ + directed_edge_index, + directed_edge_index[[1, 0], :], + ], + dim=1, + ) + + augmented_mol["edge_info"][k.NUM_EDGES] *= 2 # Undirected edges + augmented_molecule = { + "nodes": augmented_mol["node_info"], + "edges": augmented_mol["edge_info"], + } + + return undirected_edge_index, augmented_molecule + + def _augment_graph_structure(self, mol: Chem.Mol) -> dict: + """ + Constructs the full augmented graph structure from a molecule. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + dict: A dictionary containing: + - "directed_edge_index" (torch.Tensor): Directed edge index tensor, + - "node_info" (dict): Node attributes dictionary, + - "edge_info" (dict): Edge attributes dictionary, + - "graph_meta_info" (dict): Additional meta information. + """ + self._idx_of_node = mol.GetNumAtoms() + self._idx_of_edge = mol.GetNumBonds() + + self._annotate_atoms_and_bonds(mol) + atom_edge_index = self._generate_atom_level_edge_index(mol) + + total_atoms = mol.GetNumAtoms() + assert ( + self._idx_of_node == total_atoms + ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + + node_info = { + "atom_nodes": mol, + "num_nodes": self._idx_of_node, + } + + total_edges = mol.GetNumBonds() + assert ( + self._idx_of_edge == total_edges + ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + edge_info = { + k.WITHIN_ATOMS_EDGE: mol, + k.NUM_EDGES: self._idx_of_edge, + } + + return { + "directed_edge_index": atom_edge_index, + "node_info": node_info, + "edge_info": edge_info, + "graph_meta_info": {}, + } + + @staticmethod + def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: + """ + Annotates each atom and bond with node and edge with certain properties. + + Args: + mol (Chem.Mol): RDKit molecule. + """ + for atom in mol.GetAtoms(): + atom.SetProp(k.NODE_LEVEL, k.ATOM_NODE_LEVEL) + for bond in mol.GetBonds(): + bond.SetProp(k.EDGE_LEVEL, k.WITHIN_ATOMS_EDGE) + + @staticmethod + def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: + """ + Generates bidirectional atom-level edge index tensor. + + Args: + mol (Chem.Mol): RDKit molecule. + + Returns: + torch.Tensor: Directed edge index tensor with shape [2, num_edges]. + """ + # We need to ensure that directed edges which form a undirected edge are adjacent to each other + edge_index_list: list[list[int]] = [[], []] + for bond in mol.GetBonds(): + edge_index_list[0].append(bond.GetBeginAtomIdx()) + edge_index_list[1].append(bond.GetEndAtomIdx()) + return torch.tensor(edge_index_list, dtype=torch.long) + + def on_finish(self) -> None: + """ + Finalizes the reading process and logs the number of failed SMILES and failed augmentation. + """ + print(f"Failed to read {self.f_cnt_for_smiles} SMILES in total") + print( + f"Failed to construct augmented graph for {self.f_cnt_for_aug_graph} number of SMILES" + ) + self.mol_object_buffer = {} + + def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + """ + Reads a specific property from a molecule represented by a SMILES string. + + Args: + smiles (str): SMILES string representing the molecule. + property (MolecularProperty): Molecular property object for which the value needs to be extracted. + + Returns: + list | None: Property values if molecule parsing is successful, else None. + """ + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + + returned_result = self._create_augmented_graph(mol) + if returned_result is None: + return None + + _, augmented_mol = returned_result + return property.get_property_value(augmented_mol) + + +class AtomsFGReader_NoFGEdges_NoGraphNode(_AugmentorReader): + """ + Adds FG nodes (connected to their respective atom nodes) without + intra-functional group edges, and without introducing a graph-level node. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Constructs the full augmented graph structure from a molecule by adding + fg nodes to their respective atom nodes. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + Tuple[Tensor, dict, dict]: A tuple containing: + - Augmented graph edge index (Tensor), + - Augmented graph node attributes (dict), + - Augmented graph edge attributes (dict). + """ + augmented_mol = super()._augment_graph_structure(mol) + atom_edge_index = augmented_mol["directed_edge_index"] + + # Create FG-level structure and edges + fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, fg_bonds = ( + self._construct_fg_to_atom_structure(mol) + ) + + # Merge all edge types + directed_edge_index = torch.cat( + [ + atom_edge_index, + torch.tensor(fg_atom_edge_index, dtype=torch.long), + ], + dim=1, + ) + augmented_mol["directed_edge_index"] = directed_edge_index + + total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes)]) + assert ( + self._idx_of_node == total_atoms + ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + augmented_mol["node_info"]["fg_nodes"] = fg_nodes + augmented_mol["node_info"]["num_nodes"] = self._idx_of_node + + total_edges = sum([mol.GetNumBonds(), len(atom_fg_edges)]) + assert ( + self._idx_of_edge == total_edges + ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + augmented_mol["edge_info"][k.ATOM_FG_EDGE] = atom_fg_edges + augmented_mol["edge_info"][k.NUM_EDGES] = self._idx_of_edge + + augmented_mol["graph_meta_info"]["fg_to_atoms_map"] = fg_to_atoms_map + augmented_mol["graph_meta_info"]["fg_bonds"] = fg_bonds + + return augmented_mol + + def _construct_fg_to_atom_structure( + self, mol: Chem.Mol + ) -> tuple[list[list[int]], dict, dict, dict, list]: + """ + Constructs edges between functional group (FG) nodes and atom nodes. + This method detects functional groups in the molecule and creates edges + between FG nodes and their connected atom nodes. + + Args: + mol (Chem.Mol): RDKit molecule. + + Returns: + Tuple[list[list[int]], dict, dict, dict, list]: A tuple containing: + - Edge index for FG to atom connections. + - FG node info. + - FG-atom edge attributes. + - FG to atoms mapping. + - Bonds between FG nodes. + + Raises: + ValueError: If functional groups span multiple ring sizes or if no functional group is assigned to atoms. + """ + # Rule-based algorithm to detect functional groups + structure, bonds = get_structure(mol) + assert structure is not None, "Failed to detect functional groups." + + fg_atom_edge_index = [[], []] + fg_nodes, atom_fg_edges = {}, {} + # Contains augmented fg-nodes and connected atoms indices + fg_to_atoms_map = {} + + molecule_atoms_set = set() + for fg_smiles, fg_group in structure.items(): + fg_to_atoms_map[self._idx_of_node] = fg_group + is_ring_fg = fg_group["is_ring_fg"] + + connected_atoms = [] + # Build edge index for fg to atom nodes connections + for atom_idx in fg_group["atom"]: + # Fused rings can have an atom which belong to more than one ring + if atom_idx in molecule_atoms_set and not is_ring_fg: + raise ValueError( + f"An atom {atom_idx} cannot belong to more than one functional group" + ) + molecule_atoms_set.add(atom_idx) + + fg_atom_edge_index[0].append(self._idx_of_node) + fg_atom_edge_index[1].append(atom_idx) + atom_fg_edges[f"{self._idx_of_node}_{atom_idx}"] = { + k.EDGE_LEVEL: k.ATOM_FG_EDGE + } + self._idx_of_edge += 1 + + atom = mol.GetAtomWithIdx(atom_idx) + connected_atoms.append(atom) + + if is_ring_fg: + self._set_ring_fg_prop(connected_atoms, fg_nodes) + else: + self._set_fg_prop(connected_atoms, fg_nodes, fg_smiles) + + self._idx_of_node += 1 + + return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds + + def _set_ring_fg_prop(self, connected_atoms: list, fg_nodes: dict) -> None: + """ + Sets ring functional group properties. + + Args: + connected_atoms (list): List of atoms in the ring. + fg_nodes (dict): Dictionary to store FG node attributes. + + Raises: + ValueError: If an atom in the ring does not have a ring size set. + """ + # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings + ring_size = len(connected_atoms) + fg_nodes[self._idx_of_node] = { + k.NODE_LEVEL: k.FG_NODE_LEVEL, + "FG": f"RING_{ring_size}", + "RING": ring_size, + "is_alkyl": "0", + } + # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group + for atom in connected_atoms: + ring_prop = atom.GetProp("RING") + if not ring_prop: + raise ValueError("Atom does not have a ring size set") + # TODO: discuss the case, should it be max or average + # An atom belonging to multiple rings in fused Ring has size "5-6", indicating size of each ring + max_ring_size = max(list(map(int, ring_prop.split("-")))) + atom.SetProp("FG", f"RING_{max_ring_size}") + atom.SetProp("is_alkyl", "0") + + def _set_fg_prop( + self, connected_atoms: list, fg_nodes: dict, fg_smiles: str + ) -> None: + """ + Sets non-ring functional group properties. + + Args: + connected_atoms (list): Atoms in the FG. + fg_nodes (dict): Dictionary to store FG node attributes. + fg_smiles (str): SMILES of the FG. + + Raises: + ValueError: If functional group assignment is inconsistent or missing. + AssertionError: If no representative atom is found. + """ + NO_FG = "NO_FG" + representative_atom = None + + # Check if the functional group SMILES corresponds to an alkyl group + # by removing common alkyl characters and checking if anything remains. + check = re.sub(r"[CH\-\(\)\[\]/\\@]", "", fg_smiles) + is_alkyl = "1" if len(check) == 0 else "0" + + fg_set = set() + for atom in connected_atoms: + atom.SetProp("is_alkyl", is_alkyl) + + # Set FG to NO_FG if this atom's fg is marked to be ignored + if atom.HasProp(FLAG_NO_FG): + atom.SetProp("FG", NO_FG) + + fg = atom.GetProp("FG") + fg_set.add(fg) + + # Store the last seen valid FG atom as representative + if fg and fg != NO_FG: + representative_atom = atom + + # Raise error if no FG at all was found (likely unexpected state) + if not fg_set: + raise ValueError( + "No functional group assigned to atoms in the functional group." + ) + + # Determine how many valid functional groups are present + valid_fgs = fg_set - {"", NO_FG} + num_of_valid_fgs = len(valid_fgs) + + if num_of_valid_fgs == 0: + # fg_set = {"", NO_FG} or {""} or {NO_FG} + for atom in connected_atoms: + atom.SetProp("FG", NO_FG) + node_fg = NO_FG + + elif num_of_valid_fgs > 1: + # fg_set = {"FG1", "FG2", ...} or {"FG1", "FG2", ..., NO_FG} or + # {"FG1", "FG2", ..., ""} or {"FG1", FG2, ..., "", NO_FG} + # Inconsistent FG assignments; possibly a bug in FG detection + raise ValueError( + "Connected atoms have different functional groups assigned.\n" + "All connected atoms must belong to one functional group or None." + ) + + elif num_of_valid_fgs == 1: + # fg_set = {"FG1"} or {"FG1", ""} or {"FG1", NO_FG} or {"FG1", "", NO_FG} + # Exactly one valid FG; ensure we have an atom to extract it from + if representative_atom is None: + raise AssertionError( + "Expected at least one atom with a valid functional group." + ) + node_fg = representative_atom.GetProp("FG") + # If any atom had FG as an empty string (""), backfill it with node_fg + for atom in connected_atoms: + atom.SetProp("FG", node_fg) + + else: + # This branch is unreachable but kept for safety + raise AssertionError("Unexpected state in functional group detection.") + + # Assign the final FG node metadata + fg_nodes[self._idx_of_node] = { + k.NODE_LEVEL: k.FG_NODE_LEVEL, + "FG": node_fg, + "RING": 0, + "is_alkyl": is_alkyl, + } + + +class AtomFGReader_WithFGEdges_NoGraphNode(AtomsFGReader_NoFGEdges_NoGraphNode): + """ + Adds FG nodes (connected to their respective atom nodes) with intra-functional group + edges, and without introducing a graph-level node. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the molecule graph with intra-functional group edges. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph with FG-level edges. + """ + augmented_struct = super()._augment_graph_structure(mol) + graph_meta_info = augmented_struct["graph_meta_info"] + + fg_to_atoms_map = graph_meta_info["fg_to_atoms_map"] + fg_bonds = graph_meta_info["fg_bonds"] + + fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( + fg_to_atoms_map, fg_bonds + ) + + augmented_struct["edge_info"][k.WITHIN_FG_EDGE] = internal_fg_edges + augmented_struct["edge_info"][k.NUM_EDGES] += len(internal_fg_edges) + + assert ( + self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] + ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + assert ( + self._idx_of_node == augmented_struct["node_info"]["num_nodes"] + ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + + augmented_struct["directed_edge_index"] = torch.cat( + [ + augmented_struct["directed_edge_index"], + torch.tensor(fg_internal_edge_index, dtype=torch.long), + ], + dim=1, + ) + return augmented_struct + + def _construct_fg_level_structure( + self, fg_to_atoms_map: dict, bonds: list + ) -> tuple[list[list[int]], dict]: + """ + Constructs internal edges between functional group nodes based on bond connections. + + Args: + fg_to_atoms_map (dict): Mapping from FG ID to atom indices. + bonds (list): List of bond tuples (source, target, ...). + + Returns: + tuple[list[list[int]], dict]: + - Edge index within FG nodes. + - Edge attributes for edges within FG nodes. + """ + internal_fg_edges = {} + internal_edge_index = [[], []] + + def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: + assert ( + source_fg is not None and target_fg is not None + ), "Each bond should have a fg node on both end" + assert source_fg != target_fg, "Source and Target FG should be different" + + edge_key = tuple(sorted((source_fg, target_fg))) + edge_str = f"{edge_key[0]}_{edge_key[1]}" + if edge_str not in internal_fg_edges: + # If two atoms of a FG point to atom(s) belonging to another FG, only one edge is counted. + # Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG + internal_edge_index[0].append(source_fg) + internal_edge_index[1].append(target_fg) + internal_fg_edges[edge_str] = {k.EDGE_LEVEL: k.WITHIN_FG_EDGE} + self._idx_of_edge += 1 + + for bond in bonds: + source_atom, target_atom = bond[:2] + source_fg, target_fg = None, None + for fg_id, data in fg_to_atoms_map.items(): + if source_fg is None and source_atom in data["atom"]: + source_fg = fg_id + if target_fg is None and target_atom in data["atom"]: + target_fg = fg_id + if source_fg is not None and target_fg is not None: + break + add_fg_internal_edge(source_fg, target_fg) + + # For Rings belonging to fused rings + fg_nodes = list(fg_to_atoms_map.keys()) + for i, fg_node_1 in enumerate(fg_nodes): + fg_map_1 = fg_to_atoms_map[fg_node_1] + for fg_node_2 in fg_nodes[i + 1 :]: + fg_map_2 = fg_to_atoms_map[fg_node_2] + if ( + (fg_node_1 == fg_node_2) + or not fg_map_1["is_ring_fg"] + or not fg_map_2["is_ring_fg"] + ): + continue + if fg_map_1["atom"] & fg_map_2["atom"]: + add_fg_internal_edge(fg_node_1, fg_node_2) + + return internal_edge_index, internal_fg_edges + + +class _AddGraphNode(_AugmentorReader): + """Adds a graph-level node and connects it to selected/given nodes.""" + + def _read_data(self, smiles: str) -> GeomData | None: + """ + Reads data and adds a graph-level node annotation. + + Args: + smiles (str): SMILES string. + + Returns: + Data | None: Geometric data object with is_graph_node annotation. + """ + geom_data = super()._read_data(smiles) + if geom_data is None: + return None + NUM_NODES = geom_data.x.shape[0] + is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) + is_graph_node[-1] = True + geom_data.is_graph_node = is_graph_node + return geom_data + + def _add_graph_node_and_edges_to_nodes( + self, + augmented_struct: dict, + nodes_ids: dict[int, object] | set[int], + ) -> tuple[torch.Tensor, dict, dict]: + """ + Adds a graph-level node and connects it to given nodes. + + Args: + augmented_struct (dict): Current graph structure. + nodes_ids (dict[int, object] | set[int]): Node indices to connect to the graph-level node. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure with graph node edges and metadata. + """ + nodes_graph_edge_index, graph_node, nodes_to_graph_edges = ( + self._construct_nodes_to_graph_node_structure(nodes_ids) + ) + + augmented_struct["edge_info"][k.TO_GRAPHNODE_EDGE] = nodes_to_graph_edges + augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges) + assert ( + self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] + ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + + augmented_struct["node_info"]["graph_node"] = graph_node + augmented_struct["node_info"]["num_nodes"] += 1 + assert ( + self._idx_of_node == augmented_struct["node_info"]["num_nodes"] + ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + + augmented_struct["directed_edge_index"] = torch.cat( + [ + augmented_struct["directed_edge_index"], + torch.tensor(nodes_graph_edge_index, dtype=torch.long), + ], + dim=1, + ) + return augmented_struct + + def _construct_nodes_to_graph_node_structure( + self, nodes_ids: dict[int, object] | set[int] + ) -> tuple[list[list[int]], dict, dict]: + """ + Constructs edges between selected nodes and a global graph-level node. + + Args: + nodes_ids (dict[int, object] | set[int]): IDs of nodes to connect to the graph-level node. + + Returns: + tuple[list[list[int]], dict, dict]: + - Edge index connecting nodes to graph node. + - Graph-level node attributes. + - Edge attributes for graph-level connections. + """ + graph_node = { + k.NODE_LEVEL: k.GRAPH_NODE_LEVEL, + "FG": "graph_fg", + "RING": "0", + "is_alkyl": "0", + } + + graph_to_nodes_edges = {} + graph_edge_index = [[], []] + + for fg_id in nodes_ids: + graph_edge_index[0].append(self._idx_of_node) + graph_edge_index[1].append(fg_id) + graph_to_nodes_edges[f"{self._idx_of_node}_{fg_id}"] = { + k.EDGE_LEVEL: k.TO_GRAPHNODE_EDGE + } + self._idx_of_edge += 1 + self._idx_of_node += 1 + + return graph_edge_index, graph_node, graph_to_nodes_edges + + +class AtomFGReader_WithFGEdges_WithGraphNode( + AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) with intra-functional group + edges, and adds a graph-level node connected to all FG nodes. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, fg_to_atoms_map + ) + + +class AtomFGReader_NoFGEdges_WithGraphNode( + AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all FG nodes. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with only a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, fg_to_atoms_map + ) + + +class AtomReader_WithGraphNodeOnly(_AddGraphNode): + """Adds a graph-level node and connects it to all atom nodes.""" + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph by adding a graph-level node connected to all atoms. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) + + +class GN_WithAtoms_FG_WithAtoms_NoFGE( + AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) + + +class GN_WithAtoms_FG_WithAtoms_FGE( + AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) + + +class GN_WithAllNodes_FG_WithAtoms_FGE( + AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, atom_ids | fg_to_atoms_map.keys() + ) + + +class GN_WithAllNodes_FG_WithAtoms_NoFGE( + AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, atom_ids | fg_to_atoms_map.keys() + ) diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py new file mode 100644 index 0000000..a63b8a1 --- /dev/null +++ b/chebai_graph/preprocessing/reader/reader.py @@ -0,0 +1,203 @@ +import os + +import chebai.preprocessing.reader as dr +import networkx as nx +import pysmiles as ps +import rdkit.Chem as Chem +import torch +from torch_geometric.data import Data as GeomData +from torch_geometric.utils import from_networkx + +from chebai_graph.preprocessing.collate import GraphCollator +from chebai_graph.preprocessing.properties import MolecularProperty + + +class GraphPropertyReader(dr.DataReader): + COLLATOR = GraphCollator + + def __init__( + self, + *args, + **kwargs, + ) -> None: + """ + Initialize GraphPropertyReader. + + Args: + *args: Positional arguments forwarded to the base class. + **kwargs: Keyword arguments forwarded to the base class. + """ + super().__init__(*args, **kwargs) + self.failed_counter = 0 + self.mol_object_buffer: dict[str, Chem.rdchem.Mol | None] = {} + + @classmethod + def name(cls) -> str: + """ + Get the name identifier of the reader. + + Returns: + str: The name of the reader. + """ + return "graph_properties" + + def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: + """ + Load SMILES string into an RDKit molecule object and cache it. + + Args: + smiles (str): The SMILES string to parse. + + Returns: + Chem.rdchem.Mol | None: Parsed molecule object or None if parsing failed. + """ + if smiles in self.mol_object_buffer: + return self.mol_object_buffer[smiles] + + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"RDKit failed to at parsing {smiles} (returned None)") + self.failed_counter += 1 + else: + try: + Chem.SanitizeMol(mol) + except Exception as e: + print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") + self.failed_counter += 1 + self.mol_object_buffer[smiles] = mol + return mol + + def _read_data(self, raw_data: str) -> GeomData | None: + """ + Convert raw SMILES string data into a PyTorch Geometric Data object. + + Args: + raw_data (str): SMILES string. + + Returns: + GeomData | None: Graph data object or None if molecule parsing failed. + """ + mol = self._smiles_to_mol(raw_data) + if mol is None: + return None + + x = torch.zeros((mol.GetNumAtoms(), 0)) + + # First source to target edges, then target to source edges + src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()] + tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()] + edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long) + + # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] + edge_attr = torch.zeros((edge_index.size(1), 0)) + + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + + def on_finish(self) -> None: + """ + Called after reading is done to log information and clean up. + """ + print(f"Failed to read {self.failed_counter} SMILES in total") + self.mol_object_buffer = {} + + def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + """ + Read a molecular property for a given SMILES string. + + Args: + smiles (str): SMILES string of the molecule. + property (MolecularProperty): Property extractor to apply. + + Returns: + list | None: Property values or None if molecule parsing failed. + """ + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + return property.get_property_value(mol) + + +class GraphReader(dr.ChemDataReader): + """Reads each atom as one token (atom symbol + charge), reads bond order as edge attribute. + Creates nx Graph from SMILES.""" + + COLLATOR = GraphCollator + + def __init__(self, *args, **kwargs) -> None: + """ + Initialize GraphReader. + + Args: + *args: Positional arguments forwarded to the base class. + **kwargs: Keyword arguments forwarded to the base class. + """ + super().__init__(*args, **kwargs) + self.dirname = os.path.dirname(__file__) + + @classmethod + def name(cls) -> str: + """ + Get the name identifier of the reader. + + Returns: + str: The name of the reader. + """ + return "graph" + + def _read_data(self, raw_data: str) -> GeomData | None: + """ + Convert a SMILES string into a PyTorch Geometric Data object with atom tokens and bond order attributes. + + Args: + raw_data (str): SMILES string. + + Returns: + GeomData | None: Graph data object or None if parsing failed. + """ + # raw_data is a SMILES string + try: + mol = ps.read_smiles(raw_data) + except ValueError: + return None + assert isinstance(mol, nx.Graph) + d: dict[int, int] = {} + de: dict[tuple[int, int], int] = {} + for node in mol.nodes: + n = mol.nodes[node] + try: + m = n["element"] + charge = n["charge"] + if charge: + if charge > 0: + m += "+" + else: + m += "-" + charge *= -1 + if charge > 1: + m += str(charge) + m = f"[{m}]" + except KeyError: + m = "*" + d[node] = self._get_token_index(m) + for attr in list(mol.nodes[node].keys()): + del mol.nodes[node][attr] + for edge in mol.edges: + de[edge] = mol.edges[edge]["order"] + for attr in list(mol.edges[edge].keys()): + del mol.edges[edge][attr] + nx.set_node_attributes(mol, d, "x") + nx.set_edge_attributes(mol, de, "edge_attr") + data = from_networkx(mol) + return data + + def collate(self, list_of_tuples: list) -> any: + """ + Collate a list of samples into a batch. + + Args: + list_of_tuples (list): List of data tuples to collate. + + Returns: + Any: Collated batch (type depends on collator). + """ + return self.collator(list_of_tuples) diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py new file mode 100644 index 0000000..106c528 --- /dev/null +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -0,0 +1,80 @@ +""" +Abboud, Ralph, et al. +"The surprising power of graph neural networks with random node initialization." +arXiv preprint arXiv:2010.01179 (2020). + +Code Reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py +""" + +import torch +from torch_geometric.data import Data as GeomData + +from .reader import GraphPropertyReader + + +class RandomFeatureInitializationReader(GraphPropertyReader): + DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] + + def __init__( + self, + num_node_properties: int, + num_bond_properties: int, + num_molecule_properties: int, + distribution: str = "normal", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_node_properties = num_node_properties + self.num_bond_properties = num_bond_properties + self.num_molecule_properties = num_molecule_properties + assert distribution in self.DISTRIBUTIONS + self.distribution = distribution + + def name(self) -> str: + """ + Get the name identifier of the reader. + + Returns: + str: The name of the reader. + """ + return f"gni-{self.distribution}-node{self.num_node_properties}-bond{self.num_bond_properties}-mol{self.num_molecule_properties}" + + def _read_data(self, raw_data): + data: GeomData = super()._read_data(raw_data) + if data is None: + return None + + random_x = torch.empty(data.x.shape[0], self.num_node_properties) + random_edge_attr = torch.empty( + data.edge_attr.shape[0], self.num_bond_properties + ) + random_molecule_properties = torch.empty(1, self.num_molecule_properties) + + self.random_gni(random_x, self.distribution) + self.random_gni(random_edge_attr, self.distribution) + self.random_gni(random_molecule_properties, self.distribution) + + data.x = random_x + data.edge_attr = random_edge_attr + data.molecule_attr = random_molecule_properties + return data + + def read_property(self, *args, **kwargs) -> Exception: + """This reader does not support reading specific properties.""" + raise NotImplementedError("This reader only performs random initialization.") + + @staticmethod + def random_gni(tensor: torch.Tensor, distribution: str) -> None: + if distribution == "normal": + torch.nn.init.normal_(tensor) + elif distribution == "uniform": + torch.nn.init.uniform_(tensor, a=-1.0, b=1.0) + elif distribution == "xavier_normal": + torch.nn.init.xavier_normal_(tensor) + elif distribution == "xavier_uniform": + torch.nn.init.xavier_uniform_(tensor) + elif distribution == "zeros": + torch.nn.init.zeros_(tensor) + else: + raise ValueError("Unknown distribution type") diff --git a/chebai_graph/preprocessing/structures.py b/chebai_graph/preprocessing/structures.py index b417593..665c605 100644 --- a/chebai_graph/preprocessing/structures.py +++ b/chebai_graph/preprocessing/structures.py @@ -1,11 +1,34 @@ +import torch from chebai.preprocessing.structures import XYData class XYGraphData(XYData): - def __len__(self): + """ + Extension of XYData supporting `.to(device)` for potentially complex `x` structures. + + `x` can be: + - a tensor, + - a tuple of tensors or dicts of tensors, + and this class recursively sends all tensors to the specified device. + + Args: + Inherits from XYData. + """ + + def __len__(self) -> int: + """Return the length of y.""" return len(self.y) - def to_x(self, device): + def to_x(self, device: torch.device | str) -> object: + """ + Move the input features `x` to the given device. + + Args: + device: torch device or device string (e.g. 'cpu' or 'cuda'). + + Returns: + The input `x` moved to the specified device, preserving structure. + """ if isinstance(self.x, tuple): res = [] for elem in self.x: diff --git a/configs/data/chebi100_graph_properties.yml b/configs/data/chebi100_graph_properties.yml new file mode 100644 index 0000000..7e78711 --- /dev/null +++ b/configs/data/chebi100_graph_properties.yml @@ -0,0 +1,14 @@ +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI100GraphProperties +init_args: + properties: + - chebai_graph.preprocessing.properties.AtomType + - chebai_graph.preprocessing.properties.NumAtomBonds + - chebai_graph.preprocessing.properties.AtomCharge + - chebai_graph.preprocessing.properties.AtomAromaticity + - chebai_graph.preprocessing.properties.AtomHybridization + - chebai_graph.preprocessing.properties.AtomNumHs + - chebai_graph.preprocessing.properties.BondType + - chebai_graph.preprocessing.properties.BondInRing + - chebai_graph.preprocessing.properties.BondAromaticity + #- chebai_graph.preprocessing.properties.MoleculeNumRings + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/data/chebi50_aug_prop_as_per_node.yml b/configs/data/chebi50_aug_prop_as_per_node.yml new file mode 100644 index 0000000..576cf75 --- /dev/null +++ b/configs/data/chebi50_aug_prop_as_per_node.yml @@ -0,0 +1,24 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType +init_args: + properties: + # All Node type properties + - chebai_graph.preprocessing.properties.AtomNodeLevel + # Atom Node type properties + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + # FG Node type properties + - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG + - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG + - chebai_graph.preprocessing.properties.IsFGAlkyl + # Graph Node type properties + - chebai_graph.preprocessing.properties.AugRDKit2DNormalized + # Bond properties + - chebai_graph.preprocessing.properties.BondLevel + - chebai_graph.preprocessing.properties.AugBondAromaticity + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondType diff --git a/configs/data/chebi50_augmented_baseline.yml b/configs/data/chebi50_augmented_baseline.yml new file mode 100644 index 0000000..e8520f3 --- /dev/null +++ b/configs/data/chebi50_augmented_baseline.yml @@ -0,0 +1,12 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp +init_args: + properties: + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AugBondType + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondAromaticity diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 19c8753..12e8abd 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData +class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphData diff --git a/configs/data/chebi50_graph_properties.yml b/configs/data/chebi50_graph_properties.yml index c84ac3e..0b770b2 100644 --- a/configs/data/chebi50_graph_properties.yml +++ b/configs/data/chebi50_graph_properties.yml @@ -1,4 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphProperties +class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties init_args: properties: - chebai_graph.preprocessing.properties.AtomType diff --git a/configs/data/chebi50_static_gni.yml b/configs/data/chebi50_static_gni.yml new file mode 100644 index 0000000..12096cb --- /dev/null +++ b/configs/data/chebi50_static_gni.yml @@ -0,0 +1,6 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50_StaticGNI +init_args: + reader_kwargs: + num_node_properties: 158 + num_bond_properties: 7 + num_molecule_properties: 0 diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index c21f188..855f93a 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -1,4 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.pubchem.PubChemGraphProperties +class_path: chebai_graph.preprocessing.datasets.PubChemGraphProperties init_args: transform: class_path: chebai_graph.preprocessing.transform_unlabeled.MaskAtom # mask atoms / bonds diff --git a/configs/model/gat.yml b/configs/model/gat.yml new file mode 100644 index 0000000..d72b4a7 --- /dev/null +++ b/configs/model/gat.yml @@ -0,0 +1,15 @@ +class_path: chebai_graph.models.GATGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 158 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 7 # number of bond properties + heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) + v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/gat_aug_aapool.yml b/configs/model/gat_aug_aapool.yml new file mode 100644 index 0000000..05ef4e8 --- /dev/null +++ b/configs/model/gat_aug_aapool.yml @@ -0,0 +1,15 @@ +class_path: chebai_graph.models.GATAugNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) + v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/gat_aug_amgpool.yml b/configs/model/gat_aug_amgpool.yml new file mode 100644 index 0000000..b5ce89d --- /dev/null +++ b/configs/model/gat_aug_amgpool.yml @@ -0,0 +1,15 @@ +class_path: chebai_graph.models.GATGraphNodeFGNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) + v2: True # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/res_aug_aapool.yml b/configs/model/res_aug_aapool.yml new file mode 100644 index 0000000..9e364f9 --- /dev/null +++ b/configs/model/res_aug_aapool.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedAugNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/res_aug_amgpool.yml b/configs/model/res_aug_amgpool.yml new file mode 100644 index 0000000..2aba5ea --- /dev/null +++ b/configs/model/res_aug_amgpool.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedGraphNodeFGNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/resgated.yml b/configs/model/resgated.yml new file mode 100644 index 0000000..ccc6615 --- /dev/null +++ b/configs/model/resgated.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 158 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 7 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/resgated_dynamic_gni.yml b/configs/model/resgated_dynamic_gni.yml new file mode 100644 index 0000000..4749795 --- /dev/null +++ b/configs/model/resgated_dynamic_gni.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedDynamicGNIGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 158 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 7 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/pyproject.toml b/pyproject.toml index 1e8745d..8a56a81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ # torch-geometric # torch_scatter ] +requires-python = ">=3.8" + [project.optional-dependencies] dev = [ diff --git a/results/visualize_augmented_molecule.py b/results/visualize_augmented_molecule.py new file mode 100644 index 0000000..fcc406b --- /dev/null +++ b/results/visualize_augmented_molecule.py @@ -0,0 +1,517 @@ +import io + +import matplotlib +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +from jsonargparse import CLI +from PIL import Image +from rdkit.Chem import AllChem, BondType, Mol, rdDepictor +from rdkit.Chem.Draw import rdMolDraw2D +from torch import Tensor + +from chebai_graph.preprocessing.properties import constants as k +from chebai_graph.preprocessing.reader import ( + AtomFGReader_NoFGEdges_WithGraphNode, + AtomFGReader_WithFGEdges_NoGraphNode, + AtomFGReader_WithFGEdges_WithGraphNode, + AtomReader_WithGraphNodeOnly, + AtomsFGReader_NoFGEdges_NoGraphNode, + GN_WithAllNodes_FG_WithAtoms_FGE, + GN_WithAllNodes_FG_WithAtoms_NoFGE, + GN_WithAtoms_FG_WithAtoms_FGE, + GN_WithAtoms_FG_WithAtoms_NoFGE, +) + +matplotlib.use("TkAgg") + +EDGE_COLOR_MAP = { + k.WITHIN_ATOMS_EDGE: "#1f77b4", + k.ATOM_FG_EDGE: "#9467bd", + k.WITHIN_FG_EDGE: "#ff7f0e", + k.TO_GRAPHNODE_EDGE: "#2ca02c", +} + +NODE_COLOR_MAP = { + "atom": "#9ecae1", + "fg": "#fdae6b", + "graph": "#d62728", +} + + +BOND_COLOR_MAP = { + BondType.SINGLE: "black", + BondType.DOUBLE: "blue", + BondType.TRIPLE: "green", + BondType.DATIVE: "red", + BondType.DATIVEL: "red", + BondType.DATIVER: "red", + BondType.DATIVEONE: "red", +} + +READER = { + "n_fge_w_gn": AtomFGReader_NoFGEdges_WithGraphNode, + "w_fge_w_gn": AtomFGReader_WithFGEdges_WithGraphNode, + "w_fge_n_gn": AtomFGReader_WithFGEdges_NoGraphNode, + "n_fge_n_gn": AtomsFGReader_NoFGEdges_NoGraphNode, + "atom_w_gn": AtomReader_WithGraphNodeOnly, + "gnwa_fgwa_nfge": GN_WithAtoms_FG_WithAtoms_NoFGE, + "gnwa_fgwa_wfge": GN_WithAtoms_FG_WithAtoms_FGE, + "gn_wall_fgwa_nfge": GN_WithAllNodes_FG_WithAtoms_NoFGE, + "gn_wall_fgwa_wfge": GN_WithAllNodes_FG_WithAtoms_FGE, +} + + +def _create_graph( + edge_index: Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict +) -> nx.Graph: + """ + Create a NetworkX graph from augmented molecular information. + + Args: + edge_index (torch.Tensor): Tensor of shape (2, num_edges) with source and target indices. + augmented_graph_nodes (dict): Dictionary of node attributes grouped by type ('atom_nodes', 'fg_nodes', etc.). + augmented_graph_edges (dict): Dictionary of edges grouped by predefined edge type. + + Returns: + nx.Graph: Constructed NetworkX graph with annotated nodes and edges. + """ + G = nx.Graph() + + # Add atom nodes + atom_nodes = augmented_graph_nodes["atom_nodes"] + for atom in atom_nodes.GetAtoms(): + idx = atom.GetIdx() + G.add_node( + idx, + node_name=atom.GetSymbol(), + node_type="atom", + node_color=NODE_COLOR_MAP["atom"], + ) + + # Add functional group (FG) nodes + if "fg_nodes" in augmented_graph_nodes: + fg_nodes = augmented_graph_nodes["fg_nodes"] + for fg_idx, fg_props in fg_nodes.items(): + G.add_node( + fg_idx, + node_name=f"FG:{fg_props['FG']}", + node_type="fg", + node_color=NODE_COLOR_MAP["fg"], + ) + + if "graph_node" in augmented_graph_nodes: + graph_node_idx = augmented_graph_nodes["num_nodes"] - 1 + G.add_node( + graph_node_idx, + node_name="Graph Node", + node_type="graph", + node_color=NODE_COLOR_MAP["graph"], + ) + + # Decode edge types and add edges with proper color and type + src_nodes, tgt_nodes = edge_index.tolist() + with_atom_edges = { + f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" + for bond in augmented_graph_edges[k.WITHIN_ATOMS_EDGE].GetBonds() + } + + atom_fg_edges = ( + set(augmented_graph_edges[k.ATOM_FG_EDGE]) + if k.ATOM_FG_EDGE in augmented_graph_edges + else set() + ) + within_fg_edges = ( + set(augmented_graph_edges[k.WITHIN_FG_EDGE]) + if k.WITHIN_FG_EDGE in augmented_graph_edges + else set() + ) + fg_graph_edges = ( + set(augmented_graph_edges[k.TO_GRAPHNODE_EDGE]) + if k.TO_GRAPHNODE_EDGE in augmented_graph_edges + else set() + ) + + for src, tgt in zip(src_nodes, tgt_nodes): + undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} + if undirected_edge_set & with_atom_edges: + edge_type = k.WITHIN_ATOMS_EDGE + elif undirected_edge_set & atom_fg_edges: + edge_type = k.ATOM_FG_EDGE + elif undirected_edge_set & within_fg_edges: + edge_type = k.WITHIN_FG_EDGE + elif undirected_edge_set & fg_graph_edges: + edge_type = k.TO_GRAPHNODE_EDGE + else: + raise ValueError("Unexpected edge type") + G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type]) + + return G + + +def _get_subgraph_by_node_type(G: nx.Graph, node_type: str) -> nx.Graph: + """ + Extract a subgraph containing only nodes of the given type. + + Args: + G (nx.Graph): Full graph with node_type attributes. + node_type (str): Type of node to extract ('atom', 'fg', or 'graph'). + + Returns: + nx.Graph: Subgraph with selected node type. + """ + selected_nodes = [ + n for n, attr in G.nodes(data=True) if attr.get("node_type") == node_type + ] + return G.subgraph(selected_nodes).copy() if selected_nodes else None + + +def _draw_hierarchy(G: nx.Graph, mol: Mol) -> None: + """ + Draw a hierarchical layout combining RDKit 2D coordinates for atoms and spring layout for FG/graph nodes. + + Args: + G (nx.Graph): Augmented molecular graph. + mol (Chem.Mol): RDKit molecule object with atom layout. + """ + AllChem.Compute2DCoords(mol) + + # Get 2D positions from RDKit + atom_pos = {} + max_atom_pos_y = 0 + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + pos = mol.GetConformer().GetAtomPosition(idx) + atom_pos[idx] = (pos.x, pos.y) + max_atom_pos_y = max(max_atom_pos_y, pos.y) + + # Position FG nodes above atoms + fg_graph = _get_subgraph_by_node_type(G, "fg") + fg_pos = { + node: (x, y + max_atom_pos_y + 2) + for node, (x, y) in nx.spring_layout(fg_graph, seed=42).items() + } + + # Position the graph node further above + graph_node_graph = _get_subgraph_by_node_type(G, "graph") + graph_pos = { + node: (x, y + max_atom_pos_y + 3) + for node, (x, y) in nx.spring_layout(graph_node_graph, seed=123).items() + } + + # Merge all positions + pos = {**atom_pos, **fg_pos, **graph_pos} + node_colors = [G.nodes[n]["node_color"] for n in G.nodes] + node_labels = {n: G.nodes[n]["node_name"] for n in G.nodes} + edge_colors = [G.edges[e]["edge_color"] for e in G.edges] + + plt.figure(figsize=(10, 8)) + nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=600) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=2) + plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") + plt.axis("off") + plt.show() + + +def _draw_simple(G: nx.Graph) -> None: + """ + Draw the graph using a simple spring layout. + + Args: + G (nx.Graph): Augmented molecular graph. + """ + plt.figure(figsize=(10, 8)) + pos = nx.spring_layout(G, seed=42) + node_colors = [G.nodes[n]["node_color"] for n in G.nodes] + node_labels = {n: G.nodes[n]["node_name"] for n in G.nodes} + edge_colors = [G.edges[e]["edge_color"] for e in G.edges] + nx.draw( + G, + pos, + with_labels=False, + node_color=node_colors, + node_size=600, + edge_color=edge_colors, + width=2, + ) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + plt.title("Augmented Graph with simple layout") + plt.axis("off") + plt.show() + + +def _draw_3d(G: nx.Graph, mol: Mol) -> None: + """ + Visualize the graph in 3D using Plotly. + + Args: + G (nx.Graph): Augmented molecular graph. + mol (Chem.Mol): RDKit molecule object for 3D coordinates. + + Raises: + ImportError: If Plotly is not installed. + """ + try: + from plotly import graph_objects as go + except ImportError as e: + raise ImportError( + "Plotly is required for 3D plotting. Install it with `pip install plotly`." + ) from e + + # Generate 3D coordinates for atoms + AllChem.EmbedMolecule(mol) + conf = mol.GetConformer() + + atom_pos = { + atom.GetIdx(): (pos.x, pos.y, 0) + for atom in mol.GetAtoms() + for pos in [conf.GetAtomPosition(atom.GetIdx())] + } + + # Dictionary to store functional group node positions + fg_pos = {} + fg_subgraph = _get_subgraph_by_node_type(G, "fg") + if fg_subgraph: + # Loop through each functional group node in the graph + for fg_node in fg_subgraph.nodes(): + # Get connected atom nodes (assuming edges are between fg and atom nodes) + connected_atoms = [ + nbr + for nbr in G.neighbors(fg_node) + if G.nodes[nbr].get("node_type") == "atom" + ] + + # Get the 2D positions of the connected atoms + positions = np.array([atom_pos[atom] for atom in connected_atoms]) + x_mean, y_mean = positions[:, 0].mean(), positions[:, 1].mean() + fg_pos[fg_node] = (x_mean, y_mean, 2) # z = 2 for elevation + + graph_pos = {} + graph_subgraph = _get_subgraph_by_node_type(G, "graph") + if graph_subgraph: + graph_node = next(iter(graph_subgraph.nodes())) + neighbor_type = { + G.nodes[nbr].get("node_type") for nbr in G.neighbors(graph_node) + } + + if "fg" in neighbor_type: + graph_pos_arr = np.array( + [ + fg_pos[nbr] + for nbr in G.neighbors(graph_node) + if G.nodes[nbr].get("node_type") == "fg" + ] + ) + else: + graph_pos_arr = np.array( + [ + atom_pos[nbr] + for nbr in G.neighbors(graph_node) + if G.nodes[nbr].get("node_type") == "atom" + ] + ) + graph_pos = { + graph_node: (graph_pos_arr[:, 0].mean(), graph_pos_arr[:, 1].mean(), 4) + } + + pos = {**atom_pos, **fg_pos, **graph_pos} + + # Collect edges by type + edge_type_to_edges = { + k.WITHIN_ATOMS_EDGE: [], + k.ATOM_FG_EDGE: [], + k.WITHIN_FG_EDGE: [], + k.TO_GRAPHNODE_EDGE: [], + } + for src, tgt, data in G.edges(data=True): + edge_type_to_edges[data["edge_type"]].append((src, tgt)) + + edge_traces = [] + for edge_type, edge_list in edge_type_to_edges.items(): + xs, ys, zs = [], [], [] + for src, tgt in edge_list: + x0, y0, z0 = pos[src] + x1, y1, z1 = pos[tgt] + xs += [x0, x1, None] + ys += [y0, y1, None] + zs += [z0, z1, None] + + trace = go.Scatter3d( + x=xs, + y=ys, + z=zs, + mode="lines", + line=dict(color=EDGE_COLOR_MAP[edge_type], width=4), + name=edge_type, + hoverinfo="none", + ) + edge_traces.append(trace) + + # Collect node attributes for visualization + pos_x, pos_y, pos_z, node_colors, node_names, node_ids = zip( + *[ + (pos[n][0], pos[n][1], pos[n][2], attr["node_color"], attr["node_name"], n) + for n, attr in G.nodes(data=True) + ] + ) + + node_trace = go.Scatter3d( + x=pos_x, + y=pos_y, + z=pos_z, + mode="markers+text", + marker=dict(size=8, color=node_colors, opacity=0.9), + text=node_names, + textposition="top center", + hovertext=node_ids, + hoverinfo="text", + ) + + fig = go.Figure(data=edge_traces + [node_trace]) + fig.update_layout( + title="3D Augmented Molecule Graph", + showlegend=True, + scene=dict( + xaxis=dict(visible=False), + yaxis=dict(visible=False), + zaxis=dict(visible=False), + ), + margin=dict(l=0, r=0, b=0, t=40), + ) + fig.show() + + +def plot_augmented_graph( + edge_index: Tensor, + augmented_molecule: dict, + mol: Mol, + plot_type: str, +) -> None: + """ + Main plotting function to visualize the augmented graph. + + Args: + edge_index (torch.Tensor): Edge indices tensor (2, num_edges). + augmented_molecule (dict): Augmented Molecule. + mol (Chem.Mol): RDKit molecule object. + plot_type (str): One of ["simple", "h", "3d"]. + """ + G = _create_graph( + edge_index, augmented_molecule["nodes"], augmented_molecule["edges"] + ) + + if plot_type == "h": + _draw_hierarchy(G, mol) + elif plot_type == "simple": + _draw_simple(G) + elif plot_type == "3d": + _draw_3d(G, mol) + else: + raise ValueError(f"Unknown plot type: {plot_type}") + + +def plot_nonaugment_molecule_graph(mol: Mol, size=(800, 800)) -> None: + """ + Visualize a molecule using rdkit. + """ + + print(f"Number of atoms: {mol.GetNumAtoms()}") + print(f"Number of bonds: {mol.GetNumBonds()}") + print("\nAtoms:") + for atom in mol.GetAtoms(): + print(f"\tAtom index: {atom.GetIdx()}, Symbol: {atom.GetSymbol()}") + + print("\nBonds:") + for bond in mol.GetBonds(): + a1 = bond.GetBeginAtomIdx() + a2 = bond.GetEndAtomIdx() + btype = bond.GetBondType() + print(f"\tBond index: {bond.GetIdx()}, Atoms: ({a1}, {a2}), Type: {btype}") + + # Generate 2D coordinates + rdDepictor.Compute2DCoords(mol) + + # Set up drawer with high resolution + drawer = rdMolDraw2D.MolDraw2DCairo(*size) + options = drawer.drawOptions() + + # Display atom indices and symbols + options.addAtomIndices = True + # Show bond indices + options.addBondIndices = True + options.addStereoAnnotation = True + options.padding = 0.05 # Less whitespace + options.fixedBondLength = 25 # for visual clarity + + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + # Convert to image + png = drawer.GetDrawingText() + img = Image.open(io.BytesIO(png)) + + # Show using matplotlib + dpi = 300 + plt.figure(figsize=(size[0] / dpi, size[1] / dpi), dpi=dpi) + plt.imshow(img) + plt.axis("off") + plt.tight_layout(pad=0) + # plt.title("RDKit 2D Molecule with Atom Indices", fontsize=14) + plt.show() + + +class Main: + """ + Command-line wrapper class for plotting augmented molecular graphs. + """ + + @staticmethod + def plot( + smiles: str = "OC(=O)c1ccccc1O", + plot_type: str = "simple", + reader: str = "w_fge_w_gn", + ) -> None: + """ + Plot an augmented molecular graph from SMILES. + + Args: + smiles (str): SMILES string to parse. + plot_type (str): Type of plot. One of ['simple', 'h', '3d']. + - simple : 2D graph with all nodes on same plane + - h: Hierarchical 2D-graph with separate plane for each node type + - 3d: Hierarchical 3D-graph + reader (str): Reader type for graph augmentation. Options: + - 'n_fge_w_gn': FG nodes without FG edges, with a graph node connected to fg nodes. + - 'w_fge_w_gn': FG nodes with FG edges, with a graph node connected to fg nodes. + - 'w_fge_n_gn': FG nodes with FG edges, no graph node. + - 'n_fge_n_gn': FG nodes without FG edges, no graph node. + - 'atom_w_gn': Atom nodes only, connected to a graph node. + - 'gnwa_fgwa_nfge': Graph node connected to atoms and FG nodes connected to atoms, no FG edges. + - 'gnwa_fgwa_wfge': Graph node connected to atoms and FG nodes connected to atoms, with FG edges. + - 'gn_wall_fgwa_nfge': Graph node connected to all atoms and FG nodes, no FG edges. + - 'gn_wall_fgwa_wfge': Graph node connected to all atoms and FG nodes, with FG edges. + """ + fg_reader = READER[reader]() + mol = fg_reader._smiles_to_mol(smiles) + if mol is None: + raise ValueError(f"Invalid SMILES: {smiles}") + + if plot_type == "molecule_only": + plot_nonaugment_molecule_graph(mol) + return + + edge_index, augmented_molecule = fg_reader._create_augmented_graph(mol) + plot_augmented_graph(edge_index, augmented_molecule, mol, plot_type) + + +if __name__ == "__main__": + # Example: python visualize_augmented_molecule.py plot --smiles="OC(=O)c1ccccc1O" --plot_type="h" + # Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365, acetylsalicylic acid + # Salicylic acid -> OC(=O)c1ccccc1O ; CHEBI:16914 + # 1-hydroxy-2-naphthoic acid -> OC(=O)c1ccc2ccccc2c1O ; CHEBI:36108 ; Fused Rings + # 3-nitrobenzoic acid -> OC(=O)C1=CC(=CC=C1)[N+]([O-])=O ; CHEBI:231494 ; Ring + Novel atom (Nitrogen) + # nile blue A -> [Cl-].CCN(CC)c1ccc2nc3c(cc(N)c4ccccc34)[o+]c2c1 ; CHEBI:52163 ; Fused rings + Novel atoms + # CHEBI:52723; Complicated molecule; 'O.O.[Cl-].[Cl-].C1=Cc2ccc3C=CC=[N]4c3c2[N](=C1)[Ru++]4123[N]4=CC=Cc5ccc6C=CC=[N]1c6c45.C1=Cc4ccc5C=CC=[N]2c5c4[N]3=C1' + # CHEBI:87627; Complicated molecule; "[Cl-].[Cl-].[Cl-].[Cl-].[Zn++].COc1cc(ccc1[N+]#N)-c1ccc([N+]#N)c(OC)c1" + CLI(Main, as_positional=False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/dataclasses/testGraphPropAsPerNodeType.py b/tests/unit/dataclasses/testGraphPropAsPerNodeType.py new file mode 100644 index 0000000..df9a52f --- /dev/null +++ b/tests/unit/dataclasses/testGraphPropAsPerNodeType.py @@ -0,0 +1,151 @@ +import unittest + +import pandas as pd +import torch +from torch_geometric.data.data import Data as GeomData + +from chebai_graph.preprocessing.datasets.chebi import ChEBI50_WFGE_WGN_AsPerNodeType +from chebai_graph.preprocessing.properties import ( + AtomNodeLevel, + AugAtomCharge, + AugAtomHybridization, + BondLevel, + IsFGAlkyl, + IsHydrogenBondAcceptorFG, + RDKit2DNormalized, +) + + +class TestGraphPropAsPerNodeType(unittest.TestCase): + def test_merge_properties(self): + num_nodes = 4 + dummy_x = torch.zeros((num_nodes, 0)) # Initial dummy x + dummy_edge_index = torch.tensor([[0, 1], [1, 0]]) + dummy_edge_attr = torch.zeros((4, 0)) # 4 edges, each with 0 feature + + # Masks + is_atom_node = torch.tensor([1, 0, 1, 0], dtype=torch.bool) + is_graph_node = torch.tensor([0, 0, 0, 1], dtype=torch.bool) + + # GeomData + geom_data = GeomData( + x=dummy_x, + edge_index=dummy_edge_index, + edge_attr=dummy_edge_attr, + is_atom_node=is_atom_node, + is_graph_node=is_graph_node, + ) + + # Define properties + # atom props = 5, fg_props = 4, graph_node_props = 6; max = 6 + all_node_prop = AtomNodeLevel(DummyEncoder(2)) + atom_prop = AugAtomCharge(DummyEncoder(1)) + atom_prop_2 = AugAtomHybridization(DummyEncoder(2)) + fg_prop = IsFGAlkyl(DummyEncoder(1)) + fg_prop_2 = IsHydrogenBondAcceptorFG(DummyEncoder(1)) + mol_prop = RDKit2DNormalized(DummyEncoder(4)) + bond_prop = BondLevel(DummyEncoder(2)) + + properties = [ + atom_prop, + atom_prop_2, + fg_prop, + fg_prop_2, + bond_prop, + all_node_prop, + mol_prop, + ] + + merger = ChEBI50_WFGE_WGN_AsPerNodeType(properties) + + # Define encoded property values for the row + row = pd.Series( + { + "features": geom_data, + "AtomNodeLevel": torch.tensor( + [ + [1.0, 0.0], # atom + [0.0, 1.0], # fg + [0.0, 0.0], # atom + [0.0, 1.0], # graph + ] + ), + "AtomCharge": torch.tensor( + [ + [1.0], # atom + [2.0], # fg + [6.0], # atom + [3.0], # graph + ] + ), + "AtomHybridization": torch.tensor( + [ + [11.0, 9.0], # atom + [7.0, 3.0], # fg + [3.0, 1.0], # atom + [7.0, 43.0], # graph + ] + ), + "IsFGAlkyl": torch.tensor( + [ + [5.0], # atom + [55.0], # fg + [13.0], # atom + [14.0], # graph + ] + ), # values for fg at idx 1 + "IsHydrogenBondAcceptorFG": torch.tensor( + [ + [3.0], # atom + [5.0], # fg + [17.0], # atom + [15.0], # graph + ] + ), + "RDKit2DNormalized": torch.tensor( + [ + [65.0, 23.0, 6.0, 8.0], # atom + [2.0, 8.0, 55.0, 77.0], # fg + [3.0, 51.0, 55.0, 3.0], # atom + [33.0, 6.0, 10.0, 10.0], # graph + ] + ), # only idx 3 + "BondLevel": torch.tensor( + [ + [0.1, 0.2], + [0.3, 0.4], + ] + ), # will be duplicated to 4x2 + } + ) + expected_result = torch.tensor( + [ + # all node # ap1 # ap2 # 0 concat + [1.0, 0.0] + [1.0] + [11.0, 9.0] + [0.0], # atom node + # all node # fg1 # fg2 # 0 concat + [0.0, 1.0] + [55.0] + [5.0] + [0.0, 0.0], # fg node + # all node # ap1 # ap2 # 0 concat + [0.0, 0.0] + [6.0] + [3.0, 1.0] + [0.0], # atom node + # all node # mol props + [0.0, 1.0] + [33.0, 6.0, 10.0, 10.0], # graph node + ] + ) + + result = merger._merge_props_into_base(row, max_len_node_properties=6) + self.assertTrue(torch.equal(result.x, expected_result)) + + +class DummyEncoder: + def __init__(self, length): + self.length = length + + @property + def name(self): + return self.__class__.__name__.replace("DummyEncoder", "") + + def get_encoding_length(self): + return self.length + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py new file mode 100644 index 0000000..0222fa4 --- /dev/null +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -0,0 +1,86 @@ +import unittest + +import torch +from torch_geometric.data import Data as GeomData + +from chebai_graph.preprocessing.reader import GraphPropertyReader +from tests.unit.test_data import MoleculeGraph + + +class TestGraphPropertyReader(unittest.TestCase): + """Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects.""" + + def setUp(self) -> None: + """Initialize the reader and the reference molecule graph.""" + self.reader: GraphPropertyReader = GraphPropertyReader() + self.molecule_graph: MoleculeGraph = MoleculeGraph() + + def test_read_data(self) -> None: + """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" + smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin + + data: GeomData = self.reader._read_data(smiles) # noqa + + self.assertIsInstance( + data, + GeomData, + msg="The output should be an instance of torch_geometric.data.Data.", + ) + + self.assertEqual( + data.edge_index.shape[0], + 2, + msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}", + ) + + self.assertEqual( + data.edge_index.shape[1], + data.edge_attr.shape[0], + msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})", + ) + + self.assertEqual( + len(set(data.edge_index[0].tolist())), + data.x.shape[0], + msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})", + ) + + # Check for duplicates by checking if the rows are the same (direction matters) + _, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True) + self.assertFalse( + torch.any(counts > 1), + msg="There are duplicates of directed edge in edge_index", + ) + + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + self.assertTrue( + torch.equal(data.edge_index, expected_data.edge_index), + msg=( + "edge_index tensors do not match.\n" + f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n" + f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}" + f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6" + ), + ) + + self.assertEqual( + data.x.shape[0], + expected_data.x.shape[0], + msg=( + "The number of atoms (nodes) in the parsed graph does not match the reference graph.\n" + f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}" + ), + ) + + self.assertEqual( + data.edge_attr.shape[0], + expected_data.edge_attr.shape[0], + msg=( + "The number of edge attributes does not match the expected value.\n" + f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}" + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 0000000..4acf41d --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,102 @@ +import torch +from torch_geometric.data import Data + + +class MoleculeGraph: + """Class representing molecular graph data.""" + + def get_aspirin_graph(self): + """ + Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin. + + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 + + Node labels (atom indices): + O2 C5———C6 + \ / \ + C1———O3———C4 C7 + / \ / + C0 C9———C8 + / + C10 + / \ + O12 O11 + + + Returns: + torch_geometric.data.Data: A Data object with attributes: + - x (FloatTensor): Node feature matrix of shape (num_nodes, 1). + - edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges). + - edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1). + + Refer: + For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html + """ + + # --- Node features: atomic numbers (C=6, O=8) --- + # Shape of x : num_nodes x num_of_node_features + # fmt: off + x = torch.tensor( + [ + [6], # C0 - This feature belongs to atom/node with 0 value in edge_index + [6], # C1 - This feature belongs to atom/node with 1 value in edge_index + [8], # O2 - This feature belongs to atom/node with 2 value in edge_index + [8], # O3 - This feature belongs to atom/node with 3 value in edge_index + [6], # C4 - This feature belongs to atom/node with 4 value in edge_index + [6], # C5 - This feature belongs to atom/node with 5 value in edge_index + [6], # C6 - This feature belongs to atom/node with 6 value in edge_index + [6], # C7 - This feature belongs to atom/node with 7 value in edge_index + [6], # C8 - This feature belongs to atom/node with 8 value in edge_index + [6], # C9 - This feature belongs to atom/node with 9 value in edge_index + [6], # C10 - This feature belongs to atom/node with 10 value in edge_index + [8], # O11 - This feature belongs to atom/node with 11 value in edge_index + [8], # O12 - This feature belongs to atom/node with 12 value in edge_index + ], + dtype=torch.float, + ) + # fmt: on + + # --- Edge list (bidirectional) --- + # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) + # Generated using RDKIT 2024.9.6 + # fmt: off + _edge_index = torch.tensor([ + [0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v) + ], dtype=torch.long) + # fmt: on + + # Reverse the edges + reversed_edge_index = _edge_index[[1, 0], :] + + # First all directed edges from source to target are placed, + # then all directed edges from target to source are placed --- this is needed + undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1) + + # --- Dummy edge features --- + # Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1) + # fmt: off + _edge_attr = torch.tensor([ + [1], # C0 - C1, This two features belong to elements at index 0 in `edge_index` + [2], # C1 - C2, This two features belong to elements at index 1 in `edge_index` + [2], # C1 - O3, This two features belong to elements at index 2 in `edge_index` + [2], # O3 - C4, This two features belong to elements at index 3 in `edge_index` + [1], # C4 - C5, This two features belong to elements at index 4 in `edge_index` + [1], # C5 - C6, This two features belong to elements at index 5 in `edge_index` + [1], # C6 - C7, This two features belong to elements at index 6 in `edge_index` + [1], # C7 - C8, This two features belong to elements at index 7 in `edge_index` + [1], # C8 - C9, This two features belong to elements at index 8 in `edge_index` + [1], # C9 - C10, This two features belong to elements at index 9 in `edge_index` + [1], # C10 - O11, This two features belong to elements at index 10 in `edge_index` + [1], # C10 - O12, This two features belong to elements at index 11 in `edge_index` + [1], # C9 - C4, This two features belong to elements at index 12 in `edge_index` + ], dtype=torch.float) + # fmt: on + + # Alignement of edge attributes should in same order as of edge_index + undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0) + + # Create graph data object + return Data( + x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr + ) diff --git a/tutorials/classwise_stats.ipynb b/tutorials/classwise_stats.ipynb new file mode 100644 index 0000000..072a128 --- /dev/null +++ b/tutorials/classwise_stats.ipynb @@ -0,0 +1,767 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "id": "07f2eb38", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "eb64bfee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelres_baseline_propsres_wfgn_wfge_wgn_wprops_apoolgatv2_wfgn_wfge_wgn_wprops_apoolgatv2 - baselineres - baseline
01461800.00.00000.00000.00000.0000
1600040.00.11760.00000.00000.1176
2167330.00.23530.00000.00000.2353
3222990.00.28570.00000.00000.2857
4178150.00.12500.11110.11110.1250
\n", + "
" + ], + "text/plain": [ + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "0 146180 0.0 0.0000 \n", + "1 60004 0.0 0.1176 \n", + "2 16733 0.0 0.2353 \n", + "3 22299 0.0 0.2857 \n", + "4 17815 0.0 0.1250 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \n", + "0 0.0000 0.0000 0.0000 \n", + "1 0.0000 0.0000 0.1176 \n", + "2 0.0000 0.0000 0.2353 \n", + "3 0.0000 0.0000 0.2857 \n", + "4 0.1111 0.1111 0.1250 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_excel(\"G:\\github-aditya0by0\\python-chebai-graph\\classwise_f1_score.xlsx\")\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "14c10198", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1528, 6)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "32062eb6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['label', 'res_baseline_props', 'res_wfgn_wfge_wgn_wprops_apool',\n", + " 'gatv2_wfgn_wfge_wgn_wprops_apool', 'gatv2 - baseline ',\n", + " 'res - baseline'],\n", + " dtype='object')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5cd9265a", + "metadata": {}, + "outputs": [], + "source": [ + "BASELINE = \"res_baseline_props\"\n", + "RES = \"res_wfgn_wfge_wgn_wprops_apool\"\n", + "GATv2 = \"gatv2_wfgn_wfge_wgn_wprops_apool\"\n", + "GATv2_BASELINE_DIFF = \"gatv2 - baseline \"\n", + "RES_BASELINE_DIFF = \"res - baseline\"" + ] + }, + { + "cell_type": "markdown", + "id": "f679dfee", + "metadata": {}, + "source": [ + "### Summary statistics\n", + "\n", + "Look at the overall performance:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b38fa321", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean improvement over baseline:\n", + " res - baseline 0.054853\n", + "gatv2 - baseline 0.068136\n", + "dtype: float64\n", + "\n", + "Median improvement over baseline:\n", + " res - baseline 0.03330\n", + "gatv2 - baseline 0.04535\n", + "dtype: float64\n", + "\n", + "Variation in improvement:\n", + " res - baseline 0.144831\n", + "gatv2 - baseline 0.139299\n", + "dtype: float64\n" + ] + } + ], + "source": [ + "# mean improvement\n", + "mean_diff = df[[RES_BASELINE_DIFF, GATv2_BASELINE_DIFF]].mean()\n", + "print(\"Mean improvement over baseline:\\n\", mean_diff)\n", + "print()\n", + "\n", + "# median improvement\n", + "median_diff = df[[RES_BASELINE_DIFF, GATv2_BASELINE_DIFF]].median()\n", + "print(\"Median improvement over baseline:\\n\", median_diff)\n", + "print()\n", + "\n", + "\n", + "# standard deviation of improvements\n", + "std_diff = df[[RES_BASELINE_DIFF, GATv2_BASELINE_DIFF]].std()\n", + "print(\"Variation in improvement:\\n\", std_diff)" + ] + }, + { + "cell_type": "markdown", + "id": "6359ca34", + "metadata": {}, + "source": [ + "### Count of labels improved / worsened" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "d046db01", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Labels improved in ResGated Final model over baseline: 1010\n", + "Labels improved in GATv2 Final model over baseline: 1091\n" + ] + } + ], + "source": [ + "res_improvement = (df[RES_BASELINE_DIFF] > 0).sum()\n", + "gat_v2_improvement = (df[GATv2_BASELINE_DIFF] > 0).sum()\n", + "\n", + "print(f\"Labels improved in ResGated Final model over baseline: {res_improvement}\")\n", + "print(f\"Labels improved in GATv2 Final model over baseline: {gat_v2_improvement}\")" + ] + }, + { + "cell_type": "markdown", + "id": "91dd094e", + "metadata": {}, + "source": [ + "### Ranking per label" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "3c3972a1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "best_model\n", + "gatv2_wfgn_wfge_wgn_wprops_apool 678\n", + "res_wfgn_wfge_wgn_wprops_apool 524\n", + "res_baseline_props 326\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "df[\"best_model\"] = df[[BASELINE, RES, GATv2]].idxmax(axis=1)\n", + "print(df[\"best_model\"].value_counts())" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "02ecb78e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Top 10 classes with highest improvement for res - baseline:\n", + " 25 0.8889\n", + "24 0.8889\n", + "37 0.7692\n", + "34 0.7143\n", + "35 0.7143\n", + "29 0.7143\n", + "77 0.6889\n", + "41 0.6487\n", + "36 0.6400\n", + "23 0.6364\n", + "Name: res - baseline, dtype: float64\n", + "\n", + "Top 10 classes with highest improvement for gatv2 - baseline :\n", + " 25 0.8889\n", + "24 0.8889\n", + "37 0.7692\n", + "34 0.7143\n", + "35 0.7143\n", + "29 0.7143\n", + "77 0.6889\n", + "41 0.6487\n", + "36 0.6400\n", + "23 0.6364\n", + "Name: res - baseline, dtype: float64\n" + ] + } + ], + "source": [ + "# Classes with highest difference\n", + "top_diff_res = df[RES_BASELINE_DIFF].sort_values(ascending=False).head(10)\n", + "print(\n", + " f\"\\nTop 10 classes with highest improvement for {RES_BASELINE_DIFF}:\\n\",\n", + " top_diff_res,\n", + ")\n", + "\n", + "top_diff_gat = df[GATv2_BASELINE_DIFF].sort_values(ascending=False).head(10)\n", + "print(\n", + " f\"\\nTop 10 classes with highest improvement for {GATv2_BASELINE_DIFF}:\\n\",\n", + " top_diff_res,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "524a01d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelres_baseline_propsres_wfgn_wfge_wgn_wprops_apoolgatv2_wfgn_wfge_wgn_wprops_apoolgatv2 - baselineres - baselinebest_model
01461800.00.00.00000.00000.0res_baseline_props
9257110.00.00.25000.25000.0gatv2_wfgn_wfge_wgn_wprops_apool
10390930.00.00.25000.25000.0gatv2_wfgn_wfge_wgn_wprops_apool
1841940.00.00.46150.46150.0gatv2_wfgn_wfge_wgn_wprops_apool
19176080.00.00.46150.46150.0gatv2_wfgn_wfge_wgn_wprops_apool
........................
1523766291.01.01.00000.00000.0res_baseline_props
1524787761.01.01.00000.00000.0res_baseline_props
1525792041.01.01.00000.00000.0res_baseline_props
15261327421.01.01.00000.00000.0res_baseline_props
15271366371.01.01.00000.00000.0res_baseline_props
\n", + "

87 rows × 7 columns

\n", + "
" + ], + "text/plain": [ + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "0 146180 0.0 0.0 \n", + "9 25711 0.0 0.0 \n", + "10 39093 0.0 0.0 \n", + "18 4194 0.0 0.0 \n", + "19 17608 0.0 0.0 \n", + "... ... ... ... \n", + "1523 76629 1.0 1.0 \n", + "1524 78776 1.0 1.0 \n", + "1525 79204 1.0 1.0 \n", + "1526 132742 1.0 1.0 \n", + "1527 136637 1.0 1.0 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "0 0.0000 0.0000 0.0 \n", + "9 0.2500 0.2500 0.0 \n", + "10 0.2500 0.2500 0.0 \n", + "18 0.4615 0.4615 0.0 \n", + "19 0.4615 0.4615 0.0 \n", + "... ... ... ... \n", + "1523 1.0000 0.0000 0.0 \n", + "1524 1.0000 0.0000 0.0 \n", + "1525 1.0000 0.0000 0.0 \n", + "1526 1.0000 0.0000 0.0 \n", + "1527 1.0000 0.0000 0.0 \n", + "\n", + " best_model \n", + "0 res_baseline_props \n", + "9 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "10 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "18 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "19 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "... ... \n", + "1523 res_baseline_props \n", + "1524 res_baseline_props \n", + "1525 res_baseline_props \n", + "1526 res_baseline_props \n", + "1527 res_baseline_props \n", + "\n", + "[87 rows x 7 columns]" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Classes with minimal or no difference\n", + "minimal_diff_threshold = 0.01 # you can adjust\n", + "minimal_diff = df[df[RES_BASELINE_DIFF].abs() <= minimal_diff_threshold]\n", + "minimal_diff" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "5f15e872", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Classes where baseline performs better than our model:\n", + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "45 59999 0.1176 0.0541 \n", + "47 59266 0.1176 0.1143 \n", + "50 25095 0.1333 0.0000 \n", + "70 48927 0.2000 0.0000 \n", + "72 83943 0.2000 0.0000 \n", + "... ... ... ... \n", + "1508 24846 1.0000 0.9333 \n", + "1509 26187 1.0000 0.9333 \n", + "1510 78231 1.0000 0.9333 \n", + "1511 26377 1.0000 0.9565 \n", + "1512 38834 1.0000 0.9677 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "45 0.1250 0.0074 -0.0635 \n", + "47 0.5385 0.4209 -0.0033 \n", + "50 0.2667 0.1334 -0.1333 \n", + "70 0.0000 -0.2000 -0.2000 \n", + "72 0.2000 0.0000 -0.2000 \n", + "... ... ... ... \n", + "1508 1.0000 0.0000 -0.0667 \n", + "1509 1.0000 0.0000 -0.0667 \n", + "1510 1.0000 0.0000 -0.0667 \n", + "1511 1.0000 0.0000 -0.0435 \n", + "1512 1.0000 0.0000 -0.0323 \n", + "\n", + " best_model \n", + "45 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "47 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "50 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "70 res_baseline_props \n", + "72 res_baseline_props \n", + "... ... \n", + "1508 res_baseline_props \n", + "1509 res_baseline_props \n", + "1510 res_baseline_props \n", + "1511 res_baseline_props \n", + "1512 res_baseline_props \n", + "\n", + "[431 rows x 7 columns]\n" + ] + } + ], + "source": [ + "# Classes where baseline performs better\n", + "baseline_better = df[df[RES_BASELINE_DIFF] < 0]\n", + "print(\"\\nClasses where baseline performs better than our model:\\n\", baseline_better)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "0c4292bf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Classes where both baseline and res_wfgn_wfge_wgn_wprops_apool have F1 < 0.5:\n", + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "0 146180 0.0000 0.0000 \n", + "1 60004 0.0000 0.1176 \n", + "2 16733 0.0000 0.2353 \n", + "3 22299 0.0000 0.2857 \n", + "4 17815 0.0000 0.1250 \n", + ".. ... ... ... \n", + "265 32877 0.4615 0.4706 \n", + "271 139358 0.4667 0.4000 \n", + "283 35346 0.4762 0.4706 \n", + "284 35507 0.4762 0.4737 \n", + "295 24586 0.4878 0.4762 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "0 0.0000 0.0000 0.0000 \n", + "1 0.0000 0.0000 0.1176 \n", + "2 0.0000 0.0000 0.2353 \n", + "3 0.0000 0.0000 0.2857 \n", + "4 0.1111 0.1111 0.1250 \n", + ".. ... ... ... \n", + "265 0.5333 0.0718 0.0091 \n", + "271 0.5517 0.0850 -0.0667 \n", + "283 0.3529 -0.1233 -0.0056 \n", + "284 0.3871 -0.0891 -0.0025 \n", + "295 0.6667 0.1789 -0.0116 \n", + "\n", + " best_model \n", + "0 res_baseline_props \n", + "1 res_wfgn_wfge_wgn_wprops_apool \n", + "2 res_wfgn_wfge_wgn_wprops_apool \n", + "3 res_wfgn_wfge_wgn_wprops_apool \n", + "4 res_wfgn_wfge_wgn_wprops_apool \n", + ".. ... \n", + "265 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "271 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "283 res_baseline_props \n", + "284 res_baseline_props \n", + "295 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "\n", + "[148 rows x 7 columns]\n" + ] + } + ], + "source": [ + "# Classes where both baseline and our model perform worse\n", + "# Assuming \"worse\" means below a threshold, e.g., f1 < 0.5\n", + "threshold = 0.5\n", + "both_worse = df[(df[BASELINE] < threshold) & (df[RES] < threshold)]\n", + "print(f\"\\nClasses where both baseline and {RES} have F1 < {threshold}:\\n\", both_worse)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "8948b7bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Classes where baseline performs better than our model:\n", + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "45 59999 0.1176 0.0541 \n", + "47 59266 0.1176 0.1143 \n", + "50 25095 0.1333 0.0000 \n", + "70 48927 0.2000 0.0000 \n", + "72 83943 0.2000 0.0000 \n", + "... ... ... ... \n", + "1508 24846 1.0000 0.9333 \n", + "1509 26187 1.0000 0.9333 \n", + "1510 78231 1.0000 0.9333 \n", + "1511 26377 1.0000 0.9565 \n", + "1512 38834 1.0000 0.9677 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "45 0.1250 0.0074 -0.0635 \n", + "47 0.5385 0.4209 -0.0033 \n", + "50 0.2667 0.1334 -0.1333 \n", + "70 0.0000 -0.2000 -0.2000 \n", + "72 0.2000 0.0000 -0.2000 \n", + "... ... ... ... \n", + "1508 1.0000 0.0000 -0.0667 \n", + "1509 1.0000 0.0000 -0.0667 \n", + "1510 1.0000 0.0000 -0.0667 \n", + "1511 1.0000 0.0000 -0.0435 \n", + "1512 1.0000 0.0000 -0.0323 \n", + "\n", + " best_model \n", + "45 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "47 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "50 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "70 res_baseline_props \n", + "72 res_baseline_props \n", + "... ... \n", + "1508 res_baseline_props \n", + "1509 res_baseline_props \n", + "1510 res_baseline_props \n", + "1511 res_baseline_props \n", + "1512 res_baseline_props \n", + "\n", + "[431 rows x 7 columns]\n" + ] + } + ], + "source": [ + "# Classes where baseline performs better\n", + "baseline_better = df[df[RES_BASELINE_DIFF] < 0]\n", + "print(\"\\nClasses where baseline performs better than our model:\\n\", baseline_better)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "gnn3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}